Просмотр исходного кода

Rename class 'ReturnType' to 'ReturnBehavior' + add check on fields.

Joeri Exelmans 5 лет назад
Родитель
Сommit
021c634211
1 измененных файлов с 32 добавлено и 32 удалено
  1. 32 32
      src/sccd/syntax/statement.py

+ 32 - 32
src/sccd/syntax/statement.py

@@ -7,7 +7,7 @@ class Return:
     val: Any = None
 
 @dataclass(frozen=True)
-class ReturnType:
+class ReturnBehavior:
 
     class When(Enum):
         ALWAYS = auto()
@@ -17,37 +17,37 @@ class ReturnType:
     when: When
     type: Optional[type] = None
 
-    def may_return(self) -> bool:
-        return self.when == When.ALWAYS or self.when == When.SOME_BRANCHES
+    def __post_init__(self):
+        assert (self.when == ReturnBehavior.When.NEVER) == (self.type is None)
 
-    # Check if two branches have combinable ReturnTypes and if so, combine them.
+    # Check if two branches have combinable ReturnBehaviors and if so, combine them.
     @staticmethod
-    def combine_branches(one: 'ReturnType', two: 'ReturnType') -> 'ReturnType':
+    def combine_branches(one: 'ReturnBehavior', two: 'ReturnBehavior') -> 'ReturnBehavior':
         if one == two:
             # Whether ALWAYS/SOME_BRANCHES/NEVER, when both branches
             # have the same 'when' and the same type, the combination
             # is valid and has that type too :)
             return one
-        if one.when == ReturnType.When.NEVER:
+        if one.when == ReturnBehavior.When.NEVER:
             # two will not be NEVER
-            return ReturnType(ReturnType.When.SOME_BRANCHES, two.type)
-        if two.when == ReturnType.When.NEVER:
+            return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, two.type)
+        if two.when == ReturnBehavior.When.NEVER:
             # one will not be NEVER
-            return ReturnType(ReturnType.When.SOME_BRANCHES, one.type)
+            return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, one.type)
         # Only remaining case: ALWAYS & SOME_BRANCHES.
         # Now the types must match:
         if one.type != two.type:
             raise StaticTypeError("Branches have different return types: %s and %s" % (str(one.type), str(two.type)))
-        return ReturnType(ReturnType.When.SOME_BRANCHES, one.type)
+        return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, one.type)
 
     @staticmethod
-    def sequence(earlier: 'ReturnType', later: 'ReturnType') -> 'ReturnType':
-        if earlier.when == ReturnType.When.NEVER:
+    def sequence(earlier: 'ReturnBehavior', later: 'ReturnBehavior') -> 'ReturnBehavior':
+        if earlier.when == ReturnBehavior.When.NEVER:
             return later
-        if earlier.when == ReturnType.When.SOME_BRANCHES:
-            if later.when == ReturnType.When.NEVER:
+        if earlier.when == ReturnBehavior.When.SOME_BRANCHES:
+            if later.when == ReturnBehavior.When.NEVER:
                 return earlier
-            if later.when == ReturnType.When.SOME_BRANCHES:
+            if later.when == ReturnBehavior.When.SOME_BRANCHES:
                 if earlier.type != later.type:
                     raise StaticTypeError("Earlier statement may return %s, later statement may return %s" % (str(earlier.type), str(later.type)))
                 return earlier
@@ -64,7 +64,7 @@ class Statement(ABC):
         pass
 
     @abstractmethod
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         pass
 
     @abstractmethod
@@ -77,10 +77,10 @@ class Assignment(Statement):
     operator: str # token value from the grammar.
     rhs: Expression
 
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         rhs_t = self.rhs.init_rvalue(scope)
         self.lhs.init_lvalue(scope, rhs_t)
-        return ReturnType(ReturnType.When.NEVER)
+        return ReturnBehavior(ReturnBehavior.When.NEVER)
 
     def exec(self, ctx: EvalContext) -> Return:
         rhs_val = self.rhs.eval(ctx)
@@ -120,12 +120,12 @@ class Block(Statement):
     stmts: List[Statement]
     scope: Optional[Scope] = None
 
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         self.scope = Scope("local", scope)
-        earlier = ReturnType(ReturnType.When.NEVER)
+        earlier = ReturnBehavior(ReturnBehavior.When.NEVER)
         for i, stmt in enumerate(self.stmts):
             later = stmt.init_stmt(self.scope)
-            earlier = ReturnType.sequence(earlier, later)
+            earlier = ReturnBehavior.sequence(earlier, later)
         return earlier
 
     def exec(self, ctx: EvalContext) -> Return:
@@ -148,9 +148,9 @@ class Block(Statement):
 class ExpressionStatement(Statement):
     expr: Expression
 
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         self.expr.init_rvalue(scope)
-        return ReturnType(ReturnType.When.NEVER)
+        return ReturnBehavior(ReturnBehavior.When.NEVER)
 
     def exec(self, ctx: EvalContext) -> Return:
         self.expr.eval(ctx)
@@ -163,9 +163,9 @@ class ExpressionStatement(Statement):
 class ReturnStatement(Statement):
     expr: Expression
 
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         t = self.expr.init_rvalue(scope)
-        return ReturnType(ReturnType.When.ALWAYS, t)
+        return ReturnBehavior(ReturnBehavior.When.ALWAYS, t)
 
     def exec(self, ctx: EvalContext) -> Return:
         val = self.expr.eval(ctx)
@@ -180,14 +180,14 @@ class IfStatement(Statement):
     if_body: Statement
     else_body: Optional[Statement] = None
 
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         cond_t = self.cond.init_rvalue(scope)
         if_ret = self.if_body.init_stmt(scope)
         if self.else_body is None:
-            else_ret = ReturnType(ReturnType.When.NEVER)
+            else_ret = ReturnBehavior(ReturnBehavior.When.NEVER)
         else:
             else_ret = self.else_body.init_stmt(scope)
-        return ReturnType.combine_branches(if_ret, else_ret)
+        return ReturnBehavior.combine_branches(if_ret, else_ret)
 
     def exec(self, ctx: EvalContext) -> Return:
         val = self.cond.eval(ctx)
@@ -216,19 +216,19 @@ class Function(Statement):
     scope: Optional[Scope] = None
     return_type: Optional[type] = None
 
-    def init_stmt(self, scope: Scope) -> ReturnType:
+    def init_stmt(self, scope: Scope) -> ReturnBehavior:
         self.scope = Scope("function_params", scope)
         # Reserve space for arguments on stack
         for p in self.params:
             p.init_param(self.scope)
         ret = self.body.init_stmt(self.scope)
-        if ret.when == ReturnType.When.ALWAYS:
+        if ret.when == ReturnBehavior.When.ALWAYS:
             self.return_type = ret.type
-        elif ret.when == ReturnType.When.SOME_BRANCHES:
+        elif ret.when == ReturnBehavior.When.SOME_BRANCHES:
             raise StaticTypeError("Cannot statically infer function return type: Some branches return %s, others return nothing." % str(ret.type))
 
         # Execution of function declaration doesn't return (or do) anything
-        return ReturnType(ReturnType.When.NEVER)
+        return ReturnBehavior(ReturnBehavior.When.NEVER)
 
     def exec(self, ctx: EvalContext) -> Return:
         # Execution of function declaration doesn't do anything