jit.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jitted_parameters = {}
  15. self.jit_globals = {
  16. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  17. }
  18. self.jit_count = 0
  19. self.max_instructions = 30 if max_instructions is None else max_instructions
  20. self.jit_enabled = True
  21. def set_jit_enabled(self, is_enabled=True):
  22. """Enables or disables the JIT."""
  23. self.jit_enabled = is_enabled
  24. def mark_entry_point(self, body_id):
  25. """Marks the node with the given identifier as a function entry point."""
  26. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  27. self.todo_entry_points.add(body_id)
  28. def is_entry_point(self, body_id):
  29. """Tells if the node with the given identifier is a function entry point."""
  30. return body_id in self.todo_entry_points or \
  31. body_id in self.no_jit_entry_points or \
  32. body_id in self.jitted_entry_points
  33. def is_jittable_entry_point(self, body_id):
  34. """Tells if the node with the given identifier is a function entry point that
  35. has not been marked as non-jittable. This only returns `True` if the JIT
  36. is enabled and the function entry point has been marked jittable, or if
  37. the function has already been compiled."""
  38. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  39. self.has_compiled(body_id))
  40. def has_compiled(self, body_id):
  41. """Tests if the function belonging to the given body node has been compiled yet."""
  42. return body_id in self.jitted_entry_points
  43. def get_compiled_name(self, body_id):
  44. """Gets the name of the compiled version of the given body node in the JIT
  45. global state."""
  46. return self.jitted_entry_points[body_id]
  47. def mark_no_jit(self, body_id):
  48. """Informs the JIT that the node with the given identifier is a function entry
  49. point that must never be jitted."""
  50. self.no_jit_entry_points.add(body_id)
  51. if body_id in self.todo_entry_points:
  52. self.todo_entry_points.remove(body_id)
  53. def generate_function_name(self):
  54. """Generates a new function name,"""
  55. function_name = 'jit_func%d' % self.jit_count
  56. self.jit_count += 1
  57. return function_name
  58. def register_compiled(self, body_id, compiled_function, function_name=None):
  59. """Registers a compiled entry point with the JIT."""
  60. if function_name is None:
  61. function_name = self.generate_function_name()
  62. self.jitted_entry_points[body_id] = function_name
  63. self.jit_globals[function_name] = compiled_function
  64. if body_id in self.todo_entry_points:
  65. self.todo_entry_points.remove(body_id)
  66. def jit_parameters(self, body_id):
  67. """Acquires the parameter list for the given body id node."""
  68. if body_id not in self.jitted_parameters:
  69. signature_id, = yield [("RRD", [body_id, "body"])]
  70. signature_id = signature_id[0]
  71. param_set_id, = yield [("RD", [signature_id, "params"])]
  72. if param_set_id is None:
  73. self.jitted_parameters[body_id] = ([], [])
  74. else:
  75. param_name_ids, = yield [("RDK", [param_set_id])]
  76. param_names = yield [("RV", [n]) for n in param_name_ids]
  77. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  78. self.jitted_parameters[body_id] = (param_vars, param_names)
  79. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  80. def jit_compile(self, user_root, body_id):
  81. """Tries to jit the function defined by the given entry point id and parameter list."""
  82. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  83. # pylint: disable=I0011,W0122
  84. if body_id in self.jitted_entry_points:
  85. # We have already compiled this function.
  86. raise primitive_functions.PrimitiveFinished(
  87. self.jit_globals[self.jitted_entry_points[body_id]])
  88. elif body_id in self.no_jit_entry_points:
  89. # We're not allowed to jit this function or have tried and failed before.
  90. raise JitCompilationFailedException(
  91. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  92. # Generate a name for the function we're about to analyze, and pretend that
  93. # it already exists. (we need to do this for recursive functions)
  94. function_name = self.generate_function_name()
  95. self.jitted_entry_points[body_id] = function_name
  96. self.jit_globals[function_name] = None
  97. try:
  98. gen = self.jit_parameters(body_id)
  99. inp = None
  100. while True:
  101. inp = yield gen.send(inp)
  102. except primitive_functions.PrimitiveFinished as ex:
  103. parameter_ids, parameter_list = ex.result
  104. param_dict = dict(zip(parameter_ids, parameter_list))
  105. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  106. try:
  107. gen = AnalysisState(
  108. self, user_root, body_param_dict, self.max_instructions).analyze(body_id)
  109. inp = None
  110. while True:
  111. inp = yield gen.send(inp)
  112. except primitive_functions.PrimitiveFinished as ex:
  113. constructed_body = ex.result
  114. except JitCompilationFailedException as ex:
  115. self.mark_no_jit(body_id)
  116. del self.jitted_entry_points[body_id]
  117. raise JitCompilationFailedException(
  118. '%s (function at %d)' % (ex.message, body_id))
  119. # Write a prologue and prepend it to the generated function body.
  120. prologue_statements = []
  121. for (key, val) in param_dict.items():
  122. arg_ptr = tree_ir.StoreLocalInstruction(
  123. body_param_dict[key],
  124. tree_ir.CreateNodeInstruction())
  125. prologue_statements.append(arg_ptr)
  126. prologue_statements.append(
  127. tree_ir.CreateDictionaryEdgeInstruction(
  128. arg_ptr.create_load(),
  129. tree_ir.LiteralInstruction('value'),
  130. tree_ir.LoadLocalInstruction(val)))
  131. constructed_body = tree_ir.create_block(
  132. *(prologue_statements + [constructed_body]))
  133. # Wrap the IR in a function definition, give it a unique name.
  134. constructed_function = tree_ir.DefineFunctionInstruction(
  135. function_name,
  136. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  137. constructed_body.simplify())
  138. # Convert the function definition to Python code, and compile it.
  139. exec(str(constructed_function), self.jit_globals)
  140. # Extract the compiled function from the JIT global state.
  141. compiled_function = self.jit_globals[function_name]
  142. # print(constructed_function)
  143. raise primitive_functions.PrimitiveFinished(compiled_function)
  144. class AnalysisState(object):
  145. """The state of a bytecode analysis call graph."""
  146. def __init__(self, jit, user_root, local_mapping, max_instructions=None):
  147. self.analyzed_instructions = set()
  148. self.function_vars = set()
  149. self.local_vars = set()
  150. self.max_instructions = max_instructions
  151. self.user_root = user_root
  152. self.jit = jit
  153. self.local_mapping = local_mapping
  154. def get_local_name(self, local_id):
  155. """Gets the name for a local with the given id."""
  156. if local_id not in self.local_mapping:
  157. self.local_mapping[local_id] = 'local%d' % local_id
  158. return self.local_mapping[local_id]
  159. def register_local_var(self, local_id):
  160. """Registers the given variable node id as a local."""
  161. if local_id in self.function_vars:
  162. raise JitCompilationFailedException(
  163. "Local is used as target of function call.")
  164. self.local_vars.add(local_id)
  165. def register_function_var(self, local_id):
  166. """Registers the given variable node id as a function."""
  167. if local_id in self.local_vars:
  168. raise JitCompilationFailedException(
  169. "Local is used as target of function call.")
  170. self.function_vars.add(local_id)
  171. def retrieve_user_root(self):
  172. """Creates an instruction that stores the user_root variable
  173. in a local."""
  174. return tree_ir.StoreLocalInstruction(
  175. 'user_root',
  176. tree_ir.LoadIndexInstruction(
  177. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  178. tree_ir.LiteralInstruction('user_root')))
  179. def analyze(self, instruction_id):
  180. """Tries to build an intermediate representation from the instruction with the
  181. given id."""
  182. # Check the analyzed_instructions set for instruction_id to avoid
  183. # infinite loops.
  184. if instruction_id in self.analyzed_instructions:
  185. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  186. elif (self.max_instructions is not None and
  187. len(self.analyzed_instructions) > self.max_instructions):
  188. raise JitCompilationFailedException('Maximum number of instructions exceeded.')
  189. self.analyzed_instructions.add(instruction_id)
  190. instruction_val, = yield [("RV", [instruction_id])]
  191. instruction_val = instruction_val["value"]
  192. if instruction_val in self.instruction_analyzers:
  193. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  194. try:
  195. inp = None
  196. while True:
  197. inp = yield gen.send(inp)
  198. except StopIteration:
  199. raise Exception(
  200. "Instruction analyzer (for '%s') finished without returning a value!" %
  201. (instruction_val))
  202. except primitive_functions.PrimitiveFinished as outer_e:
  203. # Check if the instruction has a 'next' instruction.
  204. next_instr, = yield [("RD", [instruction_id, "next"])]
  205. if next_instr is None:
  206. raise outer_e
  207. else:
  208. gen = self.analyze(next_instr)
  209. try:
  210. inp = None
  211. while True:
  212. inp = yield gen.send(inp)
  213. except primitive_functions.PrimitiveFinished as inner_e:
  214. raise primitive_functions.PrimitiveFinished(
  215. tree_ir.CompoundInstruction(
  216. outer_e.result,
  217. inner_e.result))
  218. else:
  219. raise JitCompilationFailedException(
  220. "Unknown instruction type: '%s'" % (instruction_val))
  221. def analyze_all(self, instruction_ids):
  222. """Tries to compile a list of IR trees from the given list of instruction ids."""
  223. results = []
  224. for inst in instruction_ids:
  225. gen = self.analyze(inst)
  226. try:
  227. inp = None
  228. while True:
  229. inp = yield gen.send(inp)
  230. except primitive_functions.PrimitiveFinished as ex:
  231. results.append(ex.result)
  232. raise primitive_functions.PrimitiveFinished(results)
  233. def analyze_return(self, instruction_id):
  234. """Tries to analyze the given 'return' instruction."""
  235. retval_id, = yield [("RD", [instruction_id, 'value'])]
  236. if retval_id is None:
  237. raise primitive_functions.PrimitiveFinished(
  238. tree_ir.ReturnInstruction(
  239. tree_ir.EmptyInstruction()))
  240. else:
  241. gen = self.analyze(retval_id)
  242. try:
  243. inp = None
  244. while True:
  245. inp = yield gen.send(inp)
  246. except primitive_functions.PrimitiveFinished as ex:
  247. raise primitive_functions.PrimitiveFinished(
  248. tree_ir.ReturnInstruction(ex.result))
  249. def analyze_if(self, instruction_id):
  250. """Tries to analyze the given 'if' instruction."""
  251. cond, true, false = yield [
  252. ("RD", [instruction_id, "cond"]),
  253. ("RD", [instruction_id, "then"]),
  254. ("RD", [instruction_id, "else"])]
  255. gen = self.analyze_all(
  256. [cond, true]
  257. if false is None
  258. else [cond, true, false])
  259. try:
  260. inp = None
  261. while True:
  262. inp = yield gen.send(inp)
  263. except primitive_functions.PrimitiveFinished as ex:
  264. if false is None:
  265. cond_r, true_r = ex.result
  266. false_r = tree_ir.EmptyInstruction()
  267. else:
  268. cond_r, true_r, false_r = ex.result
  269. raise primitive_functions.PrimitiveFinished(
  270. tree_ir.SelectInstruction(
  271. tree_ir.ReadValueInstruction(cond_r),
  272. true_r,
  273. false_r))
  274. def analyze_while(self, instruction_id):
  275. """Tries to analyze the given 'while' instruction."""
  276. cond, body = yield [
  277. ("RD", [instruction_id, "cond"]),
  278. ("RD", [instruction_id, "body"])]
  279. gen = self.analyze_all([cond, body])
  280. try:
  281. inp = None
  282. while True:
  283. inp = yield gen.send(inp)
  284. except primitive_functions.PrimitiveFinished as ex:
  285. cond_r, body_r = ex.result
  286. raise primitive_functions.PrimitiveFinished(
  287. tree_ir.LoopInstruction(
  288. tree_ir.CompoundInstruction(
  289. tree_ir.SelectInstruction(
  290. tree_ir.ReadValueInstruction(cond_r),
  291. tree_ir.EmptyInstruction(),
  292. tree_ir.BreakInstruction()),
  293. body_r)))
  294. def analyze_constant(self, instruction_id):
  295. """Tries to analyze the given 'constant' (literal) instruction."""
  296. node_id, = yield [("RD", [instruction_id, "node"])]
  297. raise primitive_functions.PrimitiveFinished(
  298. tree_ir.LiteralInstruction(node_id))
  299. def analyze_output(self, instruction_id):
  300. """Tries to analyze the given 'output' instruction."""
  301. # The plan is to basically generate this tree:
  302. #
  303. # value = <some tree>
  304. # last_output, last_output_link, new_last_output = \
  305. # yield [("RD", [user_root, "last_output"]),
  306. # ("RDE", [user_root, "last_output"]),
  307. # ("CN", []),
  308. # ]
  309. # _, _, _, _ = \
  310. # yield [("CD", [last_output, "value", value]),
  311. # ("CD", [last_output, "next", new_last_output]),
  312. # ("CD", [user_root, "last_output", new_last_output]),
  313. # ("DE", [last_output_link])
  314. # ]
  315. # yield None
  316. value_id, = yield [("RD", [instruction_id, "value"])]
  317. gen = self.analyze(value_id)
  318. try:
  319. inp = None
  320. while True:
  321. inp = yield gen.send(inp)
  322. except primitive_functions.PrimitiveFinished as ex:
  323. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  324. store_user_root = self.retrieve_user_root()
  325. last_output = tree_ir.StoreLocalInstruction(
  326. 'last_output',
  327. tree_ir.ReadDictionaryValueInstruction(
  328. store_user_root.create_load(),
  329. tree_ir.LiteralInstruction('last_output')))
  330. last_output_link = tree_ir.StoreLocalInstruction(
  331. 'last_output_link',
  332. tree_ir.ReadDictionaryEdgeInstruction(
  333. store_user_root.create_load(),
  334. tree_ir.LiteralInstruction('last_output')))
  335. new_last_output = tree_ir.StoreLocalInstruction(
  336. 'new_last_output',
  337. tree_ir.CreateNodeInstruction())
  338. result = tree_ir.create_block(
  339. value_local,
  340. store_user_root,
  341. last_output,
  342. last_output_link,
  343. new_last_output,
  344. tree_ir.CreateDictionaryEdgeInstruction(
  345. last_output.create_load(),
  346. tree_ir.LiteralInstruction('value'),
  347. value_local.create_load()),
  348. tree_ir.CreateDictionaryEdgeInstruction(
  349. last_output.create_load(),
  350. tree_ir.LiteralInstruction('next'),
  351. new_last_output.create_load()),
  352. tree_ir.CreateDictionaryEdgeInstruction(
  353. store_user_root.create_load(),
  354. tree_ir.LiteralInstruction('last_output'),
  355. new_last_output.create_load()),
  356. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  357. tree_ir.NopInstruction())
  358. raise primitive_functions.PrimitiveFinished(result)
  359. def analyze_input(self, _):
  360. """Tries to analyze the given 'input' instruction."""
  361. # The plan is to generate this tree:
  362. #
  363. # value = None
  364. # while True:
  365. # if value is None:
  366. # yield None # nop
  367. # else:
  368. # break
  369. #
  370. # _input = yield [("RD", [user_root, "input"])]
  371. # value = yield [("RD", [_input, "value"])]
  372. #
  373. # _next = yield [("RD", [_input, "next"])]
  374. # yield [("CD", [user_root, "input", _next])]
  375. # yield [("DN", [_input])]
  376. user_root = self.retrieve_user_root()
  377. _input = tree_ir.StoreLocalInstruction(
  378. '_input',
  379. tree_ir.ReadDictionaryValueInstruction(
  380. user_root.create_load(),
  381. tree_ir.LiteralInstruction('input')))
  382. value = tree_ir.StoreLocalInstruction(
  383. 'value',
  384. tree_ir.ReadDictionaryValueInstruction(
  385. _input.create_load(),
  386. tree_ir.LiteralInstruction('value')))
  387. raise primitive_functions.PrimitiveFinished(
  388. tree_ir.CompoundInstruction(
  389. tree_ir.create_block(
  390. user_root,
  391. value.create_store(tree_ir.LiteralInstruction(None)),
  392. tree_ir.LoopInstruction(
  393. tree_ir.create_block(
  394. tree_ir.SelectInstruction(
  395. tree_ir.BinaryInstruction(
  396. value.create_load(),
  397. 'is',
  398. tree_ir.LiteralInstruction(None)),
  399. tree_ir.NopInstruction(),
  400. tree_ir.BreakInstruction()),
  401. _input,
  402. value)),
  403. tree_ir.CreateDictionaryEdgeInstruction(
  404. user_root.create_load(),
  405. tree_ir.LiteralInstruction('input'),
  406. tree_ir.ReadDictionaryValueInstruction(
  407. _input.create_load(),
  408. tree_ir.LiteralInstruction('next'))),
  409. tree_ir.DeleteNodeInstruction(_input.create_load())),
  410. value.create_load()))
  411. def analyze_resolve(self, instruction_id):
  412. """Tries to analyze the given 'resolve' instruction."""
  413. var_id, = yield [("RD", [instruction_id, "var"])]
  414. var_name, = yield [("RV", [var_id])]
  415. # To resolve a variable, we'll do something along the
  416. # lines of:
  417. #
  418. # if 'local_var' in locals():
  419. # tmp = local_var
  420. # else:
  421. # _globals, = yield [("RD", [user_root, "globals"])]
  422. # global_var, = yield [("RD", [_globals, var_name])]
  423. #
  424. # if global_var is None:
  425. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  426. #
  427. # tmp = global_var
  428. name = self.get_local_name(var_id)
  429. if var_name is None:
  430. raise primitive_functions.PrimitiveFinished(
  431. tree_ir.LoadLocalInstruction(name))
  432. user_root = self.retrieve_user_root()
  433. global_var = tree_ir.StoreLocalInstruction(
  434. 'global_var',
  435. tree_ir.ReadDictionaryValueInstruction(
  436. tree_ir.ReadDictionaryValueInstruction(
  437. user_root.create_load(),
  438. tree_ir.LiteralInstruction('globals')),
  439. tree_ir.LiteralInstruction(var_name)))
  440. err_block = tree_ir.SelectInstruction(
  441. tree_ir.BinaryInstruction(
  442. global_var.create_load(),
  443. 'is',
  444. tree_ir.LiteralInstruction(None)),
  445. tree_ir.RaiseInstruction(
  446. tree_ir.CallInstruction(
  447. tree_ir.LoadGlobalInstruction('Exception'),
  448. [tree_ir.LiteralInstruction(
  449. "Runtime error: global '%s' not found" % var_name)
  450. ])),
  451. tree_ir.EmptyInstruction())
  452. raise primitive_functions.PrimitiveFinished(
  453. tree_ir.SelectInstruction(
  454. tree_ir.LocalExistsInstruction(name),
  455. tree_ir.LoadLocalInstruction(name),
  456. tree_ir.CompoundInstruction(
  457. tree_ir.create_block(
  458. user_root,
  459. global_var,
  460. err_block),
  461. global_var.create_load())))
  462. def analyze_declare(self, instruction_id):
  463. """Tries to analyze the given 'declare' function."""
  464. var_id, = yield [("RD", [instruction_id, "var"])]
  465. self.register_local_var(var_id)
  466. name = self.get_local_name(var_id)
  467. # The following logic declares a local:
  468. #
  469. # if 'local_name' not in locals():
  470. # local_name, = yield [("CN", [])]
  471. raise primitive_functions.PrimitiveFinished(
  472. tree_ir.SelectInstruction(
  473. tree_ir.LocalExistsInstruction(name),
  474. tree_ir.EmptyInstruction(),
  475. tree_ir.StoreLocalInstruction(
  476. name,
  477. tree_ir.CreateNodeInstruction())))
  478. def analyze_global(self, instruction_id):
  479. """Tries to analyze the given 'global' (declaration) instruction."""
  480. var_id, = yield [("RD", [instruction_id, "var"])]
  481. var_name, = yield [("RV", [var_id])]
  482. # To resolve a variable, we'll do something along the
  483. # lines of:
  484. #
  485. # _globals, = yield [("RD", [user_root, "globals"])]
  486. # global_var = yield [("RD", [_globals, var_name])]
  487. #
  488. # if global_var is None:
  489. # global_var, = yield [("CN", [])]
  490. # yield [("CD", [_globals, var_name, global_var])]
  491. #
  492. # tmp = global_var
  493. user_root = self.retrieve_user_root()
  494. _globals = tree_ir.StoreLocalInstruction(
  495. '_globals',
  496. tree_ir.ReadDictionaryValueInstruction(
  497. user_root.create_load(),
  498. tree_ir.LiteralInstruction('globals')))
  499. global_var = tree_ir.StoreLocalInstruction(
  500. 'global_var',
  501. tree_ir.ReadDictionaryValueInstruction(
  502. _globals.create_load(),
  503. tree_ir.LiteralInstruction(var_name)))
  504. raise primitive_functions.PrimitiveFinished(
  505. tree_ir.CompoundInstruction(
  506. tree_ir.create_block(
  507. user_root,
  508. _globals,
  509. global_var,
  510. tree_ir.SelectInstruction(
  511. tree_ir.BinaryInstruction(
  512. global_var.create_load(),
  513. 'is',
  514. tree_ir.LiteralInstruction(None)),
  515. tree_ir.create_block(
  516. global_var.create_store(
  517. tree_ir.CreateNodeInstruction()),
  518. tree_ir.CreateDictionaryEdgeInstruction(
  519. _globals.create_load(),
  520. tree_ir.LiteralInstruction(var_name),
  521. global_var.create_load())),
  522. tree_ir.EmptyInstruction())),
  523. global_var.create_load()))
  524. def analyze_assign(self, instruction_id):
  525. """Tries to analyze the given 'assign' instruction."""
  526. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  527. ("RD", [instruction_id, "value"])]
  528. try:
  529. gen = self.analyze_all([var_id, value_id])
  530. inp = None
  531. while True:
  532. inp = yield gen.send(inp)
  533. except primitive_functions.PrimitiveFinished as ex:
  534. var_r, value_r = ex.result
  535. # Assignments work like this:
  536. #
  537. # value_link = yield [("RDE", [variable, "value"])]
  538. # _, _ = yield [("CD", [variable, "value", value]),
  539. # ("DE", [value_link])]
  540. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  541. value = tree_ir.StoreLocalInstruction('value', value_r)
  542. value_link = tree_ir.StoreLocalInstruction(
  543. 'value_link',
  544. tree_ir.ReadDictionaryEdgeInstruction(
  545. variable.create_load(),
  546. tree_ir.LiteralInstruction('value')))
  547. raise primitive_functions.PrimitiveFinished(
  548. tree_ir.create_block(
  549. variable,
  550. value,
  551. value_link,
  552. tree_ir.CreateDictionaryEdgeInstruction(
  553. variable.create_load(),
  554. tree_ir.LiteralInstruction('value'),
  555. value.create_load()),
  556. tree_ir.DeleteEdgeInstruction(
  557. value_link.create_load())))
  558. def analyze_access(self, instruction_id):
  559. """Tries to analyze the given 'access' instruction."""
  560. var_id, = yield [("RD", [instruction_id, "var"])]
  561. try:
  562. gen = self.analyze(var_id)
  563. inp = None
  564. while True:
  565. inp = yield gen.send(inp)
  566. except primitive_functions.PrimitiveFinished as ex:
  567. var_r = ex.result
  568. # Accessing a variable is pretty easy. It really just boils
  569. # down to reading the value corresponding to the 'value' key
  570. # of the variable.
  571. #
  572. # value, = yield [("RD", [returnvalue, "value"])]
  573. raise primitive_functions.PrimitiveFinished(
  574. tree_ir.ReadDictionaryValueInstruction(
  575. var_r,
  576. tree_ir.LiteralInstruction('value')))
  577. def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
  578. """Tries to analyze a direct 'call' instruction."""
  579. self.register_function_var(callee_id)
  580. body_id, = yield [("RD", [callee_id, "body"])]
  581. # Analyze the parameter list.
  582. try:
  583. gen = self.jit.jit_parameters(body_id)
  584. inp = None
  585. while True:
  586. inp = yield gen.send(inp)
  587. except primitive_functions.PrimitiveFinished as ex:
  588. _, parameter_names = ex.result
  589. # Compile the callee.
  590. try:
  591. gen = self.jit.jit_compile(self.user_root, body_id)
  592. inp = None
  593. while True:
  594. inp = yield gen.send(inp)
  595. except primitive_functions.PrimitiveFinished as ex:
  596. pass
  597. # Get the callee's name.
  598. compiled_func_name = self.jit.get_compiled_name(body_id)
  599. # Analyze the argument dictionary.
  600. try:
  601. gen = self.analyze_argument_dict(first_parameter_id)
  602. inp = None
  603. while True:
  604. inp = yield gen.send(inp)
  605. except primitive_functions.PrimitiveFinished as ex:
  606. arg_dict = ex.result
  607. # Construct the argument list from the parameter list and
  608. # argument dictionary.
  609. arg_list = []
  610. for param_name in parameter_names:
  611. if param_name in arg_dict:
  612. arg_list.append(arg_dict[param_name])
  613. else:
  614. raise JitCompilationFailedException(
  615. "Cannot JIT-compile function call to '%s' with missing argument for "
  616. "formal parameter '%s'." % (callee_name, param_name))
  617. raise primitive_functions.PrimitiveFinished(
  618. tree_ir.JitCallInstruction(
  619. tree_ir.LoadGlobalInstruction(compiled_func_name),
  620. arg_list,
  621. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  622. def analyze_argument_dict(self, first_argument_id):
  623. """Analyzes the parameter-to-argument mapping started by the specified first argument
  624. node."""
  625. next_param = first_argument_id
  626. argument_dict = {}
  627. while next_param is not None:
  628. param_name_id, = yield [("RD", [next_param, "name"])]
  629. param_name, = yield [("RV", [param_name_id])]
  630. param_val_id, = yield [("RD", [next_param, "value"])]
  631. try:
  632. gen = self.analyze(param_val_id)
  633. inp = None
  634. while True:
  635. inp = yield gen.send(inp)
  636. except primitive_functions.PrimitiveFinished as ex:
  637. argument_dict[param_name] = ex.result
  638. next_param, = yield [("RD", [next_param, "next_param"])]
  639. raise primitive_functions.PrimitiveFinished(argument_dict)
  640. def analyze_call(self, instruction_id):
  641. """Tries to analyze the given 'call' instruction."""
  642. func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
  643. ("RD", [instruction_id, "params"])]
  644. # Figure out what the 'func' instruction's type is.
  645. func_instruction_op, = yield [("RV", [func_id])]
  646. if func_instruction_op['value'] == 'access':
  647. # Calls to 'access(resolve(var))' instructions are translated to direct calls.
  648. access_value_id, = yield [("RD", [func_id, "var"])]
  649. access_value_op, = yield [("RV", [access_value_id])]
  650. if access_value_op['value'] == 'resolve':
  651. resolved_var_id, = yield [("RD", [access_value_id, "var"])]
  652. resolved_var_name, = yield [("RV", [resolved_var_id])]
  653. # Try to look the name up as a global.
  654. _globals, = yield [("RD", [self.user_root, "globals"])]
  655. global_var, = yield [("RD", [_globals, resolved_var_name])]
  656. global_val, = yield [("RD", [global_var, "value"])]
  657. if global_val is None:
  658. raise JitCompilationFailedException(
  659. "Cannot JIT function calls that target an unknown value.")
  660. else:
  661. gen = self.analyze_direct_call(
  662. global_val, resolved_var_name, first_param_id)
  663. inp = None
  664. while True:
  665. inp = yield gen.send(inp)
  666. # PrimitiveFinished exception will bubble up from here.
  667. raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
  668. instruction_analyzers = {
  669. 'if' : analyze_if,
  670. 'while' : analyze_while,
  671. 'return' : analyze_return,
  672. 'constant' : analyze_constant,
  673. 'resolve' : analyze_resolve,
  674. 'declare' : analyze_declare,
  675. 'global' : analyze_global,
  676. 'assign' : analyze_assign,
  677. 'access' : analyze_access,
  678. 'output' : analyze_output,
  679. 'input' : analyze_input,
  680. 'call' : analyze_call
  681. }