statement.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from typing import *
  2. from sccd.action_lang.static.expression import *
  3. from sccd.util.debug import *
  4. @dataclass(frozen=True)
  5. class Return:
  6. ret: bool
  7. val: Any = None
  8. def __post_init__(self):
  9. assert self.ret == (self.val is not None)
  10. DontReturn = Return(False)
  11. DoReturn = lambda v: Return(True, v)
  12. @dataclass(frozen=True)
  13. class ReturnBehavior:
  14. class When(Enum):
  15. ALWAYS = auto()
  16. SOME_BRANCHES = auto()
  17. NEVER = auto()
  18. when: When
  19. type: Optional[SCCDType] = None
  20. def __post_init__(self):
  21. assert (self.when == ReturnBehavior.When.NEVER) == (self.type is None)
  22. def get_return_type(self) -> Optional[SCCDType]:
  23. if self.when == ReturnBehavior.When.ALWAYS:
  24. return self.type
  25. elif self.when == ReturnBehavior.When.SOME_BRANCHES:
  26. raise StaticTypeError("Cannot statically infer return type: Some branches return %s, others return nothing." % str(self.type))
  27. # Check if two branches have combinable ReturnBehaviors and if so, combine them.
  28. @staticmethod
  29. def combine_branches(one: 'ReturnBehavior', two: 'ReturnBehavior') -> 'ReturnBehavior':
  30. if one == two:
  31. # Whether ALWAYS/SOME_BRANCHES/NEVER, when both branches
  32. # have the same 'when' and the same type, the combination
  33. # is valid and has that type too :)
  34. return one
  35. if one.when == ReturnBehavior.When.NEVER:
  36. # two will not be NEVER
  37. return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, two.type)
  38. if two.when == ReturnBehavior.When.NEVER:
  39. # one will not be NEVER
  40. return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, one.type)
  41. # Only remaining case: ALWAYS & SOME_BRANCHES.
  42. # Now the types must match:
  43. if one.type != two.type:
  44. raise StaticTypeError("Return types differ: One branch returns %s, the other %s" % (str(one.type), str(two.type)))
  45. return ReturnBehavior(ReturnBehavior.When.SOME_BRANCHES, one.type)
  46. # If a statement with known ReturnBehavior is followed by another statement with known ReturnBehavior, what is the ReturnBehavior of their sequence? Also, raises if their sequence is illegal.
  47. @staticmethod
  48. def sequence(earlier: 'ReturnBehavior', later: 'ReturnBehavior') -> 'ReturnBehavior':
  49. if earlier.when == ReturnBehavior.When.NEVER:
  50. return later
  51. if earlier.when == ReturnBehavior.When.SOME_BRANCHES:
  52. if later.when == ReturnBehavior.When.NEVER:
  53. return earlier
  54. if later.when == ReturnBehavior.When.SOME_BRANCHES:
  55. if earlier.type != later.type:
  56. raise StaticTypeError("Return types differ: Earlier statement may return %s, later statement may return %s" % (str(earlier.type), str(later.type)))
  57. return earlier
  58. if earlier.type != later.type:
  59. raise StaticTypeError("Return types differ: Earlier statement may return %s, later statement returns %s" % (str(earlier.type), str(later.type)))
  60. return later
  61. raise StaticTypeError("Earlier statement always returns %s, cannot be followed by another statement" % str(earlier.type))
  62. NeverReturns = ReturnBehavior(ReturnBehavior.When.NEVER)
  63. AlwaysReturns = lambda t: ReturnBehavior(ReturnBehavior.When.ALWAYS, t)
  64. # A statement is NOT an expression.
  65. class Statement(ABC, Visitable):
  66. # Run static analysis on the statement.
  67. # Looks up identifiers in the given scope, and adds new identifiers to the scope.
  68. @abstractmethod
  69. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  70. pass
  71. # Execute the statement.
  72. # Execution typically has side effects.
  73. @abstractmethod
  74. def exec(self, memory: MemoryInterface) -> Return:
  75. pass
  76. @abstractmethod
  77. def render(self) -> str:
  78. pass
  79. @dataclass
  80. class Assignment(Statement):
  81. lhs: LValue
  82. rhs: Expression
  83. # Did the assignment create a new variable in its scope?
  84. is_initialization: Optional[bool] = None
  85. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  86. rhs_t = self.rhs.init_expr(scope)
  87. self.is_initialization = self.lhs.init_lvalue(scope, rhs_t, self.rhs)
  88. # Very common case of assignment of a function to an identifier:
  89. # Make the function's scope name a little bit more expressive
  90. if isinstance(self.rhs, FunctionDeclaration) and isinstance(self.lhs, Identifier):
  91. self.rhs.scope.name = "fn_"+self.lhs.name
  92. return NeverReturns
  93. def exec(self, memory: MemoryInterface) -> Return:
  94. rhs_val = self.rhs.eval(memory)
  95. self.lhs.assign(memory, rhs_val)
  96. if DEBUG:
  97. print(" "+termcolor.colored(self.lhs.render() + ' = ' + str(rhs_val), 'grey'))
  98. return DontReturn
  99. def render(self) -> str:
  100. return self.lhs.render() + ' = ' + self.rhs.render() #+ '⁏'
  101. @dataclass
  102. class Block(Statement):
  103. stmts: List[Statement]
  104. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  105. so_far = NeverReturns
  106. for i, stmt in enumerate(self.stmts):
  107. now_what = stmt.init_stmt(scope)
  108. so_far = ReturnBehavior.sequence(so_far, now_what)
  109. return so_far
  110. def exec(self, memory: MemoryInterface) -> Return:
  111. ret = DontReturn
  112. for stmt in self.stmts:
  113. # if DEBUG:
  114. # print(" "+termcolor.colored(stmt.render(), 'grey'))
  115. ret = stmt.exec(memory)
  116. if ret.ret:
  117. break
  118. return ret
  119. def render(self) -> str:
  120. result = ""
  121. for stmt in self.stmts:
  122. result += stmt.render() + '⁏ '
  123. return result
  124. # e.g. a function call
  125. @dataclass
  126. class ExpressionStatement(Statement):
  127. expr: Expression
  128. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  129. self.expr.init_expr(scope)
  130. return NeverReturns
  131. def exec(self, memory: MemoryInterface) -> Return:
  132. self.expr.eval(memory)
  133. return DontReturn
  134. def render(self) -> str:
  135. return self.expr.render()
  136. @dataclass
  137. class ReturnStatement(Statement):
  138. expr: Expression
  139. scope: Optional[Scope] = None
  140. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  141. self.scope = scope
  142. t = self.expr.init_expr(scope)
  143. if t is None:
  144. raise StaticTypeError("Return statement: Expression does not evaluate to a value.")
  145. return AlwaysReturns(t)
  146. def exec(self, memory: MemoryInterface) -> Return:
  147. val = self.expr.eval(memory)
  148. return DoReturn(val)
  149. def render(self) -> str:
  150. return "return " + self.expr.render() #+ ";"
  151. @dataclass
  152. class IfStatement(Statement):
  153. cond: Expression
  154. if_body: Statement
  155. else_body: Optional[Statement] = None
  156. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  157. cond_t = self.cond.init_expr(scope)
  158. if_ret = self.if_body.init_stmt(scope)
  159. if self.else_body is None:
  160. else_ret = NeverReturns
  161. else:
  162. else_ret = self.else_body.init_stmt(scope)
  163. return ReturnBehavior.combine_branches(if_ret, else_ret)
  164. def exec(self, memory: MemoryInterface) -> Return:
  165. val = self.cond.eval(memory)
  166. if val:
  167. return self.if_body.exec(memory)
  168. elif self.else_body is not None:
  169. return self.else_body.exec(memory)
  170. return DontReturn
  171. def render(self) -> str:
  172. return "if (%s) [[" % self.cond.render() + self.if_body.render() + "]]"
  173. @dataclass
  174. class ImportStatement(Statement):
  175. module_name: str
  176. # Offsets and values of imported stuff
  177. declarations: Optional[List[Tuple[int, Any]]] = None
  178. def init_stmt(self, scope: Scope) -> ReturnBehavior:
  179. import importlib
  180. self.module = importlib.import_module(self.module_name)
  181. self.declarations = []
  182. for name, (value, type) in self.module.SCCD_EXPORTS.items():
  183. offset = scope.declare(name, type, const=True)
  184. if isinstance(type, SCCDFunction):
  185. # Function values are a bit special, in the action language they are secretly passed a MemoryInterface object as first parameter, followed by the other (visible) parameters.
  186. # I don't really like this solution, but it works for now.
  187. def make_wrapper(func):
  188. def wrapper(memory: MemoryInterface, *params):
  189. return func(*params)
  190. return wrapper
  191. self.declarations.append((offset, make_wrapper(value)))
  192. else:
  193. self.declarations.append((offset, value))
  194. return NeverReturns
  195. def exec(self, memory: MemoryInterface) -> Return:
  196. for offset, value in self.declarations:
  197. memory.store(offset, value)
  198. return DontReturn
  199. def render(self) -> str:
  200. return "import " + self.module_name #+ ";"