jit.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jitted_parameters = {}
  15. self.jit_globals = {
  16. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  17. }
  18. self.jit_count = 0
  19. self.max_instructions = 30 if max_instructions is None else max_instructions
  20. self.compiled_function_lookup = compiled_function_lookup
  21. self.jit_intrinsics = {}
  22. self.compilation_dependencies = {}
  23. self.jit_enabled = True
  24. def set_jit_enabled(self, is_enabled=True):
  25. """Enables or disables the JIT."""
  26. self.jit_enabled = is_enabled
  27. def mark_entry_point(self, body_id):
  28. """Marks the node with the given identifier as a function entry point."""
  29. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  30. self.todo_entry_points.add(body_id)
  31. def is_entry_point(self, body_id):
  32. """Tells if the node with the given identifier is a function entry point."""
  33. return body_id in self.todo_entry_points or \
  34. body_id in self.no_jit_entry_points or \
  35. body_id in self.jitted_entry_points
  36. def is_jittable_entry_point(self, body_id):
  37. """Tells if the node with the given identifier is a function entry point that
  38. has not been marked as non-jittable. This only returns `True` if the JIT
  39. is enabled and the function entry point has been marked jittable, or if
  40. the function has already been compiled."""
  41. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  42. self.has_compiled(body_id))
  43. def has_compiled(self, body_id):
  44. """Tests if the function belonging to the given body node has been compiled yet."""
  45. return body_id in self.jitted_entry_points
  46. def get_compiled_name(self, body_id):
  47. """Gets the name of the compiled version of the given body node in the JIT
  48. global state."""
  49. return self.jitted_entry_points[body_id]
  50. def mark_no_jit(self, body_id):
  51. """Informs the JIT that the node with the given identifier is a function entry
  52. point that must never be jitted."""
  53. self.no_jit_entry_points.add(body_id)
  54. if body_id in self.todo_entry_points:
  55. self.todo_entry_points.remove(body_id)
  56. def generate_function_name(self, suggested_name=None):
  57. """Generates a new function name or picks the suggested name if it is still
  58. available."""
  59. if suggested_name is not None and suggested_name not in self.jit_globals:
  60. self.jit_count += 1
  61. return suggested_name
  62. else:
  63. function_name = 'jit_func%d' % self.jit_count
  64. self.jit_count += 1
  65. return function_name
  66. def register_compiled(self, body_id, compiled_function, function_name=None):
  67. """Registers a compiled entry point with the JIT."""
  68. function_name = self.generate_function_name(function_name)
  69. self.jitted_entry_points[body_id] = function_name
  70. self.jit_globals[function_name] = compiled_function
  71. if body_id in self.todo_entry_points:
  72. self.todo_entry_points.remove(body_id)
  73. def lookup_compiled_function(self, name):
  74. """Looks up a compiled function by name. Returns a matching function,
  75. or None if no function was found."""
  76. if name in self.jit_globals:
  77. return self.jit_globals[name]
  78. elif self.compiled_function_lookup is not None:
  79. return self.compiled_function_lookup(name)
  80. else:
  81. return None
  82. def register_intrinsic(self, name, apply_intrinsic):
  83. """Registers the given intrisic with the JIT. This will make the JIT replace calls to
  84. the function with the given entry point by an application of the specified function."""
  85. self.jit_intrinsics[name] = apply_intrinsic
  86. def register_binary_intrinsic(self, name, operator):
  87. """Registers an intrinsic with the JIT that represents the given binary operation."""
  88. self.register_intrinsic(name, lambda lhs, rhs: tree_ir.CreateNodeWithValueInstruction(
  89. tree_ir.BinaryInstruction(
  90. tree_ir.ReadValueInstruction(lhs),
  91. operator,
  92. tree_ir.ReadValueInstruction(rhs))))
  93. def register_unary_intrinsic(self, name, operator):
  94. """Registers an intrinsic with the JIT that represents the given unary operation."""
  95. self.register_intrinsic(name, lambda val: tree_ir.CreateNodeWithValueInstruction(
  96. tree_ir.UnaryInstruction(
  97. operator,
  98. tree_ir.ReadValueInstruction(val))))
  99. def jit_parameters(self, body_id):
  100. """Acquires the parameter list for the given body id node."""
  101. if body_id not in self.jitted_parameters:
  102. signature_id, = yield [("RRD", [body_id, "body"])]
  103. signature_id = signature_id[0]
  104. param_set_id, = yield [("RD", [signature_id, "params"])]
  105. if param_set_id is None:
  106. self.jitted_parameters[body_id] = ([], [])
  107. else:
  108. param_name_ids, = yield [("RDK", [param_set_id])]
  109. param_names = yield [("RV", [n]) for n in param_name_ids]
  110. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  111. self.jitted_parameters[body_id] = (param_vars, param_names)
  112. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  113. def jit_compile(self, user_root, body_id, suggested_name=None):
  114. """Tries to jit the function defined by the given entry point id and parameter list."""
  115. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  116. # pylint: disable=I0011,W0122
  117. if body_id in self.jitted_entry_points:
  118. # We have already compiled this function.
  119. raise primitive_functions.PrimitiveFinished(
  120. self.jit_globals[self.jitted_entry_points[body_id]])
  121. elif body_id in self.no_jit_entry_points:
  122. # We're not allowed to jit this function or have tried and failed before.
  123. raise JitCompilationFailedException(
  124. 'Cannot jit function %s at %d because it is marked non-jittable.' % (
  125. '' if suggested_name is None else "'" + suggested_name + "'",
  126. body_id))
  127. # Generate a name for the function we're about to analyze, and pretend that
  128. # it already exists. (we need to do this for recursive functions)
  129. function_name = self.generate_function_name(suggested_name)
  130. self.jitted_entry_points[body_id] = function_name
  131. self.jit_globals[function_name] = None
  132. try:
  133. gen = self.jit_parameters(body_id)
  134. inp = None
  135. while True:
  136. inp = yield gen.send(inp)
  137. except primitive_functions.PrimitiveFinished as ex:
  138. parameter_ids, parameter_list = ex.result
  139. param_dict = dict(zip(parameter_ids, parameter_list))
  140. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  141. dependencies = set([body_id])
  142. self.compilation_dependencies[body_id] = dependencies
  143. try:
  144. gen = AnalysisState(
  145. self, body_id, user_root, body_param_dict,
  146. self.max_instructions).analyze(body_id)
  147. inp = None
  148. while True:
  149. inp = yield gen.send(inp)
  150. except primitive_functions.PrimitiveFinished as ex:
  151. del self.compilation_dependencies[body_id]
  152. constructed_body = ex.result
  153. except JitCompilationFailedException as ex:
  154. del self.compilation_dependencies[body_id]
  155. for dep in dependencies:
  156. self.mark_no_jit(dep)
  157. if dep in self.jitted_entry_points:
  158. del self.jitted_entry_points[dep]
  159. raise JitCompilationFailedException(
  160. "%s (function '%s' at %d)" % (ex.message, function_name, body_id))
  161. # Write a prologue and prepend it to the generated function body.
  162. prologue_statements = []
  163. for (key, val) in param_dict.items():
  164. arg_ptr = tree_ir.StoreLocalInstruction(
  165. body_param_dict[key],
  166. tree_ir.CreateNodeInstruction())
  167. prologue_statements.append(arg_ptr)
  168. prologue_statements.append(
  169. tree_ir.CreateDictionaryEdgeInstruction(
  170. arg_ptr.create_load(),
  171. tree_ir.LiteralInstruction('value'),
  172. tree_ir.LoadLocalInstruction(val)))
  173. constructed_body = tree_ir.create_block(
  174. *(prologue_statements + [constructed_body]))
  175. # Wrap the IR in a function definition, give it a unique name.
  176. constructed_function = tree_ir.DefineFunctionInstruction(
  177. function_name,
  178. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  179. constructed_body.simplify())
  180. # Convert the function definition to Python code, and compile it.
  181. exec(str(constructed_function), self.jit_globals)
  182. # Extract the compiled function from the JIT global state.
  183. compiled_function = self.jit_globals[function_name]
  184. print(constructed_function)
  185. raise primitive_functions.PrimitiveFinished(compiled_function)
  186. class AnalysisState(object):
  187. """The state of a bytecode analysis call graph."""
  188. def __init__(self, jit, body_id, user_root, local_mapping, max_instructions=None):
  189. self.analyzed_instructions = set()
  190. self.function_vars = set()
  191. self.local_vars = set()
  192. self.body_id = body_id
  193. self.max_instructions = max_instructions
  194. self.user_root = user_root
  195. self.jit = jit
  196. self.local_mapping = local_mapping
  197. def get_local_name(self, local_id):
  198. """Gets the name for a local with the given id."""
  199. if local_id not in self.local_mapping:
  200. self.local_mapping[local_id] = 'local%d' % local_id
  201. return self.local_mapping[local_id]
  202. def register_local_var(self, local_id):
  203. """Registers the given variable node id as a local."""
  204. if local_id in self.function_vars:
  205. raise JitCompilationFailedException(
  206. "Local is used as target of function call.")
  207. self.local_vars.add(local_id)
  208. def register_function_var(self, local_id):
  209. """Registers the given variable node id as a function."""
  210. if local_id in self.local_vars:
  211. raise JitCompilationFailedException(
  212. "Local is used as target of function call.")
  213. self.function_vars.add(local_id)
  214. def retrieve_user_root(self):
  215. """Creates an instruction that stores the user_root variable
  216. in a local."""
  217. return tree_ir.StoreLocalInstruction(
  218. 'user_root',
  219. tree_ir.LoadIndexInstruction(
  220. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  221. tree_ir.LiteralInstruction('user_root')))
  222. def analyze(self, instruction_id):
  223. """Tries to build an intermediate representation from the instruction with the
  224. given id."""
  225. # Check the analyzed_instructions set for instruction_id to avoid
  226. # infinite loops.
  227. if instruction_id in self.analyzed_instructions:
  228. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  229. elif (self.max_instructions is not None and
  230. len(self.analyzed_instructions) > self.max_instructions):
  231. raise JitCompilationFailedException('Maximum number of instructions exceeded.')
  232. self.analyzed_instructions.add(instruction_id)
  233. instruction_val, = yield [("RV", [instruction_id])]
  234. instruction_val = instruction_val["value"]
  235. if instruction_val in self.instruction_analyzers:
  236. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  237. try:
  238. inp = None
  239. while True:
  240. inp = yield gen.send(inp)
  241. except StopIteration:
  242. raise Exception(
  243. "Instruction analyzer (for '%s') finished without returning a value!" %
  244. (instruction_val))
  245. except primitive_functions.PrimitiveFinished as outer_e:
  246. # Check if the instruction has a 'next' instruction.
  247. next_instr, = yield [("RD", [instruction_id, "next"])]
  248. if next_instr is None:
  249. raise outer_e
  250. else:
  251. gen = self.analyze(next_instr)
  252. try:
  253. inp = None
  254. while True:
  255. inp = yield gen.send(inp)
  256. except primitive_functions.PrimitiveFinished as inner_e:
  257. raise primitive_functions.PrimitiveFinished(
  258. tree_ir.CompoundInstruction(
  259. outer_e.result,
  260. inner_e.result))
  261. else:
  262. raise JitCompilationFailedException(
  263. "Unknown instruction type: '%s'" % (instruction_val))
  264. def analyze_all(self, instruction_ids):
  265. """Tries to compile a list of IR trees from the given list of instruction ids."""
  266. results = []
  267. for inst in instruction_ids:
  268. gen = self.analyze(inst)
  269. try:
  270. inp = None
  271. while True:
  272. inp = yield gen.send(inp)
  273. except primitive_functions.PrimitiveFinished as ex:
  274. results.append(ex.result)
  275. raise primitive_functions.PrimitiveFinished(results)
  276. def analyze_return(self, instruction_id):
  277. """Tries to analyze the given 'return' instruction."""
  278. retval_id, = yield [("RD", [instruction_id, 'value'])]
  279. if retval_id is None:
  280. raise primitive_functions.PrimitiveFinished(
  281. tree_ir.ReturnInstruction(
  282. tree_ir.EmptyInstruction()))
  283. else:
  284. gen = self.analyze(retval_id)
  285. try:
  286. inp = None
  287. while True:
  288. inp = yield gen.send(inp)
  289. except primitive_functions.PrimitiveFinished as ex:
  290. raise primitive_functions.PrimitiveFinished(
  291. tree_ir.ReturnInstruction(ex.result))
  292. def analyze_if(self, instruction_id):
  293. """Tries to analyze the given 'if' instruction."""
  294. cond, true, false = yield [
  295. ("RD", [instruction_id, "cond"]),
  296. ("RD", [instruction_id, "then"]),
  297. ("RD", [instruction_id, "else"])]
  298. gen = self.analyze_all(
  299. [cond, true]
  300. if false is None
  301. else [cond, true, false])
  302. try:
  303. inp = None
  304. while True:
  305. inp = yield gen.send(inp)
  306. except primitive_functions.PrimitiveFinished as ex:
  307. if false is None:
  308. cond_r, true_r = ex.result
  309. false_r = tree_ir.EmptyInstruction()
  310. else:
  311. cond_r, true_r, false_r = ex.result
  312. raise primitive_functions.PrimitiveFinished(
  313. tree_ir.SelectInstruction(
  314. tree_ir.ReadValueInstruction(cond_r),
  315. true_r,
  316. false_r))
  317. def analyze_while(self, instruction_id):
  318. """Tries to analyze the given 'while' instruction."""
  319. cond, body = yield [
  320. ("RD", [instruction_id, "cond"]),
  321. ("RD", [instruction_id, "body"])]
  322. gen = self.analyze_all([cond, body])
  323. try:
  324. inp = None
  325. while True:
  326. inp = yield gen.send(inp)
  327. except primitive_functions.PrimitiveFinished as ex:
  328. cond_r, body_r = ex.result
  329. raise primitive_functions.PrimitiveFinished(
  330. tree_ir.LoopInstruction(
  331. tree_ir.CompoundInstruction(
  332. tree_ir.SelectInstruction(
  333. tree_ir.ReadValueInstruction(cond_r),
  334. tree_ir.EmptyInstruction(),
  335. tree_ir.BreakInstruction()),
  336. body_r)))
  337. def analyze_constant(self, instruction_id):
  338. """Tries to analyze the given 'constant' (literal) instruction."""
  339. node_id, = yield [("RD", [instruction_id, "node"])]
  340. raise primitive_functions.PrimitiveFinished(
  341. tree_ir.LiteralInstruction(node_id))
  342. def analyze_output(self, instruction_id):
  343. """Tries to analyze the given 'output' instruction."""
  344. # The plan is to basically generate this tree:
  345. #
  346. # value = <some tree>
  347. # last_output, last_output_link, new_last_output = \
  348. # yield [("RD", [user_root, "last_output"]),
  349. # ("RDE", [user_root, "last_output"]),
  350. # ("CN", []),
  351. # ]
  352. # _, _, _, _ = \
  353. # yield [("CD", [last_output, "value", value]),
  354. # ("CD", [last_output, "next", new_last_output]),
  355. # ("CD", [user_root, "last_output", new_last_output]),
  356. # ("DE", [last_output_link])
  357. # ]
  358. # yield None
  359. value_id, = yield [("RD", [instruction_id, "value"])]
  360. gen = self.analyze(value_id)
  361. try:
  362. inp = None
  363. while True:
  364. inp = yield gen.send(inp)
  365. except primitive_functions.PrimitiveFinished as ex:
  366. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  367. store_user_root = self.retrieve_user_root()
  368. last_output = tree_ir.StoreLocalInstruction(
  369. 'last_output',
  370. tree_ir.ReadDictionaryValueInstruction(
  371. store_user_root.create_load(),
  372. tree_ir.LiteralInstruction('last_output')))
  373. last_output_link = tree_ir.StoreLocalInstruction(
  374. 'last_output_link',
  375. tree_ir.ReadDictionaryEdgeInstruction(
  376. store_user_root.create_load(),
  377. tree_ir.LiteralInstruction('last_output')))
  378. new_last_output = tree_ir.StoreLocalInstruction(
  379. 'new_last_output',
  380. tree_ir.CreateNodeInstruction())
  381. result = tree_ir.create_block(
  382. value_local,
  383. store_user_root,
  384. last_output,
  385. last_output_link,
  386. new_last_output,
  387. tree_ir.CreateDictionaryEdgeInstruction(
  388. last_output.create_load(),
  389. tree_ir.LiteralInstruction('value'),
  390. value_local.create_load()),
  391. tree_ir.CreateDictionaryEdgeInstruction(
  392. last_output.create_load(),
  393. tree_ir.LiteralInstruction('next'),
  394. new_last_output.create_load()),
  395. tree_ir.CreateDictionaryEdgeInstruction(
  396. store_user_root.create_load(),
  397. tree_ir.LiteralInstruction('last_output'),
  398. new_last_output.create_load()),
  399. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  400. tree_ir.NopInstruction())
  401. raise primitive_functions.PrimitiveFinished(result)
  402. def analyze_input(self, _):
  403. """Tries to analyze the given 'input' instruction."""
  404. # The plan is to generate this tree:
  405. #
  406. # value = None
  407. # while True:
  408. # if value is None:
  409. # yield None # nop
  410. # else:
  411. # break
  412. #
  413. # _input = yield [("RD", [user_root, "input"])]
  414. # value = yield [("RD", [_input, "value"])]
  415. #
  416. # _next = yield [("RD", [_input, "next"])]
  417. # yield [("CD", [user_root, "input", _next])]
  418. # yield [("DN", [_input])]
  419. user_root = self.retrieve_user_root()
  420. _input = tree_ir.StoreLocalInstruction(
  421. '_input',
  422. tree_ir.ReadDictionaryValueInstruction(
  423. user_root.create_load(),
  424. tree_ir.LiteralInstruction('input')))
  425. value = tree_ir.StoreLocalInstruction(
  426. 'value',
  427. tree_ir.ReadDictionaryValueInstruction(
  428. _input.create_load(),
  429. tree_ir.LiteralInstruction('value')))
  430. raise primitive_functions.PrimitiveFinished(
  431. tree_ir.CompoundInstruction(
  432. tree_ir.create_block(
  433. user_root,
  434. value.create_store(tree_ir.LiteralInstruction(None)),
  435. tree_ir.LoopInstruction(
  436. tree_ir.create_block(
  437. tree_ir.SelectInstruction(
  438. tree_ir.BinaryInstruction(
  439. value.create_load(),
  440. 'is',
  441. tree_ir.LiteralInstruction(None)),
  442. tree_ir.NopInstruction(),
  443. tree_ir.BreakInstruction()),
  444. _input,
  445. value)),
  446. tree_ir.CreateDictionaryEdgeInstruction(
  447. user_root.create_load(),
  448. tree_ir.LiteralInstruction('input'),
  449. tree_ir.ReadDictionaryValueInstruction(
  450. _input.create_load(),
  451. tree_ir.LiteralInstruction('next'))),
  452. tree_ir.DeleteNodeInstruction(_input.create_load())),
  453. value.create_load()))
  454. def analyze_resolve(self, instruction_id):
  455. """Tries to analyze the given 'resolve' instruction."""
  456. var_id, = yield [("RD", [instruction_id, "var"])]
  457. var_name, = yield [("RV", [var_id])]
  458. # To resolve a variable, we'll do something along the
  459. # lines of:
  460. #
  461. # if 'local_var' in locals():
  462. # tmp = local_var
  463. # else:
  464. # _globals, = yield [("RD", [user_root, "globals"])]
  465. # global_var, = yield [("RD", [_globals, var_name])]
  466. #
  467. # if global_var is None:
  468. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  469. #
  470. # tmp = global_var
  471. name = self.get_local_name(var_id)
  472. if var_name is None:
  473. raise primitive_functions.PrimitiveFinished(
  474. tree_ir.LoadLocalInstruction(name))
  475. user_root = self.retrieve_user_root()
  476. global_var = tree_ir.StoreLocalInstruction(
  477. 'global_var',
  478. tree_ir.ReadDictionaryValueInstruction(
  479. tree_ir.ReadDictionaryValueInstruction(
  480. user_root.create_load(),
  481. tree_ir.LiteralInstruction('globals')),
  482. tree_ir.LiteralInstruction(var_name)))
  483. err_block = tree_ir.SelectInstruction(
  484. tree_ir.BinaryInstruction(
  485. global_var.create_load(),
  486. 'is',
  487. tree_ir.LiteralInstruction(None)),
  488. tree_ir.RaiseInstruction(
  489. tree_ir.CallInstruction(
  490. tree_ir.LoadGlobalInstruction('Exception'),
  491. [tree_ir.LiteralInstruction(
  492. "Runtime error: global '%s' not found" % var_name)
  493. ])),
  494. tree_ir.EmptyInstruction())
  495. raise primitive_functions.PrimitiveFinished(
  496. tree_ir.SelectInstruction(
  497. tree_ir.LocalExistsInstruction(name),
  498. tree_ir.LoadLocalInstruction(name),
  499. tree_ir.CompoundInstruction(
  500. tree_ir.create_block(
  501. user_root,
  502. global_var,
  503. err_block),
  504. global_var.create_load())))
  505. def analyze_declare(self, instruction_id):
  506. """Tries to analyze the given 'declare' function."""
  507. var_id, = yield [("RD", [instruction_id, "var"])]
  508. self.register_local_var(var_id)
  509. name = self.get_local_name(var_id)
  510. # The following logic declares a local:
  511. #
  512. # if 'local_name' not in locals():
  513. # local_name, = yield [("CN", [])]
  514. raise primitive_functions.PrimitiveFinished(
  515. tree_ir.SelectInstruction(
  516. tree_ir.LocalExistsInstruction(name),
  517. tree_ir.EmptyInstruction(),
  518. tree_ir.StoreLocalInstruction(
  519. name,
  520. tree_ir.CreateNodeInstruction())))
  521. def analyze_global(self, instruction_id):
  522. """Tries to analyze the given 'global' (declaration) instruction."""
  523. var_id, = yield [("RD", [instruction_id, "var"])]
  524. var_name, = yield [("RV", [var_id])]
  525. # To resolve a variable, we'll do something along the
  526. # lines of:
  527. #
  528. # _globals, = yield [("RD", [user_root, "globals"])]
  529. # global_var = yield [("RD", [_globals, var_name])]
  530. #
  531. # if global_var is None:
  532. # global_var, = yield [("CN", [])]
  533. # yield [("CD", [_globals, var_name, global_var])]
  534. #
  535. # tmp = global_var
  536. user_root = self.retrieve_user_root()
  537. _globals = tree_ir.StoreLocalInstruction(
  538. '_globals',
  539. tree_ir.ReadDictionaryValueInstruction(
  540. user_root.create_load(),
  541. tree_ir.LiteralInstruction('globals')))
  542. global_var = tree_ir.StoreLocalInstruction(
  543. 'global_var',
  544. tree_ir.ReadDictionaryValueInstruction(
  545. _globals.create_load(),
  546. tree_ir.LiteralInstruction(var_name)))
  547. raise primitive_functions.PrimitiveFinished(
  548. tree_ir.CompoundInstruction(
  549. tree_ir.create_block(
  550. user_root,
  551. _globals,
  552. global_var,
  553. tree_ir.SelectInstruction(
  554. tree_ir.BinaryInstruction(
  555. global_var.create_load(),
  556. 'is',
  557. tree_ir.LiteralInstruction(None)),
  558. tree_ir.create_block(
  559. global_var.create_store(
  560. tree_ir.CreateNodeInstruction()),
  561. tree_ir.CreateDictionaryEdgeInstruction(
  562. _globals.create_load(),
  563. tree_ir.LiteralInstruction(var_name),
  564. global_var.create_load())),
  565. tree_ir.EmptyInstruction())),
  566. global_var.create_load()))
  567. def analyze_assign(self, instruction_id):
  568. """Tries to analyze the given 'assign' instruction."""
  569. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  570. ("RD", [instruction_id, "value"])]
  571. try:
  572. gen = self.analyze_all([var_id, value_id])
  573. inp = None
  574. while True:
  575. inp = yield gen.send(inp)
  576. except primitive_functions.PrimitiveFinished as ex:
  577. var_r, value_r = ex.result
  578. # Assignments work like this:
  579. #
  580. # value_link = yield [("RDE", [variable, "value"])]
  581. # _, _ = yield [("CD", [variable, "value", value]),
  582. # ("DE", [value_link])]
  583. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  584. value = tree_ir.StoreLocalInstruction('value', value_r)
  585. value_link = tree_ir.StoreLocalInstruction(
  586. 'value_link',
  587. tree_ir.ReadDictionaryEdgeInstruction(
  588. variable.create_load(),
  589. tree_ir.LiteralInstruction('value')))
  590. raise primitive_functions.PrimitiveFinished(
  591. tree_ir.create_block(
  592. variable,
  593. value,
  594. value_link,
  595. tree_ir.CreateDictionaryEdgeInstruction(
  596. variable.create_load(),
  597. tree_ir.LiteralInstruction('value'),
  598. value.create_load()),
  599. tree_ir.DeleteEdgeInstruction(
  600. value_link.create_load())))
  601. def analyze_access(self, instruction_id):
  602. """Tries to analyze the given 'access' instruction."""
  603. var_id, = yield [("RD", [instruction_id, "var"])]
  604. try:
  605. gen = self.analyze(var_id)
  606. inp = None
  607. while True:
  608. inp = yield gen.send(inp)
  609. except primitive_functions.PrimitiveFinished as ex:
  610. var_r = ex.result
  611. # Accessing a variable is pretty easy. It really just boils
  612. # down to reading the value corresponding to the 'value' key
  613. # of the variable.
  614. #
  615. # value, = yield [("RD", [returnvalue, "value"])]
  616. raise primitive_functions.PrimitiveFinished(
  617. tree_ir.ReadDictionaryValueInstruction(
  618. var_r,
  619. tree_ir.LiteralInstruction('value')))
  620. def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
  621. """Tries to analyze a direct 'call' instruction."""
  622. self.register_function_var(callee_id)
  623. body_id, = yield [("RD", [callee_id, "body"])]
  624. # Make this function dependent on the callee.
  625. if body_id in self.jit.compilation_dependencies:
  626. self.jit.compilation_dependencies[body_id].add(self.body_id)
  627. # Analyze the parameter list.
  628. try:
  629. gen = self.jit.jit_parameters(body_id)
  630. inp = None
  631. while True:
  632. inp = yield gen.send(inp)
  633. except primitive_functions.PrimitiveFinished as ex:
  634. _, parameter_names = ex.result
  635. is_intrinsic = callee_name in self.jit.jit_intrinsics
  636. if not is_intrinsic:
  637. compiled_func = self.jit.lookup_compiled_function(callee_name)
  638. if compiled_func is None:
  639. # Compile the callee.
  640. try:
  641. gen = self.jit.jit_compile(self.user_root, body_id, callee_name)
  642. inp = None
  643. while True:
  644. inp = yield gen.send(inp)
  645. except primitive_functions.PrimitiveFinished as ex:
  646. pass
  647. else:
  648. self.jit.register_compiled(body_id, compiled_func, callee_name)
  649. # Get the callee's name.
  650. compiled_func_name = self.jit.get_compiled_name(body_id)
  651. # Analyze the argument dictionary.
  652. try:
  653. gen = self.analyze_argument_dict(first_parameter_id)
  654. inp = None
  655. while True:
  656. inp = yield gen.send(inp)
  657. except primitive_functions.PrimitiveFinished as ex:
  658. arg_dict = ex.result
  659. # Construct the argument list from the parameter list and
  660. # argument dictionary.
  661. arg_list = []
  662. for param_name in parameter_names:
  663. if param_name in arg_dict:
  664. arg_list.append(arg_dict[param_name])
  665. else:
  666. raise JitCompilationFailedException(
  667. "Cannot JIT-compile function call to '%s' with missing argument for "
  668. "formal parameter '%s'." % (callee_name, param_name))
  669. if is_intrinsic:
  670. raise primitive_functions.PrimitiveFinished(
  671. self.jit.jit_intrinsics[callee_name](*arg_list))
  672. else:
  673. raise primitive_functions.PrimitiveFinished(
  674. tree_ir.JitCallInstruction(
  675. tree_ir.LoadGlobalInstruction(compiled_func_name),
  676. arg_list,
  677. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  678. def analyze_argument_dict(self, first_argument_id):
  679. """Analyzes the parameter-to-argument mapping started by the specified first argument
  680. node."""
  681. next_param = first_argument_id
  682. argument_dict = {}
  683. while next_param is not None:
  684. param_name_id, = yield [("RD", [next_param, "name"])]
  685. param_name, = yield [("RV", [param_name_id])]
  686. param_val_id, = yield [("RD", [next_param, "value"])]
  687. try:
  688. gen = self.analyze(param_val_id)
  689. inp = None
  690. while True:
  691. inp = yield gen.send(inp)
  692. except primitive_functions.PrimitiveFinished as ex:
  693. argument_dict[param_name] = ex.result
  694. next_param, = yield [("RD", [next_param, "next_param"])]
  695. raise primitive_functions.PrimitiveFinished(argument_dict)
  696. def analyze_call(self, instruction_id):
  697. """Tries to analyze the given 'call' instruction."""
  698. func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
  699. ("RD", [instruction_id, "params"])]
  700. # Figure out what the 'func' instruction's type is.
  701. func_instruction_op, = yield [("RV", [func_id])]
  702. if func_instruction_op['value'] == 'access':
  703. # Calls to 'access(resolve(var))' instructions are translated to direct calls.
  704. access_value_id, = yield [("RD", [func_id, "var"])]
  705. access_value_op, = yield [("RV", [access_value_id])]
  706. if access_value_op['value'] == 'resolve':
  707. resolved_var_id, = yield [("RD", [access_value_id, "var"])]
  708. resolved_var_name, = yield [("RV", [resolved_var_id])]
  709. # Try to look the name up as a global.
  710. _globals, = yield [("RD", [self.user_root, "globals"])]
  711. global_var, = yield [("RD", [_globals, resolved_var_name])]
  712. global_val, = yield [("RD", [global_var, "value"])]
  713. if global_val is None:
  714. raise JitCompilationFailedException(
  715. "Cannot JIT function calls that target an unknown value.")
  716. else:
  717. gen = self.analyze_direct_call(
  718. global_val, resolved_var_name, first_param_id)
  719. inp = None
  720. while True:
  721. inp = yield gen.send(inp)
  722. # PrimitiveFinished exception will bubble up from here.
  723. raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
  724. instruction_analyzers = {
  725. 'if' : analyze_if,
  726. 'while' : analyze_while,
  727. 'return' : analyze_return,
  728. 'constant' : analyze_constant,
  729. 'resolve' : analyze_resolve,
  730. 'declare' : analyze_declare,
  731. 'global' : analyze_global,
  732. 'assign' : analyze_assign,
  733. 'access' : analyze_access,
  734. 'output' : analyze_output,
  735. 'input' : analyze_input,
  736. 'call' : analyze_call
  737. }