Просмотр исходного кода

Progress towards supporting INSTATE-macro in Rust

Joeri Exelmans 4 лет назад
Родитель
Сommit
667e5b22f3

+ 2 - 0
notes.txt

@@ -37,6 +37,8 @@ Long-term vision:
 
 Random notes:
 
+  - Durations in action language are a single type, consisting of an integer value and a unit. Perhaps it would be better to have a separate duration type for every unit. Especially in Rust it is better to leave the units to the type system, and only work with 'the numbers' in the machine code :)
+
   - Statechart interface:
       - want YAKINDU style: strict separation between:
         - input events

+ 5 - 5
src/sccd/action_lang/codegen/rust.py

@@ -78,8 +78,6 @@ class ActionLangRustGenerator(Visitor):
         self.scope = ScopeHelper()
         self.functions_to_write = [] # Function and Rust identifier
 
-        self.function_types = {} # maps Function to Rust type
-
     def default(self, what):
         # self.w.wno("<%s>" % what)
         raise UnsupportedFeature(what)
@@ -121,8 +119,6 @@ class ActionLangRustGenerator(Visitor):
     # When compiling Rust code, the Visitable.accept method must be called on the root of the AST, to write code wherever desired (e.g. in a main-function) followed by 'write_scope' at the module level.
     def write_decls(self):
 
-        function_types = {}
-
         # Write functions
         for function, identifier in self.functions_to_write:
             scope = function.scope
@@ -144,8 +140,11 @@ class ActionLangRustGenerator(Visitor):
             self.w.writeln("let scope = action_lang::Empty{};")
 
             self.scope.push(function.scope)
+            
             # Parameters are part of function's scope
             self.scope.commit(len(function.params_decl), self.w)
+
+            # Visit the body. This may cause new functions to be added to self.functions_to_write, which we are iterating over (which is allowed in Python), so those will also be dealt with in this loop.
             function.body.accept(self)
             self.scope.pop()
 
@@ -341,4 +340,5 @@ class ActionLangRustGenerator(Visitor):
     def visit__SCCDSimpleType(self, type):
         self.w.wno(type.name
             .replace("int", "i32")
-            .replace("float", "f64"))
+            .replace("float", "f64")
+        )

+ 4 - 0
src/sccd/action_lang/parser/action_lang.g

@@ -39,12 +39,16 @@
      | "(" expr ")"             -> group
      | literal
      | func_call
+     | macro_call
      | array_indexed
      | func_decl
      | array
 
 IDENTIFIER: /[A-Za-z_][A-Za-z_0-9]*/ 
 
+MACRO_IDENTIFIER: "@" IDENTIFIER
+macro_call: MACRO_IDENTIFIER "(" param_list ")"
+
 func_call: atom "(" param_list ")"
 param_list: ( expr ("," expr)* )?  -> params
 

+ 3 - 0
src/sccd/action_lang/parser/text.py

@@ -52,6 +52,9 @@ class ExpressionTransformer(Transformer):
   def func_call(self, node):
     return FunctionCall(node[0], node[1].children)
 
+  def macro_call(self, node):
+    return MacroCall(node[0], node[1].children)
+
   def array_indexed(self, node):
     return ArrayIndexed(node[0], node[1])
 

+ 65 - 24
src/sccd/statechart/codegen/rust.py

@@ -33,7 +33,7 @@ def ident_enum_variant(state: State) -> str:
     # We know the direct children of a state must have unique names relative to each other,
     # and enum variants are scoped locally, so we can use the short name here.
     # Furthermore, the XML parser asserts that state ids are valid identifiers in Rust.
-    return "S" + state.short_name
+    return "S_" + state.short_name
 
 def ident_field(state: State) -> str:
     return "s" + snake_case(state)
@@ -72,6 +72,35 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         super().__init__(w)
         self.globals = globals
 
+        self.parallel_state_cache = {}
+
+        self.state_stack = []
+
+    def get_parallel_states(self, state):
+        try:
+            return self.parallel_state_cache[state]
+        except KeyError:
+            parallel_states = []
+            while state.parent is not None:
+                # print("state:" , state.full_name)
+                # print("parent:" , state.parent.full_name, type(state.parent))
+                if isinstance(state.parent.type, AndState):
+                    # print("parent is And-state")
+                    for sibling in state.parent.children:
+                        # print("sibling: ", sibling.full_name)
+                        if sibling is not state:
+                            parallel_states.append(sibling)
+                state = state.parent
+            self.parallel_state_cache[state] = parallel_states
+            return parallel_states
+
+    def get_parallel_states_tuple(self):
+        parallel_states = self.get_parallel_states(self.state_stack[-1])
+        return "(" + ", ".join("*"+ident_var(s) for s in parallel_states) + ")"
+
+    def visit_SCCDStateConfiguration(self, type):
+        self.w.wno("(%s)" % ", ".join(ident_type(s) for s in self.get_parallel_states(type.state)))
+
     def visit_RaiseOutputEvent(self, a):
         # TODO: evaluate event parameters
         if DEBUG:
@@ -84,9 +113,14 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln("internal.raise().%s = Some(%s{});" % (ident_event_field(a.name), (ident_event_type(a.name))))
 
     def visit_Code(self, a):
-            a.block.accept(self)
+            self.w.write()
+            a.block.accept(self) # block is a function
+            self.w.wno("(%s, scope);" % self.get_parallel_states_tuple()) # call it!
+            self.w.wnoln()
 
     def visit_State(self, state):
+        self.state_stack.append(state)
+
         # visit children first
         for c in state.real_children:
             c.accept(self)
@@ -190,6 +224,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln("}")
         self.w.writeln()
 
+        self.state_stack.pop()
+
     def visit_Statechart(self, sc):
         self.scope.push(sc.scope)
 
@@ -305,7 +341,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         datamodel_type = self.scope.commit(sc.scope.size(), self.w)
         self.w.dedent(); self.w.dedent();
         self.w.writeln("    Self {")
-        self.w.writeln("      current_state: Default::default(),")
+        self.w.writeln("      configuration: Default::default(),")
         for h in tree.history_states:
             self.w.writeln("      %s: Default::default()," % (ident_history_field(h)))
         self.w.writeln("      timers: Default::default(),")
@@ -315,7 +351,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln("}")
         self.w.writeln("type DataModel = %s;" % datamodel_type)
         self.w.writeln("pub struct Statechart<TimerId> {")
-        self.w.writeln("  current_state: %s," % ident_type(tree.root))
+        self.w.writeln("  configuration: %s," % ident_type(tree.root))
         # We always store a history value as 'deep' (also for shallow history).
         # TODO: We may save a tiny bit of space in some rare cases by storing shallow history as only the exited child of the Or-state.
         for h in tree.history_states:
@@ -325,17 +361,16 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln("}")
         self.w.writeln()
 
-        self.write_decls()
-
         # Function fair_step: a single "Take One" Maximality 'round' (= nonoverlapping arenas allowed to fire 1 transition)
         self.w.writeln("fn fair_step<TimerId: Copy, Sched: statechart::Scheduler<InEvent, TimerId>, OutputCallback: FnMut(statechart::OutEvent)>(sc: &mut Statechart<TimerId>, input: Option<InEvent>, internal: &mut InternalLifeline, sched: &mut Sched, output: &mut OutputCallback, dirty: Arenas) -> Arenas {")
         self.w.writeln("  let mut fired: Arenas = ARENA_NONE;")
         self.w.writeln("  let mut scope = &mut sc.data;")
-        self.w.writeln("  let %s = &mut sc.current_state;" % ident_var(tree.root))
+        self.w.writeln("  let %s = &mut sc.configuration;" % ident_var(tree.root))
         self.w.indent()
 
         transitions_written = []
         def write_transitions(state: State):
+            self.state_stack.append(state)
 
             # Many of the states to exit can be computed statically (i.e. they are always the same)
             # The one we cannot compute statically are:
@@ -361,7 +396,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
 
                     if len(exit_path) == 1:
                         # Exit s:
-                        self.w.writeln("%s.exit_current(&mut sc.timers, *parent1, internal, sched, output);" % (ident_var(s)))
+                        self.w.writeln("%s.exit_current(&mut sc.timers, scope, internal, sched, output);" % (ident_var(s)))
                     else:
                         # Exit children:
                         if isinstance(s.type, AndState):
@@ -369,12 +404,12 @@ class StatechartRustGenerator(ActionLangRustGenerator):
                                 if exit_path[1] is c:
                                     write_exit(exit_path[1:]) # continue recursively
                                 else:
-                                    self.w.writeln("%s.exit_current(&mut sc.timers, *parent1, internal, sched, output);" % (ident_var(c)))
+                                    self.w.writeln("%s.exit_current(&mut sc.timers, scope, internal, sched, output);" % (ident_var(c)))
                         elif isinstance(s.type, OrState):
                             write_exit(exit_path[1:]) # continue recursively with the next child on the exit path
 
                         # Exit s:
-                        self.w.writeln("%s::exit_actions(&mut sc.timers, *parent1, internal, sched, output);" % (ident_type(s)))
+                        self.w.writeln("%s::exit_actions(&mut sc.timers, scope, internal, sched, output);" % (ident_type(s)))
 
                     # Store history
                     if s.deep_history:
@@ -398,19 +433,19 @@ class StatechartRustGenerator(ActionLangRustGenerator):
                     if len(enter_path) == 1:
                         # Target state.
                         if isinstance(s, HistoryState):
-                            self.w.writeln("sc.%s.enter_current(&mut sc.timers, *parent1, internal, sched, output); // Enter actions for history state" %(ident_history_field(s)))
+                            self.w.writeln("sc.%s.enter_current(&mut sc.timers, scope, internal, sched, output); // Enter actions for history state" %(ident_history_field(s)))
                         else:
-                            self.w.writeln("%s::enter_default(&mut sc.timers, *parent1, internal, sched, output);" % (ident_type(s)))
+                            self.w.writeln("%s::enter_default(&mut sc.timers, scope, internal, sched, output);" % (ident_type(s)))
                     else:
                         # Enter s:
-                        self.w.writeln("%s::enter_actions(&mut sc.timers, *parent1, internal, sched, output);" % (ident_type(s)))
+                        self.w.writeln("%s::enter_actions(&mut sc.timers, scope, internal, sched, output);" % (ident_type(s)))
                         # Enter children:
                         if isinstance(s.type, AndState):
                             for c in s.children:
                                 if enter_path[1] is c:
                                     write_enter(enter_path[1:]) # continue recursively
                                 else:
-                                    self.w.writeln("%s::enter_default(&mut sc.timers, *parent1, internal, sched, output);" % (ident_type(c)))
+                                    self.w.writeln("%s::enter_default(&mut sc.timers, scope, internal, sched, output);" % (ident_type(c)))
                         elif isinstance(s.type, OrState):
                             if len(s.children) > 0:
                                 write_enter(enter_path[1:]) # continue recursively with the next child on the enter path
@@ -478,19 +513,21 @@ class StatechartRustGenerator(ActionLangRustGenerator):
                             elif bit(e.id) & internal_events:
                                 condition.append("let Some(%s) = &internal.current().%s" % (ident_event_type(e.name), ident_event_field(e.name)))
                             else:
-                                # Bug in SCCD :(
-                                raise Exception("Illegal event ID")
+                                raise Exception("Illegal event ID - Bug in SCCD :(")
                         self.w.writeln("if %s {" % " && ".join(condition))
                         self.w.indent()
 
-                    self.w.writeln("let parent1 = &mut scope;")
-
-                    if t.scope.size() > 0:
-                        raise UnsupportedFeature("Event parameters")
-
                     if t.guard is not None:
+                        if t.guard.scope.size() > 1:
+                            raise UnsupportedFeature("Guard reads an event parameter")
                         self.w.write("if ")
-                        t.guard.accept(self)
+                        t.guard.accept(self) # guard is a function...
+                        self.w.wno("(") # call it!
+                        self.w.wno(self.get_parallel_states_tuple())
+                        self.w.wno(", ")
+                        # TODO: write event parameters here
+                        self.write_parent_call_params(t.guard.scope)
+                        self.w.wno(")")
                         self.w.wnoln(" {")
                         self.w.indent()
 
@@ -570,10 +607,10 @@ class StatechartRustGenerator(ActionLangRustGenerator):
 
                         self.w.writeln("'%s: loop {" % ident_arena_label(state))
                         self.w.indent()
-                        self.w.writeln("match %s {" % ident_var(state))
+                        self.w.writeln("match *%s {" % ident_var(state))
                         for child in state.real_children:
                             self.w.indent()
-                            self.w.writeln("%s::%s(%s) => {" % (ident_type(state), ident_enum_variant(child), ident_var(child)))
+                            self.w.writeln("%s::%s(ref mut %s) => {" % (ident_type(state), ident_enum_variant(child), ident_var(child)))
                             self.w.indent()
                             write_transitions(child)
                             self.w.dedent()
@@ -601,6 +638,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
             else:
                 raise UnsupportedFeature("Priority semantics %s" % sc.semantics.hierarchical_priority)
 
+            self.state_stack.pop()
+
         write_transitions(tree.root)
 
         self.w.dedent()
@@ -668,6 +707,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         # Write state types
         tree.root.accept(self)
 
+        self.write_decls()
+
         if DEBUG:
             self.w.writeln("use std::mem::size_of;")
             self.w.writeln("fn debug_print_sizes<TimerId: Copy>() {")

+ 64 - 61
src/sccd/statechart/parser/xml.py

@@ -1,12 +1,12 @@
 from typing import *
 import re
 from lark.exceptions import *
+from sccd.statechart.static.types import *
 from sccd.statechart.static.statechart import *
 from sccd.statechart.static.tree import *
 from sccd.statechart.dynamic.builtin_scope import *
 from sccd.util.xml_parser import *
 from sccd.statechart.parser.text import *
-
 class SkipFile(Exception):
   pass
 
@@ -88,66 +88,6 @@ def statechart_parser_rules(globals, path, load_external = True, parse_f = parse
       transitions = [] # All of the statechart's transitions accumulate here, cause we still need to find their targets, which we can't do before the entire state tree has been built. We find their targets when encoutering the </root> closing tag.
       after_id = 0 # After triggers need unique IDs within the scope of the statechart model
 
-      # A transition's guard expression and action statements can read the transition's event parameters, and also possibly the current state configuration. We therefore now wrap these into a function with a bunch of parameters for those values that we want to bring into scope.
-      def wrap_transition_params(expr_or_stmt, trigger: Trigger):
-        if isinstance(expr_or_stmt, Statement):
-          # Transition's action code
-          body = expr_or_stmt
-        elif isinstance(expr_or_stmt, Expression):
-          # Transition's guard
-          body = ReturnStatement(expr=expr_or_stmt)
-        else:
-          raise Exception("Unexpected error in parser")
-        # The joy of writing expressions in abstract syntax:
-        wrapped = FunctionDeclaration(
-          params_decl=
-            # The param '@conf' (which, on purpose, is an illegal identifier in textual concrete syntax, to prevent naming collisions) will contain the statechart's configuration as a bitmap (SCCDInt). This parameter is currently only used in the expansion of the INSTATE-macro.
-            [ParamDecl(name="@conf", formal_type=SCCDInt)]
-            # Plus all the parameters of the enabling events of the transition's trigger:
-            + [param for event in trigger.enabling for param in event.params_decl],
-          body=body)
-        return wrapped
-
-      def actions_rules(scope, wrap_trigger: Trigger = EMPTY_TRIGGER):
-
-        def parse_raise(el):
-          params = []
-          def parse_param(el):
-            expr_text = require_attribute(el, "expr")
-            expr = parse_expression(globals, expr_text)
-            function = wrap_transition_params(expr, trigger=wrap_trigger)
-            function.init_expr(scope)
-            params.append(function)
-
-          def finish_raise():
-            event_name = require_attribute(el, "event")
-            try:
-              port = statechart.event_outport[event_name]
-            except KeyError:
-              # Legacy fallback: read port from attribute
-              port = el.get("port")
-            if port is None:
-              # internal event
-              event_id = globals.events.assign_id(event_name)
-              statechart.internally_raised_events |= bit(event_id)
-              return RaiseInternalEvent(event_id=event_id, name=event_name, params=params)
-            else:
-              # output event - no ID in global namespace
-              statechart.event_outport[event_name] = port
-              globals.outports.assign_id(port)
-              return RaiseOutputEvent(name=event_name, params=params, outport=port)
-          return ([("param*", parse_param)], finish_raise)
-
-        def parse_code(el):
-          def finish_code():
-            block = parse_block(globals, el.text)
-            function = wrap_transition_params(block, trigger=wrap_trigger)
-            function.init_expr(scope)
-            return Code(function)
-          return ([], finish_code)
-
-        return {"raise": parse_raise, "code": parse_code}
-
       def get_default_state(el, state, children_dict):
         have_initial = False
 
@@ -173,6 +113,68 @@ def statechart_parser_rules(globals, path, load_external = True, parse_f = parse
 
       def state_child_rules(parent, sibling_dict: Dict[str, State]):
 
+        # A transition's guard expression and action statements can read the transition's event parameters, and also possibly the current state configuration. We therefore now wrap these into a function with a bunch of parameters for those values that we want to bring into scope.
+        def wrap_transition_params(expr_or_stmt, trigger: Trigger):
+          if isinstance(expr_or_stmt, Statement):
+            # Transition's action code
+            body = expr_or_stmt
+          elif isinstance(expr_or_stmt, Expression):
+            # Transition's guard
+            body = ReturnStatement(expr=expr_or_stmt)
+          else:
+            raise Exception("Unexpected error in parser")
+          # The joy of writing expressions in abstract syntax:
+          wrapped = FunctionDeclaration(
+            params_decl=
+              # The param '@conf' (which, on purpose, is an illegal identifier in textual concrete syntax, to prevent naming collisions) will contain the statechart's configuration as a bitmap (SCCDInt). This parameter is currently only used in the expansion of the INSTATE-macro.
+              [ParamDecl(name="_conf", formal_type=SCCDStateConfiguration(state=parent))]
+              # Plus all the parameters of the enabling events of the transition's trigger:
+              + [param for event in trigger.enabling for param in event.params_decl],
+            body=body)
+          return wrapped
+
+        def actions_rules(scope, wrap_trigger: Trigger = EMPTY_TRIGGER):
+
+          def parse_raise(el):
+            params = []
+            def parse_param(el):
+              expr_text = require_attribute(el, "expr")
+              expr = parse_expression(globals, expr_text)
+              function = wrap_transition_params(expr, trigger=wrap_trigger)
+              function.init_expr(scope)
+              function.scope.name = "event_param"
+              params.append(function)
+
+            def finish_raise():
+              event_name = require_attribute(el, "event")
+              try:
+                port = statechart.event_outport[event_name]
+              except KeyError:
+                # Legacy fallback: read port from attribute
+                port = el.get("port")
+              if port is None:
+                # internal event
+                event_id = globals.events.assign_id(event_name)
+                statechart.internally_raised_events |= bit(event_id)
+                return RaiseInternalEvent(event_id=event_id, name=event_name, params=params)
+              else:
+                # output event - no ID in global namespace
+                statechart.event_outport[event_name] = port
+                globals.outports.assign_id(port)
+                return RaiseOutputEvent(name=event_name, params=params, outport=port)
+            return ([("param*", parse_param)], finish_raise)
+
+          def parse_code(el):
+            def finish_code():
+              block = parse_block(globals, el.text)
+              function = wrap_transition_params(block, trigger=wrap_trigger)
+              function.init_expr(scope)
+              function.scope.name = "code"
+              return Code(function)
+            return ([], finish_code)
+
+          return {"raise": parse_raise, "code": parse_code}
+
         def common(el, constructor):
           short_name = require_attribute(el, "id")
           match = re.match("[A-Za-z_][A-Za-z_0-9]*", short_name)
@@ -270,6 +272,7 @@ def statechart_parser_rules(globals, path, load_external = True, parse_f = parse
             guard_expr = parse_expression(globals, cond)
             guard_function = wrap_transition_params(guard_expr, transition.trigger)
             guard_type = guard_function.init_expr(statechart.scope)
+            guard_function.scope.name = "guard"
 
             if guard_type.return_type is not SCCDBool:
               raise XmlError("Guard should be an expression evaluating to 'bool'.")

+ 11 - 0
src/sccd/statechart/static/types.py

@@ -0,0 +1,11 @@
+from sccd.action_lang.static.types import SCCDType
+
+# In the Python interpreter, a state configuration is a 'Bitmap' (basically just an 'int')
+# In generated Rust code, a state configuration is the statechart-specific type called 'Root'.
+class SCCDStateConfiguration(SCCDType):
+    
+    def __init__(self, state):
+        self.state = state
+
+    def _str(self):
+        return "sconf"

+ 1 - 0
src/sccd/test/codegen/rust.py

@@ -1,6 +1,7 @@
 from sccd.test.static.syntax import *
 from sccd.util.indenting_writer import *
 from sccd.cd.codegen.rust import ClassDiagramRustGenerator
+from sccd.action_lang.codegen.rust import UnsupportedFeature
 from sccd.statechart.codegen.rust import ident_event_type
 
 class TestRustGenerator(ClassDiagramRustGenerator):

+ 2 - 2
src/sccd/util/indenting_writer.py

@@ -22,8 +22,8 @@ class IndentingWriter:
     def wno(self, s):
         self.out.write(s)
 
-    def wnoln(self, s):
+    def wnoln(self, s=""):
         self.out.write(s + '\n')
 
-    def write(self, s):
+    def write(self, s=""):
         self.out.write(' '*self.state + s)