Browse Source

Bugfixes + plotting extensions + LotkaVolterra example

Randy Paredis 3 years ago
parent
commit
36dc8eff63

BIN
CBD-startingpoint.zip


File diff suppressed because it is too large
+ 1 - 0
examples/LotkeVolterra/diagram.io


+ 69 - 0
examples/LotkeVolterra/diagram.py

@@ -0,0 +1,69 @@
+#!/usr/bin/python3
+# This file was automatically generated from drawio2cbd with the command:
+#   /home/red/git/DrawioConvert/__main__.py -F CBD -e LotkeVolterra -vaSrg -E delta=0.1 -t 100 diagram.io
+
+from CBD.Core import *
+from CBD.lib.std import *
+from CBD.lib.endpoints import SignalCollectorBlock
+
+DELTA_T = 0.1
+
+class LotkeVolterra(CBD):
+    def __init__(self, block_name, alpha=(1.5), beta=(0.7), delta=(0.2), gamma=(0.2)):
+        CBD.__init__(self, block_name, input_ports=[], output_ports=[])
+
+        # Create the Blocks
+        self.addBlock(ConstantBlock("alpha", value=(alpha)))
+        self.addBlock(ConstantBlock("gamma", value=(-gamma)))
+        self.addBlock(ConstantBlock("delta", value=(delta)))
+        self.addBlock(ProductBlock("ax"))
+        self.addBlock(IntegratorBlock("x"))
+        self.addBlock(IntegratorBlock("y"))
+        self.addBlock(ProductBlock("xy"))
+        self.addBlock(ConstantBlock("beta", value=(-beta)))
+        self.addBlock(ProductBlock("gy"))
+        self.addBlock(AdderBlock("sum1"))
+        self.addBlock(ProductBlock("bxy"))
+        self.addBlock(ProductBlock("dxy"))
+        self.addBlock(AdderBlock("sum2"))
+        self.addBlock(ConstantBlock("IC", value=(10)))
+        self.addBlock(SignalCollectorBlock("rabbits"))
+        self.addBlock(SignalCollectorBlock("foxes"))
+
+        # Create the Connections
+        self.addConnection("x", "ax", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("x", "xy", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("y", "xy", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("y", "gy", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("gamma", "gy", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("xy", "bxy", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("xy", "dxy", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("beta", "bxy", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("alpha", "ax", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("ax", "sum1", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("bxy", "sum1", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("sum1", "x", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("sum1", "rabbits", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("delta", "dxy", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("dxy", "sum2", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("gy", "sum2", output_port_name='OUT1', input_port_name='IN2')
+        self.addConnection("sum2", "y", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("sum2", "foxes", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("IC", "y", output_port_name='OUT1', input_port_name='IC')
+        self.addConnection("IC", "x", output_port_name='OUT1', input_port_name='IC')
+
+
+class SinGen(CBD):
+    def __init__(self, block_name):
+        CBD.__init__(self, block_name, input_ports=[], output_ports=[])
+
+        # Create the Blocks
+        self.addBlock(TimeBlock("time"))
+        self.addBlock(GenericBlock("sin", block_operator=("sin")))
+        self.addBlock(SignalCollectorBlock("mySin"))
+
+        # Create the Connections
+        self.addConnection("time", "sin", output_port_name='OUT1', input_port_name='IN1')
+        self.addConnection("sin", "mySin", output_port_name='OUT1', input_port_name='IN1')
+
+

+ 39 - 0
examples/LotkeVolterra/diagram_experiment.py

@@ -0,0 +1,39 @@
+#!/usr/bin/python3
+# This file was automatically generated from drawio2cbd with the command:
+#   /home/red/git/DrawioConvert/__main__.py -F CBD -e LotkeVolterra -fvaSrg -E delta=0.1 -t 100 diagram.io
+
+from diagram import *
+from CBD.simulator import Simulator
+from CBD.realtime.plotting import plot, LinePlot
+from CBD.converters.latexify import CBD2Latex
+
+DELTA_T = 0.01
+END = 10.0
+
+cbd = LotkeVolterra("LotkeVolterra")
+
+cbd.addFixedRateClock("clock", DELTA_T)
+cbd.addConnection("clock-clock", "x", output_port_name="delta", input_port_name='delta_t')
+cbd.addConnection("clock-clock", "y", output_port_name="delta", input_port_name='delta_t')
+
+# cbd.flatten()
+c2l = CBD2Latex(cbd)
+c2l.simplify()
+print(c2l.render())
+
+# cbd = SinGen("SinGen")
+
+# Run the Simulation
+sim = Simulator(cbd)
+sim.setDeltaT(DELTA_T)
+sim.run(END)
+
+# TODO: Process Your Simulation Results
+import matplotlib.pyplot as plt
+fig = plt.figure()
+ax = fig.add_subplot(111)
+plot(cbd.findBlock("rabbits")[0], (fig, ax), LinePlot(color='red', label='rabbits'))
+plot(cbd.findBlock("foxes")[0], (fig, ax), LinePlot(color='blue', label='foxes'))
+# plot(cbd.findBlock("mySin")[0], (fig, ax), LinePlot(color='red'))
+ax.legend()
+plt.show()

+ 4 - 2
src/CBD/depGraph.py

@@ -324,10 +324,12 @@ def gvDepGraph(model, curIt):
 		curIt (int):    The iteration for which the dependency graph will
 						be constructed.
 	"""
-	depGraph = createDepGraph(model, curIt)
+	m2 = model.clone()
+	m2.flatten()
+	depGraph = createDepGraph(m2, curIt)
 	nodes = []
 	edges = []
-	for block in model.getBlocks():
+	for block in m2.getBlocks():
 		nodes.append('{name} [label="{type}({name})"];'.format(name=block.getBlockName(), type=block.getBlockType()))
 		for inf in depGraph.getInfluencers(block):
 			edges.append("{} -> {};".format(block.getBlockName(), inf.getBlockName()))

+ 82 - 9
src/CBD/realtime/plotting.py

@@ -28,6 +28,9 @@ try:
 except ImportError:
 	_BOKEH_FOUND = False
 
+__all__ = ['Backend', 'PlotKind', 'PlotHandler', 'PlotManager', 'plot', 'follow', 'set_xlim', 'set_ylim',
+           'Arrow', 'StepPlot', 'ScatterPlot', 'LinePlot']
+
 # TODO: Bokeh (see TODOs)
 # TODO: More Plot Kinds
 
@@ -146,13 +149,14 @@ class PlotHandler:
 		interval (int):     Amount of milliseconds between plot refreshes.
 		frames (int):       The amount of frames for the animation. Only used if the
 							animation needs to be saved.
+		static (bool):      When :code:`True`, no animation will be created.
 
 	See Also:
 		- :class:`Backend`
 		- :class:`PlotKind`
 		- :class:`PlotManager`
 	"""
-	def __init__(self, object, figure, kind, backend=Backend.MPL, interval=100, frames=None):
+	def __init__(self, object, figure, kind, backend=Backend.MPL, interval=100, frames=None, static=False):
 		assert Backend.exists(backend), "Invalid backend."
 		self.object = object
 		self.kind = kind
@@ -173,14 +177,15 @@ class PlotHandler:
 		}
 
 		# backend info:
-		if Backend.compare("MPL", backend) or Backend.compare("SNS", backend):
-			self.__ani = animation.FuncAnimation(figure[0], lambda _: self.update(),
-			                                     interval=interval, frames=frames)
-			figure[0].canvas.mpl_connect('close_event', lambda evt: self.__close_event())
-		elif Backend.compare("BOKEH", backend):
-			curdoc().add_periodic_callback(lambda: self.update(), interval)
-			# TODO (is this even possible?):
-			curdoc().on_session_destroyed(lambda ctx: self.__close_event())
+		if not static:
+			if Backend.compare("MPL", backend) or Backend.compare("SNS", backend):
+				self.__ani = animation.FuncAnimation(figure[0], lambda _: self.update(),
+				                                     interval=interval, frames=frames)
+				figure[0].canvas.mpl_connect('close_event', lambda evt: self.__close_event())
+			elif Backend.compare("BOKEH", backend):
+				curdoc().add_periodic_callback(lambda: self.update(), interval)
+				# TODO (is this even possible?):
+				curdoc().on_session_destroyed(lambda ctx: self.__close_event())
 
 	def signal(self, name, *args):
 		"""
@@ -792,6 +797,74 @@ class PlotManager:
 			handler.stop()
 
 
+def plot(object, figure, kind, backend=Backend.get("MPL"), margin=(0.02, 0.02)):
+	"""
+	Plot data on a figure after the simulation has finished.
+	This will automatically plot the full graph, so all axis scaling needs to be done afterwards.
+
+	Args:
+		object (Any):       The object from which data needs to be polled. By default,
+							the :code:`data_xy` attribute will be used on the object,
+							which should result in a 2xN array in the form of
+							:code:`([x1, x2, x3...], [y1, y2, y3...])`.
+		figure:             The figure object required for plotting. This has been made
+							externally to allow for full manipulation of the plotting
+							framework itself.
+		kind (PlotKind):    What needs to be plotted. See :class:`PlotKind` for more
+							info.
+		backend (Backend):  The backend to use.
+		margin (tuple):     A margin for the limits of the axes.
+
+	See Also:
+		:class:`PlotHandler`
+	"""
+	assert Backend.exists(backend), "Invalid backend."
+	ph = PlotHandler(object, figure, kind, backend, static=True)
+	ph.update()
+
+	x, y = ph.get_data()
+	xlim = min(x), max(x)
+	ylim = min(y), max(y)
+	xlim = xlim[0] - margin[0], xlim[1] + margin[0]
+	ylim = ylim[0] - margin[1], ylim[1] + margin[1]
+
+	set_xlim(figure, backend, xlim)
+	set_ylim(figure, backend, ylim)
+
+
+def set_xlim(figure, backend, values):
+	"""
+	Shorthand method for setting the x limits of the figure.
+
+	Args:
+		figure:             The figure to alter.
+		backend (Backend):  The backend to use.
+		values (tuple):     The new limits for the x-axis.
+	"""
+	if Backend.compare("MPL", backend) or Backend.compare("SNS", backend):
+		figure[1].set_xlim(values)
+	elif Backend.compare("BKH", backend):
+		lower, upper = values
+		figure.x_range.start = lower
+		figure.x_range.end = upper
+
+
+def set_ylim(figure, backend, values):
+	"""
+	Shorthand method for setting the y limits of the figure.
+
+	Args:
+		figure:             The figure to alter.
+		backend (Backend):  The backend to use.
+		values (tuple):     The new limits for the y-axis.
+	"""
+	if Backend.compare("MPL", backend) or Backend.compare("SNS"):
+		figure[1].set_ylim(values)
+	elif Backend.compare("BKH", backend):
+		lower, upper = values
+		figure.y_range.start = lower
+		figure.y_range.end = upper
+
 # TEMPORARY TEST OBJECT:
 import math
 class __Block:

+ 1 - 0
src/CBD/solver.py

@@ -169,6 +169,7 @@ class LinearSolver(Solver):
 				M1[i, indexdict[dblock]] = - 1
 			elif block.getBlockType() == "DelayBlock":
 				# If a delay is in a strong component, this is the first iteration
+				# FIXME: turn this into a normal error?
 				assert curIteration == 0
 				# And so the dependency is the IC
 				# M2 can stay 0 because we have an equation of the type -x = -ic <=> -x + ic = 0