bytecode_interpreter.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. """Interprets parsed bytecode graphs."""
  2. import modelverse_jit.bytecode_ir as bytecode_ir
  3. import modelverse_jit.runtime as jit_runtime
  4. import modelverse_jit.source_map as source_map
  5. import modelverse_kernel.primitives as primitive_functions
  6. class BreakException(Exception):
  7. """A type of exception that is used to interpret 'break' instructions:
  8. the 'break' instructions throw a BreakException, which is then handled
  9. by the appropriate 'while' instruction."""
  10. def __init__(self, loop):
  11. Exception.__init__(self)
  12. self.loop = loop
  13. class ContinueException(Exception):
  14. """A type of exception that is used to interpret 'continue' instructions:
  15. the 'continue' instructions throw a ContinueException, which is then handled
  16. by the appropriate 'while' instruction."""
  17. def __init__(self, loop):
  18. Exception.__init__(self)
  19. self.loop = loop
  20. class InterpreterState(object):
  21. """The state of the bytecode interpreter."""
  22. def __init__(self, gc_root_node, keyword_arg_dict, src_map, nop_period=10):
  23. self.gc_root_node = gc_root_node
  24. self.nop_period = nop_period
  25. self.keyword_arg_dict = keyword_arg_dict
  26. self.src_map = src_map
  27. self.current_result = None
  28. self.nop_phase = 0
  29. self.local_vars = {}
  30. def import_local(self, node_id, value):
  31. """Imports the given value as a local in this interpreter state."""
  32. local_node, = yield [("CN", [])]
  33. yield [
  34. ("CE", [self.gc_root_node, local_node]),
  35. ("CD", [local_node, "value", value])]
  36. self.local_vars[node_id] = local_node
  37. raise primitive_functions.PrimitiveFinished(None)
  38. def schedule_nop(self):
  39. """Increments the nop-phase. If a nop should be performed, then True is returned.
  40. Otherwise, False."""
  41. self.nop_phase += 1
  42. if self.nop_phase == self.nop_period:
  43. self.nop_phase = 0
  44. return True
  45. else:
  46. return False
  47. def update_result(self, new_result):
  48. """Sets the current result to the given value, if it is not None."""
  49. if new_result is not None:
  50. self.current_result = new_result
  51. def get_task_root(self):
  52. """Gets the task root node id."""
  53. return self.keyword_arg_dict['task_root']
  54. def get_kernel(self):
  55. """Gets the Modelverse kernel instance."""
  56. return self.keyword_arg_dict['mvk']
  57. def interpret(self, instruction):
  58. """Interprets the given instruction and returns the current result."""
  59. instruction_type = type(instruction)
  60. if instruction_type in InterpreterState.INTERPRETERS:
  61. old_debug_info = self.src_map.debug_information
  62. debug_info = instruction.debug_information
  63. if debug_info is not None:
  64. self.src_map.debug_information = debug_info
  65. # Interpret the instruction.
  66. yield [("CALL_ARGS",
  67. [InterpreterState.INTERPRETERS[instruction_type], (self, instruction)])]
  68. self.src_map.debug_information = old_debug_info
  69. # Maybe perform a nop.
  70. if self.schedule_nop():
  71. yield None
  72. # Interpret the next instruction.
  73. next_instruction = instruction.next_instruction
  74. if next_instruction is not None:
  75. yield [("TAIL_CALL_ARGS", [self.interpret, (next_instruction,)])]
  76. else:
  77. raise primitive_functions.PrimitiveFinished(self.current_result)
  78. else:
  79. raise jit_runtime.JitCompilationFailedException(
  80. 'Unknown bytecode instruction: %r' % instruction)
  81. def interpret_select(self, instruction):
  82. """Interprets the given 'select' instruction."""
  83. cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
  84. cond_val, = yield [("RV", [cond_node])]
  85. if cond_val:
  86. yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.if_clause,)])]
  87. elif instruction.else_clause is not None:
  88. yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.else_clause,)])]
  89. else:
  90. raise primitive_functions.PrimitiveFinished(None)
  91. def interpret_while(self, instruction):
  92. """Interprets the given 'while' instruction."""
  93. def __handle_break(exception):
  94. if exception.loop == instruction:
  95. # End the loop.
  96. raise primitive_functions.PrimitiveFinished(None)
  97. else:
  98. # Propagate the exception to the next 'while' loop.
  99. raise exception
  100. def __handle_continue(exception):
  101. if exception.loop == instruction:
  102. # Restart the loop.
  103. yield [("TAIL_CALL_ARGS", [self.interpret_while, (instruction,)])]
  104. else:
  105. # Propagate the exception to the next 'while' loop.
  106. raise exception
  107. yield [("TRY", [])]
  108. yield [("CATCH", [BreakException, __handle_break])]
  109. yield [("CATCH", [ContinueException, __handle_continue])]
  110. while 1:
  111. cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
  112. cond_val, = yield [("RV", [cond_node])]
  113. if cond_val:
  114. yield [("CALL_ARGS", [self.interpret, (instruction.body,)])]
  115. else:
  116. break
  117. yield [("END_TRY", [])]
  118. raise primitive_functions.PrimitiveFinished(None)
  119. def interpret_break(self, instruction):
  120. """Interprets the given 'break' instruction."""
  121. raise BreakException(instruction.loop)
  122. def interpret_continue(self, instruction):
  123. """Interprets the given 'continue' instruction."""
  124. raise ContinueException(instruction.loop)
  125. def interpret_return(self, instruction):
  126. """Interprets the given 'return' instruction."""
  127. if instruction.value is None:
  128. raise primitive_functions.InterpretedFunctionFinished(None)
  129. else:
  130. return_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
  131. raise primitive_functions.InterpretedFunctionFinished(return_node)
  132. def interpret_call(self, instruction):
  133. """Interprets the given 'call' instruction."""
  134. target, = yield [("CALL_ARGS", [self.interpret, (instruction.target,)])]
  135. named_args = {}
  136. for name, arg_instruction in instruction.argument_list:
  137. arg, = yield [("CALL_ARGS", [self.interpret, (arg_instruction,)])]
  138. named_args[name] = arg
  139. kwargs = {'function_id': target, 'named_arguments': named_args}
  140. kwargs.update(self.keyword_arg_dict)
  141. result, = yield [("CALL_KWARGS", [jit_runtime.call_function, kwargs])]
  142. if result is not None:
  143. yield [("CE", [self.gc_root_node, result])]
  144. self.update_result(result)
  145. raise primitive_functions.PrimitiveFinished(None)
  146. def interpret_constant(self, instruction):
  147. """Interprets the given 'constant' instruction."""
  148. self.update_result(instruction.constant_id)
  149. raise primitive_functions.PrimitiveFinished(None)
  150. def interpret_input(self, instruction):
  151. """Interprets the given 'input' instruction."""
  152. result, = yield [("CALL_KWARGS", [jit_runtime.get_input, self.keyword_arg_dict])]
  153. self.update_result(result)
  154. yield [("CE", [self.gc_root_node, result])]
  155. raise primitive_functions.PrimitiveFinished(None)
  156. def interpret_output(self, instruction):
  157. """Interprets the given 'output' instruction."""
  158. output_value, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
  159. task_root = self.get_task_root()
  160. last_output, last_output_link, new_last_output = yield [
  161. ("RD", [task_root, "last_output"]),
  162. ("RDE", [task_root, "last_output"]),
  163. ("CN", [])
  164. ]
  165. yield [
  166. ("CD", [last_output, "value", output_value]),
  167. ("CD", [last_output, "next", new_last_output]),
  168. ("CD", [task_root, "last_output", new_last_output]),
  169. ("DE", [last_output_link])
  170. ]
  171. yield None
  172. raise primitive_functions.PrimitiveFinished(None)
  173. def interpret_declare(self, instruction):
  174. """Interprets a 'declare' (local) instruction."""
  175. node_id = instruction.variable.node_id
  176. if node_id in self.local_vars:
  177. raise primitive_functions.PrimitiveFinished(None)
  178. else:
  179. local_node, = yield [("CN", [])]
  180. yield [("CE", [self.gc_root_node, local_node])]
  181. self.local_vars[node_id] = local_node
  182. raise primitive_functions.PrimitiveFinished(None)
  183. def interpret_global(self, instruction):
  184. """Interprets a (declare) 'global' instruction."""
  185. var_name = instruction.variable.name
  186. task_root = self.get_task_root()
  187. _globals, = yield [("RD", [task_root, "globals"])]
  188. global_var, = yield [("RD", [_globals, var_name])]
  189. if global_var is None:
  190. global_var, = yield [("CN", [])]
  191. yield [("CD", [_globals, var_name, global_var])]
  192. raise primitive_functions.PrimitiveFinished(None)
  193. def interpret_resolve(self, instruction):
  194. """Interprets a 'resolve' instruction."""
  195. node_id = instruction.variable.node_id
  196. if node_id in self.local_vars:
  197. self.update_result(self.local_vars[node_id])
  198. raise primitive_functions.PrimitiveFinished(None)
  199. else:
  200. task_root = self.get_task_root()
  201. var_name = instruction.variable.name
  202. _globals, = yield [("RD", [task_root, "globals"])]
  203. global_var, = yield [("RD", [_globals, var_name])]
  204. if global_var is None:
  205. raise Exception(jit_runtime.GLOBAL_NOT_FOUND_MESSAGE_FORMAT % var_name)
  206. mvk = self.get_kernel()
  207. if mvk.suggest_function_names and mvk.jit.get_global_body_id(var_name) is None:
  208. global_val, = yield [("RD", [global_var, "value"])]
  209. if global_val is not None:
  210. func_body, = yield [("RD", [global_val, "body"])]
  211. if func_body is not None:
  212. mvk.jit.register_global(func_body, var_name)
  213. self.update_result(global_var)
  214. yield [("CE", [self.gc_root_node, global_var])]
  215. raise primitive_functions.PrimitiveFinished(None)
  216. def interpret_access(self, instruction):
  217. """Interprets an 'access' instruction."""
  218. pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
  219. value_node, = yield [("RD", [pointer_node, "value"])]
  220. self.update_result(value_node)
  221. yield [("CE", [self.gc_root_node, value_node])]
  222. raise primitive_functions.PrimitiveFinished(None)
  223. def interpret_assign(self, instruction):
  224. """Interprets an 'assign' instruction."""
  225. pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
  226. value_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
  227. value_link, = yield [("RDE", [pointer_node, "value"])]
  228. yield [
  229. ("CD", [pointer_node, "value", value_node]),
  230. ("DE", [value_link])]
  231. raise primitive_functions.PrimitiveFinished(None)
  232. INTERPRETERS = {
  233. bytecode_ir.SelectInstruction: interpret_select,
  234. bytecode_ir.WhileInstruction: interpret_while,
  235. bytecode_ir.BreakInstruction: interpret_break,
  236. bytecode_ir.ContinueInstruction: interpret_continue,
  237. bytecode_ir.ReturnInstruction: interpret_return,
  238. bytecode_ir.CallInstruction: interpret_call,
  239. bytecode_ir.ConstantInstruction: interpret_constant,
  240. bytecode_ir.InputInstruction: interpret_input,
  241. bytecode_ir.OutputInstruction: interpret_output,
  242. bytecode_ir.DeclareInstruction: interpret_declare,
  243. bytecode_ir.GlobalInstruction: interpret_global,
  244. bytecode_ir.ResolveInstruction: interpret_resolve,
  245. bytecode_ir.AccessInstruction: interpret_access,
  246. bytecode_ir.AssignInstruction: interpret_assign
  247. }
  248. def interpret_bytecode_function(function_name, body_bytecode, local_arguments, keyword_arguments):
  249. """Interprets the bytecode function with the given name, body, named arguments and
  250. keyword arguments."""
  251. src_map = source_map.ManualSourceMap()
  252. yield [("DEBUG_INFO", [function_name, src_map, jit_runtime.BYTECODE_INTERPRETER_ORIGIN_NAME])]
  253. task_root = keyword_arguments['task_root']
  254. gc_root_node, = yield [("CN", [])]
  255. gc_root_edge, = yield [("CE", [task_root, gc_root_node])]
  256. interpreter = InterpreterState(gc_root_node, keyword_arguments, src_map)
  257. for param_id, arg_node in local_arguments.items():
  258. yield [("CALL_ARGS", [interpreter.import_local, (param_id, arg_node)])]
  259. def __handle_return(exception):
  260. yield [("DE", [gc_root_edge])]
  261. raise primitive_functions.PrimitiveFinished(exception.result)
  262. def __handle_break(_):
  263. raise jit_runtime.UnreachableCodeException(
  264. "Function '%s' tries to break out of a loop that is not currently executing." %
  265. function_name)
  266. def __handle_continue(_):
  267. raise jit_runtime.UnreachableCodeException(
  268. "Function '%s' tries to continue a loop that is not currently executing." %
  269. function_name)
  270. # Perform a nop before interpreting the function.
  271. yield None
  272. yield [("TRY", [])]
  273. yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, __handle_return])]
  274. yield [("CATCH", [BreakException, __handle_break])]
  275. yield [("CATCH", [ContinueException, __handle_continue])]
  276. yield [("CALL_ARGS", [interpreter.interpret, (body_bytecode,)])]
  277. yield [("END_TRY", [])]
  278. raise jit_runtime.UnreachableCodeException("Function '%s' failed to return." % function_name)