|
@@ -8,7 +8,13 @@ class Return:
|
|
|
|
|
|
@dataclass
|
|
|
class ReturnType:
|
|
|
- ret: bool
|
|
|
+
|
|
|
+ class When(Enum):
|
|
|
+ ALWAYS = auto()
|
|
|
+ SOME_BRANCHES = auto()
|
|
|
+ NEVER = auto()
|
|
|
+
|
|
|
+ when: When
|
|
|
type: Optional[type] = None
|
|
|
|
|
|
# A statement is NOT an expression.
|
|
@@ -35,7 +41,7 @@ class Assignment(Statement):
|
|
|
def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
rhs_t = self.rhs.init_rvalue(scope)
|
|
|
self.lhs.init_lvalue(scope, rhs_t)
|
|
|
- return ReturnType(False)
|
|
|
+ return ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
rhs_val = self.rhs.eval(ctx)
|
|
@@ -77,11 +83,24 @@ class Block(Statement):
|
|
|
|
|
|
def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
self.scope = Scope("local", scope)
|
|
|
- for stmt in self.stmts:
|
|
|
+ earlier_return_type = ReturnType(ReturnType.When.NEVER)
|
|
|
+ for i, stmt in enumerate(self.stmts):
|
|
|
ret = stmt.init_stmt(self.scope)
|
|
|
- if ret.ret:
|
|
|
- break
|
|
|
- return ret
|
|
|
+ if ret.when == ReturnType.When.ALWAYS:
|
|
|
+ if earlier_return_type.when == ReturnType.When.SOME_BRANCHES:
|
|
|
+ if earlier_return_type.type != ret.type:
|
|
|
+ raise Exception("Not all branches have same return type: %s and %s" % (str(ret.type), str(earlier_return_type.type)))
|
|
|
+ # A return statement is encountered, don't init the rest of the statements since they are unreachable
|
|
|
+ if i < len(self.stmts)-1:
|
|
|
+ print_debug("Warning: statements after return statement ignored.")
|
|
|
+ return ret
|
|
|
+ elif ret.when == ReturnType.When.SOME_BRANCHES:
|
|
|
+ if earlier_return_type.when == ReturnType.When.SOME_BRANCHES:
|
|
|
+ if earlier_return_type.type != ret.type:
|
|
|
+ raise Exception("Not all branches have same return type: %s and %s" % (str(ret.type), str(earlier_return_type.type)))
|
|
|
+ earlier_return_type = ret
|
|
|
+
|
|
|
+ return earlier_return_type
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
ctx.memory.grow_stack(self.scope)
|
|
@@ -105,7 +124,7 @@ class ExpressionStatement(Statement):
|
|
|
|
|
|
def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
self.expr.init_rvalue(scope)
|
|
|
- return ReturnType(False)
|
|
|
+ return ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
self.expr.eval(ctx)
|
|
@@ -120,7 +139,7 @@ class ReturnStatement(Statement):
|
|
|
|
|
|
def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
t = self.expr.init_rvalue(scope)
|
|
|
- return ReturnType(True, t)
|
|
|
+ return ReturnType(ReturnType.When.ALWAYS, t)
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
val = self.expr.eval(ctx)
|
|
@@ -138,7 +157,10 @@ class IfStatement(Statement):
|
|
|
cond_t = self.cond.init_rvalue(scope)
|
|
|
# todo: assert cond_t is bool
|
|
|
ret = self.body.init_stmt(scope) # return type is only if cond evaluates to True...
|
|
|
- return Return(False)
|
|
|
+ if ret.when == ReturnType.When.NEVER:
|
|
|
+ return ReturnType(ReturnType.When.NEVER)
|
|
|
+ else:
|
|
|
+ return ReturnType(ReturnType.When.SOME_BRANCHES, ret.type)
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
val = self.cond.eval(ctx)
|
|
@@ -172,10 +194,14 @@ class Function(Statement):
|
|
|
# Reserve space for arguments on stack
|
|
|
for p in self.params:
|
|
|
p.init_param(self.scope)
|
|
|
- self.return_type = self.body.init_stmt(self.scope).type
|
|
|
-
|
|
|
- # Execution of function declaration doesn't do anything
|
|
|
- return ReturnType(False)
|
|
|
+ ret = self.body.init_stmt(self.scope)
|
|
|
+ if ret.when == ReturnType.When.ALWAYS:
|
|
|
+ self.return_type = ret.type
|
|
|
+ elif ret.when == ReturnType.When.SOME_BRANCHES:
|
|
|
+ raise Exception("Cannot statically infer function return type: Some branches return %s, others return nothing." % str(ret.type))
|
|
|
+
|
|
|
+ # Execution of function declaration doesn't return (or do) anything
|
|
|
+ return ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
# Execution of function declaration doesn't do anything
|