ソースを参照

more work on StEL

rparedis 3 年 前
コミット
c11bef88da
3 ファイル変更297 行追加14 行削除
  1. 8 0
      src/CBD/lib/std.py
  2. 173 14
      src/CBD/state_events/locators.py
  3. 116 0
      src/test/stelTest.py

+ 8 - 0
src/CBD/lib/std.py

@@ -1158,6 +1158,14 @@ class DummyClock(BaseBlock):
 		"""
 		return self.__delta_t
 
+	def setDeltaT(self, dt):
+		"""
+		Sets the time-delta.
+		Args:
+			dt (float): New time delta.
+		"""
+		self.__delta_t = dt
+
 	def _rewind(self):
 		self.__start_time -= self.__delta_t
 

+ 173 - 14
src/CBD/state_events/locators.py

@@ -2,8 +2,12 @@
 This module contains the standard State Event locators.
 """
 
+import math
 from CBD.state_events import Direction
 
+__all__ = ['PreCrossingStateEventLocator', 'PostCrossingStateEventLocator', 'LinearStateEventLocator',
+           'BisectionStateEventLocator', 'RegulaFalsiStateEventLocator', 'ITPStateEventLocator']
+
 class StateEventLocator:
 	"""
 	Computes the exact level crossing time and locates when a state event must be scheduled.
@@ -29,7 +33,7 @@ class StateEventLocator:
 
 	def detect(self, prev, curr, direction=Direction.ANY):
 		"""
-		Detects that a crossing happened between prev and curr.
+		Detects that a crossing through zero happened between prev and curr.
 
 		Args:
 			prev (numeric):         The previous value.
@@ -96,13 +100,24 @@ class StateEventLocator:
 		Returns:
 			The signal value of the output at the given time, shifted towards 0.
 		"""
+		if callable(output_name):
+			return output_name(time) - level
 		assert time >= self.t_lower
 		self.sim._rewind()
-		# TODO: actually update h
-		h = time - self.t_lower
+		self.setDeltaT(time - self.t_lower)
 		self.sim._lcc_compute()
 		return self.sim.model.getSignal(output_name)[-1].value - level
 
+	def setDeltaT(self, dt):
+		"""
+		'Forces' the time-delta to be this value for the next computation.
+		Args:
+			dt (float): New time-delta.
+		"""
+		# TODO: make this work for non-fixed rate clocks?
+		clock = self.sim.model.getClock()
+		clock.getBlockConnectedToInput("h").block.setValue(dt)
+
 	def run(self, output_name, level=0.0, direction=Direction.ANY):
 		"""
 		Executes the locator for an output.
@@ -118,12 +133,16 @@ class StateEventLocator:
 		Returns:
 			The detected time at which the crossing is suspected to occur.
 		"""
+		h = self.sim.model.getClock().getDeltaT()
+
 		sig = self.sim.model.getSignal(output_name)
-		p1 = sig[-2].time, sig[-2].value
-		p2 = sig[-1].time, sig[-1].value
+		p1 = sig[-2].time, sig[-2].value - level
+		p2 = sig[-1].time, sig[-1].value - level
 		self.t_lower = p1[0]
 		t_crossing = self.algorithm(p1, p2, output_name, level, direction)
-		# TODO: reset delta to old value
+
+		# Reset time-delta after crossing
+		self.setDeltaT(h)
 		return t_crossing
 
 	# TODO: is the direction even required? Isn't it automatically maintained?
@@ -133,13 +152,24 @@ class StateEventLocator:
 		in sub-classes. Should only ever be called if a crossing exists.
 
 		Args:
-			p1 (tuple):         The (time, value) coordinate before the crossing.
-			p2 (tuple):         The (time, value) coordinate after the crossing.
-			output_name (str):  The output port name for which the crossing point
-								must be computed.
+			p1 (tuple):         The (time, value) coordinate before the crossing,
+								shifted towards zero.
+			p2 (tuple):         The (time, value) coordinate after the crossing,
+								shifted towards zero.
+			output_name:        The output port name for which the crossing point
+								must be computed, if a CBD is given. Otherwise, a
+								single-argument callable :math`f(t)` is accepted
+								as well.
 			level (float):      The level through which the crossing must be
 								identified. Defaults to 0.
-			direction (Direction):  The direction of the crossing to detect.
+			direction (Direction):  The direction of the crossing to detect. This
+								value ensures a valid crossing is identified if there
+								are multiple between :attr:`p1` and :attr:`p2`. Will
+								only provide an acceptable result if the direction of
+								the crossing can be identified. For instance, if
+								there is a crossing from below, according to the
+								:meth:`detect` function, the algorithm will usually
+								not accurately identify any crossings from above.
 								Defaults to :attr:`Direction.ANY`.
 
 		Returns:
@@ -191,6 +221,37 @@ class LinearStateEventLocator(StateEventLocator):
 		return (t2 - t1) / (y2 - y1) * (level - y1) + t1
 
 
+class BisectionStateEventLocator(StateEventLocator):
+	"""
+	Uses the bisection method to compute the crossing. This method is more accurate
+	than a linear algorithm :class:`LinearStateEventLocator`, but less accurate than
+	regula falsi (:class:`RegulaFalsiStateEventLocator`).
+
+	Args:
+		n (int):    The maximal amount of iterations to compute. Roughly very 3 iterations,
+					a decimal place of accuracy is gained. Defaults to 10.
+	"""
+	def __init__(self, n=10):
+		assert n > 0, "There must be at least 1 iteration for this method."
+		super(BisectionStateEventLocator, self).__init__()
+		self.n = n
+
+	def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
+		tc = p1[0]
+		for i in range(self.n):
+			tc = (p1[0] + p2[0]) / 2
+			yc = self._function(output_name, tc, level)
+
+			if self.detect(p1[1], yc, direction):
+				p2 = tc, yc
+			elif self.detect(yc, p2[1], direction):
+				p1 = tc, yc
+			else:
+				break
+				# raise ValueError("Cannot find a viable crossing.")
+		return tc
+
+
 class RegulaFalsiStateEventLocator(StateEventLocator):
 	"""
 	Implements the Illinois algorithm for finding the root for a crossing problem.
@@ -217,17 +278,24 @@ class RegulaFalsiStateEventLocator(StateEventLocator):
 		t2, y2 = p2
 		tn, yn = t1, y1
 
+		y1 -= level
+		y2 -= level
+
 		side = 0
 		for i in range(self.n):
 			if abs(t1 - t2) < self.eps * abs(t1 + t2): break
-			tn = (y1 * t2 - y2 * t1) / (y1 - y2)
+			if abs(y1 - y2) < self.eps:
+				tn = (t2 - t1) / 2 + t1
+			else:
+				tn = (y1 * t2 - y2 * t1) / (y1 - y2)
 			yn = self._function(output_name, tn, level)
-			if yn * y2 > 0:
+
+			if self.detect(y1, yn, direction):
 				t2, y2 = tn, yn
 				if side == -1:
 					y1 /= 2
 				side = -1
-			elif yn * y1 > 0:
+			elif self.detect(yn, y2, direction):
 				t1, y1 = tn, yn
 				if side == 1:
 					y2 /= 2
@@ -235,3 +303,94 @@ class RegulaFalsiStateEventLocator(StateEventLocator):
 			else:
 				break
 		return tn
+
+
+class ITPStateEventLocator(StateEventLocator):
+	r"""
+	Implements the Interpolation-Truncation-Projection algorithm for finding
+	the root of a function.
+
+	Args:
+		eps (float):    Minimal interval size. Defaults to 1e-5.
+		k1 (float):     First truncation size hyperparameter. Must be in the
+						range of :math:`(0, \infty)`. Defaults to 0.1.
+		k2 (float):     Second truncation size hyperparameter. Must be in the
+						range of :math:`[1, 1 + \frac{1}{2}(1 + \sqrt{5})]`.
+						Defaults to 1.5.
+		n0 (float):     Slack variable to control the size of the interval for
+						the projection step. Must be in :math:`[0, \infty)`.
+						When 0, the average number of iterations will be less
+						than that of the bisection method. Defaults to 0.
+
+	See Also:
+		https://en.wikipedia.org/wiki/ITP_method
+	"""
+	def __init__(self, eps=1e-5, k1=0.1, k2=1.5, n0=0):
+		assert 0 < k1, "For ITP, k1 must be strictly positive."
+		assert 1 <= k2 <= (1 + (1. + 5 ** 0.5) / 2.), "For ITP, k2 must be in [1, 1 + phi]."
+		assert 0 <= n0, "For ITP, n0 must be positive or zero."
+
+		super(ITPStateEventLocator, self).__init__()
+
+		self.eps = eps
+		self.k1 = k1
+		self.k2 = k2
+		self.n0 = n0
+
+	def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
+		sign = lambda x: 1 if x > 0 else (-1 if x < 0 else 0)
+
+		a, ya = p1
+		b, yb = p2
+
+		ya -= level
+		yb -= level
+
+		if ya == 0:
+			return a
+		if yb == 0:
+			return b
+
+		# Preprocessing
+		nh = math.ceil(math.log((b - a) / (2 * self.eps), 2))
+		nm = nh + self.n0
+		j = 0
+
+		while (b - a) > (2 * self.eps):
+			xh = (b - a) / 2 + a
+			r = self.eps * 2 ** (nm - j) - (b - a) / 2
+			d = self.k1 * (b - a) ** self.k2
+
+			# Interpolation
+			if abs(yb - ya) < self.eps:
+				xf = xh
+			else:
+				xf = (yb * a - ya * b) / (yb - ya)
+
+			# Truncation
+			s = sign(xh - xf)
+			if d <= abs(xh - xf):
+				xt = xf + s * d
+			else:
+				xt = xh
+
+			# Projection
+			if abs(xt - xh) <= r:
+				xI = xt
+			else:
+				xI = xh - s * r
+
+			# Update Interval
+			yI = self._function(output_name, xI, level)
+			if (ya - yb) * yI < 0 and self.detect(ya, yI, direction):
+				b = xI
+				yb = yI
+			elif (ya - yb) * yI > 0 and self.detect(yI, yb, direction):
+				a = xI
+				ya = yI
+			else:
+				a = xI
+				b = xI
+			j += 1
+
+		return (a + b) / 2

+ 116 - 0
src/test/stelTest.py

@@ -0,0 +1,116 @@
+#!/usr/bin/env python
+"""
+Unit tests for the state event locators
+"""
+
+import unittest
+import math
+
+from CBD.state_events.locators import *
+from CBD.state_events import Direction
+
+class StELTestCase(unittest.TestCase):
+	def setUp(self) -> None:
+		self.func = lambda t: (math.sin(t) * math.cos(2*t)) + 6
+		self.level = 6
+		self.eps = 1e-5
+
+		x1 = -0.421
+		y1 = self.func(x1)
+		x2 = 0.421
+		y2 = self.func(x2)
+		x3 = 2.721
+		y3 = self.func(x2)
+
+		self.p1 = x1, y1
+		self.p2 = x2, y2
+		self.p3 = x3, y3
+
+	def testPre(self):
+		stel = PreCrossingStateEventLocator()
+
+		a = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.ANY)
+		b = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_BELOW)
+		c = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_ABOVE)
+		d = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.ANY)
+
+		self.assertAlmostEqual(self.p1[0], a, 5)
+		self.assertAlmostEqual(self.p1[0], b, 5)
+		self.assertAlmostEqual(self.p1[0], c, 5)
+		self.assertAlmostEqual(self.p2[0], d, 5)
+
+	def testPost(self):
+		stel = PostCrossingStateEventLocator()
+
+		a = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.ANY)
+		b = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_BELOW)
+		c = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_ABOVE)
+		d = stel.algorithm(self.p1, self.p2, self.func, self.level, Direction.ANY)
+
+		self.assertAlmostEqual(self.p3[0], a, 5)
+		self.assertAlmostEqual(self.p3[0], b, 5)
+		self.assertAlmostEqual(self.p3[0], c, 5)
+		self.assertAlmostEqual(self.p2[0], d, 5)
+
+	def testLinear(self):
+		stel = LinearStateEventLocator()
+
+		a = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.ANY)
+		b = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_BELOW)
+		c = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_ABOVE)
+		d = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.ANY)
+
+		mid = (self.p3[0] - self.p1[0]) / 2 + self.p1[0]
+
+		self.assertAlmostEqual(mid, a, 5)
+		self.assertAlmostEqual(mid, b, 5)
+		self.assertAlmostEqual(mid, c, 5)
+		self.assertAlmostEqual(self.p2[0], d, 5)
+
+	def testBisection(self):
+		stel = BisectionStateEventLocator(200)
+
+		a = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.ANY)
+		b = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_BELOW)
+		c = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.ANY)
+		d = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.FROM_ABOVE)
+
+		x2 = math.pi * 0.25
+		x3 = math.pi * 0.75
+
+		self.assertAlmostEqual(x2, a, 5)
+		self.assertAlmostEqual(x3, b, 5)
+		self.assertAlmostEqual(x2, c, 5)
+		self.assertAlmostEqual(x2, d, 5)
+
+	def testRegulaFalsi(self):
+		stel = RegulaFalsiStateEventLocator()
+
+		a = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.ANY)
+		b = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_BELOW)
+		c = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.ANY)
+		d = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.FROM_ABOVE)
+
+		x2 = math.pi * 0.25
+		x3 = math.pi * 0.75
+
+		self.assertAlmostEqual(x3, a, 5)
+		self.assertAlmostEqual(x3, b, 5)
+		self.assertAlmostEqual(x2, c, 5)
+		self.assertAlmostEqual(x2, d, 5)
+
+	def testITP(self):
+		stel = RegulaFalsiStateEventLocator()
+
+		a = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.ANY)
+		b = stel.algorithm(self.p1, self.p3, self.func, self.level, Direction.FROM_BELOW)
+		c = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.ANY)
+		d = stel.algorithm(self.p2, self.p3, self.func, self.level, Direction.FROM_ABOVE)
+
+		x2 = math.pi * 0.25
+		x3 = math.pi * 0.75
+
+		self.assertAlmostEqual(x3, a, 5)
+		self.assertAlmostEqual(x3, b, 5)
+		self.assertAlmostEqual(x2, c, 5)
+		self.assertAlmostEqual(x2, d, 5)