bytecode_interpreter.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. """Interprets parsed bytecode graphs."""
  2. import modelverse_jit.bytecode_ir as bytecode_ir
  3. import modelverse_jit.runtime as jit_runtime
  4. import modelverse_kernel.primitives as primitive_functions
  5. class BreakException(Exception):
  6. """A type of exception that is used to interpret 'break' instructions:
  7. the 'break' instructions throw a BreakException, which is then handled
  8. by the appropriate 'while' instruction."""
  9. def __init__(self, loop):
  10. Exception.__init__(self)
  11. self.loop = loop
  12. class ContinueException(Exception):
  13. """A type of exception that is used to interpret 'continue' instructions:
  14. the 'continue' instructions throw a ContinueException, which is then handled
  15. by the appropriate 'while' instruction."""
  16. def __init__(self, loop):
  17. Exception.__init__(self)
  18. self.loop = loop
  19. class InterpreterState(object):
  20. """The state of the bytecode interpreter."""
  21. def __init__(self, gc_root_node, keyword_arg_dict, nop_period=20):
  22. self.gc_root_node = gc_root_node
  23. self.nop_period = nop_period
  24. self.keyword_arg_dict = keyword_arg_dict
  25. self.current_result = None
  26. self.nop_phase = 0
  27. self.local_vars = {}
  28. def import_local(self, node_id, value):
  29. """Imports the given value as a local in this interpreter state."""
  30. local_node, = yield [("CN", [])]
  31. yield [
  32. ("CE", [self.gc_root_node, local_node]),
  33. ("CD", [local_node, "value", value])]
  34. self.local_vars[node_id] = local_node
  35. raise primitive_functions.PrimitiveFinished(None)
  36. def schedule_nop(self):
  37. """Increments the nop-phase. If a nop should be performed, then True is returned.
  38. Otherwise, False."""
  39. self.nop_phase += 1
  40. if self.nop_phase == self.nop_period:
  41. self.nop_phase = 0
  42. return True
  43. else:
  44. return False
  45. def update_result(self, new_result):
  46. """Sets the current result to the given value, if it is not None."""
  47. if new_result is not None:
  48. self.current_result = new_result
  49. def get_task_root(self):
  50. """Gets the task root node id."""
  51. return self.keyword_arg_dict['task_root']
  52. def get_kernel(self):
  53. """Gets the Modelverse kernel instance."""
  54. return self.keyword_arg_dict['mvk']
  55. def interpret(self, instruction):
  56. """Interprets the given instruction and returns the current result."""
  57. instruction_type = type(instruction)
  58. if instruction_type in InterpreterState.INTERPRETERS:
  59. # Interpret the instruction.
  60. yield [("CALL_ARGS",
  61. [InterpreterState.INTERPRETERS[instruction_type], (self, instruction)])]
  62. # Maybe perform a nop.
  63. if self.schedule_nop():
  64. yield None
  65. # Interpret the next instruction.
  66. next_instruction = instruction.next_instruction
  67. if next_instruction is not None:
  68. yield [("TAIL_CALL_ARGS", [self.interpret, (next_instruction,)])]
  69. else:
  70. raise primitive_functions.PrimitiveFinished(self.current_result)
  71. else:
  72. raise jit_runtime.JitCompilationFailedException(
  73. 'Unknown bytecode instruction: %r' % instruction)
  74. def interpret_select(self, instruction):
  75. """Interprets the given 'select' instruction."""
  76. cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
  77. cond_val, = yield [("RV", [cond_node])]
  78. if cond_val:
  79. yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.if_clause,)])]
  80. elif instruction.else_clause is not None:
  81. yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.else_clause,)])]
  82. else:
  83. raise primitive_functions.PrimitiveFinished(None)
  84. def interpret_while(self, instruction):
  85. """Interprets the given 'while' instruction."""
  86. def __handle_break(exception):
  87. if exception.loop == instruction:
  88. # End the loop.
  89. raise primitive_functions.PrimitiveFinished(None)
  90. else:
  91. # Propagate the exception to the next 'while' loop.
  92. raise exception
  93. def __handle_continue(exception):
  94. if exception.loop == instruction:
  95. # Restart the loop.
  96. yield [("TAIL_CALL_ARGS", [self.interpret, (instruction,)])]
  97. else:
  98. # Propagate the exception to the next 'while' loop.
  99. raise exception
  100. yield [("TRY", [])]
  101. yield [("CATCH", [BreakException, __handle_break])]
  102. yield [("CATCH", [ContinueException, __handle_continue])]
  103. while 1:
  104. cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
  105. cond_val, = yield [("RV", [cond_node])]
  106. if cond_val:
  107. yield [("CALL_ARGS", [self.interpret, (instruction.body,)])]
  108. else:
  109. break
  110. yield [("END_TRY", [])]
  111. raise primitive_functions.PrimitiveFinished(None)
  112. def interpret_break(self, instruction):
  113. """Interprets the given 'break' instruction."""
  114. raise BreakException(instruction.loop)
  115. def interpret_continue(self, instruction):
  116. """Interprets the given 'continue' instruction."""
  117. raise ContinueException(instruction.loop)
  118. def interpret_return(self, instruction):
  119. """Interprets the given 'return' instruction."""
  120. if instruction.value is None:
  121. raise primitive_functions.InterpretedFunctionFinished(None)
  122. else:
  123. return_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
  124. raise primitive_functions.InterpretedFunctionFinished(return_node)
  125. def interpret_call(self, instruction):
  126. """Interprets the given 'call' instruction."""
  127. target, = yield [("CALL_ARGS", [self.interpret, (instruction.target,)])]
  128. named_args = {}
  129. for name, arg_instruction in instruction.argument_list:
  130. arg, = yield [("CALL_ARGS", [self.interpret, (arg_instruction,)])]
  131. named_args[name] = arg
  132. kwargs = {'function_id': target, 'named_arguments': named_args}
  133. kwargs.update(self.keyword_arg_dict)
  134. result, = yield [("CALL_KWARGS", [jit_runtime.call_function, kwargs])]
  135. if result is not None:
  136. yield [("CE", [self.gc_root_node, result])]
  137. self.update_result(result)
  138. raise primitive_functions.PrimitiveFinished(None)
  139. def interpret_constant(self, instruction):
  140. """Interprets the given 'constant' instruction."""
  141. self.update_result(instruction.constant_id)
  142. raise primitive_functions.PrimitiveFinished(None)
  143. def interpret_input(self, instruction):
  144. """Interprets the given 'input' instruction."""
  145. result, = yield [("CALL_KWARGS", [jit_runtime.get_input, self.keyword_arg_dict])]
  146. self.update_result(result)
  147. yield [("CE", [self.gc_root_node, result])]
  148. raise primitive_functions.PrimitiveFinished(None)
  149. def interpret_output(self, instruction):
  150. """Interprets the given 'output' instruction."""
  151. output_value, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
  152. task_root = self.get_task_root()
  153. last_output, last_output_link, new_last_output = yield [
  154. ("RD", [task_root, "last_output"]),
  155. ("RDE", [task_root, "last_output"]),
  156. ("CN", [])
  157. ]
  158. yield [
  159. ("CD", [last_output, "value", output_value]),
  160. ("CD", [last_output, "next", new_last_output]),
  161. ("CD", [task_root, "last_output", new_last_output]),
  162. ("DE", [last_output_link])
  163. ]
  164. yield None
  165. raise primitive_functions.PrimitiveFinished(None)
  166. def interpret_declare(self, instruction):
  167. """Interprets a 'declare' (local) instruction."""
  168. node_id = instruction.variable.node_id
  169. if node_id in self.local_vars:
  170. self.update_result(self.local_vars[node_id])
  171. raise primitive_functions.PrimitiveFinished(None)
  172. else:
  173. local_node, = yield [("CN", [])]
  174. yield [("CE", [self.gc_root_node, local_node])]
  175. self.update_result(local_node)
  176. self.local_vars[node_id] = local_node
  177. raise primitive_functions.PrimitiveFinished(None)
  178. def interpret_global(self, instruction):
  179. """Interprets a (declare) 'global' instruction."""
  180. var_name = instruction.variable.name
  181. task_root = self.get_task_root()
  182. _globals, = yield [("RD", [task_root, "globals"])]
  183. global_var, = yield [("RD", [_globals, var_name])]
  184. if global_var is None:
  185. global_var, = yield [("CN", [])]
  186. yield [("CD", [_globals, var_name, global_var])]
  187. self.update_result(global_var)
  188. yield [("CE", [self.gc_root_node, global_var])]
  189. raise primitive_functions.PrimitiveFinished(None)
  190. def interpret_resolve(self, instruction):
  191. """Interprets a 'resolve' instruction."""
  192. node_id = instruction.variable.node_id
  193. if node_id in self.local_vars:
  194. self.update_result(self.local_vars[node_id])
  195. raise primitive_functions.PrimitiveFinished(None)
  196. else:
  197. task_root = self.get_task_root()
  198. var_name = instruction.variable.name
  199. _globals, = yield [("RD", [task_root, "globals"])]
  200. global_var, = yield [("RD", [_globals, var_name])]
  201. if global_var is None:
  202. raise Exception(jit_runtime.GLOBAL_NOT_FOUND_MESSAGE_FORMAT % var_name)
  203. mvk = self.get_kernel()
  204. if mvk.suggest_function_names and mvk.jit.get_global_body_id(var_name) is None:
  205. global_val, = yield [("RD", [global_var, "value"])]
  206. if global_val is not None:
  207. func_body, = yield [("RD", [global_val, "body"])]
  208. if func_body is not None:
  209. mvk.jit.register_global(func_body, var_name)
  210. self.update_result(global_var)
  211. yield [("CE", [self.gc_root_node, global_var])]
  212. raise primitive_functions.PrimitiveFinished(None)
  213. def interpret_access(self, instruction):
  214. """Interprets an 'access' instruction."""
  215. pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
  216. value_node, = yield [("RD", [pointer_node, "value"])]
  217. self.update_result(value_node)
  218. yield [("CE", [self.gc_root_node, value_node])]
  219. raise primitive_functions.PrimitiveFinished(None)
  220. def interpret_assign(self, instruction):
  221. """Interprets an 'assign' instruction."""
  222. pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
  223. value_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
  224. value_link, = yield [("RDE", [pointer_node, "value"])]
  225. yield [
  226. ("CD", [pointer_node, "value", value_node]),
  227. ("DE", [value_link])]
  228. raise primitive_functions.PrimitiveFinished(None)
  229. INTERPRETERS = {
  230. bytecode_ir.SelectInstruction: interpret_select,
  231. bytecode_ir.WhileInstruction: interpret_while,
  232. bytecode_ir.BreakInstruction: interpret_break,
  233. bytecode_ir.ContinueInstruction: interpret_continue,
  234. bytecode_ir.ReturnInstruction: interpret_return,
  235. bytecode_ir.CallInstruction: interpret_call,
  236. bytecode_ir.ConstantInstruction: interpret_constant,
  237. bytecode_ir.InputInstruction: interpret_input,
  238. bytecode_ir.OutputInstruction: interpret_output,
  239. bytecode_ir.DeclareInstruction: interpret_declare,
  240. bytecode_ir.GlobalInstruction: interpret_global,
  241. bytecode_ir.ResolveInstruction: interpret_resolve,
  242. bytecode_ir.AccessInstruction: interpret_access,
  243. bytecode_ir.AssignInstruction: interpret_assign
  244. }
  245. def interpret_bytecode_function(function_name, body_bytecode, local_arguments, keyword_arguments):
  246. """Interprets the bytecode function with the given name, body, named arguments and
  247. keyword arguments."""
  248. yield [("DEBUG_INFO", [function_name, None, jit_runtime.BYTECODE_INTERPRETER_ORIGIN_NAME])]
  249. task_root = keyword_arguments['task_root']
  250. gc_root_node, = yield [("CN", [])]
  251. gc_root_edge, = yield [("CE", [task_root, gc_root_node])]
  252. interpreter = InterpreterState(gc_root_node, keyword_arguments)
  253. for param_id, arg_node in local_arguments.items():
  254. yield [("CALL_ARGS", [interpreter.import_local, (param_id, arg_node)])]
  255. def __handle_return(exception):
  256. yield [("DE", [gc_root_edge])]
  257. raise primitive_functions.PrimitiveFinished(exception.result)
  258. def __handle_break(_):
  259. raise jit_runtime.UnreachableCodeException(
  260. "Function '%s' tries to break out of a loop that is not currently executing." %
  261. function_name)
  262. def __handle_continue(_):
  263. raise jit_runtime.UnreachableCodeException(
  264. "Function '%s' tries to continue a loop that is not currently executing." %
  265. function_name)
  266. # Perform a nop before interpreting the function.
  267. yield None
  268. yield [("TRY", [])]
  269. yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, __handle_return])]
  270. yield [("CATCH", [BreakException, __handle_break])]
  271. yield [("CATCH", [ContinueException, __handle_continue])]
  272. yield [("CALL_ARGS", [interpreter.interpret, (body_bytecode,)])]
  273. yield [("END_TRY", [])]
  274. raise jit_runtime.UnreachableCodeException("Function '%s' failed to return." % function_name)