import modelverse_kernel.primitives as primitive_functions import modelverse_jit.tree_ir as tree_ir import modelverse_jit.runtime as jit_runtime import keyword # Import JitCompilationFailedException because it used to be defined # in this module. JitCompilationFailedException = jit_runtime.JitCompilationFailedException KWARGS_PARAMETER_NAME = "kwargs" """The name of the kwargs parameter in jitted functions.""" CALL_FUNCTION_NAME = "__call_function" """The name of the '__call_function' function, in the jitted function scope.""" GET_INPUT_FUNCTION_NAME = "__get_input" """The name of the '__get_input' function, in the jitted function scope.""" LOCALS_NODE_NAME = "jit_locals" """The name of the node that is connected to all JIT locals in a given function call.""" LOCALS_EDGE_NAME = "jit_locals_edge" """The name of the edge that connects the LOCALS_NODE_NAME node to a user root.""" def get_parameter_names(compiled_function): """Gets the given compiled function's parameter names.""" if hasattr(compiled_function, '__code__'): return compiled_function.__code__.co_varnames[ :compiled_function.__code__.co_argcount] elif hasattr(compiled_function, '__init__'): return get_parameter_names(compiled_function.__init__)[1:] else: raise ValueError("'compiled_function' must be a function or a type.") def apply_intrinsic(intrinsic_function, named_args): """Applies the given intrinsic to the given sequence of named arguments.""" param_names = get_parameter_names(intrinsic_function) if tuple(param_names) == tuple([n for n, _ in named_args]): # Perfect match. Yay! return intrinsic_function(**dict(named_args)) else: # We'll have to store the arguments into locals to preserve # the order of evaluation. stored_args = [(name, tree_ir.StoreLocalInstruction(None, arg)) for name, arg in named_args] arg_value_dict = dict([(name, arg.create_load()) for name, arg in stored_args]) store_instructions = [instruction for _, instruction in stored_args] return tree_ir.CompoundInstruction( tree_ir.create_block(*store_instructions), intrinsic_function(**arg_value_dict)) def map_and_simplify_generator(function, instruction): """Applies the given mapping function to every instruction in the tree that has the given instruction as root, and simplifies it on-the-fly. This is at least as powerful as first mapping and then simplifying, as maps and simplifications are interspersed. This function assumes that function creates a generator that returns by raising a primitive_functions.PrimitiveFinished.""" # First handle the children by mapping on them and then simplifying them. new_children = [] for inst in instruction.get_children(): new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])] new_children.append(new_inst) # Then apply the function to the top-level node. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])] # Finally, simplify the transformed top-level node. raise primitive_functions.PrimitiveFinished(transformed.simplify_node()) def expand_constant_read(instruction): """Tries to replace a read of a constant node by a literal.""" if isinstance(instruction, tree_ir.ReadValueInstruction) and \ isinstance(instruction.node_id, tree_ir.LiteralInstruction): val, = yield [("RV", [instruction.node_id.literal])] raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val)) else: raise primitive_functions.PrimitiveFinished(instruction) def optimize_tree_ir(instruction): """Optimizes an IR tree.""" return map_and_simplify_generator(expand_constant_read, instruction) def print_value(val): """A thin wrapper around 'print'.""" print(val) class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self, max_instructions=None, compiled_function_lookup=None): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_entry_points = {} self.jitted_parameters = {} self.jit_globals = { 'PrimitiveFinished' : primitive_functions.PrimitiveFinished, CALL_FUNCTION_NAME : jit_runtime.call_function, GET_INPUT_FUNCTION_NAME : jit_runtime.get_input } self.jit_count = 0 self.max_instructions = max_instructions self.compiled_function_lookup = compiled_function_lookup # jit_intrinsics is a function name -> intrinsic map. self.jit_intrinsics = {} self.compilation_dependencies = {} self.jit_enabled = True self.direct_calls_allowed = True self.tracing_enabled = False self.input_function_enabled = False self.nop_insertion_enabled = True self.jit_success_log_function = None self.jit_code_log_function = None def set_jit_enabled(self, is_enabled=True): """Enables or disables the JIT.""" self.jit_enabled = is_enabled def allow_direct_calls(self, is_allowed=True): """Allows or disallows direct calls from jitted to jitted code.""" self.direct_calls_allowed = is_allowed def use_input_function(self, is_enabled=True): """Configures the JIT to compile 'input' instructions as function calls.""" self.input_function_enabled = is_enabled def enable_tracing(self, is_enabled=True): """Enables or disables tracing for jitted code.""" self.tracing_enabled = is_enabled def enable_nop_insertion(self, is_enabled=True): """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop back-edges. Inserting nops sacrifices performance to keep the jitted code from blocking the thread of execution by consuming all resources; nops give the Modelverse server an opportunity to interrupt the currently running code.""" self.nop_insertion_enabled = is_enabled def set_jit_success_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Success and failure messages for specific functions are then sent to said log.""" self.jit_success_log_function = log_function def set_jit_code_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Function definitions of jitted functions are then sent to said log.""" self.jit_code_log_function = log_function def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable. This only returns `True` if the JIT is enabled and the function entry point has been marked jittable, or if the function has already been compiled.""" return ((self.jit_enabled and body_id in self.todo_entry_points) or self.has_compiled(body_id)) def has_compiled(self, body_id): """Tests if the function belonging to the given body node has been compiled yet.""" return body_id in self.jitted_entry_points def get_compiled_name(self, body_id): """Gets the name of the compiled version of the given body node in the JIT global state.""" return self.jitted_entry_points[body_id] def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def generate_name(self, infix, suggested_name=None): """Generates a new name or picks the suggested name if it is still available.""" if suggested_name is not None \ and suggested_name not in self.jit_globals \ and not keyword.iskeyword(suggested_name): self.jit_count += 1 return suggested_name else: function_name = 'jit_%s%d' % (infix, self.jit_count) self.jit_count += 1 return function_name def generate_function_name(self, suggested_name=None): """Generates a new function name or picks the suggested name if it is still available.""" return self.generate_name('func', suggested_name) def register_compiled(self, body_id, compiled_function, function_name=None): """Registers a compiled entry point with the JIT.""" # Get the function's name. function_name = self.generate_function_name(function_name) # Map the body id to the given parameter list. self.jitted_entry_points[body_id] = function_name self.jit_globals[function_name] = compiled_function if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def import_value(self, value, suggested_name=None): """Imports the given value into the JIT's global scope, with the given suggested name. The actual name of the value (within the JIT's global scope) is returned.""" actual_name = self.generate_name('import', suggested_name) self.jit_globals[actual_name] = value return actual_name def lookup_compiled_function(self, name): """Looks up a compiled function by name. Returns a matching function, or None if no function was found.""" if name is None: return None elif name in self.jit_globals: return self.jit_globals[name] elif self.compiled_function_lookup is not None: return self.compiled_function_lookup(name) else: return None def get_intrinsic(self, name): """Tries to find an intrinsic version of the function with the given name.""" if name in self.jit_intrinsics: return self.jit_intrinsics[name] else: return None def register_intrinsic(self, name, intrinsic_function): """Registers the given intrisic with the JIT. This will make the JIT replace calls to the function with the given entry point by an application of the specified function.""" self.jit_intrinsics[name] = intrinsic_function def register_binary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given binary operation.""" self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction( tree_ir.BinaryInstruction( tree_ir.ReadValueInstruction(a), operator, tree_ir.ReadValueInstruction(b)))) def register_unary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given unary operation.""" self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.UnaryInstruction( operator, tree_ir.ReadValueInstruction(a)))) def register_cast_intrinsic(self, name, target_type): """Registers an intrinsic with the JIT that represents a unary conversion operator.""" self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.CallInstruction( tree_ir.LoadGlobalInstruction(target_type.__name__), [tree_ir.ReadValueInstruction(a)]))) def jit_signature(self, body_id): """Acquires the signature for the given body id node, which consists of the parameter variables, parameter name and a flag that tells if the given function is mutable.""" if body_id not in self.jitted_parameters: signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])] signature_id = signature_id[0] param_set_id, is_mutable = yield [ ("RD", [signature_id, "params"]), ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])] if param_set_id is None: self.jitted_parameters[body_id] = ([], [], is_mutable) else: param_name_ids, = yield [("RDK", [param_set_id])] param_names = yield [("RV", [n]) for n in param_name_ids] param_vars = yield [("RD", [param_set_id, k]) for k in param_names] self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable) raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id]) def jit_compile(self, user_root, body_id, suggested_name=None): """Tries to jit the function defined by the given entry point id and parameter list.""" # The comment below makes pylint shut up about our (hopefully benign) use of exec here. # pylint: disable=I0011,W0122 if body_id is None: raise ValueError('body_id cannot be None') elif body_id in self.jitted_entry_points: # We have already compiled this function. raise primitive_functions.PrimitiveFinished( self.jit_globals[self.jitted_entry_points[body_id]]) elif body_id in self.no_jit_entry_points: # We're not allowed to jit this function or have tried and failed before. raise JitCompilationFailedException( 'Cannot jit function %s at %d because it is marked non-jittable.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) elif not self.jit_enabled: # We're not allowed to jit anything. raise JitCompilationFailedException( 'Cannot jit function %s at %d because the JIT has been disabled.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) # Generate a name for the function we're about to analyze, and pretend that # it already exists. (we need to do this for recursive functions) function_name = self.generate_function_name(suggested_name) self.jitted_entry_points[body_id] = function_name self.jit_globals[function_name] = None (parameter_ids, parameter_list, is_mutable), = yield [ ("CALL_ARGS", [self.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list])) dependencies = set([body_id]) self.compilation_dependencies[body_id] = dependencies def handle_jit_exception(exception): # If analysis fails, then a JitCompilationFailedException will be thrown. del self.compilation_dependencies[body_id] for dep in dependencies: self.mark_no_jit(dep) if dep in self.jitted_entry_points: del self.jitted_entry_points[dep] failure_message = "%s (function '%s' at %d)" % ( exception.message, function_name, body_id) if self.jit_success_log_function is not None: self.jit_success_log_function('JIT compilation failed: %s' % failure_message) raise JitCompilationFailedException(failure_message) # Try to analyze the function's body. yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])] if is_mutable: # We can't just JIT mutable functions. That'd be dangerous. raise JitCompilationFailedException( "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY) state = AnalysisState( self, body_id, user_root, body_param_dict, self.max_instructions) constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_id,)])] yield [("END_TRY", [])] del self.compilation_dependencies[body_id] # Write a prologue and prepend it to the generated function body. prologue_statements = [] # Create a LOCALS_NODE_NAME node, and connect it to the user root. prologue_statements.append( tree_ir.create_new_local_node( LOCALS_NODE_NAME, tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('user_root')), LOCALS_EDGE_NAME)) for (key, val) in param_dict.items(): arg_ptr = tree_ir.create_new_local_node( body_param_dict[key], tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME)) prologue_statements.append(arg_ptr) prologue_statements.append( tree_ir.CreateDictionaryEdgeInstruction( tree_ir.LoadLocalInstruction(body_param_dict[key]), tree_ir.LiteralInstruction('value'), tree_ir.LoadLocalInstruction(val))) constructed_body = tree_ir.create_block( *(prologue_statements + [constructed_body])) # Optimize the function's body. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])] # Shield temporaries from the GC. constructed_body = tree_ir.protect_temporaries_from_gc( constructed_body, tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME)) # Wrap the IR in a function definition, give it a unique name. constructed_function = tree_ir.DefineFunctionInstruction( function_name, parameter_list + ['**' + KWARGS_PARAMETER_NAME], constructed_body) # Convert the function definition to Python code, and compile it. exec(str(constructed_function), self.jit_globals) # Extract the compiled function from the JIT global state. compiled_function = self.jit_globals[function_name] if self.jit_success_log_function is not None: self.jit_success_log_function( "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id)) if self.jit_code_log_function is not None: self.jit_code_log_function(constructed_function) raise primitive_functions.PrimitiveFinished(compiled_function) class AnalysisState(object): """The state of a bytecode analysis call graph.""" def __init__(self, jit, body_id, user_root, local_mapping, max_instructions=None): self.analyzed_instructions = set() self.function_vars = set() self.local_vars = set() self.body_id = body_id self.max_instructions = max_instructions self.user_root = user_root self.jit = jit self.local_mapping = local_mapping self.function_name = jit.jitted_entry_points[body_id] self.enclosing_loop_instruction = None def get_local_name(self, local_id): """Gets the name for a local with the given id.""" if local_id not in self.local_mapping: self.local_mapping[local_id] = 'local%d' % local_id return self.local_mapping[local_id] def register_local_var(self, local_id): """Registers the given variable node id as a local.""" if local_id in self.function_vars: raise JitCompilationFailedException( "Local is used as target of function call.") self.local_vars.add(local_id) def register_function_var(self, local_id): """Registers the given variable node id as a function.""" if local_id in self.local_vars: raise JitCompilationFailedException( "Local is used as target of function call.") self.function_vars.add(local_id) def retrieve_user_root(self): """Creates an instruction that stores the user_root variable in a local.""" return tree_ir.StoreLocalInstruction( 'user_root', tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('user_root'))) def load_kernel(self): """Creates an instruction that loads the Modelverse kernel.""" return tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('mvk')) def analyze(self, instruction_id): """Tries to build an intermediate representation from the instruction with the given id.""" # Check the analyzed_instructions set for instruction_id to avoid # infinite loops. if instruction_id in self.analyzed_instructions: raise JitCompilationFailedException('Cannot jit non-tree instruction graph.') elif (self.max_instructions is not None and len(self.analyzed_instructions) > self.max_instructions): raise JitCompilationFailedException('Maximum number of instructions exceeded.') self.analyzed_instructions.add(instruction_id) instruction_val, = yield [("RV", [instruction_id])] instruction_val = instruction_val["value"] if instruction_val in self.instruction_analyzers: # If tracing is enabled, then this would be an appropriate time to # retrieve the debug information. if self.jit.tracing_enabled: debug_info, = yield [("RD", [instruction_id, "__debug"])] if debug_info is not None: debug_info, = yield [("RV", [debug_info])] # Analyze the instruction itself. outer_result, = yield [ ("CALL_ARGS", [self.instruction_analyzers[instruction_val], (self, instruction_id)])] if self.jit.tracing_enabled: outer_result = tree_ir.with_debug_info_trace(outer_result, debug_info, self.function_name) # Check if the instruction has a 'next' instruction. next_instr, = yield [("RD", [instruction_id, "next"])] if next_instr is None: raise primitive_functions.PrimitiveFinished(outer_result) else: next_result, = yield [("CALL_ARGS", [self.analyze, (next_instr,)])] raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( outer_result, next_result)) else: raise JitCompilationFailedException( "Unknown instruction type: '%s'" % (instruction_val)) def analyze_all(self, instruction_ids): """Tries to compile a list of IR trees from the given list of instruction ids.""" results = [] for inst in instruction_ids: analyzed_inst, = yield [("CALL_ARGS", [self.analyze, (inst,)])] results.append(analyzed_inst) raise primitive_functions.PrimitiveFinished(results) def analyze_return(self, instruction_id): """Tries to analyze the given 'return' instruction.""" retval_id, = yield [("RD", [instruction_id, 'value'])] def create_return(return_value): return tree_ir.ReturnInstruction( tree_ir.CompoundInstruction( return_value, tree_ir.DeleteEdgeInstruction( tree_ir.LoadLocalInstruction(LOCALS_EDGE_NAME)))) if retval_id is None: raise primitive_functions.PrimitiveFinished( create_return( tree_ir.EmptyInstruction())) else: retval, = yield [("CALL_ARGS", [self.analyze, (retval_id,)])] raise primitive_functions.PrimitiveFinished( create_return(retval)) def analyze_if(self, instruction_id): """Tries to analyze the given 'if' instruction.""" cond, true, false = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "then"]), ("RD", [instruction_id, "else"])] analysis_results, = yield [("CALL_ARGS", [self.analyze_all, ( [cond, true] if false is None else [cond, true, false],)])] if false is None: cond_r, true_r = analysis_results false_r = tree_ir.EmptyInstruction() else: cond_r, true_r, false_r = analysis_results raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), true_r, false_r)) def analyze_while(self, instruction_id): """Tries to analyze the given 'while' instruction.""" cond, body = yield [ ("RD", [instruction_id, "cond"]), ("RD", [instruction_id, "body"])] # Analyze the condition. cond_r, = yield [("CALL_ARGS", [self.analyze, (cond,)])] # Store the old enclosing loop on the stack, and make this loop the # new enclosing loop. old_loop_instruction = self.enclosing_loop_instruction self.enclosing_loop_instruction = instruction_id body_r, = yield [("CALL_ARGS", [self.analyze, (body,)])] # Restore hte old enclosing loop. self.enclosing_loop_instruction = old_loop_instruction if self.jit.nop_insertion_enabled: create_loop_body = lambda check, body: tree_ir.create_block( check, body_r, tree_ir.NopInstruction()) else: create_loop_body = tree_ir.CompoundInstruction raise primitive_functions.PrimitiveFinished( tree_ir.LoopInstruction( create_loop_body( tree_ir.SelectInstruction( tree_ir.ReadValueInstruction(cond_r), tree_ir.EmptyInstruction(), tree_ir.BreakInstruction()), body_r))) def analyze_constant(self, instruction_id): """Tries to analyze the given 'constant' (literal) instruction.""" node_id, = yield [("RD", [instruction_id, "node"])] raise primitive_functions.PrimitiveFinished( tree_ir.LiteralInstruction(node_id)) def analyze_output(self, instruction_id): """Tries to analyze the given 'output' instruction.""" # The plan is to basically generate this tree: # # value = # last_output, last_output_link, new_last_output = \ # yield [("RD", [user_root, "last_output"]), # ("RDE", [user_root, "last_output"]), # ("CN", []), # ] # _, _, _, _ = \ # yield [("CD", [last_output, "value", value]), # ("CD", [last_output, "next", new_last_output]), # ("CD", [user_root, "last_output", new_last_output]), # ("DE", [last_output_link]) # ] # yield None value_id, = yield [("RD", [instruction_id, "value"])] value_val, = yield [("CALL_ARGS", [self.analyze, (value_id,)])] value_local = tree_ir.StoreLocalInstruction('value', value_val) store_user_root = self.retrieve_user_root() last_output = tree_ir.StoreLocalInstruction( 'last_output', tree_ir.ReadDictionaryValueInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'))) last_output_link = tree_ir.StoreLocalInstruction( 'last_output_link', tree_ir.ReadDictionaryEdgeInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'))) new_last_output = tree_ir.StoreLocalInstruction( 'new_last_output', tree_ir.CreateNodeInstruction()) result = tree_ir.create_block( value_local, store_user_root, last_output, last_output_link, new_last_output, tree_ir.CreateDictionaryEdgeInstruction( last_output.create_load(), tree_ir.LiteralInstruction('value'), value_local.create_load()), tree_ir.CreateDictionaryEdgeInstruction( last_output.create_load(), tree_ir.LiteralInstruction('next'), new_last_output.create_load()), tree_ir.CreateDictionaryEdgeInstruction( store_user_root.create_load(), tree_ir.LiteralInstruction('last_output'), new_last_output.create_load()), tree_ir.DeleteEdgeInstruction(last_output_link.create_load()), tree_ir.NopInstruction()) raise primitive_functions.PrimitiveFinished(result) def analyze_input(self, _): """Tries to analyze the given 'input' instruction.""" # Possible alternative to the explicit syntax tree: if self.jit.input_function_enabled: raise primitive_functions.PrimitiveFinished( tree_ir.create_jit_call( tree_ir.LoadGlobalInstruction(GET_INPUT_FUNCTION_NAME), [], tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME))) # The plan is to generate this tree: # # value = None # while True: # _input = yield [("RD", [user_root, "input"])] # value = yield [("RD", [_input, "value"])] # # if value is None: # kwargs['mvk'].success = False # to avoid blocking # yield None # nop/interrupt # else: # break # # _next = yield [("RD", [_input, "next"])] # yield [("CD", [user_root, "input", _next])] # yield [("CE", [jit_locals, value])] # yield [("DN", [_input])] user_root = self.retrieve_user_root() _input = tree_ir.StoreLocalInstruction( None, tree_ir.ReadDictionaryValueInstruction( user_root.create_load(), tree_ir.LiteralInstruction('input'))) value = tree_ir.StoreLocalInstruction( None, tree_ir.ReadDictionaryValueInstruction( _input.create_load(), tree_ir.LiteralInstruction('value'))) raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( tree_ir.create_block( user_root, value.create_store(tree_ir.LiteralInstruction(None)), tree_ir.LoopInstruction( tree_ir.create_block( _input, value, tree_ir.SelectInstruction( tree_ir.BinaryInstruction( value.create_load(), 'is', tree_ir.LiteralInstruction(None)), tree_ir.create_block( tree_ir.StoreMemberInstruction( self.load_kernel(), 'success', tree_ir.LiteralInstruction(False)), tree_ir.NopInstruction()), tree_ir.BreakInstruction()))), tree_ir.CreateDictionaryEdgeInstruction( user_root.create_load(), tree_ir.LiteralInstruction('input'), tree_ir.ReadDictionaryValueInstruction( _input.create_load(), tree_ir.LiteralInstruction('next'))), tree_ir.CreateEdgeInstruction( tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME), value.create_load()), tree_ir.DeleteNodeInstruction(_input.create_load())), value.create_load())) def analyze_resolve(self, instruction_id): """Tries to analyze the given 'resolve' instruction.""" var_id, = yield [("RD", [instruction_id, "var"])] var_name, = yield [("RV", [var_id])] # To resolve a variable, we'll do something along the # lines of: # # if 'local_var' in locals(): # tmp = local_var # else: # _globals, = yield [("RD", [user_root, "globals"])] # global_var, = yield [("RD", [_globals, var_name])] # # if global_var is None: # raise Exception("Not found as global: %s" % (var_name)) # # tmp = global_var name = self.get_local_name(var_id) if var_name is None: raise primitive_functions.PrimitiveFinished( tree_ir.LoadLocalInstruction(name)) user_root = self.retrieve_user_root() global_var = tree_ir.StoreLocalInstruction( 'global_var', tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( user_root.create_load(), tree_ir.LiteralInstruction('globals')), tree_ir.LiteralInstruction(var_name))) err_block = tree_ir.SelectInstruction( tree_ir.BinaryInstruction( global_var.create_load(), 'is', tree_ir.LiteralInstruction(None)), tree_ir.RaiseInstruction( tree_ir.CallInstruction( tree_ir.LoadGlobalInstruction('Exception'), [tree_ir.LiteralInstruction( "Not found as global: %s" % var_name) ])), tree_ir.EmptyInstruction()) raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.LocalExistsInstruction(name), tree_ir.LoadLocalInstruction(name), tree_ir.CompoundInstruction( tree_ir.create_block( user_root, global_var, err_block), global_var.create_load()))) def analyze_declare(self, instruction_id): """Tries to analyze the given 'declare' function.""" var_id, = yield [("RD", [instruction_id, "var"])] self.register_local_var(var_id) name = self.get_local_name(var_id) # The following logic declares a local: # # if 'local_name' not in locals(): # local_name, = yield [("CN", [])] # yield [("CE", [LOCALS_NODE_NAME, local_name])] raise primitive_functions.PrimitiveFinished( tree_ir.SelectInstruction( tree_ir.LocalExistsInstruction(name), tree_ir.EmptyInstruction(), tree_ir.create_new_local_node( name, tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME)))) def analyze_global(self, instruction_id): """Tries to analyze the given 'global' (declaration) instruction.""" var_id, = yield [("RD", [instruction_id, "var"])] var_name, = yield [("RV", [var_id])] # To resolve a variable, we'll do something along the # lines of: # # _globals, = yield [("RD", [user_root, "globals"])] # global_var = yield [("RD", [_globals, var_name])] # # if global_var is None: # global_var, = yield [("CN", [])] # yield [("CD", [_globals, var_name, global_var])] # # tmp = global_var user_root = self.retrieve_user_root() _globals = tree_ir.StoreLocalInstruction( '_globals', tree_ir.ReadDictionaryValueInstruction( user_root.create_load(), tree_ir.LiteralInstruction('globals'))) global_var = tree_ir.StoreLocalInstruction( 'global_var', tree_ir.ReadDictionaryValueInstruction( _globals.create_load(), tree_ir.LiteralInstruction(var_name))) raise primitive_functions.PrimitiveFinished( tree_ir.CompoundInstruction( tree_ir.create_block( user_root, _globals, global_var, tree_ir.SelectInstruction( tree_ir.BinaryInstruction( global_var.create_load(), 'is', tree_ir.LiteralInstruction(None)), tree_ir.create_block( global_var.create_store( tree_ir.CreateNodeInstruction()), tree_ir.CreateDictionaryEdgeInstruction( _globals.create_load(), tree_ir.LiteralInstruction(var_name), global_var.create_load())), tree_ir.EmptyInstruction())), global_var.create_load())) def analyze_assign(self, instruction_id): """Tries to analyze the given 'assign' instruction.""" var_id, value_id = yield [("RD", [instruction_id, "var"]), ("RD", [instruction_id, "value"])] (var_r, value_r), = yield [("CALL_ARGS", [self.analyze_all, ([var_id, value_id],)])] # Assignments work like this: # # value_link = yield [("RDE", [variable, "value"])] # _, _ = yield [("CD", [variable, "value", value]), # ("DE", [value_link])] variable = tree_ir.StoreLocalInstruction(None, var_r) value = tree_ir.StoreLocalInstruction(None, value_r) value_link = tree_ir.StoreLocalInstruction( 'value_link', tree_ir.ReadDictionaryEdgeInstruction( variable.create_load(), tree_ir.LiteralInstruction('value'))) raise primitive_functions.PrimitiveFinished( tree_ir.create_block( variable, value, value_link, tree_ir.CreateDictionaryEdgeInstruction( variable.create_load(), tree_ir.LiteralInstruction('value'), value.create_load()), tree_ir.DeleteEdgeInstruction( value_link.create_load()))) def analyze_access(self, instruction_id): """Tries to analyze the given 'access' instruction.""" var_id, = yield [("RD", [instruction_id, "var"])] var_r, = yield [("CALL_ARGS", [self.analyze, (var_id,)])] # Accessing a variable is pretty easy. It really just boils # down to reading the value corresponding to the 'value' key # of the variable. # # value, = yield [("RD", [returnvalue, "value"])] raise primitive_functions.PrimitiveFinished( tree_ir.ReadDictionaryValueInstruction( var_r, tree_ir.LiteralInstruction('value'))) def analyze_direct_call(self, callee_id, callee_name, first_parameter_id): """Tries to analyze a direct 'call' instruction.""" self.register_function_var(callee_id) body_id, = yield [("RD", [callee_id, jit_runtime.FUNCTION_BODY_KEY])] # Make this function dependent on the callee. if body_id in self.jit.compilation_dependencies: self.jit.compilation_dependencies[body_id].add(self.body_id) # Figure out if the function might be an intrinsic. intrinsic = self.jit.get_intrinsic(callee_name) if intrinsic is None: compiled_func = self.jit.lookup_compiled_function(callee_name) if compiled_func is None: # Compile the callee. yield [("CALL_ARGS", [self.jit.jit_compile, (self.user_root, body_id, callee_name)])] else: self.jit.register_compiled(body_id, compiled_func, callee_name) # Get the callee's name. compiled_func_name = self.jit.get_compiled_name(body_id) # This handles the corner case where a constant node is called, like # 'call(constant(9), ...)'. In this case, `callee_name` is `None` # because 'constant(9)' doesn't give us a name. However, we can look up # the name of the function at a specific node. If that turns out to be # an intrinsic, then we still want to pick the intrinsic over a call. intrinsic = self.jit.get_intrinsic(compiled_func_name) # Analyze the argument dictionary. named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (first_parameter_id,)])] if intrinsic is not None: raise primitive_functions.PrimitiveFinished( apply_intrinsic(intrinsic, named_args)) else: raise primitive_functions.PrimitiveFinished( tree_ir.create_jit_call( tree_ir.LoadGlobalInstruction(compiled_func_name), named_args, tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME))) def analyze_arguments(self, first_argument_id): """Analyzes the parameter-to-argument mapping started by the specified first argument node.""" next_param = first_argument_id named_args = [] while next_param is not None: param_name_id, = yield [("RD", [next_param, "name"])] param_name, = yield [("RV", [param_name_id])] param_val_id, = yield [("RD", [next_param, "value"])] param_val, = yield [("CALL_ARGS", [self.analyze, (param_val_id,)])] named_args.append((param_name, param_val)) next_param, = yield [("RD", [next_param, "next_param"])] raise primitive_functions.PrimitiveFinished(named_args) def analyze_indirect_call(self, func_id, first_arg_id): """Analyzes a call to an unknown function.""" # First off, let's analyze the callee and the argument list. func_val, = yield [("CALL_ARGS", [self.analyze, (func_id,)])] named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (first_arg_id,)])] # Call the __call_function function to run the interpreter, like so: # # __call_function(function_id, { first_param_name : first_param_val, ... }, **kwargs) # dict_literal = tree_ir.DictionaryLiteralInstruction( [(tree_ir.LiteralInstruction(key), val) for key, val in named_args]) raise primitive_functions.PrimitiveFinished( tree_ir.create_jit_call( tree_ir.LoadGlobalInstruction(CALL_FUNCTION_NAME), [('function_id', func_val), ('named_arguments', dict_literal)], tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME))) def try_analyze_direct_call(self, func_id, first_param_id): """Tries to analyze the given 'call' instruction as a direct call.""" if not self.jit.direct_calls_allowed: raise JitCompilationFailedException('Direct calls are not allowed by the JIT.') # Figure out what the 'func' instruction's type is. func_instruction_op, = yield [("RV", [func_id])] if func_instruction_op['value'] == 'access': # 'access(resolve(var))' instructions are translated to direct calls. access_value_id, = yield [("RD", [func_id, "var"])] access_value_op, = yield [("RV", [access_value_id])] if access_value_op['value'] == 'resolve': resolved_var_id, = yield [("RD", [access_value_id, "var"])] resolved_var_name, = yield [("RV", [resolved_var_id])] # Try to look up the name as a global. _globals, = yield [("RD", [self.user_root, "globals"])] global_var, = yield [("RD", [_globals, resolved_var_name])] global_val, = yield [("RD", [global_var, "value"])] if global_val is not None: result, = yield [("CALL_ARGS", [self.analyze_direct_call, ( global_val, resolved_var_name, first_param_id)])] raise primitive_functions.PrimitiveFinished(result) elif func_instruction_op['value'] == 'constant': # 'const(func_id)' instructions are also translated to direct calls. function_val_id, = yield [("RD", [func_id, "node"])] result, = yield [("CALL_ARGS", [self.analyze_direct_call, ( function_val_id, None, first_param_id)])] raise primitive_functions.PrimitiveFinished(result) raise JitCompilationFailedException( "Cannot JIT function calls that target an unknown value as direct calls.") def analyze_call(self, instruction_id): """Tries to analyze the given 'call' instruction.""" func_id, first_param_id, = yield [("RD", [instruction_id, "func"]), ("RD", [instruction_id, "params"])] def handle_exception(exception): # Looks like we'll have to compile it as an indirect call. gen = self.analyze_indirect_call(func_id, first_param_id) result, = yield [("CALL", [gen])] raise primitive_functions.PrimitiveFinished(result) # Try to analyze the call as a direct call. yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_exception])] result, = yield [("CALL_ARGS", [self.try_analyze_direct_call, (func_id, first_param_id)])] yield [("END_TRY", [])] raise primitive_functions.PrimitiveFinished(result) def analyze_break(self, instruction_id): """Tries to analyze the given 'break' instruction.""" target_instruction_id, = yield [("RD", [instruction_id, "while"])] if target_instruction_id == self.enclosing_loop_instruction: raise primitive_functions.PrimitiveFinished(tree_ir.BreakInstruction()) else: raise JitCompilationFailedException( "Multilevel 'break' is not supported by the baseline JIT.") def analyze_continue(self, instruction_id): """Tries to analyze the given 'continue' instruction.""" target_instruction_id, = yield [("RD", [instruction_id, "while"])] if target_instruction_id == self.enclosing_loop_instruction: raise primitive_functions.PrimitiveFinished(tree_ir.ContinueInstruction()) else: raise JitCompilationFailedException( "Multilevel 'continue' is not supported by the baseline JIT.") instruction_analyzers = { 'if' : analyze_if, 'while' : analyze_while, 'return' : analyze_return, 'constant' : analyze_constant, 'resolve' : analyze_resolve, 'declare' : analyze_declare, 'global' : analyze_global, 'assign' : analyze_assign, 'access' : analyze_access, 'output' : analyze_output, 'input' : analyze_input, 'call' : analyze_call, 'break' : analyze_break, 'continue' : analyze_continue }