Browse Source

Proper block-of-statements static return type inference checking all branches and raising error if different branches have different return types.

Joeri Exelmans 5 years ago
parent
commit
52823174d2
1 changed files with 39 additions and 13 deletions
  1. 39 13
      src/sccd/syntax/statement.py

+ 39 - 13
src/sccd/syntax/statement.py

@@ -8,7 +8,13 @@ class Return:
 
 @dataclass
 class ReturnType:
-    ret: bool
+
+    class When(Enum):
+        ALWAYS = auto()
+        SOME_BRANCHES = auto()
+        NEVER = auto()
+
+    when: When
     type: Optional[type] = None
 
 # A statement is NOT an expression.
@@ -35,7 +41,7 @@ class Assignment(Statement):
     def init_stmt(self, scope: Scope) -> ReturnType:
         rhs_t = self.rhs.init_rvalue(scope)
         self.lhs.init_lvalue(scope, rhs_t)
-        return ReturnType(False)
+        return ReturnType(ReturnType.When.NEVER)
 
     def exec(self, ctx: EvalContext) -> Return:
         rhs_val = self.rhs.eval(ctx)
@@ -77,11 +83,24 @@ class Block(Statement):
 
     def init_stmt(self, scope: Scope) -> ReturnType:
         self.scope = Scope("local", scope)
-        for stmt in self.stmts:
+        earlier_return_type = ReturnType(ReturnType.When.NEVER)
+        for i, stmt in enumerate(self.stmts):
             ret = stmt.init_stmt(self.scope)
-            if ret.ret:
-                break
-        return ret
+            if ret.when == ReturnType.When.ALWAYS:
+                if earlier_return_type.when == ReturnType.When.SOME_BRANCHES:
+                    if earlier_return_type.type != ret.type:
+                        raise Exception("Not all branches have same return type: %s and %s" % (str(ret.type), str(earlier_return_type.type)))
+                # A return statement is encountered, don't init the rest of the statements since they are unreachable
+                if i < len(self.stmts)-1:
+                    print_debug("Warning: statements after return statement ignored.")
+                return ret
+            elif ret.when == ReturnType.When.SOME_BRANCHES:
+                if earlier_return_type.when == ReturnType.When.SOME_BRANCHES:
+                    if earlier_return_type.type != ret.type:
+                        raise Exception("Not all branches have same return type: %s and %s" % (str(ret.type), str(earlier_return_type.type)))
+                earlier_return_type = ret
+
+        return earlier_return_type
 
     def exec(self, ctx: EvalContext) -> Return:
         ctx.memory.grow_stack(self.scope)
@@ -105,7 +124,7 @@ class ExpressionStatement(Statement):
 
     def init_stmt(self, scope: Scope) -> ReturnType:
         self.expr.init_rvalue(scope)
-        return ReturnType(False)
+        return ReturnType(ReturnType.When.NEVER)
 
     def exec(self, ctx: EvalContext) -> Return:
         self.expr.eval(ctx)
@@ -120,7 +139,7 @@ class ReturnStatement(Statement):
 
     def init_stmt(self, scope: Scope) -> ReturnType:
         t = self.expr.init_rvalue(scope)
-        return ReturnType(True, t)
+        return ReturnType(ReturnType.When.ALWAYS, t)
 
     def exec(self, ctx: EvalContext) -> Return:
         val = self.expr.eval(ctx)
@@ -138,7 +157,10 @@ class IfStatement(Statement):
         cond_t = self.cond.init_rvalue(scope)
         # todo: assert cond_t is bool
         ret = self.body.init_stmt(scope) # return type is only if cond evaluates to True...
-        return Return(False)
+        if ret.when == ReturnType.When.NEVER:
+            return ReturnType(ReturnType.When.NEVER)
+        else:
+            return ReturnType(ReturnType.When.SOME_BRANCHES, ret.type)
 
     def exec(self, ctx: EvalContext) -> Return:
         val = self.cond.eval(ctx)
@@ -172,10 +194,14 @@ class Function(Statement):
         # Reserve space for arguments on stack
         for p in self.params:
             p.init_param(self.scope)
-        self.return_type = self.body.init_stmt(self.scope).type
-
-        # Execution of function declaration doesn't do anything
-        return ReturnType(False)
+        ret = self.body.init_stmt(self.scope)
+        if ret.when == ReturnType.When.ALWAYS:
+            self.return_type = ret.type
+        elif ret.when == ReturnType.When.SOME_BRANCHES:
+            raise Exception("Cannot statically infer function return type: Some branches return %s, others return nothing." % str(ret.type))
+
+        # Execution of function declaration doesn't return (or do) anything
+        return ReturnType(ReturnType.When.NEVER)
 
     def exec(self, ctx: EvalContext) -> Return:
         # Execution of function declaration doesn't do anything