import modelverse_kernel.primitives as primitive_functions import modelverse_jit.tree_ir as tree_ir class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_entry_points = {} def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable.""" return body_id in self.todo_entry_points or \ body_id in self.jitted_entry_points def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def register_compiled(self, body_id, compiled): """Registers a compiled entry point with the JIT.""" self.jitted_entry_points[body_id] = compiled if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def try_jit(self, body_id, parameter_list): """Tries to jit the function defined by the given entry point id and parameter list.""" gen = AnalysisState().analyze(body_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: pass self.mark_no_jit(body_id) raise JitCompilationFailedException("Can't JIT function body at " + str(body_id)) class AnalysisState(object): """The state of a bytecode analysis call graph.""" def __init__(self): self.analyzed_instructions = set() def analyze(self, instruction_id): """Tries to build an intermediate representation from the instruction with the given id.""" # Add the instruction id to the analyzed_instructions set to avoid # infinite loops. self.analyzed_instructions.add(instruction_id) instruction_val, = yield [("RV", [instruction_id])] instruction_val = instruction_val["value"] if instruction_val in self.instruction_analyzers: gen = self.instruction_analyzers[instruction_val](self, instruction_id) try: inp = None while True: inp = yield gen.send(inp) except StopIteration: raise Exception( "Instruction analyzer (for '%s') finished without returning a value!" % (instruction_val)) except primitive_functions.PrimitiveFinished as outer_e: # Check if the instruction has a 'next' instruction. next_instr, = yield [("RD", [instruction_id, "next"])] if next_instr is None: raise outer_e else: gen = self.analyze(next_instr) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as inner_e: raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( outer_e.result, inner_e.result)) else: raise JitCompilationFailedException( "Unknown instruction type: '%s'" % (instruction_val)) def analyze_all(self, instruction_ids): """Tries to compile a list of IR trees from the given list of instruction ids.""" results = [] for inst in instruction_ids: gen = self.analyze(inst) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: results.append(e.result) raise primitive_functions.PrimitiveFinished(results) def analyze_return(self, instruction_id): """Tries to analyze the given 'return' instruction.""" retval_id, = yield [("RD", [instruction_id, 'value'])] if retval_id is None: raise primitive_functions.PrimitiveFinished( tree_ir.ReturnInstruction( tree_ir.EmptyInstruction())) else: gen = self.analyze(retval_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: raise primitive_functions.PrimitiveFinished( tree_ir.ReturnInstruction(e.result)) def analyze_if(self, instruction_id): """Tries to analyze the given 'if' instruction.""" cond, true, false = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "then"]), ("RD", [instruction_id, "else"])] gen = self.analyze_all([cond, true, false]) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: cond_r, true_r, false_r = e.result raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), true_r, false_r)) def analyze_constant(self, instruction_id): """Tries to analyze the given 'constant' (literal) instruction.""" node_id, = yield [("RD", [instruction_id, "node"])] raise primitive_functions.PrimitiveFinished( tree_ir.LiteralInstruction(node_id)) instruction_analyzers = { 'if' : analyze_if, 'return' : analyze_return, 'constant' : analyze_constant }