jit.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jit_globals = {
  15. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  16. }
  17. self.jit_count = 0
  18. self.max_instructions = 30 if max_instructions is None else max_instructions
  19. def mark_entry_point(self, body_id):
  20. """Marks the node with the given identifier as a function entry point."""
  21. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  22. self.todo_entry_points.add(body_id)
  23. def is_entry_point(self, body_id):
  24. """Tells if the node with the given identifier is a function entry point."""
  25. return body_id in self.todo_entry_points or \
  26. body_id in self.no_jit_entry_points or \
  27. body_id in self.jitted_entry_points
  28. def is_jittable_entry_point(self, body_id):
  29. """Tells if the node with the given identifier is a function entry point that
  30. has not been marked as non-jittable."""
  31. return body_id in self.todo_entry_points or \
  32. body_id in self.jitted_entry_points
  33. def mark_no_jit(self, body_id):
  34. """Informs the JIT that the node with the given identifier is a function entry
  35. point that must never be jitted."""
  36. self.no_jit_entry_points.add(body_id)
  37. if body_id in self.todo_entry_points:
  38. self.todo_entry_points.remove(body_id)
  39. def register_compiled(self, body_id, compiled_function, function_name=None):
  40. """Registers a compiled entry point with the JIT."""
  41. if function_name is None:
  42. function_name = 'jit_func%d' % self.jit_count
  43. self.jit_count += 1
  44. self.jitted_entry_points[body_id] = function_name
  45. self.jit_globals[function_name] = compiled_function
  46. if body_id in self.todo_entry_points:
  47. self.todo_entry_points.remove(body_id)
  48. def jit_compile(self, body_id, parameter_list):
  49. """Tries to jit the function defined by the given entry point id and parameter list."""
  50. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  51. # pylint: disable=I0011,W0122
  52. if body_id in self.jitted_entry_points:
  53. # We have already compiled this function.
  54. raise primitive_functions.PrimitiveFinished(
  55. self.jit_globals[self.jitted_entry_points[body_id]])
  56. elif body_id in self.no_jit_entry_points:
  57. # We're not allowed to jit this function or have tried and failed before.
  58. raise JitCompilationFailedException(
  59. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  60. try:
  61. gen = AnalysisState(self.max_instructions).analyze(body_id)
  62. inp = None
  63. while True:
  64. inp = yield gen.send(inp)
  65. except primitive_functions.PrimitiveFinished as ex:
  66. constructed_body = ex.result
  67. except JitCompilationFailedException as ex:
  68. self.mark_no_jit(body_id)
  69. raise JitCompilationFailedException(
  70. '%s (function at %d)' % (ex.message, body_id))
  71. # Wrap the IR in a function definition, give it a unique name.
  72. constructed_function = tree_ir.DefineFunctionInstruction(
  73. 'jit_func%d' % self.jit_count,
  74. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  75. constructed_body.simplify())
  76. self.jit_count += 1
  77. # Convert the function definition to Python code, and compile it.
  78. exec(str(constructed_function), self.jit_globals)
  79. # Extract the compiled function from the JIT global state.
  80. compiled_function = self.jit_globals[constructed_function.name]
  81. # Save the compiled function so we can reuse it later.
  82. self.jitted_entry_points[body_id] = constructed_function.name
  83. print(constructed_function)
  84. raise primitive_functions.PrimitiveFinished(compiled_function)
  85. class AnalysisState(object):
  86. """The state of a bytecode analysis call graph."""
  87. def __init__(self, max_instructions=None):
  88. self.analyzed_instructions = set()
  89. self.max_instructions = max_instructions
  90. def get_local_name(self, local_id):
  91. """Gets the name for a local with the given id."""
  92. return 'local%d' % local_id
  93. def retrieve_user_root(self):
  94. """Creates an instruction that stores the user_root variable
  95. in a local."""
  96. return tree_ir.StoreLocalInstruction(
  97. 'user_root',
  98. tree_ir.LoadIndexInstruction(
  99. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  100. tree_ir.LiteralInstruction('user_root')))
  101. def analyze(self, instruction_id):
  102. """Tries to build an intermediate representation from the instruction with the
  103. given id."""
  104. # Check the analyzed_instructions set for instruction_id to avoid
  105. # infinite loops.
  106. if instruction_id in self.analyzed_instructions:
  107. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  108. elif (self.max_instructions is not None and
  109. len(self.analyzed_instructions) > self.max_instructions):
  110. raise JitCompilationFailedException('Maximal number of instructions exceeded.')
  111. self.analyzed_instructions.add(instruction_id)
  112. instruction_val, = yield [("RV", [instruction_id])]
  113. instruction_val = instruction_val["value"]
  114. if instruction_val in self.instruction_analyzers:
  115. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  116. try:
  117. inp = None
  118. while True:
  119. inp = yield gen.send(inp)
  120. except StopIteration:
  121. raise Exception(
  122. "Instruction analyzer (for '%s') finished without returning a value!" %
  123. (instruction_val))
  124. except primitive_functions.PrimitiveFinished as outer_e:
  125. # Check if the instruction has a 'next' instruction.
  126. next_instr, = yield [("RD", [instruction_id, "next"])]
  127. if next_instr is None:
  128. raise outer_e
  129. else:
  130. gen = self.analyze(next_instr)
  131. try:
  132. inp = None
  133. while True:
  134. inp = yield gen.send(inp)
  135. except primitive_functions.PrimitiveFinished as inner_e:
  136. raise primitive_functions.PrimitiveFinished(
  137. tree_ir.CompoundInstruction(
  138. outer_e.result,
  139. inner_e.result))
  140. else:
  141. raise JitCompilationFailedException(
  142. "Unknown instruction type: '%s'" % (instruction_val))
  143. def analyze_all(self, instruction_ids):
  144. """Tries to compile a list of IR trees from the given list of instruction ids."""
  145. results = []
  146. for inst in instruction_ids:
  147. gen = self.analyze(inst)
  148. try:
  149. inp = None
  150. while True:
  151. inp = yield gen.send(inp)
  152. except primitive_functions.PrimitiveFinished as ex:
  153. results.append(ex.result)
  154. raise primitive_functions.PrimitiveFinished(results)
  155. def analyze_return(self, instruction_id):
  156. """Tries to analyze the given 'return' instruction."""
  157. retval_id, = yield [("RD", [instruction_id, 'value'])]
  158. if retval_id is None:
  159. raise primitive_functions.PrimitiveFinished(
  160. tree_ir.ReturnInstruction(
  161. tree_ir.EmptyInstruction()))
  162. else:
  163. gen = self.analyze(retval_id)
  164. try:
  165. inp = None
  166. while True:
  167. inp = yield gen.send(inp)
  168. except primitive_functions.PrimitiveFinished as ex:
  169. raise primitive_functions.PrimitiveFinished(
  170. tree_ir.ReturnInstruction(ex.result))
  171. def analyze_if(self, instruction_id):
  172. """Tries to analyze the given 'if' instruction."""
  173. cond, true, false = yield [
  174. ("RD", [instruction_id, "cond"]),
  175. ("RD", [instruction_id, "then"]),
  176. ("RD", [instruction_id, "else"])]
  177. gen = self.analyze_all(
  178. [cond, true]
  179. if false is None
  180. else [cond, true, false])
  181. try:
  182. inp = None
  183. while True:
  184. inp = yield gen.send(inp)
  185. except primitive_functions.PrimitiveFinished as ex:
  186. if false is None:
  187. cond_r, true_r = ex.result
  188. false_r = tree_ir.EmptyInstruction()
  189. else:
  190. cond_r, true_r, false_r = ex.result
  191. raise primitive_functions.PrimitiveFinished(
  192. tree_ir.SelectInstruction(
  193. tree_ir.ReadValueInstruction(cond_r),
  194. true_r,
  195. false_r))
  196. def analyze_while(self, instruction_id):
  197. """Tries to analyze the given 'while' instruction."""
  198. cond, body = yield [
  199. ("RD", [instruction_id, "cond"]),
  200. ("RD", [instruction_id, "body"])]
  201. gen = self.analyze_all([cond, body])
  202. try:
  203. inp = None
  204. while True:
  205. inp = yield gen.send(inp)
  206. except primitive_functions.PrimitiveFinished as ex:
  207. cond_r, body_r = ex.result
  208. raise primitive_functions.PrimitiveFinished(
  209. tree_ir.LoopInstruction(
  210. tree_ir.CompoundInstruction(
  211. tree_ir.SelectInstruction(
  212. tree_ir.ReadValueInstruction(cond_r),
  213. tree_ir.EmptyInstruction(),
  214. tree_ir.BreakInstruction()),
  215. body_r)))
  216. def analyze_constant(self, instruction_id):
  217. """Tries to analyze the given 'constant' (literal) instruction."""
  218. node_id, = yield [("RD", [instruction_id, "node"])]
  219. raise primitive_functions.PrimitiveFinished(
  220. tree_ir.LiteralInstruction(node_id))
  221. def analyze_output(self, instruction_id):
  222. """Tries to analyze the given 'output' instruction."""
  223. # The plan is to basically generate this tree:
  224. #
  225. # value = <some tree>
  226. # last_output, last_output_link, new_last_output = \
  227. # yield [("RD", [user_root, "last_output"]),
  228. # ("RDE", [user_root, "last_output"]),
  229. # ("CN", []),
  230. # ]
  231. # _, _, _, _ = \
  232. # yield [("CD", [last_output, "value", value]),
  233. # ("CD", [last_output, "next", new_last_output]),
  234. # ("CD", [user_root, "last_output", new_last_output]),
  235. # ("DE", [last_output_link])
  236. # ]
  237. # yield None
  238. value_id, = yield [("RD", [instruction_id, "value"])]
  239. gen = self.analyze(value_id)
  240. try:
  241. inp = None
  242. while True:
  243. inp = yield gen.send(inp)
  244. except primitive_functions.PrimitiveFinished as ex:
  245. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  246. store_user_root = self.retrieve_user_root()
  247. last_output = tree_ir.StoreLocalInstruction(
  248. 'last_output',
  249. tree_ir.ReadDictionaryValueInstruction(
  250. store_user_root.create_load(),
  251. tree_ir.LiteralInstruction('last_output')))
  252. last_output_link = tree_ir.StoreLocalInstruction(
  253. 'last_output_link',
  254. tree_ir.ReadDictionaryEdgeInstruction(
  255. store_user_root.create_load(),
  256. tree_ir.LiteralInstruction('last_output')))
  257. new_last_output = tree_ir.StoreLocalInstruction(
  258. 'new_last_output',
  259. tree_ir.CreateNodeInstruction())
  260. result = tree_ir.create_block(
  261. value_local,
  262. store_user_root,
  263. last_output,
  264. last_output_link,
  265. new_last_output,
  266. tree_ir.CreateDictionaryEdgeInstruction(
  267. last_output.create_load(),
  268. tree_ir.LiteralInstruction('value'),
  269. value_local.create_load()),
  270. tree_ir.CreateDictionaryEdgeInstruction(
  271. last_output.create_load(),
  272. tree_ir.LiteralInstruction('next'),
  273. new_last_output.create_load()),
  274. tree_ir.CreateDictionaryEdgeInstruction(
  275. store_user_root.create_load(),
  276. tree_ir.LiteralInstruction('last_output'),
  277. new_last_output.create_load()),
  278. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  279. tree_ir.NopInstruction())
  280. raise primitive_functions.PrimitiveFinished(result)
  281. def analyze_input(self, _):
  282. """Tries to analyze the given 'input' instruction."""
  283. # The plan is to generate this tree:
  284. #
  285. # value = None
  286. # while True:
  287. # if value is None:
  288. # yield None # nop
  289. # else:
  290. # break
  291. #
  292. # _input = yield [("RD", [user_root, "input"])]
  293. # value = yield [("RD", [_input, "value"])]
  294. #
  295. # _next = yield [("RD", [_input, "next"])]
  296. # yield [("CD", [user_root, "input", _next])]
  297. # yield [("DN", [_input])]
  298. user_root = self.retrieve_user_root()
  299. _input = tree_ir.StoreLocalInstruction(
  300. '_input',
  301. tree_ir.ReadDictionaryValueInstruction(
  302. user_root.create_load(),
  303. tree_ir.LiteralInstruction('input')))
  304. value = tree_ir.StoreLocalInstruction(
  305. 'value',
  306. tree_ir.ReadDictionaryValueInstruction(
  307. _input.create_load(),
  308. tree_ir.LiteralInstruction('value')))
  309. raise primitive_functions.PrimitiveFinished(
  310. tree_ir.CompoundInstruction(
  311. tree_ir.create_block(
  312. user_root,
  313. value.create_store(tree_ir.LiteralInstruction(None)),
  314. tree_ir.LoopInstruction(
  315. tree_ir.create_block(
  316. tree_ir.SelectInstruction(
  317. tree_ir.BinaryInstruction(
  318. value.create_load(),
  319. 'is',
  320. tree_ir.LiteralInstruction(None)),
  321. tree_ir.NopInstruction(),
  322. tree_ir.BreakInstruction()),
  323. _input,
  324. value)),
  325. tree_ir.CreateDictionaryEdgeInstruction(
  326. user_root.create_load(),
  327. tree_ir.LiteralInstruction('input'),
  328. tree_ir.ReadDictionaryValueInstruction(
  329. _input.create_load(),
  330. tree_ir.LiteralInstruction('next'))),
  331. tree_ir.DeleteNodeInstruction(_input.create_load())),
  332. value.create_load()))
  333. def analyze_resolve(self, instruction_id):
  334. """Tries to analyze the given 'resolve' instruction."""
  335. var_id, = yield [("RD", [instruction_id, "var"])]
  336. var_name, = yield [("RV", [var_id])]
  337. # To resolve a variable, we'll do something along the
  338. # lines of:
  339. #
  340. # if 'local_var' in locals():
  341. # tmp = local_var
  342. # else:
  343. # _globals, = yield [("RD", [user_root, "globals"])]
  344. # global_var, = yield [("RD", [_globals, var_name])]
  345. #
  346. # if global_var is None:
  347. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  348. #
  349. # tmp = global_var
  350. user_root = self.retrieve_user_root()
  351. global_var = tree_ir.StoreLocalInstruction(
  352. 'global_var',
  353. tree_ir.ReadDictionaryValueInstruction(
  354. tree_ir.ReadDictionaryValueInstruction(
  355. user_root.create_load(),
  356. tree_ir.LiteralInstruction('globals')),
  357. tree_ir.LiteralInstruction(var_name)))
  358. err_block = tree_ir.SelectInstruction(
  359. tree_ir.BinaryInstruction(
  360. global_var.create_load(),
  361. 'is',
  362. tree_ir.LiteralInstruction(None)),
  363. tree_ir.RaiseInstruction(
  364. tree_ir.CallInstruction(
  365. tree_ir.LoadLocalInstruction('Exception'),
  366. [tree_ir.LiteralInstruction(
  367. "Runtime error: global '%s' not found" % var_name)
  368. ])),
  369. tree_ir.EmptyInstruction())
  370. name = self.get_local_name(var_id)
  371. raise primitive_functions.PrimitiveFinished(
  372. tree_ir.SelectInstruction(
  373. tree_ir.LocalExistsInstruction(name),
  374. tree_ir.LoadLocalInstruction(name),
  375. tree_ir.CompoundInstruction(
  376. tree_ir.create_block(
  377. user_root,
  378. global_var,
  379. err_block),
  380. global_var.create_load())))
  381. def analyze_declare(self, instruction_id):
  382. """Tries to analyze the given 'declare' function."""
  383. var_id, = yield [("RD", [instruction_id, "var"])]
  384. name = self.get_local_name(var_id)
  385. # The following logic declares a local:
  386. #
  387. # if 'local_name' not in locals():
  388. # local_name, = yield [("CN", [])]
  389. raise primitive_functions.PrimitiveFinished(
  390. tree_ir.SelectInstruction(
  391. tree_ir.LocalExistsInstruction(name),
  392. tree_ir.EmptyInstruction(),
  393. tree_ir.StoreLocalInstruction(
  394. name,
  395. tree_ir.CreateNodeInstruction())))
  396. def analyze_global(self, instruction_id):
  397. """Tries to analyze the given 'global' (declaration) instruction."""
  398. var_id, = yield [("RD", [instruction_id, "var"])]
  399. var_name, = yield [("RV", [var_id])]
  400. # To resolve a variable, we'll do something along the
  401. # lines of:
  402. #
  403. # _globals, = yield [("RD", [user_root, "globals"])]
  404. # global_var = yield [("RD", [_globals, var_name])]
  405. #
  406. # if global_var is None:
  407. # global_var, = yield [("CN", [])]
  408. # yield [("CD", [_globals, var_name, global_var])]
  409. #
  410. # tmp = global_var
  411. user_root = self.retrieve_user_root()
  412. _globals = tree_ir.StoreLocalInstruction(
  413. '_globals',
  414. tree_ir.ReadDictionaryValueInstruction(
  415. user_root.create_load(),
  416. tree_ir.LiteralInstruction('globals')))
  417. global_var = tree_ir.StoreLocalInstruction(
  418. 'global_var',
  419. tree_ir.ReadDictionaryValueInstruction(
  420. _globals.create_load(),
  421. tree_ir.LiteralInstruction(var_name)))
  422. raise primitive_functions.PrimitiveFinished(
  423. tree_ir.CompoundInstruction(
  424. tree_ir.create_block(
  425. user_root,
  426. _globals,
  427. global_var,
  428. tree_ir.SelectInstruction(
  429. tree_ir.BinaryInstruction(
  430. global_var.create_load(),
  431. 'is',
  432. tree_ir.LiteralInstruction(None)),
  433. tree_ir.create_block(
  434. global_var.create_store(
  435. tree_ir.CreateNodeInstruction()),
  436. tree_ir.CreateDictionaryEdgeInstruction(
  437. _globals.create_load(),
  438. tree_ir.LiteralInstruction(var_name),
  439. global_var.create_load())),
  440. tree_ir.EmptyInstruction())),
  441. global_var.create_load()))
  442. def analyze_assign(self, instruction_id):
  443. """Tries to analyze the given 'assign' instruction."""
  444. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  445. ("RD", [instruction_id, "value"])]
  446. try:
  447. gen = self.analyze_all([var_id, value_id])
  448. inp = None
  449. while True:
  450. inp = yield gen.send(inp)
  451. except primitive_functions.PrimitiveFinished as ex:
  452. var_r, value_r = ex.result
  453. # Assignments work like this:
  454. #
  455. # value_link = yield [("RDE", [variable, "value"])]
  456. # _, _ = yield [("CD", [variable, "value", value]),
  457. # ("DE", [value_link])]
  458. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  459. value = tree_ir.StoreLocalInstruction('value', value_r)
  460. value_link = tree_ir.StoreLocalInstruction(
  461. 'value_link',
  462. tree_ir.ReadDictionaryEdgeInstruction(
  463. variable.create_load(),
  464. tree_ir.LiteralInstruction('value')))
  465. raise primitive_functions.PrimitiveFinished(
  466. tree_ir.create_block(
  467. variable,
  468. value,
  469. value_link,
  470. tree_ir.CreateDictionaryEdgeInstruction(
  471. variable.create_load(),
  472. tree_ir.LiteralInstruction('value'),
  473. value.create_load()),
  474. tree_ir.DeleteEdgeInstruction(
  475. value_link.create_load())))
  476. def analyze_access(self, instruction_id):
  477. """Tries to analyze the given 'access' instruction."""
  478. var_id, = yield [("RD", [instruction_id, "var"])]
  479. try:
  480. gen = self.analyze(var_id)
  481. inp = None
  482. while True:
  483. inp = yield gen.send(inp)
  484. except primitive_functions.PrimitiveFinished as ex:
  485. var_r = ex.result
  486. # Accessing a variable is pretty easy. It really just boils
  487. # down to reading the value corresponding to the 'value' key
  488. # of the variable.
  489. #
  490. # value, = yield [("RD", [returnvalue, "value"])]
  491. raise primitive_functions.PrimitiveFinished(
  492. tree_ir.ReadDictionaryValueInstruction(
  493. var_r,
  494. tree_ir.LiteralInstruction('value')))
  495. instruction_analyzers = {
  496. 'if' : analyze_if,
  497. 'while' : analyze_while,
  498. 'return' : analyze_return,
  499. 'constant' : analyze_constant,
  500. 'resolve' : analyze_resolve,
  501. 'declare' : analyze_declare,
  502. 'global' : analyze_global,
  503. 'assign' : analyze_assign,
  504. 'access' : analyze_access,
  505. 'output' : analyze_output,
  506. 'input' : analyze_input
  507. }