jit.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. class JitCompilationFailedException(Exception):
  4. """A type of exception that is raised when the jit fails to compile a function."""
  5. pass
  6. class ModelverseJit(object):
  7. """A high-level interface to the modelverse JIT compiler."""
  8. def __init__(self):
  9. self.todo_entry_points = set()
  10. self.no_jit_entry_points = set()
  11. self.jitted_entry_points = {}
  12. def mark_entry_point(self, body_id):
  13. """Marks the node with the given identifier as a function entry point."""
  14. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  15. self.todo_entry_points.add(body_id)
  16. def is_entry_point(self, body_id):
  17. """Tells if the node with the given identifier is a function entry point."""
  18. return body_id in self.todo_entry_points or \
  19. body_id in self.no_jit_entry_points or \
  20. body_id in self.jitted_entry_points
  21. def is_jittable_entry_point(self, body_id):
  22. """Tells if the node with the given identifier is a function entry point that
  23. has not been marked as non-jittable."""
  24. return body_id in self.todo_entry_points or \
  25. body_id in self.jitted_entry_points
  26. def mark_no_jit(self, body_id):
  27. """Informs the JIT that the node with the given identifier is a function entry
  28. point that must never be jitted."""
  29. self.no_jit_entry_points.add(body_id)
  30. if body_id in self.todo_entry_points:
  31. self.todo_entry_points.remove(body_id)
  32. def register_compiled(self, body_id, compiled):
  33. """Registers a compiled entry point with the JIT."""
  34. self.jitted_entry_points[body_id] = compiled
  35. if body_id in self.todo_entry_points:
  36. self.todo_entry_points.remove(body_id)
  37. def try_jit(self, body_id, parameter_list):
  38. """Tries to jit the function defined by the given entry point id and parameter list."""
  39. gen = AnalysisState().analyze(body_id)
  40. try:
  41. inp = None
  42. while True:
  43. inp = yield gen.send(inp)
  44. except primitive_functions.PrimitiveFinished as e:
  45. pass
  46. self.mark_no_jit(body_id)
  47. raise JitCompilationFailedException("Can't JIT function body at " + str(body_id))
  48. class AnalysisState(object):
  49. """The state of a bytecode analysis call graph."""
  50. def __init__(self):
  51. self.analyzed_instructions = set()
  52. def analyze(self, instruction_id):
  53. """Tries to build an intermediate representation from the instruction with the
  54. given id."""
  55. instruction_val, = yield [("RV", [instruction_id])]
  56. instruction_val = instruction_val["value"]
  57. if instruction_val in self.instruction_analyzers:
  58. gen = self.instruction_analyzers[instruction_val](self, instruction_id)
  59. try:
  60. inp = None
  61. while True:
  62. inp = yield gen.send(inp)
  63. except StopIteration:
  64. raise Exception(
  65. "Instruction analyzer (for '%s') finished without returning a value!" %
  66. (instruction_val))
  67. else:
  68. raise JitCompilationFailedException(
  69. "Unknown instruction type: '%s'" % (instruction_val))
  70. def analyze_all(self, instruction_ids):
  71. """Tries to compile a list of IR trees from the given list of instruction ids."""
  72. results = []
  73. for inst in instruction_ids:
  74. gen = self.analyze(inst)
  75. try:
  76. inp = None
  77. while True:
  78. inp = yield gen.send(inp)
  79. except primitive_functions.PrimitiveFinished as e:
  80. results.append(e.result)
  81. raise primitive_functions.PrimitiveFinished(results)
  82. def analyze_return(self, instruction_id):
  83. """Tries to analyze the given 'return' instruction."""
  84. retval_id, = yield [("RD", [instruction_id, 'value'])]
  85. if retval_id is None:
  86. raise primitive_functions.PrimitiveFinished(
  87. tree_ir.ReturnInstruction(
  88. tree_ir.EmptyInstruction()))
  89. else:
  90. gen = self.analyze(retval_id)
  91. try:
  92. inp = None
  93. while True:
  94. inp = yield gen.send(inp)
  95. except primitive_functions.PrimitiveFinished as e:
  96. raise primitive_functions.PrimitiveFinished(
  97. tree_ir.ReturnInstruction(e.result))
  98. def analyze_if(self, instruction_id):
  99. """Tries to analyze the given 'if' instruction."""
  100. cond, true, false, next_inst = yield [
  101. ("RD", [instruction_id, "cond"]),
  102. ("RD", [instruction_id, "then"]),
  103. ("RD", [instruction_id, "else"]),
  104. ("RD", [instruction_id, "next"])]
  105. gen = self.analyze_all([cond, true, false, next_inst])
  106. try:
  107. inp = None
  108. while True:
  109. inp = yield gen.send(inp)
  110. except primitive_functions.PrimitiveFinished as e:
  111. cond_r, true_r, false_r, next_r = e.result
  112. raise primitive_functions.PrimitiveFinished(
  113. tree_ir.CompoundInstruction(
  114. tree_ir.SelectInstruction(
  115. tree_ir.ReadValueInstruction(cond_r),
  116. true_r,
  117. false_r),
  118. next_r))
  119. def analyze_constant(self, instruction_id):
  120. """Tries to analyze the given 'constant' (literal) instruction."""
  121. node_id, = yield [("RD", [instruction_id, "node"])]
  122. raise primitive_functions.PrimitiveFinished(
  123. tree_ir.LiteralInstruction(node_id))
  124. instruction_analyzers = {
  125. 'if' : analyze_if,
  126. 'return' : analyze_return,
  127. 'constant' : analyze_constant
  128. }