jit.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. class JitCompilationFailedException(Exception):
  4. """A type of exception that is raised when the jit fails to compile a function."""
  5. pass
  6. class ModelverseJit(object):
  7. """A high-level interface to the modelverse JIT compiler."""
  8. def __init__(self):
  9. self.todo_entry_points = set()
  10. self.no_jit_entry_points = set()
  11. self.jitted_entry_points = {}
  12. def mark_entry_point(self, body_id):
  13. """Marks the node with the given identifier as a function entry point."""
  14. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  15. self.todo_entry_points.add(body_id)
  16. def is_entry_point(self, body_id):
  17. """Tells if the node with the given identifier is a function entry point."""
  18. return body_id in self.todo_entry_points or \
  19. body_id in self.no_jit_entry_points or \
  20. body_id in self.jitted_entry_points
  21. def is_jittable_entry_point(self, body_id):
  22. """Tells if the node with the given identifier is a function entry point that
  23. has not been marked as non-jittable."""
  24. return body_id in self.todo_entry_points or \
  25. body_id in self.jitted_entry_points
  26. def mark_no_jit(self, body_id):
  27. """Informs the JIT that the node with the given identifier is a function entry
  28. point that must never be jitted."""
  29. self.no_jit_entry_points.add(body_id)
  30. if body_id in self.todo_entry_points:
  31. self.todo_entry_points.remove(body_id)
  32. def register_compiled(self, body_id, compiled):
  33. """Registers a compiled entry point with the JIT."""
  34. self.jitted_entry_points[body_id] = compiled
  35. if body_id in self.todo_entry_points:
  36. self.todo_entry_points.remove(body_id)
  37. def try_jit(self, body_id, parameter_list):
  38. """Tries to jit the function defined by the given entry point id and parameter list."""
  39. gen = AnalysisState().analyze(body_id)
  40. try:
  41. inp = None
  42. while True:
  43. inp = yield gen.send(inp)
  44. except primitive_functions.PrimitiveFinished as e:
  45. pass
  46. self.mark_no_jit(body_id)
  47. raise JitCompilationFailedException("Can't JIT function body at " + str(body_id))
  48. class AnalysisState(object):
  49. """The state of a bytecode analysis call graph."""
  50. def __init__(self):
  51. self.analyzed_instructions = set()
  52. def analyze(self, instruction_id):
  53. """Tries to build an intermediate representation from the instruction with the
  54. given id."""
  55. # Add the instruction id to the analyzed_instructions set to avoid
  56. # infinite loops.
  57. self.analyzed_instructions.add(instruction_id)
  58. instruction_val, = yield [("RV", [instruction_id])]
  59. instruction_val = instruction_val["value"]
  60. if instruction_val in self.instruction_analyzers:
  61. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  62. try:
  63. inp = None
  64. while True:
  65. inp = yield gen.send(inp)
  66. except StopIteration:
  67. raise Exception(
  68. "Instruction analyzer (for '%s') finished without returning a value!" %
  69. (instruction_val))
  70. except primitive_functions.PrimitiveFinished as outer_e:
  71. # Check if the instruction has a 'next' instruction.
  72. next_instr, = yield [("RD", [instruction_id, "next"])]
  73. if next_instr is None:
  74. raise outer_e
  75. else:
  76. gen = self.analyze(next_instr)
  77. try:
  78. inp = None
  79. while True:
  80. inp = yield gen.send(inp)
  81. except primitive_functions.PrimitiveFinished as inner_e:
  82. raise primitive_functions.PrimitiveFinished(
  83. tree_ir.CompoundInstruction(
  84. outer_e.result,
  85. inner_e.result))
  86. else:
  87. raise JitCompilationFailedException(
  88. "Unknown instruction type: '%s'" % (instruction_val))
  89. def analyze_all(self, instruction_ids):
  90. """Tries to compile a list of IR trees from the given list of instruction ids."""
  91. results = []
  92. for inst in instruction_ids:
  93. gen = self.analyze(inst)
  94. try:
  95. inp = None
  96. while True:
  97. inp = yield gen.send(inp)
  98. except primitive_functions.PrimitiveFinished as e:
  99. results.append(e.result)
  100. raise primitive_functions.PrimitiveFinished(results)
  101. def analyze_return(self, instruction_id):
  102. """Tries to analyze the given 'return' instruction."""
  103. retval_id, = yield [("RD", [instruction_id, 'value'])]
  104. if retval_id is None:
  105. raise primitive_functions.PrimitiveFinished(
  106. tree_ir.ReturnInstruction(
  107. tree_ir.EmptyInstruction()))
  108. else:
  109. gen = self.analyze(retval_id)
  110. try:
  111. inp = None
  112. while True:
  113. inp = yield gen.send(inp)
  114. except primitive_functions.PrimitiveFinished as e:
  115. raise primitive_functions.PrimitiveFinished(
  116. tree_ir.ReturnInstruction(e.result))
  117. def analyze_if(self, instruction_id):
  118. """Tries to analyze the given 'if' instruction."""
  119. cond, true, false = yield [
  120. ("RD", [instruction_id, "cond"]),
  121. ("RD", [instruction_id, "then"]),
  122. ("RD", [instruction_id, "else"])]
  123. gen = self.analyze_all([cond, true, false])
  124. try:
  125. inp = None
  126. while True:
  127. inp = yield gen.send(inp)
  128. except primitive_functions.PrimitiveFinished as e:
  129. cond_r, true_r, false_r = e.result
  130. raise primitive_functions.PrimitiveFinished(
  131. tree_ir.SelectInstruction(
  132. tree_ir.ReadValueInstruction(cond_r),
  133. true_r,
  134. false_r))
  135. def analyze_constant(self, instruction_id):
  136. """Tries to analyze the given 'constant' (literal) instruction."""
  137. node_id, = yield [("RD", [instruction_id, "node"])]
  138. raise primitive_functions.PrimitiveFinished(
  139. tree_ir.LiteralInstruction(node_id))
  140. instruction_analyzers = {
  141. 'if' : analyze_if,
  142. 'return' : analyze_return,
  143. 'constant' : analyze_constant
  144. }