jit.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.bytecode_ir as bytecode_ir
  3. import modelverse_jit.bytecode_parser as bytecode_parser
  4. import modelverse_jit.bytecode_to_tree as bytecode_to_tree
  5. import modelverse_jit.tree_ir as tree_ir
  6. import modelverse_jit.runtime as jit_runtime
  7. import keyword
  8. # Import JitCompilationFailedException because it used to be defined
  9. # in this module.
  10. JitCompilationFailedException = jit_runtime.JitCompilationFailedException
  11. def map_and_simplify_generator(function, instruction):
  12. """Applies the given mapping function to every instruction in the tree
  13. that has the given instruction as root, and simplifies it on-the-fly.
  14. This is at least as powerful as first mapping and then simplifying, as
  15. maps and simplifications are interspersed.
  16. This function assumes that function creates a generator that returns by
  17. raising a primitive_functions.PrimitiveFinished."""
  18. # First handle the children by mapping on them and then simplifying them.
  19. new_children = []
  20. for inst in instruction.get_children():
  21. new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
  22. new_children.append(new_inst)
  23. # Then apply the function to the top-level node.
  24. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
  25. # Finally, simplify the transformed top-level node.
  26. raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
  27. def expand_constant_read(instruction):
  28. """Tries to replace a read of a constant node by a literal."""
  29. if isinstance(instruction, tree_ir.ReadValueInstruction) and \
  30. isinstance(instruction.node_id, tree_ir.LiteralInstruction):
  31. val, = yield [("RV", [instruction.node_id.literal])]
  32. raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
  33. else:
  34. raise primitive_functions.PrimitiveFinished(instruction)
  35. def optimize_tree_ir(instruction):
  36. """Optimizes an IR tree."""
  37. return map_and_simplify_generator(expand_constant_read, instruction)
  38. def print_value(val):
  39. """A thin wrapper around 'print'."""
  40. print(val)
  41. class ModelverseJit(object):
  42. """A high-level interface to the modelverse JIT compiler."""
  43. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  44. self.todo_entry_points = set()
  45. self.no_jit_entry_points = set()
  46. self.jitted_entry_points = {}
  47. self.jitted_parameters = {}
  48. self.jit_globals = {
  49. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
  50. jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
  51. jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input
  52. }
  53. self.bytecode_graphs = {}
  54. self.jit_count = 0
  55. self.max_instructions = max_instructions
  56. self.compiled_function_lookup = compiled_function_lookup
  57. # jit_intrinsics is a function name -> intrinsic map.
  58. self.jit_intrinsics = {}
  59. self.compilation_dependencies = {}
  60. self.jit_enabled = True
  61. self.direct_calls_allowed = True
  62. self.tracing_enabled = False
  63. self.input_function_enabled = False
  64. self.nop_insertion_enabled = True
  65. self.jit_success_log_function = None
  66. self.jit_code_log_function = None
  67. def set_jit_enabled(self, is_enabled=True):
  68. """Enables or disables the JIT."""
  69. self.jit_enabled = is_enabled
  70. def allow_direct_calls(self, is_allowed=True):
  71. """Allows or disallows direct calls from jitted to jitted code."""
  72. self.direct_calls_allowed = is_allowed
  73. def use_input_function(self, is_enabled=True):
  74. """Configures the JIT to compile 'input' instructions as function calls."""
  75. self.input_function_enabled = is_enabled
  76. def enable_tracing(self, is_enabled=True):
  77. """Enables or disables tracing for jitted code."""
  78. self.tracing_enabled = is_enabled
  79. def enable_nop_insertion(self, is_enabled=True):
  80. """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop
  81. back-edges. Inserting nops sacrifices performance to keep the jitted code from
  82. blocking the thread of execution by consuming all resources; nops give the
  83. Modelverse server an opportunity to interrupt the currently running code."""
  84. self.nop_insertion_enabled = is_enabled
  85. def set_jit_success_log(self, log_function=print_value):
  86. """Configures this JIT instance with a function that prints output to a log.
  87. Success and failure messages for specific functions are then sent to said log."""
  88. self.jit_success_log_function = log_function
  89. def set_jit_code_log(self, log_function=print_value):
  90. """Configures this JIT instance with a function that prints output to a log.
  91. Function definitions of jitted functions are then sent to said log."""
  92. self.jit_code_log_function = log_function
  93. def mark_entry_point(self, body_id):
  94. """Marks the node with the given identifier as a function entry point."""
  95. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  96. self.todo_entry_points.add(body_id)
  97. def is_entry_point(self, body_id):
  98. """Tells if the node with the given identifier is a function entry point."""
  99. return body_id in self.todo_entry_points or \
  100. body_id in self.no_jit_entry_points or \
  101. body_id in self.jitted_entry_points
  102. def is_jittable_entry_point(self, body_id):
  103. """Tells if the node with the given identifier is a function entry point that
  104. has not been marked as non-jittable. This only returns `True` if the JIT
  105. is enabled and the function entry point has been marked jittable, or if
  106. the function has already been compiled."""
  107. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  108. self.has_compiled(body_id))
  109. def has_compiled(self, body_id):
  110. """Tests if the function belonging to the given body node has been compiled yet."""
  111. return body_id in self.jitted_entry_points
  112. def get_compiled_name(self, body_id):
  113. """Gets the name of the compiled version of the given body node in the JIT
  114. global state."""
  115. return self.jitted_entry_points[body_id]
  116. def mark_no_jit(self, body_id):
  117. """Informs the JIT that the node with the given identifier is a function entry
  118. point that must never be jitted."""
  119. self.no_jit_entry_points.add(body_id)
  120. if body_id in self.todo_entry_points:
  121. self.todo_entry_points.remove(body_id)
  122. def generate_name(self, infix, suggested_name=None):
  123. """Generates a new name or picks the suggested name if it is still
  124. available."""
  125. if suggested_name is not None \
  126. and suggested_name not in self.jit_globals \
  127. and not keyword.iskeyword(suggested_name):
  128. self.jit_count += 1
  129. return suggested_name
  130. else:
  131. function_name = 'jit_%s%d' % (infix, self.jit_count)
  132. self.jit_count += 1
  133. return function_name
  134. def generate_function_name(self, suggested_name=None):
  135. """Generates a new function name or picks the suggested name if it is still
  136. available."""
  137. return self.generate_name('func', suggested_name)
  138. def register_compiled(self, body_id, compiled_function, function_name=None):
  139. """Registers a compiled entry point with the JIT."""
  140. # Get the function's name.
  141. function_name = self.generate_function_name(function_name)
  142. # Map the body id to the given parameter list.
  143. self.jitted_entry_points[body_id] = function_name
  144. self.jit_globals[function_name] = compiled_function
  145. if body_id in self.todo_entry_points:
  146. self.todo_entry_points.remove(body_id)
  147. def import_value(self, value, suggested_name=None):
  148. """Imports the given value into the JIT's global scope, with the given suggested name.
  149. The actual name of the value (within the JIT's global scope) is returned."""
  150. actual_name = self.generate_name('import', suggested_name)
  151. self.jit_globals[actual_name] = value
  152. return actual_name
  153. def lookup_compiled_function(self, name):
  154. """Looks up a compiled function by name. Returns a matching function,
  155. or None if no function was found."""
  156. if name is None:
  157. return None
  158. elif name in self.jit_globals:
  159. return self.jit_globals[name]
  160. elif self.compiled_function_lookup is not None:
  161. return self.compiled_function_lookup(name)
  162. else:
  163. return None
  164. def get_intrinsic(self, name):
  165. """Tries to find an intrinsic version of the function with the
  166. given name."""
  167. if name in self.jit_intrinsics:
  168. return self.jit_intrinsics[name]
  169. else:
  170. return None
  171. def register_intrinsic(self, name, intrinsic_function):
  172. """Registers the given intrisic with the JIT. This will make the JIT replace calls to
  173. the function with the given entry point by an application of the specified function."""
  174. self.jit_intrinsics[name] = intrinsic_function
  175. def register_binary_intrinsic(self, name, operator):
  176. """Registers an intrinsic with the JIT that represents the given binary operation."""
  177. self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction(
  178. tree_ir.BinaryInstruction(
  179. tree_ir.ReadValueInstruction(a),
  180. operator,
  181. tree_ir.ReadValueInstruction(b))))
  182. def register_unary_intrinsic(self, name, operator):
  183. """Registers an intrinsic with the JIT that represents the given unary operation."""
  184. self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
  185. tree_ir.UnaryInstruction(
  186. operator,
  187. tree_ir.ReadValueInstruction(a))))
  188. def register_cast_intrinsic(self, name, target_type):
  189. """Registers an intrinsic with the JIT that represents a unary conversion operator."""
  190. self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
  191. tree_ir.CallInstruction(
  192. tree_ir.LoadGlobalInstruction(target_type.__name__),
  193. [tree_ir.ReadValueInstruction(a)])))
  194. def jit_signature(self, body_id):
  195. """Acquires the signature for the given body id node, which consists of the
  196. parameter variables, parameter name and a flag that tells if the given function
  197. is mutable."""
  198. if body_id not in self.jitted_parameters:
  199. signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
  200. signature_id = signature_id[0]
  201. param_set_id, is_mutable = yield [
  202. ("RD", [signature_id, "params"]),
  203. ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
  204. if param_set_id is None:
  205. self.jitted_parameters[body_id] = ([], [], is_mutable)
  206. else:
  207. param_name_ids, = yield [("RDK", [param_set_id])]
  208. param_names = yield [("RV", [n]) for n in param_name_ids]
  209. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  210. self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
  211. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  212. def jit_parse_bytecode(self, body_id):
  213. """Parses the given function body as a bytecode graph."""
  214. if body_id in self.bytecode_graphs:
  215. raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id])
  216. parser = bytecode_parser.BytecodeParser()
  217. result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])]
  218. self.bytecode_graphs[body_id] = result
  219. raise primitive_functions.PrimitiveFinished(result)
  220. def jit_compile(self, user_root, body_id, suggested_name=None):
  221. """Tries to jit the function defined by the given entry point id and parameter list."""
  222. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  223. # pylint: disable=I0011,W0122
  224. if body_id is None:
  225. raise ValueError('body_id cannot be None')
  226. elif body_id in self.jitted_entry_points:
  227. # We have already compiled this function.
  228. raise primitive_functions.PrimitiveFinished(
  229. self.jit_globals[self.jitted_entry_points[body_id]])
  230. elif body_id in self.no_jit_entry_points:
  231. # We're not allowed to jit this function or have tried and failed before.
  232. raise JitCompilationFailedException(
  233. 'Cannot jit function %s at %d because it is marked non-jittable.' % (
  234. '' if suggested_name is None else "'" + suggested_name + "'",
  235. body_id))
  236. elif not self.jit_enabled:
  237. # We're not allowed to jit anything.
  238. raise JitCompilationFailedException(
  239. 'Cannot jit function %s at %d because the JIT has been disabled.' % (
  240. '' if suggested_name is None else "'" + suggested_name + "'",
  241. body_id))
  242. # Generate a name for the function we're about to analyze, and pretend that
  243. # it already exists. (we need to do this for recursive functions)
  244. function_name = self.generate_function_name(suggested_name)
  245. self.jitted_entry_points[body_id] = function_name
  246. self.jit_globals[function_name] = None
  247. (parameter_ids, parameter_list, is_mutable), = yield [
  248. ("CALL_ARGS", [self.jit_signature, (body_id,)])]
  249. param_dict = dict(zip(parameter_ids, parameter_list))
  250. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  251. dependencies = set([body_id])
  252. self.compilation_dependencies[body_id] = dependencies
  253. def handle_jit_exception(exception):
  254. # If analysis fails, then a JitCompilationFailedException will be thrown.
  255. del self.compilation_dependencies[body_id]
  256. for dep in dependencies:
  257. self.mark_no_jit(dep)
  258. if dep in self.jitted_entry_points:
  259. del self.jitted_entry_points[dep]
  260. failure_message = "%s (function '%s' at %d)" % (
  261. exception.message, function_name, body_id)
  262. if self.jit_success_log_function is not None:
  263. self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
  264. raise JitCompilationFailedException(failure_message)
  265. # Try to analyze the function's body.
  266. yield [("TRY", [])]
  267. yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
  268. if is_mutable:
  269. # We can't just JIT mutable functions. That'd be dangerous.
  270. raise JitCompilationFailedException(
  271. "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
  272. body_bytecode, = yield [("CALL_ARGS", [self.jit_parse_bytecode, (body_id,)])]
  273. state = bytecode_to_tree.AnalysisState(
  274. self, body_id, user_root, body_param_dict,
  275. self.max_instructions)
  276. constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
  277. yield [("END_TRY", [])]
  278. del self.compilation_dependencies[body_id]
  279. # Write a prologue and prepend it to the generated function body.
  280. prologue_statements = []
  281. # Create a LOCALS_NODE_NAME node, and connect it to the user root.
  282. prologue_statements.append(
  283. tree_ir.create_new_local_node(
  284. jit_runtime.LOCALS_NODE_NAME,
  285. tree_ir.LoadIndexInstruction(
  286. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
  287. tree_ir.LiteralInstruction('user_root')),
  288. jit_runtime.LOCALS_EDGE_NAME))
  289. for (key, val) in param_dict.items():
  290. arg_ptr = tree_ir.create_new_local_node(
  291. body_param_dict[key],
  292. tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
  293. prologue_statements.append(arg_ptr)
  294. prologue_statements.append(
  295. tree_ir.CreateDictionaryEdgeInstruction(
  296. tree_ir.LoadLocalInstruction(body_param_dict[key]),
  297. tree_ir.LiteralInstruction('value'),
  298. tree_ir.LoadLocalInstruction(val)))
  299. constructed_body = tree_ir.create_block(
  300. *(prologue_statements + [constructed_body]))
  301. # Optimize the function's body.
  302. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
  303. # Shield temporaries from the GC.
  304. constructed_body = tree_ir.protect_temporaries_from_gc(
  305. constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
  306. # Wrap the IR in a function definition, give it a unique name.
  307. constructed_function = tree_ir.DefineFunctionInstruction(
  308. function_name,
  309. parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
  310. constructed_body)
  311. # Convert the function definition to Python code, and compile it.
  312. exec(str(constructed_function), self.jit_globals)
  313. # Extract the compiled function from the JIT global state.
  314. compiled_function = self.jit_globals[function_name]
  315. if self.jit_success_log_function is not None:
  316. self.jit_success_log_function(
  317. "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
  318. if self.jit_code_log_function is not None:
  319. self.jit_code_log_function(constructed_function)
  320. raise primitive_functions.PrimitiveFinished(compiled_function)