Przeglądaj źródła

Implemented INSTATE macro in Rust

Joeri Exelmans 4 lat temu
rodzic
commit
cd3b7a143f

+ 3 - 3
rust/src/action_lang.rs

@@ -41,10 +41,10 @@ pub struct Empty{}
 #[macro_export]
 macro_rules! call_closure {
   ($closure: expr, $($param: expr),*  $(,)?) => {
-    (||{
+    {
       let scope = &mut $closure.0;
       let function = &mut $closure.1;
-      return function($($param),* scope);
-    })()
+      function($($param),* scope)
+    }
   };
 }

+ 14 - 9
src/sccd/action_lang/codegen/rust.py

@@ -10,6 +10,12 @@ def ident_scope_type(scope):
 def ident_scope_constructor(scope):
     return "new_" + ident_scope_type(scope)
 
+def ident_local(name):
+    if name[0] == '@':
+        return "builtin_" + name[1:]
+    else:
+        return "local_" + name
+
 @dataclass(frozen=True)
 class ScopeCommit:
     type_name: str
@@ -66,7 +72,7 @@ class ScopeHelper():
             writer.writeln("let mut scope = %s {" % type_name)
             writer.writeln("  _base: scope,")
             for v in self.current().scope.variables[start:end]:
-                writer.writeln("  %s: local_%s," % (v.name, v.name))
+                writer.writeln("  %s," % ident_local(v.name))
             writer.writeln("};")
 
         self.current().committed = end
@@ -159,9 +165,9 @@ class ActionLangRustGenerator(Visitor):
                 self.w.indent()
                 self.w.writeln("%s (%s) {" % (commit.type_name, commit.supertype_name))
                 for v in scope.variables[commit.start: commit.end]:
-                    self.w.write("  %s: " % v.name)
+                    self.w.write("  %s: " % ident_local(v.name))
                     v.type.accept(self)
-                    self.w.wnoln(", ")
+                    self.w.wnoln(",")
                 self.w.writeln("}")
                 self.w.dedent()
                 self.w.writeln("}")
@@ -256,8 +262,7 @@ class ActionLangRustGenerator(Visitor):
         # self.w.wno(") ")
 
     def visit_ParamDecl(self, expr):
-        self.w.wno("local_")
-        self.w.wno(expr.name)
+        self.w.wno(ident_local(expr.name))
         self.w.wno(": ")
         expr.formal_type.accept(self)
 
@@ -302,16 +307,16 @@ class ActionLangRustGenerator(Visitor):
 
         if lval.is_init:
             self.w.wno("let mut ")
-            self.w.wno("local_" + lval.name)
+            self.w.wno(ident_local(lval.name))
         else:
             if lval.offset < 0:
                 self.w.wno("parent%d." % self.scope.current().scope.nested_levels(lval.offset))
-                self.w.wno(lval.name)
+                self.w.wno(ident_local(lval.name))
             elif lval.offset < self.scope.current().committed:
                 self.w.wno("scope.")
-                self.w.wno(lval.name)
+                self.w.wno(ident_local(lval.name))
             else:
-                self.w.wno("local_" + lval.name)
+                self.w.wno(ident_local(lval.name))
 
     def visit_SCCDClosureObject(self, type):
         self.w.wno("(%s, " % self.scope.type(type.scope, type.scope.size()))

+ 0 - 3
src/sccd/action_lang/parser/text.py

@@ -9,11 +9,9 @@ class Transformer(lark.Transformer):
     self.macros = defaultdict(list)
 
   def set_macro(self, macro_id, constructor):
-    # print("registered macro", macro_id, constructor)
     self.macros[macro_id].append(constructor)
 
   def unset_macro(self, macro_id):
-    # print("unregistered macro", macro_id)
     self.macros[macro_id].pop()
 
   array = Array
@@ -63,7 +61,6 @@ class Transformer(lark.Transformer):
     try:
       constructor = self.macros[macro_id][-1]
     except IndexError as e:
-      print(self.macros)
       raise Exception("Unknown macro: %s" % macro_id) from e
 
     return constructor(params)

+ 61 - 6
src/sccd/statechart/codegen/rust.py

@@ -74,6 +74,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
 
         self.parallel_state_cache = {}
 
+        self.tree = None
         self.state_stack = []
 
     def get_parallel_states(self, state):
@@ -82,12 +83,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         except KeyError:
             parallel_states = []
             while state.parent is not None:
-                # print("state:" , state.full_name)
-                # print("parent:" , state.parent.full_name, type(state.parent))
                 if isinstance(state.parent.type, AndState):
-                    # print("parent is And-state")
                     for sibling in state.parent.children:
-                        # print("sibling: ", sibling.full_name)
                         if sibling is not state:
                             parallel_states.append(sibling)
                 state = state.parent
@@ -96,10 +93,65 @@ class StatechartRustGenerator(ActionLangRustGenerator):
 
     def get_parallel_states_tuple(self):
         parallel_states = self.get_parallel_states(self.state_stack[-1])
-        return "(" + ", ".join("*"+ident_var(s) for s in parallel_states) + ")"
+        return "(" + ", ".join("*"+ident_var(s) for s in parallel_states) + ", )"
+
+    def visit_InStateMacroExpansion(self, instate):
+        source = instate.ref.source
+        target = instate.ref.target
+
+        self.w.wnoln("{ // macro expansion for INSTATE(\"%s\")" % target.full_name)
+        self.w.indent()
+
+        # Non-exhaustive set of current states, given that 'source' is a current state
+        parents = self.get_parallel_states(source)
+
+        # Deconstruct state configuration tuple
+        self.w.write("let (")
+        for parent in parents:
+            self.w.wno("ref ")
+            self.w.wno(ident_var(parent))
+            self.w.wno(", ")
+        self.w.wnoln(") = %s;" % ident_local("@conf"))
+
+        for parent in parents + [source]:
+            if is_ancestor(parent=target, child=parent):
+                # Special case: target is parent of a current state
+                self.w.writeln("true") # Always a current state
+                return
+
+            if not is_ancestor(parent=parent, child=target):
+                continue # Skip
+
+            # Sequence of states. First is parent. Next is the child of parent on the path to target. Last item is target.
+            path = list(self.tree.bitmap_to_states((parent.state_id_bitmap | parent.descendants) & (target.ancestors | target.state_id_bitmap)))
+
+            def write_path(path, target):
+                parent = path[0];
+                if parent is target:
+                    self.w.writeln("true")
+                else:
+                    child = path[1]
+                    if isinstance(parent.type, OrState):
+                        self.w.writeln("if let %s::%s(ref %s) = %s {" % (ident_type(parent), ident_enum_variant(child), ident_var(child), ident_var(parent)))
+                        self.w.indent()
+                        write_path(path[1:], target)
+                        self.w.dedent()
+                        self.w.writeln("} else { false }")
+                    elif isinstance(parent.type, AndState):
+                        self.w.writeln("let ref %s = %s.%s;" % (ident_var(child), ident_var(parent), ident_field(child)))
+                        write_path(path[1:], target)
+                    else:
+                        raise Exception("The impossible has happened")
+
+            write_path(path, target)
+
+        self.w.dedent()
+        self.w.write("}")
+
+        # self.w.wno("false")
 
     def visit_SCCDStateConfiguration(self, type):
-        self.w.wno("(%s)" % ", ".join(ident_type(s) for s in self.get_parallel_states(type.state)))
+        self.w.wno("(%s, )" % ", ".join(ident_type(s) for s in self.get_parallel_states(type.state)))
 
     def visit_RaiseOutputEvent(self, a):
         # TODO: evaluate event parameters
@@ -245,6 +297,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         priority_ordered_transitions = priority.priority_and_concurrency(sc) # may raise error
 
         tree = sc.tree
+        self.tree = tree
 
         self.w.writeln("type Timers<TimerId> = [TimerId; %d];" % tree.timer_count)
         self.w.writeln()
@@ -731,3 +784,5 @@ class StatechartRustGenerator(ActionLangRustGenerator):
             self.w.writeln("  eprintln!(\"------------------------\");")
             self.w.writeln("}")
             self.w.writeln()
+
+        self.tree = None

+ 6 - 4
src/sccd/statechart/parser/xml.py

@@ -7,7 +7,7 @@ from sccd.statechart.static.tree import *
 from sccd.statechart.dynamic.builtin_scope import *
 from sccd.util.xml_parser import *
 from sccd.statechart.parser.text import *
-from sccd.statechart.static.in_state import InState
+from sccd.statechart.static.in_state import InStateMacroExpansion
 
 class SkipFile(Exception):
   pass
@@ -118,9 +118,11 @@ def statechart_parser_rules(globals, path, load_external = True, parse_f = parse
       def state_child_rules(parent, sibling_dict: Dict[str, State]):
 
         def macro_in_state(params):
-          refs = [StateRef(source=parent, path=text_parser.parse_path(p.string)) for p in params]
-          refs_to_resolve.extend(refs)
-          return InState(state_refs=refs)
+          if len(params) != 1:
+            raise XmlError("Macro @in: Expected 1 parameter")
+          ref= StateRef(source=parent, path=text_parser.parse_path(params[0].string))
+          refs_to_resolve.append(ref)
+          return InStateMacroExpansion(ref=ref)
 
         text_parser.parser.options.transformer.set_macro("@in", macro_in_state)
 

+ 4 - 9
src/sccd/statechart/static/in_state.py

@@ -2,10 +2,9 @@ from dataclasses import *
 from sccd.action_lang.static.expression import *
 from sccd.statechart.static.state_ref import StateRef
 
-# Macro expansion for @in
 @dataclass
-class InState(Expression):
-    state_refs: List[StateRef]
+class InStateMacroExpansion(Expression):
+    ref: StateRef
 
     offset: Optional[int] = None
 
@@ -18,11 +17,7 @@ class InState(Expression):
 
     def eval(self, memory: MemoryInterface):
         state_configuration = memory.load(self.offset)
-        # print("state_configuration:", state_configuration)
-        # print("INSTATE ", [(r.target, r.target.state_id_bitmap) for r in self.state_refs], " ??")
-        result = reduce(lambda x,y: x and y, (bool(ref.target.state_id_bitmap & state_configuration) for ref in self.state_refs))
-        # print(result)
-        return result
+        return self.ref.target.state_id_bitmap & state_configuration
 
     def render(self):
-        return "@in(" + ",".join(ref.target.full_name for ref in self.state_refs)
+        return "@in(" + self.ref.target.full_name + ')'

+ 4 - 2
src/sccd/statechart/static/tree.py

@@ -426,10 +426,12 @@ class StateTree:
 
     def lca(self, s1: State, s2: State) -> State:
         # Intersection between source & target ancestors, last member in depth-first sorted state list.
-        return self.state_list[bm_highest_bit(s1.ancestors & s2.ancestors)]
+        return self.state_list[bm_highest_bit((s1.ancestors | s1.state_id_bitmap) & (s2.ancestors | s2.state_id_bitmap))]
 
 def states_to_bitmap(states: Iterable[State]) -> Bitmap:
     return bm_from_list(s.state_id for s in states)
 
+# Is parent ancestor of child? Also returns true when parent IS child.
+# If this function returns True, and child is a current state, then parent will be too.
 def is_ancestor(parent: State, child: State) -> bool:
-    return bm_has(child.ancestors, parent.state_id)
+    return bm_has(child.ancestors | child.state_id_bitmap, parent.state_id)

+ 4 - 8
test_files/features/history/test_deep.xml

@@ -5,7 +5,7 @@
     <semantics
       big_step_maximality="take_many"
       combo_step_maximality="take_many"
-      internal_event_lifeline="queue"/>
+      internal_event_lifeline="next_combo_step"/>
 
     <inport name="in">
       <event name="start"/>
@@ -53,19 +53,19 @@
             </transition>
           </state>
           <state id="step1">
-            <transition cond='INSTATE(["/parallel/orthogonal/wrapper/state_2/inner_4"])' target="../step2">
+            <transition cond='@in("/parallel/orthogonal/wrapper/state_2/inner_4")' target="../step2">
               <raise port="out" event="check1" />
               <raise event="to_outer" />
             </transition>
           </state>
           <state id="step2">
-            <transition cond='INSTATE(["/parallel/orthogonal/outer"])' target="../step3">
+            <transition cond='@in("/parallel/orthogonal/outer")' target="../step3">
               <raise port="out" event="check2" />
               <raise event="to_history" />
             </transition>
           </state>
           <state id="step3">
-            <transition cond='INSTATE(["/parallel/orthogonal/wrapper/state_2/inner_4"])' target="../end">
+            <transition cond='@in("/parallel/orthogonal/wrapper/state_2/inner_4")' target="../end">
               <raise port="out" event="check3" />
             </transition>
           </state>
@@ -82,11 +82,7 @@
   <output>
     <big_step>
       <event port="out" name="check1"/>
-    </big_step>
-    <big_step>
       <event port="out" name="check2"/>
-    </big_step>
-    <big_step>
       <event port="out" name="check3"/>
     </big_step>
   </output>

+ 17 - 0
test_files/features/instate/fail_instate.xml

@@ -0,0 +1,17 @@
+<test>
+  <statechart>
+    <datamodel>
+      instate_b = func {
+        # Illegal: macro @in only available in guards and actions
+
+        return @in("/b");
+      };
+    </datamodel>
+    <root initial="a">
+      <state id="a">
+      </state>
+      <state id="b">
+      </state>
+    </root>
+  </statechart>
+</test>