jit.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jit_globals = {
  15. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  16. }
  17. self.jit_count = 0
  18. def mark_entry_point(self, body_id):
  19. """Marks the node with the given identifier as a function entry point."""
  20. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  21. self.todo_entry_points.add(body_id)
  22. def is_entry_point(self, body_id):
  23. """Tells if the node with the given identifier is a function entry point."""
  24. return body_id in self.todo_entry_points or \
  25. body_id in self.no_jit_entry_points or \
  26. body_id in self.jitted_entry_points
  27. def is_jittable_entry_point(self, body_id):
  28. """Tells if the node with the given identifier is a function entry point that
  29. has not been marked as non-jittable."""
  30. return body_id in self.todo_entry_points or \
  31. body_id in self.jitted_entry_points
  32. def mark_no_jit(self, body_id):
  33. """Informs the JIT that the node with the given identifier is a function entry
  34. point that must never be jitted."""
  35. self.no_jit_entry_points.add(body_id)
  36. if body_id in self.todo_entry_points:
  37. self.todo_entry_points.remove(body_id)
  38. def register_compiled(self, body_id, compiled):
  39. """Registers a compiled entry point with the JIT."""
  40. self.jitted_entry_points[body_id] = compiled
  41. if body_id in self.todo_entry_points:
  42. self.todo_entry_points.remove(body_id)
  43. def try_jit(self, body_id, parameter_list):
  44. """Tries to jit the function defined by the given entry point id and parameter list."""
  45. if body_id in self.jitted_entry_points:
  46. # We have already compiled this function.
  47. raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
  48. elif body_id in self.no_jit_entry_points:
  49. # We're not allowed to jit this function or have tried and failed before.
  50. raise JitCompilationFailedException(
  51. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  52. gen = AnalysisState().analyze(body_id)
  53. try:
  54. inp = None
  55. while True:
  56. inp = yield gen.send(inp)
  57. except primitive_functions.PrimitiveFinished as ex:
  58. constructed_body = ex.result
  59. except JitCompilationFailedException as ex:
  60. self.mark_no_jit(body_id)
  61. raise JitCompilationFailedException(
  62. '%s (function at %d)' % (ex.message, body_id))
  63. # Wrap the IR in a function definition, give it a unique name.
  64. constructed_function = tree_ir.DefineFunctionInstruction(
  65. 'jit_func%d' % self.jit_count,
  66. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  67. constructed_body.simplify())
  68. self.jit_count += 1
  69. # Convert the function definition to Python code, and compile it.
  70. exec(str(constructed_function), self.jit_globals)
  71. # Extract the compiled function from the JIT global state.
  72. compiled_function = self.jit_globals[constructed_function.name]
  73. print(constructed_function)
  74. # Save the compiled function so we can reuse it later.
  75. self.jitted_entry_points[body_id] = compiled_function
  76. raise primitive_functions.PrimitiveFinished(compiled_function)
  77. class AnalysisState(object):
  78. """The state of a bytecode analysis call graph."""
  79. def __init__(self):
  80. self.analyzed_instructions = set()
  81. def get_local_name(self, local_id):
  82. """Gets the name for a local with the given id."""
  83. return 'local%d' % local_id
  84. def retrieve_user_root(self):
  85. """Creates an instruction that stores the user_root variable
  86. in a local."""
  87. return tree_ir.StoreLocalInstruction(
  88. 'user_root',
  89. tree_ir.LoadIndexInstruction(
  90. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  91. tree_ir.LiteralInstruction('user_root')))
  92. def analyze(self, instruction_id):
  93. """Tries to build an intermediate representation from the instruction with the
  94. given id."""
  95. # Check the analyzed_instructions set for instruction_id to avoid
  96. # infinite loops.
  97. if instruction_id in self.analyzed_instructions:
  98. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  99. self.analyzed_instructions.add(instruction_id)
  100. instruction_val, = yield [("RV", [instruction_id])]
  101. instruction_val = instruction_val["value"]
  102. if instruction_val in self.instruction_analyzers:
  103. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  104. try:
  105. inp = None
  106. while True:
  107. inp = yield gen.send(inp)
  108. except StopIteration:
  109. raise Exception(
  110. "Instruction analyzer (for '%s') finished without returning a value!" %
  111. (instruction_val))
  112. except primitive_functions.PrimitiveFinished as outer_e:
  113. # Check if the instruction has a 'next' instruction.
  114. next_instr, = yield [("RD", [instruction_id, "next"])]
  115. if next_instr is None:
  116. raise outer_e
  117. else:
  118. gen = self.analyze(next_instr)
  119. try:
  120. inp = None
  121. while True:
  122. inp = yield gen.send(inp)
  123. except primitive_functions.PrimitiveFinished as inner_e:
  124. raise primitive_functions.PrimitiveFinished(
  125. tree_ir.CompoundInstruction(
  126. outer_e.result,
  127. inner_e.result))
  128. else:
  129. raise JitCompilationFailedException(
  130. "Unknown instruction type: '%s'" % (instruction_val))
  131. def analyze_all(self, instruction_ids):
  132. """Tries to compile a list of IR trees from the given list of instruction ids."""
  133. results = []
  134. for inst in instruction_ids:
  135. gen = self.analyze(inst)
  136. try:
  137. inp = None
  138. while True:
  139. inp = yield gen.send(inp)
  140. except primitive_functions.PrimitiveFinished as ex:
  141. results.append(ex.result)
  142. raise primitive_functions.PrimitiveFinished(results)
  143. def analyze_return(self, instruction_id):
  144. """Tries to analyze the given 'return' instruction."""
  145. retval_id, = yield [("RD", [instruction_id, 'value'])]
  146. if retval_id is None:
  147. raise primitive_functions.PrimitiveFinished(
  148. tree_ir.ReturnInstruction(
  149. tree_ir.EmptyInstruction()))
  150. else:
  151. gen = self.analyze(retval_id)
  152. try:
  153. inp = None
  154. while True:
  155. inp = yield gen.send(inp)
  156. except primitive_functions.PrimitiveFinished as ex:
  157. raise primitive_functions.PrimitiveFinished(
  158. tree_ir.ReturnInstruction(ex.result))
  159. def analyze_if(self, instruction_id):
  160. """Tries to analyze the given 'if' instruction."""
  161. cond, true, false = yield [
  162. ("RD", [instruction_id, "cond"]),
  163. ("RD", [instruction_id, "then"]),
  164. ("RD", [instruction_id, "else"])]
  165. gen = self.analyze_all(
  166. [cond, true]
  167. if false is None
  168. else [cond, true, false])
  169. try:
  170. inp = None
  171. while True:
  172. inp = yield gen.send(inp)
  173. except primitive_functions.PrimitiveFinished as ex:
  174. if false is None:
  175. cond_r, true_r = ex.result
  176. false_r = tree_ir.EmptyInstruction()
  177. else:
  178. cond_r, true_r, false_r = ex.result
  179. raise primitive_functions.PrimitiveFinished(
  180. tree_ir.SelectInstruction(
  181. tree_ir.ReadValueInstruction(cond_r),
  182. true_r,
  183. false_r))
  184. def analyze_while(self, instruction_id):
  185. """Tries to analyze the given 'while' instruction."""
  186. cond, body = yield [
  187. ("RD", [instruction_id, "cond"]),
  188. ("RD", [instruction_id, "body"])]
  189. gen = self.analyze_all([cond, body])
  190. try:
  191. inp = None
  192. while True:
  193. inp = yield gen.send(inp)
  194. except primitive_functions.PrimitiveFinished as ex:
  195. cond_r, body_r = ex.result
  196. raise primitive_functions.PrimitiveFinished(
  197. tree_ir.LoopInstruction(
  198. tree_ir.CompoundInstruction(
  199. tree_ir.SelectInstruction(
  200. tree_ir.ReadValueInstruction(cond_r),
  201. tree_ir.EmptyInstruction(),
  202. tree_ir.BreakInstruction()),
  203. body_r)))
  204. def analyze_constant(self, instruction_id):
  205. """Tries to analyze the given 'constant' (literal) instruction."""
  206. node_id, = yield [("RD", [instruction_id, "node"])]
  207. raise primitive_functions.PrimitiveFinished(
  208. tree_ir.LiteralInstruction(node_id))
  209. def analyze_output(self, instruction_id):
  210. """Tries to analyze the given 'output' instruction."""
  211. # The plan is to basically generate this tree:
  212. #
  213. # value = <some tree>
  214. # last_output, last_output_link, new_last_output = \
  215. # yield [("RD", [user_root, "last_output"]),
  216. # ("RDE", [user_root, "last_output"]),
  217. # ("CN", []),
  218. # ]
  219. # _, _, _, _ = \
  220. # yield [("CD", [last_output, "value", value]),
  221. # ("CD", [last_output, "next", new_last_output]),
  222. # ("CD", [user_root, "last_output", new_last_output]),
  223. # ("DE", [last_output_link])
  224. # ]
  225. # yield None
  226. value_id, = yield [("RD", [instruction_id, "value"])]
  227. gen = self.analyze(value_id)
  228. try:
  229. inp = None
  230. while True:
  231. inp = yield gen.send(inp)
  232. except primitive_functions.PrimitiveFinished as ex:
  233. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  234. store_user_root = tree_ir.StoreLocalInstruction(
  235. 'user_root',
  236. tree_ir.LoadIndexInstruction(
  237. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  238. tree_ir.LiteralInstruction('user_root')))
  239. last_output = tree_ir.StoreLocalInstruction(
  240. 'last_output',
  241. tree_ir.ReadDictionaryValueInstruction(
  242. store_user_root.create_load(),
  243. tree_ir.LiteralInstruction('last_output')))
  244. last_output_link = tree_ir.StoreLocalInstruction(
  245. 'last_output_link',
  246. tree_ir.ReadDictionaryEdgeInstruction(
  247. store_user_root.create_load(),
  248. tree_ir.LiteralInstruction('last_output')))
  249. new_last_output = tree_ir.StoreLocalInstruction(
  250. 'new_last_output',
  251. tree_ir.CreateNodeInstruction())
  252. result = tree_ir.create_block(
  253. value_local,
  254. store_user_root,
  255. last_output,
  256. last_output_link,
  257. new_last_output,
  258. tree_ir.CreateDictionaryEdgeInstruction(
  259. last_output.create_load(),
  260. tree_ir.LiteralInstruction('value'),
  261. value_local.create_load()),
  262. tree_ir.CreateDictionaryEdgeInstruction(
  263. last_output.create_load(),
  264. tree_ir.LiteralInstruction('next'),
  265. new_last_output.create_load()),
  266. tree_ir.CreateDictionaryEdgeInstruction(
  267. store_user_root.create_load(),
  268. tree_ir.LiteralInstruction('last_output'),
  269. new_last_output.create_load()),
  270. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  271. tree_ir.NopInstruction())
  272. raise primitive_functions.PrimitiveFinished(result)
  273. def analyze_resolve(self, instruction_id):
  274. """Tries to analyze the given 'resolve' instruction."""
  275. var_id, = yield [("RD", [instruction_id, "var"])]
  276. # To resolve a variable, we'll do something along the
  277. # lines of:
  278. #
  279. # if 'local_var' in locals():
  280. # tmp = local_var
  281. # else:
  282. # _globals, var_name = yield [("RD", [user_root, "globals"]),
  283. # ("RV", [var_id])
  284. # ]
  285. # global_var = yield [("RD", [_globals, var_name])]
  286. #
  287. # if global_var is None:
  288. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  289. #
  290. # tmp = global_var
  291. user_root = self.retrieve_user_root()
  292. var_name = tree_ir.StoreLocalInstruction(
  293. 'var_name',
  294. tree_ir.ReadValueInstruction(
  295. tree_ir.LiteralInstruction(var_id)))
  296. global_var = tree_ir.StoreLocalInstruction(
  297. 'global_var',
  298. tree_ir.ReadDictionaryValueInstruction(
  299. tree_ir.ReadDictionaryValueInstruction(
  300. user_root.create_load(),
  301. tree_ir.LiteralInstruction('globals')),
  302. var_name.create_load()))
  303. err_block = tree_ir.SelectInstruction(
  304. tree_ir.BinaryInstruction(
  305. global_var.create_load(),
  306. 'is',
  307. tree_ir.LiteralInstruction(None)),
  308. tree_ir.RaiseInstruction(
  309. tree_ir.CallInstruction(
  310. tree_ir.LoadLocalInstruction('Exception'),
  311. [tree_ir.BinaryInstruction(
  312. tree_ir.LiteralInstruction("Runtime error: global '%s' not found"),
  313. '%',
  314. var_name.create_load())
  315. ])),
  316. tree_ir.EmptyInstruction())
  317. name = self.get_local_name(var_id)
  318. raise primitive_functions.PrimitiveFinished(
  319. tree_ir.SelectInstruction(
  320. tree_ir.LocalExistsInstruction(name),
  321. tree_ir.LoadLocalInstruction(name),
  322. tree_ir.CompoundInstruction(
  323. tree_ir.create_block(
  324. user_root,
  325. var_name,
  326. global_var,
  327. err_block),
  328. global_var.create_load())))
  329. def analyze_declare(self, instruction_id):
  330. """Tries to analyze the given 'declare' function."""
  331. var_id, = yield [("RD", [instruction_id, "var"])]
  332. name = self.get_local_name(var_id)
  333. # The following logic declares a local:
  334. #
  335. # if 'local_name' not in locals():
  336. # local_name, = yield [("CN", [])]
  337. raise primitive_functions.PrimitiveFinished(
  338. tree_ir.SelectInstruction(
  339. tree_ir.LocalExistsInstruction(name),
  340. tree_ir.EmptyInstruction(),
  341. tree_ir.StoreLocalInstruction(
  342. name,
  343. tree_ir.CreateNodeInstruction())))
  344. def analyze_global(self, instruction_id):
  345. """Tries to analyze the given 'global' (declaration) instruction."""
  346. var_id, = yield [("RD", [instruction_id, "var"])]
  347. # To resolve a variable, we'll do something along the
  348. # lines of:
  349. #
  350. # _globals, var_name = yield [("RD", [user_root, "globals"]),
  351. # ("RV", [var_id])
  352. # ]
  353. # global_var = yield [("RD", [_globals, var_name])]
  354. #
  355. # if global_var is None:
  356. # global_var, = yield [("CN", [])]
  357. # yield [("CD", [_globals, var_name, global_var])]
  358. #
  359. # tmp = global_var
  360. user_root = self.retrieve_user_root()
  361. var_name = tree_ir.StoreLocalInstruction(
  362. 'var_name',
  363. tree_ir.ReadValueInstruction(
  364. tree_ir.LiteralInstruction(var_id)))
  365. _globals = tree_ir.StoreLocalInstruction(
  366. '_globals',
  367. tree_ir.ReadDictionaryValueInstruction(
  368. user_root.create_load(),
  369. tree_ir.LiteralInstruction('globals')))
  370. global_var = tree_ir.StoreLocalInstruction(
  371. 'global_var',
  372. tree_ir.ReadDictionaryValueInstruction(
  373. _globals.create_load(),
  374. var_name.create_load()))
  375. raise primitive_functions.PrimitiveFinished(
  376. tree_ir.CompoundInstruction(
  377. tree_ir.create_block(
  378. user_root,
  379. var_name,
  380. _globals,
  381. global_var,
  382. tree_ir.SelectInstruction(
  383. tree_ir.BinaryInstruction(
  384. global_var.create_load(),
  385. 'is',
  386. tree_ir.LiteralInstruction(None)),
  387. tree_ir.create_block(
  388. global_var.create_store(
  389. tree_ir.CreateNodeInstruction()),
  390. tree_ir.CreateDictionaryEdgeInstruction(
  391. _globals.create_load(),
  392. var_name.create_load(),
  393. global_var.create_load())),
  394. tree_ir.EmptyInstruction())),
  395. global_var.create_load()))
  396. def analyze_assign(self, instruction_id):
  397. """Tries to analyze the given 'assign' instruction."""
  398. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  399. ("RD", [instruction_id, "value"])]
  400. try:
  401. gen = self.analyze_all([var_id, value_id])
  402. inp = None
  403. while True:
  404. inp = yield gen.send(inp)
  405. except primitive_functions.PrimitiveFinished as ex:
  406. var_r, value_r = ex.result
  407. # Assignments work like this:
  408. #
  409. # value_link = yield [("RDE", [variable, "value"])]
  410. # _, _ = yield [("CD", [variable, "value", value]),
  411. # ("DE", [value_link])]
  412. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  413. value = tree_ir.StoreLocalInstruction('value', value_r)
  414. value_link = tree_ir.StoreLocalInstruction(
  415. 'value_link',
  416. tree_ir.ReadDictionaryEdgeInstruction(
  417. variable.create_load(),
  418. tree_ir.LiteralInstruction('value')))
  419. raise primitive_functions.PrimitiveFinished(
  420. tree_ir.create_block(
  421. variable,
  422. value,
  423. value_link,
  424. tree_ir.CreateDictionaryEdgeInstruction(
  425. variable.create_load(),
  426. tree_ir.LiteralInstruction('value'),
  427. value.create_load()),
  428. tree_ir.DeleteEdgeInstruction(
  429. value_link.create_load())))
  430. instruction_analyzers = {
  431. 'if' : analyze_if,
  432. 'while' : analyze_while,
  433. 'return' : analyze_return,
  434. 'constant' : analyze_constant,
  435. 'resolve' : analyze_resolve,
  436. 'declare' : analyze_declare,
  437. 'global' : analyze_global,
  438. 'assign' : analyze_assign,
  439. 'output' : analyze_output
  440. }