import math import keyword import time from collections import defaultdict import modelverse_kernel.primitives as primitive_functions import modelverse_jit.bytecode_parser as bytecode_parser import modelverse_jit.bytecode_to_tree as bytecode_to_tree import modelverse_jit.bytecode_to_cfg as bytecode_to_cfg import modelverse_jit.bytecode_ir as bytecode_ir import modelverse_jit.bytecode_interpreter as bytecode_interpreter import modelverse_jit.cfg_optimization as cfg_optimization import modelverse_jit.cfg_to_tree as cfg_to_tree import modelverse_jit.cfg_ir as cfg_ir import modelverse_jit.tree_ir as tree_ir import modelverse_jit.runtime as jit_runtime # Import JitCompilationFailedException because it used to be defined # in this module. JitCompilationFailedException = jit_runtime.JitCompilationFailedException def map_and_simplify_generator(function, instruction): """Applies the given mapping function to every instruction in the tree that has the given instruction as root, and simplifies it on-the-fly. This is at least as powerful as first mapping and then simplifying, as maps and simplifications are interspersed. This function assumes that function creates a generator that returns by raising a primitive_functions.PrimitiveFinished.""" # First handle the children by mapping on them and then simplifying them. new_children = [] for inst in instruction.get_children(): new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])] new_children.append(new_inst) # Then apply the function to the top-level node. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])] # Finally, simplify the transformed top-level node. raise primitive_functions.PrimitiveFinished(transformed.simplify_node()) def expand_constant_read(instruction): """Tries to replace a read of a constant node by a literal.""" if isinstance(instruction, tree_ir.ReadValueInstruction) and \ isinstance(instruction.node_id, tree_ir.LiteralInstruction): val, = yield [("RV", [instruction.node_id.literal])] raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val)) else: raise primitive_functions.PrimitiveFinished(instruction) def optimize_tree_ir(instruction): """Optimizes an IR tree.""" return map_and_simplify_generator(expand_constant_read, instruction) def create_bare_function(function_name, parameter_list, function_body): """Creates a function definition from the given function name, parameter list and function body. No prolog is included.""" # Wrap the IR in a function definition, give it a unique name. return tree_ir.DefineFunctionInstruction( function_name, parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME], function_body) def create_function( function_name, parameter_list, param_dict, body_param_dict, function_body, source_map_name=None, compatible_temporary_protects=False): """Creates a function from the given function name, parameter list, variable-to-parameter name map, variable-to-local name map and function body. An optional source map can be included, too.""" # Write a prologue and prepend it to the generated function body. prolog_statements = [] # If the source map is not None, then we should generate a "DEBUG_INFO" # request. if source_map_name is not None: prolog_statements.append( tree_ir.RegisterDebugInfoInstruction( tree_ir.LiteralInstruction(function_name), tree_ir.LoadGlobalInstruction(source_map_name), tree_ir.LiteralInstruction(jit_runtime.BASELINE_JIT_ORIGIN_NAME))) # Create a LOCALS_NODE_NAME node, and connect it to the user root. prolog_statements.append( tree_ir.create_new_local_node( jit_runtime.LOCALS_NODE_NAME, tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('task_root')), jit_runtime.LOCALS_EDGE_NAME)) for (key, val) in param_dict.items(): arg_ptr = tree_ir.create_new_local_node( body_param_dict[key], tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME)) prolog_statements.append(arg_ptr) prolog_statements.append( tree_ir.CreateDictionaryEdgeInstruction( tree_ir.LoadLocalInstruction(body_param_dict[key]), tree_ir.LiteralInstruction('value'), tree_ir.LoadLocalInstruction(val))) constructed_body = tree_ir.create_block( *(prolog_statements + [function_body])) # Shield temporaries from the GC. constructed_body = tree_ir.protect_temporaries_from_gc( constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME), compatible_temporary_protects) return create_bare_function(function_name, parameter_list, constructed_body) def print_value(val): """A thin wrapper around 'print'.""" print(val) class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self, max_instructions=None, compiled_function_lookup=None): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_parameters = {} self.jit_globals = { 'PrimitiveFinished' : primitive_functions.PrimitiveFinished, jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function, jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input, jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function, jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global, jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit, jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast, jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable } # jitted_entry_points maps body ids to values in jit_globals. self.jitted_entry_points = {} # global_functions maps global value names to body ids. self.global_functions = {} # global_functions_inv maps body ids to global value names. self.global_functions_inv = {} # bytecode_graphs maps body ids to their parsed bytecode graphs. self.bytecode_graphs = {} # jitted_function_aliases maps body ids to known aliases. self.jitted_function_aliases = defaultdict(set) self.jit_count = 0 self.max_instructions = max_instructions self.compiled_function_lookup = compiled_function_lookup # jit_intrinsics is a function name -> intrinsic map. self.jit_intrinsics = {} # cfg_jit_intrinsics is a function name -> intrinsic map. self.cfg_jit_intrinsics = {} self.compilation_dependencies = {} self.jit_enabled = True self.direct_calls_allowed = True self.tracing_enabled = False self.source_maps_enabled = True self.input_function_enabled = False self.nop_insertion_enabled = True self.thunks_enabled = True self.jit_success_log_function = None self.jit_code_log_function = None self.jit_timing_log = None self.compile_function_body = compile_function_body_baseline def set_jit_enabled(self, is_enabled=True): """Enables or disables the JIT.""" self.jit_enabled = is_enabled def allow_direct_calls(self, is_allowed=True): """Allows or disallows direct calls from jitted to jitted code.""" self.direct_calls_allowed = is_allowed def use_input_function(self, is_enabled=True): """Configures the JIT to compile 'input' instructions as function calls.""" self.input_function_enabled = is_enabled def enable_tracing(self, is_enabled=True): """Enables or disables tracing for jitted code.""" self.tracing_enabled = is_enabled def enable_source_maps(self, is_enabled=True): """Enables or disables the creation of source maps for jitted code. Source maps convert lines in the generated code to debug information. Source maps are enabled by default.""" self.source_maps_enabled = is_enabled def enable_nop_insertion(self, is_enabled=True): """Enables or disables nop insertion for jitted code. If enabled, the JIT will insert nops at loop back-edges. Inserting nops sacrifices performance to keep the jitted code from blocking the thread of execution and consuming all resources; nops give the Modelverse server an opportunity to interrupt the currently running code.""" self.nop_insertion_enabled = is_enabled def enable_thunks(self, is_enabled=True): """Enables or disables thunks for jitted code. Thunks delay the compilation of functions until they are actually used. Thunks generally reduce start-up time. Thunks are enabled by default.""" self.thunks_enabled = is_enabled def set_jit_success_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Success and failure messages for specific functions are then sent to said log.""" self.jit_success_log_function = log_function def set_jit_code_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Function definitions of jitted functions are then sent to said log.""" self.jit_code_log_function = log_function def set_function_body_compiler(self, compile_function_body): """Sets the function that the JIT uses to compile function bodies.""" self.compile_function_body = compile_function_body def set_jit_timing_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. The time it takes to compile functions is then sent to this log.""" self.jit_timing_log = log_function def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable. This only returns `True` if the JIT is enabled and the function entry point has been marked jittable, or if the function has already been compiled.""" return ((self.jit_enabled and body_id in self.todo_entry_points) or self.has_compiled(body_id)) def has_compiled(self, body_id): """Tests if the function belonging to the given body node has been compiled yet.""" return body_id in self.jitted_entry_points def get_compiled_name(self, body_id): """Gets the name of the compiled version of the given body node in the JIT global state.""" if body_id in self.jitted_entry_points: return self.jitted_entry_points[body_id] else: return None def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def generate_name(self, infix, suggested_name=None): """Generates a new name or picks the suggested name if it is still available.""" if suggested_name is not None \ and suggested_name not in self.jit_globals \ and not keyword.iskeyword(suggested_name): self.jit_count += 1 return suggested_name else: function_name = 'jit_%s%d' % (infix, self.jit_count) self.jit_count += 1 return function_name def generate_function_name(self, body_id, suggested_name=None): """Generates a new function name or picks the suggested name if it is still available.""" if suggested_name is None: suggested_name = self.get_global_name(body_id) return self.generate_name('func', suggested_name) def register_global(self, body_id, global_name): """Associates the given body id with the given global name.""" self.global_functions[global_name] = body_id self.global_functions_inv[body_id] = global_name def get_global_name(self, body_id): """Gets the name of the global function with the given body id. Returns None if no known global exists with the given id.""" if body_id in self.global_functions_inv: return self.global_functions_inv[body_id] else: return None def get_global_body_id(self, global_name): """Gets the body id of the global function with the given name. Returns None if no known global exists with the given name.""" if global_name in self.global_functions: return self.global_functions[global_name] else: return None def register_compiled(self, body_id, compiled_function, function_name=None): """Registers a compiled entry point with the JIT.""" # Get the function's name. actual_function_name = self.generate_function_name(body_id, function_name) # Map the body id to the given parameter list. self.jitted_entry_points[body_id] = actual_function_name self.jit_globals[actual_function_name] = compiled_function if function_name is not None: self.register_global(body_id, function_name) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def import_value(self, value, suggested_name=None): """Imports the given value into the JIT's global scope, with the given suggested name. The actual name of the value (within the JIT's global scope) is returned.""" actual_name = self.generate_name('import', suggested_name) self.jit_globals[actual_name] = value return actual_name def __lookup_compiled_body_impl(self, body_id): """Looks up a compiled function by body id. Returns a matching function, or None if no function was found.""" if body_id is not None and body_id in self.jitted_entry_points: return self.jit_globals[self.jitted_entry_points[body_id]] else: return None def __lookup_external_body_impl(self, global_name, body_id): """Looks up an external function by global name. Returns a matching function, or None if no function was found.""" if global_name is not None and self.compiled_function_lookup is not None: result = self.compiled_function_lookup(global_name) if result is not None and body_id is not None: self.register_compiled(body_id, result, global_name) return result else: return None def lookup_compiled_body(self, body_id): """Looks up a compiled function by body id. Returns a matching function, or None if no function was found.""" result = self.__lookup_compiled_body_impl(body_id) if result is not None: return result else: global_name = self.get_global_name(body_id) return self.__lookup_external_body_impl(global_name, body_id) def lookup_compiled_function(self, global_name): """Looks up a compiled function by global name. Returns a matching function, or None if no function was found.""" body_id = self.get_global_body_id(global_name) result = self.__lookup_compiled_body_impl(body_id) if result is not None: return result else: return self.__lookup_external_body_impl(global_name, body_id) def get_intrinsic(self, name): """Tries to find an intrinsic version of the function with the given name.""" if name in self.jit_intrinsics: return self.jit_intrinsics[name] else: return None def get_cfg_intrinsic(self, name): """Tries to find an intrinsic version of the function with the given name that is specialized for CFGs.""" if name in self.cfg_jit_intrinsics: return self.cfg_jit_intrinsics[name] else: return None def register_intrinsic(self, name, intrinsic_function, cfg_intrinsic_function=None): """Registers the given intrisic with the JIT. This will make the JIT replace calls to the function with the given entry point by an application of the specified function.""" self.jit_intrinsics[name] = intrinsic_function if cfg_intrinsic_function is not None: self.register_cfg_intrinsic(name, cfg_intrinsic_function) def register_cfg_intrinsic(self, name, cfg_intrinsic_function): """Registers the given intrisic with the JIT. This will make the JIT replace calls to the function with the given entry point by an application of the specified function.""" self.cfg_jit_intrinsics[name] = cfg_intrinsic_function def register_binary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given binary operation.""" self.register_intrinsic( name, lambda a, b: tree_ir.CreateNodeWithValueInstruction( tree_ir.BinaryInstruction( tree_ir.ReadValueInstruction(a), operator, tree_ir.ReadValueInstruction(b))), lambda original_def, a, b: original_def.redefine( cfg_ir.CreateNode( original_def.insert_before( cfg_ir.Binary( original_def.insert_before(cfg_ir.Read(a)), operator, original_def.insert_before(cfg_ir.Read(b))))))) def register_unary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given unary operation.""" self.register_intrinsic( name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.UnaryInstruction( operator, tree_ir.ReadValueInstruction(a))), lambda original_def, a: original_def.redefine( cfg_ir.CreateNode( original_def.insert_before( cfg_ir.Unary( operator, original_def.insert_before(cfg_ir.Read(a))))))) def register_cast_intrinsic(self, name, target_type): """Registers an intrinsic with the JIT that represents a unary conversion operator.""" self.register_intrinsic( name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.CallInstruction( tree_ir.LoadGlobalInstruction(target_type.__name__), [tree_ir.ReadValueInstruction(a)])), lambda original_def, a: original_def.redefine( cfg_ir.CreateNode( original_def.insert_before( cfg_ir.create_pure_simple_call( target_type.__name__, original_def.insert_before(cfg_ir.Read(a))))))) def jit_signature(self, body_id): """Acquires the signature for the given body id node, which consists of the parameter variables, parameter name and a flag that tells if the given function is mutable.""" if body_id not in self.jitted_parameters: signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])] signature_id = signature_id[0] param_set_id, is_mutable = yield [ ("RD", [signature_id, "params"]), ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])] if param_set_id is None: self.jitted_parameters[body_id] = ([], [], is_mutable) else: param_name_ids, = yield [("RDK", [param_set_id])] param_names = yield [("RV", [n]) for n in param_name_ids] param_vars = yield [("RD", [param_set_id, k]) for k in param_names] self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable) raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id]) def jit_parse_bytecode(self, body_id): """Parses the given function body as a bytecode graph.""" if body_id in self.bytecode_graphs: raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id]) parser = bytecode_parser.BytecodeParser() result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])] self.bytecode_graphs[body_id] = result raise primitive_functions.PrimitiveFinished(result) def check_jittable(self, body_id, suggested_name=None): """Checks if the function with the given body id is obviously non-jittable. If it's non-jittable, then a `JitCompilationFailedException` exception is thrown.""" if body_id is None: raise ValueError('body_id cannot be None') elif body_id in self.no_jit_entry_points: # We're not allowed to jit this function or have tried and failed before. raise JitCompilationFailedException( 'Cannot jit function %s at %d because it is marked non-jittable.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) elif not self.jit_enabled: # We're not allowed to jit anything. raise JitCompilationFailedException( 'Cannot jit function %s at %d because the JIT has been disabled.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None): """Replaces the function with the given name by compiling the bytecode at the given body id.""" if self.jit_timing_log is not None: start_time = time.time() if compile_function_body is None: compile_function_body = self.compile_function_body self.check_jittable(body_id, function_name) # Generate a name for the function we're about to analyze, and pretend that # it already exists. (we need to do this for recursive functions) self.jitted_entry_points[body_id] = function_name self.jit_globals[function_name] = None (_, _, is_mutable), = yield [ ("CALL_ARGS", [self.jit_signature, (body_id,)])] dependencies = set([body_id]) self.compilation_dependencies[body_id] = dependencies def handle_jit_exception(exception): # If analysis fails, then a JitCompilationFailedException will be thrown. del self.compilation_dependencies[body_id] for dep in dependencies: self.mark_no_jit(dep) if dep in self.jitted_entry_points: del self.jitted_entry_points[dep] failure_message = "%s (function '%s' at %d)" % ( exception.message, function_name, body_id) if self.jit_success_log_function is not None: self.jit_success_log_function('JIT compilation failed: %s' % failure_message) raise JitCompilationFailedException(failure_message) # Try to analyze the function's body. yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])] if is_mutable: # We can't just JIT mutable functions. That'd be dangerous. raise JitCompilationFailedException( "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY) compiled_function, = yield [ ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])] yield [("END_TRY", [])] del self.compilation_dependencies[body_id] if self.jit_success_log_function is not None: assert self.jitted_entry_points[body_id] == function_name self.jit_success_log_function( "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id)) if self.jit_timing_log is not None: end_time = time.time() compile_time = end_time - start_time self.jit_timing_log('Compile time for %s:%f' % (function_name, compile_time)) raise primitive_functions.PrimitiveFinished(compiled_function) def get_source_map_name(self, function_name): """Gets the name of the given jitted function's source map. None is returned if source maps are disabled.""" if self.source_maps_enabled: return function_name + "_source_map" else: return None def get_can_rejit_name(self, function_name): """Gets the name of the given jitted function's can-rejit flag.""" return function_name + "_can_rejit" def jit_define_function(self, function_name, function_def): """Converts the given tree-IR function definition to Python code, defines it, and extracts the resulting function.""" # The comment below makes pylint shut up about our (hopefully benign) use of exec here. # pylint: disable=I0011,W0122 if self.jit_code_log_function is not None: self.jit_code_log_function(function_def) # Convert the function definition to Python code, and compile it. code_generator = tree_ir.PythonGenerator() function_def.generate_python_def(code_generator) source_map_name = self.get_source_map_name(function_name) if source_map_name is not None: self.jit_globals[source_map_name] = code_generator.source_map_builder.source_map exec(str(code_generator), self.jit_globals) # Extract the compiled function from the JIT global state. return self.jit_globals[function_name] def jit_delete_function(self, function_name): """Deletes the function with the given function name.""" del self.jit_globals[function_name] def jit_compile(self, task_root, body_id, suggested_name=None): """Tries to jit the function defined by the given entry point id and parameter list.""" if body_id is None: raise ValueError('body_id cannot be None') elif body_id in self.jitted_entry_points: raise primitive_functions.PrimitiveFinished( self.jit_globals[self.jitted_entry_points[body_id]]) compiled_func = self.lookup_compiled_body(body_id) if compiled_func is not None: raise primitive_functions.PrimitiveFinished(compiled_func) # Generate a name for the function we're about to analyze, and 're-compile' # it for the first time. function_name = self.generate_function_name(body_id, suggested_name) yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])] def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None): """Re-compiles the given function. If compilation fails, then the can-rejit flag is set to false.""" old_jitted_func = self.jitted_entry_points[body_id] def __handle_jit_failed(_): self.jit_globals[self.get_can_rejit_name(function_name)] = False self.jitted_entry_points[body_id] = old_jitted_func self.no_jit_entry_points.remove(body_id) raise primitive_functions.PrimitiveFinished(None) yield [("TRY", [])] yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])] jitted_function, = yield [ ("CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])] yield [("END_TRY", [])] # Update all aliases. for function_alias in self.jitted_function_aliases[body_id]: self.jit_globals[function_alias] = jitted_function def jit_thunk(self, get_function_body, global_name=None): """Creates a thunk from the given IR tree that computes the function's body id. This thunk is a function that will invoke the function whose body id is retrieved. The thunk's name in the JIT's global context is returned.""" # The general idea is to first create a function that looks a bit like this: # # def jit_get_function_body(**kwargs): # raise primitive_functions.PrimitiveFinished() # get_function_body_name = self.generate_name('get_function_body') get_function_body_func_def = create_function( get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body)) get_function_body_func = self.jit_define_function( get_function_body_name, get_function_body_func_def) # Next, we want to create a thunk that invokes said function, and then replaces itself. thunk_name = self.generate_name('thunk', global_name) def __jit_thunk(**kwargs): # Compute the body id, and delete the function that computes the body id; we won't # be needing it anymore after this call. body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])] self.jit_delete_function(get_function_body_name) # Try to associate the global name with the body id, if that's at all possible. if global_name is not None: self.register_global(body_id, global_name) compiled_function = self.lookup_compiled_body(body_id) if compiled_function is not None: # Replace this thunk by the compiled function. self.jit_globals[thunk_name] = compiled_function self.jitted_function_aliases[body_id].add(thunk_name) else: def __handle_jit_exception(_): # Replace this thunk by a different thunk: one that calls the interpreter # directly, without checking if the function is jittable. (_, parameter_names, _), = yield [ ("CALL_ARGS", [self.jit_signature, (body_id,)])] def __interpreter_thunk(**new_kwargs): named_arg_dict = {name : new_kwargs[name] for name in parameter_names} return jit_runtime.interpret_function_body( body_id, named_arg_dict, **new_kwargs) self.jit_globals[thunk_name] = __interpreter_thunk yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])] compiled_function, = yield [ ("CALL_ARGS", [self.jit_recompile, (kwargs['task_root'], body_id, thunk_name)])] yield [("END_TRY", [])] # Call the compiled function. yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])] self.jit_globals[thunk_name] = __jit_thunk return thunk_name def jit_thunk_constant_body(self, body_id): """Creates a thunk from the given body id. This thunk is a function that will invoke the function whose body id is given. The thunk's name in the JIT's global context is returned.""" self.lookup_compiled_body(body_id) compiled_name = self.get_compiled_name(body_id) if compiled_name is not None: # We might have compiled the function with the given body id already. In that case, # we need not bother with constructing the thunk; we can return the compiled function # right away. return compiled_name else: # Looks like we'll just have to build that thunk after all. return self.jit_thunk(tree_ir.LiteralInstruction(body_id)) def jit_thunk_constant_function(self, body_id): """Creates a thunk from the given function id. This thunk is a function that will invoke the function whose function id is given. The thunk's name in the JIT's global context is returned.""" return self.jit_thunk( tree_ir.ReadDictionaryValueInstruction( tree_ir.LiteralInstruction(body_id), tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY))) def jit_thunk_global(self, global_name): """Creates a thunk from given global name. This thunk is a function that will invoke the function whose body id is given. The thunk's name in the JIT's global context is returned.""" # We might have compiled the function with the given name already. In that case, # we need not bother with constructing the thunk; we can return the compiled function # right away. body_id = self.get_global_body_id(global_name) if body_id is not None: self.lookup_compiled_body(body_id) compiled_name = self.get_compiled_name(body_id) if compiled_name is not None: return compiled_name # Looks like we'll just have to build that thunk after all. # We want to look up the global function like so # # _globals, = yield [("RD", [kwargs['task_root'], "globals"])] # global_var, = yield [("RD", [_globals, global_name])] # function_id, = yield [("RD", [global_var, "value"])] # body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])] # return self.jit_thunk( tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( tree_ir.ReadDictionaryValueInstruction( tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('task_root')), tree_ir.LiteralInstruction('globals')), tree_ir.LiteralInstruction(global_name)), tree_ir.LiteralInstruction('value')), tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)), global_name) def compile_function_body_interpret(jit, function_name, body_id, task_root, header=None): """Create a function that invokes the interpreter on the given function.""" (parameter_ids, parameter_list, _), = yield [ ("CALL_ARGS", [jit.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])] def __interpret_function(**kwargs): if header is not None: (done, result), = yield [("CALL_KWARGS", [header, kwargs])] if done: raise primitive_functions.PrimitiveFinished(result) local_args = {} inner_kwargs = dict(kwargs) for param_id, name in param_dict.items(): local_args[param_id] = inner_kwargs[name] del inner_kwargs[name] yield [("TAIL_CALL_ARGS", [bytecode_interpreter.interpret_bytecode_function, (function_name, body_bytecode, local_args, inner_kwargs)])] jit.jit_globals[function_name] = __interpret_function raise primitive_functions.PrimitiveFinished(__interpret_function) def compile_function_body_baseline( jit, function_name, body_id, task_root, header=None, compatible_temporary_protects=False): """Have the baseline JIT compile the function with the given name and body id.""" (parameter_ids, parameter_list, _), = yield [ ("CALL_ARGS", [jit.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list])) body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])] state = bytecode_to_tree.AnalysisState( jit, body_id, task_root, body_param_dict, jit.max_instructions) constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])] if header is not None: constructed_body = tree_ir.create_block(header, constructed_body) # Optimize the function's body. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])] # Wrap the tree IR in a function definition. constructed_function = create_function( function_name, parameter_list, param_dict, body_param_dict, constructed_body, jit.get_source_map_name(function_name), compatible_temporary_protects) # Convert the function definition to Python code, and compile it. raise primitive_functions.PrimitiveFinished( jit.jit_define_function(function_name, constructed_function)) def compile_function_body_fast(jit, function_name, body_id, _): """Have the fast JIT compile the function with the given name and body id.""" (parameter_ids, parameter_list, _), = yield [ ("CALL_ARGS", [jit.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])] bytecode_analyzer = bytecode_to_cfg.AnalysisState(jit, function_name, param_dict) bytecode_analyzer.analyze(body_bytecode) entry_point, = yield [ ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])] if jit.jit_code_log_function is not None: jit.jit_code_log_function( "CFG for function '%s' at '%d':\n%s" % ( function_name, body_id, '\n'.join(map(str, cfg_ir.get_all_reachable_blocks(entry_point))))) # Lower the CFG to tree IR. constructed_body = cfg_to_tree.lower_flow_graph(entry_point, jit) # Optimize the tree that was generated. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])] constructed_function = create_bare_function(function_name, parameter_list, constructed_body) # Convert the function definition to Python code, and compile it. raise primitive_functions.PrimitiveFinished( jit.jit_define_function(function_name, constructed_function)) def favor_large_functions(body_bytecode): """Computes the initial temperature of a function based on the size of its body bytecode. Larger functions are favored and the temperature is incremented by one on every call.""" # The rationale for this heuristic is that it does some damage control: # we can afford to decide (wrongly) not to fast-jit a small function, # because we can just fast-jit that function later on. Since the function # is so small, it will (hopefully) not be able to deal us a heavy blow in # terms of performance. # # If we decide not to fast-jit a large function however, we might end up # in a situation where said function runs for a long time before we # realize that we really should have jitted it. And that's exactly what # this heuristic tries to avoid. return len(body_bytecode.get_reachable()), 1 def favor_small_functions(body_bytecode): """Computes the initial temperature of a function based on the size of its body bytecode. Smaller functions are favored and the temperature is incremented by one on every call.""" # The rationale for this heuristic is that small functions are easy to # fast-jit, because they probably won't trigger the non-linear complexity # of fast-jit's algorithms. So it might be cheaper to fast-jit small # functions and get a performance boost from that than to fast-jit large # functions. return ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - len(body_bytecode.get_reachable()), 1 ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER = 4 ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD = 100 """The threshold temperature at which the adaptive JIT will use the baseline JIT.""" ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 250 """The threshold temperature at which the adaptive JIT will use the fast JIT.""" def favor_loops(body_bytecode): """Computes the initial temperature of a function. Code within a loop makes the function hotter; code outside loops makes the function colder. The temperature is incremented by one on every call.""" reachable_instructions = body_bytecode.get_reachable() # First set the temperature to the negative number of instructions. temperature = ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD - len(reachable_instructions) for instruction in reachable_instructions: if isinstance(instruction, bytecode_ir.WhileInstruction): # Then increase the temperature by the number of instructions reachable # from loop bodies. Note that the algorithm will count nested loops twice. # This is actually by design. loop_body_instructions = instruction.body.get_reachable( lambda x: not isinstance( x, (bytecode_ir.BreakInstruction, bytecode_ir.ContinueInstruction))) temperature += ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER * len(loop_body_instructions) return temperature, 1 def favor_small_loops(body_bytecode): """Computes the initial temperature of a function. Code within a loop makes the function hotter; code outside loops makes the function colder. The temperature is incremented by one on every call.""" reachable_instructions = body_bytecode.get_reachable() # First set the temperature to the negative number of instructions. temperature = ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - 50 - len(reachable_instructions) for instruction in reachable_instructions: if isinstance(instruction, bytecode_ir.WhileInstruction): # Then increase the temperature by the number of instructions reachable # from loop bodies. Note that the algorithm will count nested loops twice. # This is actually by design. loop_body_instructions = instruction.body.get_reachable( lambda x: not isinstance( x, (bytecode_ir.BreakInstruction, bytecode_ir.ContinueInstruction))) temperature += ( (ADAPTIVE_JIT_LOOP_INSTRUCTION_MULTIPLIER ** 2) * int(math.sqrt(len(loop_body_instructions)))) return temperature, max(int(math.log(len(reachable_instructions), 2)), 1) class AdaptiveJitState(object): """Shared state for adaptive JIT compilation.""" def __init__( self, temperature_counter_name, temperature_increment, can_rejit_name): self.temperature_counter_name = temperature_counter_name self.temperature_increment = temperature_increment self.can_rejit_name = can_rejit_name def compile_interpreter( self, jit, function_name, body_id, task_root): """Compiles the given function as a function that controls the temperature counter and calls the interpreter.""" def __increment_temperature(**kwargs): if jit.jit_globals[self.can_rejit_name]: temperature_counter_val = jit.jit_globals[self.temperature_counter_name] temperature_counter_val += self.temperature_increment jit.jit_globals[self.temperature_counter_name] = temperature_counter_val if temperature_counter_val >= ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD: if temperature_counter_val >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD: yield [ ("CALL_ARGS", [jit.jit_rejit, (task_root, body_id, function_name, compile_function_body_fast)])] else: yield [ ("CALL_ARGS", [jit.jit_rejit, (task_root, body_id, function_name, self.compile_baseline)])] result, = yield [("CALL_KWARGS", [jit.jit_globals[function_name], kwargs])] raise primitive_functions.PrimitiveFinished((True, result)) raise primitive_functions.PrimitiveFinished((False, None)) yield [ ("TAIL_CALL_ARGS", [compile_function_body_interpret, (jit, function_name, body_id, task_root, __increment_temperature)])] def compile_baseline( self, jit, function_name, body_id, task_root): """Compiles the given function with the baseline JIT, and inserts logic that controls the temperature counter.""" (_, parameter_list, _), = yield [ ("CALL_ARGS", [jit.jit_signature, (body_id,)])] # This tree represents the following logic: # # if can_rejit: # global temperature_counter # temperature_counter = temperature_counter + temperature_increment # if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD: # yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])] # yield [("TAIL_CALL_KWARGS", [function_name, {...}])] header = tree_ir.SelectInstruction( tree_ir.LoadGlobalInstruction(self.can_rejit_name), tree_ir.create_block( tree_ir.DeclareGlobalInstruction(self.temperature_counter_name), tree_ir.IgnoreInstruction( tree_ir.StoreGlobalInstruction( self.temperature_counter_name, tree_ir.BinaryInstruction( tree_ir.LoadGlobalInstruction(self.temperature_counter_name), '+', tree_ir.LiteralInstruction(self.temperature_increment)))), tree_ir.SelectInstruction( tree_ir.BinaryInstruction( tree_ir.LoadGlobalInstruction(self.temperature_counter_name), '>=', tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)), tree_ir.create_block( tree_ir.RunGeneratorFunctionInstruction( tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME), tree_ir.DictionaryLiteralInstruction([ (tree_ir.LiteralInstruction('task_root'), bytecode_to_tree.load_task_root()), (tree_ir.LiteralInstruction('body_id'), tree_ir.LiteralInstruction(body_id)), (tree_ir.LiteralInstruction('function_name'), tree_ir.LiteralInstruction(function_name)), (tree_ir.LiteralInstruction('compile_function_body'), tree_ir.LoadGlobalInstruction( jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]), result_type=tree_ir.NO_RESULT_TYPE), bytecode_to_tree.create_return( tree_ir.create_jit_call( tree_ir.LoadGlobalInstruction(function_name), [(name, tree_ir.LoadLocalInstruction(name)) for name in parameter_list], tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))), tree_ir.EmptyInstruction())), tree_ir.EmptyInstruction()) # Compile with the baseline JIT, and insert the header. yield [ ("TAIL_CALL_ARGS", [compile_function_body_baseline, (jit, function_name, body_id, task_root, header, True)])] def compile_function_body_adaptive( jit, function_name, body_id, task_root, temperature_heuristic=favor_loops): """Compile the function with the given name and body id. An execution engine is picked automatically, and the function may be compiled again at a later time.""" # The general idea behind this compilation technique is to first use the baseline JIT # to compile a function, and then switch to the fast JIT when we determine that doing # so would be a good idea. We maintain a 'temperature' counter, which has an initial value # and gets incremented every time the function is executed. body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])] initial_temperature, temperature_increment = temperature_heuristic(body_bytecode) if jit.jit_success_log_function is not None: jit.jit_success_log_function( "Initial temperature for '%s': %d" % (function_name, initial_temperature)) if initial_temperature >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD: # Initial temperature exceeds the fast-jit threshold. # Compile this thing with fast-jit right away. if jit.jit_success_log_function is not None: jit.jit_success_log_function( "Compiling '%s' with fast-jit." % function_name) yield [ ("TAIL_CALL_ARGS", [compile_function_body_fast, (jit, function_name, body_id, task_root)])] temperature_counter_name = jit.import_value( initial_temperature, function_name + "_temperature_counter") can_rejit_name = jit.get_can_rejit_name(function_name) jit.jit_globals[can_rejit_name] = True state = AdaptiveJitState(temperature_counter_name, temperature_increment, can_rejit_name) if initial_temperature >= ADAPTIVE_BASELINE_JIT_TEMPERATURE_THRESHOLD: # Initial temperature exceeds the baseline JIT threshold. # Compile this thing with baseline JIT right away. if jit.jit_success_log_function is not None: jit.jit_success_log_function( "Compiling '%s' with baseline-jit." % function_name) yield [ ("TAIL_CALL_ARGS", [state.compile_baseline, (jit, function_name, body_id, task_root)])] else: # Looks like we'll use the interpreter initially. if jit.jit_success_log_function is not None: jit.jit_success_log_function( "Compiling '%s' with bytecode-interpreter." % function_name) yield [ ("TAIL_CALL_ARGS", [state.compile_interpreter, (jit, function_name, body_id, task_root)])]