Bladeren bron

Rust: progress with code generation for function closures

Joeri Exelmans 4 jaren geleden
bovenliggende
commit
22af155800

+ 92 - 20
src/sccd/action_lang/codegen/rust.py

@@ -3,6 +3,9 @@ from sccd.action_lang.static.statement import *
 class UnsupportedFeature(Exception):
     pass
 
+def ident_scope(scope):
+    return "Scope_" + scope.name
+
 class ActionLangRustGenerator(Visitor):
     def __init__(self, w):
         self.w = w
@@ -18,12 +21,13 @@ class ActionLangRustGenerator(Visitor):
         # self.w.writeln("}")
 
     def visit_Assignment(self, stmt):
-        self.w.write('')
-        stmt.lhs.accept(self)
-        self.w.wno(" = ")
-        stmt.rhs.accept(self)
-        self.w.wnoln(";")
-        self.w.writeln("eprintln!(\"%s\");" % termcolor.colored(stmt.render(),'blue'))
+        if not stmt.is_initialization:
+            self.w.write('') # indent
+            stmt.lhs.accept(self)
+            self.w.wno(" = ")
+            stmt.rhs.accept(self)
+            self.w.wnoln(";")
+            self.w.writeln("eprintln!(\"%s\");" % termcolor.colored(stmt.render(),'blue'))
 
     def visit_IfStatement(self, stmt):
         self.w.wno("if ")
@@ -41,9 +45,9 @@ class ActionLangRustGenerator(Visitor):
             self.w.writeln("}")
 
     def visit_ReturnStatement(self, stmt):
-        self.w.wno("return ")
+        self.w.write("return ")
         stmt.expr.accept(self)
-        self.w.writeln(";")
+        self.w.wnoln(";")
 
     def visit_BoolLiteral(self, expr):
         self.w.wno("true" if expr.b else "false")
@@ -62,23 +66,22 @@ class ActionLangRustGenerator(Visitor):
         self.w.wno("]")
 
     def visit_BinaryExpression(self, expr):
-        self.w.wno(" (")
-
         if expr.operator == "**":
             raise UnsupportedFeature("exponent operator")
         else:
+            # always put parentheses
+            self.w.wno("(")
             expr.lhs.accept(self)
             self.w.wno(" %s " % expr.operator
                 .replace('and', '&&')
                 .replace('or', '||')
                 .replace('//', '/')) # integer division
             expr.rhs.accept(self)
-
-        self.w.wno(") ")
+            self.w.wno(")")
 
     def visit_UnaryExpression(self, expr):
         self.w.wno(expr.operator
-            .replace('not', '!'))
+            .replace('not', '! '))
         expr.expr.accept(self)
 
     def visit_Group(self, expr):
@@ -86,17 +89,26 @@ class ActionLangRustGenerator(Visitor):
         expr.subexpr.accept(self)
         # self.w.wno(") ")
 
+    def visit_ParamDecl(self, expr):
+        self.w.wno(expr.name)
+        self.w.wno(": ")
+        expr.formal_type.accept(self)
+
     def visit_FunctionDeclaration(self, expr):
         self.w.wno("|")
         for p in expr.params_decl:
             p.accept(self)
             self.w.wno(", ")
-        self.w.wnoln("| = {")
+        self.w.wnoln("| {")
         self.w.indent()
+        self.w.writeln("let scope = %s::default();" % ident_scope(expr.scope))
         expr.body.accept(self)
         self.w.dedent()
         self.w.write("}")
-        self.scopes.append(expr.scope)
+
+        # should write scope type later
+        if expr.scope not in self.scopes:
+            self.scopes.append(expr.scope)
 
     def visit_FunctionCall(self, expr):
         self.w.wno("(")
@@ -108,18 +120,78 @@ class ActionLangRustGenerator(Visitor):
         self.w.wno(")")
 
     def visit_Identifier(self, lval):
-        self.w.wno("sc.data."+lval.name)
+        self.w.wno("scope."+lval.name)
 
     def visit_Scope(self, scope):
-        self.w.writeln("#[derive(Default)]")
-        self.w.writeln("struct Scope_%s {" % scope.name)
-        for v in scope.variables:
+        # Map variable to template param name (for functions)
+        mapping = {}
+        for i,v in enumerate(scope.variables):
+            if isinstance(v.type, SCCDFunction):
+                mapping[i] = "F%d" % len(mapping)
+
+
+        def write_template_params_with_trait():
+            for i,v in enumerate(scope.variables):
+                if i in mapping:
+                    self.w.wno("%s: " % mapping[i])
+                    v.type.accept(self)
+                    self.w.wno(", ")
+
+        def write_template_params():
+            for i,v in enumerate(scope.variables):
+                if i in mapping:
+                    self.w.wno("%s, " % mapping[i])
+
+        # Write type
+        self.w.write("struct %s<" % ident_scope(scope))
+        write_template_params_with_trait()
+        self.w.wnoln("> {")
+        for i,v in enumerate(scope.variables):
             self.w.write("  %s: " % v.name)
-            v.type.accept(self)
+            if i in mapping:
+                self.w.wno(mapping[i])
+            else:
+                v.type.accept(self)
             self.w.wnoln(",")
         self.w.writeln("}")
+
+        # Impl trait Default:
+        # self.w.writeln("impl Default for %s {" % ident_scope(scope))
+        self.w.write("impl<")
+        write_template_params_with_trait()
+        self.w.wno("> Default for %s<" % ident_scope(scope))
+        write_template_params()
+        self.w.wnoln("> {")
+        self.w.indent()
+        self.w.writeln("fn default() -> Self {")
+        self.w.indent()
+        for v in scope.variables:
+            if v.initial_value is not None:
+                self.w.writeln("eprintln!(\"%s\");" % termcolor.colored("(init) %s = %s;" % (v.name, v.initial_value.render()),'blue'))
+        self.w.writeln("Self {")
+        self.w.indent()
+        for v in scope.variables:
+            if v.initial_value is not None:
+                self.w.write("%s: " % v.name)
+                v.initial_value.accept(self)
+                self.w.wnoln(",")
+            else:
+                self.w.writeln("%s: Default::default()," % v.name)
+        self.w.dedent()
+        self.w.writeln("}")
+        self.w.dedent()
+        self.w.writeln("}")
+        self.w.dedent()
+        self.w.writeln("}")
         self.w.writeln()
 
+    def visit_SCCDFunction(self, type):
+        self.w.wno("FnMut(")
+        for i, t in enumerate(type.param_types):
+            t.accept(self)
+            if i != len(type.param_types)-1:
+                self.w.wno(", ")
+        self.w.wno(")")
 
     def visit__SCCDSimpleType(self, type):
         self.w.wno(type.name

+ 9 - 7
src/sccd/action_lang/static/expression.py

@@ -59,9 +59,10 @@ class Expression(ABC, Visitable):
 # Either 'init_expr' or 'init_lvalue' is called to initialize the LValue.
 # Then either 'eval' or 'eval_lvalue' can be called any number of times.
 class LValue(Expression):
-    # Initialize the LValue as an LValue. 
+    # Initialize the LValue as an LValue.
+    # Returns whether LValue was initialized, or just re-assigned another value.
     @abstractmethod
-    def init_lvalue(self, scope: Scope, rhs_type: SCCDType):
+    def init_lvalue(self, scope: Scope, rhs_t: SCCDType, rhs: Expression) -> bool:
         pass
 
     # Should return offset relative to current context stack frame.
@@ -80,8 +81,9 @@ class Identifier(LValue):
         self.offset, type = scope.get_rvalue(self.name)
         return type
 
-    def init_lvalue(self, scope: Scope, type):
-        self.offset = scope.put_lvalue(self.name, type)
+    def init_lvalue(self, scope: Scope, rhs_t: SCCDType, rhs: Expression) -> bool:
+        self.offset, is_init = scope.put_lvalue(self.name, rhs_t, rhs)
+        return is_init
 
     def assign(self, memory: MemoryInterface, value: Any):
         memory.store(self.offset, value)
@@ -123,7 +125,7 @@ class FunctionCall(Expression):
 
 # Used in EventDecl and FunctionDeclaration
 @dataclass
-class ParamDecl:
+class ParamDecl(Visitable):
     name: str
     formal_type: SCCDType
     offset: Optional[int] = None
@@ -178,11 +180,11 @@ class ArrayIndexed(LValue):
             raise StaticTypeError("Array indexation: Expression '%s' is not an integer" % self.index_type.render())
         return array_type.element_type
 
-    def init_lvalue(self, scope: Scope, type):
+    def init_lvalue(self, scope: Scope, rhs_t: SCCDType, rhs: Expression) -> bool:
         if not isinstance(self.array, LValue):
             raise StaticTypeError("Array indexation as LValue: Expression '%s' must be an LValue" % self.array.render())
 
-        self.array.init_lvalue(scope, SCCDArray(element_type=type))
+        return self.array.init_lvalue(scope, SCCDArray(element_type=type), rhs)
 
     def assign(self, memory: MemoryInterface, value):
         self.array.eval(memory)[self.index.eval(memory)] = value

+ 17 - 9
src/sccd/action_lang/static/scope.py

@@ -17,12 +17,13 @@ class ScopeError(ModelStaticError):
 # Stateless stuff we know about a variable existing within a scope.
 @dataclass(frozen=True)
 class _Variable(ABC):
-  __slots__ = ["_name", "offset", "type", "const"]
+  __slots__ = ["_name", "offset", "type", "const", "initial_value"]
   
   _name: str # only used to print error messages
   offset: int # Offset within variable's scope. Always >= 0.
   type: SCCDType
   const: bool
+  initial_value: 'Expression'
 
   @property
   def name(self):
@@ -85,9 +86,10 @@ class Scope(Visitable):
         return None
 
   # Create name in this scope
-  def _internal_add(self, name, type, const) -> int:
+  # Precondition: _internal_lookup of name returns 'None'
+  def _internal_add(self, name, type, const, initial_value: 'Expression') -> int:
     offset = len(self.variables)
-    var = _Variable(name, offset, type, const)
+    var = _Variable(name, offset, type, const, initial_value)
     self.names[name] = var
     self.variables.append(var)
     return offset
@@ -95,19 +97,25 @@ class Scope(Visitable):
   # This is what we do when we encounter an assignment expression:
   # Add name to current scope if it doesn't exist yet in current or any parent scope.
   # Or assign to existing variable, if the name already exists, if the types match.
-  # Returns offset relative to the beginning of this scope (may be a postive or negative number).
-  def put_lvalue(self, name: str, type: SCCDType) -> int:
+  # Returns tuple:
+  #  - offset relative to the beginning of this scope (may be a postive or negative number).
+  #  - whether a new variable was declared (and initialized)
+  def put_lvalue(self, name: str, type: SCCDType, value: 'Expression') -> (int, bool):
     found = self._internal_lookup(name)
     if found:
       scope, scope_offset, var = found
       if var.type == type:
+        # Cannot assign to const
         if var.const:
           raise ScopeError(self, "Cannot assign to %s: %s of scope '%s': Variable is constant." % (var.name, str(var.type), scope.name))
-        return scope_offset + var.offset
+        # Assign to existing variable
+        return (scope_offset + var.offset, False)
       else:
+        # Types don't match
         raise ScopeError(self, "Cannot assign %s to %s: %s of scope '%s'" %(str(type), var.name, str(var.type), scope.name))
-
-    return self._internal_add(name, type, const=False)
+    else:
+      # Declare new variable
+      return (self._internal_add(name, type, const=False, initial_value=value), True)
 
   # Lookup name in this scope and its ancestors. Raises exception if not found.
   # Returns offset relative to the beginning of this scope, just like put_lvalue, and also the type of the variable.
@@ -128,4 +136,4 @@ class Scope(Visitable):
       scope, scope_offset, var = found
       raise ScopeError(self, "Cannot declare '%s' in scope '%s': Name already exists in scope '%s'" % (var.name, self.name, scope.name))
 
-    return self._internal_add(name, type, const)
+    return self._internal_add(name, type, const, initial_value=None)

+ 4 - 1
src/sccd/action_lang/static/statement.py

@@ -96,9 +96,12 @@ class Assignment(Statement):
     lhs: LValue
     rhs: Expression
 
+    # Did the assignment create a new variable in its scope?
+    is_initialization: Optional[bool] = None
+
     def init_stmt(self, scope: Scope) -> ReturnBehavior:
         rhs_t = self.rhs.init_expr(scope)
-        self.lhs.init_lvalue(scope, rhs_t)
+        self.is_initialization = self.lhs.init_lvalue(scope, rhs_t, self.rhs)
         return NeverReturns
 
     def exec(self, memory: MemoryInterface) -> Return:

+ 114 - 118
src/sccd/statechart/codegen/rust.py

@@ -82,6 +82,115 @@ class StatechartRustGenerator(ActionLangRustGenerator):
     def visit_Code(self, a):
             a.block.accept(self)
 
+    def visit_State(self, state):
+        # visit children first
+        for c in state.real_children:
+            c.accept(self)
+
+        # Write 'current state' types
+        if isinstance(state.type, AndState):
+            self.w.writeln("// And-state")
+            # TODO: Only annotate Copy for states that will be recorded by deep history.
+            self.w.writeln("#[derive(Default, Copy, Clone)]")
+            self.w.writeln("struct %s {" % ident_type(state))
+            for child in state.real_children:
+                self.w.writeln("  %s: %s," % (ident_field(child), ident_type(child)))
+            self.w.writeln("}")
+        elif isinstance(state.type, OrState):
+            self.w.writeln("// Or-state")
+            self.w.writeln("#[derive(Copy, Clone)]")
+            self.w.writeln("enum %s {" % ident_type(state))
+            for child in state.real_children:
+                self.w.writeln("  %s(%s)," % (ident_enum_variant(child), ident_type(child)))
+            self.w.writeln("}")
+
+        # Write "default" constructor
+        # We use Rust's Default-trait to record default states,
+        # this way, constructing a state instance without parameters will initialize it as the default state.
+        if isinstance(state.type, OrState):
+            self.w.writeln("impl Default for %s {" % ident_type(state))
+            self.w.writeln("  fn default() -> Self {")
+            self.w.writeln("    Self::%s(Default::default())" % (ident_enum_variant(state.type.default_state)))
+            self.w.writeln("  }")
+            self.w.writeln("}")
+
+        # Implement trait 'State': enter/exit
+        # self.w.writeln("impl<'a, OutputCallback: FnMut(OutEvent)> State<Timers, Controller<InEvent, OutputCallback>> for %s {" % ident_type(state))
+        # self.w.writeln("impl<'a, OutputCallback: FnMut(OutEvent)> %s {" % ident_type(state))
+        self.w.writeln("impl %s {" % ident_type(state))
+
+        # Enter actions: Executes enter actions of only this state
+        # self.w.writeln("  fn enter_actions(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn enter_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("    eprintln!(\"enter %s\");" % state.full_name);
+        self.w.indent(); self.w.indent()
+        for a in state.enter:
+            a.accept(self)
+        # compile_actions(state.enter, w)
+        self.w.dedent(); self.w.dedent()
+        for a in state.after_triggers:
+            self.w.writeln("    timers[%d] = ctrl.set_timeout(%d, InEvent::%s);" % (a.after_id, a.delay.opt, ident_event_type(a.enabling[0].name)))
+        self.w.writeln("  }")
+
+        # Enter actions: Executes exit actions of only this state
+        # self.w.writeln("  fn exit_actions(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn exit_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("    eprintln!(\"exit %s\");" % state.full_name);
+        for a in state.after_triggers:
+            self.w.writeln("    ctrl.unset_timeout(timers[%d]);" % (a.after_id))
+        self.w.indent(); self.w.indent()
+        for a in state.exit:
+            a.accept(self)
+        # compile_actions(state.exit, w)
+        self.w.dedent(); self.w.dedent()
+        self.w.writeln("  }")
+
+        # Enter default: Executes enter actions of entering this state and its default substates, recursively
+        # self.w.writeln("  fn enter_default(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn enter_default<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("    %s::enter_actions(timers, internal, ctrl);" % (ident_type(state)))
+        if isinstance(state.type, AndState):
+            for child in state.real_children:
+                self.w.writeln("    %s::enter_default(timers, internal, ctrl);" % (ident_type(child)))
+        elif isinstance(state.type, OrState):
+            self.w.writeln("    %s::enter_default(timers, internal, ctrl);" % (ident_type(state.type.default_state)))
+        self.w.writeln("  }")
+
+        # Exit current: Executes exit actions of this state and current children, recursively
+        # self.w.writeln("  fn exit_current(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn exit_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        # first, children (recursion):
+        if isinstance(state.type, AndState):
+            for child in state.real_children:
+                self.w.writeln("    self.%s.exit_current(timers, internal, ctrl);" % (ident_field(child)))
+        elif isinstance(state.type, OrState):
+            self.w.writeln("    match self {")
+            for child in state.real_children:
+                self.w.writeln("      Self::%s(s) => { s.exit_current(timers, internal, ctrl); }," % (ident_enum_variant(child)))
+            self.w.writeln("    }")
+        # then, parent:
+        self.w.writeln("    %s::exit_actions(timers, internal, ctrl);" % (ident_type(state)))
+        self.w.writeln("  }")
+
+        # Exit current: Executes enter actions of this state and current children, recursively
+        # self.w.writeln("  fn enter_current(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn enter_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        # first, parent:
+        self.w.writeln("    %s::enter_actions(timers, internal, ctrl);" % (ident_type(state)))
+        # then, children (recursion):
+        if isinstance(state.type, AndState):
+            for child in state.real_children:
+                self.w.writeln("    self.%s.enter_current(timers, internal, ctrl);" % (ident_field(child)))
+        elif isinstance(state.type, OrState):
+            self.w.writeln("    match self {")
+            for child in state.real_children:
+                self.w.writeln("      Self::%s(s) => { s.enter_current(timers, internal, ctrl); }," % (ident_enum_variant(child)))
+            self.w.writeln("    }")
+        self.w.writeln("  }")
+
+        self.w.writeln("}")
+        self.w.writeln()
+
     def visit_Statechart(self, sc):
         if sc.semantics.concurrency == Concurrency.MANY:
             raise UnsupportedFeature("concurrency")
@@ -142,122 +251,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
             # self.w.writeln("}")
         self.w.writeln()
 
-        # Write 'current state' types
-        def write_state_type(state: State, children: List[State]):
-            if isinstance(state.type, AndState):
-                self.w.writeln("// And-state")
-                # TODO: Only annotate Copy for states that will be recorded by deep history.
-                self.w.writeln("#[derive(Default, Copy, Clone)]")
-                self.w.writeln("struct %s {" % ident_type(state))
-                for child in children:
-                    self.w.writeln("  %s: %s," % (ident_field(child), ident_type(child)))
-                self.w.writeln("}")
-            elif isinstance(state.type, OrState):
-                self.w.writeln("// Or-state")
-                self.w.writeln("#[derive(Copy, Clone)]")
-                self.w.writeln("enum %s {" % ident_type(state))
-                for child in children:
-                    self.w.writeln("  %s(%s)," % (ident_enum_variant(child), ident_type(child)))
-                self.w.writeln("}")
-            return state
-
-        # Write "default" constructor
-        def write_default(state: State, children: List[State]):
-            # We use Rust's Default-trait to record default states,
-            # this way, constructing a state instance without parameters will initialize it as the default state.
-            if isinstance(state.type, OrState):
-                self.w.writeln("impl Default for %s {" % ident_type(state))
-                self.w.writeln("  fn default() -> Self {")
-                self.w.writeln("    Self::%s(Default::default())" % (ident_enum_variant(state.type.default_state)))
-                self.w.writeln("  }")
-                self.w.writeln("}")
-            return state
-
-        # Implement trait 'State': enter/exit
-        def write_enter_exit(state: State, children: List[State]):
-            # self.w.writeln("impl<'a, OutputCallback: FnMut(OutEvent)> State<Timers, Controller<InEvent, OutputCallback>> for %s {" % ident_type(state))
-            # self.w.writeln("impl<'a, OutputCallback: FnMut(OutEvent)> %s {" % ident_type(state))
-            self.w.writeln("impl %s {" % ident_type(state))
-
-            # Enter actions: Executes enter actions of only this state
-            # self.w.writeln("  fn enter_actions(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("  fn enter_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("    eprintln!(\"enter %s\");" % state.full_name);
-            self.w.indent(); self.w.indent()
-            for a in state.enter:
-                a.accept(self)
-            # compile_actions(state.enter, w)
-            self.w.dedent(); self.w.dedent()
-            for a in state.after_triggers:
-                self.w.writeln("    timers[%d] = ctrl.set_timeout(%d, InEvent::%s);" % (a.after_id, a.delay.opt, ident_event_type(a.enabling[0].name)))
-            self.w.writeln("  }")
-
-            # Enter actions: Executes exit actions of only this state
-            # self.w.writeln("  fn exit_actions(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("  fn exit_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("    eprintln!(\"exit %s\");" % state.full_name);
-            for a in state.after_triggers:
-                self.w.writeln("    ctrl.unset_timeout(timers[%d]);" % (a.after_id))
-            self.w.indent(); self.w.indent()
-            for a in state.exit:
-                a.accept(self)
-            # compile_actions(state.exit, w)
-            self.w.dedent(); self.w.dedent()
-            self.w.writeln("  }")
-
-            # Enter default: Executes enter actions of entering this state and its default substates, recursively
-            # self.w.writeln("  fn enter_default(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("  fn enter_default<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("    %s::enter_actions(timers, internal, ctrl);" % (ident_type(state)))
-            if isinstance(state.type, AndState):
-                for child in children:
-                    self.w.writeln("    %s::enter_default(timers, internal, ctrl);" % (ident_type(child)))
-            elif isinstance(state.type, OrState):
-                self.w.writeln("    %s::enter_default(timers, internal, ctrl);" % (ident_type(state.type.default_state)))
-            self.w.writeln("  }")
-
-            # Exit current: Executes exit actions of this state and current children, recursively
-            # self.w.writeln("  fn exit_current(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("  fn exit_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            # first, children (recursion):
-            if isinstance(state.type, AndState):
-                for child in children:
-                    self.w.writeln("    self.%s.exit_current(timers, internal, ctrl);" % (ident_field(child)))
-            elif isinstance(state.type, OrState):
-                self.w.writeln("    match self {")
-                for child in children:
-                    self.w.writeln("      Self::%s(s) => { s.exit_current(timers, internal, ctrl); }," % (ident_enum_variant(child)))
-                self.w.writeln("    }")
-            # then, parent:
-            self.w.writeln("    %s::exit_actions(timers, internal, ctrl);" % (ident_type(state)))
-            self.w.writeln("  }")
-
-            # Exit current: Executes enter actions of this state and current children, recursively
-            # self.w.writeln("  fn enter_current(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            self.w.writeln("  fn enter_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-            # first, parent:
-            self.w.writeln("    %s::enter_actions(timers, internal, ctrl);" % (ident_type(state)))
-            # then, children (recursion):
-            if isinstance(state.type, AndState):
-                for child in children:
-                    self.w.writeln("    self.%s.enter_current(timers, internal, ctrl);" % (ident_field(child)))
-            elif isinstance(state.type, OrState):
-                self.w.writeln("    match self {")
-                for child in children:
-                    self.w.writeln("      Self::%s(s) => { s.enter_current(timers, internal, ctrl); }," % (ident_enum_variant(child)))
-                self.w.writeln("    }")
-            self.w.writeln("  }")
-
-            self.w.writeln("}")
-            self.w.writeln()
-            return state
-
-        visit_tree(tree.root, lambda s: s.real_children,
-            child_first=[
-                write_state_type,
-                write_default,
-                write_enter_exit,
-            ])
+        # Write state types
+        tree.root.accept(self)
 
         syntactic_maximality = (
             sc.semantics.big_step_maximality == Maximality.SYNTACTIC
@@ -326,6 +321,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         # Function fair_step: a single "Take One" Maximality 'round' (= nonoverlapping arenas allowed to fire 1 transition)
         self.w.writeln("fn fair_step<OutputCallback: FnMut(OutEvent)>(sc: &mut Statechart, input: Option<InEvent>, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>, dirty: Arenas) -> Arenas {")
         self.w.writeln("  let mut fired: Arenas = ARENA_NONE;")
+        self.w.writeln("  let scope = &mut sc.data;")
         self.w.writeln("  let %s = &mut sc.current_state;" % ident_var(tree.root))
         self.w.indent()
 
@@ -647,10 +643,10 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln("  fn init(&mut self, ctrl: &mut Controller<InEvent, OutputCallback>) {")
         if sc.datamodel is not None:
             self.w.indent(); self.w.indent();
-            self.w.writeln("let sc: &mut Self = self;")
+            self.w.writeln("let scope = &mut self.data;")
             sc.datamodel.accept(self)
-            self.scopes.append(sc.scope)
             self.w.dedent(); self.w.dedent();
+        self.scopes.append(sc.scope)
         self.w.writeln("    %s::enter_default(&mut self.timers, &mut Default::default(), ctrl)" % (ident_type(tree.root)))
         self.w.writeln("  }")
         self.w.writeln("  fn big_step(&mut self, input: Option<InEvent>, c: &mut Controller<InEvent, OutputCallback>) {")

+ 1 - 1
src/sccd/statechart/static/tree.py

@@ -35,7 +35,7 @@ class AbstractState:
     __repr__ = __str__
 
 @dataclass(eq=False)
-class State(AbstractState):
+class State(AbstractState, Visitable):
     type: 'StateType' = None
 
     real_children: List['State'] = field(default_factory=list) # children, but not including pseudo-states such as history

+ 39 - 0
test/test_files/features/action_lang/test_functions2.xml

@@ -0,0 +1,39 @@
+<?xml version="1.0" ?>
+<test>
+  <!-- a simpler functions test - no imports -->
+  <statechart>
+    <datamodel>
+      add42 = func(i: int) {
+        return i + 42;
+      };
+    </datamodel>
+
+    <inport name="in">
+      <event name="start"/>
+    </inport>
+
+    <outport name="out">
+      <event name="ok"/>
+    </outport>
+
+    <root initial="ready">
+      <state id="ready">
+        <transition event="start" cond="add42(10) == 52" target="../final">
+          <raise event="ok"/>
+        </transition>
+      </state>
+
+      <state id="final"/>
+    </root>
+  </statechart>
+
+  <input>
+    <event port="in" name="start" time="0 d"/>
+  </input>
+
+  <output>
+    <big_step>
+      <event port="out" name="ok"/>
+    </big_step>
+  </output>
+</test>