jit.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049
  1. import math
  2. import keyword
  3. import time
  4. from collections import defaultdict
  5. import modelverse_kernel.primitives as primitive_functions
  6. import modelverse_jit.bytecode_parser as bytecode_parser
  7. import modelverse_jit.bytecode_to_tree as bytecode_to_tree
  8. import modelverse_jit.bytecode_to_cfg as bytecode_to_cfg
  9. import modelverse_jit.bytecode_ir as bytecode_ir
  10. import modelverse_jit.bytecode_interpreter as bytecode_interpreter
  11. import modelverse_jit.cfg_optimization as cfg_optimization
  12. import modelverse_jit.cfg_to_tree as cfg_to_tree
  13. import modelverse_jit.cfg_ir as cfg_ir
  14. import modelverse_jit.tree_ir as tree_ir
  15. import modelverse_jit.runtime as jit_runtime
  16. # Import JitCompilationFailedException because it used to be defined
  17. # in this module.
  18. JitCompilationFailedException = jit_runtime.JitCompilationFailedException
  19. def map_and_simplify_generator(function, instruction):
  20. """Applies the given mapping function to every instruction in the tree
  21. that has the given instruction as root, and simplifies it on-the-fly.
  22. This is at least as powerful as first mapping and then simplifying, as
  23. maps and simplifications are interspersed.
  24. This function assumes that function creates a generator that returns by
  25. raising a primitive_functions.PrimitiveFinished."""
  26. # First handle the children by mapping on them and then simplifying them.
  27. new_children = []
  28. for inst in instruction.get_children():
  29. new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
  30. new_children.append(new_inst)
  31. # Then apply the function to the top-level node.
  32. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
  33. # Finally, simplify the transformed top-level node.
  34. raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
  35. def expand_constant_read(instruction):
  36. """Tries to replace a read of a constant node by a literal."""
  37. if isinstance(instruction, tree_ir.ReadValueInstruction) and \
  38. isinstance(instruction.node_id, tree_ir.LiteralInstruction):
  39. val, = yield [("RV", [instruction.node_id.literal])]
  40. raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
  41. else:
  42. raise primitive_functions.PrimitiveFinished(instruction)
  43. def optimize_tree_ir(instruction):
  44. """Optimizes an IR tree."""
  45. return map_and_simplify_generator(expand_constant_read, instruction)
  46. def create_bare_function(function_name, parameter_list, function_body):
  47. """Creates a function definition from the given function name, parameter list
  48. and function body. No prolog is included."""
  49. # Wrap the IR in a function definition, give it a unique name.
  50. return tree_ir.DefineFunctionInstruction(
  51. function_name,
  52. parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
  53. function_body)
  54. def create_function(
  55. function_name, parameter_list, param_dict,
  56. body_param_dict, function_body, source_map_name=None,
  57. compatible_temporary_protects=False):
  58. """Creates a function from the given function name, parameter list,
  59. variable-to-parameter name map, variable-to-local name map and
  60. function body. An optional source map can be included, too."""
  61. # Write a prologue and prepend it to the generated function body.
  62. prolog_statements = []
  63. # If the source map is not None, then we should generate a "DEBUG_INFO"
  64. # request.
  65. if source_map_name is not None:
  66. prolog_statements.append(
  67. tree_ir.RegisterDebugInfoInstruction(
  68. tree_ir.LiteralInstruction(function_name),
  69. tree_ir.LoadGlobalInstruction(source_map_name),
  70. tree_ir.LiteralInstruction(jit_runtime.BASELINE_JIT_ORIGIN_NAME)))
  71. # Create a LOCALS_NODE_NAME node, and connect it to the user root.
  72. prolog_statements.append(
  73. tree_ir.create_new_local_node(
  74. jit_runtime.LOCALS_NODE_NAME,
  75. tree_ir.LoadIndexInstruction(
  76. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
  77. tree_ir.LiteralInstruction('task_root')),
  78. jit_runtime.LOCALS_EDGE_NAME))
  79. for (key, val) in param_dict.items():
  80. arg_ptr = tree_ir.create_new_local_node(
  81. body_param_dict[key],
  82. tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
  83. prolog_statements.append(arg_ptr)
  84. prolog_statements.append(
  85. tree_ir.CreateDictionaryEdgeInstruction(
  86. tree_ir.LoadLocalInstruction(body_param_dict[key]),
  87. tree_ir.LiteralInstruction('value'),
  88. tree_ir.LoadLocalInstruction(val)))
  89. constructed_body = tree_ir.create_block(
  90. *(prolog_statements + [function_body]))
  91. # Shield temporaries from the GC.
  92. constructed_body = tree_ir.protect_temporaries_from_gc(
  93. constructed_body,
  94. tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME),
  95. compatible_temporary_protects)
  96. return create_bare_function(function_name, parameter_list, constructed_body)
  97. def print_value(val):
  98. """A thin wrapper around 'print'."""
  99. print(val)
  100. class ModelverseJit(object):
  101. """A high-level interface to the modelverse JIT compiler."""
  102. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  103. self.todo_entry_points = set()
  104. self.no_jit_entry_points = set()
  105. self.jitted_parameters = {}
  106. self.jit_globals = {
  107. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
  108. jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
  109. jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
  110. jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
  111. jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
  112. jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit,
  113. jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast,
  114. jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
  115. }
  116. # jitted_entry_points maps body ids to values in jit_globals.
  117. self.jitted_entry_points = {}
  118. # global_functions maps global value names to body ids.
  119. self.global_functions = {}
  120. # global_functions_inv maps body ids to global value names.
  121. self.global_functions_inv = {}
  122. # bytecode_graphs maps body ids to their parsed bytecode graphs.
  123. self.bytecode_graphs = {}
  124. # jitted_function_aliases maps body ids to known aliases.
  125. self.jitted_function_aliases = defaultdict(set)
  126. self.jit_count = 0
  127. self.max_instructions = max_instructions
  128. self.compiled_function_lookup = compiled_function_lookup
  129. # jit_intrinsics is a function name -> intrinsic map.
  130. self.jit_intrinsics = {}
  131. # cfg_jit_intrinsics is a function name -> intrinsic map.
  132. self.cfg_jit_intrinsics = {}
  133. self.compilation_dependencies = {}
  134. self.jit_enabled = True
  135. self.direct_calls_allowed = True
  136. self.tracing_enabled = False
  137. self.source_maps_enabled = True
  138. self.input_function_enabled = False
  139. self.nop_insertion_enabled = True
  140. self.thunks_enabled = True
  141. self.jit_success_log_function = None
  142. self.jit_code_log_function = None
  143. self.jit_timing_log = None
  144. self.compile_function_body = compile_function_body_baseline
  145. def set_jit_enabled(self, is_enabled=True):
  146. """Enables or disables the JIT."""
  147. self.jit_enabled = is_enabled
  148. def allow_direct_calls(self, is_allowed=True):
  149. """Allows or disallows direct calls from jitted to jitted code."""
  150. self.direct_calls_allowed = is_allowed
  151. def use_input_function(self, is_enabled=True):
  152. """Configures the JIT to compile 'input' instructions as function calls."""
  153. self.input_function_enabled = is_enabled
  154. def enable_tracing(self, is_enabled=True):
  155. """Enables or disables tracing for jitted code."""
  156. self.tracing_enabled = is_enabled
  157. def enable_source_maps(self, is_enabled=True):
  158. """Enables or disables the creation of source maps for jitted code. Source maps
  159. convert lines in the generated code to debug information.
  160. Source maps are enabled by default."""
  161. self.source_maps_enabled = is_enabled
  162. def enable_nop_insertion(self, is_enabled=True):
  163. """Enables or disables nop insertion for jitted code. If enabled, the JIT will
  164. insert nops at loop back-edges. Inserting nops sacrifices performance to
  165. keep the jitted code from blocking the thread of execution and consuming
  166. all resources; nops give the Modelverse server an opportunity to interrupt
  167. the currently running code."""
  168. self.nop_insertion_enabled = is_enabled
  169. def enable_thunks(self, is_enabled=True):
  170. """Enables or disables thunks for jitted code. Thunks delay the compilation of
  171. functions until they are actually used. Thunks generally reduce start-up
  172. time.
  173. Thunks are enabled by default."""
  174. self.thunks_enabled = is_enabled
  175. def set_jit_success_log(self, log_function=print_value):
  176. """Configures this JIT instance with a function that prints output to a log.
  177. Success and failure messages for specific functions are then sent to said log."""
  178. self.jit_success_log_function = log_function
  179. def set_jit_code_log(self, log_function=print_value):
  180. """Configures this JIT instance with a function that prints output to a log.
  181. Function definitions of jitted functions are then sent to said log."""
  182. self.jit_code_log_function = log_function
  183. def set_function_body_compiler(self, compile_function_body):
  184. """Sets the function that the JIT uses to compile function bodies."""
  185. self.compile_function_body = compile_function_body
  186. def set_jit_timing_log(self, log_function=print_value):
  187. """Configures this JIT instance with a function that prints output to a log.
  188. The time it takes to compile functions is then sent to this log."""
  189. self.jit_timing_log = log_function
  190. def mark_entry_point(self, body_id):
  191. """Marks the node with the given identifier as a function entry point."""
  192. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  193. self.todo_entry_points.add(body_id)
  194. def is_entry_point(self, body_id):
  195. """Tells if the node with the given identifier is a function entry point."""
  196. return body_id in self.todo_entry_points or \
  197. body_id in self.no_jit_entry_points or \
  198. body_id in self.jitted_entry_points
  199. def is_jittable_entry_point(self, body_id):
  200. """Tells if the node with the given identifier is a function entry point that
  201. has not been marked as non-jittable. This only returns `True` if the JIT
  202. is enabled and the function entry point has been marked jittable, or if
  203. the function has already been compiled."""
  204. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  205. self.has_compiled(body_id))
  206. def has_compiled(self, body_id):
  207. """Tests if the function belonging to the given body node has been compiled yet."""
  208. return body_id in self.jitted_entry_points
  209. def get_compiled_name(self, body_id):
  210. """Gets the name of the compiled version of the given body node in the JIT
  211. global state."""
  212. if body_id in self.jitted_entry_points:
  213. return self.jitted_entry_points[body_id]
  214. else:
  215. return None
  216. def mark_no_jit(self, body_id):
  217. """Informs the JIT that the node with the given identifier is a function entry
  218. point that must never be jitted."""
  219. self.no_jit_entry_points.add(body_id)
  220. if body_id in self.todo_entry_points:
  221. self.todo_entry_points.remove(body_id)
  222. def generate_name(self, infix, suggested_name=None):
  223. """Generates a new name or picks the suggested name if it is still
  224. available."""
  225. if suggested_name is not None \
  226. and suggested_name not in self.jit_globals \
  227. and not keyword.iskeyword(suggested_name):
  228. self.jit_count += 1
  229. return suggested_name
  230. else:
  231. function_name = 'jit_%s%d' % (infix, self.jit_count)
  232. self.jit_count += 1
  233. return function_name
  234. def generate_function_name(self, body_id, suggested_name=None):
  235. """Generates a new function name or picks the suggested name if it is still
  236. available."""
  237. if suggested_name is None:
  238. suggested_name = self.get_global_name(body_id)
  239. return self.generate_name('func', suggested_name)
  240. def register_global(self, body_id, global_name):
  241. """Associates the given body id with the given global name."""
  242. self.global_functions[global_name] = body_id
  243. self.global_functions_inv[body_id] = global_name
  244. def get_global_name(self, body_id):
  245. """Gets the name of the global function with the given body id.
  246. Returns None if no known global exists with the given id."""
  247. if body_id in self.global_functions_inv:
  248. return self.global_functions_inv[body_id]
  249. else:
  250. return None
  251. def get_global_body_id(self, global_name):
  252. """Gets the body id of the global function with the given name.
  253. Returns None if no known global exists with the given name."""
  254. if global_name in self.global_functions:
  255. return self.global_functions[global_name]
  256. else:
  257. return None
  258. def register_compiled(self, body_id, compiled_function, function_name=None):
  259. """Registers a compiled entry point with the JIT."""
  260. # Get the function's name.
  261. actual_function_name = self.generate_function_name(body_id, function_name)
  262. # Map the body id to the given parameter list.
  263. self.jitted_entry_points[body_id] = actual_function_name
  264. self.jit_globals[actual_function_name] = compiled_function
  265. if function_name is not None:
  266. self.register_global(body_id, function_name)
  267. if body_id in self.todo_entry_points:
  268. self.todo_entry_points.remove(body_id)
  269. def import_value(self, value, suggested_name=None):
  270. """Imports the given value into the JIT's global scope, with the given suggested name.
  271. The actual name of the value (within the JIT's global scope) is returned."""
  272. actual_name = self.generate_name('import', suggested_name)
  273. self.jit_globals[actual_name] = value
  274. return actual_name
  275. def __lookup_compiled_body_impl(self, body_id):
  276. """Looks up a compiled function by body id. Returns a matching function,
  277. or None if no function was found."""
  278. if body_id is not None and body_id in self.jitted_entry_points:
  279. return self.jit_globals[self.jitted_entry_points[body_id]]
  280. else:
  281. return None
  282. def __lookup_external_body_impl(self, global_name, body_id):
  283. """Looks up an external function by global name. Returns a matching function,
  284. or None if no function was found."""
  285. if global_name is not None and self.compiled_function_lookup is not None:
  286. result = self.compiled_function_lookup(global_name)
  287. if result is not None and body_id is not None:
  288. self.register_compiled(body_id, result, global_name)
  289. return result
  290. else:
  291. return None
  292. def lookup_compiled_body(self, body_id):
  293. """Looks up a compiled function by body id. Returns a matching function,
  294. or None if no function was found."""
  295. result = self.__lookup_compiled_body_impl(body_id)
  296. if result is not None:
  297. return result
  298. else:
  299. global_name = self.get_global_name(body_id)
  300. return self.__lookup_external_body_impl(global_name, body_id)
  301. def lookup_compiled_function(self, global_name):
  302. """Looks up a compiled function by global name. Returns a matching function,
  303. or None if no function was found."""
  304. body_id = self.get_global_body_id(global_name)
  305. result = self.__lookup_compiled_body_impl(body_id)
  306. if result is not None:
  307. return result
  308. else:
  309. return self.__lookup_external_body_impl(global_name, body_id)
  310. def get_intrinsic(self, name):
  311. """Tries to find an intrinsic version of the function with the
  312. given name."""
  313. if name in self.jit_intrinsics:
  314. return self.jit_intrinsics[name]
  315. else:
  316. return None
  317. def get_cfg_intrinsic(self, name):
  318. """Tries to find an intrinsic version of the function with the
  319. given name that is specialized for CFGs."""
  320. if name in self.cfg_jit_intrinsics:
  321. return self.cfg_jit_intrinsics[name]
  322. else:
  323. return None
  324. def register_intrinsic(self, name, intrinsic_function, cfg_intrinsic_function=None):
  325. """Registers the given intrisic with the JIT. This will make the JIT replace calls to
  326. the function with the given entry point by an application of the specified function."""
  327. self.jit_intrinsics[name] = intrinsic_function
  328. if cfg_intrinsic_function is not None:
  329. self.register_cfg_intrinsic(name, cfg_intrinsic_function)
  330. def register_cfg_intrinsic(self, name, cfg_intrinsic_function):
  331. """Registers the given intrisic with the JIT. This will make the JIT replace calls to
  332. the function with the given entry point by an application of the specified function."""
  333. self.cfg_jit_intrinsics[name] = cfg_intrinsic_function
  334. def register_binary_intrinsic(self, name, operator):
  335. """Registers an intrinsic with the JIT that represents the given binary operation."""
  336. self.register_intrinsic(
  337. name,
  338. lambda a, b:
  339. tree_ir.CreateNodeWithValueInstruction(
  340. tree_ir.BinaryInstruction(
  341. tree_ir.ReadValueInstruction(a),
  342. operator,
  343. tree_ir.ReadValueInstruction(b))),
  344. lambda original_def, a, b:
  345. original_def.redefine(
  346. cfg_ir.CreateNode(
  347. original_def.insert_before(
  348. cfg_ir.Binary(
  349. original_def.insert_before(cfg_ir.Read(a)),
  350. operator,
  351. original_def.insert_before(cfg_ir.Read(b)))))))
  352. def register_unary_intrinsic(self, name, operator):
  353. """Registers an intrinsic with the JIT that represents the given unary operation."""
  354. self.register_intrinsic(
  355. name,
  356. lambda a:
  357. tree_ir.CreateNodeWithValueInstruction(
  358. tree_ir.UnaryInstruction(
  359. operator,
  360. tree_ir.ReadValueInstruction(a))),
  361. lambda original_def, a:
  362. original_def.redefine(
  363. cfg_ir.CreateNode(
  364. original_def.insert_before(
  365. cfg_ir.Unary(
  366. operator,
  367. original_def.insert_before(cfg_ir.Read(a)))))))
  368. def register_cast_intrinsic(self, name, target_type):
  369. """Registers an intrinsic with the JIT that represents a unary conversion operator."""
  370. self.register_intrinsic(
  371. name,
  372. lambda a:
  373. tree_ir.CreateNodeWithValueInstruction(
  374. tree_ir.CallInstruction(
  375. tree_ir.LoadGlobalInstruction(target_type.__name__),
  376. [tree_ir.ReadValueInstruction(a)])),
  377. lambda original_def, a:
  378. original_def.redefine(
  379. cfg_ir.CreateNode(
  380. original_def.insert_before(
  381. cfg_ir.create_pure_simple_call(
  382. target_type.__name__,
  383. original_def.insert_before(cfg_ir.Read(a)))))))
  384. def jit_signature(self, body_id):
  385. """Acquires the signature for the given body id node, which consists of the
  386. parameter variables, parameter name and a flag that tells if the given function
  387. is mutable."""
  388. if body_id not in self.jitted_parameters:
  389. signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
  390. signature_id = signature_id[0]
  391. param_set_id, is_mutable = yield [
  392. ("RD", [signature_id, "params"]),
  393. ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
  394. if param_set_id is None:
  395. self.jitted_parameters[body_id] = ([], [], is_mutable)
  396. else:
  397. param_name_ids, = yield [("RDK", [param_set_id])]
  398. param_names = yield [("RV", [n]) for n in param_name_ids]
  399. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  400. self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
  401. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  402. def jit_parse_bytecode(self, body_id):
  403. """Parses the given function body as a bytecode graph."""
  404. if body_id in self.bytecode_graphs:
  405. raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id])
  406. parser = bytecode_parser.BytecodeParser()
  407. result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])]
  408. self.bytecode_graphs[body_id] = result
  409. raise primitive_functions.PrimitiveFinished(result)
  410. def check_jittable(self, body_id, suggested_name=None):
  411. """Checks if the function with the given body id is obviously non-jittable. If it's
  412. non-jittable, then a `JitCompilationFailedException` exception is thrown."""
  413. if body_id is None:
  414. raise ValueError('body_id cannot be None')
  415. elif body_id in self.no_jit_entry_points:
  416. # We're not allowed to jit this function or have tried and failed before.
  417. raise JitCompilationFailedException(
  418. 'Cannot jit function %s at %d because it is marked non-jittable.' % (
  419. '' if suggested_name is None else "'" + suggested_name + "'",
  420. body_id))
  421. elif not self.jit_enabled:
  422. # We're not allowed to jit anything.
  423. raise JitCompilationFailedException(
  424. 'Cannot jit function %s at %d because the JIT has been disabled.' % (
  425. '' if suggested_name is None else "'" + suggested_name + "'",
  426. body_id))
  427. def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
  428. """Replaces the function with the given name by compiling the bytecode at the given
  429. body id."""
  430. if self.jit_timing_log is not None:
  431. start_time = time.time()
  432. if compile_function_body is None:
  433. compile_function_body = self.compile_function_body
  434. self.check_jittable(body_id, function_name)
  435. # Generate a name for the function we're about to analyze, and pretend that
  436. # it already exists. (we need to do this for recursive functions)
  437. self.jitted_entry_points[body_id] = function_name
  438. self.jit_globals[function_name] = None
  439. (_, _, is_mutable), = yield [
  440. ("CALL_ARGS", [self.jit_signature, (body_id,)])]
  441. dependencies = set([body_id])
  442. self.compilation_dependencies[body_id] = dependencies
  443. def handle_jit_exception(exception):
  444. # If analysis fails, then a JitCompilationFailedException will be thrown.
  445. del self.compilation_dependencies[body_id]
  446. for dep in dependencies:
  447. self.mark_no_jit(dep)
  448. if dep in self.jitted_entry_points:
  449. del self.jitted_entry_points[dep]
  450. failure_message = "%s (function '%s' at %d)" % (
  451. exception.message, function_name, body_id)
  452. if self.jit_success_log_function is not None:
  453. self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
  454. raise JitCompilationFailedException(failure_message)
  455. # Try to analyze the function's body.
  456. yield [("TRY", [])]
  457. yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
  458. if is_mutable:
  459. # We can't just JIT mutable functions. That'd be dangerous.
  460. raise JitCompilationFailedException(
  461. "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
  462. compiled_function, = yield [
  463. ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
  464. yield [("END_TRY", [])]
  465. del self.compilation_dependencies[body_id]
  466. if self.jit_success_log_function is not None:
  467. assert self.jitted_entry_points[body_id] == function_name
  468. self.jit_success_log_function(
  469. "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
  470. if self.jit_timing_log is not None:
  471. end_time = time.time()
  472. compile_time = end_time - start_time
  473. self.jit_timing_log('Compile time for %s:%f' % (function_name, compile_time))
  474. raise primitive_functions.PrimitiveFinished(compiled_function)
  475. def get_source_map_name(self, function_name):
  476. """Gets the name of the given jitted function's source map. None is returned if source maps
  477. are disabled."""
  478. if self.source_maps_enabled:
  479. return function_name + "_source_map"
  480. else:
  481. return None
  482. def get_can_rejit_name(self, function_name):
  483. """Gets the name of the given jitted function's can-rejit flag."""
  484. return function_name + "_can_rejit"
  485. def jit_define_function(self, function_name, function_def):
  486. """Converts the given tree-IR function definition to Python code, defines it,
  487. and extracts the resulting function."""
  488. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  489. # pylint: disable=I0011,W0122
  490. if self.jit_code_log_function is not None:
  491. self.jit_code_log_function(function_def)
  492. # Convert the function definition to Python code, and compile it.
  493. code_generator = tree_ir.PythonGenerator()
  494. function_def.generate_python_def(code_generator)
  495. source_map_name = self.get_source_map_name(function_name)
  496. if source_map_name is not None:
  497. self.jit_globals[source_map_name] = code_generator.source_map_builder.source_map
  498. exec(str(code_generator), self.jit_globals)
  499. # Extract the compiled function from the JIT global state.
  500. return self.jit_globals[function_name]
  501. def jit_delete_function(self, function_name):
  502. """Deletes the function with the given function name."""
  503. del self.jit_globals[function_name]
  504. def jit_compile(self, task_root, body_id, suggested_name=None):
  505. """Tries to jit the function defined by the given entry point id and parameter list."""
  506. if body_id is None:
  507. raise ValueError('body_id cannot be None')
  508. elif body_id in self.jitted_entry_points:
  509. raise primitive_functions.PrimitiveFinished(
  510. self.jit_globals[self.jitted_entry_points[body_id]])
  511. compiled_func = self.lookup_compiled_body(body_id)
  512. if compiled_func is not None:
  513. raise primitive_functions.PrimitiveFinished(compiled_func)
  514. # Generate a name for the function we're about to analyze, and 're-compile'
  515. # it for the first time.
  516. function_name = self.generate_function_name(body_id, suggested_name)
  517. yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
  518. def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
  519. """Re-compiles the given function. If compilation fails, then the can-rejit
  520. flag is set to false."""
  521. old_jitted_func = self.jitted_entry_points[body_id]
  522. def __handle_jit_failed(_):
  523. self.jit_globals[self.get_can_rejit_name(function_name)] = False
  524. self.jitted_entry_points[body_id] = old_jitted_func
  525. self.no_jit_entry_points.remove(body_id)
  526. raise primitive_functions.PrimitiveFinished(None)
  527. yield [("TRY", [])]
  528. yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
  529. jitted_function, = yield [
  530. ("CALL_ARGS",
  531. [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
  532. yield [("END_TRY", [])]
  533. # Update all aliases.
  534. for function_alias in self.jitted_function_aliases[body_id]:
  535. self.jit_globals[function_alias] = jitted_function
  536. def jit_thunk(self, get_function_body, global_name=None):
  537. """Creates a thunk from the given IR tree that computes the function's body id.
  538. This thunk is a function that will invoke the function whose body id is retrieved.
  539. The thunk's name in the JIT's global context is returned."""
  540. # The general idea is to first create a function that looks a bit like this:
  541. #
  542. # def jit_get_function_body(**kwargs):
  543. # raise primitive_functions.PrimitiveFinished(<get_function_body>)
  544. #
  545. get_function_body_name = self.generate_name('get_function_body')
  546. get_function_body_func_def = create_function(
  547. get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))
  548. get_function_body_func = self.jit_define_function(
  549. get_function_body_name, get_function_body_func_def)
  550. # Next, we want to create a thunk that invokes said function, and then replaces itself.
  551. thunk_name = self.generate_name('thunk', global_name)
  552. def __jit_thunk(**kwargs):
  553. # Compute the body id, and delete the function that computes the body id; we won't
  554. # be needing it anymore after this call.
  555. body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
  556. self.jit_delete_function(get_function_body_name)
  557. # Try to associate the global name with the body id, if that's at all possible.
  558. if global_name is not None:
  559. self.register_global(body_id, global_name)
  560. compiled_function = self.lookup_compiled_body(body_id)
  561. if compiled_function is not None:
  562. # Replace this thunk by the compiled function.
  563. self.jit_globals[thunk_name] = compiled_function
  564. self.jitted_function_aliases[body_id].add(thunk_name)
  565. else:
  566. def __handle_jit_exception(_):
  567. # Replace this thunk by a different thunk: one that calls the interpreter
  568. # directly, without checking if the function is jittable.
  569. (_, parameter_names, _), = yield [
  570. ("CALL_ARGS", [self.jit_signature, (body_id,)])]
  571. def __interpreter_thunk(**new_kwargs):
  572. named_arg_dict = {name : new_kwargs[name] for name in parameter_names}
  573. return jit_runtime.interpret_function_body(
  574. body_id, named_arg_dict, **new_kwargs)
  575. self.jit_globals[thunk_name] = __interpreter_thunk
  576. yield [("TRY", [])]
  577. yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
  578. compiled_function, = yield [
  579. ("CALL_ARGS",
  580. [self.jit_recompile, (kwargs['task_root'], body_id, thunk_name)])]
  581. yield [("END_TRY", [])]
  582. # Call the compiled function.
  583. yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
  584. self.jit_globals[thunk_name] = __jit_thunk
  585. return thunk_name
  586. def jit_thunk_constant_body(self, body_id):
  587. """Creates a thunk from the given body id.
  588. This thunk is a function that will invoke the function whose body id is given.
  589. The thunk's name in the JIT's global context is returned."""
  590. self.lookup_compiled_body(body_id)
  591. compiled_name = self.get_compiled_name(body_id)
  592. if compiled_name is not None:
  593. # We might have compiled the function with the given body id already. In that case,
  594. # we need not bother with constructing the thunk; we can return the compiled function
  595. # right away.
  596. return compiled_name
  597. else:
  598. # Looks like we'll just have to build that thunk after all.
  599. return self.jit_thunk(tree_ir.LiteralInstruction(body_id))
  600. def jit_thunk_constant_function(self, body_id):
  601. """Creates a thunk from the given function id.
  602. This thunk is a function that will invoke the function whose function id is given.
  603. The thunk's name in the JIT's global context is returned."""
  604. return self.jit_thunk(
  605. tree_ir.ReadDictionaryValueInstruction(
  606. tree_ir.LiteralInstruction(body_id),
  607. tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)))
  608. def jit_thunk_global(self, global_name):
  609. """Creates a thunk from given global name.
  610. This thunk is a function that will invoke the function whose body id is given.
  611. The thunk's name in the JIT's global context is returned."""
  612. # We might have compiled the function with the given name already. In that case,
  613. # we need not bother with constructing the thunk; we can return the compiled function
  614. # right away.
  615. body_id = self.get_global_body_id(global_name)
  616. if body_id is not None:
  617. self.lookup_compiled_body(body_id)
  618. compiled_name = self.get_compiled_name(body_id)
  619. if compiled_name is not None:
  620. return compiled_name
  621. # Looks like we'll just have to build that thunk after all.
  622. # We want to look up the global function like so
  623. #
  624. # _globals, = yield [("RD", [kwargs['task_root'], "globals"])]
  625. # global_var, = yield [("RD", [_globals, global_name])]
  626. # function_id, = yield [("RD", [global_var, "value"])]
  627. # body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
  628. #
  629. return self.jit_thunk(
  630. tree_ir.ReadDictionaryValueInstruction(
  631. tree_ir.ReadDictionaryValueInstruction(
  632. tree_ir.ReadDictionaryValueInstruction(
  633. tree_ir.ReadDictionaryValueInstruction(
  634. tree_ir.LoadIndexInstruction(
  635. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
  636. tree_ir.LiteralInstruction('task_root')),
  637. tree_ir.LiteralInstruction('globals')),
  638. tree_ir.LiteralInstruction(global_name)),
  639. tree_ir.LiteralInstruction('value')),
  640. tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
  641. global_name)
  642. def compile_function_body_interpret(jit, function_name, body_id, task_root, header=None):
  643. """Create a function that invokes the interpreter on the given function."""
  644. (parameter_ids, parameter_list, _), = yield [
  645. ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
  646. param_dict = dict(zip(parameter_ids, parameter_list))
  647. body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
  648. def __interpret_function(**kwargs):
  649. if header is not None:
  650. (done, result), = yield [("CALL_KWARGS", [header, kwargs])]
  651. if done:
  652. raise primitive_functions.PrimitiveFinished(result)
  653. local_args = {}
  654. inner_kwargs = dict(kwargs)
  655. for param_id, name in param_dict.items():
  656. local_args[param_id] = inner_kwargs[name]
  657. del inner_kwargs[name]
  658. yield [("TAIL_CALL_ARGS",
  659. [bytecode_interpreter.interpret_bytecode_function,
  660. (function_name, body_bytecode, local_args, inner_kwargs)])]
  661. jit.jit_globals[function_name] = __interpret_function
  662. raise primitive_functions.PrimitiveFinished(__interpret_function)
  663. def compile_function_body_baseline(
  664. jit, function_name, body_id, task_root,
  665. header=None, compatible_temporary_protects=False):
  666. """Have the baseline JIT compile the function with the given name and body id."""
  667. (parameter_ids, parameter_list, _), = yield [
  668. ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
  669. param_dict = dict(zip(parameter_ids, parameter_list))
  670. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  671. body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
  672. state = bytecode_to_tree.AnalysisState(
  673. jit, body_id, task_root, body_param_dict,
  674. jit.max_instructions)
  675. constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
  676. if header is not None:
  677. constructed_body = tree_ir.create_block(header, constructed_body)
  678. # Optimize the function's body.
  679. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
  680. # Wrap the tree IR in a function definition.
  681. constructed_function = create_function(
  682. function_name, parameter_list, param_dict,
  683. body_param_dict, constructed_body, jit.get_source_map_name(function_name),
  684. compatible_temporary_protects)
  685. # Convert the function definition to Python code, and compile it.
  686. raise primitive_functions.PrimitiveFinished(
  687. jit.jit_define_function(function_name, constructed_function))
  688. def compile_function_body_fast(jit, function_name, body_id, _):
  689. """Have the fast JIT compile the function with the given name and body id."""
  690. (parameter_ids, parameter_list, _), = yield [
  691. ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
  692. param_dict = dict(zip(parameter_ids, parameter_list))
  693. body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
  694. bytecode_analyzer = bytecode_to_cfg.AnalysisState(jit, function_name, param_dict)
  695. bytecode_analyzer.analyze(body_bytecode)
  696. entry_point, = yield [
  697. ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])]
  698. if jit.jit_code_log_function is not None:
  699. jit.jit_code_log_function(
  700. "CFG for function '%s' at '%d':\n%s" % (
  701. function_name, body_id,
  702. '\n'.join(map(str, cfg_ir.get_all_reachable_blocks(entry_point)))))
  703. # Lower the CFG to tree IR.
  704. constructed_body = cfg_to_tree.lower_flow_graph(entry_point, jit)
  705. # Optimize the tree that was generated.
  706. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
  707. constructed_function = create_bare_function(function_name, parameter_list, constructed_body)
  708. # Convert the function definition to Python code, and compile it.
  709. raise primitive_functions.PrimitiveFinished(
  710. jit.jit_define_function(function_name, constructed_function))
  711. def favor_large_functions(body_bytecode):
  712. """Computes the initial temperature of a function based on the size of
  713. its body bytecode. Larger functions are favored and the temperature
  714. is incremented by one on every call."""
  715. # The rationale for this heuristic is that it does some damage control:
  716. # we can afford to decide (wrongly) not to fast-jit a small function,
  717. # because we can just fast-jit that function later on. Since the function
  718. # is so small, it will (hopefully) not be able to deal us a heavy blow in
  719. # terms of performance.
  720. #
  721. # If we decide not to fast-jit a large function however, we might end up
  722. # in a situation where said function runs for a long time before we
  723. # realize that we really should have jitted it. And that's exactly what
  724. # this heuristic tries to avoid.
  725. return len(body_bytecode.get_reachable()), 1
  726. def favor_small_functions(body_bytecode):
  727. """Computes the initial temperature of a function based on the size of
  728. its body bytecode. Smaller functions are favored and the temperature
  729. is incremented by one on every call."""
  730. # The rationale for this heuristic is that small functions are easy to
  731. # fast-jit, because they probably won't trigger the non-linear complexity
  732. # of fast-jit's algorithms. So it might be cheaper to fast-jit small
  733. # functions and get a performance boost from that than to fast-jit large
  734. # functions.
  735. return ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - len(body_bytecode.get_reachable()), 1
  736. ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER = 4
  737. ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD = 100
  738. """The threshold temperature at which the adaptive JIT will use the baseline JIT."""
  739. ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 250
  740. """The threshold temperature at which the adaptive JIT will use the fast JIT."""
  741. def favor_loops(body_bytecode):
  742. """Computes the initial temperature of a function. Code within a loop makes
  743. the function hotter; code outside loops makes the function colder. The
  744. temperature is incremented by one on every call."""
  745. reachable_instructions = body_bytecode.get_reachable()
  746. # First set the temperature to the negative number of instructions.
  747. temperature = ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD - len(reachable_instructions)
  748. for instruction in reachable_instructions:
  749. if isinstance(instruction, bytecode_ir.WhileInstruction):
  750. # Then increase the temperature by the number of instructions reachable
  751. # from loop bodies. Note that the algorithm will count nested loops twice.
  752. # This is actually by design.
  753. loop_body_instructions = instruction.body.get_reachable(
  754. lambda x: not isinstance(
  755. x, (bytecode_ir.BreakInstruction, bytecode_ir.ContinueInstruction)))
  756. temperature += ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER * len(loop_body_instructions)
  757. return temperature, 1
  758. def favor_small_loops(body_bytecode):
  759. """Computes the initial temperature of a function. Code within a loop makes
  760. the function hotter; code outside loops makes the function colder. The
  761. temperature is incremented by one on every call."""
  762. reachable_instructions = body_bytecode.get_reachable()
  763. # First set the temperature to the negative number of instructions.
  764. temperature = ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - 50 - len(reachable_instructions)
  765. for instruction in reachable_instructions:
  766. if isinstance(instruction, bytecode_ir.WhileInstruction):
  767. # Then increase the temperature by the number of instructions reachable
  768. # from loop bodies. Note that the algorithm will count nested loops twice.
  769. # This is actually by design.
  770. loop_body_instructions = instruction.body.get_reachable(
  771. lambda x: not isinstance(
  772. x, (bytecode_ir.BreakInstruction, bytecode_ir.ContinueInstruction)))
  773. temperature += (
  774. (ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER ** 2) *
  775. int(math.sqrt(len(loop_body_instructions))))
  776. return temperature, max(int(math.log(len(reachable_instructions), 2)), 1)
  777. class AdaptiveJitState(object):
  778. """Shared state for adaptive JIT compilation."""
  779. def __init__(
  780. self, temperature_counter_name,
  781. temperature_increment, can_rejit_name):
  782. self.temperature_counter_name = temperature_counter_name
  783. self.temperature_increment = temperature_increment
  784. self.can_rejit_name = can_rejit_name
  785. def compile_interpreter(
  786. self, jit, function_name, body_id, task_root):
  787. """Compiles the given function as a function that controls the temperature counter
  788. and calls the interpreter."""
  789. def __increment_temperature(**kwargs):
  790. if jit.jit_globals[self.can_rejit_name]:
  791. temperature_counter_val = jit.jit_globals[self.temperature_counter_name]
  792. temperature_counter_val += self.temperature_increment
  793. jit.jit_globals[self.temperature_counter_name] = temperature_counter_val
  794. if temperature_counter_val >= ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD:
  795. if temperature_counter_val >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
  796. yield [
  797. ("CALL_ARGS",
  798. [jit.jit_rejit,
  799. (task_root, body_id, function_name, compile_function_body_fast)])]
  800. else:
  801. yield [
  802. ("CALL_ARGS",
  803. [jit.jit_rejit,
  804. (task_root, body_id, function_name, self.compile_baseline)])]
  805. result, = yield [("CALL_KWARGS", [jit.jit_globals[function_name], kwargs])]
  806. raise primitive_functions.PrimitiveFinished((True, result))
  807. raise primitive_functions.PrimitiveFinished((False, None))
  808. yield [
  809. ("TAIL_CALL_ARGS",
  810. [compile_function_body_interpret,
  811. (jit, function_name, body_id, task_root, __increment_temperature)])]
  812. def compile_baseline(
  813. self, jit, function_name, body_id, task_root):
  814. """Compiles the given function with the baseline JIT, and inserts logic that controls
  815. the temperature counter."""
  816. (_, parameter_list, _), = yield [
  817. ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
  818. # This tree represents the following logic:
  819. #
  820. # if can_rejit:
  821. # global temperature_counter
  822. # temperature_counter = temperature_counter + temperature_increment
  823. # if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
  824. # yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
  825. # yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
  826. header = tree_ir.SelectInstruction(
  827. tree_ir.LoadGlobalInstruction(self.can_rejit_name),
  828. tree_ir.create_block(
  829. tree_ir.DeclareGlobalInstruction(self.temperature_counter_name),
  830. tree_ir.IgnoreInstruction(
  831. tree_ir.StoreGlobalInstruction(
  832. self.temperature_counter_name,
  833. tree_ir.BinaryInstruction(
  834. tree_ir.LoadGlobalInstruction(self.temperature_counter_name),
  835. '+',
  836. tree_ir.LiteralInstruction(self.temperature_increment)))),
  837. tree_ir.SelectInstruction(
  838. tree_ir.BinaryInstruction(
  839. tree_ir.LoadGlobalInstruction(self.temperature_counter_name),
  840. '>=',
  841. tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
  842. tree_ir.create_block(
  843. tree_ir.RunGeneratorFunctionInstruction(
  844. tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
  845. tree_ir.DictionaryLiteralInstruction([
  846. (tree_ir.LiteralInstruction('task_root'),
  847. bytecode_to_tree.load_task_root()),
  848. (tree_ir.LiteralInstruction('body_id'),
  849. tree_ir.LiteralInstruction(body_id)),
  850. (tree_ir.LiteralInstruction('function_name'),
  851. tree_ir.LiteralInstruction(function_name)),
  852. (tree_ir.LiteralInstruction('compile_function_body'),
  853. tree_ir.LoadGlobalInstruction(
  854. jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
  855. result_type=tree_ir.NO_RESULT_TYPE),
  856. bytecode_to_tree.create_return(
  857. tree_ir.create_jit_call(
  858. tree_ir.LoadGlobalInstruction(function_name),
  859. [(name, tree_ir.LoadLocalInstruction(name))
  860. for name in parameter_list],
  861. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))),
  862. tree_ir.EmptyInstruction())),
  863. tree_ir.EmptyInstruction())
  864. # Compile with the baseline JIT, and insert the header.
  865. yield [
  866. ("TAIL_CALL_ARGS",
  867. [compile_function_body_baseline,
  868. (jit, function_name, body_id, task_root, header, True)])]
  869. def compile_function_body_adaptive(
  870. jit, function_name, body_id, task_root,
  871. temperature_heuristic=favor_loops):
  872. """Compile the function with the given name and body id. An execution engine is picked
  873. automatically, and the function may be compiled again at a later time."""
  874. # The general idea behind this compilation technique is to first use the baseline JIT
  875. # to compile a function, and then switch to the fast JIT when we determine that doing
  876. # so would be a good idea. We maintain a 'temperature' counter, which has an initial value
  877. # and gets incremented every time the function is executed.
  878. body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
  879. initial_temperature, temperature_increment = temperature_heuristic(body_bytecode)
  880. if jit.jit_success_log_function is not None:
  881. jit.jit_success_log_function(
  882. "Initial temperature for '%s': %d" % (function_name, initial_temperature))
  883. if initial_temperature >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
  884. # Initial temperature exceeds the fast-jit threshold.
  885. # Compile this thing with fast-jit right away.
  886. if jit.jit_success_log_function is not None:
  887. jit.jit_success_log_function(
  888. "Compiling '%s' with fast-jit." % function_name)
  889. yield [
  890. ("TAIL_CALL_ARGS",
  891. [compile_function_body_fast, (jit, function_name, body_id, task_root)])]
  892. temperature_counter_name = jit.import_value(
  893. initial_temperature, function_name + "_temperature_counter")
  894. can_rejit_name = jit.get_can_rejit_name(function_name)
  895. jit.jit_globals[can_rejit_name] = True
  896. state = AdaptiveJitState(temperature_counter_name, temperature_increment, can_rejit_name)
  897. if initial_temperature >= ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD:
  898. # Initial temperature exceeds the baseline JIT threshold.
  899. # Compile this thing with baseline JIT right away.
  900. if jit.jit_success_log_function is not None:
  901. jit.jit_success_log_function(
  902. "Compiling '%s' with baseline-jit." % function_name)
  903. yield [
  904. ("TAIL_CALL_ARGS",
  905. [state.compile_baseline, (jit, function_name, body_id, task_root)])]
  906. else:
  907. # Looks like we'll use the interpreter initially.
  908. if jit.jit_success_log_function is not None:
  909. jit.jit_success_log_function(
  910. "Compiling '%s' with bytecode-interpreter." % function_name)
  911. yield [
  912. ("TAIL_CALL_ARGS",
  913. [state.compile_interpreter, (jit, function_name, body_id, task_root)])]