jit.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import modelverse_kernel.primitives as primitive_functions
  2. import modelverse_jit.tree_ir as tree_ir
  3. class JitCompilationFailedException(Exception):
  4. """A type of exception that is raised when the jit fails to compile a function."""
  5. pass
  6. class ModelverseJit(object):
  7. """A high-level interface to the modelverse JIT compiler."""
  8. def __init__(self):
  9. self.todo_entry_points = set()
  10. self.no_jit_entry_points = set()
  11. self.jitted_entry_points = {}
  12. def mark_entry_point(self, body_id):
  13. """Marks the node with the given identifier as a function entry point."""
  14. if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
  15. self.todo_entry_points.add(body_id)
  16. def is_entry_point(self, body_id):
  17. """Tells if the node with the given identifier is a function entry point."""
  18. return body_id in self.todo_entry_points or \
  19. body_id in self.no_jit_entry_points or \
  20. body_id in self.jitted_entry_points
  21. def is_jittable_entry_point(self, body_id):
  22. """Tells if the node with the given identifier is a function entry point that
  23. has not been marked as non-jittable."""
  24. return body_id in self.todo_entry_points or \
  25. body_id in self.jitted_entry_points
  26. def mark_no_jit(self, body_id):
  27. """Informs the JIT that the node with the given identifier is a function entry
  28. point that must never be jitted."""
  29. self.no_jit_entry_points.add(body_id)
  30. if body_id in self.todo_entry_points:
  31. self.todo_entry_points.remove(body_id)
  32. def register_compiled(self, body_id, compiled):
  33. """Registers a compiled entry point with the JIT."""
  34. self.jitted_entry_points[body_id] = compiled
  35. if body_id in self.todo_entry_points:
  36. self.todo_entry_points.remove(body_id)
  37. def try_jit(self, body_id, parameter_list):
  38. """Tries to jit the function defined by the given entry point id and parameter list."""
  39. print("Couldn't JIT: " + str(body_id))
  40. self.mark_no_jit(body_id)
  41. raise JitCompilationFailedException("Couln't JIT")