import modelverse_kernel.primitives as primitive_functions import modelverse_jit.tree_ir as tree_ir KWARGS_PARAMETER_NAME = "remainder" """The name of the kwargs parameter in jitted functions.""" class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_entry_points = {} def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable.""" return body_id in self.todo_entry_points or \ body_id in self.jitted_entry_points def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def register_compiled(self, body_id, compiled): """Registers a compiled entry point with the JIT.""" self.jitted_entry_points[body_id] = compiled if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def try_jit(self, body_id, parameter_list): """Tries to jit the function defined by the given entry point id and parameter list.""" gen = AnalysisState().analyze(body_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: constructed_ir = e.result except JitCompilationFailedException: self.mark_no_jit(body_id) raise print(constructed_ir) self.mark_no_jit(body_id) raise JitCompilationFailedException("Can't jit function body at " + str(body_id)) class AnalysisState(object): """The state of a bytecode analysis call graph.""" def __init__(self): self.analyzed_instructions = set() def analyze(self, instruction_id): """Tries to build an intermediate representation from the instruction with the given id.""" # Check the analyzed_instructions set for instruction_id to avoid # infinite loops. if instruction_id in self.analyzed_instructions: raise JitCompilationFailedException('Cannon jit non-tree instruction graph.') self.analyzed_instructions.add(instruction_id) instruction_val, = yield [("RV", [instruction_id])] instruction_val = instruction_val["value"] if instruction_val in self.instruction_analyzers: gen = self.instruction_analyzers[instruction_val](self, instruction_id) try: inp = None while True: inp = yield gen.send(inp) except StopIteration: raise Exception( "Instruction analyzer (for '%s') finished without returning a value!" % (instruction_val)) except primitive_functions.PrimitiveFinished as outer_e: # Check if the instruction has a 'next' instruction. next_instr, = yield [("RD", [instruction_id, "next"])] if next_instr is None: raise outer_e else: gen = self.analyze(next_instr) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as inner_e: raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( outer_e.result, inner_e.result)) else: raise JitCompilationFailedException( "Unknown instruction type: '%s'" % (instruction_val)) def analyze_all(self, instruction_ids): """Tries to compile a list of IR trees from the given list of instruction ids.""" results = [] for inst in instruction_ids: gen = self.analyze(inst) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: results.append(e.result) raise primitive_functions.PrimitiveFinished(results) def analyze_return(self, instruction_id): """Tries to analyze the given 'return' instruction.""" retval_id, = yield [("RD", [instruction_id, 'value'])] if retval_id is None: raise primitive_functions.PrimitiveFinished( tree_ir.ReturnInstruction( tree_ir.EmptyInstruction())) else: gen = self.analyze(retval_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: raise primitive_functions.PrimitiveFinished( tree_ir.ReturnInstruction(e.result)) def analyze_if(self, instruction_id): """Tries to analyze the given 'if' instruction.""" cond, true, false = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "then"]), ("RD", [instruction_id, "else"])] gen = self.analyze_all([cond, true, false]) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: cond_r, true_r, false_r = e.result raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), true_r, false_r)) def analyze_while(self, instruction_id): """Tries to analyze the given 'while' instruction.""" cond, body = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "body"])] gen = self.analyze_all([cond, body]) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: cond_r, body_r = e.result raise primitive_functions.PrimitiveFinished( tree_ir.LoopInstruction( tree_ir.CompoundInstruction( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), tree_ir.EmptyInstruction(), tree_ir.BreakInstruction()), body_r))) def analyze_constant(self, instruction_id): """Tries to analyze the given 'constant' (literal) instruction.""" node_id, = yield [("RD", [instruction_id, "node"])] raise primitive_functions.PrimitiveFinished( tree_ir.LiteralInstruction(node_id)) def analyze_output(self, instruction_id): """Tries to analyze the given 'output' instruction.""" # The plan is to basically generate this tree: # # value = # last_output, last_output_link, new_last_output = \ # yield [("RD", [user_root, "last_output"]), # ("RDE", [user_root, "last_output"]), # ("CN", []), # ] # _, _, _, _ = \ # yield [("CD", [last_output, "value", value]), # ("CD", [last_output, "next", new_last_output]), # ("CD", [user_root, "last_output", new_last_output]), # ("DE", [last_output_link]) # ] value_id, = yield [("RD", [instruction_id, "value"])] gen = self.analyze(value_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as e: value_local = tree_ir.StoreLocalInstruction('value', e.result) store_user_root = tree_ir.StoreLocalInstruction( 'user_root', tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('user_root'))) last_output = tree_ir.StoreLocalInstruction( 'last_output', tree_ir.ReadDictionaryValueInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'))) last_output_link = tree_ir.StoreLocalInstruction( 'last_output_link', tree_ir.ReadDictionaryEdgeInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'))) new_last_output = tree_ir.StoreLocalInstruction( 'new_last_output', tree_ir.CreateNodeInstruction()) result = tree_ir.create_block( value_local, store_user_root, last_output, last_output_link, new_last_output, tree_ir.CreateDictionaryEdgeInstruction( last_output.create_load(), tree_ir.LiteralInstruction('value'), value_local.create_load()), tree_ir.CreateDictionaryEdgeInstruction( last_output.create_load(), tree_ir.LiteralInstruction('next'), new_last_output.create_load()), tree_ir.CreateDictionaryEdgeInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'), new_last_output.create_load()), tree_ir.DeleteEdgeInstruction(last_output_link.create_load())) raise primitive_functions.PrimitiveFinished(result) instruction_analyzers = { 'if' : analyze_if, 'while' : analyze_while, 'return' : analyze_return, 'constant' : analyze_constant, 'output' : analyze_output }