jit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jit_globals = {
  15. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  16. }
  17. self.jit_count = 0
  18. self.max_instructions = 30 if max_instructions is None else max_instructions
  19. def mark_entry_point(self, body_id):
  20. """Marks the node with the given identifier as a function entry point."""
  21. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  22. self.todo_entry_points.add(body_id)
  23. def is_entry_point(self, body_id):
  24. """Tells if the node with the given identifier is a function entry point."""
  25. return body_id in self.todo_entry_points or \
  26. body_id in self.no_jit_entry_points or \
  27. body_id in self.jitted_entry_points
  28. def is_jittable_entry_point(self, body_id):
  29. """Tells if the node with the given identifier is a function entry point that
  30. has not been marked as non-jittable."""
  31. return body_id in self.todo_entry_points or \
  32. body_id in self.jitted_entry_points
  33. def mark_no_jit(self, body_id):
  34. """Informs the JIT that the node with the given identifier is a function entry
  35. point that must never be jitted."""
  36. self.no_jit_entry_points.add(body_id)
  37. if body_id in self.todo_entry_points:
  38. self.todo_entry_points.remove(body_id)
  39. def register_compiled(self, body_id, compiled):
  40. """Registers a compiled entry point with the JIT."""
  41. self.jitted_entry_points[body_id] = compiled
  42. if body_id in self.todo_entry_points:
  43. self.todo_entry_points.remove(body_id)
  44. def try_jit(self, body_id, parameter_list):
  45. """Tries to jit the function defined by the given entry point id and parameter list."""
  46. if body_id in self.jitted_entry_points:
  47. # We have already compiled this function.
  48. raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
  49. elif body_id in self.no_jit_entry_points:
  50. # We're not allowed to jit this function or have tried and failed before.
  51. raise JitCompilationFailedException(
  52. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  53. try:
  54. gen = AnalysisState(self.max_instructions).analyze(body_id)
  55. inp = None
  56. while True:
  57. inp = yield gen.send(inp)
  58. except primitive_functions.PrimitiveFinished as ex:
  59. constructed_body = ex.result
  60. except JitCompilationFailedException as ex:
  61. self.mark_no_jit(body_id)
  62. raise JitCompilationFailedException(
  63. '%s (function at %d)' % (ex.message, body_id))
  64. # Wrap the IR in a function definition, give it a unique name.
  65. constructed_function = tree_ir.DefineFunctionInstruction(
  66. 'jit_func%d' % self.jit_count,
  67. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  68. constructed_body.simplify())
  69. self.jit_count += 1
  70. # Convert the function definition to Python code, and compile it.
  71. exec(str(constructed_function), self.jit_globals)
  72. # Extract the compiled function from the JIT global state.
  73. compiled_function = self.jit_globals[constructed_function.name]
  74. print(constructed_function)
  75. # Save the compiled function so we can reuse it later.
  76. self.jitted_entry_points[body_id] = compiled_function
  77. raise primitive_functions.PrimitiveFinished(compiled_function)
  78. class AnalysisState(object):
  79. """The state of a bytecode analysis call graph."""
  80. def __init__(self, max_instructions=None):
  81. self.analyzed_instructions = set()
  82. self.max_instructions = max_instructions
  83. def get_local_name(self, local_id):
  84. """Gets the name for a local with the given id."""
  85. return 'local%d' % local_id
  86. def retrieve_user_root(self):
  87. """Creates an instruction that stores the user_root variable
  88. in a local."""
  89. return tree_ir.StoreLocalInstruction(
  90. 'user_root',
  91. tree_ir.LoadIndexInstruction(
  92. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  93. tree_ir.LiteralInstruction('user_root')))
  94. def analyze(self, instruction_id):
  95. """Tries to build an intermediate representation from the instruction with the
  96. given id."""
  97. # Check the analyzed_instructions set for instruction_id to avoid
  98. # infinite loops.
  99. if instruction_id in self.analyzed_instructions:
  100. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  101. elif (self.max_instructions is not None and
  102. len(self.analyzed_instructions) > self.max_instructions):
  103. raise JitCompilationFailedException('Maximal number of instructions exceeded.')
  104. self.analyzed_instructions.add(instruction_id)
  105. instruction_val, = yield [("RV", [instruction_id])]
  106. instruction_val = instruction_val["value"]
  107. if instruction_val in self.instruction_analyzers:
  108. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  109. try:
  110. inp = None
  111. while True:
  112. inp = yield gen.send(inp)
  113. except StopIteration:
  114. raise Exception(
  115. "Instruction analyzer (for '%s') finished without returning a value!" %
  116. (instruction_val))
  117. except primitive_functions.PrimitiveFinished as outer_e:
  118. # Check if the instruction has a 'next' instruction.
  119. next_instr, = yield [("RD", [instruction_id, "next"])]
  120. if next_instr is None:
  121. raise outer_e
  122. else:
  123. gen = self.analyze(next_instr)
  124. try:
  125. inp = None
  126. while True:
  127. inp = yield gen.send(inp)
  128. except primitive_functions.PrimitiveFinished as inner_e:
  129. raise primitive_functions.PrimitiveFinished(
  130. tree_ir.CompoundInstruction(
  131. outer_e.result,
  132. inner_e.result))
  133. else:
  134. raise JitCompilationFailedException(
  135. "Unknown instruction type: '%s'" % (instruction_val))
  136. def analyze_all(self, instruction_ids):
  137. """Tries to compile a list of IR trees from the given list of instruction ids."""
  138. results = []
  139. for inst in instruction_ids:
  140. gen = self.analyze(inst)
  141. try:
  142. inp = None
  143. while True:
  144. inp = yield gen.send(inp)
  145. except primitive_functions.PrimitiveFinished as ex:
  146. results.append(ex.result)
  147. raise primitive_functions.PrimitiveFinished(results)
  148. def analyze_return(self, instruction_id):
  149. """Tries to analyze the given 'return' instruction."""
  150. retval_id, = yield [("RD", [instruction_id, 'value'])]
  151. if retval_id is None:
  152. raise primitive_functions.PrimitiveFinished(
  153. tree_ir.ReturnInstruction(
  154. tree_ir.EmptyInstruction()))
  155. else:
  156. gen = self.analyze(retval_id)
  157. try:
  158. inp = None
  159. while True:
  160. inp = yield gen.send(inp)
  161. except primitive_functions.PrimitiveFinished as ex:
  162. raise primitive_functions.PrimitiveFinished(
  163. tree_ir.ReturnInstruction(ex.result))
  164. def analyze_if(self, instruction_id):
  165. """Tries to analyze the given 'if' instruction."""
  166. cond, true, false = yield [
  167. ("RD", [instruction_id, "cond"]),
  168. ("RD", [instruction_id, "then"]),
  169. ("RD", [instruction_id, "else"])]
  170. gen = self.analyze_all(
  171. [cond, true]
  172. if false is None
  173. else [cond, true, false])
  174. try:
  175. inp = None
  176. while True:
  177. inp = yield gen.send(inp)
  178. except primitive_functions.PrimitiveFinished as ex:
  179. if false is None:
  180. cond_r, true_r = ex.result
  181. false_r = tree_ir.EmptyInstruction()
  182. else:
  183. cond_r, true_r, false_r = ex.result
  184. raise primitive_functions.PrimitiveFinished(
  185. tree_ir.SelectInstruction(
  186. tree_ir.ReadValueInstruction(cond_r),
  187. true_r,
  188. false_r))
  189. def analyze_while(self, instruction_id):
  190. """Tries to analyze the given 'while' instruction."""
  191. cond, body = yield [
  192. ("RD", [instruction_id, "cond"]),
  193. ("RD", [instruction_id, "body"])]
  194. gen = self.analyze_all([cond, body])
  195. try:
  196. inp = None
  197. while True:
  198. inp = yield gen.send(inp)
  199. except primitive_functions.PrimitiveFinished as ex:
  200. cond_r, body_r = ex.result
  201. raise primitive_functions.PrimitiveFinished(
  202. tree_ir.LoopInstruction(
  203. tree_ir.CompoundInstruction(
  204. tree_ir.SelectInstruction(
  205. tree_ir.ReadValueInstruction(cond_r),
  206. tree_ir.EmptyInstruction(),
  207. tree_ir.BreakInstruction()),
  208. body_r)))
  209. def analyze_constant(self, instruction_id):
  210. """Tries to analyze the given 'constant' (literal) instruction."""
  211. node_id, = yield [("RD", [instruction_id, "node"])]
  212. raise primitive_functions.PrimitiveFinished(
  213. tree_ir.LiteralInstruction(node_id))
  214. def analyze_output(self, instruction_id):
  215. """Tries to analyze the given 'output' instruction."""
  216. # The plan is to basically generate this tree:
  217. #
  218. # value = <some tree>
  219. # last_output, last_output_link, new_last_output = \
  220. # yield [("RD", [user_root, "last_output"]),
  221. # ("RDE", [user_root, "last_output"]),
  222. # ("CN", []),
  223. # ]
  224. # _, _, _, _ = \
  225. # yield [("CD", [last_output, "value", value]),
  226. # ("CD", [last_output, "next", new_last_output]),
  227. # ("CD", [user_root, "last_output", new_last_output]),
  228. # ("DE", [last_output_link])
  229. # ]
  230. # yield None
  231. value_id, = yield [("RD", [instruction_id, "value"])]
  232. gen = self.analyze(value_id)
  233. try:
  234. inp = None
  235. while True:
  236. inp = yield gen.send(inp)
  237. except primitive_functions.PrimitiveFinished as ex:
  238. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  239. store_user_root = tree_ir.StoreLocalInstruction(
  240. 'user_root',
  241. tree_ir.LoadIndexInstruction(
  242. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  243. tree_ir.LiteralInstruction('user_root')))
  244. last_output = tree_ir.StoreLocalInstruction(
  245. 'last_output',
  246. tree_ir.ReadDictionaryValueInstruction(
  247. store_user_root.create_load(),
  248. tree_ir.LiteralInstruction('last_output')))
  249. last_output_link = tree_ir.StoreLocalInstruction(
  250. 'last_output_link',
  251. tree_ir.ReadDictionaryEdgeInstruction(
  252. store_user_root.create_load(),
  253. tree_ir.LiteralInstruction('last_output')))
  254. new_last_output = tree_ir.StoreLocalInstruction(
  255. 'new_last_output',
  256. tree_ir.CreateNodeInstruction())
  257. result = tree_ir.create_block(
  258. value_local,
  259. store_user_root,
  260. last_output,
  261. last_output_link,
  262. new_last_output,
  263. tree_ir.CreateDictionaryEdgeInstruction(
  264. last_output.create_load(),
  265. tree_ir.LiteralInstruction('value'),
  266. value_local.create_load()),
  267. tree_ir.CreateDictionaryEdgeInstruction(
  268. last_output.create_load(),
  269. tree_ir.LiteralInstruction('next'),
  270. new_last_output.create_load()),
  271. tree_ir.CreateDictionaryEdgeInstruction(
  272. store_user_root.create_load(),
  273. tree_ir.LiteralInstruction('last_output'),
  274. new_last_output.create_load()),
  275. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  276. tree_ir.NopInstruction())
  277. raise primitive_functions.PrimitiveFinished(result)
  278. def analyze_resolve(self, instruction_id):
  279. """Tries to analyze the given 'resolve' instruction."""
  280. var_id, = yield [("RD", [instruction_id, "var"])]
  281. # To resolve a variable, we'll do something along the
  282. # lines of:
  283. #
  284. # if 'local_var' in locals():
  285. # tmp = local_var
  286. # else:
  287. # _globals, var_name = yield [("RD", [user_root, "globals"]),
  288. # ("RV", [var_id])
  289. # ]
  290. # global_var = yield [("RD", [_globals, var_name])]
  291. #
  292. # if global_var is None:
  293. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  294. #
  295. # tmp = global_var
  296. user_root = self.retrieve_user_root()
  297. var_name = tree_ir.StoreLocalInstruction(
  298. 'var_name',
  299. tree_ir.ReadValueInstruction(
  300. tree_ir.LiteralInstruction(var_id)))
  301. global_var = tree_ir.StoreLocalInstruction(
  302. 'global_var',
  303. tree_ir.ReadDictionaryValueInstruction(
  304. tree_ir.ReadDictionaryValueInstruction(
  305. user_root.create_load(),
  306. tree_ir.LiteralInstruction('globals')),
  307. var_name.create_load()))
  308. err_block = tree_ir.SelectInstruction(
  309. tree_ir.BinaryInstruction(
  310. global_var.create_load(),
  311. 'is',
  312. tree_ir.LiteralInstruction(None)),
  313. tree_ir.RaiseInstruction(
  314. tree_ir.CallInstruction(
  315. tree_ir.LoadLocalInstruction('Exception'),
  316. [tree_ir.BinaryInstruction(
  317. tree_ir.LiteralInstruction("Runtime error: global '%s' not found"),
  318. '%',
  319. var_name.create_load())
  320. ])),
  321. tree_ir.EmptyInstruction())
  322. name = self.get_local_name(var_id)
  323. raise primitive_functions.PrimitiveFinished(
  324. tree_ir.SelectInstruction(
  325. tree_ir.LocalExistsInstruction(name),
  326. tree_ir.LoadLocalInstruction(name),
  327. tree_ir.CompoundInstruction(
  328. tree_ir.create_block(
  329. user_root,
  330. var_name,
  331. global_var,
  332. err_block),
  333. global_var.create_load())))
  334. def analyze_declare(self, instruction_id):
  335. """Tries to analyze the given 'declare' function."""
  336. var_id, = yield [("RD", [instruction_id, "var"])]
  337. name = self.get_local_name(var_id)
  338. # The following logic declares a local:
  339. #
  340. # if 'local_name' not in locals():
  341. # local_name, = yield [("CN", [])]
  342. raise primitive_functions.PrimitiveFinished(
  343. tree_ir.SelectInstruction(
  344. tree_ir.LocalExistsInstruction(name),
  345. tree_ir.EmptyInstruction(),
  346. tree_ir.StoreLocalInstruction(
  347. name,
  348. tree_ir.CreateNodeInstruction())))
  349. def analyze_global(self, instruction_id):
  350. """Tries to analyze the given 'global' (declaration) instruction."""
  351. var_id, = yield [("RD", [instruction_id, "var"])]
  352. # To resolve a variable, we'll do something along the
  353. # lines of:
  354. #
  355. # _globals, var_name = yield [("RD", [user_root, "globals"]),
  356. # ("RV", [var_id])
  357. # ]
  358. # global_var = yield [("RD", [_globals, var_name])]
  359. #
  360. # if global_var is None:
  361. # global_var, = yield [("CN", [])]
  362. # yield [("CD", [_globals, var_name, global_var])]
  363. #
  364. # tmp = global_var
  365. user_root = self.retrieve_user_root()
  366. var_name = tree_ir.StoreLocalInstruction(
  367. 'var_name',
  368. tree_ir.ReadValueInstruction(
  369. tree_ir.LiteralInstruction(var_id)))
  370. _globals = tree_ir.StoreLocalInstruction(
  371. '_globals',
  372. tree_ir.ReadDictionaryValueInstruction(
  373. user_root.create_load(),
  374. tree_ir.LiteralInstruction('globals')))
  375. global_var = tree_ir.StoreLocalInstruction(
  376. 'global_var',
  377. tree_ir.ReadDictionaryValueInstruction(
  378. _globals.create_load(),
  379. var_name.create_load()))
  380. raise primitive_functions.PrimitiveFinished(
  381. tree_ir.CompoundInstruction(
  382. tree_ir.create_block(
  383. user_root,
  384. var_name,
  385. _globals,
  386. global_var,
  387. tree_ir.SelectInstruction(
  388. tree_ir.BinaryInstruction(
  389. global_var.create_load(),
  390. 'is',
  391. tree_ir.LiteralInstruction(None)),
  392. tree_ir.create_block(
  393. global_var.create_store(
  394. tree_ir.CreateNodeInstruction()),
  395. tree_ir.CreateDictionaryEdgeInstruction(
  396. _globals.create_load(),
  397. var_name.create_load(),
  398. global_var.create_load())),
  399. tree_ir.EmptyInstruction())),
  400. global_var.create_load()))
  401. def analyze_assign(self, instruction_id):
  402. """Tries to analyze the given 'assign' instruction."""
  403. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  404. ("RD", [instruction_id, "value"])]
  405. try:
  406. gen = self.analyze_all([var_id, value_id])
  407. inp = None
  408. while True:
  409. inp = yield gen.send(inp)
  410. except primitive_functions.PrimitiveFinished as ex:
  411. var_r, value_r = ex.result
  412. # Assignments work like this:
  413. #
  414. # value_link = yield [("RDE", [variable, "value"])]
  415. # _, _ = yield [("CD", [variable, "value", value]),
  416. # ("DE", [value_link])]
  417. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  418. value = tree_ir.StoreLocalInstruction('value', value_r)
  419. value_link = tree_ir.StoreLocalInstruction(
  420. 'value_link',
  421. tree_ir.ReadDictionaryEdgeInstruction(
  422. variable.create_load(),
  423. tree_ir.LiteralInstruction('value')))
  424. raise primitive_functions.PrimitiveFinished(
  425. tree_ir.create_block(
  426. variable,
  427. value,
  428. value_link,
  429. tree_ir.CreateDictionaryEdgeInstruction(
  430. variable.create_load(),
  431. tree_ir.LiteralInstruction('value'),
  432. value.create_load()),
  433. tree_ir.DeleteEdgeInstruction(
  434. value_link.create_load())))
  435. def analyze_access(self, instruction_id):
  436. """Tries to analyze the given 'access' instruction."""
  437. var_id, = yield [("RD", [instruction_id, "var"])]
  438. try:
  439. gen = self.analyze(var_id)
  440. inp = None
  441. while True:
  442. inp = yield gen.send(inp)
  443. except primitive_functions.PrimitiveFinished as ex:
  444. var_r = ex.result
  445. # Accessing a variable is pretty easy. It really just boils
  446. # down to reading the value corresponding to the 'value' key
  447. # of the variable.
  448. #
  449. # value, = yield [("RD", [returnvalue, "value"])]
  450. raise primitive_functions.PrimitiveFinished(
  451. tree_ir.ReadDictionaryValueInstruction(
  452. var_r,
  453. tree_ir.LiteralInstruction('value')))
  454. instruction_analyzers = {
  455. 'if' : analyze_if,
  456. 'while' : analyze_while,
  457. 'return' : analyze_return,
  458. 'constant' : analyze_constant,
  459. 'resolve' : analyze_resolve,
  460. 'declare' : analyze_declare,
  461. 'global' : analyze_global,
  462. 'assign' : analyze_assign,
  463. 'access' : analyze_access,
  464. 'output' : analyze_output
  465. }