import modelverse_kernel.primitives as primitive_functions import modelverse_jit.tree_ir as tree_ir KWARGS_PARAMETER_NAME = "remainder" """The name of the kwargs parameter in jitted functions.""" class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self, max_instructions=None): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_entry_points = {} self.jit_globals = { 'PrimitiveFinished' : primitive_functions.PrimitiveFinished } self.jit_count = 0 self.max_instructions = 30 if max_instructions is None else max_instructions def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable.""" return body_id in self.todo_entry_points or \ body_id in self.jitted_entry_points def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def register_compiled(self, body_id, compiled): """Registers a compiled entry point with the JIT.""" self.jitted_entry_points[body_id] = compiled if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def try_jit(self, body_id, parameter_list): """Tries to jit the function defined by the given entry point id and parameter list.""" if body_id in self.jitted_entry_points: # We have already compiled this function. raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id]) elif body_id in self.no_jit_entry_points: # We're not allowed to jit this function or have tried and failed before. raise JitCompilationFailedException( 'Cannot jit function at %d because it is marked non-jittable.' % body_id) try: gen = AnalysisState(self.max_instructions).analyze(body_id) inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: constructed_body = ex.result except JitCompilationFailedException as ex: self.mark_no_jit(body_id) raise JitCompilationFailedException( '%s (function at %d)' % (ex.message, body_id)) # Wrap the IR in a function definition, give it a unique name. constructed_function = tree_ir.DefineFunctionInstruction( 'jit_func%d' % self.jit_count, parameter_list + ['**' + KWARGS_PARAMETER_NAME], constructed_body.simplify()) self.jit_count += 1 # Convert the function definition to Python code, and compile it. exec(str(constructed_function), self.jit_globals) # Extract the compiled function from the JIT global state. compiled_function = self.jit_globals[constructed_function.name] print(constructed_function) # Save the compiled function so we can reuse it later. self.jitted_entry_points[body_id] = compiled_function raise primitive_functions.PrimitiveFinished(compiled_function) class AnalysisState(object): """The state of a bytecode analysis call graph.""" def __init__(self, max_instructions=None): self.analyzed_instructions = set() self.max_instructions = max_instructions def get_local_name(self, local_id): """Gets the name for a local with the given id.""" return 'local%d' % local_id def retrieve_user_root(self): """Creates an instruction that stores the user_root variable in a local.""" return tree_ir.StoreLocalInstruction( 'user_root', tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('user_root'))) def analyze(self, instruction_id): """Tries to build an intermediate representation from the instruction with the given id.""" # Check the analyzed_instructions set for instruction_id to avoid # infinite loops. if instruction_id in self.analyzed_instructions: raise JitCompilationFailedException('Cannot jit non-tree instruction graph.') elif (self.max_instructions is not None and len(self.analyzed_instructions) > self.max_instructions): raise JitCompilationFailedException('Maximal number of instructions exceeded.') self.analyzed_instructions.add(instruction_id) instruction_val, = yield [("RV", [instruction_id])] instruction_val = instruction_val["value"] if instruction_val in self.instruction_analyzers: gen = self.instruction_analyzers[instruction_val](self, instruction_id) try: inp = None while True: inp = yield gen.send(inp) except StopIteration: raise Exception( "Instruction analyzer (for '%s') finished without returning a value!" % (instruction_val)) except primitive_functions.PrimitiveFinished as outer_e: # Check if the instruction has a 'next' instruction. next_instr, = yield [("RD", [instruction_id, "next"])] if next_instr is None: raise outer_e else: gen = self.analyze(next_instr) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as inner_e: raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( outer_e.result, inner_e.result)) else: raise JitCompilationFailedException( "Unknown instruction type: '%s'" % (instruction_val)) def analyze_all(self, instruction_ids): """Tries to compile a list of IR trees from the given list of instruction ids.""" results = [] for inst in instruction_ids: gen = self.analyze(inst) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: results.append(ex.result) raise primitive_functions.PrimitiveFinished(results) def analyze_return(self, instruction_id): """Tries to analyze the given 'return' instruction.""" retval_id, = yield [("RD", [instruction_id, 'value'])] if retval_id is None: raise primitive_functions.PrimitiveFinished( tree_ir.ReturnInstruction( tree_ir.EmptyInstruction())) else: gen = self.analyze(retval_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: raise primitive_functions.PrimitiveFinished( tree_ir.ReturnInstruction(ex.result)) def analyze_if(self, instruction_id): """Tries to analyze the given 'if' instruction.""" cond, true, false = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "then"]), ("RD", [instruction_id, "else"])] gen = self.analyze_all( [cond, true] if false is None else [cond, true, false]) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: if false is None: cond_r, true_r = ex.result false_r = tree_ir.EmptyInstruction() else: cond_r, true_r, false_r = ex.result raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), true_r, false_r)) def analyze_while(self, instruction_id): """Tries to analyze the given 'while' instruction.""" cond, body = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "body"])] gen = self.analyze_all([cond, body]) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: cond_r, body_r = ex.result raise primitive_functions.PrimitiveFinished( tree_ir.LoopInstruction( tree_ir.CompoundInstruction( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), tree_ir.EmptyInstruction(), tree_ir.BreakInstruction()), body_r))) def analyze_constant(self, instruction_id): """Tries to analyze the given 'constant' (literal) instruction.""" node_id, = yield [("RD", [instruction_id, "node"])] raise primitive_functions.PrimitiveFinished( tree_ir.LiteralInstruction(node_id)) def analyze_output(self, instruction_id): """Tries to analyze the given 'output' instruction.""" # The plan is to basically generate this tree: # # value = # last_output, last_output_link, new_last_output = \ # yield [("RD", [user_root, "last_output"]), # ("RDE", [user_root, "last_output"]), # ("CN", []), # ] # _, _, _, _ = \ # yield [("CD", [last_output, "value", value]), # ("CD", [last_output, "next", new_last_output]), # ("CD", [user_root, "last_output", new_last_output]), # ("DE", [last_output_link]) # ] # yield None value_id, = yield [("RD", [instruction_id, "value"])] gen = self.analyze(value_id) try: inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: value_local = tree_ir.StoreLocalInstruction('value', ex.result) store_user_root = tree_ir.StoreLocalInstruction( 'user_root', tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('user_root'))) last_output = tree_ir.StoreLocalInstruction( 'last_output', tree_ir.ReadDictionaryValueInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'))) last_output_link = tree_ir.StoreLocalInstruction( 'last_output_link', tree_ir.ReadDictionaryEdgeInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'))) new_last_output = tree_ir.StoreLocalInstruction( 'new_last_output', tree_ir.CreateNodeInstruction()) result = tree_ir.create_block( value_local, store_user_root, last_output, last_output_link, new_last_output, tree_ir.CreateDictionaryEdgeInstruction( last_output.create_load(), tree_ir.LiteralInstruction('value'), value_local.create_load()), tree_ir.CreateDictionaryEdgeInstruction( last_output.create_load(), tree_ir.LiteralInstruction('next'), new_last_output.create_load()), tree_ir.CreateDictionaryEdgeInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'), new_last_output.create_load()), tree_ir.DeleteEdgeInstruction(last_output_link.create_load()), tree_ir.NopInstruction()) raise primitive_functions.PrimitiveFinished(result) def analyze_resolve(self, instruction_id): """Tries to analyze the given 'resolve' instruction.""" var_id, = yield [("RD", [instruction_id, "var"])] var_name, = yield [("RV", [var_id])] # To resolve a variable, we'll do something along the # lines of: # # if 'local_var' in locals(): # tmp = local_var # else: # _globals, = yield [("RD", [user_root, "globals"])] # global_var, = yield [("RD", [_globals, var_name])] # # if global_var is None: # raise Exception("Runtime error: global '%s' not found" % (var_name)) # # tmp = global_var user_root = self.retrieve_user_root() global_var = tree_ir.StoreLocalInstruction( 'global_var', tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( user_root.create_load(), tree_ir.LiteralInstruction('globals')), tree_ir.LiteralInstruction(var_name))) err_block = tree_ir.SelectInstruction( tree_ir.BinaryInstruction( global_var.create_load(), 'is', tree_ir.LiteralInstruction(None)), tree_ir.RaiseInstruction( tree_ir.CallInstruction( tree_ir.LoadLocalInstruction('Exception'), [tree_ir.LiteralInstruction( "Runtime error: global '%s' not found" % var_name) ])), tree_ir.EmptyInstruction()) name = self.get_local_name(var_id) raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.LocalExistsInstruction(name), tree_ir.LoadLocalInstruction(name), tree_ir.CompoundInstruction( tree_ir.create_block( user_root, global_var, err_block), global_var.create_load()))) def analyze_declare(self, instruction_id): """Tries to analyze the given 'declare' function.""" var_id, = yield [("RD", [instruction_id, "var"])] name = self.get_local_name(var_id) # The following logic declares a local: # # if 'local_name' not in locals(): # local_name, = yield [("CN", [])] raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.LocalExistsInstruction(name), tree_ir.EmptyInstruction(), tree_ir.StoreLocalInstruction( name, tree_ir.CreateNodeInstruction()))) def analyze_global(self, instruction_id): """Tries to analyze the given 'global' (declaration) instruction.""" var_id, = yield [("RD", [instruction_id, "var"])] var_name, = yield [("RV", [var_id])] # To resolve a variable, we'll do something along the # lines of: # # _globals, = yield [("RD", [user_root, "globals"])] # global_var = yield [("RD", [_globals, var_name])] # # if global_var is None: # global_var, = yield [("CN", [])] # yield [("CD", [_globals, var_name, global_var])] # # tmp = global_var user_root = self.retrieve_user_root() _globals = tree_ir.StoreLocalInstruction( '_globals', tree_ir.ReadDictionaryValueInstruction( user_root.create_load(), tree_ir.LiteralInstruction('globals'))) global_var = tree_ir.StoreLocalInstruction( 'global_var', tree_ir.ReadDictionaryValueInstruction( _globals.create_load(), tree_ir.LiteralInstruction(var_name))) raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( tree_ir.create_block( user_root, _globals, global_var, tree_ir.SelectInstruction( tree_ir.BinaryInstruction( global_var.create_load(), 'is', tree_ir.LiteralInstruction(None)), tree_ir.create_block( global_var.create_store( tree_ir.CreateNodeInstruction()), tree_ir.CreateDictionaryEdgeInstruction( _globals.create_load(), tree_ir.LiteralInstruction(var_name), global_var.create_load())), tree_ir.EmptyInstruction())), global_var.create_load())) def analyze_assign(self, instruction_id): """Tries to analyze the given 'assign' instruction.""" var_id, value_id = yield [("RD", [instruction_id, "var"]), ("RD", [instruction_id, "value"])] try: gen = self.analyze_all([var_id, value_id]) inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: var_r, value_r = ex.result # Assignments work like this: # # value_link = yield [("RDE", [variable, "value"])] # _, _ = yield [("CD", [variable, "value", value]), # ("DE", [value_link])] variable = tree_ir.StoreLocalInstruction('variable', var_r) value = tree_ir.StoreLocalInstruction('value', value_r) value_link = tree_ir.StoreLocalInstruction( 'value_link', tree_ir.ReadDictionaryEdgeInstruction( variable.create_load(), tree_ir.LiteralInstruction('value'))) raise primitive_functions.PrimitiveFinished( tree_ir.create_block( variable, value, value_link, tree_ir.CreateDictionaryEdgeInstruction( variable.create_load(), tree_ir.LiteralInstruction('value'), value.create_load()), tree_ir.DeleteEdgeInstruction( value_link.create_load()))) def analyze_access(self, instruction_id): """Tries to analyze the given 'access' instruction.""" var_id, = yield [("RD", [instruction_id, "var"])] try: gen = self.analyze(var_id) inp = None while True: inp = yield gen.send(inp) except primitive_functions.PrimitiveFinished as ex: var_r = ex.result # Accessing a variable is pretty easy. It really just boils # down to reading the value corresponding to the 'value' key # of the variable. # # value, = yield [("RD", [returnvalue, "value"])] raise primitive_functions.PrimitiveFinished( tree_ir.ReadDictionaryValueInstruction( var_r, tree_ir.LiteralInstruction('value'))) instruction_analyzers = { 'if' : analyze_if, 'while' : analyze_while, 'return' : analyze_return, 'constant' : analyze_constant, 'resolve' : analyze_resolve, 'declare' : analyze_declare, 'global' : analyze_global, 'assign' : analyze_assign, 'access' : analyze_access, 'output' : analyze_output }