jit.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. import modelverse_jit.runtime as jit_runtime
  4. import keyword
  5. # Import JitCompilationFailedException because it used to be defined
  6. # in this module.
  7. JitCompilationFailedException = jit_runtime.JitCompilationFailedException
  8. KWARGS_PARAMETER_NAME = "kwargs"
  9. """The name of the kwargs parameter in jitted functions."""
  10. CALL_FUNCTION_NAME = "__call_function"
  11. """The name of the '__call_function' function, in the jitted function scope."""
  12. GET_INPUT_FUNCTION_NAME = "__get_input"
  13. """The name of the '__get_input' function, in the jitted function scope."""
  14. LOCALS_NODE_NAME = "jit_locals"
  15. """The name of the node that is connected to all JIT locals in a given function call."""
  16. LOCALS_EDGE_NAME = "jit_locals_edge"
  17. """The name of the edge that connects the LOCALS_NODE_NAME node to a user root."""
  18. def get_parameter_names(compiled_function):
  19. """Gets the given compiled function's parameter names."""
  20. if hasattr(compiled_function, '__code__'):
  21. return compiled_function.__code__.co_varnames[
  22. :compiled_function.__code__.co_argcount]
  23. elif hasattr(compiled_function, '__init__'):
  24. return get_parameter_names(compiled_function.__init__)[1:]
  25. else:
  26. raise ValueError("'compiled_function' must be a function or a type.")
  27. def apply_intrinsic(intrinsic_function, named_args):
  28. """Applies the given intrinsic to the given sequence of named arguments."""
  29. param_names = get_parameter_names(intrinsic_function)
  30. if tuple(param_names) == tuple([n for n, _ in named_args]):
  31. # Perfect match. Yay!
  32. return intrinsic_function(**dict(named_args))
  33. else:
  34. # We'll have to store the arguments into locals to preserve
  35. # the order of evaluation.
  36. stored_args = [(name, tree_ir.StoreLocalInstruction(None, arg)) for name, arg in named_args]
  37. arg_value_dict = dict([(name, arg.create_load()) for name, arg in stored_args])
  38. store_instructions = [instruction for _, instruction in stored_args]
  39. return tree_ir.CompoundInstruction(
  40. tree_ir.create_block(*store_instructions),
  41. intrinsic_function(**arg_value_dict))
  42. def map_and_simplify_generator(function, instruction):
  43. """Applies the given mapping function to every instruction in the tree
  44. that has the given instruction as root, and simplifies it on-the-fly.
  45. This is at least as powerful as first mapping and then simplifying, as
  46. maps and simplifications are interspersed.
  47. This function assumes that function creates a generator that returns by
  48. raising a primitive_functions.PrimitiveFinished."""
  49. # First handle the children by mapping on them and then simplifying them.
  50. new_children = []
  51. for inst in instruction.get_children():
  52. new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
  53. new_children.append(new_inst)
  54. # Then apply the function to the top-level node.
  55. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
  56. # Finally, simplify the transformed top-level node.
  57. raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
  58. def expand_constant_read(instruction):
  59. """Tries to replace a read of a constant node by a literal."""
  60. if isinstance(instruction, tree_ir.ReadValueInstruction) and \
  61. isinstance(instruction.node_id, tree_ir.LiteralInstruction):
  62. val, = yield [("RV", [instruction.node_id.literal])]
  63. raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
  64. else:
  65. raise primitive_functions.PrimitiveFinished(instruction)
  66. def optimize_tree_ir(instruction):
  67. """Optimizes an IR tree."""
  68. return map_and_simplify_generator(expand_constant_read, instruction)
  69. def print_value(val):
  70. """A thin wrapper around 'print'."""
  71. print(val)
  72. class ModelverseJit(object):
  73. """A high-level interface to the modelverse JIT compiler."""
  74. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  75. self.todo_entry_points = set()
  76. self.no_jit_entry_points = set()
  77. self.jitted_entry_points = {}
  78. self.jitted_parameters = {}
  79. self.jit_globals = {
  80. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
  81. CALL_FUNCTION_NAME : jit_runtime.call_function,
  82. GET_INPUT_FUNCTION_NAME : jit_runtime.get_input
  83. }
  84. self.jit_count = 0
  85. self.max_instructions = max_instructions
  86. self.compiled_function_lookup = compiled_function_lookup
  87. # jit_intrinsics is a function name -> intrinsic map.
  88. self.jit_intrinsics = {}
  89. self.compilation_dependencies = {}
  90. self.jit_enabled = True
  91. self.direct_calls_allowed = True
  92. self.tracing_enabled = False
  93. self.input_function_enabled = False
  94. self.nop_insertion_enabled = True
  95. self.jit_success_log_function = None
  96. self.jit_code_log_function = None
  97. def set_jit_enabled(self, is_enabled=True):
  98. """Enables or disables the JIT."""
  99. self.jit_enabled = is_enabled
  100. def allow_direct_calls(self, is_allowed=True):
  101. """Allows or disallows direct calls from jitted to jitted code."""
  102. self.direct_calls_allowed = is_allowed
  103. def use_input_function(self, is_enabled=True):
  104. """Configures the JIT to compile 'input' instructions as function calls."""
  105. self.input_function_enabled = is_enabled
  106. def enable_tracing(self, is_enabled=True):
  107. """Enables or disables tracing for jitted code."""
  108. self.tracing_enabled = is_enabled
  109. def enable_nop_insertion(self, is_enabled=True):
  110. """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop
  111. back-edges. Inserting nops sacrifices performance to keep the jitted code from
  112. blocking the thread of execution by consuming all resources; nops give the
  113. Modelverse server an opportunity to interrupt the currently running code."""
  114. self.nop_insertion_enabled = is_enabled
  115. def set_jit_success_log(self, log_function=print_value):
  116. """Configures this JIT instance with a function that prints output to a log.
  117. Success and failure messages for specific functions are then sent to said log."""
  118. self.jit_success_log_function = log_function
  119. def set_jit_code_log(self, log_function=print_value):
  120. """Configures this JIT instance with a function that prints output to a log.
  121. Function definitions of jitted functions are then sent to said log."""
  122. self.jit_code_log_function = log_function
  123. def mark_entry_point(self, body_id):
  124. """Marks the node with the given identifier as a function entry point."""
  125. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  126. self.todo_entry_points.add(body_id)
  127. def is_entry_point(self, body_id):
  128. """Tells if the node with the given identifier is a function entry point."""
  129. return body_id in self.todo_entry_points or \
  130. body_id in self.no_jit_entry_points or \
  131. body_id in self.jitted_entry_points
  132. def is_jittable_entry_point(self, body_id):
  133. """Tells if the node with the given identifier is a function entry point that
  134. has not been marked as non-jittable. This only returns `True` if the JIT
  135. is enabled and the function entry point has been marked jittable, or if
  136. the function has already been compiled."""
  137. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  138. self.has_compiled(body_id))
  139. def has_compiled(self, body_id):
  140. """Tests if the function belonging to the given body node has been compiled yet."""
  141. return body_id in self.jitted_entry_points
  142. def get_compiled_name(self, body_id):
  143. """Gets the name of the compiled version of the given body node in the JIT
  144. global state."""
  145. return self.jitted_entry_points[body_id]
  146. def mark_no_jit(self, body_id):
  147. """Informs the JIT that the node with the given identifier is a function entry
  148. point that must never be jitted."""
  149. self.no_jit_entry_points.add(body_id)
  150. if body_id in self.todo_entry_points:
  151. self.todo_entry_points.remove(body_id)
  152. def generate_name(self, infix, suggested_name=None):
  153. """Generates a new name or picks the suggested name if it is still
  154. available."""
  155. if suggested_name is not None \
  156. and suggested_name not in self.jit_globals \
  157. and not keyword.iskeyword(suggested_name):
  158. self.jit_count += 1
  159. return suggested_name
  160. else:
  161. function_name = 'jit_%s%d' % (infix, self.jit_count)
  162. self.jit_count += 1
  163. return function_name
  164. def generate_function_name(self, suggested_name=None):
  165. """Generates a new function name or picks the suggested name if it is still
  166. available."""
  167. return self.generate_name('func', suggested_name)
  168. def register_compiled(self, body_id, compiled_function, function_name=None):
  169. """Registers a compiled entry point with the JIT."""
  170. # Get the function's name.
  171. function_name = self.generate_function_name(function_name)
  172. # Map the body id to the given parameter list.
  173. self.jitted_entry_points[body_id] = function_name
  174. self.jit_globals[function_name] = compiled_function
  175. if body_id in self.todo_entry_points:
  176. self.todo_entry_points.remove(body_id)
  177. def import_value(self, value, suggested_name=None):
  178. """Imports the given value into the JIT's global scope, with the given suggested name.
  179. The actual name of the value (within the JIT's global scope) is returned."""
  180. actual_name = self.generate_name('import', suggested_name)
  181. self.jit_globals[actual_name] = value
  182. return actual_name
  183. def lookup_compiled_function(self, name):
  184. """Looks up a compiled function by name. Returns a matching function,
  185. or None if no function was found."""
  186. if name is None:
  187. return None
  188. elif name in self.jit_globals:
  189. return self.jit_globals[name]
  190. elif self.compiled_function_lookup is not None:
  191. return self.compiled_function_lookup(name)
  192. else:
  193. return None
  194. def get_intrinsic(self, name):
  195. """Tries to find an intrinsic version of the function with the
  196. given name."""
  197. if name in self.jit_intrinsics:
  198. return self.jit_intrinsics[name]
  199. else:
  200. return None
  201. def register_intrinsic(self, name, intrinsic_function):
  202. """Registers the given intrisic with the JIT. This will make the JIT replace calls to
  203. the function with the given entry point by an application of the specified function."""
  204. self.jit_intrinsics[name] = intrinsic_function
  205. def register_binary_intrinsic(self, name, operator):
  206. """Registers an intrinsic with the JIT that represents the given binary operation."""
  207. self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction(
  208. tree_ir.BinaryInstruction(
  209. tree_ir.ReadValueInstruction(a),
  210. operator,
  211. tree_ir.ReadValueInstruction(b))))
  212. def register_unary_intrinsic(self, name, operator):
  213. """Registers an intrinsic with the JIT that represents the given unary operation."""
  214. self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
  215. tree_ir.UnaryInstruction(
  216. operator,
  217. tree_ir.ReadValueInstruction(a))))
  218. def register_cast_intrinsic(self, name, target_type):
  219. """Registers an intrinsic with the JIT that represents a unary conversion operator."""
  220. self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
  221. tree_ir.CallInstruction(
  222. tree_ir.LoadGlobalInstruction(target_type.__name__),
  223. [tree_ir.ReadValueInstruction(a)])))
  224. def jit_signature(self, body_id):
  225. """Acquires the signature for the given body id node, which consists of the
  226. parameter variables, parameter name and a flag that tells if the given function
  227. is mutable."""
  228. if body_id not in self.jitted_parameters:
  229. signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
  230. signature_id = signature_id[0]
  231. param_set_id, is_mutable = yield [
  232. ("RD", [signature_id, "params"]),
  233. ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
  234. if param_set_id is None:
  235. self.jitted_parameters[body_id] = ([], [], is_mutable)
  236. else:
  237. param_name_ids, = yield [("RDK", [param_set_id])]
  238. param_names = yield [("RV", [n]) for n in param_name_ids]
  239. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  240. self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
  241. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  242. def jit_compile(self, user_root, body_id, suggested_name=None):
  243. """Tries to jit the function defined by the given entry point id and parameter list."""
  244. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  245. # pylint: disable=I0011,W0122
  246. if body_id is None:
  247. raise ValueError('body_id cannot be None')
  248. elif body_id in self.jitted_entry_points:
  249. # We have already compiled this function.
  250. raise primitive_functions.PrimitiveFinished(
  251. self.jit_globals[self.jitted_entry_points[body_id]])
  252. elif body_id in self.no_jit_entry_points:
  253. # We're not allowed to jit this function or have tried and failed before.
  254. raise JitCompilationFailedException(
  255. 'Cannot jit function %s at %d because it is marked non-jittable.' % (
  256. '' if suggested_name is None else "'" + suggested_name + "'",
  257. body_id))
  258. elif not self.jit_enabled:
  259. # We're not allowed to jit anything.
  260. raise JitCompilationFailedException(
  261. 'Cannot jit function %s at %d because the JIT has been disabled.' % (
  262. '' if suggested_name is None else "'" + suggested_name + "'",
  263. body_id))
  264. # Generate a name for the function we're about to analyze, and pretend that
  265. # it already exists. (we need to do this for recursive functions)
  266. function_name = self.generate_function_name(suggested_name)
  267. self.jitted_entry_points[body_id] = function_name
  268. self.jit_globals[function_name] = None
  269. (parameter_ids, parameter_list, is_mutable), = yield [
  270. ("CALL_ARGS", [self.jit_signature, (body_id,)])]
  271. param_dict = dict(zip(parameter_ids, parameter_list))
  272. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  273. dependencies = set([body_id])
  274. self.compilation_dependencies[body_id] = dependencies
  275. def handle_jit_exception(exception):
  276. # If analysis fails, then a JitCompilationFailedException will be thrown.
  277. del self.compilation_dependencies[body_id]
  278. for dep in dependencies:
  279. self.mark_no_jit(dep)
  280. if dep in self.jitted_entry_points:
  281. del self.jitted_entry_points[dep]
  282. failure_message = "%s (function '%s' at %d)" % (
  283. exception.message, function_name, body_id)
  284. if self.jit_success_log_function is not None:
  285. self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
  286. raise JitCompilationFailedException(failure_message)
  287. # Try to analyze the function's body.
  288. yield [("TRY", [])]
  289. yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
  290. if is_mutable:
  291. # We can't just JIT mutable functions. That'd be dangerous.
  292. raise JitCompilationFailedException(
  293. "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
  294. state = AnalysisState(
  295. self, body_id, user_root, body_param_dict,
  296. self.max_instructions)
  297. constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_id,)])]
  298. yield [("END_TRY", [])]
  299. del self.compilation_dependencies[body_id]
  300. # Write a prologue and prepend it to the generated function body.
  301. prologue_statements = []
  302. # Create a LOCALS_NODE_NAME node, and connect it to the user root.
  303. prologue_statements.append(
  304. tree_ir.create_new_local_node(
  305. LOCALS_NODE_NAME,
  306. tree_ir.LoadIndexInstruction(
  307. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  308. tree_ir.LiteralInstruction('user_root')),
  309. LOCALS_EDGE_NAME))
  310. for (key, val) in param_dict.items():
  311. arg_ptr = tree_ir.create_new_local_node(
  312. body_param_dict[key],
  313. tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
  314. prologue_statements.append(arg_ptr)
  315. prologue_statements.append(
  316. tree_ir.CreateDictionaryEdgeInstruction(
  317. tree_ir.LoadLocalInstruction(body_param_dict[key]),
  318. tree_ir.LiteralInstruction('value'),
  319. tree_ir.LoadLocalInstruction(val)))
  320. constructed_body = tree_ir.create_block(
  321. *(prologue_statements + [constructed_body]))
  322. # Optimize the function's body.
  323. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
  324. # Shield temporaries from the GC.
  325. constructed_body = tree_ir.protect_temporaries_from_gc(
  326. constructed_body, tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
  327. # Wrap the IR in a function definition, give it a unique name.
  328. constructed_function = tree_ir.DefineFunctionInstruction(
  329. function_name,
  330. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  331. constructed_body)
  332. # Convert the function definition to Python code, and compile it.
  333. exec(str(constructed_function), self.jit_globals)
  334. # Extract the compiled function from the JIT global state.
  335. compiled_function = self.jit_globals[function_name]
  336. if self.jit_success_log_function is not None:
  337. self.jit_success_log_function(
  338. "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
  339. if self.jit_code_log_function is not None:
  340. self.jit_code_log_function(constructed_function)
  341. raise primitive_functions.PrimitiveFinished(compiled_function)
  342. class AnalysisState(object):
  343. """The state of a bytecode analysis call graph."""
  344. def __init__(self, jit, body_id, user_root, local_mapping, max_instructions=None):
  345. self.analyzed_instructions = set()
  346. self.function_vars = set()
  347. self.local_vars = set()
  348. self.body_id = body_id
  349. self.max_instructions = max_instructions
  350. self.user_root = user_root
  351. self.jit = jit
  352. self.local_mapping = local_mapping
  353. self.function_name = jit.jitted_entry_points[body_id]
  354. self.enclosing_loop_instruction = None
  355. def get_local_name(self, local_id):
  356. """Gets the name for a local with the given id."""
  357. if local_id not in self.local_mapping:
  358. self.local_mapping[local_id] = 'local%d' % local_id
  359. return self.local_mapping[local_id]
  360. def register_local_var(self, local_id):
  361. """Registers the given variable node id as a local."""
  362. if local_id in self.function_vars:
  363. raise JitCompilationFailedException(
  364. "Local is used as target of function call.")
  365. self.local_vars.add(local_id)
  366. def register_function_var(self, local_id):
  367. """Registers the given variable node id as a function."""
  368. if local_id in self.local_vars:
  369. raise JitCompilationFailedException(
  370. "Local is used as target of function call.")
  371. self.function_vars.add(local_id)
  372. def retrieve_user_root(self):
  373. """Creates an instruction that stores the user_root variable
  374. in a local."""
  375. return tree_ir.StoreLocalInstruction(
  376. 'user_root',
  377. tree_ir.LoadIndexInstruction(
  378. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  379. tree_ir.LiteralInstruction('user_root')))
  380. def load_kernel(self):
  381. """Creates an instruction that loads the Modelverse kernel."""
  382. return tree_ir.LoadIndexInstruction(
  383. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  384. tree_ir.LiteralInstruction('mvk'))
  385. def analyze(self, instruction_id):
  386. """Tries to build an intermediate representation from the instruction with the
  387. given id."""
  388. # Check the analyzed_instructions set for instruction_id to avoid
  389. # infinite loops.
  390. if instruction_id in self.analyzed_instructions:
  391. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  392. elif (self.max_instructions is not None and
  393. len(self.analyzed_instructions) > self.max_instructions):
  394. raise JitCompilationFailedException('Maximum number of instructions exceeded.')
  395. self.analyzed_instructions.add(instruction_id)
  396. instruction_val, = yield [("RV", [instruction_id])]
  397. instruction_val = instruction_val["value"]
  398. if instruction_val in self.instruction_analyzers:
  399. # If tracing is enabled, then this would be an appropriate time to
  400. # retrieve the debug information.
  401. if self.jit.tracing_enabled:
  402. debug_info, = yield [("RD", [instruction_id, "__debug"])]
  403. if debug_info is not None:
  404. debug_info, = yield [("RV", [debug_info])]
  405. # Analyze the instruction itself.
  406. outer_result, = yield [
  407. ("CALL_ARGS", [self.instruction_analyzers[instruction_val], (self, instruction_id)])]
  408. if self.jit.tracing_enabled:
  409. outer_result = tree_ir.with_debug_info_trace(outer_result, debug_info, self.function_name)
  410. # Check if the instruction has a 'next' instruction.
  411. next_instr, = yield [("RD", [instruction_id, "next"])]
  412. if next_instr is None:
  413. raise primitive_functions.PrimitiveFinished(outer_result)
  414. else:
  415. next_result, = yield [("CALL_ARGS", [self.analyze, (next_instr,)])]
  416. raise primitive_functions.PrimitiveFinished(
  417. tree_ir.CompoundInstruction(
  418. outer_result,
  419. next_result))
  420. else:
  421. raise JitCompilationFailedException(
  422. "Unknown instruction type: '%s'" % (instruction_val))
  423. def analyze_all(self, instruction_ids):
  424. """Tries to compile a list of IR trees from the given list of instruction ids."""
  425. results = []
  426. for inst in instruction_ids:
  427. analyzed_inst, = yield [("CALL_ARGS", [self.analyze, (inst,)])]
  428. results.append(analyzed_inst)
  429. raise primitive_functions.PrimitiveFinished(results)
  430. def analyze_return(self, instruction_id):
  431. """Tries to analyze the given 'return' instruction."""
  432. retval_id, = yield [("RD", [instruction_id, 'value'])]
  433. def create_return(return_value):
  434. return tree_ir.ReturnInstruction(
  435. tree_ir.CompoundInstruction(
  436. return_value,
  437. tree_ir.DeleteEdgeInstruction(
  438. tree_ir.LoadLocalInstruction(LOCALS_EDGE_NAME))))
  439. if retval_id is None:
  440. raise primitive_functions.PrimitiveFinished(
  441. create_return(
  442. tree_ir.EmptyInstruction()))
  443. else:
  444. retval, = yield [("CALL_ARGS", [self.analyze, (retval_id,)])]
  445. raise primitive_functions.PrimitiveFinished(
  446. create_return(retval))
  447. def analyze_if(self, instruction_id):
  448. """Tries to analyze the given 'if' instruction."""
  449. cond, true, false = yield [
  450. ("RD", [instruction_id, "cond"]),
  451. ("RD", [instruction_id, "then"]),
  452. ("RD", [instruction_id, "else"])]
  453. analysis_results, = yield [("CALL_ARGS", [self.analyze_all, (
  454. [cond, true]
  455. if false is None
  456. else [cond, true, false],)])]
  457. if false is None:
  458. cond_r, true_r = analysis_results
  459. false_r = tree_ir.EmptyInstruction()
  460. else:
  461. cond_r, true_r, false_r = analysis_results
  462. raise primitive_functions.PrimitiveFinished(
  463. tree_ir.SelectInstruction(
  464. tree_ir.ReadValueInstruction(cond_r),
  465. true_r,
  466. false_r))
  467. def analyze_while(self, instruction_id):
  468. """Tries to analyze the given 'while' instruction."""
  469. cond, body = yield [
  470. ("RD", [instruction_id, "cond"]),
  471. ("RD", [instruction_id, "body"])]
  472. # Analyze the condition.
  473. cond_r, = yield [("CALL_ARGS", [self.analyze, (cond,)])]
  474. # Store the old enclosing loop on the stack, and make this loop the
  475. # new enclosing loop.
  476. old_loop_instruction = self.enclosing_loop_instruction
  477. self.enclosing_loop_instruction = instruction_id
  478. body_r, = yield [("CALL_ARGS", [self.analyze, (body,)])]
  479. # Restore hte old enclosing loop.
  480. self.enclosing_loop_instruction = old_loop_instruction
  481. if self.jit.nop_insertion_enabled:
  482. create_loop_body = lambda check, body: tree_ir.create_block(
  483. check,
  484. body_r,
  485. tree_ir.NopInstruction())
  486. else:
  487. create_loop_body = tree_ir.CompoundInstruction
  488. raise primitive_functions.PrimitiveFinished(
  489. tree_ir.LoopInstruction(
  490. create_loop_body(
  491. tree_ir.SelectInstruction(
  492. tree_ir.ReadValueInstruction(cond_r),
  493. tree_ir.EmptyInstruction(),
  494. tree_ir.BreakInstruction()),
  495. body_r)))
  496. def analyze_constant(self, instruction_id):
  497. """Tries to analyze the given 'constant' (literal) instruction."""
  498. node_id, = yield [("RD", [instruction_id, "node"])]
  499. raise primitive_functions.PrimitiveFinished(
  500. tree_ir.LiteralInstruction(node_id))
  501. def analyze_output(self, instruction_id):
  502. """Tries to analyze the given 'output' instruction."""
  503. # The plan is to basically generate this tree:
  504. #
  505. # value = <some tree>
  506. # last_output, last_output_link, new_last_output = \
  507. # yield [("RD", [user_root, "last_output"]),
  508. # ("RDE", [user_root, "last_output"]),
  509. # ("CN", []),
  510. # ]
  511. # _, _, _, _ = \
  512. # yield [("CD", [last_output, "value", value]),
  513. # ("CD", [last_output, "next", new_last_output]),
  514. # ("CD", [user_root, "last_output", new_last_output]),
  515. # ("DE", [last_output_link])
  516. # ]
  517. # yield None
  518. value_id, = yield [("RD", [instruction_id, "value"])]
  519. value_val, = yield [("CALL_ARGS", [self.analyze, (value_id,)])]
  520. value_local = tree_ir.StoreLocalInstruction('value', value_val)
  521. store_user_root = self.retrieve_user_root()
  522. last_output = tree_ir.StoreLocalInstruction(
  523. 'last_output',
  524. tree_ir.ReadDictionaryValueInstruction(
  525. store_user_root.create_load(),
  526. tree_ir.LiteralInstruction('last_output')))
  527. last_output_link = tree_ir.StoreLocalInstruction(
  528. 'last_output_link',
  529. tree_ir.ReadDictionaryEdgeInstruction(
  530. store_user_root.create_load(),
  531. tree_ir.LiteralInstruction('last_output')))
  532. new_last_output = tree_ir.StoreLocalInstruction(
  533. 'new_last_output',
  534. tree_ir.CreateNodeInstruction())
  535. result = tree_ir.create_block(
  536. value_local,
  537. store_user_root,
  538. last_output,
  539. last_output_link,
  540. new_last_output,
  541. tree_ir.CreateDictionaryEdgeInstruction(
  542. last_output.create_load(),
  543. tree_ir.LiteralInstruction('value'),
  544. value_local.create_load()),
  545. tree_ir.CreateDictionaryEdgeInstruction(
  546. last_output.create_load(),
  547. tree_ir.LiteralInstruction('next'),
  548. new_last_output.create_load()),
  549. tree_ir.CreateDictionaryEdgeInstruction(
  550. store_user_root.create_load(),
  551. tree_ir.LiteralInstruction('last_output'),
  552. new_last_output.create_load()),
  553. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  554. tree_ir.NopInstruction())
  555. raise primitive_functions.PrimitiveFinished(result)
  556. def analyze_input(self, _):
  557. """Tries to analyze the given 'input' instruction."""
  558. # Possible alternative to the explicit syntax tree:
  559. if self.jit.input_function_enabled:
  560. raise primitive_functions.PrimitiveFinished(
  561. tree_ir.create_jit_call(
  562. tree_ir.LoadGlobalInstruction(GET_INPUT_FUNCTION_NAME),
  563. [],
  564. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  565. # The plan is to generate this tree:
  566. #
  567. # value = None
  568. # while True:
  569. # _input = yield [("RD", [user_root, "input"])]
  570. # value = yield [("RD", [_input, "value"])]
  571. #
  572. # if value is None:
  573. # kwargs['mvk'].success = False # to avoid blocking
  574. # yield None # nop/interrupt
  575. # else:
  576. # break
  577. #
  578. # _next = yield [("RD", [_input, "next"])]
  579. # yield [("CD", [user_root, "input", _next])]
  580. # yield [("CE", [jit_locals, value])]
  581. # yield [("DN", [_input])]
  582. user_root = self.retrieve_user_root()
  583. _input = tree_ir.StoreLocalInstruction(
  584. None,
  585. tree_ir.ReadDictionaryValueInstruction(
  586. user_root.create_load(),
  587. tree_ir.LiteralInstruction('input')))
  588. value = tree_ir.StoreLocalInstruction(
  589. None,
  590. tree_ir.ReadDictionaryValueInstruction(
  591. _input.create_load(),
  592. tree_ir.LiteralInstruction('value')))
  593. raise primitive_functions.PrimitiveFinished(
  594. tree_ir.CompoundInstruction(
  595. tree_ir.create_block(
  596. user_root,
  597. value.create_store(tree_ir.LiteralInstruction(None)),
  598. tree_ir.LoopInstruction(
  599. tree_ir.create_block(
  600. _input,
  601. value,
  602. tree_ir.SelectInstruction(
  603. tree_ir.BinaryInstruction(
  604. value.create_load(),
  605. 'is',
  606. tree_ir.LiteralInstruction(None)),
  607. tree_ir.create_block(
  608. tree_ir.StoreMemberInstruction(
  609. self.load_kernel(),
  610. 'success',
  611. tree_ir.LiteralInstruction(False)),
  612. tree_ir.NopInstruction()),
  613. tree_ir.BreakInstruction()))),
  614. tree_ir.CreateDictionaryEdgeInstruction(
  615. user_root.create_load(),
  616. tree_ir.LiteralInstruction('input'),
  617. tree_ir.ReadDictionaryValueInstruction(
  618. _input.create_load(),
  619. tree_ir.LiteralInstruction('next'))),
  620. tree_ir.CreateEdgeInstruction(
  621. tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME),
  622. value.create_load()),
  623. tree_ir.DeleteNodeInstruction(_input.create_load())),
  624. value.create_load()))
  625. def analyze_resolve(self, instruction_id):
  626. """Tries to analyze the given 'resolve' instruction."""
  627. var_id, = yield [("RD", [instruction_id, "var"])]
  628. var_name, = yield [("RV", [var_id])]
  629. # To resolve a variable, we'll do something along the
  630. # lines of:
  631. #
  632. # if 'local_var' in locals():
  633. # tmp = local_var
  634. # else:
  635. # _globals, = yield [("RD", [user_root, "globals"])]
  636. # global_var, = yield [("RD", [_globals, var_name])]
  637. #
  638. # if global_var is None:
  639. # raise Exception("Not found as global: %s" % (var_name))
  640. #
  641. # tmp = global_var
  642. name = self.get_local_name(var_id)
  643. if var_name is None:
  644. raise primitive_functions.PrimitiveFinished(
  645. tree_ir.LoadLocalInstruction(name))
  646. user_root = self.retrieve_user_root()
  647. global_var = tree_ir.StoreLocalInstruction(
  648. 'global_var',
  649. tree_ir.ReadDictionaryValueInstruction(
  650. tree_ir.ReadDictionaryValueInstruction(
  651. user_root.create_load(),
  652. tree_ir.LiteralInstruction('globals')),
  653. tree_ir.LiteralInstruction(var_name)))
  654. err_block = tree_ir.SelectInstruction(
  655. tree_ir.BinaryInstruction(
  656. global_var.create_load(),
  657. 'is',
  658. tree_ir.LiteralInstruction(None)),
  659. tree_ir.RaiseInstruction(
  660. tree_ir.CallInstruction(
  661. tree_ir.LoadGlobalInstruction('Exception'),
  662. [tree_ir.LiteralInstruction(
  663. "Not found as global: %s" % var_name)
  664. ])),
  665. tree_ir.EmptyInstruction())
  666. raise primitive_functions.PrimitiveFinished(
  667. tree_ir.SelectInstruction(
  668. tree_ir.LocalExistsInstruction(name),
  669. tree_ir.LoadLocalInstruction(name),
  670. tree_ir.CompoundInstruction(
  671. tree_ir.create_block(
  672. user_root,
  673. global_var,
  674. err_block),
  675. global_var.create_load())))
  676. def analyze_declare(self, instruction_id):
  677. """Tries to analyze the given 'declare' function."""
  678. var_id, = yield [("RD", [instruction_id, "var"])]
  679. self.register_local_var(var_id)
  680. name = self.get_local_name(var_id)
  681. # The following logic declares a local:
  682. #
  683. # if 'local_name' not in locals():
  684. # local_name, = yield [("CN", [])]
  685. # yield [("CE", [LOCALS_NODE_NAME, local_name])]
  686. raise primitive_functions.PrimitiveFinished(
  687. tree_ir.SelectInstruction(
  688. tree_ir.LocalExistsInstruction(name),
  689. tree_ir.EmptyInstruction(),
  690. tree_ir.create_new_local_node(
  691. name,
  692. tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))))
  693. def analyze_global(self, instruction_id):
  694. """Tries to analyze the given 'global' (declaration) instruction."""
  695. var_id, = yield [("RD", [instruction_id, "var"])]
  696. var_name, = yield [("RV", [var_id])]
  697. # To resolve a variable, we'll do something along the
  698. # lines of:
  699. #
  700. # _globals, = yield [("RD", [user_root, "globals"])]
  701. # global_var = yield [("RD", [_globals, var_name])]
  702. #
  703. # if global_var is None:
  704. # global_var, = yield [("CN", [])]
  705. # yield [("CD", [_globals, var_name, global_var])]
  706. #
  707. # tmp = global_var
  708. user_root = self.retrieve_user_root()
  709. _globals = tree_ir.StoreLocalInstruction(
  710. '_globals',
  711. tree_ir.ReadDictionaryValueInstruction(
  712. user_root.create_load(),
  713. tree_ir.LiteralInstruction('globals')))
  714. global_var = tree_ir.StoreLocalInstruction(
  715. 'global_var',
  716. tree_ir.ReadDictionaryValueInstruction(
  717. _globals.create_load(),
  718. tree_ir.LiteralInstruction(var_name)))
  719. raise primitive_functions.PrimitiveFinished(
  720. tree_ir.CompoundInstruction(
  721. tree_ir.create_block(
  722. user_root,
  723. _globals,
  724. global_var,
  725. tree_ir.SelectInstruction(
  726. tree_ir.BinaryInstruction(
  727. global_var.create_load(),
  728. 'is',
  729. tree_ir.LiteralInstruction(None)),
  730. tree_ir.create_block(
  731. global_var.create_store(
  732. tree_ir.CreateNodeInstruction()),
  733. tree_ir.CreateDictionaryEdgeInstruction(
  734. _globals.create_load(),
  735. tree_ir.LiteralInstruction(var_name),
  736. global_var.create_load())),
  737. tree_ir.EmptyInstruction())),
  738. global_var.create_load()))
  739. def analyze_assign(self, instruction_id):
  740. """Tries to analyze the given 'assign' instruction."""
  741. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  742. ("RD", [instruction_id, "value"])]
  743. (var_r, value_r), = yield [("CALL_ARGS", [self.analyze_all, ([var_id, value_id],)])]
  744. # Assignments work like this:
  745. #
  746. # value_link = yield [("RDE", [variable, "value"])]
  747. # _, _ = yield [("CD", [variable, "value", value]),
  748. # ("DE", [value_link])]
  749. variable = tree_ir.StoreLocalInstruction(None, var_r)
  750. value = tree_ir.StoreLocalInstruction(None, value_r)
  751. value_link = tree_ir.StoreLocalInstruction(
  752. 'value_link',
  753. tree_ir.ReadDictionaryEdgeInstruction(
  754. variable.create_load(),
  755. tree_ir.LiteralInstruction('value')))
  756. raise primitive_functions.PrimitiveFinished(
  757. tree_ir.create_block(
  758. variable,
  759. value,
  760. value_link,
  761. tree_ir.CreateDictionaryEdgeInstruction(
  762. variable.create_load(),
  763. tree_ir.LiteralInstruction('value'),
  764. value.create_load()),
  765. tree_ir.DeleteEdgeInstruction(
  766. value_link.create_load())))
  767. def analyze_access(self, instruction_id):
  768. """Tries to analyze the given 'access' instruction."""
  769. var_id, = yield [("RD", [instruction_id, "var"])]
  770. var_r, = yield [("CALL_ARGS", [self.analyze, (var_id,)])]
  771. # Accessing a variable is pretty easy. It really just boils
  772. # down to reading the value corresponding to the 'value' key
  773. # of the variable.
  774. #
  775. # value, = yield [("RD", [returnvalue, "value"])]
  776. raise primitive_functions.PrimitiveFinished(
  777. tree_ir.ReadDictionaryValueInstruction(
  778. var_r,
  779. tree_ir.LiteralInstruction('value')))
  780. def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
  781. """Tries to analyze a direct 'call' instruction."""
  782. self.register_function_var(callee_id)
  783. body_id, = yield [("RD", [callee_id, jit_runtime.FUNCTION_BODY_KEY])]
  784. # Make this function dependent on the callee.
  785. if body_id in self.jit.compilation_dependencies:
  786. self.jit.compilation_dependencies[body_id].add(self.body_id)
  787. # Figure out if the function might be an intrinsic.
  788. intrinsic = self.jit.get_intrinsic(callee_name)
  789. if intrinsic is None:
  790. compiled_func = self.jit.lookup_compiled_function(callee_name)
  791. if compiled_func is None:
  792. # Compile the callee.
  793. yield [("CALL_ARGS", [self.jit.jit_compile, (self.user_root, body_id, callee_name)])]
  794. else:
  795. self.jit.register_compiled(body_id, compiled_func, callee_name)
  796. # Get the callee's name.
  797. compiled_func_name = self.jit.get_compiled_name(body_id)
  798. # This handles the corner case where a constant node is called, like
  799. # 'call(constant(9), ...)'. In this case, `callee_name` is `None`
  800. # because 'constant(9)' doesn't give us a name. However, we can look up
  801. # the name of the function at a specific node. If that turns out to be
  802. # an intrinsic, then we still want to pick the intrinsic over a call.
  803. intrinsic = self.jit.get_intrinsic(compiled_func_name)
  804. # Analyze the argument dictionary.
  805. named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (first_parameter_id,)])]
  806. if intrinsic is not None:
  807. raise primitive_functions.PrimitiveFinished(
  808. apply_intrinsic(intrinsic, named_args))
  809. else:
  810. raise primitive_functions.PrimitiveFinished(
  811. tree_ir.create_jit_call(
  812. tree_ir.LoadGlobalInstruction(compiled_func_name),
  813. named_args,
  814. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  815. def analyze_arguments(self, first_argument_id):
  816. """Analyzes the parameter-to-argument mapping started by the specified first argument
  817. node."""
  818. next_param = first_argument_id
  819. named_args = []
  820. while next_param is not None:
  821. param_name_id, = yield [("RD", [next_param, "name"])]
  822. param_name, = yield [("RV", [param_name_id])]
  823. param_val_id, = yield [("RD", [next_param, "value"])]
  824. param_val, = yield [("CALL_ARGS", [self.analyze, (param_val_id,)])]
  825. named_args.append((param_name, param_val))
  826. next_param, = yield [("RD", [next_param, "next_param"])]
  827. raise primitive_functions.PrimitiveFinished(named_args)
  828. def analyze_indirect_call(self, func_id, first_arg_id):
  829. """Analyzes a call to an unknown function."""
  830. # First off, let's analyze the callee and the argument list.
  831. func_val, = yield [("CALL_ARGS", [self.analyze, (func_id,)])]
  832. named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (first_arg_id,)])]
  833. # Call the __call_function function to run the interpreter, like so:
  834. #
  835. # __call_function(function_id, { first_param_name : first_param_val, ... }, **kwargs)
  836. #
  837. dict_literal = tree_ir.DictionaryLiteralInstruction(
  838. [(tree_ir.LiteralInstruction(key), val) for key, val in named_args])
  839. raise primitive_functions.PrimitiveFinished(
  840. tree_ir.create_jit_call(
  841. tree_ir.LoadGlobalInstruction(CALL_FUNCTION_NAME),
  842. [('function_id', func_val), ('named_arguments', dict_literal)],
  843. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  844. def try_analyze_direct_call(self, func_id, first_param_id):
  845. """Tries to analyze the given 'call' instruction as a direct call."""
  846. if not self.jit.direct_calls_allowed:
  847. raise JitCompilationFailedException('Direct calls are not allowed by the JIT.')
  848. # Figure out what the 'func' instruction's type is.
  849. func_instruction_op, = yield [("RV", [func_id])]
  850. if func_instruction_op['value'] == 'access':
  851. # 'access(resolve(var))' instructions are translated to direct calls.
  852. access_value_id, = yield [("RD", [func_id, "var"])]
  853. access_value_op, = yield [("RV", [access_value_id])]
  854. if access_value_op['value'] == 'resolve':
  855. resolved_var_id, = yield [("RD", [access_value_id, "var"])]
  856. resolved_var_name, = yield [("RV", [resolved_var_id])]
  857. # Try to look up the name as a global.
  858. _globals, = yield [("RD", [self.user_root, "globals"])]
  859. global_var, = yield [("RD", [_globals, resolved_var_name])]
  860. global_val, = yield [("RD", [global_var, "value"])]
  861. if global_val is not None:
  862. result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
  863. global_val, resolved_var_name, first_param_id)])]
  864. raise primitive_functions.PrimitiveFinished(result)
  865. elif func_instruction_op['value'] == 'constant':
  866. # 'const(func_id)' instructions are also translated to direct calls.
  867. function_val_id, = yield [("RD", [func_id, "node"])]
  868. result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
  869. function_val_id, None, first_param_id)])]
  870. raise primitive_functions.PrimitiveFinished(result)
  871. raise JitCompilationFailedException(
  872. "Cannot JIT function calls that target an unknown value as direct calls.")
  873. def analyze_call(self, instruction_id):
  874. """Tries to analyze the given 'call' instruction."""
  875. func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
  876. ("RD", [instruction_id, "params"])]
  877. def handle_exception(exception):
  878. # Looks like we'll have to compile it as an indirect call.
  879. gen = self.analyze_indirect_call(func_id, first_param_id)
  880. result, = yield [("CALL", [gen])]
  881. raise primitive_functions.PrimitiveFinished(result)
  882. # Try to analyze the call as a direct call.
  883. yield [("TRY", [])]
  884. yield [("CATCH", [JitCompilationFailedException, handle_exception])]
  885. result, = yield [("CALL_ARGS", [self.try_analyze_direct_call, (func_id, first_param_id)])]
  886. yield [("END_TRY", [])]
  887. raise primitive_functions.PrimitiveFinished(result)
  888. def analyze_break(self, instruction_id):
  889. """Tries to analyze the given 'break' instruction."""
  890. target_instruction_id, = yield [("RD", [instruction_id, "while"])]
  891. if target_instruction_id == self.enclosing_loop_instruction:
  892. raise primitive_functions.PrimitiveFinished(tree_ir.BreakInstruction())
  893. else:
  894. raise JitCompilationFailedException(
  895. "Multilevel 'break' is not supported by the baseline JIT.")
  896. def analyze_continue(self, instruction_id):
  897. """Tries to analyze the given 'continue' instruction."""
  898. target_instruction_id, = yield [("RD", [instruction_id, "while"])]
  899. if target_instruction_id == self.enclosing_loop_instruction:
  900. raise primitive_functions.PrimitiveFinished(tree_ir.ContinueInstruction())
  901. else:
  902. raise JitCompilationFailedException(
  903. "Multilevel 'continue' is not supported by the baseline JIT.")
  904. instruction_analyzers = {
  905. 'if' : analyze_if,
  906. 'while' : analyze_while,
  907. 'return' : analyze_return,
  908. 'constant' : analyze_constant,
  909. 'resolve' : analyze_resolve,
  910. 'declare' : analyze_declare,
  911. 'global' : analyze_global,
  912. 'assign' : analyze_assign,
  913. 'access' : analyze_access,
  914. 'output' : analyze_output,
  915. 'input' : analyze_input,
  916. 'call' : analyze_call,
  917. 'break' : analyze_break,
  918. 'continue' : analyze_continue
  919. }