|
@@ -7,7 +7,7 @@ class Return:
|
|
|
val: Any = None
|
|
val: Any = None
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
@dataclass(frozen=True)
|
|
|
-class ReturnType:
|
|
|
|
|
|
|
+class ReturnBehavior:
|
|
|
|
|
|
|
|
class When(Enum):
|
|
class When(Enum):
|
|
|
ALWAYS = auto()
|
|
ALWAYS = auto()
|
|
@@ -17,37 +17,37 @@ class ReturnType:
|
|
|
when: When
|
|
when: When
|
|
|
type: Optional[type] = None
|
|
type: Optional[type] = None
|
|
|
|
|
|
|
|
- def may_return(self) -> bool:
|
|
|
|
|
- return self.when == When.ALWAYS or self.when == When.SOME_BRANCHES
|
|
|
|
|
|
|
+ def __post_init__(self):
|
|
|
|
|
+ assert (self.when == ReturnBehavior.When.NEVER) == (self.type is None)
|
|
|
|
|
|
|
|
- # Check if two branches have combinable ReturnTypes and if so, combine them.
|
|
|
|
|
|
|
+ # Check if two branches have combinable ReturnBehaviors and if so, combine them.
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def combine_branches(one: 'ReturnType', two: 'ReturnType') -> 'ReturnType':
|
|
|
|
|
|
|
+ def combine_branches(one: 'ReturnBehavior', two: 'ReturnBehavior') -> 'ReturnBehavior':
|
|
|
if one == two:
|
|
if one == two:
|
|
|
# Whether ALWAYS/SOME_BRANCHES/NEVER, when both branches
|
|
# Whether ALWAYS/SOME_BRANCHES/NEVER, when both branches
|
|
|
# have the same 'when' and the same type, the combination
|
|
# have the same 'when' and the same type, the combination
|
|
|
# is valid and has that type too :)
|
|
# is valid and has that type too :)
|
|
|
return one
|
|
return one
|
|
|
- if one.when == ReturnType.When.NEVER:
|
|
|
|
|
|
|
+ if one.when == ReturnBehavior.When.NEVER:
|
|
|
# two will not be NEVER
|
|
# two will not be NEVER
|
|
|
- return ReturnType(ReturnType.When.SOME_BRANCHES, two.type)
|
|
|
|
|
- if two.when == ReturnType.When.NEVER:
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, two.type)
|
|
|
|
|
+ if two.when == ReturnBehavior.When.NEVER:
|
|
|
# one will not be NEVER
|
|
# one will not be NEVER
|
|
|
- return ReturnType(ReturnType.When.SOME_BRANCHES, one.type)
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, one.type)
|
|
|
# Only remaining case: ALWAYS & SOME_BRANCHES.
|
|
# Only remaining case: ALWAYS & SOME_BRANCHES.
|
|
|
# Now the types must match:
|
|
# Now the types must match:
|
|
|
if one.type != two.type:
|
|
if one.type != two.type:
|
|
|
raise StaticTypeError("Branches have different return types: %s and %s" % (str(one.type), str(two.type)))
|
|
raise StaticTypeError("Branches have different return types: %s and %s" % (str(one.type), str(two.type)))
|
|
|
- return ReturnType(ReturnType.When.SOME_BRANCHES, one.type)
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, one.type)
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def sequence(earlier: 'ReturnType', later: 'ReturnType') -> 'ReturnType':
|
|
|
|
|
- if earlier.when == ReturnType.When.NEVER:
|
|
|
|
|
|
|
+ def sequence(earlier: 'ReturnBehavior', later: 'ReturnBehavior') -> 'ReturnBehavior':
|
|
|
|
|
+ if earlier.when == ReturnBehavior.When.NEVER:
|
|
|
return later
|
|
return later
|
|
|
- if earlier.when == ReturnType.When.SOME_BRANCHES:
|
|
|
|
|
- if later.when == ReturnType.When.NEVER:
|
|
|
|
|
|
|
+ if earlier.when == ReturnBehavior.When.SOME_BRANCHES:
|
|
|
|
|
+ if later.when == ReturnBehavior.When.NEVER:
|
|
|
return earlier
|
|
return earlier
|
|
|
- if later.when == ReturnType.When.SOME_BRANCHES:
|
|
|
|
|
|
|
+ if later.when == ReturnBehavior.When.SOME_BRANCHES:
|
|
|
if earlier.type != later.type:
|
|
if earlier.type != later.type:
|
|
|
raise StaticTypeError("Earlier statement may return %s, later statement may return %s" % (str(earlier.type), str(later.type)))
|
|
raise StaticTypeError("Earlier statement may return %s, later statement may return %s" % (str(earlier.type), str(later.type)))
|
|
|
return earlier
|
|
return earlier
|
|
@@ -64,7 +64,7 @@ class Statement(ABC):
|
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
@abstractmethod
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
@abstractmethod
|
|
@@ -77,10 +77,10 @@ class Assignment(Statement):
|
|
|
operator: str # token value from the grammar.
|
|
operator: str # token value from the grammar.
|
|
|
rhs: Expression
|
|
rhs: Expression
|
|
|
|
|
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
rhs_t = self.rhs.init_rvalue(scope)
|
|
rhs_t = self.rhs.init_rvalue(scope)
|
|
|
self.lhs.init_lvalue(scope, rhs_t)
|
|
self.lhs.init_lvalue(scope, rhs_t)
|
|
|
- return ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.NEVER)
|
|
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
rhs_val = self.rhs.eval(ctx)
|
|
rhs_val = self.rhs.eval(ctx)
|
|
@@ -120,12 +120,12 @@ class Block(Statement):
|
|
|
stmts: List[Statement]
|
|
stmts: List[Statement]
|
|
|
scope: Optional[Scope] = None
|
|
scope: Optional[Scope] = None
|
|
|
|
|
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
self.scope = Scope("local", scope)
|
|
self.scope = Scope("local", scope)
|
|
|
- earlier = ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
|
+ earlier = ReturnBehavior(ReturnBehavior.When.NEVER)
|
|
|
for i, stmt in enumerate(self.stmts):
|
|
for i, stmt in enumerate(self.stmts):
|
|
|
later = stmt.init_stmt(self.scope)
|
|
later = stmt.init_stmt(self.scope)
|
|
|
- earlier = ReturnType.sequence(earlier, later)
|
|
|
|
|
|
|
+ earlier = ReturnBehavior.sequence(earlier, later)
|
|
|
return earlier
|
|
return earlier
|
|
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
@@ -148,9 +148,9 @@ class Block(Statement):
|
|
|
class ExpressionStatement(Statement):
|
|
class ExpressionStatement(Statement):
|
|
|
expr: Expression
|
|
expr: Expression
|
|
|
|
|
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
self.expr.init_rvalue(scope)
|
|
self.expr.init_rvalue(scope)
|
|
|
- return ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.NEVER)
|
|
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
self.expr.eval(ctx)
|
|
self.expr.eval(ctx)
|
|
@@ -163,9 +163,9 @@ class ExpressionStatement(Statement):
|
|
|
class ReturnStatement(Statement):
|
|
class ReturnStatement(Statement):
|
|
|
expr: Expression
|
|
expr: Expression
|
|
|
|
|
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
t = self.expr.init_rvalue(scope)
|
|
t = self.expr.init_rvalue(scope)
|
|
|
- return ReturnType(ReturnType.When.ALWAYS, t)
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.ALWAYS, t)
|
|
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
val = self.expr.eval(ctx)
|
|
val = self.expr.eval(ctx)
|
|
@@ -180,14 +180,14 @@ class IfStatement(Statement):
|
|
|
if_body: Statement
|
|
if_body: Statement
|
|
|
else_body: Optional[Statement] = None
|
|
else_body: Optional[Statement] = None
|
|
|
|
|
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
cond_t = self.cond.init_rvalue(scope)
|
|
cond_t = self.cond.init_rvalue(scope)
|
|
|
if_ret = self.if_body.init_stmt(scope)
|
|
if_ret = self.if_body.init_stmt(scope)
|
|
|
if self.else_body is None:
|
|
if self.else_body is None:
|
|
|
- else_ret = ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
|
+ else_ret = ReturnBehavior(ReturnBehavior.When.NEVER)
|
|
|
else:
|
|
else:
|
|
|
else_ret = self.else_body.init_stmt(scope)
|
|
else_ret = self.else_body.init_stmt(scope)
|
|
|
- return ReturnType.combine_branches(if_ret, else_ret)
|
|
|
|
|
|
|
+ return ReturnBehavior.combine_branches(if_ret, else_ret)
|
|
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
val = self.cond.eval(ctx)
|
|
val = self.cond.eval(ctx)
|
|
@@ -216,19 +216,19 @@ class Function(Statement):
|
|
|
scope: Optional[Scope] = None
|
|
scope: Optional[Scope] = None
|
|
|
return_type: Optional[type] = None
|
|
return_type: Optional[type] = None
|
|
|
|
|
|
|
|
- def init_stmt(self, scope: Scope) -> ReturnType:
|
|
|
|
|
|
|
+ def init_stmt(self, scope: Scope) -> ReturnBehavior:
|
|
|
self.scope = Scope("function_params", scope)
|
|
self.scope = Scope("function_params", scope)
|
|
|
# Reserve space for arguments on stack
|
|
# Reserve space for arguments on stack
|
|
|
for p in self.params:
|
|
for p in self.params:
|
|
|
p.init_param(self.scope)
|
|
p.init_param(self.scope)
|
|
|
ret = self.body.init_stmt(self.scope)
|
|
ret = self.body.init_stmt(self.scope)
|
|
|
- if ret.when == ReturnType.When.ALWAYS:
|
|
|
|
|
|
|
+ if ret.when == ReturnBehavior.When.ALWAYS:
|
|
|
self.return_type = ret.type
|
|
self.return_type = ret.type
|
|
|
- elif ret.when == ReturnType.When.SOME_BRANCHES:
|
|
|
|
|
|
|
+ elif ret.when == ReturnBehavior.When.SOME_BRANCHES:
|
|
|
raise StaticTypeError("Cannot statically infer function return type: Some branches return %s, others return nothing." % str(ret.type))
|
|
raise StaticTypeError("Cannot statically infer function return type: Some branches return %s, others return nothing." % str(ret.type))
|
|
|
|
|
|
|
|
# Execution of function declaration doesn't return (or do) anything
|
|
# Execution of function declaration doesn't return (or do) anything
|
|
|
- return ReturnType(ReturnType.When.NEVER)
|
|
|
|
|
|
|
+ return ReturnBehavior(ReturnBehavior.When.NEVER)
|
|
|
|
|
|
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
def exec(self, ctx: EvalContext) -> Return:
|
|
|
# Execution of function declaration doesn't do anything
|
|
# Execution of function declaration doesn't do anything
|