runtime.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.source_map as source_map
  3. class JitCompilationFailedException(Exception):
  4. """A type of exception that is raised when the jit fails to compile a function."""
  5. pass
  6. MUTABLE_FUNCTION_KEY = "mutable"
  7. """A dictionary key for functions that are mutable."""
  8. FUNCTION_BODY_KEY = "body"
  9. """A dictionary key for function bodies."""
  10. KWARGS_PARAMETER_NAME = "kwargs"
  11. """The name of the kwargs parameter in jitted functions."""
  12. CALL_FUNCTION_NAME = "__call_function"
  13. """The name of the '__call_function' function, in the jitted function scope."""
  14. GET_INPUT_FUNCTION_NAME = "__get_input"
  15. """The name of the '__get_input' function, in the jitted function scope."""
  16. JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant_function"
  17. """The name of the jit_thunk_constant_function function in the JIT's global context."""
  18. JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
  19. """The name of the jit_thunk_global function in the JIT's global context."""
  20. JIT_REJIT_FUNCTION_NAME = "__jit_rejit"
  21. """The name of the rejit function in the JIT's global context."""
  22. JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME = "__jit_compile_function_body_fast"
  23. """The name of the compile_function_body_fast function in the JIT's global context."""
  24. UNREACHABLE_FUNCTION_NAME = "__unreachable"
  25. """The name of the unreachable function in the JIT's global context."""
  26. LOCALS_NODE_NAME = "jit_locals"
  27. """The name of the node that is connected to all JIT locals in a given function call."""
  28. LOCALS_EDGE_NAME = "jit_locals_edge"
  29. """The name of the edge that connects the LOCALS_NODE_NAME node to a user root."""
  30. GLOBAL_NOT_FOUND_MESSAGE_FORMAT = "Not found as global: %s"
  31. """The format of the 'not found as global' message. Takes a single argument."""
  32. REFERENCE_INTERPRETER_ORIGIN_NAME = "reference-interpreter"
  33. """The origin name for functions that are interpreted by the reference interpreter."""
  34. BYTECODE_INTERPRETER_ORIGIN_NAME = "bytecode-interpreter"
  35. """The origin name for functions that were produced by the bytecode interpreter."""
  36. BASELINE_JIT_ORIGIN_NAME = "baseline-jit"
  37. """The origin name for functions that were produced by the baseline JIT."""
  38. FAST_JIT_ORIGIN_NAME = "fast-jit"
  39. """The origin name for functions that were produced by the fast JIT."""
  40. UNKNOWN_FUNCTION_REPR = "unknown function"
  41. """The representation used for unknown functions in stack traces."""
  42. UNKNOWN_LOCATION_REPR = "[unknown location] "
  43. """The representation used for unknown locations in stack traces."""
  44. def format_stack_frame(function_name, debug_info, origin='unknown'):
  45. """Formats a stack frame, which consists of a function name, debug
  46. information and an origin."""
  47. if function_name is None:
  48. function_name = UNKNOWN_FUNCTION_REPR
  49. if debug_info is None:
  50. debug_info = UNKNOWN_LOCATION_REPR
  51. return '%sin %s (%s)' % (debug_info, function_name, origin)
  52. def format_trace_message(debug_info, function_name, origin='unknown'):
  53. """Creates a formatted trace message."""
  54. return 'TRACE: %s' % format_stack_frame(function_name, debug_info, origin)
  55. def call_function(function_id, named_arguments, **kwargs):
  56. """Runs the function with the given id, passing it the specified argument dictionary."""
  57. task_root = kwargs['task_root']
  58. kernel = kwargs['mvk']
  59. body_id, is_mutable = yield [
  60. ("RD", [function_id, FUNCTION_BODY_KEY]),
  61. ("RD", [function_id, MUTABLE_FUNCTION_KEY])]
  62. # Try to jit the function here. We might be able to avoid building the stack
  63. # frame.
  64. def handle_jit_failed(_):
  65. """Interprets the function."""
  66. interpreter_args = {'body_id' : body_id, 'named_arguments' : named_arguments}
  67. interpreter_args.update(kwargs)
  68. yield [("TAIL_CALL_KWARGS", [interpret_function_body, interpreter_args])]
  69. if is_mutable is not None:
  70. kernel.jit.mark_no_jit(body_id)
  71. yield [("TAIL_CALL_ARGS", [handle_jit_failed, (None,)])]
  72. else:
  73. kernel.jit.mark_entry_point(body_id)
  74. yield [("TRY", [])]
  75. yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])]
  76. # Try to compile.
  77. compiled_func, = yield [("CALL_ARGS", [kernel.jit_compile, (task_root, body_id)])]
  78. yield [("END_TRY", [])]
  79. # Add the keyword arguments to the argument dictionary.
  80. named_arguments.update(kwargs)
  81. # Run the function.
  82. yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])]
  83. def interpret_function(function_id, named_arguments, **kwargs):
  84. """Makes the interpreter run the function with the given id for the specified
  85. argument dictionary."""
  86. body_id, = yield [("RD", [function_id, FUNCTION_BODY_KEY])]
  87. args = {'body_id' : body_id, named_arguments : named_arguments}
  88. args.update(kwargs)
  89. yield [("TAIL_CALL_KWARGS", [interpret_function_body, args])]
  90. def interpret_function_body(body_id, named_arguments, **kwargs):
  91. """Makes the interpreter run the function body with the given id for the specified
  92. argument dictionary."""
  93. task_root = kwargs['task_root']
  94. kernel = kwargs['mvk']
  95. user_frame, = yield [("RD", [task_root, "frame"])]
  96. inst, = yield [("RD", [user_frame, "IP"])]
  97. kernel.jit.mark_entry_point(body_id)
  98. # Create a new stack frame.
  99. frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
  100. new_returnvalue, intrinsic_return = \
  101. yield [("RDE", [task_root, "frame"]),
  102. ("CNV", ["init"]),
  103. ("CN", []),
  104. ("CN", []),
  105. ("CN", []),
  106. ("CN", []),
  107. ("CN", [])
  108. ]
  109. _, _, _, _, _, _, _, _, _, _ = \
  110. yield [("CD", [task_root, "frame", new_frame]),
  111. ("CD", [new_frame, "evalstack", new_evalstack]),
  112. ("CD", [new_frame, "symbols", new_symbols]),
  113. ("CD", [new_frame, "returnvalue", new_returnvalue]),
  114. ("CD", [new_frame, "caller", inst]),
  115. ("CD", [new_frame, "phase", new_phase]),
  116. ("CD", [new_frame, "IP", body_id]),
  117. ("CD", [new_frame, "prev", user_frame]),
  118. ("CD", [
  119. new_frame,
  120. primitive_functions.EXCEPTION_RETURN_KEY,
  121. intrinsic_return]),
  122. ("DE", [frame_link])
  123. ]
  124. # Put the parameters in the new stack frame's symbol table.
  125. (parameter_vars, parameter_names, _), = yield [
  126. ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])]
  127. parameter_dict = dict(zip(parameter_names, parameter_vars))
  128. for (key, value) in named_arguments.items():
  129. param_var = parameter_dict[key]
  130. variable, = yield [("CN", [])]
  131. yield [("CD", [variable, "value", value])]
  132. symbol_edge, = yield [("CE", [new_symbols, variable])]
  133. yield [("CE", [symbol_edge, param_var])]
  134. taskname = kwargs['taskname']
  135. # Append a debug info record and set up a source map.
  136. kernel.debug_info[taskname].append(UNKNOWN_LOCATION_REPR)
  137. src_map = source_map.InterpreterSourceMap(
  138. kernel, taskname, len(kernel.debug_info[taskname]) - 1)
  139. function_name = kernel.jit.get_global_name(body_id)
  140. if function_name is None:
  141. function_name = UNKNOWN_FUNCTION_REPR
  142. yield [("DEBUG_INFO", [function_name, src_map, REFERENCE_INTERPRETER_ORIGIN_NAME])]
  143. def exception_handler(ex):
  144. # print('Returning from interpreted function. Result: %s' % ex.result)
  145. raise primitive_functions.PrimitiveFinished(ex.result)
  146. # Create an exception handler to catch and translate InterpretedFunctionFinished.
  147. yield [("TRY", [])]
  148. yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, exception_handler])]
  149. while 1:
  150. result, = yield [("CALL_ARGS", [kernel.execute_rule, (taskname,)])]
  151. # An instruction has completed. Forward it.
  152. yield result
  153. class UnreachableCodeException(Exception):
  154. """The type of exception that is thrown when supposedly unreachable code is executed."""
  155. pass
  156. def unreachable():
  157. """Marks unreachable code."""
  158. raise UnreachableCodeException('An unreachable statement was reached.')
  159. def get_input(**parameters):
  160. """Retrieves input."""
  161. mvk = parameters["mvk"]
  162. task_root = parameters["task_root"]
  163. while 1:
  164. yield [("CALL_ARGS", [mvk.input_init, (task_root,)])]
  165. # Finished
  166. if mvk.success:
  167. # Got some input, so we can access it
  168. raise primitive_functions.PrimitiveFinished(mvk.input_value)
  169. else:
  170. # No input, so yield None but don't stop
  171. yield None