jit.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None, compiled_function_lookup=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jitted_parameters = {}
  15. self.jit_globals = {
  16. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  17. }
  18. self.jit_count = 0
  19. self.max_instructions = 30 if max_instructions is None else max_instructions
  20. self.compiled_function_lookup = compiled_function_lookup
  21. self.jit_enabled = True
  22. def set_jit_enabled(self, is_enabled=True):
  23. """Enables or disables the JIT."""
  24. self.jit_enabled = is_enabled
  25. def mark_entry_point(self, body_id):
  26. """Marks the node with the given identifier as a function entry point."""
  27. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  28. self.todo_entry_points.add(body_id)
  29. def is_entry_point(self, body_id):
  30. """Tells if the node with the given identifier is a function entry point."""
  31. return body_id in self.todo_entry_points or \
  32. body_id in self.no_jit_entry_points or \
  33. body_id in self.jitted_entry_points
  34. def is_jittable_entry_point(self, body_id):
  35. """Tells if the node with the given identifier is a function entry point that
  36. has not been marked as non-jittable. This only returns `True` if the JIT
  37. is enabled and the function entry point has been marked jittable, or if
  38. the function has already been compiled."""
  39. return ((self.jit_enabled and body_id in self.todo_entry_points) or
  40. self.has_compiled(body_id))
  41. def has_compiled(self, body_id):
  42. """Tests if the function belonging to the given body node has been compiled yet."""
  43. return body_id in self.jitted_entry_points
  44. def get_compiled_name(self, body_id):
  45. """Gets the name of the compiled version of the given body node in the JIT
  46. global state."""
  47. return self.jitted_entry_points[body_id]
  48. def mark_no_jit(self, body_id):
  49. """Informs the JIT that the node with the given identifier is a function entry
  50. point that must never be jitted."""
  51. self.no_jit_entry_points.add(body_id)
  52. if body_id in self.todo_entry_points:
  53. self.todo_entry_points.remove(body_id)
  54. def generate_function_name(self, suggested_name=None):
  55. """Generates a new function name or picks the suggested name if it is still
  56. available."""
  57. if suggested_name is not None and suggested_name not in self.jit_globals:
  58. self.jit_count += 1
  59. return suggested_name
  60. else:
  61. function_name = 'jit_func%d' % self.jit_count
  62. self.jit_count += 1
  63. return function_name
  64. def register_compiled(self, body_id, compiled_function, function_name=None):
  65. """Registers a compiled entry point with the JIT."""
  66. function_name = self.generate_function_name(function_name)
  67. self.jitted_entry_points[body_id] = function_name
  68. self.jit_globals[function_name] = compiled_function
  69. if body_id in self.todo_entry_points:
  70. self.todo_entry_points.remove(body_id)
  71. def lookup_compiled_function(self, name):
  72. """Looks up a compiled function by name. Returns a matching function,
  73. or None if no function was found."""
  74. if name in self.jit_globals:
  75. return self.jit_globals[name]
  76. elif self.compiled_function_lookup is not None:
  77. return self.compiled_function_lookup(name)
  78. else:
  79. return None
  80. def jit_parameters(self, body_id):
  81. """Acquires the parameter list for the given body id node."""
  82. if body_id not in self.jitted_parameters:
  83. signature_id, = yield [("RRD", [body_id, "body"])]
  84. signature_id = signature_id[0]
  85. param_set_id, = yield [("RD", [signature_id, "params"])]
  86. if param_set_id is None:
  87. self.jitted_parameters[body_id] = ([], [])
  88. else:
  89. param_name_ids, = yield [("RDK", [param_set_id])]
  90. param_names = yield [("RV", [n]) for n in param_name_ids]
  91. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  92. self.jitted_parameters[body_id] = (param_vars, param_names)
  93. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  94. def jit_compile(self, user_root, body_id, suggested_name=None):
  95. """Tries to jit the function defined by the given entry point id and parameter list."""
  96. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  97. # pylint: disable=I0011,W0122
  98. if body_id in self.jitted_entry_points:
  99. # We have already compiled this function.
  100. raise primitive_functions.PrimitiveFinished(
  101. self.jit_globals[self.jitted_entry_points[body_id]])
  102. elif body_id in self.no_jit_entry_points:
  103. # We're not allowed to jit this function or have tried and failed before.
  104. raise JitCompilationFailedException(
  105. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  106. # Generate a name for the function we're about to analyze, and pretend that
  107. # it already exists. (we need to do this for recursive functions)
  108. function_name = self.generate_function_name(suggested_name)
  109. self.jitted_entry_points[body_id] = function_name
  110. self.jit_globals[function_name] = None
  111. try:
  112. gen = self.jit_parameters(body_id)
  113. inp = None
  114. while True:
  115. inp = yield gen.send(inp)
  116. except primitive_functions.PrimitiveFinished as ex:
  117. parameter_ids, parameter_list = ex.result
  118. param_dict = dict(zip(parameter_ids, parameter_list))
  119. body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
  120. try:
  121. gen = AnalysisState(
  122. self, user_root, body_param_dict, self.max_instructions).analyze(body_id)
  123. inp = None
  124. while True:
  125. inp = yield gen.send(inp)
  126. except primitive_functions.PrimitiveFinished as ex:
  127. constructed_body = ex.result
  128. except JitCompilationFailedException as ex:
  129. self.mark_no_jit(body_id)
  130. del self.jitted_entry_points[body_id]
  131. raise JitCompilationFailedException(
  132. '%s (function at %d)' % (ex.message, body_id))
  133. # Write a prologue and prepend it to the generated function body.
  134. prologue_statements = []
  135. for (key, val) in param_dict.items():
  136. arg_ptr = tree_ir.StoreLocalInstruction(
  137. body_param_dict[key],
  138. tree_ir.CreateNodeInstruction())
  139. prologue_statements.append(arg_ptr)
  140. prologue_statements.append(
  141. tree_ir.CreateDictionaryEdgeInstruction(
  142. arg_ptr.create_load(),
  143. tree_ir.LiteralInstruction('value'),
  144. tree_ir.LoadLocalInstruction(val)))
  145. constructed_body = tree_ir.create_block(
  146. *(prologue_statements + [constructed_body]))
  147. # Wrap the IR in a function definition, give it a unique name.
  148. constructed_function = tree_ir.DefineFunctionInstruction(
  149. function_name,
  150. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  151. constructed_body.simplify())
  152. # Convert the function definition to Python code, and compile it.
  153. exec(str(constructed_function), self.jit_globals)
  154. # Extract the compiled function from the JIT global state.
  155. compiled_function = self.jit_globals[function_name]
  156. print(constructed_function)
  157. raise primitive_functions.PrimitiveFinished(compiled_function)
  158. class AnalysisState(object):
  159. """The state of a bytecode analysis call graph."""
  160. def __init__(self, jit, user_root, local_mapping, max_instructions=None):
  161. self.analyzed_instructions = set()
  162. self.function_vars = set()
  163. self.local_vars = set()
  164. self.max_instructions = max_instructions
  165. self.user_root = user_root
  166. self.jit = jit
  167. self.local_mapping = local_mapping
  168. def get_local_name(self, local_id):
  169. """Gets the name for a local with the given id."""
  170. if local_id not in self.local_mapping:
  171. self.local_mapping[local_id] = 'local%d' % local_id
  172. return self.local_mapping[local_id]
  173. def register_local_var(self, local_id):
  174. """Registers the given variable node id as a local."""
  175. if local_id in self.function_vars:
  176. raise JitCompilationFailedException(
  177. "Local is used as target of function call.")
  178. self.local_vars.add(local_id)
  179. def register_function_var(self, local_id):
  180. """Registers the given variable node id as a function."""
  181. if local_id in self.local_vars:
  182. raise JitCompilationFailedException(
  183. "Local is used as target of function call.")
  184. self.function_vars.add(local_id)
  185. def retrieve_user_root(self):
  186. """Creates an instruction that stores the user_root variable
  187. in a local."""
  188. return tree_ir.StoreLocalInstruction(
  189. 'user_root',
  190. tree_ir.LoadIndexInstruction(
  191. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  192. tree_ir.LiteralInstruction('user_root')))
  193. def analyze(self, instruction_id):
  194. """Tries to build an intermediate representation from the instruction with the
  195. given id."""
  196. # Check the analyzed_instructions set for instruction_id to avoid
  197. # infinite loops.
  198. if instruction_id in self.analyzed_instructions:
  199. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  200. elif (self.max_instructions is not None and
  201. len(self.analyzed_instructions) > self.max_instructions):
  202. raise JitCompilationFailedException('Maximum number of instructions exceeded.')
  203. self.analyzed_instructions.add(instruction_id)
  204. instruction_val, = yield [("RV", [instruction_id])]
  205. instruction_val = instruction_val["value"]
  206. if instruction_val in self.instruction_analyzers:
  207. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  208. try:
  209. inp = None
  210. while True:
  211. inp = yield gen.send(inp)
  212. except StopIteration:
  213. raise Exception(
  214. "Instruction analyzer (for '%s') finished without returning a value!" %
  215. (instruction_val))
  216. except primitive_functions.PrimitiveFinished as outer_e:
  217. # Check if the instruction has a 'next' instruction.
  218. next_instr, = yield [("RD", [instruction_id, "next"])]
  219. if next_instr is None:
  220. raise outer_e
  221. else:
  222. gen = self.analyze(next_instr)
  223. try:
  224. inp = None
  225. while True:
  226. inp = yield gen.send(inp)
  227. except primitive_functions.PrimitiveFinished as inner_e:
  228. raise primitive_functions.PrimitiveFinished(
  229. tree_ir.CompoundInstruction(
  230. outer_e.result,
  231. inner_e.result))
  232. else:
  233. raise JitCompilationFailedException(
  234. "Unknown instruction type: '%s'" % (instruction_val))
  235. def analyze_all(self, instruction_ids):
  236. """Tries to compile a list of IR trees from the given list of instruction ids."""
  237. results = []
  238. for inst in instruction_ids:
  239. gen = self.analyze(inst)
  240. try:
  241. inp = None
  242. while True:
  243. inp = yield gen.send(inp)
  244. except primitive_functions.PrimitiveFinished as ex:
  245. results.append(ex.result)
  246. raise primitive_functions.PrimitiveFinished(results)
  247. def analyze_return(self, instruction_id):
  248. """Tries to analyze the given 'return' instruction."""
  249. retval_id, = yield [("RD", [instruction_id, 'value'])]
  250. if retval_id is None:
  251. raise primitive_functions.PrimitiveFinished(
  252. tree_ir.ReturnInstruction(
  253. tree_ir.EmptyInstruction()))
  254. else:
  255. gen = self.analyze(retval_id)
  256. try:
  257. inp = None
  258. while True:
  259. inp = yield gen.send(inp)
  260. except primitive_functions.PrimitiveFinished as ex:
  261. raise primitive_functions.PrimitiveFinished(
  262. tree_ir.ReturnInstruction(ex.result))
  263. def analyze_if(self, instruction_id):
  264. """Tries to analyze the given 'if' instruction."""
  265. cond, true, false = yield [
  266. ("RD", [instruction_id, "cond"]),
  267. ("RD", [instruction_id, "then"]),
  268. ("RD", [instruction_id, "else"])]
  269. gen = self.analyze_all(
  270. [cond, true]
  271. if false is None
  272. else [cond, true, false])
  273. try:
  274. inp = None
  275. while True:
  276. inp = yield gen.send(inp)
  277. except primitive_functions.PrimitiveFinished as ex:
  278. if false is None:
  279. cond_r, true_r = ex.result
  280. false_r = tree_ir.EmptyInstruction()
  281. else:
  282. cond_r, true_r, false_r = ex.result
  283. raise primitive_functions.PrimitiveFinished(
  284. tree_ir.SelectInstruction(
  285. tree_ir.ReadValueInstruction(cond_r),
  286. true_r,
  287. false_r))
  288. def analyze_while(self, instruction_id):
  289. """Tries to analyze the given 'while' instruction."""
  290. cond, body = yield [
  291. ("RD", [instruction_id, "cond"]),
  292. ("RD", [instruction_id, "body"])]
  293. gen = self.analyze_all([cond, body])
  294. try:
  295. inp = None
  296. while True:
  297. inp = yield gen.send(inp)
  298. except primitive_functions.PrimitiveFinished as ex:
  299. cond_r, body_r = ex.result
  300. raise primitive_functions.PrimitiveFinished(
  301. tree_ir.LoopInstruction(
  302. tree_ir.CompoundInstruction(
  303. tree_ir.SelectInstruction(
  304. tree_ir.ReadValueInstruction(cond_r),
  305. tree_ir.EmptyInstruction(),
  306. tree_ir.BreakInstruction()),
  307. body_r)))
  308. def analyze_constant(self, instruction_id):
  309. """Tries to analyze the given 'constant' (literal) instruction."""
  310. node_id, = yield [("RD", [instruction_id, "node"])]
  311. raise primitive_functions.PrimitiveFinished(
  312. tree_ir.LiteralInstruction(node_id))
  313. def analyze_output(self, instruction_id):
  314. """Tries to analyze the given 'output' instruction."""
  315. # The plan is to basically generate this tree:
  316. #
  317. # value = <some tree>
  318. # last_output, last_output_link, new_last_output = \
  319. # yield [("RD", [user_root, "last_output"]),
  320. # ("RDE", [user_root, "last_output"]),
  321. # ("CN", []),
  322. # ]
  323. # _, _, _, _ = \
  324. # yield [("CD", [last_output, "value", value]),
  325. # ("CD", [last_output, "next", new_last_output]),
  326. # ("CD", [user_root, "last_output", new_last_output]),
  327. # ("DE", [last_output_link])
  328. # ]
  329. # yield None
  330. value_id, = yield [("RD", [instruction_id, "value"])]
  331. gen = self.analyze(value_id)
  332. try:
  333. inp = None
  334. while True:
  335. inp = yield gen.send(inp)
  336. except primitive_functions.PrimitiveFinished as ex:
  337. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  338. store_user_root = self.retrieve_user_root()
  339. last_output = tree_ir.StoreLocalInstruction(
  340. 'last_output',
  341. tree_ir.ReadDictionaryValueInstruction(
  342. store_user_root.create_load(),
  343. tree_ir.LiteralInstruction('last_output')))
  344. last_output_link = tree_ir.StoreLocalInstruction(
  345. 'last_output_link',
  346. tree_ir.ReadDictionaryEdgeInstruction(
  347. store_user_root.create_load(),
  348. tree_ir.LiteralInstruction('last_output')))
  349. new_last_output = tree_ir.StoreLocalInstruction(
  350. 'new_last_output',
  351. tree_ir.CreateNodeInstruction())
  352. result = tree_ir.create_block(
  353. value_local,
  354. store_user_root,
  355. last_output,
  356. last_output_link,
  357. new_last_output,
  358. tree_ir.CreateDictionaryEdgeInstruction(
  359. last_output.create_load(),
  360. tree_ir.LiteralInstruction('value'),
  361. value_local.create_load()),
  362. tree_ir.CreateDictionaryEdgeInstruction(
  363. last_output.create_load(),
  364. tree_ir.LiteralInstruction('next'),
  365. new_last_output.create_load()),
  366. tree_ir.CreateDictionaryEdgeInstruction(
  367. store_user_root.create_load(),
  368. tree_ir.LiteralInstruction('last_output'),
  369. new_last_output.create_load()),
  370. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  371. tree_ir.NopInstruction())
  372. raise primitive_functions.PrimitiveFinished(result)
  373. def analyze_input(self, _):
  374. """Tries to analyze the given 'input' instruction."""
  375. # The plan is to generate this tree:
  376. #
  377. # value = None
  378. # while True:
  379. # if value is None:
  380. # yield None # nop
  381. # else:
  382. # break
  383. #
  384. # _input = yield [("RD", [user_root, "input"])]
  385. # value = yield [("RD", [_input, "value"])]
  386. #
  387. # _next = yield [("RD", [_input, "next"])]
  388. # yield [("CD", [user_root, "input", _next])]
  389. # yield [("DN", [_input])]
  390. user_root = self.retrieve_user_root()
  391. _input = tree_ir.StoreLocalInstruction(
  392. '_input',
  393. tree_ir.ReadDictionaryValueInstruction(
  394. user_root.create_load(),
  395. tree_ir.LiteralInstruction('input')))
  396. value = tree_ir.StoreLocalInstruction(
  397. 'value',
  398. tree_ir.ReadDictionaryValueInstruction(
  399. _input.create_load(),
  400. tree_ir.LiteralInstruction('value')))
  401. raise primitive_functions.PrimitiveFinished(
  402. tree_ir.CompoundInstruction(
  403. tree_ir.create_block(
  404. user_root,
  405. value.create_store(tree_ir.LiteralInstruction(None)),
  406. tree_ir.LoopInstruction(
  407. tree_ir.create_block(
  408. tree_ir.SelectInstruction(
  409. tree_ir.BinaryInstruction(
  410. value.create_load(),
  411. 'is',
  412. tree_ir.LiteralInstruction(None)),
  413. tree_ir.NopInstruction(),
  414. tree_ir.BreakInstruction()),
  415. _input,
  416. value)),
  417. tree_ir.CreateDictionaryEdgeInstruction(
  418. user_root.create_load(),
  419. tree_ir.LiteralInstruction('input'),
  420. tree_ir.ReadDictionaryValueInstruction(
  421. _input.create_load(),
  422. tree_ir.LiteralInstruction('next'))),
  423. tree_ir.DeleteNodeInstruction(_input.create_load())),
  424. value.create_load()))
  425. def analyze_resolve(self, instruction_id):
  426. """Tries to analyze the given 'resolve' instruction."""
  427. var_id, = yield [("RD", [instruction_id, "var"])]
  428. var_name, = yield [("RV", [var_id])]
  429. # To resolve a variable, we'll do something along the
  430. # lines of:
  431. #
  432. # if 'local_var' in locals():
  433. # tmp = local_var
  434. # else:
  435. # _globals, = yield [("RD", [user_root, "globals"])]
  436. # global_var, = yield [("RD", [_globals, var_name])]
  437. #
  438. # if global_var is None:
  439. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  440. #
  441. # tmp = global_var
  442. name = self.get_local_name(var_id)
  443. if var_name is None:
  444. raise primitive_functions.PrimitiveFinished(
  445. tree_ir.LoadLocalInstruction(name))
  446. user_root = self.retrieve_user_root()
  447. global_var = tree_ir.StoreLocalInstruction(
  448. 'global_var',
  449. tree_ir.ReadDictionaryValueInstruction(
  450. tree_ir.ReadDictionaryValueInstruction(
  451. user_root.create_load(),
  452. tree_ir.LiteralInstruction('globals')),
  453. tree_ir.LiteralInstruction(var_name)))
  454. err_block = tree_ir.SelectInstruction(
  455. tree_ir.BinaryInstruction(
  456. global_var.create_load(),
  457. 'is',
  458. tree_ir.LiteralInstruction(None)),
  459. tree_ir.RaiseInstruction(
  460. tree_ir.CallInstruction(
  461. tree_ir.LoadGlobalInstruction('Exception'),
  462. [tree_ir.LiteralInstruction(
  463. "Runtime error: global '%s' not found" % var_name)
  464. ])),
  465. tree_ir.EmptyInstruction())
  466. raise primitive_functions.PrimitiveFinished(
  467. tree_ir.SelectInstruction(
  468. tree_ir.LocalExistsInstruction(name),
  469. tree_ir.LoadLocalInstruction(name),
  470. tree_ir.CompoundInstruction(
  471. tree_ir.create_block(
  472. user_root,
  473. global_var,
  474. err_block),
  475. global_var.create_load())))
  476. def analyze_declare(self, instruction_id):
  477. """Tries to analyze the given 'declare' function."""
  478. var_id, = yield [("RD", [instruction_id, "var"])]
  479. self.register_local_var(var_id)
  480. name = self.get_local_name(var_id)
  481. # The following logic declares a local:
  482. #
  483. # if 'local_name' not in locals():
  484. # local_name, = yield [("CN", [])]
  485. raise primitive_functions.PrimitiveFinished(
  486. tree_ir.SelectInstruction(
  487. tree_ir.LocalExistsInstruction(name),
  488. tree_ir.EmptyInstruction(),
  489. tree_ir.StoreLocalInstruction(
  490. name,
  491. tree_ir.CreateNodeInstruction())))
  492. def analyze_global(self, instruction_id):
  493. """Tries to analyze the given 'global' (declaration) instruction."""
  494. var_id, = yield [("RD", [instruction_id, "var"])]
  495. var_name, = yield [("RV", [var_id])]
  496. # To resolve a variable, we'll do something along the
  497. # lines of:
  498. #
  499. # _globals, = yield [("RD", [user_root, "globals"])]
  500. # global_var = yield [("RD", [_globals, var_name])]
  501. #
  502. # if global_var is None:
  503. # global_var, = yield [("CN", [])]
  504. # yield [("CD", [_globals, var_name, global_var])]
  505. #
  506. # tmp = global_var
  507. user_root = self.retrieve_user_root()
  508. _globals = tree_ir.StoreLocalInstruction(
  509. '_globals',
  510. tree_ir.ReadDictionaryValueInstruction(
  511. user_root.create_load(),
  512. tree_ir.LiteralInstruction('globals')))
  513. global_var = tree_ir.StoreLocalInstruction(
  514. 'global_var',
  515. tree_ir.ReadDictionaryValueInstruction(
  516. _globals.create_load(),
  517. tree_ir.LiteralInstruction(var_name)))
  518. raise primitive_functions.PrimitiveFinished(
  519. tree_ir.CompoundInstruction(
  520. tree_ir.create_block(
  521. user_root,
  522. _globals,
  523. global_var,
  524. tree_ir.SelectInstruction(
  525. tree_ir.BinaryInstruction(
  526. global_var.create_load(),
  527. 'is',
  528. tree_ir.LiteralInstruction(None)),
  529. tree_ir.create_block(
  530. global_var.create_store(
  531. tree_ir.CreateNodeInstruction()),
  532. tree_ir.CreateDictionaryEdgeInstruction(
  533. _globals.create_load(),
  534. tree_ir.LiteralInstruction(var_name),
  535. global_var.create_load())),
  536. tree_ir.EmptyInstruction())),
  537. global_var.create_load()))
  538. def analyze_assign(self, instruction_id):
  539. """Tries to analyze the given 'assign' instruction."""
  540. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  541. ("RD", [instruction_id, "value"])]
  542. try:
  543. gen = self.analyze_all([var_id, value_id])
  544. inp = None
  545. while True:
  546. inp = yield gen.send(inp)
  547. except primitive_functions.PrimitiveFinished as ex:
  548. var_r, value_r = ex.result
  549. # Assignments work like this:
  550. #
  551. # value_link = yield [("RDE", [variable, "value"])]
  552. # _, _ = yield [("CD", [variable, "value", value]),
  553. # ("DE", [value_link])]
  554. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  555. value = tree_ir.StoreLocalInstruction('value', value_r)
  556. value_link = tree_ir.StoreLocalInstruction(
  557. 'value_link',
  558. tree_ir.ReadDictionaryEdgeInstruction(
  559. variable.create_load(),
  560. tree_ir.LiteralInstruction('value')))
  561. raise primitive_functions.PrimitiveFinished(
  562. tree_ir.create_block(
  563. variable,
  564. value,
  565. value_link,
  566. tree_ir.CreateDictionaryEdgeInstruction(
  567. variable.create_load(),
  568. tree_ir.LiteralInstruction('value'),
  569. value.create_load()),
  570. tree_ir.DeleteEdgeInstruction(
  571. value_link.create_load())))
  572. def analyze_access(self, instruction_id):
  573. """Tries to analyze the given 'access' instruction."""
  574. var_id, = yield [("RD", [instruction_id, "var"])]
  575. try:
  576. gen = self.analyze(var_id)
  577. inp = None
  578. while True:
  579. inp = yield gen.send(inp)
  580. except primitive_functions.PrimitiveFinished as ex:
  581. var_r = ex.result
  582. # Accessing a variable is pretty easy. It really just boils
  583. # down to reading the value corresponding to the 'value' key
  584. # of the variable.
  585. #
  586. # value, = yield [("RD", [returnvalue, "value"])]
  587. raise primitive_functions.PrimitiveFinished(
  588. tree_ir.ReadDictionaryValueInstruction(
  589. var_r,
  590. tree_ir.LiteralInstruction('value')))
  591. def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
  592. """Tries to analyze a direct 'call' instruction."""
  593. self.register_function_var(callee_id)
  594. body_id, = yield [("RD", [callee_id, "body"])]
  595. # Analyze the parameter list.
  596. try:
  597. gen = self.jit.jit_parameters(body_id)
  598. inp = None
  599. while True:
  600. inp = yield gen.send(inp)
  601. except primitive_functions.PrimitiveFinished as ex:
  602. _, parameter_names = ex.result
  603. compiled_func = self.jit.lookup_compiled_function(callee_name)
  604. if compiled_func is None:
  605. # Compile the callee.
  606. try:
  607. gen = self.jit.jit_compile(self.user_root, body_id, callee_name)
  608. inp = None
  609. while True:
  610. inp = yield gen.send(inp)
  611. except primitive_functions.PrimitiveFinished as ex:
  612. pass
  613. else:
  614. self.jit.register_compiled(body_id, compiled_func, callee_name)
  615. # Get the callee's name.
  616. compiled_func_name = self.jit.get_compiled_name(body_id)
  617. # Analyze the argument dictionary.
  618. try:
  619. gen = self.analyze_argument_dict(first_parameter_id)
  620. inp = None
  621. while True:
  622. inp = yield gen.send(inp)
  623. except primitive_functions.PrimitiveFinished as ex:
  624. arg_dict = ex.result
  625. # Construct the argument list from the parameter list and
  626. # argument dictionary.
  627. arg_list = []
  628. for param_name in parameter_names:
  629. if param_name in arg_dict:
  630. arg_list.append(arg_dict[param_name])
  631. else:
  632. raise JitCompilationFailedException(
  633. "Cannot JIT-compile function call to '%s' with missing argument for "
  634. "formal parameter '%s'." % (callee_name, param_name))
  635. raise primitive_functions.PrimitiveFinished(
  636. tree_ir.JitCallInstruction(
  637. tree_ir.LoadGlobalInstruction(compiled_func_name),
  638. arg_list,
  639. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
  640. def analyze_argument_dict(self, first_argument_id):
  641. """Analyzes the parameter-to-argument mapping started by the specified first argument
  642. node."""
  643. next_param = first_argument_id
  644. argument_dict = {}
  645. while next_param is not None:
  646. param_name_id, = yield [("RD", [next_param, "name"])]
  647. param_name, = yield [("RV", [param_name_id])]
  648. param_val_id, = yield [("RD", [next_param, "value"])]
  649. try:
  650. gen = self.analyze(param_val_id)
  651. inp = None
  652. while True:
  653. inp = yield gen.send(inp)
  654. except primitive_functions.PrimitiveFinished as ex:
  655. argument_dict[param_name] = ex.result
  656. next_param, = yield [("RD", [next_param, "next_param"])]
  657. raise primitive_functions.PrimitiveFinished(argument_dict)
  658. def analyze_call(self, instruction_id):
  659. """Tries to analyze the given 'call' instruction."""
  660. func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
  661. ("RD", [instruction_id, "params"])]
  662. # Figure out what the 'func' instruction's type is.
  663. func_instruction_op, = yield [("RV", [func_id])]
  664. if func_instruction_op['value'] == 'access':
  665. # Calls to 'access(resolve(var))' instructions are translated to direct calls.
  666. access_value_id, = yield [("RD", [func_id, "var"])]
  667. access_value_op, = yield [("RV", [access_value_id])]
  668. if access_value_op['value'] == 'resolve':
  669. resolved_var_id, = yield [("RD", [access_value_id, "var"])]
  670. resolved_var_name, = yield [("RV", [resolved_var_id])]
  671. # Try to look the name up as a global.
  672. _globals, = yield [("RD", [self.user_root, "globals"])]
  673. global_var, = yield [("RD", [_globals, resolved_var_name])]
  674. global_val, = yield [("RD", [global_var, "value"])]
  675. if global_val is None:
  676. raise JitCompilationFailedException(
  677. "Cannot JIT function calls that target an unknown value.")
  678. else:
  679. gen = self.analyze_direct_call(
  680. global_val, resolved_var_name, first_param_id)
  681. inp = None
  682. while True:
  683. inp = yield gen.send(inp)
  684. # PrimitiveFinished exception will bubble up from here.
  685. raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
  686. instruction_analyzers = {
  687. 'if' : analyze_if,
  688. 'while' : analyze_while,
  689. 'return' : analyze_return,
  690. 'constant' : analyze_constant,
  691. 'resolve' : analyze_resolve,
  692. 'declare' : analyze_declare,
  693. 'global' : analyze_global,
  694. 'assign' : analyze_assign,
  695. 'access' : analyze_access,
  696. 'output' : analyze_output,
  697. 'input' : analyze_input,
  698. 'call' : analyze_call
  699. }