Browse Source

Implemented user-defined functions. Parameter types explicit, return types statically inferred.

Joeri Exelmans 5 years ago
parent
commit
225a6d0c8d

+ 14 - 5
src/sccd/execution/builtin_scope.py

@@ -1,12 +1,21 @@
-import math
-from sccd.execution.statechart_state import *
-
+from sccd.syntax.scope import *
 
 builtin_scope = Scope("builtin", None)
 
 def _in_state(ctx: EvalContext, state_list: List[str]) -> bool:
+  from sccd.execution.statechart_state import StatechartState
+
   return StatechartState.in_state(ctx.current_state, state_list)
 
-builtin_scope.add_function("INSTATE", _in_state)
+builtin_scope.add_python_function("INSTATE", _in_state)
+
+def _log10(ctx: EvalContext, i: int) -> float:
+  import math
+  return math.log10(i)
+
+builtin_scope.add_python_function("log10", _log10)
+
+def _float_to_int(ctx: EvalContext, x: float) -> int:
+  return int(x)
 
-builtin_scope.add_function("log10", lambda _1,_2,_3,i: math.log10(i))
+builtin_scope.add_python_function("float_to_int", _float_to_int)

+ 28 - 14
src/sccd/parser/expression_parser.py

@@ -17,6 +17,8 @@ class _ExpressionTransformer(Transformer):
     super().__init__()
     self.globals: Globals = None
 
+  # Expression and statement parsing
+
   array = Array
 
   block = Block
@@ -79,6 +81,12 @@ class _ExpressionTransformer(Transformer):
   def expression_stmt(self, node):
     return ExpressionStatement(node[0])
 
+  def return_stmt(self, node):
+    return ReturnStatement(node[0])
+
+
+  # Event declaration parsing
+
   def event_decl_list(self, node):
     pos_events = []
     neg_events = []
@@ -94,40 +102,46 @@ class _ExpressionTransformer(Transformer):
   def event_decl(self, node):
     return EventDecl(name=node[0], params=node[1])
 
-  event_params = list
+  params_decl = list
 
-  def event_param_decl(self, node):
+  def param_decl(self, node):
     type = {
       "int": int,
       "str": str,
       "Duration": Duration
     }[node[1]]
-    return Param(name=node[0], type=type)
+    return Param(name=node[0].value, type=type)
+
+  def func_decl(self, node):
+    return (node[0], node[1])
 
 # Global variables so we don't have to rebuild our parser every time
 # Obviously not thread-safe
 _transformer = _ExpressionTransformer()
-_parser = Lark(_action_lang_grammar, parser="lalr", start=["expr", "block", "duration", "event_decl_list", "state_ref", "semantic_choice"], transformer=_transformer)
+_parser = Lark(_action_lang_grammar, parser="lalr", start=["expr", "block", "duration", "event_decl_list", "func_decl", "state_ref", "semantic_choice"], transformer=_transformer)
 
 # Exported functions:
 
-def parse_expression(globals: Globals, expr: str) -> Expression:
+def parse_expression(globals: Globals, text: str) -> Expression:
   _transformer.globals = globals
-  return _parser.parse(expr, start="expr")
+  return _parser.parse(text, start="expr")
 
-def parse_duration(globals: Globals, expr:str) -> Duration:
+def parse_duration(globals: Globals, text: str) -> Duration:
   _transformer.globals = globals
-  return _parser.parse(expr, start="duration")
+  return _parser.parse(text, start="duration")
 
-def parse_block(globals: Globals, block: str) -> Statement:
+def parse_block(globals: Globals, text: str) -> Statement:
   _transformer.globals = globals
-  return _parser.parse(block, start="block")
+  return _parser.parse(text, start="block")
+
+def parse_events_decl(text: str):
+  return _parser.parse(text, start="event_decl_list")
 
-def parse_events_decl(events_decl: str):
-  return _parser.parse(events_decl, start="event_decl_list")
+def parse_func_decl(text: str) -> Tuple[str, List[Param]]:
+  return _parser.parse(text, start="func_decl")
 
-def parse_state_ref(state_ref: str):
-  return _parser.parse(state_ref, start="state_ref")
+def parse_state_ref(text: str):
+  return _parser.parse(text, start="state_ref")
 
 def parse_semantic_choice(choice: str):
   return _parser.parse(choice, start="semantic_choice")

+ 11 - 4
src/sccd/parser/grammar/action_language.g

@@ -29,15 +29,20 @@ event_decl_list: neg_event_decl ("," neg_event_decl)*
 ?neg_event_decl: event_decl -> pos
                | "not" event_decl -> neg
 
-?event_decl: IDENTIFIER event_params
+?event_decl: IDENTIFIER params_decl
 
-event_params: ( "(" event_param_decl ("," event_param_decl)* ")" )?
+params_decl: ( "(" param_decl ("," param_decl)* ")" )?
 
-?event_param_decl: IDENTIFIER ":" TYPE
+// param_decl rule shared with function declaration
+?param_decl: IDENTIFIER ":" TYPE
 
 TYPE: "int" | "str" | "Duration"
 
 
+// Function declaration parsing
+
+func_decl: IDENTIFIER params_decl
+
 // Expression parsing
 
 // We use the same operators and operator precedence rules as Python
@@ -130,10 +135,12 @@ TIME_D: "d" // for zero-duration
 
 // Statement parsing
 
-block: (stmt ";")*
+?block: (stmt ";")*
 
 ?stmt: assignment
      | expr -> expression_stmt
+     | "return" expr -> return_stmt
+     | "{" block "}" -> block
 
 assignment: lhs assign_operator expr
 

+ 22 - 4
src/sccd/parser/statechart_parser.py

@@ -152,7 +152,7 @@ class ActionParser(XmlParser):
     scope = self.scope.require()
     actions = self.actions.require()
 
-    block = parse_block(globals, block=el.text)
+    block = parse_block(globals, el.text)
     block.init_stmt(scope)
     a = Code(block)
     actions.append(a)
@@ -362,7 +362,7 @@ class TreeParser(StateParser):
         if event is not None:
           self._raise(t_el, "Can only specify one of attributes 'after', 'event'.", None)
         try:
-          after_expr = parse_expression(globals, expr=after)
+          after_expr = parse_expression(globals, after)
           after_type = after_expr.init_rvalue(t_scope)
           if after_type != Duration:
             msg = "Expression is '%s' type. Expected 'Duration' type." % str(after_type)
@@ -386,7 +386,7 @@ class TreeParser(StateParser):
       # Guard
       if cond is not None:
         try:
-          expr = parse_expression(globals, expr=cond)
+          expr = parse_expression(globals, cond)
           expr.init_rvalue(t_scope)
         except Exception as e:
           self._raise(t_el, "cond=\"%s\": %s" % (cond, str(e)), e)
@@ -435,11 +435,29 @@ class StatechartParser(TreeParser):
     id = el.get("id")
     expr = el.get("expr")
 
-    parsed = parse_expression(globals, expr=expr)
+    parsed = parse_expression(globals, expr)
     rhs_type = parsed.init_rvalue(scope)
     val = parsed.eval(_blank_eval_context)
     scope.add_variable_w_initial(name=id, initial=val)
 
+  def end_func(self, el):
+    globals = self.globals.require()
+    scope = self.scope.require()
+
+    id = el.get("id")
+    text = el.text
+
+    name, params = parse_func_decl(id)
+
+    # print("name:", name)
+    # print("params:", params)
+
+    body = parse_block(globals, text)
+    func = Function(params, body)
+    func.init_stmt(scope)
+
+    scope.add_function(name, func)
+
   def start_datamodel(self, el):
     statechart = self.statechart.require()
     self.scope.push(statechart.scope)

+ 3 - 2
src/sccd/syntax/expression.py

@@ -80,9 +80,10 @@ class FunctionCall(Expression):
         self.type = return_type
 
         actual_types = [p.init_rvalue(scope) for p in self.parameters]
-        for formal, actual in zip(formal_types, actual_types):
+        for i, (formal, actual) in enumerate(zip(formal_types, actual_types)):
             if formal != actual:
-                raise Exception("Function call: Actual types '%s' differ from formal types '%s'" % (actual_types, formal_types))
+                print(self.function)
+                raise Exception("Function call, argument %d: %s is not expected type %s, instead is %s" % (i, self.parameters[i].render(), str(formal), str(actual)))
         return self.type
 
     def eval(self, ctx: EvalContext):

+ 18 - 1
src/sccd/syntax/scope.py

@@ -169,6 +169,14 @@ class Scope:
     self.named_values[name] = c
     return c
 
+  def add_variable(self, name: str, expected_type: type) -> Variable:
+    assert not self.frozen
+    self._assert_name_available(name)
+    variable = Variable(name=name, type=expected_type, offset=self.total_size())
+    self.named_values[name] = variable
+    self.variables.append(variable)
+    return variable
+
   def add_variable_w_initial(self, name: str, initial: Any) -> Variable:
     assert not self.frozen
     self._assert_name_available(name)
@@ -187,7 +195,7 @@ class Scope:
     self.variables.append(param)
     return param
 
-  def add_function(self, name: str, function: Callable) -> Constant:
+  def add_python_function(self, name: str, function: Callable) -> Constant:
     sig = signature(function)
     return_type = sig.return_annotation
     args = list(sig.parameters.values())[1:] # hide 'EvalContext' parameter to user
@@ -197,3 +205,12 @@ class Scope:
     c = Constant(name=name, type=function_type, value=function)
     self.named_values[name] = c
     return c
+
+  def add_function(self, name: str, function: 'Function') -> Constant:
+    return_type = function.return_type
+    param_types = [p.type for p in function.params]
+    function_type = Callable[param_types, return_type]
+
+    c = Constant(name=name, type=function_type, value=function)
+    self.named_values[name] = c
+    return c

+ 85 - 12
src/sccd/syntax/statement.py

@@ -1,15 +1,25 @@
 from typing import *
 from sccd.syntax.expression import *
 
+@dataclass
+class Return:
+    ret: bool
+    val: Any = None
+
+@dataclass
+class ReturnType:
+    ret: bool
+    type: Optional[type] = None
+
 # A statement is NOT an expression.
 class Statement(ABC):
     # Execution typically has side effects.
     @abstractmethod
-    def exec(self, ctx: EvalContext):
+    def exec(self, ctx: EvalContext) -> Return:
         pass
 
     @abstractmethod
-    def init_stmt(self, scope: Scope):
+    def init_stmt(self, scope: Scope) -> ReturnType:
         pass
 
     @abstractmethod
@@ -22,11 +32,12 @@ class Assignment(Statement):
     operator: str # token value from the grammar.
     rhs: Expression
 
-    def init_stmt(self, scope: Scope):
+    def init_stmt(self, scope: Scope) -> ReturnType:
         rhs_t = self.rhs.init_rvalue(scope)
         self.lhs.init_lvalue(scope, rhs_t)
+        return ReturnType(False)
 
-    def exec(self, ctx: EvalContext):
+    def exec(self, ctx: EvalContext) -> Return:
         rhs_val = self.rhs.eval(ctx)
         variable = self.lhs.eval_lvalue(ctx)
 
@@ -54,6 +65,8 @@ class Assignment(Statement):
             "/=": divide,
         }[self.operator]()
 
+        return Return(False)
+
     def render(self) -> str:
         return self.lhs.render() + ' ' + self.operator + ' ' + self.rhs.render()
 
@@ -62,16 +75,22 @@ class Block(Statement):
     stmts: List[Statement]
     scope: Optional[Scope] = None
 
-    def init_stmt(self, scope: Scope):
+    def init_stmt(self, scope: Scope) -> ReturnType:
         self.scope = Scope("local", scope)
         for stmt in self.stmts:
-            stmt.init_stmt(self.scope)
+            ret = stmt.init_stmt(self.scope)
+            if ret.ret:
+                break
+        return ret
 
-    def exec(self, ctx: EvalContext):
+    def exec(self, ctx: EvalContext) -> Return:
         ctx.memory.grow_stack(self.scope)
         for stmt in self.stmts:
-            stmt.exec(ctx)
+            ret = stmt.exec(ctx)
+            if ret.ret:
+                break
         ctx.memory.shrink_stack()
+        return ret
 
     def render(self) -> str:
         result = ""
@@ -84,11 +103,13 @@ class Block(Statement):
 class ExpressionStatement(Statement):
     expr: Expression
 
-    def init_stmt(self, scope: Scope):
+    def init_stmt(self, scope: Scope) -> ReturnType:
         self.expr.init_rvalue(scope)
+        return ReturnType(False)
 
-    def exec(self, ctx: EvalContext):
+    def exec(self, ctx: EvalContext) -> Return:
         self.expr.eval(ctx)
+        return Return(False)
 
     def render(self) -> str:
         return self.expr.render()
@@ -97,5 +118,57 @@ class ExpressionStatement(Statement):
 class ReturnStatement(Statement):
     expr: Expression
 
-    def init_stmt(self, scope: Scope):
-        pass
+    def init_stmt(self, scope: Scope) -> ReturnType:
+        t = self.expr.init_rvalue(scope)
+        return ReturnType(True, t)
+
+    def exec(self, ctx: EvalContext):
+        val = self.expr.eval(ctx)
+        return Return(True, val)
+
+    def render(self) -> str:
+        return "return " + self.expr.render()
+
+# Used in EventDecl and Function
+@dataclass
+class Param:
+    name: str
+    type: type
+
+    variable: Optional[Variable] = None
+
+    def init_param(self, scope: Scope):
+        self.variable = scope.add_variable(self.name, self.type)
+
+@dataclass
+class Function(Statement):
+    params: List[Param]
+    body: Block
+    scope: Optional[Scope] = None
+    return_type: Optional[type] = None
+
+    def init_stmt(self, scope: Scope) -> ReturnType:
+        self.scope = Scope("function_params", scope)
+        # Reserve space for arguments on stack
+        for p in self.params:
+            p.init_param(self.scope)
+        self.return_type = self.body.init_stmt(self.scope).type
+
+        # Execution of function declaration doesn't do anything
+        return ReturnType(False)
+
+    def exec(self, ctx: EvalContext) -> Return:
+        # Execution of function declaration doesn't do anything
+        return Return(False)
+
+    def __call__(self, ctx: EvalContext, *params) -> Any:
+        ctx.memory.grow_stack(self.scope)
+        # Copy arguments to stack
+        for val, p in zip(params, self.params):
+            p.variable.store(ctx, val)
+        ret = self.body.exec(ctx)
+        ctx.memory.shrink_stack()
+        return ret.val
+
+    def render(self) -> str:
+        return "" # todo

+ 0 - 8
src/sccd/syntax/tree.py

@@ -81,14 +81,6 @@ class ParallelState(State):
                 targets.extend(c.getEffectiveTargetStates(instance))
         return targets
 
-@dataclass
-class Param:
-    name: str 
-    type: type
-
-    def render(self) -> str:
-        return self.name + ': ' + str(self.type)
-
 @dataclass
 class EventDecl:
     name: str

+ 1 - 1
test/test_files/day_atlee/statechart_fig1_redialer.xml

@@ -15,7 +15,7 @@
     </func>
 
     <func id="numdigits(i:int)">
-      return int(log10(i)) + 1;
+      return float_to_int(log10(i)) + 1;
     </func>
   </datamodel>
 

+ 33 - 0
test/test_files/features/functions/fail_parameter_type.xml

@@ -0,0 +1,33 @@
+<?xml version="1.0" ?>
+<test>
+  <statechart>
+    <semantics
+      big_step_maximality="take_many"
+      concurrency="single"
+      input_event_lifeline="first_combo_step"/>
+
+    <datamodel>
+      <func id="digit(i:int, pos:int)">
+        pow = 10 ** pos;
+        return i // pow % 10;
+      </func>
+
+      <func id="numdigits(i:int)">
+        return float_to_int(log10(i)) + 1;
+      </func>
+    </datamodel>
+
+    <tree>
+      <state initial="ready">
+        <state id="ready">
+          <!-- illegal condition: sole parameter of numdigits is 'int', however string is given -->
+          <transition port="in" event="start" target="../final" cond='numdigits("123") == 3'>
+            <raise port="out" event="ok"/>
+          </transition>
+        </state>
+
+        <state id="final"/>
+      </state>
+    </tree>
+  </statechart>
+</test>

+ 33 - 0
test/test_files/features/functions/fail_return_type.xml

@@ -0,0 +1,33 @@
+<?xml version="1.0" ?>
+<test>
+  <statechart>
+    <semantics
+      big_step_maximality="take_many"
+      concurrency="single"
+      input_event_lifeline="first_combo_step"/>
+
+    <datamodel>
+      <func id="digit(i:int, pos:int)">
+        pow = 10 ** pos;
+        return i // pow % 10;
+      </func>
+
+      <func id="numdigits(i:int)">
+        return float_to_int(log10(i)) + 1;
+      </func>
+    </datamodel>
+
+    <tree>
+      <state initial="ready">
+        <state id="ready">
+          <!-- illegal condition: return type of numdigits inferred to be 'int', however RHS of '=='-expression is string -->
+          <transition port="in" event="start" target="../final" cond='numdigits(123) == "3"'>
+            <raise port="out" event="ok"/>
+          </transition>
+        </state>
+
+        <state id="final"/>
+      </state>
+    </tree>
+  </statechart>
+</test>

+ 43 - 0
test/test_files/features/functions/test_functions.xml

@@ -0,0 +1,43 @@
+<?xml version="1.0" ?>
+<test>
+  <statechart>
+    <semantics
+      big_step_maximality="take_many"
+      concurrency="single"
+      input_event_lifeline="first_combo_step"/>
+
+    <datamodel>
+      <func id="digit(i:int, pos:int)">
+        pow = 10 ** pos;
+        return i // pow % 10;
+      </func>
+
+      <func id="numdigits(i:int)">
+        return float_to_int(log10(i)) + 1;
+      </func>
+    </datamodel>
+
+    <tree>
+      <state initial="ready">
+        <state id="ready">
+          <transition port="in" event="start" target="../final"
+            cond="numdigits(123) == 3 and digit(123, 1) == 2">
+            <raise port="out" event="ok"/>
+          </transition>
+        </state>
+
+        <state id="final"/>
+      </state>
+    </tree>
+  </statechart>
+
+  <input>
+    <input_event port="in" name="start" time="0 d"/>
+  </input>
+
+  <output>
+    <big_step>
+      <event port="out" name="ok"/>
+    </big_step>
+  </output>
+</test>