|
|
@@ -0,0 +1,216 @@
|
|
|
+"""
|
|
|
+This module contains all the logic for Runge-Kutta preprocessing.
|
|
|
+"""
|
|
|
+from CBD.CBD import CBD
|
|
|
+from CBD.lib.std import *
|
|
|
+
|
|
|
+class RKPreprocessor:
|
|
|
+ def __init__(self, tableau):
|
|
|
+ self._tableau = tableau
|
|
|
+
|
|
|
+ def preprocess(self, original):
|
|
|
+ # 1. Detect all IVPs and group them in their own blocks
|
|
|
+ model = original.clone()
|
|
|
+ model.flatten(ignore=[IntegratorBlock])
|
|
|
+ blocks = model.getBlocks()
|
|
|
+ IVP = []
|
|
|
+ for block in blocks:
|
|
|
+ if isinstance(block, IntegratorBlock):
|
|
|
+ i = len(IVP)
|
|
|
+ ivp = CBD("IVP-%d" % i, ["time"], ["OUT1"])
|
|
|
+ collection = self.collect(block, ["IN1"], [IntegratorBlock, DelayBlock, TimeBlock])
|
|
|
+ for child in collection:
|
|
|
+ ivp.addBlock(child.clone())
|
|
|
+ for child in collection:
|
|
|
+ for name_input, link in child.getLinksIn().items():
|
|
|
+ lbn = link.block.getBlockName()
|
|
|
+ lop = link.output_port
|
|
|
+ if not ivp.hasBlock(lbn):
|
|
|
+ if link.block.getBlockType() == "TimeBlock":
|
|
|
+ lbn = 'time'
|
|
|
+ else:
|
|
|
+ lbn = name_input + "-" + child.getBlockName()
|
|
|
+ ivp.addInputPort(lbn)
|
|
|
+ lop = None
|
|
|
+ ivp.addConnection(lbn, child.getBlockName(), name_input, lop)
|
|
|
+ fin = block.getBlockConnectedToInput("IN1")
|
|
|
+ ivp.addConnection(fin.block.getBlockName(), "OUT1", None, fin.output_port)
|
|
|
+ IVP.append((block, ivp))
|
|
|
+
|
|
|
+ # 2. Foreach IVP: create an RK-model, based on the given tableau
|
|
|
+ RKs = []
|
|
|
+ for block, ivp in IVP:
|
|
|
+ RKs.append((block, self.create_RK(ivp)))
|
|
|
+
|
|
|
+ # 3. Substitute the RK-model collection in the original CBD
|
|
|
+ # TODO: multiple outputs
|
|
|
+ # TODO: multiple IVP?
|
|
|
+ outputs = original.getSignals().keys()
|
|
|
+ new_model = CBD(original.getBlockName(), [], outputs)
|
|
|
+ new_model.addBlock(TimeBlock("time"))
|
|
|
+ new_model.addBlock(DeltaTBlock("delta_t"))
|
|
|
+ for integ, rk in RKs:
|
|
|
+ rkname = rk.getBlockName()
|
|
|
+ new_model.addBlock(rk)
|
|
|
+
|
|
|
+ # TODO: more complex IC?
|
|
|
+ ic = integ.getBlockConnectedToInput("IC").block
|
|
|
+ new_model.addBlock(ic.clone())
|
|
|
+ new_model.addConnection(ic.getBlockName(), rkname, "IC")
|
|
|
+
|
|
|
+ new_model.addConnection("delta_t", rkname, "h")
|
|
|
+ new_model.addConnection("time", rkname, "t")
|
|
|
+ for outp in outputs:
|
|
|
+ conn = original.getBlockByName(outp).getBlockConnectedToInput("IN1").block
|
|
|
+ if conn.getBlockName() == integ.getBlockName():
|
|
|
+ new_model.addConnection(rkname, outp)
|
|
|
+ return new_model
|
|
|
+
|
|
|
+
|
|
|
+ def collect(self, start, sport=None, finish=None):
|
|
|
+ if finish is None:
|
|
|
+ finish = []
|
|
|
+ collection = [x[1].block for x in start.getLinksIn().items() if \
|
|
|
+ ((sport is not None and x[0] in sport) or (sport is None)) and not isinstance(x[1].block, tuple(finish))]
|
|
|
+ n_collection = [x.getBlockName() for x in collection]
|
|
|
+ for block in collection:
|
|
|
+ ccoll = self.collect(block, None, finish)
|
|
|
+ for child in ccoll:
|
|
|
+ cname = child.getBlockName()
|
|
|
+ if cname not in n_collection:
|
|
|
+ n_collection.append(cname)
|
|
|
+ collection.append(child)
|
|
|
+ return collection
|
|
|
+
|
|
|
+ def create_RK(self, f):
|
|
|
+ RK = CBD("RK", ["h", "t", "IC"], ["OUT1"])
|
|
|
+ fy = [x for x in f.getInputPortNames() if x != 'time']
|
|
|
+ weights = self._tableau.getWeights()[0]
|
|
|
+ RK.addBlock(AdderBlock("RKSum", len(weights)))
|
|
|
+ RK.addBlock(DelayBlock("delay"))
|
|
|
+ for i in range(len(weights)):
|
|
|
+ j = i + 1
|
|
|
+ RK.addBlock(self.create_K(j, f.clone()))
|
|
|
+ RK.addBlock(ProductBlock("Mult_%d" % j))
|
|
|
+ RK.addBlock(ConstantBlock("B_%d" % j, weights[i]))
|
|
|
+ RK.addConnection("h", "RK-K_%d" % j, "h")
|
|
|
+ RK.addConnection("t", "RK-K_%d" % j, "t")
|
|
|
+ for y in fy:
|
|
|
+ RK.addConnection("delay", "RK-K_%d" % j, y)
|
|
|
+ RK.addConnection("B_%d" % j, "Mult_%d" % j)
|
|
|
+ RK.addConnection("RK-K_%d" % j, "Mult_%d" % j)
|
|
|
+ RK.addConnection("Mult_%d" % j, "RKSum")
|
|
|
+ for s in range(i):
|
|
|
+ RK.addConnection("RK-K_%d" % (s+1), "RK-K_%d" % j, "k_%d" % (s+1))
|
|
|
+
|
|
|
+ # Initial Conditions
|
|
|
+ RK.addBlock(NegatorBlock("neg"))
|
|
|
+ RK.addBlock(AdderBlock("ICSum"))
|
|
|
+ RK.addConnection("RKSum", "neg")
|
|
|
+ RK.addConnection("neg", "ICSum")
|
|
|
+ RK.addConnection("IC", "ICSum")
|
|
|
+ RK.addConnection("ICSum", "delay", "IC")
|
|
|
+
|
|
|
+ # Loop
|
|
|
+ RK.addBlock(AdderBlock("YSum"))
|
|
|
+ RK.addConnection("delay", "YSum")
|
|
|
+ RK.addConnection("YSum", "delay", "IN1")
|
|
|
+ RK.addConnection("RKSum", "YSum")
|
|
|
+ RK.addConnection("delay", "OUT1")
|
|
|
+
|
|
|
+ return RK
|
|
|
+
|
|
|
+ def create_K(self, s, f):
|
|
|
+ input_ports = ["h", "t"] + ["k_%d" % (i+1) for i in range(s-1)]
|
|
|
+ fy = [x for x in f.getInputPortNames() if x != "time"]
|
|
|
+ input_ports += fy
|
|
|
+ K = CBD("RK-K_%d" % s, input_ports, ["OUT1"])
|
|
|
+ K.addBlock(f)
|
|
|
+
|
|
|
+ # Time parameter
|
|
|
+ K.addBlock(ConstantBlock("C", self._tableau.getNodes()[s-1]))
|
|
|
+ K.addBlock(ProductBlock("CMult"))
|
|
|
+ K.addBlock(AdderBlock("CSum"))
|
|
|
+ K.addConnection("h", "CMult")
|
|
|
+ K.addConnection("C", "CMult")
|
|
|
+ K.addConnection("t", "CSum")
|
|
|
+ K.addConnection("CMult", "CSum")
|
|
|
+ K.addConnection("CSum", f.getBlockName(), "time")
|
|
|
+
|
|
|
+ # Y parameters
|
|
|
+ if s - 1 > 0:
|
|
|
+ K.addBlock(AdderBlock("KSum", s - 1))
|
|
|
+ for i in range(s-1):
|
|
|
+ j = i + 1
|
|
|
+ K.addBlock(ConstantBlock("A_%d" % j, self._tableau.getA(s-1, j)))
|
|
|
+ K.addBlock(ProductBlock("Mult_%d" % j))
|
|
|
+ K.addConnection("A_%d" % j, "Mult_%d" % j)
|
|
|
+ K.addConnection("k_%d" % j, "Mult_%d" % j)
|
|
|
+ K.addConnection("Mult_%d" % j, "KSum")
|
|
|
+ for y in fy:
|
|
|
+ K.addInputPort(y)
|
|
|
+ K.addBlock(AdderBlock("YSum-%s" % y))
|
|
|
+ K.addConnection(y, "YSum-%s" % y)
|
|
|
+ K.addConnection("KSum", "YSum-%s" % y)
|
|
|
+ K.addConnection("YSum-%s" % y, f.getBlockName(), y)
|
|
|
+ else:
|
|
|
+ for y in fy:
|
|
|
+ K.addInputPort(y)
|
|
|
+ K.addConnection(y, f.getBlockName(), y)
|
|
|
+
|
|
|
+ # Finishing Up
|
|
|
+ K.addBlock(ProductBlock("FMult"))
|
|
|
+ K.addConnection("h", "FMult")
|
|
|
+ K.addConnection(f.getBlockName(), "FMult")
|
|
|
+ K.addConnection("FMult", "OUT1")
|
|
|
+
|
|
|
+ return K
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ from CBD.stepsize import ButcherTableau as BT
|
|
|
+ from CBD.CBDDraw import draw
|
|
|
+ DELTA_T = 0.1
|
|
|
+
|
|
|
+ class Test(CBD):
|
|
|
+ def __init__(self, block_name):
|
|
|
+ CBD.__init__(self, block_name, input_ports=[], output_ports=['y'])
|
|
|
+
|
|
|
+ # Create the Blocks
|
|
|
+ self.addBlock(IntegratorBlock("int"))
|
|
|
+ self.addBlock(ConstantBlock("IC", value=(0)))
|
|
|
+ self.addBlock(ProductBlock("mult"))
|
|
|
+ self.addBlock(AdderBlock("sum"))
|
|
|
+ self.addBlock(ConstantBlock("one", value=(1)))
|
|
|
+ self.addBlock(ConstantBlock("time", value=(DELTA_T)))
|
|
|
+
|
|
|
+ # Create the Connections
|
|
|
+ self.addConnection("IC", "int", output_port_name='OUT1', input_port_name='IC')
|
|
|
+ self.addConnection("int", "mult", output_port_name='OUT1', input_port_name='IN1')
|
|
|
+ self.addConnection("int", "mult", output_port_name='OUT1', input_port_name='IN2')
|
|
|
+ self.addConnection("int", "y", output_port_name='OUT1')
|
|
|
+ self.addConnection("mult", "sum", output_port_name='OUT1', input_port_name='IN2')
|
|
|
+ self.addConnection("one", "sum", output_port_name='OUT1', input_port_name='IN1')
|
|
|
+ self.addConnection("sum", "int", output_port_name='OUT1', input_port_name='IN1')
|
|
|
+ self.addConnection("time", "int", output_port_name='OUT1', input_port_name='delta_t')
|
|
|
+
|
|
|
+ prep = RKPreprocessor(BT.Heun())
|
|
|
+ model = prep.preprocess(Test("Test"))
|
|
|
+ draw(model.findBlock("RK.RK-K_2")[0], "test.dot")
|
|
|
+ # model = Test("Test")
|
|
|
+
|
|
|
+ # from CBD.simulator import Simulator
|
|
|
+ # sim = Simulator(model)
|
|
|
+ # sim.setDeltaT(0.1)
|
|
|
+ # sim.run(1.4)
|
|
|
+ #
|
|
|
+ # s = model.getSignal("y")
|
|
|
+ # L = len(s)
|
|
|
+ #
|
|
|
+ # print("+------------+------------+")
|
|
|
+ # print("| TIME | VALUE |")
|
|
|
+ # print("+------------+------------+")
|
|
|
+ # for i in range(L):
|
|
|
+ # t, v = s[i]
|
|
|
+ # print(f"| {t:10.7f} | {v:10.7f} |")
|
|
|
+ # print("+------------+------------+")
|