locators.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. """
  2. This module contains the standard State Event locators.
  3. """
  4. import math
  5. from CBD.state_events import Direction
  6. __all__ = ['PreCrossingStateEventLocator', 'PostCrossingStateEventLocator', 'LinearStateEventLocator',
  7. 'BisectionStateEventLocator', 'RegulaFalsiStateEventLocator', 'ITPStateEventLocator']
  8. class StateEventLocator:
  9. """
  10. Computes the exact level crossing time and locates when a state event must be scheduled.
  11. Attributes:
  12. sim (CBD.simulator.Simulator): The simulator to which the locator belongs.
  13. t_lower (float): The lower range of the level crossing. It is certain
  14. that the crossing happens at a time later than (or
  15. equal to) this time.
  16. """
  17. def __init__(self):
  18. self.sim = None
  19. self.t_lower = 0.0
  20. def setSimulator(self, sim):
  21. """
  22. Sets the simulator to the event locator.
  23. Args:
  24. sim (CBD.simulator.Simulator): The current simulator.
  25. """
  26. self.sim = sim
  27. def detect(self, prev, curr, direction=Direction.ANY):
  28. """
  29. Detects that a crossing through zero happened between prev and curr.
  30. Args:
  31. prev (numeric): The previous value.
  32. curr (numeric): The current value.
  33. direction (Direction): The direction of the crossing to detect.
  34. Defaults to :attr:`Direction.ANY`.
  35. Returns:
  36. :code:`True` when the crossing happened, otherwise :code:`False`.
  37. """
  38. if direction == Direction.FROM_BELOW:
  39. return prev <= 0 <= curr
  40. if direction == Direction.FROM_ABOVE:
  41. return prev >= 0 >= curr
  42. if direction == Direction.ANY:
  43. return (prev <= 0 <= curr) or (prev >= 0 >= curr)
  44. return False
  45. def detect_signal(self, output_name, level=0.0, direction=Direction.ANY):
  46. """
  47. Detects that an output port has a crossing through a specific level.
  48. Args:
  49. output_name (str): The name of the output port.
  50. level (numeric): The level through which the value must go.
  51. Defaults to 0.
  52. direction (Direction): The direction of the crossing to detect.
  53. Defaults to :attr:`Direction.ANY`.
  54. Returns:
  55. :code:`True` when the crossing happened, otherwise :code:`False`.
  56. """
  57. sig = self.sim.model.getSignal(output_name)
  58. if len(sig) < 2:
  59. # No crossing possible (yet)
  60. return False
  61. prev = sig[-2].value - level
  62. curr = sig[-1].value - level
  63. return self.detect(prev, curr, direction)
  64. def _function(self, output_name, time, level=0.0):
  65. """
  66. The internal function. Whenever an algorithm requires the computation of the
  67. CBD model at another time, this function can be executed.
  68. Note:
  69. The CBD will remain at the computed time afterwards. Use
  70. :meth:`CBD.simulator._rewind` to undo the actions of this
  71. function.
  72. Args:
  73. output_name (str): The output port name for which the crossing point must
  74. be computed.
  75. time (float): The time at which the CBD must be computed. Must be
  76. larger than the lower bound time.
  77. level (float): The level through which the crossing must be identified.
  78. This mainly shifts the signal towards 0, as most algorithms
  79. are basically root finders. If the algorithm incorporates
  80. the level itself, keep this value at 0 for correct behaviour.
  81. Defaults to 0.
  82. Returns:
  83. The signal value of the output at the given time, shifted towards 0.
  84. """
  85. if callable(output_name):
  86. return output_name(time) - level
  87. assert time >= self.t_lower
  88. self.sim._rewind()
  89. self.setDeltaT(time - self.t_lower)
  90. self.sim._lcc_compute()
  91. return self.sim.model.getSignal(output_name)[-1].value - level
  92. def setDeltaT(self, dt):
  93. """
  94. 'Forces' the time-delta to be this value for the next computation.
  95. Args:
  96. dt (float): New time-delta.
  97. """
  98. # TODO: make this work for non-fixed rate clocks?
  99. clock = self.sim.model.getClock()
  100. clock.getBlockConnectedToInput("h").block.setValue(dt)
  101. def run(self, output_name, level=0.0, direction=Direction.ANY):
  102. """
  103. Executes the locator for an output.
  104. Args:
  105. output_name (str): The output port name for which the crossing
  106. point must be computed.
  107. level (float): The level through which the crossing must be
  108. identified. Defaults to 0.
  109. direction (Direction): The direction of the crossing to detect.
  110. Defaults to :attr:`Direction.ANY`.
  111. Returns:
  112. The detected time at which the crossing is suspected to occur.
  113. """
  114. h = self.sim.model.getClock().getDeltaT()
  115. sig = self.sim.model.getSignal(output_name)
  116. p1 = sig[-2].time, sig[-2].value - level
  117. p2 = sig[-1].time, sig[-1].value - level
  118. self.t_lower = p1[0]
  119. t_crossing = self.algorithm(p1, p2, output_name, level, direction)
  120. # Reset time-delta after crossing
  121. self.setDeltaT(h)
  122. return t_crossing
  123. # TODO: is the direction even required? Isn't it automatically maintained?
  124. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  125. """
  126. The algorithm that identifies the locator functionality. Must be implemented
  127. in sub-classes. Should only ever be called if a crossing exists.
  128. Args:
  129. p1 (tuple): The (time, value) coordinate before the crossing,
  130. shifted towards zero.
  131. p2 (tuple): The (time, value) coordinate after the crossing,
  132. shifted towards zero.
  133. output_name: The output port name for which the crossing point
  134. must be computed, if a CBD is given. Otherwise, a
  135. single-argument callable :math`f(t)` is accepted
  136. as well.
  137. level (float): The level through which the crossing must be
  138. identified. Defaults to 0.
  139. direction (Direction): The direction of the crossing to detect. This
  140. value ensures a valid crossing is identified if there
  141. are multiple between :attr:`p1` and :attr:`p2`. Will
  142. only provide an acceptable result if the direction of
  143. the crossing can be identified. For instance, if
  144. there is a crossing from below, according to the
  145. :meth:`detect` function, the algorithm will usually
  146. not accurately identify any crossings from above.
  147. Defaults to :attr:`Direction.ANY`.
  148. Returns:
  149. A suspected time of the crossing.
  150. """
  151. raise NotImplementedError()
  152. class PreCrossingStateEventLocator(StateEventLocator):
  153. """
  154. Assumes that the crossing happens at the start of the interval.
  155. Can be used if a precise detection is not a requirement.
  156. This implementation computes a rough under-estimate.
  157. """
  158. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  159. return p1[0]
  160. class PostCrossingStateEventLocator(StateEventLocator):
  161. """
  162. Assumes that the crossing happens at the end of the interval.
  163. Can be used if a precise detection is not a requirement.
  164. This implementation computes a rough over-estimate.
  165. Corresponds to the :code:`if` statement in `Modelica <https://modelica.org/>`_,
  166. whereas the other locators can be seen as the :code:`when` statement.
  167. """
  168. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  169. return p2[0]
  170. class LinearStateEventLocator(StateEventLocator):
  171. """
  172. Uses simple linear interpolation to compute the time of the crossing.
  173. This is usually a rough, yet centered estimate.
  174. This locator should only be used if it is known that the signal is
  175. (mostly) linear between the lower and upper bounds.
  176. """
  177. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  178. t1, y1 = p1
  179. t2, y2 = p2
  180. if y1 == y2:
  181. return t1
  182. # Use the equation of a line between two points
  183. # Formula is easier if x and y axes are swapped.
  184. return (t2 - t1) / (y2 - y1) * (level - y1) + t1
  185. class BisectionStateEventLocator(StateEventLocator):
  186. """
  187. Uses the bisection method to compute the crossing. This method is more accurate
  188. than a linear algorithm :class:`LinearStateEventLocator`, but less accurate than
  189. regula falsi (:class:`RegulaFalsiStateEventLocator`).
  190. Args:
  191. n (int): The maximal amount of iterations to compute. Roughly very 3 iterations,
  192. a decimal place of accuracy is gained. Defaults to 10.
  193. """
  194. def __init__(self, n=10):
  195. assert n > 0, "There must be at least 1 iteration for this method."
  196. super(BisectionStateEventLocator, self).__init__()
  197. self.n = n
  198. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  199. tc = p1[0]
  200. for i in range(self.n):
  201. tc = (p1[0] + p2[0]) / 2
  202. yc = self._function(output_name, tc, level)
  203. if self.detect(p1[1], yc, direction):
  204. p2 = tc, yc
  205. elif self.detect(yc, p2[1], direction):
  206. p1 = tc, yc
  207. else:
  208. break
  209. # raise ValueError("Cannot find a viable crossing.")
  210. return tc
  211. class RegulaFalsiStateEventLocator(StateEventLocator):
  212. """
  213. Implements the Illinois algorithm for finding the root for a crossing problem.
  214. Args:
  215. eps (float): Half of the upper bound for the relative error.
  216. Defaults to 1e-5.
  217. n (int): The maximal amount of iterations to compute. Defaults to
  218. 5 million iterations.
  219. See Also:
  220. https://en.wikipedia.org/wiki/Regula_falsi
  221. """
  222. def __init__(self, eps=1e-5, n=5_000_000):
  223. super(RegulaFalsiStateEventLocator, self).__init__()
  224. self.eps = eps
  225. self.n = n
  226. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  227. # direction unused, because the algorithm will automatically maintain
  228. # the crossing direction
  229. t1, y1 = p1
  230. t2, y2 = p2
  231. tn, yn = t1, y1
  232. y1 -= level
  233. y2 -= level
  234. side = 0
  235. for i in range(self.n):
  236. if abs(t1 - t2) < self.eps * abs(t1 + t2): break
  237. if abs(y1 - y2) < self.eps:
  238. tn = (t2 - t1) / 2 + t1
  239. else:
  240. tn = (y1 * t2 - y2 * t1) / (y1 - y2)
  241. yn = self._function(output_name, tn, level)
  242. if self.detect(y1, yn, direction):
  243. t2, y2 = tn, yn
  244. if side == -1:
  245. y1 /= 2
  246. side = -1
  247. elif self.detect(yn, y2, direction):
  248. t1, y1 = tn, yn
  249. if side == 1:
  250. y2 /= 2
  251. side = 1
  252. else:
  253. break
  254. return tn
  255. class ITPStateEventLocator(StateEventLocator):
  256. r"""
  257. Implements the Interpolation-Truncation-Projection algorithm for finding
  258. the root of a function.
  259. Args:
  260. eps (float): Minimal interval size. Defaults to 1e-5.
  261. k1 (float): First truncation size hyperparameter. Must be in the
  262. range of :math:`(0, \infty)`. Defaults to 0.1.
  263. k2 (float): Second truncation size hyperparameter. Must be in the
  264. range of :math:`[1, 1 + \frac{1}{2}(1 + \sqrt{5})]`.
  265. Defaults to 1.5.
  266. n0 (float): Slack variable to control the size of the interval for
  267. the projection step. Must be in :math:`[0, \infty)`.
  268. When 0, the average number of iterations will be less
  269. than that of the bisection method. Defaults to 0.
  270. See Also:
  271. https://en.wikipedia.org/wiki/ITP_method
  272. """
  273. def __init__(self, eps=1e-5, k1=0.1, k2=1.5, n0=0):
  274. assert 0 < k1, "For ITP, k1 must be strictly positive."
  275. assert 1 <= k2 <= (1 + (1. + 5 ** 0.5) / 2.), "For ITP, k2 must be in [1, 1 + phi]."
  276. assert 0 <= n0, "For ITP, n0 must be positive or zero."
  277. super(ITPStateEventLocator, self).__init__()
  278. self.eps = eps
  279. self.k1 = k1
  280. self.k2 = k2
  281. self.n0 = n0
  282. def algorithm(self, p1, p2, output_name, level=0.0, direction=Direction.ANY):
  283. sign = lambda x: 1 if x > 0 else (-1 if x < 0 else 0)
  284. a, ya = p1
  285. b, yb = p2
  286. ya -= level
  287. yb -= level
  288. if ya == 0:
  289. return a
  290. if yb == 0:
  291. return b
  292. # Preprocessing
  293. nh = math.ceil(math.log((b - a) / (2 * self.eps), 2))
  294. nm = nh + self.n0
  295. j = 0
  296. while (b - a) > (2 * self.eps):
  297. xh = (b - a) / 2 + a
  298. r = self.eps * 2 ** (nm - j) - (b - a) / 2
  299. d = self.k1 * (b - a) ** self.k2
  300. # Interpolation
  301. if abs(yb - ya) < self.eps:
  302. xf = xh
  303. else:
  304. xf = (yb * a - ya * b) / (yb - ya)
  305. # Truncation
  306. s = sign(xh - xf)
  307. if d <= abs(xh - xf):
  308. xt = xf + s * d
  309. else:
  310. xt = xh
  311. # Projection
  312. if abs(xt - xh) <= r:
  313. xI = xt
  314. else:
  315. xI = xh - s * r
  316. # Update Interval
  317. yI = self._function(output_name, xI, level)
  318. if (ya - yb) * yI < 0 and self.detect(ya, yI, direction):
  319. b = xI
  320. yb = yI
  321. elif (ya - yb) * yI > 0 and self.detect(yI, yb, direction):
  322. a = xI
  323. ya = yI
  324. else:
  325. a = xI
  326. b = xI
  327. j += 1
  328. return (a + b) / 2