Browse Source

Progress with Rust code generation for action language. Some earlier tests broken.

Joeri Exelmans 4 years ago
parent
commit
f4fc142494

+ 237 - 112
src/sccd/action_lang/codegen/rust.py

@@ -1,17 +1,88 @@
 from sccd.action_lang.static.statement import *
+from collections import defaultdict
 
 class UnsupportedFeature(Exception):
     pass
 
+def ident_scope_type(scope):
+    return "Scope_%s" % (scope.name)
+
+def ident_scope_constructor(scope):
+    return "new_" + ident_scope_type(scope)
+
+@dataclass(frozen=True)
+class ScopeCommit:
+    type_name: str
+    supertype_name: str
+    start: int
+    end: int
+
+@dataclass
+class ScopeStackEntry:
+    scope: Scope
+    committed: int = 0
+
+class ScopeHelper():
+    def __init__(self):
+        self.scope_stack = []
+        self.scope_structs = defaultdict(dict)
+        self.scope_names = {}
+
+    def root(self):
+        return self.scope_stack[0].scope
+
+    def push(self, scope):
+        self.scope_stack.append( ScopeStackEntry(scope) )
+
+    def pop(self):
+        self.scope_stack.pop()
+
+    def current(self):
+        return self.scope_stack[-1]
+
+    def basename(self, scope):
+        return self.scope_names.setdefault(scope, "Scope%d_%s" % (len(self.scope_names), scope.name))
+    
+    def type(self, scope, end):
+        if end == 0:
+            return "Empty"
+        else:
+            return self.basename(scope) + "_l" + str(end)
+
+    def commit(self, offset, writer):
+        start = self.current().committed
+        end = offset
+        type_name = self.type(self.current().scope, end)
+
+        if start != end  and  end > 0:
+            if start == 0:
+                supertype_name = "Empty"
+            else:
+                supertype_name = self.scope_structs[self.current().scope][start].type_name
+
+            commit = ScopeCommit(type_name, supertype_name, start, end)
+            self.scope_structs[self.current().scope][end] = commit
+
+            writer.writeln("let mut scope = %s {" % type_name)
+            writer.writeln("  _base: scope,")
+            for v in self.current().scope.variables[start:end]:
+                writer.writeln("  %s," % v.name)
+            writer.writeln("};")
+
+        self.current().committed = end
+        return type_name
 
 class ActionLangRustGenerator(Visitor):
     def __init__(self, w):
         self.w = w
-        self.scopes = {}
-        self.scopes_written = set()
+        self.scope = ScopeHelper()
+        self.functions_to_write = [] # Function and Rust identifier
+
+        self.function_types = {} # maps Function to Rust type
 
     def default(self, what):
-        raise UnsupportedFeature(what)
+        self.w.wno("<%s>" % what)
+        # raise UnsupportedFeature(what)
 
     def debug_print_stack(self):
         # Print Python stack in Rust file as a comment
@@ -19,44 +90,95 @@ class ActionLangRustGenerator(Visitor):
         for line in ''.join(traceback.format_stack()).split('\n'):
             self.w.writeln("// "+line)
 
+    def write_parent_params(self, scope, with_identifiers=True):
+        ctr = 1
+        while scope is not self.scope.root():
+            if ctr > scope.deepest_lookup:
+                break
+            if with_identifiers:
+                self.w.wno("parent%d: " % ctr)
+            self.w.wno("&mut %s, " % self.scope.type(scope.parent, scope.parent_offset))
+            ctr += 1
+            scope = scope.parent
+
+    def write_parent_call_params(self, scope, skip=0):
+        ctr = 0
+        while scope is not self.scope.root():
+            if ctr == skip:
+                break
+            ctr += 1
+            scope = scope.parent
+
+        while scope is not self.scope.root():
+            if ctr > scope.deepest_lookup:
+                break
+            if ctr == skip:
+                self.w.wno("&mut scope, ")
+            else:
+                self.w.wno("&mut parent%d, " % ctr-skip)
+            ctr += 1
+            scope = scope.parent
 
-    def ident_scope(self, scope):
-        return self.scopes.setdefault(scope, "Scope%d_%s" % (len(self.scopes), scope.name))
+    # This is not a visit method because Scopes may be encountered whenever there's a function call, but they are written as structs and constructor functions, which can only be written at the module level.
+    # When compiling Rust code, the Visitable.accept method must be called on the root of the AST, to write code wherever desired (e.g. in a main-function) followed by 'write_scope' at the module level.
+    def write_decls(self):
 
-    def ident_new_scope(self, scope):
-        return "new_" + self.ident_scope(scope)
+        function_types = {}
 
-    def write_scopes(self):
-        def scopes_to_write():
-            to_write = []
-            for s in self.scopes:
-                if s not in self.scopes_written:
-                    to_write.append(s)
-            return to_write
+        # Write functions
+        for function, identifier in self.functions_to_write:
+            scope = function.scope
+            # self.w.write("fn %s(parent_scope: &mut %s, " % (identifier, self.scope.type(scope.parent, scope.parent_offset)))
+            self.w.write("fn %s(" % (identifier))
 
-        while True:
-            to_write = scopes_to_write()
-            if len(to_write) == 0:
-                return
-            else:
-                for scope in to_write:
-                    scope.accept(self)
+            self.write_parent_params(scope)
+
+            for p in function.params_decl:
+                p.accept(self)
+
+            # self.w.wno(") -> (%s," % self.scope.type(scope, scope.size()))
+            self.w.wno(") -> ")
+            self.write_return_type(function)
+            self.w.wnoln(" {")
+            self.w.indent()
+            self.w.writeln("let scope = Empty{};")
 
+            self.scope.push(function.scope)
+            # Parameters are part of function's scope
+            self.scope.commit(len(function.params_decl), self.w)
+            function.body.accept(self)
+            self.scope.pop()
+
+            self.w.dedent()
+            self.w.writeln("}")
+            self.w.writeln()
+
+        # Write function scopes (as structs)
+        for scope, structs in self.scope.scope_structs.items():
+            for end, commit in structs.items():
+                self.w.writeln("inherit_struct! {")
+                self.w.indent()
+                self.w.writeln("%s (%s) {" % (commit.type_name, commit.supertype_name))
+                for v in scope.variables[commit.start: commit.end]:
+                    self.w.write("  %s: " % v.name)
+                    v.type.accept(self)
+                    self.w.wnoln(", ")
+                self.w.writeln("}")
+                self.w.dedent()
+                self.w.writeln("}")
+                self.w.writeln()
 
     def visit_Block(self, stmt):
-        # self.w.writeln("{")
         for s in stmt.stmts:
             s.accept(self)
-        # self.w.writeln("}")
 
     def visit_Assignment(self, stmt):
-        if not stmt.is_initialization:
-            self.w.write('') # indent
-            stmt.lhs.accept(self)
-            self.w.wno(" = ")
-            stmt.rhs.accept(self)
-            self.w.wnoln(";")
-            self.w.writeln("eprintln!(\"%s\");" % termcolor.colored(stmt.render(),'blue'))
+        #self.w.write('') # indent
+        stmt.lhs.accept(self)
+        self.w.wno(" = ")
+        stmt.rhs.accept(self)
+        self.w.wnoln(";")
+        self.w.writeln("eprintln!(\"%s\");" % termcolor.colored(stmt.render(),'blue'))
 
     def visit_IfStatement(self, stmt):
         self.w.write("if ")
@@ -75,7 +197,21 @@ class ActionLangRustGenerator(Visitor):
 
     def visit_ReturnStatement(self, stmt):
         self.w.write("return ")
+
+        return_type = stmt.expr.get_type()
+        returns_closure_obj = (
+            isinstance(return_type, SCCDFunction) and
+            return_type.function.scope.parent is stmt.scope
+        )
+
+        # self.w.write("return (scope, ")
+        if returns_closure_obj:
+            self.w.wno("(scope, ")
         stmt.expr.accept(self)
+        # self.w.wnoln(");")
+        if returns_closure_obj:
+            self.w.wno(")")
+
         self.w.wnoln(";")
 
     def visit_ExpressionStatement(self, stmt):
@@ -129,102 +265,91 @@ class ActionLangRustGenerator(Visitor):
         expr.formal_type.accept(self)
 
     def visit_FunctionDeclaration(self, expr):
-        self.w.wno("|")
-        for i, p in enumerate(expr.params_decl):
-            p.accept(self)
-            if i != len(expr.params_decl)-1:
-                self.w.wno(", ")
-        self.w.wnoln("| {")
-        self.w.indent()
-        self.w.writeln("let scope = %s();" % self.ident_new_scope(expr.scope))
-        expr.body.accept(self)
-        self.w.dedent()
-        self.w.write("}")
+        function_identifier = "f%d_%s" % (len(self.functions_to_write), expr.scope.name)
+        self.functions_to_write.append( (expr, function_identifier) )
+        self.w.wno(function_identifier)
 
     def visit_FunctionCall(self, expr):
-        self.w.wno("(")
-        expr.function.accept(self)
-        self.w.wno(")(")
+        if isinstance(expr.function.get_type(), SCCDClosureObject):
+            self.w.wno("call_closure!(")
+            expr.function.accept(self) # an Identifier or a FunctionDeclaration (=anonymous function)
+            self.w.wno(", ")
+
+            self.write_parent_call_params(expr.function_being_called.scope, skip=1)
+        else:
+            self.w.wno("(")
+            expr.function.accept(self) # an Identifier or a FunctionDeclaration (=anonymous function)
+            self.w.wno(")(")
+
+            # Parent scope mut refs
+            self.write_parent_call_params(expr.function_being_called.scope, skip=1)
+
+        # Call parameters
         for p in expr.params:
             p.accept(self)
             self.w.wno(", ")
         self.w.wno(")")
 
+        # if not isinstance(expr.get_type(), SCCDClosureObject):
+        #     # In our generated code, a function always returns a pair of
+        #     #   0) the called function's scope
+        #     #   1) the returned value <- pick this
+        #     self.w.wno(".1")
+
+
     def visit_Identifier(self, lval):
-        self.w.wno("scope."+lval.name)
-
-    def visit_Scope(self, scope):
-        # Map variable to template param name (for functions)
-        mapping = {}
-        for i,v in enumerate(scope.variables):
-            if isinstance(v.type, SCCDFunction):
-                mapping[i] = "F%d" % len(mapping)
-
-        def write_template_params_with_trait():
-            for i,v in enumerate(scope.variables):
-                if i in mapping:
-                    self.w.wno("%s: " % mapping[i])
-                    v.type.accept(self)
-                    self.w.wno(", ")
 
-        def write_template_params():
-            for i,v in enumerate(scope.variables):
-                if i in mapping:
-                    self.w.wno("%s, " % mapping[i])
+        if lval.is_lvalue:
+            # self.debug_print_stack()
+            # self.w.writeln("// offset: %d, branches: %s" % (lval.offset, self.scope.current().scope.children.keys()))
+            if lval.offset in self.scope.current().scope.children:
+                # a child scope exists at the current offset (typically because we encountered a function declaration) - so we must commit our scope
+                self.scope.commit(lval.offset, self.w)
 
-        def write_opaque_params():
-            for i,v in enumerate(scope.variables):
-                if i in mapping:
-                    self.w.wno("impl ")
-                    v.type.accept(self)
-                    self.w.wno(", ")
-
-        # Write type
-        self.w.write("struct %s<" % self.ident_scope(scope))
-        write_template_params_with_trait()
-        self.w.wnoln("> {")
-        for i,v in enumerate(scope.variables):
-            self.w.write("  %s: " % v.name)
-            if i in mapping:
-                self.w.wno(mapping[i])
-            else:
-                v.type.accept(self)
-            self.w.wnoln(",")
-        self.w.writeln("}")
+            self.w.write('') # indent
 
-        # Write type-parameterless constructor:
-        self.w.write("fn %s() -> %s<" % (self.ident_new_scope(scope), self.ident_scope(scope)))
-        write_opaque_params()
-        self.w.wnoln("> {")
-        self.w.indent()
-        for v in scope.variables:
-            if v.initial_value is not None:
-                self.w.writeln("eprintln!(\"%s\");" % termcolor.colored("(init) %s = %s;" % (v.name, v.initial_value.render()),'blue'))
-        self.w.writeln("%s {" % self.ident_scope(scope))
-        self.w.indent()
-        for v in scope.variables:
-            if v.initial_value is not None:
-                self.w.write("%s: " % v.name)
-                v.initial_value.accept(self)
-                self.w.wnoln(",")
-            else:
-                self.w.writeln("%s: Default::default()," % v.name)
-        self.w.dedent()
-        self.w.writeln("}")
-        self.w.dedent()
-        self.w.writeln("}")
-        self.w.writeln()
+        if lval.is_init:
+            self.w.wno("let mut ")
+        else:
+            if lval.offset < 0:
+                self.w.wno("parent%d." % self.scope.current().scope.nested_levels(lval.offset))
+            elif lval.offset < self.scope.current().committed:
+                self.w.wno("scope.")
 
-        self.scopes_written.add(scope)
+        self.w.wno(lval.name)
 
-    def visit_SCCDFunction(self, type):
-        self.w.wno("FnMut(")
-        for i, t in enumerate(type.param_types):
-            t.accept(self)
-            if i != len(type.param_types)-1:
-                self.w.wno(", ")
+    def visit_SCCDClosureObject(self, type):
+        # self.w.wno("<closure type> ")
+
+        self.w.wno("(%s, " % self.scope.type(type.scope, type.scope.parent_offset))
+        type.function_type.accept(self)
         self.w.wno(")")
 
+    def write_return_type(self, function: FunctionDeclaration):
+        # self.w.wno("(%s, " % self.scope.type(function.scope, function.scope.size()))
+        # type.function_type.accept(self)
+        if function.return_type is None:
+            self.w.wno("Empty")
+        else:
+            function.return_type.accept(self)
+        # self.w.wno(")")
+
+    def visit_SCCDFunction(self, type):
+        # self.w.wno("<function type> ")
+        scope = type.function.scope
+        # self.w.wno("fn(&mut %s, " % (self.scope.type(scope.parent, scope.parent_offset)))
+        self.w.wno("fn(")
+
+        self.write_parent_params(scope, with_identifiers=False)
+
+        for p in type.param_types:
+            p.accept(self)
+            self.w.wno(", ")
+
+        self.w.wno(") -> ")
+        self.write_return_type(type.function)
+
     def visit__SCCDSimpleType(self, type):
         self.w.wno(type.name
-            .replace("int", "i32"))
+            .replace("int", "i32")
+            .replace("float", "f64"))

+ 100 - 19
src/sccd/action_lang/static/expression.py

@@ -45,6 +45,11 @@ class Expression(ABC, Visitable):
     def init_expr(self, scope: Scope) -> SCCDType:
         pass
 
+    # Returns static type of expression.
+    @abstractmethod
+    def get_type(self) -> SCCDType:
+        pass
+
     # Evaluate the expression.
     # Evaluation may have side effects.
     @abstractmethod
@@ -75,15 +80,30 @@ class LValue(Expression):
 @dataclass
 class Identifier(LValue):
     name: str
+
     offset: Optional[int] = None
+    type: Optional[SCCDType] = None
+    is_init: Optional[bool] = None
+    is_lvalue: Optional[bool] = None
+
+    # is_function_call_result: Optional[SCCDFunctionCallResult] = None
 
     def init_expr(self, scope: Scope) -> SCCDType:
-        self.offset, type = scope.get_rvalue(self.name)
-        return type
+        self.offset, self.type = scope.get_rvalue(self.name)
+        self.is_init = False
+        self.is_lvalue = False
+        return self.type
+
+    def get_type(self) -> SCCDType:
+        return self.type
 
     def init_lvalue(self, scope: Scope, rhs_t: SCCDType, rhs: Expression) -> bool:
-        self.offset, is_init = scope.put_lvalue(self.name, rhs_t, rhs)
-        return is_init
+        # if isinstance(rhs_t, SCCDFunctionCallResult):
+            # self.is_function_call_result = rhs_t
+            # rhs_t = rhs_t.return_type
+        self.offset, self.is_init = scope.put_lvalue(self.name, rhs_t, rhs)
+        self.is_lvalue = True
+        return self.is_init
 
     def assign(self, memory: MemoryInterface, value: Any):
         memory.store(self.offset, value)
@@ -94,18 +114,30 @@ class Identifier(LValue):
     def render(self):
         return self.name
 
+
 @dataclass
 class FunctionCall(Expression):
-    function: Expression
+    function: Expression # an identifier, or another function call
     params: List[Expression]
 
+    return_type: Optional[SCCDType] = None
+    function_being_called: Optional['FunctionDeclaration'] = None
+
     def init_expr(self, scope: Scope) -> SCCDType:
         function_type = self.function.init_expr(scope)
+
+        # A FunctionCall can be a call on a regular function, or a closure object
+        if isinstance(function_type, SCCDClosureObject):
+            # For static analysis, we treat calls on closure objects just like calls on regular functions.
+            function_type = function_type.function_type
+
         if not isinstance(function_type, SCCDFunction):
             raise StaticTypeError("Function call: Expression '%s' is not a function" % self.function.render())
 
+        self.function_being_called = function_type.function
+
         formal_types = function_type.param_types
-        return_type = function_type.return_type
+        self.return_type = function_type.return_type
 
         actual_types = [p.init_expr(scope) for p in self.params]
         if len(formal_types) != len(actual_types):
@@ -113,7 +145,12 @@ class FunctionCall(Expression):
         for i, (formal, actual) in enumerate(zip(formal_types, actual_types)):
             if formal != actual:
                 raise StaticTypeError("Function call, argument %d: %s is not expected type %s, instead is %s" % (i, self.params[i].render(), str(formal), str(actual)))
-        return return_type
+
+        # The type of a function call is the return type of the function called
+        return self.return_type
+
+    def get_type(self) -> SCCDType:
+        return self.return_type
 
     def eval(self, memory: MemoryInterface):
         f = self.function.eval(memory)
@@ -136,11 +173,13 @@ class ParamDecl(Visitable):
     def render(self):
         return self.name + ":" + str(self.formal_type)
 
-@dataclass
+@dataclass(eq=False) # eq=False: make it hashable (plus, we don't need auto eq)
 class FunctionDeclaration(Expression):
     params_decl: List[ParamDecl]
     body: 'Statement'
     scope: Optional[Scope] = None
+    return_type: Optional[SCCDType] = None
+    type: Optional[SCCDFunction] = None
 
     def init_expr(self, scope: Scope) -> SCCDType:
         self.scope = Scope("function", scope)
@@ -148,8 +187,16 @@ class FunctionDeclaration(Expression):
         for p in self.params_decl:
             p.init_param(self.scope)
         ret = self.body.init_stmt(self.scope)
-        return_type = ret.get_return_type()
-        return SCCDFunction([p.formal_type for p in self.params_decl], return_type)
+        self.return_type = ret.get_return_type()
+
+        if isinstance(self.return_type, SCCDFunction) and self.return_type.function.scope.parent is self.scope:
+            # Called function returns a closure object
+            self.return_type = SCCDClosureObject(self.scope, function_type=self.return_type)
+        self.type = SCCDFunction([p.formal_type for p in self.params_decl], self.return_type, function=self)
+        return self.type
+
+    def get_type(self) -> SCCDType:
+        return self.type
 
     def eval(self, memory: MemoryInterface):
         context: 'StackFrame' = memory.current_frame()
@@ -165,7 +212,7 @@ class FunctionDeclaration(Expression):
 
     def render(self) -> str:
         return "func(%s) [...]" % ", ".join(p.render() for p in self.params_decl) # todo
-        
+
 @dataclass
 class ArrayIndexed(LValue):
     array: Expression
@@ -180,6 +227,9 @@ class ArrayIndexed(LValue):
             raise StaticTypeError("Array indexation: Expression '%s' is not an integer" % self.index_type.render())
         return array_type.element_type
 
+    def get_type(self) -> SCCDType:
+        return self.array.get_type().element_type
+
     def init_lvalue(self, scope: Scope, rhs_t: SCCDType, rhs: Expression) -> bool:
         if not isinstance(self.array, LValue):
             raise StaticTypeError("Array indexation as LValue: Expression '%s' must be an LValue" % self.array.render())
@@ -206,6 +256,9 @@ class StringLiteral(Expression):
     def init_expr(self, scope: Scope) -> SCCDType:
         return SCCDString
 
+    def get_type(self) -> SCCDType:
+        return SCCDString
+
     def eval(self, memory: MemoryInterface):
         return self.string
 
@@ -220,6 +273,9 @@ class IntLiteral(Expression):
     def init_expr(self, scope: Scope) -> SCCDType:
         return SCCDInt
 
+    def get_type(self) -> SCCDType:
+        return SCCDInt
+
     def eval(self, memory: MemoryInterface):
         return self.i
 
@@ -233,6 +289,9 @@ class FloatLiteral(Expression):
     def init_expr(self, scope: Scope) -> SCCDType:
         return SCCDFloat
 
+    def get_type(self) -> SCCDType:
+        return SCCDFloat
+
     def eval(self, memory: MemoryInterface):
         return self.f
 
@@ -246,6 +305,9 @@ class BoolLiteral(Expression):
     def init_expr(self, scope: Scope) -> SCCDType:
         return SCCDBool
 
+    def get_type(self) -> SCCDType:
+        return SCCDBool
+
     def eval(self, memory: MemoryInterface):
         return self.b
 
@@ -259,6 +321,9 @@ class DurationLiteral(Expression):
     def init_expr(self, scope: Scope) -> SCCDType:
         return SCCDDuration
 
+    def get_type(self) -> SCCDType:
+        return SCCDDuration
+
     def eval(self, memory: MemoryInterface):
         return self.d
 
@@ -280,14 +345,17 @@ class Array(Expression):
 
         return SCCDArray(self.element_type)
 
+    def get_type(self) -> SCCDType:
+        return SCCDArray(self.element_type)
+
     def eval(self, memory: MemoryInterface):
         return [e.eval(memory) for e in self.elements]
 
     def render(self):
         return '['+','.join([e.render() for e in self.elements])+']'
 
-# Does not add anything semantically, but ensures that when rendering an expression,
-# the parenthesis are not lost
+# A group of parentheses in the concrete syntax.
+# Does not add anything semantically, but allows us to go back from abstract to concrete textual syntax without weird rules
 @dataclass
 class Group(Expression):
     subexpr: Expression
@@ -295,6 +363,9 @@ class Group(Expression):
     def init_expr(self, scope: Scope) -> SCCDType:
         return self.subexpr.init_expr(scope)
 
+    def get_type(self) -> SCCDType:
+        return self.subexpr.get_type()
+
     def eval(self, memory: MemoryInterface):
         return self.subexpr.eval(memory)
 
@@ -307,6 +378,8 @@ class BinaryExpression(Expression):
     operator: str # token name from the grammar.
     rhs: Expression
 
+    type: Optional[SCCDType] = None
+
     def init_expr(self, scope: Scope) -> SCCDType:
         lhs_t = self.lhs.init_expr(scope)
         rhs_t = self.rhs.init_expr(scope)
@@ -339,7 +412,7 @@ class BinaryExpression(Expression):
         def exp():
             return lhs_t.exp(rhs_t)
 
-        t = {
+        self.type = {
             "and": logical,
             "or":  logical,
             "==":  eq,
@@ -357,10 +430,13 @@ class BinaryExpression(Expression):
             "**":  exp,
         }[self.operator]()
 
-        if t is None:
+        if self.type is None:
             raise StaticTypeError("Illegal types for '%s'-operation: %s and %s" % (self.operator, lhs_t, rhs_t))
 
-        return t
+        return self.type
+
+    def get_type(self) -> SCCDType:
+        return self.type
 
     def eval(self, memory: MemoryInterface):
         return {
@@ -389,6 +465,8 @@ class UnaryExpression(Expression):
     operator: str # token value from the grammar.
     expr: Expression
 
+    type: Optional[SCCDType] = None
+
     def init_expr(self, scope: Scope) -> SCCDType:
         expr_type = self.expr.init_expr(scope)
 
@@ -400,15 +478,18 @@ class UnaryExpression(Expression):
             if expr_type.is_neg():
                 return expr_type
 
-        t = {
+        self.type = {
             "not": logical,
             "-":   neg,
         }[self.operator]()
 
-        if t is None:
+        if self.type is None:
             raise StaticTypeError("Illegal type for unary '%s'-expression: %s" % (self.operator, expr_type))
 
-        return t
+        return self.type
+
+    def get_type(self) -> SCCDType:
+        return self.type
 
     def eval(self, memory: MemoryInterface):
         return {

+ 27 - 4
src/sccd/action_lang/static/scope.py

@@ -7,7 +7,7 @@ from sccd.common.exceptions import *
 from sccd.util.visitable import *
 import itertools
 import termcolor
-
+from collections import defaultdict
 
 class ScopeError(ModelStaticError):
   def __init__(self, scope, msg):
@@ -33,17 +33,28 @@ class _Variable(ABC):
   def __str__(self):
     return "+%d: %s%s: %s" % (self.offset, "(const) "if self.const else "", self.name, str(self.type))
 
+# _scope_ctr = 0
 
 # Stateless stuff we know about a scope (= set of named values)
 class Scope(Visitable):
   __slots__ = ["name", "parent", "parent_offset", "names", "variables"]
 
   def __init__(self, name: str, parent: 'Scope'):
+    # global _scope_ctr
+    # self.id = _scope_ctr # just a unique ID within the AST (for code generation)
+    # _scope_ctr += 1
+
+
     self.name = name
+
     self.parent = parent
-    if parent:
+    self.children = defaultdict(list) # mapping from offset to child scope
+
+    if parent is not None:
       # Position of the start of this scope, seen from the parent scope
-      self.parent_offset = self.parent.size()
+      self.parent_offset = parent.size()
+      # Append to parent
+      parent.children[self.parent_offset].append(self)
     else:
       self.parent_offset = None # value should never be used
 
@@ -53,6 +64,8 @@ class Scope(Visitable):
     # All non-constant values, ordered by memory position
     self.variables: List[_Variable] = []
 
+    self.deepest_lookup = 0
+
 
   def size(self) -> int:
     return len(self.variables)
@@ -81,10 +94,20 @@ class Scope(Visitable):
       return (self, offset, self.names[name])
     except KeyError:
       if self.parent is not None:
-        return self.parent._internal_lookup(name, offset - self.parent_offset)
+        got_it = self.parent._internal_lookup(name, offset - self.parent_offset)
+        if got_it:
+          scope, off, v = got_it
+          self.deepest_lookup = max(self.deepest_lookup, self.nested_levels(off))
+        return got_it
       else:
         return None
 
+  def nested_levels(self, offset):
+    if offset >= 0:
+      return 0
+    else:
+      return 1 + self.parent.nested_levels(offset + self.parent_offset)
+
   # Create name in this scope
   # Precondition: _internal_lookup of name returns 'None'
   def _internal_add(self, name, type, const, initial_value: 'Expression') -> int:

+ 8 - 1
src/sccd/action_lang/static/statement.py

@@ -27,7 +27,7 @@ class ReturnBehavior:
     def __post_init__(self):
         assert (self.when == ReturnBehavior.When.NEVER) == (self.type is None)
 
-    def get_return_type(self) -> type:
+    def get_return_type(self) -> Optional[SCCDType]:
         if self.when == ReturnBehavior.When.ALWAYS:
             return self.type
         elif self.when == ReturnBehavior.When.SOME_BRANCHES:
@@ -102,6 +102,10 @@ class Assignment(Statement):
     def init_stmt(self, scope: Scope) -> ReturnBehavior:
         rhs_t = self.rhs.init_expr(scope)
         self.is_initialization = self.lhs.init_lvalue(scope, rhs_t, self.rhs)
+        # Very common case of assignment of a function to an identifier:
+        # Make the function's scope name a little bit more expressive
+        if isinstance(self.rhs, FunctionDeclaration) and isinstance(self.lhs, Identifier):
+            self.rhs.scope.name += "_" + self.lhs.name
         return NeverReturns
 
     def exec(self, memory: MemoryInterface) -> Return:
@@ -162,7 +166,10 @@ class ExpressionStatement(Statement):
 class ReturnStatement(Statement):
     expr: Expression
 
+    scope: Optional[Scope] = None
+
     def init_stmt(self, scope: Scope) -> ReturnBehavior:
+        self.scope = scope
         t = self.expr.init_expr(scope)
         if t is None:
             raise StaticTypeError("Return statement: Expression does not evaluate to a value.")

+ 45 - 0
src/sccd/action_lang/static/types.py

@@ -3,6 +3,7 @@ from dataclasses import *
 from typing import *
 import termcolor
 from sccd.util.visitable import *
+from functools import reduce
 
 class SCCDType(ABC, Visitable):
     @abstractmethod
@@ -101,6 +102,7 @@ class _SCCDSimpleType(SCCDType):
 class SCCDFunction(SCCDType):
     param_types: List[SCCDType]
     return_type: Optional[SCCDType] = None
+    function: Optional['FunctionDeclaration'] = None
 
     def _str(self):
         if self.param_types:
@@ -111,6 +113,9 @@ class SCCDFunction(SCCDType):
             s += " -> " + self.return_type._str()
         return s
 
+    def __eq__(self, other):
+        return isinstance(other, SCCDFunction) and self.param_types == other.param_types and self.return_type == other.return_type
+
 @dataclass(frozen=True, repr=False)
 class SCCDArray(SCCDType):
     element_type: SCCDType
@@ -123,6 +128,46 @@ class SCCDArray(SCCDType):
             return True
         return False
 
+@dataclass(frozen=True, repr=False)
+class SCCDTuple(SCCDType):
+    element_types: List[SCCDType]
+
+    def _str(self):
+        return "(" + ", ".join(t._str for t in self.element_types) + ")"
+
+    def is_eq(self, other):
+        return instance(other, SCCDTuple) and len(self.element_types) == len(other.element_types) and reduce(lambda x,y: x and y, (t1.is_eq(t2) for (t1, t2) in zip(self.element_types, other.element_types)))
+
+@dataclass(frozen=True, repr=False)
+class SCCDClosureObject(SCCDType):
+    scope: 'Scope'
+    function_type: SCCDFunction
+
+    def _str(self):
+        return "Closure(scope=%s, func=%s)" % (self.scope.name, self.function_type._str())
+
+
+# @dataclass(frozen=True, repr=False)
+# class SCCDFunctionCallResult(SCCDType):
+#     function_type: SCCDFunction
+#     return_type: SCCDType
+
+#     def _str(self):
+#         return "CallResult(%s)" % self.return_type._str()
+
+#     def is_eq(self, other):
+#         return return_type.is_eq(other)
+
+# @dataclass(frozen=True, repr=False)
+# class SCCDScope(SCCDType):
+#     scope: 'Scope'
+
+#     def _str(self):
+#         return "Scope(%s)" % scope.name
+
+#     def is_eq(self, other):
+#         return self.scope is other.scope
+
 SCCDBool = _SCCDSimpleType("bool", eq=True, bool_cast=True)
 SCCDInt = _SCCDSimpleType("int", neg=True, summable=True, eq=True, ord=True, bool_cast=True)
 SCCDFloat = _SCCDSimpleType("float", neg=True, summable=True, eq=True, ord=True)

+ 42 - 0
src/sccd/statechart/codegen/libstatechart.rs

@@ -165,3 +165,45 @@ Controller<EventType, OutputCallback> {
     }
   }
 }
+
+use std::ops::Deref;
+use std::ops::DerefMut;
+
+// This macro lets a struct "inherit" the data members of another struct
+// The inherited struct is added as a struct member and the Deref and DerefMut
+// traits are implemented to return a reference to the base struct
+macro_rules! inherit_struct {
+    ($name: ident ($base: ty) { $($element: ident: $ty: ty),* $(,)? } ) => {
+        struct $name {
+            _base: $base,
+            $($element: $ty),*
+        }
+        impl Deref for $name {
+            type Target = $base;
+            fn deref(&self) -> &$base {
+                &self._base
+            }
+        }
+        impl DerefMut for $name {
+            fn deref_mut(&mut self) -> &mut $base {
+                &mut self._base
+            }
+        }
+    }
+}
+
+// "Base struct" for all scopes
+struct Empty{}
+
+// A closure object is a pair of a functions first argument and that function.
+// The call may be part of an larger expression, and therefore we cannot just write 'let' statements to assign the pair's elements to identifiers which we need for the call.
+// This macro does exactly that, in an anonymous Rust closure, which is immediately called.
+macro_rules! call_closure {
+  ($closure: expr, $($param: expr),*  $(,)?) => {
+    (||{
+      let scope = &mut $closure.0;
+      let function = &mut $closure.1;
+      return function(scope, $($param),* );
+    })()
+  };
+}

+ 53 - 65
src/sccd/statechart/codegen/rust.py

@@ -9,7 +9,7 @@ from sccd.statechart.static import priority
 from sccd.util.indenting_writer import *
 
 # Hardcoded limit on number of sub-rounds of combo and big step to detect never-ending superrounds.
-# TODO: make this a model parameter
+# TODO: make this a model parameter, also allowing for +infinity
 LIMIT = 1000
 
 # Conversion functions from abstract syntax elements to identifiers in Rust
@@ -90,7 +90,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         # Write 'current state' types
         if isinstance(state.type, AndState):
             self.w.writeln("// And-state")
-            # TODO: Only annotate Copy for states that will be recorded by deep history.
+            # We need Copy for states that will be recorded as history.
             self.w.writeln("#[derive(Default, Copy, Clone)]")
             self.w.writeln("struct %s {" % ident_type(state))
             for child in state.real_children:
@@ -115,76 +115,69 @@ class StatechartRustGenerator(ActionLangRustGenerator):
             self.w.writeln("}")
 
         # Implement trait 'State': enter/exit
-        # self.w.writeln("impl<'a, OutputCallback: FnMut(OutEvent)> State<Timers, Controller<InEvent, OutputCallback>> for %s {" % ident_type(state))
-        # self.w.writeln("impl<'a, OutputCallback: FnMut(OutEvent)> %s {" % ident_type(state))
         self.w.writeln("impl %s {" % ident_type(state))
 
         # Enter actions: Executes enter actions of only this state
-        # self.w.writeln("  fn enter_actions(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("  fn enter_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn enter_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, data: &mut DataModel, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
         self.w.writeln("    eprintln!(\"enter %s\");" % state.full_name);
+        self.w.writeln("    let scope = data;")
         self.w.indent(); self.w.indent()
         for a in state.enter:
             a.accept(self)
-        # compile_actions(state.enter, w)
         self.w.dedent(); self.w.dedent()
         for a in state.after_triggers:
             self.w.writeln("    timers[%d] = ctrl.set_timeout(%d, InEvent::%s);" % (a.after_id, a.delay.opt, ident_event_type(a.enabling[0].name)))
         self.w.writeln("  }")
 
         # Enter actions: Executes exit actions of only this state
-        # self.w.writeln("  fn exit_actions(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("  fn exit_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("    eprintln!(\"exit %s\");" % state.full_name);
+        self.w.writeln("  fn exit_actions<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, data: &mut DataModel, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("    let scope = data;")
         for a in state.after_triggers:
             self.w.writeln("    ctrl.unset_timeout(timers[%d]);" % (a.after_id))
         self.w.indent(); self.w.indent()
         for a in state.exit:
             a.accept(self)
-        # compile_actions(state.exit, w)
+        self.w.writeln("    eprintln!(\"exit %s\");" % state.full_name);
         self.w.dedent(); self.w.dedent()
         self.w.writeln("  }")
 
         # Enter default: Executes enter actions of entering this state and its default substates, recursively
-        # self.w.writeln("  fn enter_default(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("  fn enter_default<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("    %s::enter_actions(timers, internal, ctrl);" % (ident_type(state)))
+        self.w.writeln("  fn enter_default<OutputCallback: FnMut(OutEvent)>(timers: &mut Timers, data: &mut DataModel, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("    %s::enter_actions(timers, data, internal, ctrl);" % (ident_type(state)))
         if isinstance(state.type, AndState):
             for child in state.real_children:
-                self.w.writeln("    %s::enter_default(timers, internal, ctrl);" % (ident_type(child)))
+                self.w.writeln("    %s::enter_default(timers, data, internal, ctrl);" % (ident_type(child)))
         elif isinstance(state.type, OrState):
-            self.w.writeln("    %s::enter_default(timers, internal, ctrl);" % (ident_type(state.type.default_state)))
+            self.w.writeln("    %s::enter_default(timers, data, internal, ctrl);" % (ident_type(state.type.default_state)))
         self.w.writeln("  }")
 
         # Exit current: Executes exit actions of this state and current children, recursively
-        # self.w.writeln("  fn exit_current(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("  fn exit_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn exit_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, data: &mut DataModel, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
         # first, children (recursion):
         if isinstance(state.type, AndState):
             for child in state.real_children:
-                self.w.writeln("    self.%s.exit_current(timers, internal, ctrl);" % (ident_field(child)))
+                self.w.writeln("    self.%s.exit_current(timers, data, internal, ctrl);" % (ident_field(child)))
         elif isinstance(state.type, OrState):
             self.w.writeln("    match self {")
             for child in state.real_children:
-                self.w.writeln("      Self::%s(s) => { s.exit_current(timers, internal, ctrl); }," % (ident_enum_variant(child)))
+                self.w.writeln("      Self::%s(s) => { s.exit_current(timers, data, internal, ctrl); }," % (ident_enum_variant(child)))
             self.w.writeln("    }")
         # then, parent:
-        self.w.writeln("    %s::exit_actions(timers, internal, ctrl);" % (ident_type(state)))
+        self.w.writeln("    %s::exit_actions(timers, data, internal, ctrl);" % (ident_type(state)))
         self.w.writeln("  }")
 
         # Exit current: Executes enter actions of this state and current children, recursively
-        # self.w.writeln("  fn enter_current(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        self.w.writeln("  fn enter_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
+        self.w.writeln("  fn enter_current<OutputCallback: FnMut(OutEvent)>(&self, timers: &mut Timers, data: &mut DataModel, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>) {")
         # first, parent:
-        self.w.writeln("    %s::enter_actions(timers, internal, ctrl);" % (ident_type(state)))
+        self.w.writeln("    %s::enter_actions(timers, data, internal, ctrl);" % (ident_type(state)))
         # then, children (recursion):
         if isinstance(state.type, AndState):
             for child in state.real_children:
-                self.w.writeln("    self.%s.enter_current(timers, internal, ctrl);" % (ident_field(child)))
+                self.w.writeln("    self.%s.enter_current(timers, data, internal, ctrl);" % (ident_field(child)))
         elif isinstance(state.type, OrState):
             self.w.writeln("    match self {")
             for child in state.real_children:
-                self.w.writeln("      Self::%s(s) => { s.enter_current(timers, internal, ctrl); }," % (ident_enum_variant(child)))
+                self.w.writeln("      Self::%s(s) => { s.enter_current(timers, data, internal, ctrl); }," % (ident_enum_variant(child)))
             self.w.writeln("    }")
         self.w.writeln("  }")
 
@@ -192,6 +185,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln()
 
     def visit_Statechart(self, sc):
+        self.scope.push(sc.scope)
+
         if sc.semantics.concurrency == Concurrency.MANY:
             raise UnsupportedFeature("concurrency")
 
@@ -251,9 +246,6 @@ class StatechartRustGenerator(ActionLangRustGenerator):
             # self.w.writeln("}")
         self.w.writeln()
 
-        # Write state types
-        tree.root.accept(self)
-
         syntactic_maximality = (
             sc.semantics.big_step_maximality == Maximality.SYNTACTIC
             or sc.semantics.combo_step_maximality == Maximality.SYNTACTIC)
@@ -286,38 +278,39 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         #     self.w.writeln("const ARENA_UNSTABLE: Arenas = false; // inapplicable to chosen semantics - all transition targets considered stable")
         self.w.writeln()
 
-        # Write datamodel type
-        # sc.scope.accept(self)
-        # RustGenerator(w).visit_Scope(sc.scope)
-        # self.w.writeln("struct DataModel {")
-
-        # self.w.writeln("  ")
-        # self.w.writeln(")")
-
         # Write statechart type
-        self.w.writeln("pub struct Statechart {")
-        self.w.writeln("  current_state: %s," % ident_type(tree.root))
-        # We always store a history value as 'deep' (also for shallow history).
-        # TODO: We may save a tiny bit of space in some rare cases by storing shallow history as only the exited child of the Or-state.
-        for h in tree.history_states:
-            self.w.writeln("  %s: %s," % (ident_history_field(h), ident_type(h.parent)))
-        self.w.writeln("  timers: Timers,")
-        self.w.writeln("  data: %s," % self.ident_scope(sc.scope))
-        self.w.writeln("}")
-
         self.w.writeln("impl Default for Statechart {")
         self.w.writeln("  fn default() -> Self {")
+        self.w.writeln("    // Initialize data model")
+        self.w.indent(); self.w.indent();
+        self.w.writeln("    let scope = Empty{};")
+        if sc.datamodel is not None:
+            sc.datamodel.accept(self)
+        datamodel_type = self.scope.commit(sc.scope.size(), self.w)
+        self.w.dedent(); self.w.dedent();
         self.w.writeln("    Self {")
         self.w.writeln("      current_state: Default::default(),")
         for h in tree.history_states:
             self.w.writeln("      %s: Default::default()," % (ident_history_field(h)))
         self.w.writeln("      timers: Default::default(),")
-        self.w.writeln("      data: %s()," % self.ident_new_scope(sc.scope))
+        self.w.writeln("      data: scope,")
         self.w.writeln("    }")
         self.w.writeln("  }")
         self.w.writeln("}")
+        self.w.writeln("type DataModel = %s;" % datamodel_type)
+        self.w.writeln("pub struct Statechart {")
+        self.w.writeln("  current_state: %s," % ident_type(tree.root))
+        # We always store a history value as 'deep' (also for shallow history).
+        # TODO: We may save a tiny bit of space in some rare cases by storing shallow history as only the exited child of the Or-state.
+        for h in tree.history_states:
+            self.w.writeln("  %s: %s," % (ident_history_field(h), ident_type(h.parent)))
+        self.w.writeln("  timers: Timers,")
+        self.w.writeln("  data: DataModel,")
+        self.w.writeln("}")
         self.w.writeln()
 
+        self.write_decls()
+
         # Function fair_step: a single "Take One" Maximality 'round' (= nonoverlapping arenas allowed to fire 1 transition)
         self.w.writeln("fn fair_step<OutputCallback: FnMut(OutEvent)>(sc: &mut Statechart, input: Option<InEvent>, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>, dirty: Arenas) -> Arenas {")
         self.w.writeln("  let mut fired: Arenas = ARENA_NONE;")
@@ -352,7 +345,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
 
                     if len(exit_path) == 1:
                         # Exit s:
-                        self.w.writeln("%s.exit_current(&mut sc.timers, internal, ctrl);" % (ident_var(s)))
+                        self.w.writeln("%s.exit_current(&mut sc.timers, &mut sc.data, internal, ctrl);" % (ident_var(s)))
                     else:
                         # Exit children:
                         if isinstance(s.type, AndState):
@@ -360,12 +353,12 @@ class StatechartRustGenerator(ActionLangRustGenerator):
                                 if exit_path[1] is c:
                                     write_exit(exit_path[1:]) # continue recursively
                                 else:
-                                    self.w.writeln("%s.exit_current(&mut sc.timers, internal, ctrl);" % (ident_var(c)))
+                                    self.w.writeln("%s.exit_current(&mut sc.timers, &mut sc.data, internal, ctrl);" % (ident_var(c)))
                         elif isinstance(s.type, OrState):
                             write_exit(exit_path[1:]) # continue recursively with the next child on the exit path
 
                         # Exit s:
-                        self.w.writeln("%s::exit_actions(&mut sc.timers, internal, ctrl);" % (ident_type(s)))
+                        self.w.writeln("%s::exit_actions(&mut sc.timers, &mut sc.data, internal, ctrl);" % (ident_type(s)))
 
                     # Store history
                     if s.deep_history:
@@ -389,19 +382,19 @@ class StatechartRustGenerator(ActionLangRustGenerator):
                     if len(enter_path) == 1:
                         # Target state.
                         if isinstance(s, HistoryState):
-                            self.w.writeln("sc.%s.enter_current(&mut sc.timers, internal, ctrl); // Enter actions for history state" %(ident_history_field(s)))
+                            self.w.writeln("sc.%s.enter_current(&mut sc.timers, &mut sc.data, internal, ctrl); // Enter actions for history state" %(ident_history_field(s)))
                         else:
-                            self.w.writeln("%s::enter_default(&mut sc.timers, internal, ctrl);" % (ident_type(s)))
+                            self.w.writeln("%s::enter_default(&mut sc.timers, &mut sc.data, internal, ctrl);" % (ident_type(s)))
                     else:
                         # Enter s:
-                        self.w.writeln("%s::enter_actions(&mut sc.timers, internal, ctrl);" % (ident_type(s)))
+                        self.w.writeln("%s::enter_actions(&mut sc.timers, &mut sc.data, internal, ctrl);" % (ident_type(s)))
                         # Enter children:
                         if isinstance(s.type, AndState):
                             for c in s.children:
                                 if enter_path[1] is c:
                                     write_enter(enter_path[1:]) # continue recursively
                                 else:
-                                    self.w.writeln("%s::enter_default(&mut sc.timers, internal, ctrl);" % (ident_type(c)))
+                                    self.w.writeln("%s::enter_default(&mut sc.timers, &mut sc.data, internal, ctrl);" % (ident_type(c)))
                         elif isinstance(s.type, OrState):
                             if len(s.children) > 0:
                                 write_enter(enter_path[1:]) # continue recursively with the next child on the enter path
@@ -474,6 +467,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
                         self.w.writeln("if %s {" % " && ".join(condition))
                         self.w.indent()
 
+                    self.w.writeln("let parent1 = scope;")
+
                     if t.guard is not None:
                         self.w.write("if ")
                         t.guard.accept(self)
@@ -641,12 +636,7 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         # Implement 'SC' trait
         self.w.writeln("impl<OutputCallback: FnMut(OutEvent)> SC<InEvent, Controller<InEvent, OutputCallback>> for Statechart {")
         self.w.writeln("  fn init(&mut self, ctrl: &mut Controller<InEvent, OutputCallback>) {")
-        if sc.datamodel is not None:
-            self.w.indent(); self.w.indent();
-            self.w.writeln("let scope = &mut self.data;")
-            sc.datamodel.accept(self)
-            self.w.dedent(); self.w.dedent();
-        self.w.writeln("    %s::enter_default(&mut self.timers, &mut Default::default(), ctrl)" % (ident_type(tree.root)))
+        self.w.writeln("    %s::enter_default(&mut self.timers, &mut self.data, &mut Default::default(), ctrl)" % (ident_type(tree.root)))
         self.w.writeln("  }")
         self.w.writeln("  fn big_step(&mut self, input: Option<InEvent>, c: &mut Controller<InEvent, OutputCallback>) {")
         self.w.writeln("    let mut internal: InternalLifeline = Default::default();")
@@ -655,10 +645,8 @@ class StatechartRustGenerator(ActionLangRustGenerator):
         self.w.writeln("}")
         self.w.writeln()
 
-
-        # Write 'instance' scope (and all other scopes, recursively)
-        self.write_scopes()
-        self.w.writeln()
+        # Write state types
+        tree.root.accept(self)
 
         if DEBUG:
             self.w.writeln("use std::mem::size_of;")

+ 2 - 2
src/sccd/statechart/parser/xml.py

@@ -27,7 +27,7 @@ def statechart_parser_rules(globals, path, load_external = True, parse_f = parse
     if ext_file is None:
       statechart = Statechart(
         semantics=SemanticConfiguration(),
-        scope=Scope("instance", parent=BuiltIn),
+        scope=Scope("statechart", parent=BuiltIn),
         datamodel=None,
         internal_events=Bitmap(),
         internally_raised_events=Bitmap(),
@@ -210,7 +210,7 @@ def statechart_parser_rules(globals, path, load_external = True, parse_f = parse
           if parent is root:
             raise XmlError("Root cannot be source of a transition.")
 
-          scope = Scope("event_params", parent=statechart.scope)
+          scope = Scope("transition", parent=statechart.scope)
           target_string = require_attribute(el, "target")
           transition = Transition(source=parent, target_string=target_string, scope=scope)
 

+ 1 - 1
src/sccd/test/codegen/rust.py

@@ -19,7 +19,7 @@ def compile_test(variants: List[TestVariant], w: IndentingWriter):
     w.writeln("#![allow(unused_labels)]")
     w.writeln("#![allow(unused_variables)]")
     w.writeln("#![allow(dead_code)]")
-
+    w.writeln("#![allow(unused_parens)]")
 
     with open(rustlib, 'r') as file:
         data = file.read()