jit.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. self.jit_globals = {
  15. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
  16. }
  17. self.jit_count = 0
  18. def mark_entry_point(self, body_id):
  19. """Marks the node with the given identifier as a function entry point."""
  20. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  21. self.todo_entry_points.add(body_id)
  22. def is_entry_point(self, body_id):
  23. """Tells if the node with the given identifier is a function entry point."""
  24. return body_id in self.todo_entry_points or \
  25. body_id in self.no_jit_entry_points or \
  26. body_id in self.jitted_entry_points
  27. def is_jittable_entry_point(self, body_id):
  28. """Tells if the node with the given identifier is a function entry point that
  29. has not been marked as non-jittable."""
  30. return body_id in self.todo_entry_points or \
  31. body_id in self.jitted_entry_points
  32. def mark_no_jit(self, body_id):
  33. """Informs the JIT that the node with the given identifier is a function entry
  34. point that must never be jitted."""
  35. self.no_jit_entry_points.add(body_id)
  36. if body_id in self.todo_entry_points:
  37. self.todo_entry_points.remove(body_id)
  38. def register_compiled(self, body_id, compiled):
  39. """Registers a compiled entry point with the JIT."""
  40. self.jitted_entry_points[body_id] = compiled
  41. if body_id in self.todo_entry_points:
  42. self.todo_entry_points.remove(body_id)
  43. def try_jit(self, body_id, parameter_list):
  44. """Tries to jit the function defined by the given entry point id and parameter list."""
  45. if body_id in self.jitted_entry_points:
  46. # We have already compiled this function.
  47. raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
  48. elif body_id in self.no_jit_entry_points:
  49. # We're not allowed to jit this function or have tried and failed before.
  50. raise JitCompilationFailedException(
  51. 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
  52. gen = AnalysisState().analyze(body_id)
  53. try:
  54. inp = None
  55. while True:
  56. inp = yield gen.send(inp)
  57. except primitive_functions.PrimitiveFinished as ex:
  58. constructed_body = ex.result
  59. except JitCompilationFailedException as ex:
  60. self.mark_no_jit(body_id)
  61. raise JitCompilationFailedException(
  62. '%s (function at %d)' % (ex.message, body_id))
  63. # Wrap the IR in a function definition, give it a unique name.
  64. constructed_function = tree_ir.DefineFunctionInstruction(
  65. 'jit_func%d' % self.jit_count,
  66. parameter_list + ['**' + KWARGS_PARAMETER_NAME],
  67. constructed_body)
  68. self.jit_count += 1
  69. # Convert the function definition to Python code, and compile it.
  70. exec(str(constructed_function), self.jit_globals)
  71. # Extract the compiled function from the JIT global state.
  72. compiled_function = self.jit_globals[constructed_function.name]
  73. print(constructed_function)
  74. # Save the compiled function so we can reuse it later.
  75. self.jitted_entry_points[body_id] = compiled_function
  76. raise primitive_functions.PrimitiveFinished(compiled_function)
  77. class AnalysisState(object):
  78. """The state of a bytecode analysis call graph."""
  79. def __init__(self):
  80. self.analyzed_instructions = set()
  81. def analyze(self, instruction_id):
  82. """Tries to build an intermediate representation from the instruction with the
  83. given id."""
  84. # Check the analyzed_instructions set for instruction_id to avoid
  85. # infinite loops.
  86. if instruction_id in self.analyzed_instructions:
  87. raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
  88. self.analyzed_instructions.add(instruction_id)
  89. instruction_val, = yield [("RV", [instruction_id])]
  90. instruction_val = instruction_val["value"]
  91. if instruction_val in self.instruction_analyzers:
  92. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  93. try:
  94. inp = None
  95. while True:
  96. inp = yield gen.send(inp)
  97. except StopIteration:
  98. raise Exception(
  99. "Instruction analyzer (for '%s') finished without returning a value!" %
  100. (instruction_val))
  101. except primitive_functions.PrimitiveFinished as outer_e:
  102. # Check if the instruction has a 'next' instruction.
  103. next_instr, = yield [("RD", [instruction_id, "next"])]
  104. if next_instr is None:
  105. raise outer_e
  106. else:
  107. gen = self.analyze(next_instr)
  108. try:
  109. inp = None
  110. while True:
  111. inp = yield gen.send(inp)
  112. except primitive_functions.PrimitiveFinished as inner_e:
  113. raise primitive_functions.PrimitiveFinished(
  114. tree_ir.CompoundInstruction(
  115. outer_e.result,
  116. inner_e.result))
  117. else:
  118. raise JitCompilationFailedException(
  119. "Unknown instruction type: '%s'" % (instruction_val))
  120. def analyze_all(self, instruction_ids):
  121. """Tries to compile a list of IR trees from the given list of instruction ids."""
  122. results = []
  123. for inst in instruction_ids:
  124. gen = self.analyze(inst)
  125. try:
  126. inp = None
  127. while True:
  128. inp = yield gen.send(inp)
  129. except primitive_functions.PrimitiveFinished as ex:
  130. results.append(ex.result)
  131. raise primitive_functions.PrimitiveFinished(results)
  132. def analyze_return(self, instruction_id):
  133. """Tries to analyze the given 'return' instruction."""
  134. retval_id, = yield [("RD", [instruction_id, 'value'])]
  135. if retval_id is None:
  136. raise primitive_functions.PrimitiveFinished(
  137. tree_ir.ReturnInstruction(
  138. tree_ir.EmptyInstruction()))
  139. else:
  140. gen = self.analyze(retval_id)
  141. try:
  142. inp = None
  143. while True:
  144. inp = yield gen.send(inp)
  145. except primitive_functions.PrimitiveFinished as ex:
  146. raise primitive_functions.PrimitiveFinished(
  147. tree_ir.ReturnInstruction(ex.result))
  148. def analyze_if(self, instruction_id):
  149. """Tries to analyze the given 'if' instruction."""
  150. cond, true, false = yield [
  151. ("RD", [instruction_id, "cond"]),
  152. ("RD", [instruction_id, "then"]),
  153. ("RD", [instruction_id, "else"])]
  154. gen = self.analyze_all(
  155. [cond, true]
  156. if false is None
  157. else [cond, true, false])
  158. try:
  159. inp = None
  160. while True:
  161. inp = yield gen.send(inp)
  162. except primitive_functions.PrimitiveFinished as ex:
  163. if false is None:
  164. cond_r, true_r = ex.result
  165. false_r = tree_ir.EmptyInstruction()
  166. else:
  167. cond_r, true_r, false_r = ex.result
  168. raise primitive_functions.PrimitiveFinished(
  169. tree_ir.SelectInstruction(
  170. tree_ir.ReadValueInstruction(cond_r),
  171. true_r,
  172. false_r))
  173. def analyze_while(self, instruction_id):
  174. """Tries to analyze the given 'while' instruction."""
  175. cond, body = yield [
  176. ("RD", [instruction_id, "cond"]),
  177. ("RD", [instruction_id, "body"])]
  178. gen = self.analyze_all([cond, body])
  179. try:
  180. inp = None
  181. while True:
  182. inp = yield gen.send(inp)
  183. except primitive_functions.PrimitiveFinished as ex:
  184. cond_r, body_r = ex.result
  185. raise primitive_functions.PrimitiveFinished(
  186. tree_ir.LoopInstruction(
  187. tree_ir.CompoundInstruction(
  188. tree_ir.SelectInstruction(
  189. tree_ir.ReadValueInstruction(cond_r),
  190. tree_ir.EmptyInstruction(),
  191. tree_ir.BreakInstruction()),
  192. body_r)))
  193. def analyze_constant(self, instruction_id):
  194. """Tries to analyze the given 'constant' (literal) instruction."""
  195. node_id, = yield [("RD", [instruction_id, "node"])]
  196. raise primitive_functions.PrimitiveFinished(
  197. tree_ir.LiteralInstruction(node_id))
  198. def analyze_output(self, instruction_id):
  199. """Tries to analyze the given 'output' instruction."""
  200. # The plan is to basically generate this tree:
  201. #
  202. # value = <some tree>
  203. # last_output, last_output_link, new_last_output = \
  204. # yield [("RD", [user_root, "last_output"]),
  205. # ("RDE", [user_root, "last_output"]),
  206. # ("CN", []),
  207. # ]
  208. # _, _, _, _ = \
  209. # yield [("CD", [last_output, "value", value]),
  210. # ("CD", [last_output, "next", new_last_output]),
  211. # ("CD", [user_root, "last_output", new_last_output]),
  212. # ("DE", [last_output_link])
  213. # ]
  214. # yield None
  215. value_id, = yield [("RD", [instruction_id, "value"])]
  216. gen = self.analyze(value_id)
  217. try:
  218. inp = None
  219. while True:
  220. inp = yield gen.send(inp)
  221. except primitive_functions.PrimitiveFinished as ex:
  222. value_local = tree_ir.StoreLocalInstruction('value', ex.result)
  223. store_user_root = tree_ir.StoreLocalInstruction(
  224. 'user_root',
  225. tree_ir.LoadIndexInstruction(
  226. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  227. tree_ir.LiteralInstruction('user_root')))
  228. last_output = tree_ir.StoreLocalInstruction(
  229. 'last_output',
  230. tree_ir.ReadDictionaryValueInstruction(
  231. store_user_root.create_load(),
  232. tree_ir.LiteralInstruction('last_output')))
  233. last_output_link = tree_ir.StoreLocalInstruction(
  234. 'last_output_link',
  235. tree_ir.ReadDictionaryEdgeInstruction(
  236. store_user_root.create_load(),
  237. tree_ir.LiteralInstruction('last_output')))
  238. new_last_output = tree_ir.StoreLocalInstruction(
  239. 'new_last_output',
  240. tree_ir.CreateNodeInstruction())
  241. result = tree_ir.create_block(
  242. value_local,
  243. store_user_root,
  244. last_output,
  245. last_output_link,
  246. new_last_output,
  247. tree_ir.CreateDictionaryEdgeInstruction(
  248. last_output.create_load(),
  249. tree_ir.LiteralInstruction('value'),
  250. value_local.create_load()),
  251. tree_ir.CreateDictionaryEdgeInstruction(
  252. last_output.create_load(),
  253. tree_ir.LiteralInstruction('next'),
  254. new_last_output.create_load()),
  255. tree_ir.CreateDictionaryEdgeInstruction(
  256. store_user_root.create_load(),
  257. tree_ir.LiteralInstruction('last_output'),
  258. new_last_output.create_load()),
  259. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
  260. tree_ir.NopInstruction())
  261. raise primitive_functions.PrimitiveFinished(result)
  262. instruction_analyzers = {
  263. 'if' : analyze_if,
  264. 'while' : analyze_while,
  265. 'return' : analyze_return,
  266. 'constant' : analyze_constant,
  267. 'output' : analyze_output
  268. }