jit.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jitted_parameters = {}
  15. self.jit_globals = {
  16. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  17. }
  18. self.jit_count = 0
  19. self.max_instructions = 30 if max_instructions is None else max_instructions
  20. self.compiled_function_lookup = compiled_function_lookup
  21. self.jit_enabled = True
  22. def set_jit_enabled(self, is_enabled=True):
  23. """Enables or disables the JIT."""
  24. self.jit_enabled = is_enabled
  25. def mark_entry_point(self, body_id):
  26. """Marks the node with the given identifier as a function entry point."""
  27. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  28. self.todo_entry_points.add(body_id)
  29. def is_entry_point(self, body_id):
  30. """Tells if the node with the given identifier is a function entry point."""
  31. return body_id in self.todo_entry_points or \
  32. body_id in self.no_jit_entry_points or \
  33. body_id in self.jitted_entry_points
  34. def is_jittable_entry_point(self, body_id):
  35. """Tells if the node with the given identifier is a function entry point that
  36. has not been marked as non-jittable. This only returns `True` if the JIT
  37. is enabled and the function entry point has been marked jittable, or if
  38. the function has already been compiled."""
  39. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  40. self.has_compiled(body_id))
  41. def has_compiled(self, body_id):
  42. """Tests if the function belonging to the given body node has been compiled yet."""
  43. return body_id in self.jitted_entry_points
  44. def get_compiled_name(self, body_id):
  45. """Gets the name of the compiled version of the given body node in the JIT
  46. global state."""
  47. return self.jitted_entry_points[body_id]
  48. def mark_no_jit(self, body_id):
  49. """Informs the JIT that the node with the given identifier is a function entry
  50. point that must never be jitted."""
  51. self.no_jit_entry_points.add(body_id)
  52. if body_id in self.todo_entry_points:
  53. self.todo_entry_points.remove(body_id)
  54. def generate_function_name(self):
  55. """Generates a new function name,"""
  56. function_name = 'jit_func%d' % self.jit_count
  57. self.jit_count += 1
  58. return function_name
  59. def register_compiled(self, body_id, compiled_function, function_name=None):
  60. """Registers a compiled entry point with the JIT."""
  61. if function_name is None:
  62. function_name = self.generate_function_name()
  63. self.jitted_entry_points[body_id] = function_name
  64. self.jit_globals[function_name] = compiled_function
  65. if body_id in self.todo_entry_points:
  66. self.todo_entry_points.remove(body_id)
  67. def lookup_compiled_function(self, name):
  68. """Looks up a compiled function by name. Returns a matching function,
  69. or None if no function was found."""
  70. if name in self.jit_globals:
  71. return self.jit_globals[name]
  72. elif self.compiled_function_lookup is not None:
  73. return self.compiled_function_lookup(name)
  74. else:
  75. return None
  76. def jit_parameters(self, body_id):
  77. """Acquires the parameter list for the given body id node."""
  78. if body_id not in self.jitted_parameters:
  79. signature_id, = yield [("RRD", [body_id, "body"])]
  80. signature_id = signature_id[0]
  81. param_set_id, = yield [("RD", [signature_id, "params"])]
  82. if param_set_id is None:
  83. self.jitted_parameters[body_id] = ([], [])
  84. else:
  85. param_name_ids, = yield [("RDK", [param_set_id])]
  86. param_names = yield [("RV", [n]) for n in param_name_ids]
  87. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  88. self.jitted_parameters[body_id] = (param_vars, param_names)
  89. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  90. def jit_compile(self, user_root, body_id):
  91. """Tries to jit the function defined by the given entry point id and parameter list."""
  92. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  93. # pylint: disable=I0011,W0122
  94. if body_id in self.jitted_entry_points:
  95. # We have already compiled this function.
  96. raise primitive_functions.PrimitiveFinished(
  97. self.jit_globals[self.jitted_entry_points[body_id]])
  98. elif body_id in self.no_jit_entry_points:
  99. # We're not allowed to jit this function or have tried and failed before.
  100. raise JitCompilationFailedException(
  101. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  102. # Generate a name for the function we're about to analyze, and pretend that
  103. # it already exists. (we need to do this for recursive functions)
  104. function_name = self.generate_function_name()
  105. self.jitted_entry_points[body_id] = function_name
  106. self.jit_globals[function_name] = None
  107. try:
  108. gen = self.jit_parameters(body_id)
  109. inp = None
  110. while True:
  111. inp = yield gen.send(inp)
  112. except primitive_functions.PrimitiveFinished as ex:
  113. parameter_ids, parameter_list = ex.result
  114. param_dict = dict(zip(parameter_ids, parameter_list))
  115. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  116. try:
  117. gen = AnalysisState(
  118. self, user_root, body_param_dict, self.max_instructions).analyze(body_id)
  119. inp = None
  120. while True:
  121. inp = yield gen.send(inp)
  122. except primitive_functions.PrimitiveFinished as ex:
  123. constructed_body = ex.result
  124. except JitCompilationFailedException as ex:
  125. self.mark_no_jit(body_id)
  126. del self.jitted_entry_points[body_id]
  127. raise JitCompilationFailedException(
  128. '%s (function at %d)' % (ex.message, body_id))
  129. # Write a prologue and prepend it to the generated function body.
  130. prologue_statements = []
  131. for (key, val) in param_dict.items():
  132. arg_ptr = tree_ir.StoreLocalInstruction(
  133. body_param_dict[key],
  134. tree_ir.CreateNodeInstruction())
  135. prologue_statements.append(arg_ptr)
  136. prologue_statements.append(
  137. tree_ir.CreateDictionaryEdgeInstruction(
  138. arg_ptr.create_load(),
  139. tree_ir.LiteralInstruction('value'),
  140. tree_ir.LoadLocalInstruction(val)))
  141. constructed_body = tree_ir.create_block(
  142. *(prologue_statements + [constructed_body]))
  143. # Wrap the IR in a function definition, give it a unique name.
  144. constructed_function = tree_ir.DefineFunctionInstruction(
  145. function_name,
  146. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  147. constructed_body.simplify())
  148. # Convert the function definition to Python code, and compile it.
  149. exec(str(constructed_function), self.jit_globals)
  150. # Extract the compiled function from the JIT global state.
  151. compiled_function = self.jit_globals[function_name]
  152. # print(constructed_function)
  153. raise primitive_functions.PrimitiveFinished(compiled_function)
  154. class AnalysisState(object):
  155. """The state of a bytecode analysis call graph."""
  156. def __init__(self, jit, user_root, local_mapping, max_instructions=None):
  157. self.analyzed_instructions = set()
  158. self.function_vars = set()
  159. self.local_vars = set()
  160. self.max_instructions = max_instructions
  161. self.user_root = user_root
  162. self.jit = jit
  163. self.local_mapping = local_mapping
  164. def get_local_name(self, local_id):
  165. """Gets the name for a local with the given id."""
  166. if local_id not in self.local_mapping:
  167. self.local_mapping[local_id] = 'local%d' % local_id
  168. return self.local_mapping[local_id]
  169. def register_local_var(self, local_id):
  170. """Registers the given variable node id as a local."""
  171. if local_id in self.function_vars:
  172. raise JitCompilationFailedException(
  173. "Local is used as target of function call.")
  174. self.local_vars.add(local_id)
  175. def register_function_var(self, local_id):
  176. """Registers the given variable node id as a function."""
  177. if local_id in self.local_vars:
  178. raise JitCompilationFailedException(
  179. "Local is used as target of function call.")
  180. self.function_vars.add(local_id)
  181. def retrieve_user_root(self):
  182. """Creates an instruction that stores the user_root variable
  183. in a local."""
  184. return tree_ir.StoreLocalInstruction(
  185. 'user_root',
  186. tree_ir.LoadIndexInstruction(
  187. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  188. tree_ir.LiteralInstruction('user_root')))
  189. def analyze(self, instruction_id):
  190. """Tries to build an intermediate representation from the instruction with the
  191. given id."""
  192. # Check the analyzed_instructions set for instruction_id to avoid
  193. # infinite loops.
  194. if instruction_id in self.analyzed_instructions:
  195. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  196. elif (self.max_instructions is not None and
  197. len(self.analyzed_instructions) > self.max_instructions):
  198. raise JitCompilationFailedException('Maximum number of instructions exceeded.')
  199. self.analyzed_instructions.add(instruction_id)
  200. instruction_val, = yield [("RV", [instruction_id])]
  201. instruction_val = instruction_val["value"]
  202. if instruction_val in self.instruction_analyzers:
  203. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  204. try:
  205. inp = None
  206. while True:
  207. inp = yield gen.send(inp)
  208. except StopIteration:
  209. raise Exception(
  210. "Instruction analyzer (for '%s') finished without returning a value!" %
  211. (instruction_val))
  212. except primitive_functions.PrimitiveFinished as outer_e:
  213. # Check if the instruction has a 'next' instruction.
  214. next_instr, = yield [("RD", [instruction_id, "next"])]
  215. if next_instr is None:
  216. raise outer_e
  217. else:
  218. gen = self.analyze(next_instr)
  219. try:
  220. inp = None
  221. while True:
  222. inp = yield gen.send(inp)
  223. except primitive_functions.PrimitiveFinished as inner_e:
  224. raise primitive_functions.PrimitiveFinished(
  225. tree_ir.CompoundInstruction(
  226. outer_e.result,
  227. inner_e.result))
  228. else:
  229. raise JitCompilationFailedException(
  230. "Unknown instruction type: '%s'" % (instruction_val))
  231. def analyze_all(self, instruction_ids):
  232. """Tries to compile a list of IR trees from the given list of instruction ids."""
  233. results = []
  234. for inst in instruction_ids:
  235. gen = self.analyze(inst)
  236. try:
  237. inp = None
  238. while True:
  239. inp = yield gen.send(inp)
  240. except primitive_functions.PrimitiveFinished as ex:
  241. results.append(ex.result)
  242. raise primitive_functions.PrimitiveFinished(results)
  243. def analyze_return(self, instruction_id):
  244. """Tries to analyze the given 'return' instruction."""
  245. retval_id, = yield [("RD", [instruction_id, 'value'])]
  246. if retval_id is None:
  247. raise primitive_functions.PrimitiveFinished(
  248. tree_ir.ReturnInstruction(
  249. tree_ir.EmptyInstruction()))
  250. else:
  251. gen = self.analyze(retval_id)
  252. try:
  253. inp = None
  254. while True:
  255. inp = yield gen.send(inp)
  256. except primitive_functions.PrimitiveFinished as ex:
  257. raise primitive_functions.PrimitiveFinished(
  258. tree_ir.ReturnInstruction(ex.result))
  259. def analyze_if(self, instruction_id):
  260. """Tries to analyze the given 'if' instruction."""
  261. cond, true, false = yield [
  262. ("RD", [instruction_id, "cond"]),
  263. ("RD", [instruction_id, "then"]),
  264. ("RD", [instruction_id, "else"])]
  265. gen = self.analyze_all(
  266. [cond, true]
  267. if false is None
  268. else [cond, true, false])
  269. try:
  270. inp = None
  271. while True:
  272. inp = yield gen.send(inp)
  273. except primitive_functions.PrimitiveFinished as ex:
  274. if false is None:
  275. cond_r, true_r = ex.result
  276. false_r = tree_ir.EmptyInstruction()
  277. else:
  278. cond_r, true_r, false_r = ex.result
  279. raise primitive_functions.PrimitiveFinished(
  280. tree_ir.SelectInstruction(
  281. tree_ir.ReadValueInstruction(cond_r),
  282. true_r,
  283. false_r))
  284. def analyze_while(self, instruction_id):
  285. """Tries to analyze the given 'while' instruction."""
  286. cond, body = yield [
  287. ("RD", [instruction_id, "cond"]),
  288. ("RD", [instruction_id, "body"])]
  289. gen = self.analyze_all([cond, body])
  290. try:
  291. inp = None
  292. while True:
  293. inp = yield gen.send(inp)
  294. except primitive_functions.PrimitiveFinished as ex:
  295. cond_r, body_r = ex.result
  296. raise primitive_functions.PrimitiveFinished(
  297. tree_ir.LoopInstruction(
  298. tree_ir.CompoundInstruction(
  299. tree_ir.SelectInstruction(
  300. tree_ir.ReadValueInstruction(cond_r),
  301. tree_ir.EmptyInstruction(),
  302. tree_ir.BreakInstruction()),
  303. body_r)))
  304. def analyze_constant(self, instruction_id):
  305. """Tries to analyze the given 'constant' (literal) instruction."""
  306. node_id, = yield [("RD", [instruction_id, "node"])]
  307. raise primitive_functions.PrimitiveFinished(
  308. tree_ir.LiteralInstruction(node_id))
  309. def analyze_output(self, instruction_id):
  310. """Tries to analyze the given 'output' instruction."""
  311. # The plan is to basically generate this tree:
  312. #
  313. # value = <some tree>
  314. # last_output, last_output_link, new_last_output = \
  315. # yield [("RD", [user_root, "last_output"]),
  316. # ("RDE", [user_root, "last_output"]),
  317. # ("CN", []),
  318. # ]
  319. # _, _, _, _ = \
  320. # yield [("CD", [last_output, "value", value]),
  321. # ("CD", [last_output, "next", new_last_output]),
  322. # ("CD", [user_root, "last_output", new_last_output]),
  323. # ("DE", [last_output_link])
  324. # ]
  325. # yield None
  326. value_id, = yield [("RD", [instruction_id, "value"])]
  327. gen = self.analyze(value_id)
  328. try:
  329. inp = None
  330. while True:
  331. inp = yield gen.send(inp)
  332. except primitive_functions.PrimitiveFinished as ex:
  333. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  334. store_user_root = self.retrieve_user_root()
  335. last_output = tree_ir.StoreLocalInstruction(
  336. 'last_output',
  337. tree_ir.ReadDictionaryValueInstruction(
  338. store_user_root.create_load(),
  339. tree_ir.LiteralInstruction('last_output')))
  340. last_output_link = tree_ir.StoreLocalInstruction(
  341. 'last_output_link',
  342. tree_ir.ReadDictionaryEdgeInstruction(
  343. store_user_root.create_load(),
  344. tree_ir.LiteralInstruction('last_output')))
  345. new_last_output = tree_ir.StoreLocalInstruction(
  346. 'new_last_output',
  347. tree_ir.CreateNodeInstruction())
  348. result = tree_ir.create_block(
  349. value_local,
  350. store_user_root,
  351. last_output,
  352. last_output_link,
  353. new_last_output,
  354. tree_ir.CreateDictionaryEdgeInstruction(
  355. last_output.create_load(),
  356. tree_ir.LiteralInstruction('value'),
  357. value_local.create_load()),
  358. tree_ir.CreateDictionaryEdgeInstruction(
  359. last_output.create_load(),
  360. tree_ir.LiteralInstruction('next'),
  361. new_last_output.create_load()),
  362. tree_ir.CreateDictionaryEdgeInstruction(
  363. store_user_root.create_load(),
  364. tree_ir.LiteralInstruction('last_output'),
  365. new_last_output.create_load()),
  366. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  367. tree_ir.NopInstruction())
  368. raise primitive_functions.PrimitiveFinished(result)
  369. def analyze_input(self, _):
  370. """Tries to analyze the given 'input' instruction."""
  371. # The plan is to generate this tree:
  372. #
  373. # value = None
  374. # while True:
  375. # if value is None:
  376. # yield None # nop
  377. # else:
  378. # break
  379. #
  380. # _input = yield [("RD", [user_root, "input"])]
  381. # value = yield [("RD", [_input, "value"])]
  382. #
  383. # _next = yield [("RD", [_input, "next"])]
  384. # yield [("CD", [user_root, "input", _next])]
  385. # yield [("DN", [_input])]
  386. user_root = self.retrieve_user_root()
  387. _input = tree_ir.StoreLocalInstruction(
  388. '_input',
  389. tree_ir.ReadDictionaryValueInstruction(
  390. user_root.create_load(),
  391. tree_ir.LiteralInstruction('input')))
  392. value = tree_ir.StoreLocalInstruction(
  393. 'value',
  394. tree_ir.ReadDictionaryValueInstruction(
  395. _input.create_load(),
  396. tree_ir.LiteralInstruction('value')))
  397. raise primitive_functions.PrimitiveFinished(
  398. tree_ir.CompoundInstruction(
  399. tree_ir.create_block(
  400. user_root,
  401. value.create_store(tree_ir.LiteralInstruction(None)),
  402. tree_ir.LoopInstruction(
  403. tree_ir.create_block(
  404. tree_ir.SelectInstruction(
  405. tree_ir.BinaryInstruction(
  406. value.create_load(),
  407. 'is',
  408. tree_ir.LiteralInstruction(None)),
  409. tree_ir.NopInstruction(),
  410. tree_ir.BreakInstruction()),
  411. _input,
  412. value)),
  413. tree_ir.CreateDictionaryEdgeInstruction(
  414. user_root.create_load(),
  415. tree_ir.LiteralInstruction('input'),
  416. tree_ir.ReadDictionaryValueInstruction(
  417. _input.create_load(),
  418. tree_ir.LiteralInstruction('next'))),
  419. tree_ir.DeleteNodeInstruction(_input.create_load())),
  420. value.create_load()))
  421. def analyze_resolve(self, instruction_id):
  422. """Tries to analyze the given 'resolve' instruction."""
  423. var_id, = yield [("RD", [instruction_id, "var"])]
  424. var_name, = yield [("RV", [var_id])]
  425. # To resolve a variable, we'll do something along the
  426. # lines of:
  427. #
  428. # if 'local_var' in locals():
  429. # tmp = local_var
  430. # else:
  431. # _globals, = yield [("RD", [user_root, "globals"])]
  432. # global_var, = yield [("RD", [_globals, var_name])]
  433. #
  434. # if global_var is None:
  435. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  436. #
  437. # tmp = global_var
  438. name = self.get_local_name(var_id)
  439. if var_name is None:
  440. raise primitive_functions.PrimitiveFinished(
  441. tree_ir.LoadLocalInstruction(name))
  442. user_root = self.retrieve_user_root()
  443. global_var = tree_ir.StoreLocalInstruction(
  444. 'global_var',
  445. tree_ir.ReadDictionaryValueInstruction(
  446. tree_ir.ReadDictionaryValueInstruction(
  447. user_root.create_load(),
  448. tree_ir.LiteralInstruction('globals')),
  449. tree_ir.LiteralInstruction(var_name)))
  450. err_block = tree_ir.SelectInstruction(
  451. tree_ir.BinaryInstruction(
  452. global_var.create_load(),
  453. 'is',
  454. tree_ir.LiteralInstruction(None)),
  455. tree_ir.RaiseInstruction(
  456. tree_ir.CallInstruction(
  457. tree_ir.LoadGlobalInstruction('Exception'),
  458. [tree_ir.LiteralInstruction(
  459. "Runtime error: global '%s' not found" % var_name)
  460. ])),
  461. tree_ir.EmptyInstruction())
  462. raise primitive_functions.PrimitiveFinished(
  463. tree_ir.SelectInstruction(
  464. tree_ir.LocalExistsInstruction(name),
  465. tree_ir.LoadLocalInstruction(name),
  466. tree_ir.CompoundInstruction(
  467. tree_ir.create_block(
  468. user_root,
  469. global_var,
  470. err_block),
  471. global_var.create_load())))
  472. def analyze_declare(self, instruction_id):
  473. """Tries to analyze the given 'declare' function."""
  474. var_id, = yield [("RD", [instruction_id, "var"])]
  475. self.register_local_var(var_id)
  476. name = self.get_local_name(var_id)
  477. # The following logic declares a local:
  478. #
  479. # if 'local_name' not in locals():
  480. # local_name, = yield [("CN", [])]
  481. raise primitive_functions.PrimitiveFinished(
  482. tree_ir.SelectInstruction(
  483. tree_ir.LocalExistsInstruction(name),
  484. tree_ir.EmptyInstruction(),
  485. tree_ir.StoreLocalInstruction(
  486. name,
  487. tree_ir.CreateNodeInstruction())))
  488. def analyze_global(self, instruction_id):
  489. """Tries to analyze the given 'global' (declaration) instruction."""
  490. var_id, = yield [("RD", [instruction_id, "var"])]
  491. var_name, = yield [("RV", [var_id])]
  492. # To resolve a variable, we'll do something along the
  493. # lines of:
  494. #
  495. # _globals, = yield [("RD", [user_root, "globals"])]
  496. # global_var = yield [("RD", [_globals, var_name])]
  497. #
  498. # if global_var is None:
  499. # global_var, = yield [("CN", [])]
  500. # yield [("CD", [_globals, var_name, global_var])]
  501. #
  502. # tmp = global_var
  503. user_root = self.retrieve_user_root()
  504. _globals = tree_ir.StoreLocalInstruction(
  505. '_globals',
  506. tree_ir.ReadDictionaryValueInstruction(
  507. user_root.create_load(),
  508. tree_ir.LiteralInstruction('globals')))
  509. global_var = tree_ir.StoreLocalInstruction(
  510. 'global_var',
  511. tree_ir.ReadDictionaryValueInstruction(
  512. _globals.create_load(),
  513. tree_ir.LiteralInstruction(var_name)))
  514. raise primitive_functions.PrimitiveFinished(
  515. tree_ir.CompoundInstruction(
  516. tree_ir.create_block(
  517. user_root,
  518. _globals,
  519. global_var,
  520. tree_ir.SelectInstruction(
  521. tree_ir.BinaryInstruction(
  522. global_var.create_load(),
  523. 'is',
  524. tree_ir.LiteralInstruction(None)),
  525. tree_ir.create_block(
  526. global_var.create_store(
  527. tree_ir.CreateNodeInstruction()),
  528. tree_ir.CreateDictionaryEdgeInstruction(
  529. _globals.create_load(),
  530. tree_ir.LiteralInstruction(var_name),
  531. global_var.create_load())),
  532. tree_ir.EmptyInstruction())),
  533. global_var.create_load()))
  534. def analyze_assign(self, instruction_id):
  535. """Tries to analyze the given 'assign' instruction."""
  536. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  537. ("RD", [instruction_id, "value"])]
  538. try:
  539. gen = self.analyze_all([var_id, value_id])
  540. inp = None
  541. while True:
  542. inp = yield gen.send(inp)
  543. except primitive_functions.PrimitiveFinished as ex:
  544. var_r, value_r = ex.result
  545. # Assignments work like this:
  546. #
  547. # value_link = yield [("RDE", [variable, "value"])]
  548. # _, _ = yield [("CD", [variable, "value", value]),
  549. # ("DE", [value_link])]
  550. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  551. value = tree_ir.StoreLocalInstruction('value', value_r)
  552. value_link = tree_ir.StoreLocalInstruction(
  553. 'value_link',
  554. tree_ir.ReadDictionaryEdgeInstruction(
  555. variable.create_load(),
  556. tree_ir.LiteralInstruction('value')))
  557. raise primitive_functions.PrimitiveFinished(
  558. tree_ir.create_block(
  559. variable,
  560. value,
  561. value_link,
  562. tree_ir.CreateDictionaryEdgeInstruction(
  563. variable.create_load(),
  564. tree_ir.LiteralInstruction('value'),
  565. value.create_load()),
  566. tree_ir.DeleteEdgeInstruction(
  567. value_link.create_load())))
  568. def analyze_access(self, instruction_id):
  569. """Tries to analyze the given 'access' instruction."""
  570. var_id, = yield [("RD", [instruction_id, "var"])]
  571. try:
  572. gen = self.analyze(var_id)
  573. inp = None
  574. while True:
  575. inp = yield gen.send(inp)
  576. except primitive_functions.PrimitiveFinished as ex:
  577. var_r = ex.result
  578. # Accessing a variable is pretty easy. It really just boils
  579. # down to reading the value corresponding to the 'value' key
  580. # of the variable.
  581. #
  582. # value, = yield [("RD", [returnvalue, "value"])]
  583. raise primitive_functions.PrimitiveFinished(
  584. tree_ir.ReadDictionaryValueInstruction(
  585. var_r,
  586. tree_ir.LiteralInstruction('value')))
  587. def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
  588. """Tries to analyze a direct 'call' instruction."""
  589. self.register_function_var(callee_id)
  590. body_id, = yield [("RD", [callee_id, "body"])]
  591. # Analyze the parameter list.
  592. try:
  593. gen = self.jit.jit_parameters(body_id)
  594. inp = None
  595. while True:
  596. inp = yield gen.send(inp)
  597. except primitive_functions.PrimitiveFinished as ex:
  598. _, parameter_names = ex.result
  599. compiled_func = self.jit.lookup_compiled_function(callee_name)
  600. if compiled_func is None:
  601. # Compile the callee.
  602. try:
  603. gen = self.jit.jit_compile(self.user_root, body_id)
  604. inp = None
  605. while True:
  606. inp = yield gen.send(inp)
  607. except primitive_functions.PrimitiveFinished as ex:
  608. pass
  609. else:
  610. self.jit.register_compiled(body_id, compiled_func, callee_name)
  611. # Get the callee's name.
  612. compiled_func_name = self.jit.get_compiled_name(body_id)
  613. # Analyze the argument dictionary.
  614. try:
  615. gen = self.analyze_argument_dict(first_parameter_id)
  616. inp = None
  617. while True:
  618. inp = yield gen.send(inp)
  619. except primitive_functions.PrimitiveFinished as ex:
  620. arg_dict = ex.result
  621. # Construct the argument list from the parameter list and
  622. # argument dictionary.
  623. arg_list = []
  624. for param_name in parameter_names:
  625. if param_name in arg_dict:
  626. arg_list.append(arg_dict[param_name])
  627. else:
  628. raise JitCompilationFailedException(
  629. "Cannot JIT-compile function call to '%s' with missing argument for "
  630. "formal parameter '%s'." % (callee_name, param_name))
  631. raise primitive_functions.PrimitiveFinished(
  632. tree_ir.JitCallInstruction(
  633. tree_ir.LoadGlobalInstruction(compiled_func_name),
  634. arg_list,
  635. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  636. def analyze_argument_dict(self, first_argument_id):
  637. """Analyzes the parameter-to-argument mapping started by the specified first argument
  638. node."""
  639. next_param = first_argument_id
  640. argument_dict = {}
  641. while next_param is not None:
  642. param_name_id, = yield [("RD", [next_param, "name"])]
  643. param_name, = yield [("RV", [param_name_id])]
  644. param_val_id, = yield [("RD", [next_param, "value"])]
  645. try:
  646. gen = self.analyze(param_val_id)
  647. inp = None
  648. while True:
  649. inp = yield gen.send(inp)
  650. except primitive_functions.PrimitiveFinished as ex:
  651. argument_dict[param_name] = ex.result
  652. next_param, = yield [("RD", [next_param, "next_param"])]
  653. raise primitive_functions.PrimitiveFinished(argument_dict)
  654. def analyze_call(self, instruction_id):
  655. """Tries to analyze the given 'call' instruction."""
  656. func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
  657. ("RD", [instruction_id, "params"])]
  658. # Figure out what the 'func' instruction's type is.
  659. func_instruction_op, = yield [("RV", [func_id])]
  660. if func_instruction_op['value'] == 'access':
  661. # Calls to 'access(resolve(var))' instructions are translated to direct calls.
  662. access_value_id, = yield [("RD", [func_id, "var"])]
  663. access_value_op, = yield [("RV", [access_value_id])]
  664. if access_value_op['value'] == 'resolve':
  665. resolved_var_id, = yield [("RD", [access_value_id, "var"])]
  666. resolved_var_name, = yield [("RV", [resolved_var_id])]
  667. # Try to look the name up as a global.
  668. _globals, = yield [("RD", [self.user_root, "globals"])]
  669. global_var, = yield [("RD", [_globals, resolved_var_name])]
  670. global_val, = yield [("RD", [global_var, "value"])]
  671. if global_val is None:
  672. raise JitCompilationFailedException(
  673. "Cannot JIT function calls that target an unknown value.")
  674. else:
  675. gen = self.analyze_direct_call(
  676. global_val, resolved_var_name, first_param_id)
  677. inp = None
  678. while True:
  679. inp = yield gen.send(inp)
  680. # PrimitiveFinished exception will bubble up from here.
  681. raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
  682. instruction_analyzers = {
  683. 'if' : analyze_if,
  684. 'while' : analyze_while,
  685. 'return' : analyze_return,
  686. 'constant' : analyze_constant,
  687. 'resolve' : analyze_resolve,
  688. 'declare' : analyze_declare,
  689. 'global' : analyze_global,
  690. 'assign' : analyze_assign,
  691. 'access' : analyze_access,
  692. 'output' : analyze_output,
  693. 'input' : analyze_input,
  694. 'call' : analyze_call
  695. }