jit.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import math
  2. import keyword
  3. from collections import defaultdict
  4. import modelverse_kernel.primitives as primitive_functions
  5. import modelverse_jit.runtime as jit_runtime
  6. # Import JitCompilationFailedException because it used to be defined
  7. # in this module.
  8. JitCompilationFailedException = jit_runtime.JitCompilationFailedException
  9. def map_and_simplify_generator(function, instruction):
  10. """Applies the given mapping function to every instruction in the tree
  11. that has the given instruction as root, and simplifies it on-the-fly.
  12. This is at least as powerful as first mapping and then simplifying, as
  13. maps and simplifications are interspersed.
  14. This function assumes that function creates a generator that returns by
  15. raising a primitive_functions.PrimitiveFinished."""
  16. # First handle the children by mapping on them and then simplifying them.
  17. new_children = []
  18. for inst in instruction.get_children():
  19. new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
  20. new_children.append(new_inst)
  21. # Then apply the function to the top-level node.
  22. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
  23. # Finally, simplify the transformed top-level node.
  24. raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
  25. def expand_constant_read(instruction):
  26. """Tries to replace a read of a constant node by a literal."""
  27. if isinstance(instruction, tree_ir.ReadValueInstruction) and \
  28. isinstance(instruction.node_id, tree_ir.LiteralInstruction):
  29. val, = yield [("RV", [instruction.node_id.literal])]
  30. raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
  31. else:
  32. raise primitive_functions.PrimitiveFinished(instruction)
  33. def optimize_tree_ir(instruction):
  34. """Optimizes an IR tree."""
  35. return map_and_simplify_generator(expand_constant_read, instruction)
  36. def create_bare_function(function_name, parameter_list, function_body):
  37. """Creates a function definition from the given function name, parameter list
  38. and function body. No prolog is included."""
  39. # Wrap the IR in a function definition, give it a unique name.
  40. return tree_ir.DefineFunctionInstruction(
  41. function_name,
  42. parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
  43. function_body)
  44. def create_function(
  45. function_name, parameter_list, param_dict,
  46. body_param_dict, function_body, source_map_name=None,
  47. compatible_temporary_protects=False):
  48. """Creates a function from the given function name, parameter list,
  49. variable-to-parameter name map, variable-to-local name map and
  50. function body. An optional source map can be included, too."""
  51. # Write a prologue and prepend it to the generated function body.
  52. prolog_statements = []
  53. # If the source map is not None, then we should generate a "DEBUG_INFO"
  54. # request.
  55. if source_map_name is not None:
  56. prolog_statements.append(
  57. tree_ir.RegisterDebugInfoInstruction(
  58. tree_ir.LiteralInstruction(function_name),
  59. tree_ir.LoadGlobalInstruction(source_map_name),
  60. tree_ir.LiteralInstruction(jit_runtime.BASELINE_JIT_ORIGIN_NAME)))
  61. # Create a LOCALS_NODE_NAME node, and connect it to the user root.
  62. prolog_statements.append(
  63. tree_ir.create_new_local_node(
  64. jit_runtime.LOCALS_NODE_NAME,
  65. tree_ir.LoadIndexInstruction(
  66. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
  67. tree_ir.LiteralInstruction('task_root')),
  68. jit_runtime.LOCALS_EDGE_NAME))
  69. for (key, val) in list(param_dict.items()):
  70. arg_ptr = tree_ir.create_new_local_node(
  71. body_param_dict[key],
  72. tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
  73. prolog_statements.append(arg_ptr)
  74. prolog_statements.append(
  75. tree_ir.CreateDictionaryEdgeInstruction(
  76. tree_ir.LoadLocalInstruction(body_param_dict[key]),
  77. tree_ir.LiteralInstruction('value'),
  78. tree_ir.LoadLocalInstruction(val)))
  79. constructed_body = tree_ir.create_block(
  80. *(prolog_statements + [function_body]))
  81. # Shield temporaries from the GC.
  82. constructed_body = tree_ir.protect_temporaries_from_gc(
  83. constructed_body,
  84. tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME),
  85. compatible_temporary_protects)
  86. return create_bare_function(function_name, parameter_list, constructed_body)
  87. def print_value(val):
  88. """A thin wrapper around 'print'."""
  89. print(val)
  90. class ModelverseJit(object):
  91. """A high-level interface to the modelverse JIT compiler."""
  92. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  93. self.todo_entry_points = set()
  94. self.no_jit_entry_points = set()
  95. self.jitted_parameters = {}
  96. self.jit_globals = {
  97. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
  98. }
  99. # jitted_entry_points maps body ids to values in jit_globals.
  100. self.jitted_entry_points = {}
  101. # global_functions maps global value names to body ids.
  102. self.global_functions = {}
  103. # global_functions_inv maps body ids to global value names.
  104. self.global_functions_inv = {}
  105. # bytecode_graphs maps body ids to their parsed bytecode graphs.
  106. self.bytecode_graphs = {}
  107. # jitted_function_aliases maps body ids to known aliases.
  108. self.jitted_function_aliases = defaultdict(set)
  109. self.jit_count = 0
  110. self.max_instructions = max_instructions
  111. self.compiled_function_lookup = compiled_function_lookup
  112. self.compilation_dependencies = {}
  113. self.jit_enabled = True
  114. self.direct_calls_allowed = True
  115. self.tracing_enabled = False
  116. self.source_maps_enabled = True
  117. self.input_function_enabled = False
  118. self.nop_insertion_enabled = True
  119. self.jit_success_log_function = None
  120. self.jit_code_log_function = None
  121. def set_jit_enabled(self, is_enabled=True):
  122. """Enables or disables the JIT."""
  123. self.jit_enabled = is_enabled
  124. def allow_direct_calls(self, is_allowed=True):
  125. """Allows or disallows direct calls from jitted to jitted code."""
  126. self.direct_calls_allowed = is_allowed
  127. def use_input_function(self, is_enabled=True):
  128. """Configures the JIT to compile 'input' instructions as function calls."""
  129. self.input_function_enabled = is_enabled
  130. def enable_tracing(self, is_enabled=True):
  131. """Enables or disables tracing for jitted code."""
  132. self.tracing_enabled = is_enabled
  133. def enable_source_maps(self, is_enabled=True):
  134. """Enables or disables the creation of source maps for jitted code. Source maps
  135. convert lines in the generated code to debug information.
  136. Source maps are enabled by default."""
  137. self.source_maps_enabled = is_enabled
  138. def enable_nop_insertion(self, is_enabled=True):
  139. """Enables or disables nop insertion for jitted code. If enabled, the JIT will
  140. insert nops at loop back-edges. Inserting nops sacrifices performance to
  141. keep the jitted code from blocking the thread of execution and consuming
  142. all resources; nops give the Modelverse server an opportunity to interrupt
  143. the currently running code."""
  144. self.nop_insertion_enabled = is_enabled
  145. def set_jit_success_log(self, log_function=print_value):
  146. """Configures this JIT instance with a function that prints output to a log.
  147. Success and failure messages for specific functions are then sent to said log."""
  148. self.jit_success_log_function = log_function
  149. def set_jit_code_log(self, log_function=print_value):
  150. """Configures this JIT instance with a function that prints output to a log.
  151. Function definitions of jitted functions are then sent to said log."""
  152. self.jit_code_log_function = log_function
  153. def set_function_body_compiler(self, compile_function_body):
  154. """Sets the function that the JIT uses to compile function bodies."""
  155. self.compile_function_body = compile_function_body
  156. def mark_entry_point(self, body_id):
  157. """Marks the node with the given identifier as a function entry point."""
  158. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  159. self.todo_entry_points.add(body_id)
  160. def is_entry_point(self, body_id):
  161. """Tells if the node with the given identifier is a function entry point."""
  162. return body_id in self.todo_entry_points or \
  163. body_id in self.no_jit_entry_points or \
  164. body_id in self.jitted_entry_points
  165. def is_jittable_entry_point(self, body_id):
  166. """Tells if the node with the given identifier is a function entry point that
  167. has not been marked as non-jittable. This only returns `True` if the JIT
  168. is enabled and the function entry point has been marked jittable, or if
  169. the function has already been compiled."""
  170. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  171. self.has_compiled(body_id))
  172. def has_compiled(self, body_id):
  173. """Tests if the function belonging to the given body node has been compiled yet."""
  174. return body_id in self.jitted_entry_points
  175. def get_compiled_name(self, body_id):
  176. """Gets the name of the compiled version of the given body node in the JIT
  177. global state."""
  178. if body_id in self.jitted_entry_points:
  179. return self.jitted_entry_points[body_id]
  180. else:
  181. return None
  182. def mark_no_jit(self, body_id):
  183. """Informs the JIT that the node with the given identifier is a function entry
  184. point that must never be jitted."""
  185. self.no_jit_entry_points.add(body_id)
  186. if body_id in self.todo_entry_points:
  187. self.todo_entry_points.remove(body_id)
  188. def generate_name(self, infix, suggested_name=None):
  189. """Generates a new name or picks the suggested name if it is still
  190. available."""
  191. if suggested_name is not None \
  192. and suggested_name not in self.jit_globals \
  193. and not keyword.iskeyword(suggested_name):
  194. self.jit_count += 1
  195. return suggested_name
  196. else:
  197. function_name = 'jit_%s%d' % (infix, self.jit_count)
  198. self.jit_count += 1
  199. return function_name
  200. def generate_function_name(self, body_id, suggested_name=None):
  201. """Generates a new function name or picks the suggested name if it is still
  202. available."""
  203. if suggested_name is None:
  204. suggested_name = self.get_global_name(body_id)
  205. return self.generate_name('func', suggested_name)
  206. def register_global(self, body_id, global_name):
  207. """Associates the given body id with the given global name."""
  208. self.global_functions[global_name] = body_id
  209. self.global_functions_inv[body_id] = global_name
  210. def get_global_name(self, body_id):
  211. """Gets the name of the global function with the given body id.
  212. Returns None if no known global exists with the given id."""
  213. if body_id in self.global_functions_inv:
  214. return self.global_functions_inv[body_id]
  215. else:
  216. return None
  217. def get_global_body_id(self, global_name):
  218. """Gets the body id of the global function with the given name.
  219. Returns None if no known global exists with the given name."""
  220. if global_name in self.global_functions:
  221. return self.global_functions[global_name]
  222. else:
  223. return None
  224. def register_compiled(self, body_id, compiled_function, function_name=None):
  225. """Registers a compiled entry point with the JIT."""
  226. # Get the function's name.
  227. actual_function_name = self.generate_function_name(body_id,
  228. function_name)
  229. # Map the body id to the given parameter list.
  230. self.jitted_entry_points[body_id] = actual_function_name
  231. self.jit_globals[actual_function_name] = compiled_function
  232. if function_name is not None:
  233. self.register_global(body_id, function_name)
  234. if body_id in self.todo_entry_points:
  235. self.todo_entry_points.remove(body_id)
  236. def import_value(self, value, suggested_name=None):
  237. """Imports the given value into the JIT's global scope, with the given suggested name.
  238. The actual name of the value (within the JIT's global scope) is returned."""
  239. actual_name = self.generate_name('import', suggested_name)
  240. self.jit_globals[actual_name] = value
  241. return actual_name
  242. def __lookup_compiled_body_impl(self, body_id):
  243. """Looks up a compiled function by body id. Returns a matching function,
  244. or None if no function was found."""
  245. if body_id is not None and body_id in self.jitted_entry_points:
  246. return self.jit_globals[self.jitted_entry_points[body_id]]
  247. else:
  248. return None
  249. def __lookup_external_body_impl(self, global_name, body_id):
  250. """Looks up an external function by global name. Returns a matching function,
  251. or None if no function was found."""
  252. if global_name is not None and self.compiled_function_lookup is not None:
  253. result = self.compiled_function_lookup(global_name)
  254. if result is not None and body_id is not None:
  255. self.register_compiled(body_id, result, global_name)
  256. return result
  257. else:
  258. return None
  259. def lookup_compiled_body(self, body_id):
  260. """Looks up a compiled function by body id. Returns a matching function,
  261. or None if no function was found."""
  262. result = self.__lookup_compiled_body_impl(body_id)
  263. if result is not None:
  264. return result
  265. else:
  266. global_name = self.get_global_name(body_id)
  267. return self.__lookup_external_body_impl(global_name, body_id)
  268. def lookup_compiled_function(self, global_name):
  269. """Looks up a compiled function by global name. Returns a matching function,
  270. or None if no function was found."""
  271. body_id = self.get_global_body_id(global_name)
  272. result = self.__lookup_compiled_body_impl(body_id)
  273. if result is not None:
  274. return result
  275. else:
  276. return self.__lookup_external_body_impl(global_name, body_id)
  277. def jit_signature(self, body_id):
  278. """Acquires the signature for the given body id node, which consists of the
  279. parameter variables, parameter name and a flag that tells if the given function
  280. is mutable."""
  281. if body_id not in self.jitted_parameters:
  282. signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
  283. signature_id = signature_id[0]
  284. param_set_id, is_mutable = yield [
  285. ("RD", [signature_id, "params"]),
  286. ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
  287. if param_set_id is None:
  288. self.jitted_parameters[body_id] = ([], [], is_mutable)
  289. else:
  290. param_name_ids, = yield [("RDK", [param_set_id])]
  291. param_names = yield [("RV", [n]) for n in param_name_ids]
  292. #NOTE Patch up strange links...
  293. param_names = [i for i in param_names if i is not None]
  294. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  295. #NOTE that variables might not be in the correct order, as we just read them out!
  296. lst = sorted([(name, var) for name, var in zip(param_names, param_vars)])
  297. param_vars = [i[1] for i in lst]
  298. param_names = [i[0] for i in lst]
  299. self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
  300. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  301. def jit_parse_bytecode(self, body_id):
  302. """Parses the given function body as a bytecode graph."""
  303. if body_id in self.bytecode_graphs:
  304. raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id])
  305. parser = bytecode_parser.BytecodeParser()
  306. result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])]
  307. self.bytecode_graphs[body_id] = result
  308. raise primitive_functions.PrimitiveFinished(result)
  309. def check_jittable(self, body_id, suggested_name=None):
  310. """Checks if the function with the given body id is obviously non-jittable. If it's
  311. non-jittable, then a `JitCompilationFailedException` exception is thrown."""
  312. if body_id is None:
  313. raise ValueError('body_id cannot be None: ' + suggested_name)
  314. elif body_id in self.no_jit_entry_points:
  315. # We're not allowed to jit this function or have tried and failed before.
  316. raise JitCompilationFailedException(
  317. 'Cannot jit function %s at %d because it is marked non-jittable.' % (
  318. '' if suggested_name is None else "'" + suggested_name + "'",
  319. body_id))
  320. elif not self.jit_enabled:
  321. # We're not allowed to jit anything.
  322. raise JitCompilationFailedException(
  323. 'Cannot jit function %s at %d because the JIT has been disabled.' % (
  324. '' if suggested_name is None else "'" + suggested_name + "'",
  325. body_id))
  326. def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
  327. """Replaces the function with the given name by compiling the bytecode at the given
  328. body id."""
  329. if compile_function_body is None:
  330. compile_function_body = self.compile_function_body
  331. self.check_jittable(body_id, function_name)
  332. # Generate a name for the function we're about to analyze, and pretend that
  333. # it already exists. (we need to do this for recursive functions)
  334. self.jitted_entry_points[body_id] = function_name
  335. self.jit_globals[function_name] = None
  336. (_, _, is_mutable), = yield [
  337. ("CALL_ARGS", [self.jit_signature, (body_id,)])]
  338. dependencies = set([body_id])
  339. self.compilation_dependencies[body_id] = dependencies
  340. def handle_jit_exception(exception):
  341. # If analysis fails, then a JitCompilationFailedException will be thrown.
  342. print("EXCEPTION with mutable")
  343. del self.compilation_dependencies[body_id]
  344. for dep in dependencies:
  345. self.mark_no_jit(dep)
  346. if dep in self.jitted_entry_points:
  347. del self.jitted_entry_points[dep]
  348. failure_message = "%s (function '%s' at %d)" % (
  349. str(exception), function_name, body_id)
  350. if self.jit_success_log_function is not None:
  351. self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
  352. raise JitCompilationFailedException(failure_message)
  353. # Try to analyze the function's body.
  354. yield [("TRY", [])]
  355. yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
  356. if is_mutable:
  357. # We can't just JIT mutable functions. That'd be dangerous.
  358. raise JitCompilationFailedException(
  359. "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
  360. compiled_function, = yield [
  361. ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
  362. yield [("END_TRY", [])]
  363. del self.compilation_dependencies[body_id]
  364. if self.jit_success_log_function is not None:
  365. assert self.jitted_entry_points[body_id] == function_name
  366. self.jit_success_log_function(
  367. "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
  368. raise primitive_functions.PrimitiveFinished(compiled_function)
  369. def get_source_map_name(self, function_name):
  370. """Gets the name of the given jitted function's source map. None is returned if source maps
  371. are disabled."""
  372. if self.source_maps_enabled:
  373. return function_name + "_source_map"
  374. else:
  375. return None
  376. def get_can_rejit_name(self, function_name):
  377. """Gets the name of the given jitted function's can-rejit flag."""
  378. return function_name + "_can_rejit"
  379. def jit_define_function(self, function_name, function_def):
  380. """Converts the given tree-IR function definition to Python code, defines it,
  381. and extracts the resulting function."""
  382. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  383. # pylint: disable=I0011,W0122
  384. if self.jit_code_log_function is not None:
  385. self.jit_code_log_function(function_def)
  386. # Convert the function definition to Python code, and compile it.
  387. code_generator = tree_ir.PythonGenerator()
  388. function_def.generate_python_def(code_generator)
  389. source_map_name = self.get_source_map_name(function_name)
  390. if source_map_name is not None:
  391. self.jit_globals[source_map_name] = code_generator.source_map_builder.source_map
  392. exec(str(code_generator), self.jit_globals)
  393. # Extract the compiled function from the JIT global state.
  394. return self.jit_globals[function_name]
  395. def jit_delete_function(self, function_name):
  396. """Deletes the function with the given function name."""
  397. del self.jit_globals[function_name]
  398. def jit_compile(self, task_root, body_id, suggested_name=None):
  399. """Tries to jit the function defined by the given entry point id and parameter list."""
  400. if body_id is None:
  401. raise ValueError('body_id cannot be None: ' + str(suggested_name))
  402. elif body_id in self.jitted_entry_points:
  403. raise primitive_functions.PrimitiveFinished(
  404. self.jit_globals[self.jitted_entry_points[body_id]])
  405. compiled_func = self.lookup_compiled_body(body_id)
  406. if compiled_func is not None:
  407. raise primitive_functions.PrimitiveFinished(compiled_func)
  408. # Generate a name for the function we're about to analyze, and 're-compile'
  409. # it for the first time.
  410. function_name = self.generate_function_name(body_id, suggested_name)
  411. yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
  412. def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
  413. """Re-compiles the given function. If compilation fails, then the can-rejit
  414. flag is set to false."""
  415. old_jitted_func = self.jitted_entry_points[body_id]
  416. def __handle_jit_failed(_):
  417. self.jit_globals[self.get_can_rejit_name(function_name)] = False
  418. self.jitted_entry_points[body_id] = old_jitted_func
  419. self.no_jit_entry_points.remove(body_id)
  420. raise primitive_functions.PrimitiveFinished(None)
  421. yield [("TRY", [])]
  422. yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
  423. jitted_function, = yield [
  424. ("CALL_ARGS",
  425. [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
  426. yield [("END_TRY", [])]
  427. # Update all aliases.
  428. for function_alias in self.jitted_function_aliases[body_id]:
  429. self.jit_globals[function_alias] = jitted_function
  430. def new_compile(self, body_id):
  431. print("Compiling body ID " + str(body_id))
  432. raise JitCompilationFailedException("Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
  433. #raise primitive_functions.PrimitiveFinished("pass")