jit.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self, max_instructions=None):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jit_globals = {
  15. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  16. }
  17. self.jit_count = 0
  18. self.max_instructions = 30 if max_instructions is None else max_instructions
  19. self.jit_enabled = True
  20. def set_jit_enabled(self, is_enabled=True):
  21. """Enables or disables the JIT."""
  22. self.jit_enabled = is_enabled
  23. def mark_entry_point(self, body_id):
  24. """Marks the node with the given identifier as a function entry point."""
  25. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  26. self.todo_entry_points.add(body_id)
  27. def is_entry_point(self, body_id):
  28. """Tells if the node with the given identifier is a function entry point."""
  29. return body_id in self.todo_entry_points or \
  30. body_id in self.no_jit_entry_points or \
  31. body_id in self.jitted_entry_points
  32. def is_jittable_entry_point(self, body_id):
  33. """Tells if the node with the given identifier is a function entry point that
  34. has not been marked as non-jittable. This only returns `True` if the JIT
  35. is enabled"""
  36. return self.jit_enabled and (
  37. body_id in self.todo_entry_points or
  38. body_id in self.jitted_entry_points)
  39. def mark_no_jit(self, body_id):
  40. """Informs the JIT that the node with the given identifier is a function entry
  41. point that must never be jitted."""
  42. self.no_jit_entry_points.add(body_id)
  43. if body_id in self.todo_entry_points:
  44. self.todo_entry_points.remove(body_id)
  45. def register_compiled(self, body_id, compiled_function, function_name=None):
  46. """Registers a compiled entry point with the JIT."""
  47. if function_name is None:
  48. function_name = 'jit_func%d' % self.jit_count
  49. self.jit_count += 1
  50. self.jitted_entry_points[body_id] = function_name
  51. self.jit_globals[function_name] = compiled_function
  52. if body_id in self.todo_entry_points:
  53. self.todo_entry_points.remove(body_id)
  54. def jit_compile(self, body_id, parameter_list):
  55. """Tries to jit the function defined by the given entry point id and parameter list."""
  56. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  57. # pylint: disable=I0011,W0122
  58. if body_id in self.jitted_entry_points:
  59. # We have already compiled this function.
  60. raise primitive_functions.PrimitiveFinished(
  61. self.jit_globals[self.jitted_entry_points[body_id]])
  62. elif body_id in self.no_jit_entry_points:
  63. # We're not allowed to jit this function or have tried and failed before.
  64. raise JitCompilationFailedException(
  65. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  66. try:
  67. gen = AnalysisState(self.max_instructions).analyze(body_id)
  68. inp = None
  69. while True:
  70. inp = yield gen.send(inp)
  71. except primitive_functions.PrimitiveFinished as ex:
  72. constructed_body = ex.result
  73. except JitCompilationFailedException as ex:
  74. self.mark_no_jit(body_id)
  75. raise JitCompilationFailedException(
  76. '%s (function at %d)' % (ex.message, body_id))
  77. # Wrap the IR in a function definition, give it a unique name.
  78. constructed_function = tree_ir.DefineFunctionInstruction(
  79. 'jit_func%d' % self.jit_count,
  80. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  81. constructed_body.simplify())
  82. self.jit_count += 1
  83. # Convert the function definition to Python code, and compile it.
  84. exec(str(constructed_function), self.jit_globals)
  85. # Extract the compiled function from the JIT global state.
  86. compiled_function = self.jit_globals[constructed_function.name]
  87. # Save the compiled function so we can reuse it later.
  88. self.jitted_entry_points[body_id] = constructed_function.name
  89. print(constructed_function)
  90. raise primitive_functions.PrimitiveFinished(compiled_function)
  91. class AnalysisState(object):
  92. """The state of a bytecode analysis call graph."""
  93. def __init__(self, max_instructions=None):
  94. self.analyzed_instructions = set()
  95. self.max_instructions = max_instructions
  96. def get_local_name(self, local_id):
  97. """Gets the name for a local with the given id."""
  98. return 'local%d' % local_id
  99. def retrieve_user_root(self):
  100. """Creates an instruction that stores the user_root variable
  101. in a local."""
  102. return tree_ir.StoreLocalInstruction(
  103. 'user_root',
  104. tree_ir.LoadIndexInstruction(
  105. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  106. tree_ir.LiteralInstruction('user_root')))
  107. def analyze(self, instruction_id):
  108. """Tries to build an intermediate representation from the instruction with the
  109. given id."""
  110. # Check the analyzed_instructions set for instruction_id to avoid
  111. # infinite loops.
  112. if instruction_id in self.analyzed_instructions:
  113. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  114. elif (self.max_instructions is not None and
  115. len(self.analyzed_instructions) > self.max_instructions):
  116. raise JitCompilationFailedException('Maximal number of instructions exceeded.')
  117. self.analyzed_instructions.add(instruction_id)
  118. instruction_val, = yield [("RV", [instruction_id])]
  119. instruction_val = instruction_val["value"]
  120. if instruction_val in self.instruction_analyzers:
  121. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  122. try:
  123. inp = None
  124. while True:
  125. inp = yield gen.send(inp)
  126. except StopIteration:
  127. raise Exception(
  128. "Instruction analyzer (for '%s') finished without returning a value!" %
  129. (instruction_val))
  130. except primitive_functions.PrimitiveFinished as outer_e:
  131. # Check if the instruction has a 'next' instruction.
  132. next_instr, = yield [("RD", [instruction_id, "next"])]
  133. if next_instr is None:
  134. raise outer_e
  135. else:
  136. gen = self.analyze(next_instr)
  137. try:
  138. inp = None
  139. while True:
  140. inp = yield gen.send(inp)
  141. except primitive_functions.PrimitiveFinished as inner_e:
  142. raise primitive_functions.PrimitiveFinished(
  143. tree_ir.CompoundInstruction(
  144. outer_e.result,
  145. inner_e.result))
  146. else:
  147. raise JitCompilationFailedException(
  148. "Unknown instruction type: '%s'" % (instruction_val))
  149. def analyze_all(self, instruction_ids):
  150. """Tries to compile a list of IR trees from the given list of instruction ids."""
  151. results = []
  152. for inst in instruction_ids:
  153. gen = self.analyze(inst)
  154. try:
  155. inp = None
  156. while True:
  157. inp = yield gen.send(inp)
  158. except primitive_functions.PrimitiveFinished as ex:
  159. results.append(ex.result)
  160. raise primitive_functions.PrimitiveFinished(results)
  161. def analyze_return(self, instruction_id):
  162. """Tries to analyze the given 'return' instruction."""
  163. retval_id, = yield [("RD", [instruction_id, 'value'])]
  164. if retval_id is None:
  165. raise primitive_functions.PrimitiveFinished(
  166. tree_ir.ReturnInstruction(
  167. tree_ir.EmptyInstruction()))
  168. else:
  169. gen = self.analyze(retval_id)
  170. try:
  171. inp = None
  172. while True:
  173. inp = yield gen.send(inp)
  174. except primitive_functions.PrimitiveFinished as ex:
  175. raise primitive_functions.PrimitiveFinished(
  176. tree_ir.ReturnInstruction(ex.result))
  177. def analyze_if(self, instruction_id):
  178. """Tries to analyze the given 'if' instruction."""
  179. cond, true, false = yield [
  180. ("RD", [instruction_id, "cond"]),
  181. ("RD", [instruction_id, "then"]),
  182. ("RD", [instruction_id, "else"])]
  183. gen = self.analyze_all(
  184. [cond, true]
  185. if false is None
  186. else [cond, true, false])
  187. try:
  188. inp = None
  189. while True:
  190. inp = yield gen.send(inp)
  191. except primitive_functions.PrimitiveFinished as ex:
  192. if false is None:
  193. cond_r, true_r = ex.result
  194. false_r = tree_ir.EmptyInstruction()
  195. else:
  196. cond_r, true_r, false_r = ex.result
  197. raise primitive_functions.PrimitiveFinished(
  198. tree_ir.SelectInstruction(
  199. tree_ir.ReadValueInstruction(cond_r),
  200. true_r,
  201. false_r))
  202. def analyze_while(self, instruction_id):
  203. """Tries to analyze the given 'while' instruction."""
  204. cond, body = yield [
  205. ("RD", [instruction_id, "cond"]),
  206. ("RD", [instruction_id, "body"])]
  207. gen = self.analyze_all([cond, body])
  208. try:
  209. inp = None
  210. while True:
  211. inp = yield gen.send(inp)
  212. except primitive_functions.PrimitiveFinished as ex:
  213. cond_r, body_r = ex.result
  214. raise primitive_functions.PrimitiveFinished(
  215. tree_ir.LoopInstruction(
  216. tree_ir.CompoundInstruction(
  217. tree_ir.SelectInstruction(
  218. tree_ir.ReadValueInstruction(cond_r),
  219. tree_ir.EmptyInstruction(),
  220. tree_ir.BreakInstruction()),
  221. body_r)))
  222. def analyze_constant(self, instruction_id):
  223. """Tries to analyze the given 'constant' (literal) instruction."""
  224. node_id, = yield [("RD", [instruction_id, "node"])]
  225. raise primitive_functions.PrimitiveFinished(
  226. tree_ir.LiteralInstruction(node_id))
  227. def analyze_output(self, instruction_id):
  228. """Tries to analyze the given 'output' instruction."""
  229. # The plan is to basically generate this tree:
  230. #
  231. # value = <some tree>
  232. # last_output, last_output_link, new_last_output = \
  233. # yield [("RD", [user_root, "last_output"]),
  234. # ("RDE", [user_root, "last_output"]),
  235. # ("CN", []),
  236. # ]
  237. # _, _, _, _ = \
  238. # yield [("CD", [last_output, "value", value]),
  239. # ("CD", [last_output, "next", new_last_output]),
  240. # ("CD", [user_root, "last_output", new_last_output]),
  241. # ("DE", [last_output_link])
  242. # ]
  243. # yield None
  244. value_id, = yield [("RD", [instruction_id, "value"])]
  245. gen = self.analyze(value_id)
  246. try:
  247. inp = None
  248. while True:
  249. inp = yield gen.send(inp)
  250. except primitive_functions.PrimitiveFinished as ex:
  251. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  252. store_user_root = self.retrieve_user_root()
  253. last_output = tree_ir.StoreLocalInstruction(
  254. 'last_output',
  255. tree_ir.ReadDictionaryValueInstruction(
  256. store_user_root.create_load(),
  257. tree_ir.LiteralInstruction('last_output')))
  258. last_output_link = tree_ir.StoreLocalInstruction(
  259. 'last_output_link',
  260. tree_ir.ReadDictionaryEdgeInstruction(
  261. store_user_root.create_load(),
  262. tree_ir.LiteralInstruction('last_output')))
  263. new_last_output = tree_ir.StoreLocalInstruction(
  264. 'new_last_output',
  265. tree_ir.CreateNodeInstruction())
  266. result = tree_ir.create_block(
  267. value_local,
  268. store_user_root,
  269. last_output,
  270. last_output_link,
  271. new_last_output,
  272. tree_ir.CreateDictionaryEdgeInstruction(
  273. last_output.create_load(),
  274. tree_ir.LiteralInstruction('value'),
  275. value_local.create_load()),
  276. tree_ir.CreateDictionaryEdgeInstruction(
  277. last_output.create_load(),
  278. tree_ir.LiteralInstruction('next'),
  279. new_last_output.create_load()),
  280. tree_ir.CreateDictionaryEdgeInstruction(
  281. store_user_root.create_load(),
  282. tree_ir.LiteralInstruction('last_output'),
  283. new_last_output.create_load()),
  284. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  285. tree_ir.NopInstruction())
  286. raise primitive_functions.PrimitiveFinished(result)
  287. def analyze_input(self, _):
  288. """Tries to analyze the given 'input' instruction."""
  289. # The plan is to generate this tree:
  290. #
  291. # value = None
  292. # while True:
  293. # if value is None:
  294. # yield None # nop
  295. # else:
  296. # break
  297. #
  298. # _input = yield [("RD", [user_root, "input"])]
  299. # value = yield [("RD", [_input, "value"])]
  300. #
  301. # _next = yield [("RD", [_input, "next"])]
  302. # yield [("CD", [user_root, "input", _next])]
  303. # yield [("DN", [_input])]
  304. user_root = self.retrieve_user_root()
  305. _input = tree_ir.StoreLocalInstruction(
  306. '_input',
  307. tree_ir.ReadDictionaryValueInstruction(
  308. user_root.create_load(),
  309. tree_ir.LiteralInstruction('input')))
  310. value = tree_ir.StoreLocalInstruction(
  311. 'value',
  312. tree_ir.ReadDictionaryValueInstruction(
  313. _input.create_load(),
  314. tree_ir.LiteralInstruction('value')))
  315. raise primitive_functions.PrimitiveFinished(
  316. tree_ir.CompoundInstruction(
  317. tree_ir.create_block(
  318. user_root,
  319. value.create_store(tree_ir.LiteralInstruction(None)),
  320. tree_ir.LoopInstruction(
  321. tree_ir.create_block(
  322. tree_ir.SelectInstruction(
  323. tree_ir.BinaryInstruction(
  324. value.create_load(),
  325. 'is',
  326. tree_ir.LiteralInstruction(None)),
  327. tree_ir.NopInstruction(),
  328. tree_ir.BreakInstruction()),
  329. _input,
  330. value)),
  331. tree_ir.CreateDictionaryEdgeInstruction(
  332. user_root.create_load(),
  333. tree_ir.LiteralInstruction('input'),
  334. tree_ir.ReadDictionaryValueInstruction(
  335. _input.create_load(),
  336. tree_ir.LiteralInstruction('next'))),
  337. tree_ir.DeleteNodeInstruction(_input.create_load())),
  338. value.create_load()))
  339. def analyze_resolve(self, instruction_id):
  340. """Tries to analyze the given 'resolve' instruction."""
  341. var_id, = yield [("RD", [instruction_id, "var"])]
  342. var_name, = yield [("RV", [var_id])]
  343. # To resolve a variable, we'll do something along the
  344. # lines of:
  345. #
  346. # if 'local_var' in locals():
  347. # tmp = local_var
  348. # else:
  349. # _globals, = yield [("RD", [user_root, "globals"])]
  350. # global_var, = yield [("RD", [_globals, var_name])]
  351. #
  352. # if global_var is None:
  353. # raise Exception("Runtime error: global '%s' not found" % (var_name))
  354. #
  355. # tmp = global_var
  356. user_root = self.retrieve_user_root()
  357. global_var = tree_ir.StoreLocalInstruction(
  358. 'global_var',
  359. tree_ir.ReadDictionaryValueInstruction(
  360. tree_ir.ReadDictionaryValueInstruction(
  361. user_root.create_load(),
  362. tree_ir.LiteralInstruction('globals')),
  363. tree_ir.LiteralInstruction(var_name)))
  364. err_block = tree_ir.SelectInstruction(
  365. tree_ir.BinaryInstruction(
  366. global_var.create_load(),
  367. 'is',
  368. tree_ir.LiteralInstruction(None)),
  369. tree_ir.RaiseInstruction(
  370. tree_ir.CallInstruction(
  371. tree_ir.LoadLocalInstruction('Exception'),
  372. [tree_ir.LiteralInstruction(
  373. "Runtime error: global '%s' not found" % var_name)
  374. ])),
  375. tree_ir.EmptyInstruction())
  376. name = self.get_local_name(var_id)
  377. raise primitive_functions.PrimitiveFinished(
  378. tree_ir.SelectInstruction(
  379. tree_ir.LocalExistsInstruction(name),
  380. tree_ir.LoadLocalInstruction(name),
  381. tree_ir.CompoundInstruction(
  382. tree_ir.create_block(
  383. user_root,
  384. global_var,
  385. err_block),
  386. global_var.create_load())))
  387. def analyze_declare(self, instruction_id):
  388. """Tries to analyze the given 'declare' function."""
  389. var_id, = yield [("RD", [instruction_id, "var"])]
  390. name = self.get_local_name(var_id)
  391. # The following logic declares a local:
  392. #
  393. # if 'local_name' not in locals():
  394. # local_name, = yield [("CN", [])]
  395. raise primitive_functions.PrimitiveFinished(
  396. tree_ir.SelectInstruction(
  397. tree_ir.LocalExistsInstruction(name),
  398. tree_ir.EmptyInstruction(),
  399. tree_ir.StoreLocalInstruction(
  400. name,
  401. tree_ir.CreateNodeInstruction())))
  402. def analyze_global(self, instruction_id):
  403. """Tries to analyze the given 'global' (declaration) instruction."""
  404. var_id, = yield [("RD", [instruction_id, "var"])]
  405. var_name, = yield [("RV", [var_id])]
  406. # To resolve a variable, we'll do something along the
  407. # lines of:
  408. #
  409. # _globals, = yield [("RD", [user_root, "globals"])]
  410. # global_var = yield [("RD", [_globals, var_name])]
  411. #
  412. # if global_var is None:
  413. # global_var, = yield [("CN", [])]
  414. # yield [("CD", [_globals, var_name, global_var])]
  415. #
  416. # tmp = global_var
  417. user_root = self.retrieve_user_root()
  418. _globals = tree_ir.StoreLocalInstruction(
  419. '_globals',
  420. tree_ir.ReadDictionaryValueInstruction(
  421. user_root.create_load(),
  422. tree_ir.LiteralInstruction('globals')))
  423. global_var = tree_ir.StoreLocalInstruction(
  424. 'global_var',
  425. tree_ir.ReadDictionaryValueInstruction(
  426. _globals.create_load(),
  427. tree_ir.LiteralInstruction(var_name)))
  428. raise primitive_functions.PrimitiveFinished(
  429. tree_ir.CompoundInstruction(
  430. tree_ir.create_block(
  431. user_root,
  432. _globals,
  433. global_var,
  434. tree_ir.SelectInstruction(
  435. tree_ir.BinaryInstruction(
  436. global_var.create_load(),
  437. 'is',
  438. tree_ir.LiteralInstruction(None)),
  439. tree_ir.create_block(
  440. global_var.create_store(
  441. tree_ir.CreateNodeInstruction()),
  442. tree_ir.CreateDictionaryEdgeInstruction(
  443. _globals.create_load(),
  444. tree_ir.LiteralInstruction(var_name),
  445. global_var.create_load())),
  446. tree_ir.EmptyInstruction())),
  447. global_var.create_load()))
  448. def analyze_assign(self, instruction_id):
  449. """Tries to analyze the given 'assign' instruction."""
  450. var_id, value_id = yield [("RD", [instruction_id, "var"]),
  451. ("RD", [instruction_id, "value"])]
  452. try:
  453. gen = self.analyze_all([var_id, value_id])
  454. inp = None
  455. while True:
  456. inp = yield gen.send(inp)
  457. except primitive_functions.PrimitiveFinished as ex:
  458. var_r, value_r = ex.result
  459. # Assignments work like this:
  460. #
  461. # value_link = yield [("RDE", [variable, "value"])]
  462. # _, _ = yield [("CD", [variable, "value", value]),
  463. # ("DE", [value_link])]
  464. variable = tree_ir.StoreLocalInstruction('variable', var_r)
  465. value = tree_ir.StoreLocalInstruction('value', value_r)
  466. value_link = tree_ir.StoreLocalInstruction(
  467. 'value_link',
  468. tree_ir.ReadDictionaryEdgeInstruction(
  469. variable.create_load(),
  470. tree_ir.LiteralInstruction('value')))
  471. raise primitive_functions.PrimitiveFinished(
  472. tree_ir.create_block(
  473. variable,
  474. value,
  475. value_link,
  476. tree_ir.CreateDictionaryEdgeInstruction(
  477. variable.create_load(),
  478. tree_ir.LiteralInstruction('value'),
  479. value.create_load()),
  480. tree_ir.DeleteEdgeInstruction(
  481. value_link.create_load())))
  482. def analyze_access(self, instruction_id):
  483. """Tries to analyze the given 'access' instruction."""
  484. var_id, = yield [("RD", [instruction_id, "var"])]
  485. try:
  486. gen = self.analyze(var_id)
  487. inp = None
  488. while True:
  489. inp = yield gen.send(inp)
  490. except primitive_functions.PrimitiveFinished as ex:
  491. var_r = ex.result
  492. # Accessing a variable is pretty easy. It really just boils
  493. # down to reading the value corresponding to the 'value' key
  494. # of the variable.
  495. #
  496. # value, = yield [("RD", [returnvalue, "value"])]
  497. raise primitive_functions.PrimitiveFinished(
  498. tree_ir.ReadDictionaryValueInstruction(
  499. var_r,
  500. tree_ir.LiteralInstruction('value')))
  501. instruction_analyzers = {
  502. 'if' : analyze_if,
  503. 'while' : analyze_while,
  504. 'return' : analyze_return,
  505. 'constant' : analyze_constant,
  506. 'resolve' : analyze_resolve,
  507. 'declare' : analyze_declare,
  508. 'global' : analyze_global,
  509. 'assign' : analyze_assign,
  510. 'access' : analyze_access,
  511. 'output' : analyze_output,
  512. 'input' : analyze_input
  513. }