Prechádzať zdrojové kódy

Rust: first steps in implementing action language compiler

Joeri Exelmans 4 rokov pred
rodič
commit
4b23f11d71

+ 0 - 0
src/sccd/action_lang/codegen/__init__.py


+ 105 - 0
src/sccd/action_lang/codegen/rust.py

@@ -0,0 +1,105 @@
+from sccd.action_lang.static.statement import *
+
+class UnsupportedFeature(Exception):
+    pass
+
+class RustGenerator(Visitor):
+    def __init__(self, w):
+        self.w = w
+
+    def default(self, what):
+        raise UnsupportedFeature(what)
+
+    def visit_Block(self, stmt):
+        # self.w.writeln("{")
+        for s in stmt.stmts:
+            s.accept(self)
+        # self.w.writeln("}")
+
+    def visit_Assignment(self, stmt):
+        self.w.write('')
+        stmt.lhs.accept(self)
+        self.w.wno(" = ")
+        stmt.rhs.accept(self)
+        self.w.wnoln(";")
+
+    def visit_IfStatement(self, stmt):
+        self.w.wno("if ")
+        stmt.cond.accept(self)
+        self.w.wnoln(" {")
+        self.w.indent()
+        stmt.if_body.accept(self)
+        self.w.dedent()
+        self.w.writeln("}")
+        if stmt.else_body is not None:
+            self.w.writeln("else {")
+            self.w.indent()
+            stmt.else_body.accept(self)
+            self.w.dedent()
+            self.w.writeln("}")
+
+    def visit_ReturnStatement(self, stmt):
+        self.w.wno("return ")
+        stmt.expr.accept(self)
+        self.w.writeln(";")
+
+    def visit_BoolLiteral(self, expr):
+        self.w.wno("true" if expr.b else "false")
+
+    def visit_IntLiteral(self, expr):
+        self.w.wno(str(expr.i))
+
+    def visit_StringLiteral(self, expr):
+        self.w.wno(expr.string)
+
+    def visit_ArrayLiteral(self, expr):
+        self.w.wno("[")
+        for el in expr.elements:
+            el.accept(self)
+            self.w.wno(", ")
+        self.w.wno("]")
+
+    def visit_BinaryExpression(self, expr):
+        self.w.wno(" (")
+
+        if expr.operator == "**":
+            self.w.wno(" pow(")
+            expr.lhs.accept(self)
+            self.w.wno(", ")
+            expr.rhs.accept(self)
+            self.w.wno(")")
+        else:
+            expr.lhs.accept(self)
+            self.w.wno(" %s " % expr.operator
+                .replace('and', '&&')
+                .replace('or', '||')
+                .replace('//', '/')) # integer division
+            expr.rhs.accept(self)
+
+        self.w.wno(") ")
+
+    def visit_UnaryExpression(self, expr):
+        self.w.wno(expr.operator
+            .replace('not', '!'))
+        expr.expr.accept(self)
+
+    def visit_Group(self, expr):
+        # self.w.wno(" (")
+        expr.subexpr.accept(self)
+        # self.w.wno(") ")
+
+    def visit_Identifier(self, lval):
+        self.w.wno(lval.name)
+
+    def visit_Scope(self, scope):
+        self.w.writeln("struct Scope_%s {" % scope.name)
+        for v in scope.variables:
+            self.w.write("  %s: " % v.name)
+            v.type.accept(self)
+            self.w.wnoln(",")
+        self.w.writeln("}")
+        self.w.writeln()
+
+    def visit__SCCDSimpleType(self, type):
+        self.w.wno(type.name
+            .replace("int", "i32"))

+ 2 - 1
src/sccd/action_lang/static/expression.py

@@ -2,6 +2,7 @@ from abc import *
 from typing import *
 from dataclasses import *
 from sccd.util.duration import *
+from sccd.util.visitable import *
 from sccd.action_lang.static.scope import *
 
 class MemoryInterface(ABC):
@@ -35,7 +36,7 @@ class MemoryInterface(ABC):
 class StaticTypeError(ModelStaticError):
     pass
 
-class Expression(ABC):
+class Expression(ABC, Visitable):
     # Run static analysis on the expression.
     # Must be called exactly once on each expression, before any call to eval is made.
     # Determines the static type of the expression. May throw if there is a type error.

+ 4 - 2
src/sccd/action_lang/static/scope.py

@@ -4,6 +4,7 @@ from dataclasses import *
 from inspect import signature
 from sccd.action_lang.static.types import *
 from sccd.common.exceptions import *
+from sccd.util.visitable import *
 import itertools
 import termcolor
 
@@ -25,14 +26,15 @@ class _Variable(ABC):
 
   @property
   def name(self):
-    return termcolor.colored(self._name, 'yellow')
+    return self._name
+    # return termcolor.colored(self._name, 'yellow')
 
   def __str__(self):
     return "+%d: %s%s: %s" % (self.offset, "(const) "if self.const else "", self.name, str(self.type))
 
 
 # Stateless stuff we know about a scope (= set of named values)
-class Scope:
+class Scope(Visitable):
   __slots__ = ["name", "parent", "parent_offset", "names", "variables"]
 
   def __init__(self, name: str, parent: 'Scope'):

+ 2 - 4
src/sccd/action_lang/static/statement.py

@@ -74,7 +74,7 @@ NeverReturns = ReturnBehavior(ReturnBehavior.When.NEVER)
 AlwaysReturns = lambda t: ReturnBehavior(ReturnBehavior.When.ALWAYS, t)
 
 # A statement is NOT an expression.
-class Statement(ABC):
+class Statement(ABC, Visitable):
     # Run static analysis on the statement.
     # Looks up identifiers in the given scope, and adds new identifiers to the scope.
     @abstractmethod
@@ -112,8 +112,6 @@ class Assignment(Statement):
     def render(self) -> str:
         return self.lhs.render() + ' = ' + self.rhs.render() #+ '⁏'
 
-
-
 @dataclass
 class Block(Statement):
     stmts: List[Statement]
@@ -122,7 +120,7 @@ class Block(Statement):
         so_far = NeverReturns
         for i, stmt in enumerate(self.stmts):
             now_what = stmt.init_stmt(scope)
-            so_far = ReturnBehavior.sequence(so_far, now_what)            
+            so_far = ReturnBehavior.sequence(so_far, now_what)
         return so_far
 
     def exec(self, memory: MemoryInterface) -> Return:

+ 2 - 1
src/sccd/action_lang/static/types.py

@@ -2,8 +2,9 @@ from abc import *
 from dataclasses import *
 from typing import *
 import termcolor
+from sccd.util.visitable import *
 
-class SCCDType(ABC):
+class SCCDType(ABC, Visitable):
     @abstractmethod
     def _str(self):
         pass

+ 24 - 16
src/sccd/statechart/codegen/rust.py

@@ -1,4 +1,6 @@
 from typing import *
+import io
+from sccd.action_lang.codegen.rust import *
 from sccd.statechart.static.tree import *
 from sccd.util.visit_tree import *
 from sccd.statechart.static.statechart import *
@@ -10,9 +12,6 @@ from sccd.util.indenting_writer import *
 # TODO: make this a model parameter
 LIMIT = 1000
 
-class UnsupportedFeature(Exception):
-    pass
-
 # Conversion functions from abstract syntax elements to identifiers in Rust
 
 def snake_case(state: State) -> str:
@@ -34,7 +33,7 @@ def ident_enum_variant(state: State) -> str:
     # We know the direct children of a state must have unique names relative to each other,
     # and enum variants are scoped locally, so we can use the short name here.
     # Furthermore, the XML parser asserts that state ids are valid identifiers in Rust.
-    return state.short_name
+    return "S" + state.short_name
 
 def ident_field(state: State) -> str:
     return "s" + snake_case(state)
@@ -75,8 +74,10 @@ def compile_actions(actions: List[Action], w: IndentingWriter):
             w.writeln("(ctrl.output)(OutEvent{port:\"%s\", event:\"%s\"});" % (a.outport, a.name))
         elif isinstance(a, RaiseInternalEvent):
             w.writeln("internal.raise().%s = Some(%s{});" % (ident_event_field(a.name), (ident_event_type(a.name))))
+        elif isinstance(a, Code):
+            a.block.accept(RustGenerator(w))
         else:
-            raise UnsupportedFeature(str(type(a)))
+            raise UnsupportedFeature(str(type(a).__qualname__))
 
 def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
 
@@ -128,8 +129,6 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
 
         if internal_same_round:
             w.writeln("type InternalLifeline = SameRoundLifeline<Internal>;")
-            # w.writeln("struct InternalLifeline {")
-            # w.writeln("}")
         else:
             w.writeln("type InternalLifeline = NextRoundLifeline<Internal>;")
     elif internal_type == "queue":
@@ -286,6 +285,13 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
     #     w.writeln("const ARENA_UNSTABLE: Arenas = false; // inapplicable to chosen semantics - all transition targets considered stable")
     w.writeln()
 
+    # Write datamodel type
+    RustGenerator(w).visit_Scope(sc.scope)
+    # w.writeln("struct DataModel {")
+
+    # w.writeln("  ")
+    # w.writeln(")")
+
     # Write statechart type
     w.writeln("pub struct Statechart {")
     w.writeln("  current_state: %s," % ident_type(tree.root))
@@ -294,6 +300,7 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
     for h in tree.history_states:
         w.writeln("  %s: %s," % (ident_history_field(h), ident_type(h.parent)))
     w.writeln("  timers: Timers,")
+    w.writeln("  data: Scope_instance,")
     w.writeln("}")
 
     w.writeln("impl Default for Statechart {")
@@ -465,7 +472,10 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
                     w.indent()
 
                 if t.guard is not None:
-                    raise UnsupportedFeature("Guard conditions currently unsupported")
+                    w.write("if ")
+                    t.guard.accept(RustGenerator(w))
+                    w.wnoln(" {")
+                    w.indent()
 
                 # 1. Execute transition's actions
 
@@ -510,6 +520,10 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
                 # This arena is done:
                 w.writeln("break '%s;" % (ident_arena_label(t.arena)))
 
+                if t.guard is not None:
+                    w.dedent()
+                    w.writeln("}")
+
                 if t.trigger is not EMPTY_TRIGGER:
                     w.dedent()
                     w.writeln("}")
@@ -530,9 +544,6 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
                     write_transitions(child)
             elif isinstance(state.type, OrState):
                 if state.type.default_state is not None:
-                    # if syntactic_maximality and state in arenas:
-                    #     w.writeln("if dirty & %s == ARENA_NONE {" % ident_arena_const(state))
-                    #     w.indent()
                     if state in arenas:
                         w.writeln("if (fired | dirty) & %s == ARENA_NONE {" % ident_arena_const(state))
                         w.indent()
@@ -556,9 +567,6 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
                     if state in arenas:
                         w.dedent()
                         w.writeln("}")
-                    # if syntactic_maximality and state in arenas:
-                    #     w.dedent()
-                    #     w.writeln("}")
 
         if sc.semantics.hierarchical_priority == HierarchicalPriority.SOURCE_PARENT:
             parent()
@@ -582,8 +590,7 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
 
     # Write combo step and big step function
     def write_stepping_function(name: str, title: str, maximality: Maximality, substep: str, cycle_input: bool, cycle_internal: bool):
-        w.write("fn %s<OutputCallback: FnMut(OutEvent)>(sc: &mut Statechart, input: Option<InEvent>, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>, dirty: Arenas)" % (name))
-        w.writeln(" -> Arenas {")
+        w.writeln("fn %s<OutputCallback: FnMut(OutEvent)>(sc: &mut Statechart, input: Option<InEvent>, internal: &mut InternalLifeline, ctrl: &mut Controller<InEvent, OutputCallback>, dirty: Arenas) -> Arenas {" % (name))
         w.writeln("  let mut ctr: u16 = 0;")
         if maximality == Maximality.TAKE_ONE:
             w.writeln("  // %s Maximality: Take One" % title)
@@ -657,3 +664,4 @@ def compile_statechart(sc: Statechart, globals: Globals, w: IndentingWriter):
         w.writeln("  eprintln!(\"info: Arenas: {} bytes\", size_of::<Arenas>());")
         w.writeln("  eprintln!(\"------------------------\");")
         w.writeln("}")
+        w.writeln()

+ 2 - 2
src/sccd/test/codegen/rust.py

@@ -46,9 +46,9 @@ def compile_test(variants: List[TestVariant], w: IndentingWriter):
         w.writeln("sc.init(&mut controller);")
         for i in v.input:
             if len(i.events) > 1:
-                raise Exception("Multiple simultaneous input events not supported")
+                raise UnsupportedFeature("Multiple simultaneous input events not supported")
             elif len(i.events) == 0:
-                raise Exception("Test declares empty bag of input events - not supported")
+                raise UnsupportedFeature("Test declares empty bag of input events")
             w.writeln("controller.set_timeout(%d, InEvent::%s);" % (i.timestamp.opt, ident_event_type(i.events[0].name)))
 
         w.writeln("controller.run_until(&mut sc, Until::Eternity);")

+ 7 - 0
src/sccd/util/indenting_writer.py

@@ -18,5 +18,12 @@ class IndentingWriter:
         else:
             self.out.write(' '*self.state + s + '\n')
 
+    # "write no indent"
+    def wno(self, s):
+        self.out.write(s)
+
+    def wnoln(self, s):
+        self.out.write(s + '\n')
+
     def write(self, s):
         self.out.write(' '*self.state + s)

+ 15 - 0
src/sccd/util/visitable.py

@@ -0,0 +1,15 @@
+import abc
+
+class Visitor:
+    @abc.abstractmethod
+    def default(self):
+        pass
+
+class Visitable:
+    def accept(self, visitor: Visitor):
+        typename = type(self).__qualname__.replace(".", "_")
+        lookup = "visit_" + typename
+        try:
+            return getattr(visitor, lookup)(self)
+        except AttributeError:
+            return visitor.default(what)

+ 3 - 2
test/test_files/features/action_lang/test_expressions.xml

@@ -42,10 +42,11 @@
           <transition cond="2 * 3 == 6" target="../s4"/>
         </state>
         <state id="s4">
-          <transition cond="21 // 3 == 7" target="../s5"/>
+          <transition cond="21 // 3 == 7" target="../s6"/>
         </state>
         <state id="s5">
-          <transition cond="256 == 2 ** 2 ** 3" target="../s6"/>
+          <!-- support for exponent operator dropped -->
+          <!-- <transition cond="256 == 2 ** 2 ** 3" target="../s6"/> -->
         </state>
         <state id="s6">
           <transition cond="5 % 2 == 1" target="../s7"/>