import modelverse_kernel.primitives as primitive_functions import modelverse_jit.tree_ir as tree_ir class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_entry_points = {} def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable.""" return body_id in self.todo_entry_points or \ body_id in self.jitted_entry_points def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def register_compiled(self, body_id, compiled): """Registers a compiled entry point with the JIT.""" self.jitted_entry_points[body_id] = compiled if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def try_jit(self, body_id, parameter_list): """Tries to jit the function defined by the given entry point id and parameter list.""" print("Couldn't JIT: " + str(body_id)) self.mark_no_jit(body_id) raise JitCompilationFailedException("Couln't JIT")