12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import modelverse_kernel.primitives as primitive_functions
- import modelverse_jit.tree_ir as tree_ir
- class JitCompilationFailedException(Exception):
- """A type of exception that is raised when the jit fails to compile a function."""
- pass
- class ModelverseJit(object):
- """A high-level interface to the modelverse JIT compiler."""
- def __init__(self):
- self.todo_entry_points = set()
- self.no_jit_entry_points = set()
- self.jitted_entry_points = {}
- def mark_entry_point(self, body_id):
- """Marks the node with the given identifier as a function entry point."""
- if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
- self.todo_entry_points.add(body_id)
- def is_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point."""
- return body_id in self.todo_entry_points or \
- body_id in self.no_jit_entry_points or \
- body_id in self.jitted_entry_points
- def is_jittable_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point that
- has not been marked as non-jittable."""
- return body_id in self.todo_entry_points or \
- body_id in self.jitted_entry_points
- def mark_no_jit(self, body_id):
- """Informs the JIT that the node with the given identifier is a function entry
- point that must never be jitted."""
- self.no_jit_entry_points.add(body_id)
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def register_compiled(self, body_id, compiled):
- """Registers a compiled entry point with the JIT."""
- self.jitted_entry_points[body_id] = compiled
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def try_jit(self, body_id, parameter_list):
- """Tries to jit the function defined by the given entry point id and parameter list."""
- print("Couldn't JIT: " + str(body_id))
- self.mark_no_jit(body_id)
- raise JitCompilationFailedException("Couln't JIT")
|