jit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. KWARGS_PARAMETER_NAME = "remainder"
  4. """The name of the kwargs parameter in jitted functions."""
  5. class JitCompilationFailedException(Exception):
  6. """A type of exception that is raised when the jit fails to compile a function."""
  7. pass
  8. class ModelverseJit(object):
  9. """A high-level interface to the modelverse JIT compiler."""
  10. def __init__(self):
  11. self.todo_entry_points = set()
  12. self.no_jit_entry_points = set()
  13. self.jitted_entry_points = {}
  14. def mark_entry_point(self, body_id):
  15. """Marks the node with the given identifier as a function entry point."""
  16. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  17. self.todo_entry_points.add(body_id)
  18. def is_entry_point(self, body_id):
  19. """Tells if the node with the given identifier is a function entry point."""
  20. return body_id in self.todo_entry_points or \
  21. body_id in self.no_jit_entry_points or \
  22. body_id in self.jitted_entry_points
  23. def is_jittable_entry_point(self, body_id):
  24. """Tells if the node with the given identifier is a function entry point that
  25. has not been marked as non-jittable."""
  26. return body_id in self.todo_entry_points or \
  27. body_id in self.jitted_entry_points
  28. def mark_no_jit(self, body_id):
  29. """Informs the JIT that the node with the given identifier is a function entry
  30. point that must never be jitted."""
  31. self.no_jit_entry_points.add(body_id)
  32. if body_id in self.todo_entry_points:
  33. self.todo_entry_points.remove(body_id)
  34. def register_compiled(self, body_id, compiled):
  35. """Registers a compiled entry point with the JIT."""
  36. self.jitted_entry_points[body_id] = compiled
  37. if body_id in self.todo_entry_points:
  38. self.todo_entry_points.remove(body_id)
  39. def try_jit(self, body_id, parameter_list):
  40. """Tries to jit the function defined by the given entry point id and parameter list."""
  41. gen = AnalysisState().analyze(body_id)
  42. try:
  43. inp = None
  44. while True:
  45. inp = yield gen.send(inp)
  46. except primitive_functions.PrimitiveFinished as e:
  47. constructed_ir = e.result
  48. except JitCompilationFailedException:
  49. self.mark_no_jit(body_id)
  50. raise
  51. print(constructed_ir)
  52. self.mark_no_jit(body_id)
  53. raise JitCompilationFailedException("Can't jit function body at " + str(body_id))
  54. class AnalysisState(object):
  55. """The state of a bytecode analysis call graph."""
  56. def __init__(self):
  57. self.analyzed_instructions = set()
  58. def analyze(self, instruction_id):
  59. """Tries to build an intermediate representation from the instruction with the
  60. given id."""
  61. # Check the analyzed_instructions set for instruction_id to avoid
  62. # infinite loops.
  63. if instruction_id in self.analyzed_instructions:
  64. raise JitCompilationFailedException('Cannon jit non-tree instruction graph.')
  65. self.analyzed_instructions.add(instruction_id)
  66. instruction_val, = yield [("RV", [instruction_id])]
  67. instruction_val = instruction_val["value"]
  68. if instruction_val in self.instruction_analyzers:
  69. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  70. try:
  71. inp = None
  72. while True:
  73. inp = yield gen.send(inp)
  74. except StopIteration:
  75. raise Exception(
  76. "Instruction analyzer (for '%s') finished without returning a value!" %
  77. (instruction_val))
  78. except primitive_functions.PrimitiveFinished as outer_e:
  79. # Check if the instruction has a 'next' instruction.
  80. next_instr, = yield [("RD", [instruction_id, "next"])]
  81. if next_instr is None:
  82. raise outer_e
  83. else:
  84. gen = self.analyze(next_instr)
  85. try:
  86. inp = None
  87. while True:
  88. inp = yield gen.send(inp)
  89. except primitive_functions.PrimitiveFinished as inner_e:
  90. raise primitive_functions.PrimitiveFinished(
  91. tree_ir.CompoundInstruction(
  92. outer_e.result,
  93. inner_e.result))
  94. else:
  95. raise JitCompilationFailedException(
  96. "Unknown instruction type: '%s'" % (instruction_val))
  97. def analyze_all(self, instruction_ids):
  98. """Tries to compile a list of IR trees from the given list of instruction ids."""
  99. results = []
  100. for inst in instruction_ids:
  101. gen = self.analyze(inst)
  102. try:
  103. inp = None
  104. while True:
  105. inp = yield gen.send(inp)
  106. except primitive_functions.PrimitiveFinished as e:
  107. results.append(e.result)
  108. raise primitive_functions.PrimitiveFinished(results)
  109. def analyze_return(self, instruction_id):
  110. """Tries to analyze the given 'return' instruction."""
  111. retval_id, = yield [("RD", [instruction_id, 'value'])]
  112. if retval_id is None:
  113. raise primitive_functions.PrimitiveFinished(
  114. tree_ir.ReturnInstruction(
  115. tree_ir.EmptyInstruction()))
  116. else:
  117. gen = self.analyze(retval_id)
  118. try:
  119. inp = None
  120. while True:
  121. inp = yield gen.send(inp)
  122. except primitive_functions.PrimitiveFinished as e:
  123. raise primitive_functions.PrimitiveFinished(
  124. tree_ir.ReturnInstruction(e.result))
  125. def analyze_if(self, instruction_id):
  126. """Tries to analyze the given 'if' instruction."""
  127. cond, true, false = yield [
  128. ("RD", [instruction_id, "cond"]),
  129. ("RD", [instruction_id, "then"]),
  130. ("RD", [instruction_id, "else"])]
  131. gen = self.analyze_all([cond, true, false])
  132. try:
  133. inp = None
  134. while True:
  135. inp = yield gen.send(inp)
  136. except primitive_functions.PrimitiveFinished as e:
  137. cond_r, true_r, false_r = e.result
  138. raise primitive_functions.PrimitiveFinished(
  139. tree_ir.SelectInstruction(
  140. tree_ir.ReadValueInstruction(cond_r),
  141. true_r,
  142. false_r))
  143. def analyze_while(self, instruction_id):
  144. """Tries to analyze the given 'while' instruction."""
  145. cond, body = yield [
  146. ("RD", [instruction_id, "cond"]),
  147. ("RD", [instruction_id, "body"])]
  148. gen = self.analyze_all([cond, body])
  149. try:
  150. inp = None
  151. while True:
  152. inp = yield gen.send(inp)
  153. except primitive_functions.PrimitiveFinished as e:
  154. cond_r, body_r = e.result
  155. raise primitive_functions.PrimitiveFinished(
  156. tree_ir.LoopInstruction(
  157. tree_ir.CompoundInstruction(
  158. tree_ir.SelectInstruction(
  159. tree_ir.ReadValueInstruction(cond_r),
  160. tree_ir.EmptyInstruction(),
  161. tree_ir.BreakInstruction()),
  162. body_r)))
  163. def analyze_constant(self, instruction_id):
  164. """Tries to analyze the given 'constant' (literal) instruction."""
  165. node_id, = yield [("RD", [instruction_id, "node"])]
  166. raise primitive_functions.PrimitiveFinished(
  167. tree_ir.LiteralInstruction(node_id))
  168. def analyze_output(self, instruction_id):
  169. """Tries to analyze the given 'output' instruction."""
  170. # The plan is to basically generate this tree:
  171. #
  172. # value = <some tree>
  173. # last_output, last_output_link, new_last_output = \
  174. # yield [("RD", [user_root, "last_output"]),
  175. # ("RDE", [user_root, "last_output"]),
  176. # ("CN", []),
  177. # ]
  178. # _, _, _, _ = \
  179. # yield [("CD", [last_output, "value", value]),
  180. # ("CD", [last_output, "next", new_last_output]),
  181. # ("CD", [user_root, "last_output", new_last_output]),
  182. # ("DE", [last_output_link])
  183. # ]
  184. value_id, = yield [("RD", [instruction_id, "value"])]
  185. gen = self.analyze(value_id)
  186. try:
  187. inp = None
  188. while True:
  189. inp = yield gen.send(inp)
  190. except primitive_functions.PrimitiveFinished as e:
  191. value_local = tree_ir.StoreLocalInstruction('value', e.result)
  192. store_user_root = tree_ir.StoreLocalInstruction(
  193. 'user_root',
  194. tree_ir.LoadIndexInstruction(
  195. tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
  196. tree_ir.LiteralInstruction('user_root')))
  197. last_output = tree_ir.StoreLocalInstruction(
  198. 'last_output',
  199. tree_ir.ReadDictionaryValueInstruction(
  200. store_user_root.create_load(),
  201. tree_ir.LiteralInstruction('last_output')))
  202. last_output_link = tree_ir.StoreLocalInstruction(
  203. 'last_output_link',
  204. tree_ir.ReadDictionaryEdgeInstruction(
  205. store_user_root.create_load(),
  206. tree_ir.LiteralInstruction('last_output')))
  207. new_last_output = tree_ir.StoreLocalInstruction(
  208. 'new_last_output',
  209. tree_ir.CreateNodeInstruction())
  210. result = tree_ir.create_block(
  211. value_local,
  212. store_user_root,
  213. last_output,
  214. last_output_link,
  215. new_last_output,
  216. tree_ir.CreateDictionaryEdgeInstruction(
  217. last_output.create_load(),
  218. tree_ir.LiteralInstruction('value'),
  219. value_local.create_load()),
  220. tree_ir.CreateDictionaryEdgeInstruction(
  221. last_output.create_load(),
  222. tree_ir.LiteralInstruction('next'),
  223. new_last_output.create_load()),
  224. tree_ir.CreateDictionaryEdgeInstruction(
  225. store_user_root.create_load(),
  226. tree_ir.LiteralInstruction('last_output'),
  227. new_last_output.create_load()),
  228. tree_ir.DeleteEdgeInstruction(last_output_link.create_load()))
  229. raise primitive_functions.PrimitiveFinished(result)
  230. instruction_analyzers = {
  231. 'if' : analyze_if,
  232. 'while' : analyze_while,
  233. 'return' : analyze_return,
  234. 'constant' : analyze_constant,
  235. 'output' : analyze_output
  236. }