tree.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import termcolor
  2. from typing import *
  3. import itertools
  4. from sccd.statechart.static.action import *
  5. from sccd.util.bitmap import *
  6. from sccd.util import timer
  7. from sccd.util.visit_tree import *
  8. from sccd.util.freezable import *
  9. class State(Freezable):
  10. __slots__ = ["short_name", "parent", "stable", "children", "default_state", "transitions", "enter", "exit", "opt"]
  11. def __init__(self, short_name: str, parent: Optional['State']):
  12. super().__init__()
  13. self.short_name: str = short_name # value of 'id' attribute in XML
  14. self.parent: Optional['State'] = parent # only None if root state
  15. self.stable: bool = False # whether this is a stable stabe. this field is ignored if maximality semantics is not set to SYNTACTIC
  16. self.children: List['State'] = []
  17. self.default_state: 'State' = None # child state pointed to by 'initial' attribute
  18. self.transitions: List['Transition'] = []
  19. self.enter: List[Action] = []
  20. self.exit: List[Action] = []
  21. self.opt: Optional['StateOptimization'] = None
  22. if self.parent is not None:
  23. self.parent.children.append(self)
  24. def _static_additional_target_states(self, exclude: 'State') -> Tuple[Bitmap, List['HistoryState']]:
  25. return (self.opt.state_id_bitmap, [])
  26. def __repr__(self):
  27. return "State(\"%s\")" % (self.short_name)
  28. class HistoryState(State):
  29. __slots__ = ["history_id"]
  30. def __init__(self, short_name: str, parent: Optional['State']):
  31. super().__init__(short_name, parent)
  32. self.history_id: Optional[int] = None
  33. # Set of states that may be history values.
  34. @abstractmethod
  35. def history_mask(self) -> Bitmap:
  36. pass
  37. def _static_additional_target_states(self, exclude: 'State') -> Tuple[Bitmap, List['HistoryState']]:
  38. assert False # history state cannot have children and therefore should never occur in a "enter path"
  39. class ShallowHistoryState(HistoryState):
  40. def history_mask(self) -> Bitmap:
  41. # Only direct children of parent:
  42. return states_to_bitmap(self.parent.children)
  43. def __repr__(self):
  44. return "ShallowHistoryState(\"%s\")" % (self.short_name)
  45. class DeepHistoryState(HistoryState):
  46. def history_mask(self) -> Bitmap:
  47. # All descendants of parent:
  48. return self.parent.opt.descendants
  49. def __repr__(self):
  50. return "DeepHistoryState(\"%s\")" % (self.short_name)
  51. class ParallelState(State):
  52. def _static_additional_target_states(self, exclude: 'State') -> Tuple[Bitmap, List['HistoryState']]:
  53. return (self.opt.ts_static & ~exclude.opt.ts_static, [s for s in self.opt.ts_dynamic if s not in exclude.opt.ts_dynamic])
  54. def __repr__(self):
  55. return "ParallelState(\"%s\")" % (self.short_name)
  56. @dataclass
  57. class EventDecl:
  58. __slots__ = ["id", "name", "params_decl"]
  59. id: int
  60. name: str
  61. params_decl: List[ParamDecl]
  62. def render(self) -> str:
  63. if self.params_decl:
  64. return self.name + '(' + ', '.join(p.render() for p in self.params_decl) + ')'
  65. else:
  66. return self.name
  67. @dataclass
  68. class Trigger:
  69. __slots__ = ["enabling", "enabling_bitmap"]
  70. enabling: List[EventDecl]
  71. def __post_init__(self):
  72. # Optimization: Require 'enabling' to be sorted!
  73. assert sorted(self.enabling, key=lambda e: e.id) == self.enabling
  74. self.enabling_bitmap = bm_from_list(e.id for e in self.enabling)
  75. def check(self, events_bitmap: Bitmap) -> bool:
  76. return (self.enabling_bitmap & events_bitmap) == self.enabling_bitmap
  77. def render(self) -> str:
  78. return ' ∧ '.join(e.render() for e in self.enabling)
  79. def copy_params_to_stack(self, ctx: EvalContext):
  80. # Both 'ctx.events' and 'self.enabling' are sorted by event ID,
  81. # this way we have to iterate over each of both lists at most once.
  82. iterator = iter(self.enabling)
  83. try:
  84. event_decl = next(iterator)
  85. offset = 0
  86. for e in ctx.events:
  87. if e.id < event_decl.id:
  88. continue
  89. else:
  90. while e.id > event_decl.id:
  91. event_decl = next(iterator)
  92. for p in e.params:
  93. ctx.memory.store(offset, p)
  94. offset += 1
  95. except StopIteration:
  96. pass
  97. @dataclass
  98. class NegatedTrigger(Trigger):
  99. __slots__ = ["disabling", "disabling_bitmap"]
  100. disabling: List[EventDecl]
  101. def __post_init__(self):
  102. Trigger.__post_init__(self)
  103. self.disabling_bitmap = bm_from_list(e.id for e in self.disabling)
  104. def check(self, events_bitmap: Bitmap) -> bool:
  105. return Trigger.check(self, events_bitmap) and not (self.disabling_bitmap & events_bitmap)
  106. def render(self) -> str:
  107. return Trigger.render(self) + ' ∧ ' + ' ∧ '.join('¬'+e.render() for e in self.disabling)
  108. class AfterTrigger(Trigger):
  109. def __init__(self, id: int, name: str, after_id: int, delay: Expression):
  110. enabling = [EventDecl(id=id, name=name, params_decl=[])]
  111. super().__init__(enabling)
  112. self.id = id
  113. self.name = name
  114. self.after_id = after_id # unique ID for AfterTrigger
  115. self.delay = delay
  116. def render(self) -> str:
  117. return "after("+self.delay.render()+")"
  118. # Override.
  119. # An 'after'-event also has 1 parameter, but it is not accessible to the user,
  120. # hence the override.
  121. def copy_params_to_stack(self, ctx: EvalContext):
  122. pass
  123. class Transition(Freezable):
  124. __slots__ = ["source", "targets", "scope", "target_string", "guard", "actions", "trigger", "opt"]
  125. def __init__(self, source: State, targets: List[State], scope: Scope, target_string: Optional[str] = None):
  126. super().__init__()
  127. self.source: State = source
  128. self.targets: List[State] = targets
  129. self.scope: Scope = scope
  130. self.target_string: Optional[str] = target_string
  131. self.guard: Optional[Expression] = None
  132. self.actions: List[Action] = []
  133. self.trigger: Optional[Trigger] = None
  134. self.opt: Optional['TransitionOptimization'] = None
  135. def __str__(self):
  136. return termcolor.colored("%s 🡪 %s" % (self.source.opt.full_name, self.targets[0].opt.full_name), 'green')
  137. # Data that is generated for each state.
  138. class StateOptimization(Freezable):
  139. __slots__ = ["full_name", "depth", "state_id", "state_id_bitmap", "ancestors", "descendants", "history", "ts_static", "ts_dynamic", "after_triggers"]
  140. def __init__(self):
  141. super().__init__()
  142. self.full_name: str = ""
  143. self.depth: int -1 # Root is 0, root's children are 1, and so on
  144. self.state_id: int = -1
  145. self.state_id_bitmap: Bitmap = Bitmap() # bitmap with only state_id-bit set
  146. self.ancestors: Bitmap = Bitmap()
  147. self.descendants: Bitmap = Bitmap()
  148. # Subset of children that are HistoryState.
  149. # For each item, the second element of the tuple is the "history mask".
  150. self.history: List[Tuple[HistoryState, Bitmap]] = []
  151. # Subset of descendants that are always entered when this state is the target of a transition
  152. self.ts_static: Bitmap = Bitmap()
  153. # Subset of descendants that are history states AND are in the subtree of states automatically entered if this state is the target of a transition.
  154. self.ts_dynamic: List[HistoryState] = []
  155. # Triggers of outgoing transitions that are AfterTrigger.
  156. self.after_triggers: List[AfterTrigger] = []
  157. # Data that is generated for each transition.
  158. class TransitionOptimization(Freezable):
  159. __slots__ = ["arena", "arena_bitmap", "enter_states_static", "enter_states_dynamic"]
  160. def __init__(self, arena: State, arena_bitmap: Bitmap, enter_states_static: Bitmap, enter_states_dynamic: List[HistoryState]):
  161. super().__init__()
  162. self.arena: State = arena
  163. self.arena_bitmap: Bitmap = arena_bitmap
  164. self.enter_states_static: Bitmap = enter_states_static # The "enter set" can be computed partially statically, and if there are no history states in it, entirely statically
  165. self.enter_states_dynamic: List[HistoryState] = enter_states_dynamic # The part of the "enter set" that cannot be computed statically.
  166. self.freeze()
  167. class StateTree(Freezable):
  168. __slots__ = ["root", "transition_list", "state_list", "state_dict", "after_triggers", "stable_bitmap", "history_states"]
  169. def __init__(self, root: State, transition_list: List[Transition], state_list: List[State], state_dict: Dict[str, State], after_triggers: List[AfterTrigger], stable_bitmap: Bitmap, history_states: List[HistoryState]):
  170. super().__init__()
  171. self.root: State = root
  172. self.transition_list: List[Transition] = transition_list # depth-first document order
  173. self.state_list: List[State] = state_list # depth-first document order
  174. self.state_dict: Dict[str, State] = state_dict # mapping from 'full name' to State
  175. self.after_triggers: List[AfterTrigger] = after_triggers # all after-triggers in the statechart
  176. self.stable_bitmap: Bitmap = stable_bitmap # set of states that are syntactically marked 'stable'
  177. self.history_states: List[HistoryState] = history_states # all the history states in the statechart
  178. self.freeze()
  179. # Reduce a list of states to a set of states, as a bitmap
  180. def states_to_bitmap(state_list: List[State]) -> Bitmap:
  181. return reduce(lambda x,y: x|y, (s.opt.state_id_bitmap for s in state_list), Bitmap())
  182. def optimize_tree(root: State) -> StateTree:
  183. with timer.Context("optimize tree"):
  184. transition_list = []
  185. after_triggers = []
  186. history_states = []
  187. def init_opt():
  188. next_id = 0
  189. def f(state: State, _=None):
  190. state.opt = StateOptimization()
  191. nonlocal next_id
  192. state.opt.state_id = next_id
  193. state.opt.state_id_bitmap = bit(next_id)
  194. next_id += 1
  195. for t in state.transitions:
  196. transition_list.append(t)
  197. if t.trigger and isinstance(t.trigger, AfterTrigger):
  198. state.opt.after_triggers.append(t.trigger)
  199. after_triggers.append(t.trigger)
  200. if isinstance(state, HistoryState):
  201. state.history_id = len(history_states)
  202. history_states.append(state)
  203. return f
  204. def assign_depth(state: State, parent_depth: int = 0):
  205. state.opt.depth = parent_depth + 1
  206. return parent_depth + 1
  207. def assign_full_name(state: State, parent_full_name: str = ""):
  208. if state is root:
  209. full_name = '/'
  210. elif state.parent is root:
  211. full_name = '/' + state.short_name
  212. else:
  213. full_name = parent_full_name + '/' + state.short_name
  214. state.opt.full_name = full_name
  215. return full_name
  216. state_dict = {}
  217. state_list = []
  218. stable_bitmap = Bitmap()
  219. def add_to_list(state: State ,_=None):
  220. nonlocal stable_bitmap
  221. state_dict[state.opt.full_name] = state
  222. state_list.append(state)
  223. if state.stable:
  224. stable_bitmap |= state.opt.state_id_bitmap
  225. def set_ancestors(state: State, ancestors=[]):
  226. state.opt.ancestors = states_to_bitmap(ancestors)
  227. return ancestors + [state]
  228. def set_descendants(state: State, children_descendants):
  229. descendants = reduce(lambda x,y: x|y, children_descendants, Bitmap())
  230. state.opt.descendants = descendants
  231. return state.opt.state_id_bitmap | descendants
  232. def set_static_target_states(state: State, _):
  233. if isinstance(state, ParallelState):
  234. state.opt.ts_static = reduce(lambda x,y: x|y, (s.opt.ts_static for s in state.children), state.opt.state_id_bitmap)
  235. state.opt.ts_dynamic = list(itertools.chain.from_iterable(c.opt.ts_dynamic for c in state.children if not isinstance(c, HistoryState)))
  236. elif isinstance(state, HistoryState):
  237. state.opt.ts_static = Bitmap()
  238. state.opt.ts_dynamic = [state]
  239. else: # "regular" state:
  240. if state.default_state:
  241. state.opt.ts_static = state.opt.state_id_bitmap | state.default_state.opt.ts_static
  242. state.opt.ts_dynamic = state.default_state.opt.ts_dynamic
  243. else:
  244. state.opt.ts_static = state.opt.state_id_bitmap
  245. state.opt.ts_dynamic = []
  246. def add_history(state: State, _= None):
  247. for c in state.children:
  248. if isinstance(c, HistoryState):
  249. state.opt.history.append((c, c.history_mask()))
  250. def freeze(state: State, _=None):
  251. state.freeze()
  252. state.opt.freeze()
  253. visit_tree(root, lambda s: s.children,
  254. before_children=[
  255. init_opt(),
  256. assign_full_name,
  257. assign_depth,
  258. add_to_list,
  259. set_ancestors,
  260. ],
  261. after_children=[
  262. set_descendants,
  263. add_history,
  264. set_static_target_states,
  265. freeze,
  266. ])
  267. for t in transition_list:
  268. # Arena can be computed statically. First compute Lowest-common ancestor:
  269. # Intersection between source & target ancestors, last member in depth-first sorted state list.
  270. lca_id = bm_highest_bit(t.source.opt.ancestors & t.targets[0].opt.ancestors)
  271. lca = state_list[lca_id]
  272. arena = lca
  273. # Arena must be an Or-state:
  274. while isinstance(arena, (ParallelState, HistoryState)):
  275. arena = arena.parent
  276. # Exit states can be efficiently computed at runtime based on the set of current states.
  277. # Enter states are more complex but luckily, can be computed *partially* statically:
  278. # As a start, we calculate the enter path:
  279. # The enter path is the path from arena to the target state (not including the arena state itself).
  280. # Enter path is the intersection between:
  281. # 1) the transition's target and its ancestors, and
  282. # 2) the arena's descendants
  283. enter_path = (t.targets[0].opt.state_id_bitmap | t.targets[0].opt.ancestors) & arena.opt.descendants
  284. # All states on the enter path will be entered, but on the enter path, there may also be AND-states whose children are not on the enter path, but should also be entered.
  285. enter_path_iter = bm_items(enter_path)
  286. state_id = next(enter_path_iter, None)
  287. enter_states_static = Bitmap()
  288. enter_states_dynamic = []
  289. while state_id is not None:
  290. state = state_list[state_id]
  291. next_state_id = next(enter_path_iter, None)
  292. if next_state_id:
  293. # an intermediate state on the path from arena to target
  294. next_state = state_list[next_state_id]
  295. static, dynamic = state._static_additional_target_states(next_state)
  296. enter_states_static |= static
  297. enter_states_dynamic += dynamic
  298. else:
  299. # the actual target of the transition
  300. enter_states_static |= state.opt.ts_static
  301. enter_states_dynamic += state.opt.ts_dynamic
  302. state_id = next_state_id
  303. t.opt = TransitionOptimization(
  304. arena=arena,
  305. arena_bitmap=arena.opt.descendants | arena.opt.state_id_bitmap,
  306. enter_states_static=enter_states_static,
  307. enter_states_dynamic=enter_states_dynamic)
  308. t.freeze()
  309. return StateTree(root, transition_list, state_list, state_dict, after_triggers, stable_bitmap, history_states)
  310. def priority_source_parent(tree: StateTree) -> List[Transition]:
  311. # Tree's transition list already ordered parent-first
  312. return tree.transition_list
  313. # The following 3 priority implementations all do a stable sort with a partial order-key
  314. def priority_source_child(tree: StateTree) -> List[Transition]:
  315. return list(sorted(tree.transition_list, key=lambda t: -t.source.opt.depth))
  316. def priority_arena_parent(tree: StateTree) -> List[Transition]:
  317. return list(sorted(tree.transition_list, key=lambda t: t.opt.arena.opt.depth))
  318. def priority_arena_child(tree: StateTree) -> List[Transition]:
  319. return list(sorted(tree.transition_list, key=lambda t: -t.opt.arena.opt.depth))
  320. def concurrency_arena_orthogonal(tree: StateTree):
  321. with timer.Context("concurrency_arena_orthogonal"):
  322. import collections
  323. nonoverlapping = collections.defaultdict(list)
  324. for t1,t2 in itertools.combinations(tree.transition_list, r=2):
  325. if not (t1.opt.arena_bitmap & t2.opt.arena_bitmap):
  326. nonoverlapping[t1].append(t2)
  327. nonoverlapping[t2].append(t1)
  328. for t, ts in nonoverlapping.items():
  329. print(str(t), "does not overlap with", ",".join(str(t) for t in ts))
  330. print(len(nonoverlapping), "nonoverlapping pairs of transitions")
  331. def concurrency_src_dst_orthogonal(tree: StateTree):
  332. with timer.Context("concurrency_src_dst_orthogonal"):
  333. import collections
  334. nonoverlapping = collections.defaultdict(list)
  335. for t1,t2 in itertools.combinations(tree.transition_list, r=2):
  336. lca_src = tree.state_list[bm_highest_bit(t1.source.opt.ancestors & t2.source.opt.ancestors)]
  337. lca_dst = tree.state_list[bm_highest_bit(t1.targets[0].opt.ancestors & t2.targets[0].opt.ancestors)]
  338. if isinstance(lca_src, ParallelState) and isinstance(lca_dst, ParallelState):
  339. nonoverlapping[t1].append(t2)
  340. nonoverlapping[t2].append(t1)
  341. for t, ts in nonoverlapping.items():
  342. print(str(t), "does not overlap with", ",".join(str(t) for t in ts))
  343. print(len(nonoverlapping), "nonoverlapping pairs of transitions")