Bläddra i källkod

Multiple State Events

rparedis 3 år sedan
förälder
incheckning
da19d90d1f

+ 1 - 0
doc/CBD.rst

@@ -16,6 +16,7 @@ Subpackages
     CBD.tracers
     CBD.converters
     CBD.loopsolvers
+    CBD.state_events
     CBD.preprocessing
 
 Submodules

+ 7 - 0
doc/CBD.state_events.locators.rst

@@ -0,0 +1,7 @@
+CBD.state_events.locators module
+================================
+
+.. automodule:: CBD.state_events.locators
+    :members:
+    :undoc-members:
+    :show-inheritance:

+ 15 - 0
doc/CBD.state_events.rst

@@ -0,0 +1,15 @@
+CBD.state_events module
+=======================
+
+.. automodule:: CBD.state_events
+    :members:
+    :undoc-members:
+    :show-inheritance:
+
+Submodules
+----------
+
+.. toctree::
+
+   CBD.state_events.locators
+

+ 6 - 3
examples/scripts/BouncingBall/BouncingBall.py

@@ -4,7 +4,7 @@ from CBD.lib.endpoints import SignalCollectorBlock
 
 class BouncingBall(CBD):
 	def __init__(self, k=0.7):
-		super(BouncingBall, self).__init__("BouncingBall", output_ports=["height"])
+		super(BouncingBall, self).__init__("BouncingBall", output_ports=["height", "velocity"])
 		self.k = k
 
 		self.addBlock(ConstantBlock("g", -9.81))
@@ -12,12 +12,15 @@ class BouncingBall(CBD):
 		self.addBlock(ConstantBlock("y0", 100))
 		self.addBlock(IntegratorBlock("v"))
 		self.addBlock(IntegratorBlock("y"))
-		self.addBlock(SignalCollectorBlock("plot"))
+		self.addBlock(SignalCollectorBlock("plot1"))
+		self.addBlock(SignalCollectorBlock("plot2"))
 
 		self.addConnection("g", "v")
 		self.addConnection("v", "y")
+		self.addConnection("v", "velocity")
 		self.addConnection("y", "height")
-		self.addConnection("y", "plot")
+		self.addConnection("y", "plot1")
+		self.addConnection("v", "plot2")
 		self.addConnection("v0", "v", input_port_name="IC")
 		self.addConnection("y0", "y", input_port_name="IC")
 

+ 26 - 11
examples/scripts/BouncingBall/BouncingBall_experiment.py

@@ -7,22 +7,36 @@ bb = BouncingBall()
 
 DELTA = 0.1
 TIME = 15.0
-# TIME = 4
 
 fig = plt.figure(figsize=(5, 5), dpi=100)
-ax = fig.add_subplot(111)
-ax.set_xlim((0, TIME))
-ax.set_ylim((0, 105))
-plot = fig, ax
+ax1 = fig.add_subplot(121)
+ax2 = fig.add_subplot(122)
+ax1.set_xlim((0, TIME))
+ax1.set_ylim((-1, 105))
+ax1.set_ylabel("height")
+ax2.set_xlim((0, TIME))
+ax2.set_ylim((-50, 50))
+ax2.set_ylabel("velocity")
+plot1 = fig, ax1
+plot2 = fig, ax2
+
+ax1.plot((0, TIME+DELTA), (0, 0), c='purple', ls='--', lw=0.1)
 
 manager = PlotManager()
-manager.register("height", bb.getBlockByName('plot'), plot, LinePlot(color='red'))
+manager.register("height", bb.getBlockByName('plot1'), plot1, LinePlot(color='red'))
+manager.register("velocity", bb.getBlockByName('plot2'), plot2, LinePlot(color='blue'))
 
 from CBD.state_events import StateEvent, Direction
-from CBD.state_events.locators import ITPStateEventLocator
+from CBD.state_events.locators import *
 
-def bounce(_, model):
-	model.bounce()
+def bounce(e, t, model):
+	if e.output_name == "height":
+		model.bounce()
+		print("BOUNCE AT:", t)
+	else:
+		y = model.getSignal("height")[-1].value
+		model.getBlockByName("y0").setValue(y)
+		model.getBlockByName("v0").setValue(-41)
 
 lengs = []
 times = []
@@ -34,6 +48,7 @@ def poststep(o, t, st):
 sim = Simulator(bb)
 sim.setStateEventLocator(ITPStateEventLocator())
 sim.registerStateEvent(StateEvent("height", direction=Direction.FROM_ABOVE, event=bounce))
+sim.registerStateEvent(StateEvent("velocity", direction=Direction.FROM_ABOVE, level=-40.0, event=bounce))
 sim.connect("poststep", poststep)
 sim.setRealTime()
 sim.setDeltaT(DELTA)
@@ -52,12 +67,12 @@ ax = fig.add_subplot(111)
 ax.set_xticks(np.arange(0, TIME, DELTA), minor=True)
 ax.set_xlim((0, TIME+DELTA))
 
-ax.bar(times, lengs, width=0.1, color='green', label='duration')
+ax.bar(times, lengs, width=0.05, color='green', label='duration')
 # ax.plot((0, TIME+DELTA), (DELTA, DELTA), c='red', ls='--')
 
 ax2 = ax.twinx()
 ax2.set_ylim((0, 105))
-ax2.plot(*bb.getBlockByName('plot').data_xy, c='blue', label='simulation')
+ax2.plot(*bb.getBlockByName('plot1').data_xy, c='blue', label='simulation')
 
 handles, labels = ax.get_legend_handles_labels()
 handles2, labels2 = ax2.get_legend_handles_labels()

+ 1 - 1
src/CBD/realtime/accurate_time.py

@@ -14,7 +14,7 @@ def time():
 		# better precision on windows, but deprecated since 3.3
 		return python_time.clock()
 	else:
-		return python_time.process_time()
+		return python_time.perf_counter()
 
 def sleep(t):
 	"""

+ 20 - 10
src/CBD/simulator.py

@@ -526,28 +526,38 @@ class Simulator:
 		self.__sim_data[1] = self.__scheduler.obtain(self.__sim_data[0], curIt, simT)
 		self.__computeBlocks(self.__sim_data[1], self.__sim_data[0], self.__sim_data[2])
 		self.__sim_data[2] += 1
-		# TODO: multiple LCC
+
+		# State Event Location
+		lcc = float('inf')
+		lcc_evt = None
 		for event in self.__state_events:
 			if not event.fired and self.__stel.detect_signal(event.output_name, event.level, event.direction):
 				event.fired = True
 				t = self.__stel.run(event.output_name, event.level, event.direction)
 
-				# Reset the model
-				self.model._rewind()
-				event.event(t, self.model)
-
-				# reset to allow for new IC computation
-				self.model.clearSignals()
-				self.model.getClock().setStartTime(t)
-				self.__sim_data[0] = None
-				self.__sim_data[2] = 0
+				lcc_evt = event
+				lcc = min(lcc, t)
 			elif event.fired:
 				event.fired = False
+
+		if lcc != float('inf'):
+			# TODO: Ideally, the model was cached here and should not be recomputed
+			self.__stel._function(lcc_evt.output_name, lcc, lcc_evt.level, noop=False)
+
+			lcc_evt.event(lcc_evt, lcc, self.model)
+			self.model._rewind()
+
+			# reset to allow for new IC computation
+			self.model.clearSignals()
+			self.model.getClock().setStartTime(lcc)
+			self.__sim_data[0] = None
+			self.__sim_data[2] = 0
 		post = time.time()
 		self.signal("poststep", pre, post, self.getTime())
 
 	def _lcc_compute(self):
 		self.__computeBlocks(self.__sim_data[1], self.__sim_data[0], self.__sim_data[2])
+		self.__sim_data[2] += 1
 
 	def _rewind(self):
 		self.__sim_data[2] -= 1

+ 2 - 1
src/CBD/state_events/__init__.py

@@ -31,7 +31,8 @@ class StateEvent:
 		direction (Direction):  The direction of the crossing.
 								Defaults to :attr:`Direction.ANY`.
 		event (callable):       A function that must be executed if the event
-								occurs. It takes two arguments: time and model.
+								occurs. It takes three arguments: event, time and
+								model.
 								In this function, it is allowed to alter any
 								and all attributes/properties/components of the
 								model. Defaults to a no-op.

+ 17 - 11
src/CBD/state_events/locators.py

@@ -76,16 +76,11 @@ class StateEventLocator:
 
 		return self.detect(prev, curr, direction)
 
-	def _function(self, output_name, time, level=0.0):
+	def _function(self, output_name, time, level=0.0, noop=True):
 		"""
 		The internal function. Whenever an algorithm requires the computation of the
 		CBD model at another time, this function can be executed.
 
-		Note:
-			The CBD will remain at the computed time afterwards. Use
-			:meth:`CBD.simulator._rewind` to undo the actions of this
-			function.
-
 		Args:
 			output_name (str):  The output port name for which the crossing point must
 								be computed.
@@ -96,6 +91,8 @@ class StateEventLocator:
 								are basically root finders. If the algorithm incorporates
 								the level itself, keep this value at 0 for correct behaviour.
 								Defaults to 0.
+			noop (bool):        When :code:`True`, this function will be a no-op. Otherwise,
+								the model will remain at the given time.
 
 		Returns:
 			The signal value of the output at the given time, shifted towards 0.
@@ -103,10 +100,19 @@ class StateEventLocator:
 		if callable(output_name):
 			return output_name(time) - level
 		assert time >= self.t_lower
+
+		h = self.sim.model.getClock().getDeltaT()
 		self.sim._rewind()
 		self.setDeltaT(time - self.t_lower)
 		self.sim._lcc_compute()
-		return self.sim.model.getSignal(output_name)[-1].value - level
+		s = self.sim.model.getSignal(output_name)[-1].value - level
+
+		if noop:
+			self.sim._rewind()
+		self.setDeltaT(h)
+		if noop:
+			self.sim._lcc_compute()
+		return s
 
 	def setDeltaT(self, dt):
 		"""
@@ -133,16 +139,16 @@ class StateEventLocator:
 		Returns:
 			The detected time at which the crossing is suspected to occur.
 		"""
-		h = self.sim.model.getClock().getDeltaT()
 
 		sig = self.sim.model.getSignal(output_name)
 		p1 = sig[-2].time, sig[-2].value - level
 		p2 = sig[-1].time, sig[-1].value - level
 		self.t_lower = p1[0]
+
+		# begin the algorithm on the left
+		self.sim._rewind()
 		t_crossing = self.algorithm(p1, p2, output_name, level, direction)
 
-		# Reset time-delta after crossing
-		self.setDeltaT(h)
 		return t_crossing
 
 	def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
@@ -238,7 +244,7 @@ class BisectionStateEventLocator(StateEventLocator):
 	def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
 		tc = p1[0]
 		for i in range(self.n):
-			tc = (p1[0] + p2[0]) / 2
+			tc = ((p2[0] - p1[0]) / 2) + p1[0]
 			yc = self._function(output_name, tc, level)
 
 			if self.detect(p1[1], yc, direction):