import keyword import modelverse_kernel.primitives as primitive_functions import modelverse_jit.bytecode_parser as bytecode_parser import modelverse_jit.bytecode_to_tree as bytecode_to_tree import modelverse_jit.bytecode_to_cfg as bytecode_to_cfg import modelverse_jit.cfg_optimization as cfg_optimization import modelverse_jit.cfg_to_tree as cfg_to_tree import modelverse_jit.cfg_ir as cfg_ir import modelverse_jit.tree_ir as tree_ir import modelverse_jit.runtime as jit_runtime # Import JitCompilationFailedException because it used to be defined # in this module. JitCompilationFailedException = jit_runtime.JitCompilationFailedException def map_and_simplify_generator(function, instruction): """Applies the given mapping function to every instruction in the tree that has the given instruction as root, and simplifies it on-the-fly. This is at least as powerful as first mapping and then simplifying, as maps and simplifications are interspersed. This function assumes that function creates a generator that returns by raising a primitive_functions.PrimitiveFinished.""" # First handle the children by mapping on them and then simplifying them. new_children = [] for inst in instruction.get_children(): new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])] new_children.append(new_inst) # Then apply the function to the top-level node. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])] # Finally, simplify the transformed top-level node. raise primitive_functions.PrimitiveFinished(transformed.simplify_node()) def expand_constant_read(instruction): """Tries to replace a read of a constant node by a literal.""" if isinstance(instruction, tree_ir.ReadValueInstruction) and \ isinstance(instruction.node_id, tree_ir.LiteralInstruction): val, = yield [("RV", [instruction.node_id.literal])] raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val)) else: raise primitive_functions.PrimitiveFinished(instruction) def optimize_tree_ir(instruction): """Optimizes an IR tree.""" return map_and_simplify_generator(expand_constant_read, instruction) def create_bare_function(function_name, parameter_list, function_body): """Creates a function definition from the given function name, parameter list and function body. No prolog is included.""" # Wrap the IR in a function definition, give it a unique name. return tree_ir.DefineFunctionInstruction( function_name, parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME], function_body) def create_function( function_name, parameter_list, param_dict, body_param_dict, function_body): """Creates a function from the given function name, parameter list, variable-to-parameter name map, variable-to-local name map and function body.""" # Write a prologue and prepend it to the generated function body. prologue_statements = [] # Create a LOCALS_NODE_NAME node, and connect it to the user root. prologue_statements.append( tree_ir.create_new_local_node( jit_runtime.LOCALS_NODE_NAME, tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('task_root')), jit_runtime.LOCALS_EDGE_NAME)) for (key, val) in param_dict.items(): arg_ptr = tree_ir.create_new_local_node( body_param_dict[key], tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME)) prologue_statements.append(arg_ptr) prologue_statements.append( tree_ir.CreateDictionaryEdgeInstruction( tree_ir.LoadLocalInstruction(body_param_dict[key]), tree_ir.LiteralInstruction('value'), tree_ir.LoadLocalInstruction(val))) constructed_body = tree_ir.create_block( *(prologue_statements + [function_body])) # Shield temporaries from the GC. constructed_body = tree_ir.protect_temporaries_from_gc( constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME)) return create_bare_function(function_name, parameter_list, constructed_body) def print_value(val): """A thin wrapper around 'print'.""" print(val) class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self, max_instructions=None, compiled_function_lookup=None): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_parameters = {} self.jit_globals = { 'PrimitiveFinished' : primitive_functions.PrimitiveFinished, jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function, jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input, jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant, jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global } # jitted_entry_points maps body ids to values in jit_globals. self.jitted_entry_points = {} # global_functions maps global value names to body ids. self.global_functions = {} # global_functions_inv maps body ids to global value names. self.global_functions_inv = {} # bytecode_graphs maps body ids to their parsed bytecode graphs. self.bytecode_graphs = {} self.jit_count = 0 self.max_instructions = max_instructions self.compiled_function_lookup = compiled_function_lookup # jit_intrinsics is a function name -> intrinsic map. self.jit_intrinsics = {} # cfg_jit_intrinsics is a function name -> intrinsic map. self.cfg_jit_intrinsics = {} self.compilation_dependencies = {} self.jit_enabled = True self.direct_calls_allowed = True self.tracing_enabled = False self.input_function_enabled = False self.nop_insertion_enabled = True self.thunks_enabled = True self.jit_success_log_function = None self.jit_code_log_function = None self.compile_function_body = compile_function_body_baseline def set_jit_enabled(self, is_enabled=True): """Enables or disables the JIT.""" self.jit_enabled = is_enabled def allow_direct_calls(self, is_allowed=True): """Allows or disallows direct calls from jitted to jitted code.""" self.direct_calls_allowed = is_allowed def use_input_function(self, is_enabled=True): """Configures the JIT to compile 'input' instructions as function calls.""" self.input_function_enabled = is_enabled def enable_tracing(self, is_enabled=True): """Enables or disables tracing for jitted code.""" self.tracing_enabled = is_enabled def enable_nop_insertion(self, is_enabled=True): """Enables or disables nop insertion for jitted code. If enabled, the JIT will insert nops at loop back-edges. Inserting nops sacrifices performance to keep the jitted code from blocking the thread of execution and consuming all resources; nops give the Modelverse server an opportunity to interrupt the currently running code.""" self.nop_insertion_enabled = is_enabled def enable_thunks(self, is_enabled=True): """Enables or disables thunks for jitted code. Thunks delay the compilation of functions until they are actually used. Thunks generally reduce start-up time. Thunks are enabled by default.""" self.thunks_enabled = is_enabled def set_jit_success_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Success and failure messages for specific functions are then sent to said log.""" self.jit_success_log_function = log_function def set_jit_code_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Function definitions of jitted functions are then sent to said log.""" self.jit_code_log_function = log_function def set_function_body_compiler(self, compile_function_body): """Sets the function that the JIT uses to compile function bodies.""" self.compile_function_body = compile_function_body def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable. This only returns `True` if the JIT is enabled and the function entry point has been marked jittable, or if the function has already been compiled.""" return ((self.jit_enabled and body_id in self.todo_entry_points) or self.has_compiled(body_id)) def has_compiled(self, body_id): """Tests if the function belonging to the given body node has been compiled yet.""" return body_id in self.jitted_entry_points def get_compiled_name(self, body_id): """Gets the name of the compiled version of the given body node in the JIT global state.""" if body_id in self.jitted_entry_points: return self.jitted_entry_points[body_id] else: return None def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def generate_name(self, infix, suggested_name=None): """Generates a new name or picks the suggested name if it is still available.""" if suggested_name is not None \ and suggested_name not in self.jit_globals \ and not keyword.iskeyword(suggested_name): self.jit_count += 1 return suggested_name else: function_name = 'jit_%s%d' % (infix, self.jit_count) self.jit_count += 1 return function_name def generate_function_name(self, body_id, suggested_name=None): """Generates a new function name or picks the suggested name if it is still available.""" if suggested_name is None: suggested_name = self.get_global_name(body_id) return self.generate_name('func', suggested_name) def register_global(self, body_id, global_name): """Associates the given body id with the given global name.""" self.global_functions[global_name] = body_id self.global_functions_inv[body_id] = global_name def get_global_name(self, body_id): """Gets the name of the global function with the given body id. Returns None if no known global exists with the given id.""" if body_id in self.global_functions_inv: return self.global_functions_inv[body_id] else: return None def get_global_body_id(self, global_name): """Gets the body id of the global function with the given name. Returns None if no known global exists with the given name.""" if global_name in self.global_functions: return self.global_functions[global_name] else: return None def register_compiled(self, body_id, compiled_function, function_name=None): """Registers a compiled entry point with the JIT.""" # Get the function's name. actual_function_name = self.generate_function_name(body_id, function_name) # Map the body id to the given parameter list. self.jitted_entry_points[body_id] = actual_function_name self.jit_globals[actual_function_name] = compiled_function if function_name is not None: self.register_global(body_id, function_name) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def import_value(self, value, suggested_name=None): """Imports the given value into the JIT's global scope, with the given suggested name. The actual name of the value (within the JIT's global scope) is returned.""" actual_name = self.generate_name('import', suggested_name) self.jit_globals[actual_name] = value return actual_name def __lookup_compiled_body_impl(self, body_id): """Looks up a compiled function by body id. Returns a matching function, or None if no function was found.""" if body_id is not None and body_id in self.jitted_entry_points: return self.jit_globals[self.jitted_entry_points[body_id]] else: return None def __lookup_external_body_impl(self, global_name, body_id): """Looks up an external function by global name. Returns a matching function, or None if no function was found.""" if self.compiled_function_lookup is not None: result = self.compiled_function_lookup(global_name) if result is not None and body_id is not None: self.register_compiled(body_id, result, global_name) return result else: return None def lookup_compiled_body(self, body_id): """Looks up a compiled function by body id. Returns a matching function, or None if no function was found.""" result = self.__lookup_compiled_body_impl(body_id) if result is not None: return result else: global_name = self.get_global_name(body_id) return self.__lookup_external_body_impl(global_name, body_id) def lookup_compiled_function(self, global_name): """Looks up a compiled function by global name. Returns a matching function, or None if no function was found.""" body_id = self.get_global_body_id(global_name) result = self.__lookup_compiled_body_impl(body_id) if result is not None: return result else: return self.__lookup_external_body_impl(global_name, body_id) def get_intrinsic(self, name): """Tries to find an intrinsic version of the function with the given name.""" if name in self.jit_intrinsics: return self.jit_intrinsics[name] else: return None def get_cfg_intrinsic(self, name): """Tries to find an intrinsic version of the function with the given name that is specialized for CFGs.""" if name in self.cfg_jit_intrinsics: return self.cfg_jit_intrinsics[name] else: return None def register_intrinsic(self, name, intrinsic_function, cfg_intrinsic_function=None): """Registers the given intrisic with the JIT. This will make the JIT replace calls to the function with the given entry point by an application of the specified function.""" self.jit_intrinsics[name] = intrinsic_function if cfg_intrinsic_function is not None: self.cfg_jit_intrinsics[name] = cfg_intrinsic_function def register_binary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given binary operation.""" self.register_intrinsic( name, lambda a, b: tree_ir.CreateNodeWithValueInstruction( tree_ir.BinaryInstruction( tree_ir.ReadValueInstruction(a), operator, tree_ir.ReadValueInstruction(b))), lambda original_def, a, b: original_def.redefine( cfg_ir.CreateNode( original_def.insert_before( cfg_ir.Binary( original_def.insert_before(cfg_ir.Read(a)), operator, original_def.insert_before(cfg_ir.Read(b))))))) def register_unary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given unary operation.""" self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.UnaryInstruction( operator, tree_ir.ReadValueInstruction(a)))) def register_cast_intrinsic(self, name, target_type): """Registers an intrinsic with the JIT that represents a unary conversion operator.""" self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.CallInstruction( tree_ir.LoadGlobalInstruction(target_type.__name__), [tree_ir.ReadValueInstruction(a)]))) def jit_signature(self, body_id): """Acquires the signature for the given body id node, which consists of the parameter variables, parameter name and a flag that tells if the given function is mutable.""" if body_id not in self.jitted_parameters: signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])] signature_id = signature_id[0] param_set_id, is_mutable = yield [ ("RD", [signature_id, "params"]), ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])] if param_set_id is None: self.jitted_parameters[body_id] = ([], [], is_mutable) else: param_name_ids, = yield [("RDK", [param_set_id])] param_names = yield [("RV", [n]) for n in param_name_ids] param_vars = yield [("RD", [param_set_id, k]) for k in param_names] self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable) raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id]) def jit_parse_bytecode(self, body_id): """Parses the given function body as a bytecode graph.""" if body_id in self.bytecode_graphs: raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id]) parser = bytecode_parser.BytecodeParser() result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])] self.bytecode_graphs[body_id] = result raise primitive_functions.PrimitiveFinished(result) def check_jittable(self, body_id, suggested_name=None): """Checks if the function with the given body id is obviously non-jittable. If it's non-jittable, then a `JitCompilationFailedException` exception is thrown.""" if body_id is None: raise ValueError('body_id cannot be None') elif body_id in self.jitted_entry_points: # We have already compiled this function. raise primitive_functions.PrimitiveFinished( self.jit_globals[self.jitted_entry_points[body_id]]) elif body_id in self.no_jit_entry_points: # We're not allowed to jit this function or have tried and failed before. raise JitCompilationFailedException( 'Cannot jit function %s at %d because it is marked non-jittable.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) elif not self.jit_enabled: # We're not allowed to jit anything. raise JitCompilationFailedException( 'Cannot jit function %s at %d because the JIT has been disabled.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) def jit_recompile(self, task_root, body_id, function_name): """Replaces the function with the given name by compiling the bytecode at the given body id.""" self.check_jittable(body_id, function_name) # Generate a name for the function we're about to analyze, and pretend that # it already exists. (we need to do this for recursive functions) self.jitted_entry_points[body_id] = function_name self.jit_globals[function_name] = None (_, _, is_mutable), = yield [ ("CALL_ARGS", [self.jit_signature, (body_id,)])] dependencies = set([body_id]) self.compilation_dependencies[body_id] = dependencies def handle_jit_exception(exception): # If analysis fails, then a JitCompilationFailedException will be thrown. del self.compilation_dependencies[body_id] for dep in dependencies: self.mark_no_jit(dep) if dep in self.jitted_entry_points: del self.jitted_entry_points[dep] failure_message = "%s (function '%s' at %d)" % ( exception.message, function_name, body_id) if self.jit_success_log_function is not None: self.jit_success_log_function('JIT compilation failed: %s' % failure_message) raise JitCompilationFailedException(failure_message) # Try to analyze the function's body. yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])] if is_mutable: # We can't just JIT mutable functions. That'd be dangerous. raise JitCompilationFailedException( "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY) constructed_function, = yield [ ("CALL_ARGS", [self.compile_function_body, (self, function_name, body_id, task_root)])] yield [("END_TRY", [])] del self.compilation_dependencies[body_id] # Convert the function definition to Python code, and compile it. compiled_function = self.jit_define_function(function_name, constructed_function) if self.jit_success_log_function is not None: self.jit_success_log_function( "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id)) raise primitive_functions.PrimitiveFinished(compiled_function) def jit_define_function(self, function_name, function_def): """Converts the given tree-IR function definition to Python code, defines it, and extracts the resulting function.""" # The comment below makes pylint shut up about our (hopefully benign) use of exec here. # pylint: disable=I0011,W0122 if self.jit_code_log_function is not None: self.jit_code_log_function(function_def) # Convert the function definition to Python code, and compile it. exec(str(function_def), self.jit_globals) # Extract the compiled function from the JIT global state. return self.jit_globals[function_name] def jit_delete_function(self, function_name): """Deletes the function with the given function name.""" del self.jit_globals[function_name] def jit_compile(self, task_root, body_id, suggested_name=None): """Tries to jit the function defined by the given entry point id and parameter list.""" # Generate a name for the function we're about to analyze, and pretend that # it already exists. (we need to do this for recursive functions) function_name = self.generate_function_name(body_id, suggested_name) yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])] def jit_thunk(self, get_function_body, global_name=None): """Creates a thunk from the given IR tree that computes the function's body id. This thunk is a function that will invoke the function whose body id is retrieved. The thunk's name in the JIT's global context is returned.""" # The general idea is to first create a function that looks a bit like this: # # def jit_get_function_body(**kwargs): # raise primitive_functions.PrimitiveFinished() # get_function_body_name = self.generate_name('get_function_body') get_function_body_func_def = create_function( get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body)) get_function_body_func = self.jit_define_function( get_function_body_name, get_function_body_func_def) # Next, we want to create a thunk that invokes said function, and then replaces itself. thunk_name = self.generate_name('thunk', global_name) def __jit_thunk(**kwargs): # Compute the body id, and delete the function that computes the body id; we won't # be needing it anymore after this call. body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])] self.jit_delete_function(get_function_body_name) # Try to associate the global name with the body id, if that's at all possible. if global_name is not None: self.register_global(body_id, global_name) compiled_function = self.lookup_compiled_body(body_id) if compiled_function is not None: # Replace this thunk by the compiled function. self.jit_globals[thunk_name] = compiled_function else: def __handle_jit_exception(_): # Replace this thunk by a different thunk: one that calls the interpreter # directly, without checking if the function is jittable. (_, parameter_names, _), = yield [ ("CALL_ARGS", [self.jit_signature, (body_id,)])] def __interpreter_thunk(**new_kwargs): named_arg_dict = {name : new_kwargs[name] for name in parameter_names} return jit_runtime.interpret_function_body( body_id, named_arg_dict, **new_kwargs) self.jit_globals[thunk_name] = __interpreter_thunk yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])] compiled_function, = yield [ ("CALL_ARGS", [self.jit_recompile, (kwargs['task_root'], body_id, thunk_name)])] yield [("END_TRY", [])] # Call the compiled function. yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])] self.jit_globals[thunk_name] = __jit_thunk return thunk_name def jit_thunk_constant(self, body_id): """Creates a thunk from given body id. This thunk is a function that will invoke the function whose body id is given. The thunk's name in the JIT's global context is returned.""" self.lookup_compiled_body(body_id) compiled_name = self.get_compiled_name(body_id) if compiled_name is not None: # We might have compiled the function with the given body id already. In that case, # we need not bother with constructing the thunk; we can return the compiled function # right away. return compiled_name else: # Looks like we'll just have to build that thunk after all. return self.jit_thunk(tree_ir.LiteralInstruction(body_id)) def jit_thunk_global(self, global_name): """Creates a thunk from given global name. This thunk is a function that will invoke the function whose body id is given. The thunk's name in the JIT's global context is returned.""" # We might have compiled the function with the given name already. In that case, # we need not bother with constructing the thunk; we can return the compiled function # right away. body_id = self.get_global_body_id(global_name) if body_id is not None: self.lookup_compiled_body(body_id) compiled_name = self.get_compiled_name(body_id) if compiled_name is not None: return compiled_name # Looks like we'll just have to build that thunk after all. # We want to look up the global function like so # # _globals, = yield [("RD", [kwargs['task_root'], "globals"])] # global_var, = yield [("RD", [_globals, global_name])] # function_id, = yield [("RD", [global_var, "value"])] # body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])] # return self.jit_thunk( tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('task_root')), tree_ir.LiteralInstruction('globals')), tree_ir.LiteralInstruction(global_name)), tree_ir.LiteralInstruction('value')), tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)), global_name) def compile_function_body_baseline(jit, function_name, body_id, task_root): """Have the baseline JIT compile the function with the given name and body id.""" (parameter_ids, parameter_list, _), = yield [ ("CALL_ARGS", [jit.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list])) body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])] state = bytecode_to_tree.AnalysisState( jit, body_id, task_root, body_param_dict, jit.max_instructions) constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])] # Optimize the function's body. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])] # Wrap the tree IR in a function definition. raise primitive_functions.PrimitiveFinished( create_function( function_name, parameter_list, param_dict, body_param_dict, constructed_body)) def compile_function_body_fast(jit, function_name, body_id, _): """Have the fast JIT compile the function with the given name and body id.""" (parameter_ids, parameter_list, _), = yield [ ("CALL_ARGS", [jit.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])] bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict) bytecode_analyzer.analyze(body_bytecode) yield [ ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])] if jit.jit_code_log_function is not None: jit.jit_code_log_function( "CFG for function '%s' at '%d':\n%s" % ( function_name, body_id, '\n'.join( map( str, cfg_ir.get_all_reachable_blocks( bytecode_analyzer.entry_point))))) raise primitive_functions.PrimitiveFinished( create_bare_function( function_name, parameter_list, cfg_to_tree.lower_flow_graph(bytecode_analyzer.entry_point, jit)))