import modelverse_kernel.primitives as primitive_functions import modelverse_jit.bytecode_ir as bytecode_ir import modelverse_jit.bytecode_parser as bytecode_parser import modelverse_jit.bytecode_to_tree as bytecode_to_tree import modelverse_jit.tree_ir as tree_ir import modelverse_jit.runtime as jit_runtime import keyword # Import JitCompilationFailedException because it used to be defined # in this module. JitCompilationFailedException = jit_runtime.JitCompilationFailedException def map_and_simplify_generator(function, instruction): """Applies the given mapping function to every instruction in the tree that has the given instruction as root, and simplifies it on-the-fly. This is at least as powerful as first mapping and then simplifying, as maps and simplifications are interspersed. This function assumes that function creates a generator that returns by raising a primitive_functions.PrimitiveFinished.""" # First handle the children by mapping on them and then simplifying them. new_children = [] for inst in instruction.get_children(): new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])] new_children.append(new_inst) # Then apply the function to the top-level node. transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])] # Finally, simplify the transformed top-level node. raise primitive_functions.PrimitiveFinished(transformed.simplify_node()) def expand_constant_read(instruction): """Tries to replace a read of a constant node by a literal.""" if isinstance(instruction, tree_ir.ReadValueInstruction) and \ isinstance(instruction.node_id, tree_ir.LiteralInstruction): val, = yield [("RV", [instruction.node_id.literal])] raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val)) else: raise primitive_functions.PrimitiveFinished(instruction) def optimize_tree_ir(instruction): """Optimizes an IR tree.""" return map_and_simplify_generator(expand_constant_read, instruction) def print_value(val): """A thin wrapper around 'print'.""" print(val) class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self, max_instructions=None, compiled_function_lookup=None): self.todo_entry_points = set() self.no_jit_entry_points = set() self.jitted_parameters = {} self.jit_globals = { 'PrimitiveFinished' : primitive_functions.PrimitiveFinished, jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function, jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input } # jitted_entry_points maps body ids to values in jit_globals. self.jitted_entry_points = {} # global_functions maps global value names to body ids. self.global_functions = {} # global_functions_inv maps body ids to global value names. self.global_functions_inv = {} # bytecode_graphs maps body ids to their parsed bytecode graphs. self.bytecode_graphs = {} self.jit_count = 0 self.max_instructions = max_instructions self.compiled_function_lookup = compiled_function_lookup # jit_intrinsics is a function name -> intrinsic map. self.jit_intrinsics = {} self.compilation_dependencies = {} self.jit_enabled = True self.direct_calls_allowed = True self.tracing_enabled = False self.input_function_enabled = False self.nop_insertion_enabled = True self.jit_success_log_function = None self.jit_code_log_function = None def set_jit_enabled(self, is_enabled=True): """Enables or disables the JIT.""" self.jit_enabled = is_enabled def allow_direct_calls(self, is_allowed=True): """Allows or disallows direct calls from jitted to jitted code.""" self.direct_calls_allowed = is_allowed def use_input_function(self, is_enabled=True): """Configures the JIT to compile 'input' instructions as function calls.""" self.input_function_enabled = is_enabled def enable_tracing(self, is_enabled=True): """Enables or disables tracing for jitted code.""" self.tracing_enabled = is_enabled def enable_nop_insertion(self, is_enabled=True): """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop back-edges. Inserting nops sacrifices performance to keep the jitted code from blocking the thread of execution by consuming all resources; nops give the Modelverse server an opportunity to interrupt the currently running code.""" self.nop_insertion_enabled = is_enabled def set_jit_success_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Success and failure messages for specific functions are then sent to said log.""" self.jit_success_log_function = log_function def set_jit_code_log(self, log_function=print_value): """Configures this JIT instance with a function that prints output to a log. Function definitions of jitted functions are then sent to said log.""" self.jit_code_log_function = log_function def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point.""" return body_id in self.todo_entry_points or \ body_id in self.no_jit_entry_points or \ body_id in self.jitted_entry_points def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable. This only returns `True` if the JIT is enabled and the function entry point has been marked jittable, or if the function has already been compiled.""" return ((self.jit_enabled and body_id in self.todo_entry_points) or self.has_compiled(body_id)) def has_compiled(self, body_id): """Tests if the function belonging to the given body node has been compiled yet.""" return body_id in self.jitted_entry_points def get_compiled_name(self, body_id): """Gets the name of the compiled version of the given body node in the JIT global state.""" return self.jitted_entry_points[body_id] def mark_no_jit(self, body_id): """Informs the JIT that the node with the given identifier is a function entry point that must never be jitted.""" self.no_jit_entry_points.add(body_id) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def generate_name(self, infix, suggested_name=None): """Generates a new name or picks the suggested name if it is still available.""" if suggested_name is not None \ and suggested_name not in self.jit_globals \ and not keyword.iskeyword(suggested_name): self.jit_count += 1 return suggested_name else: function_name = 'jit_%s%d' % (infix, self.jit_count) self.jit_count += 1 return function_name def generate_function_name(self, body_id, suggested_name=None): """Generates a new function name or picks the suggested name if it is still available.""" if suggested_name is None: suggested_name = self.get_global_name(body_id) return self.generate_name('func', suggested_name) def register_global(self, body_id, global_name): """Associates the given body id with the given global name.""" self.global_functions[global_name] = body_id self.global_functions_inv[body_id] = global_name def get_global_name(self, body_id): """Gets the name of the global function with the given body id. Returns None if no known global exists with the given id.""" if body_id in self.global_functions_inv: return self.global_functions_inv[body_id] else: return None def get_global_body_id(self, global_name): """Gets the body id of the global function with the given name. Returns None if no known global exists with the given name.""" if global_name in self.global_functions: return self.global_functions[global_name] else: return None def register_compiled(self, body_id, compiled_function, function_name=None): """Registers a compiled entry point with the JIT.""" # Get the function's name. function_name = self.generate_function_name(body_id, function_name) # Map the body id to the given parameter list. self.jitted_entry_points[body_id] = function_name self.jit_globals[function_name] = compiled_function if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def import_value(self, value, suggested_name=None): """Imports the given value into the JIT's global scope, with the given suggested name. The actual name of the value (within the JIT's global scope) is returned.""" actual_name = self.generate_name('import', suggested_name) self.jit_globals[actual_name] = value return actual_name def lookup_compiled_function(self, name): """Looks up a compiled function by name. Returns a matching function, or None if no function was found.""" if name is None: return None elif name in self.jit_globals: return self.jit_globals[name] elif self.compiled_function_lookup is not None: return self.compiled_function_lookup(name) else: return None def get_intrinsic(self, name): """Tries to find an intrinsic version of the function with the given name.""" if name in self.jit_intrinsics: return self.jit_intrinsics[name] else: return None def register_intrinsic(self, name, intrinsic_function): """Registers the given intrisic with the JIT. This will make the JIT replace calls to the function with the given entry point by an application of the specified function.""" self.jit_intrinsics[name] = intrinsic_function def register_binary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given binary operation.""" self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction( tree_ir.BinaryInstruction( tree_ir.ReadValueInstruction(a), operator, tree_ir.ReadValueInstruction(b)))) def register_unary_intrinsic(self, name, operator): """Registers an intrinsic with the JIT that represents the given unary operation.""" self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.UnaryInstruction( operator, tree_ir.ReadValueInstruction(a)))) def register_cast_intrinsic(self, name, target_type): """Registers an intrinsic with the JIT that represents a unary conversion operator.""" self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction( tree_ir.CallInstruction( tree_ir.LoadGlobalInstruction(target_type.__name__), [tree_ir.ReadValueInstruction(a)]))) def jit_signature(self, body_id): """Acquires the signature for the given body id node, which consists of the parameter variables, parameter name and a flag that tells if the given function is mutable.""" if body_id not in self.jitted_parameters: signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])] signature_id = signature_id[0] param_set_id, is_mutable = yield [ ("RD", [signature_id, "params"]), ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])] if param_set_id is None: self.jitted_parameters[body_id] = ([], [], is_mutable) else: param_name_ids, = yield [("RDK", [param_set_id])] param_names = yield [("RV", [n]) for n in param_name_ids] param_vars = yield [("RD", [param_set_id, k]) for k in param_names] self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable) raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id]) def jit_parse_bytecode(self, body_id): """Parses the given function body as a bytecode graph.""" if body_id in self.bytecode_graphs: raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id]) parser = bytecode_parser.BytecodeParser() result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])] self.bytecode_graphs[body_id] = result raise primitive_functions.PrimitiveFinished(result) def jit_compile(self, user_root, body_id, suggested_name=None): """Tries to jit the function defined by the given entry point id and parameter list.""" # The comment below makes pylint shut up about our (hopefully benign) use of exec here. # pylint: disable=I0011,W0122 if body_id is None: raise ValueError('body_id cannot be None') elif body_id in self.jitted_entry_points: # We have already compiled this function. raise primitive_functions.PrimitiveFinished( self.jit_globals[self.jitted_entry_points[body_id]]) elif body_id in self.no_jit_entry_points: # We're not allowed to jit this function or have tried and failed before. raise JitCompilationFailedException( 'Cannot jit function %s at %d because it is marked non-jittable.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) elif not self.jit_enabled: # We're not allowed to jit anything. raise JitCompilationFailedException( 'Cannot jit function %s at %d because the JIT has been disabled.' % ( '' if suggested_name is None else "'" + suggested_name + "'", body_id)) # Generate a name for the function we're about to analyze, and pretend that # it already exists. (we need to do this for recursive functions) function_name = self.generate_function_name(body_id, suggested_name) self.jitted_entry_points[body_id] = function_name self.jit_globals[function_name] = None (parameter_ids, parameter_list, is_mutable), = yield [ ("CALL_ARGS", [self.jit_signature, (body_id,)])] param_dict = dict(zip(parameter_ids, parameter_list)) body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list])) dependencies = set([body_id]) self.compilation_dependencies[body_id] = dependencies def handle_jit_exception(exception): # If analysis fails, then a JitCompilationFailedException will be thrown. del self.compilation_dependencies[body_id] for dep in dependencies: self.mark_no_jit(dep) if dep in self.jitted_entry_points: del self.jitted_entry_points[dep] failure_message = "%s (function '%s' at %d)" % ( exception.message, function_name, body_id) if self.jit_success_log_function is not None: self.jit_success_log_function('JIT compilation failed: %s' % failure_message) raise JitCompilationFailedException(failure_message) # Try to analyze the function's body. yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])] if is_mutable: # We can't just JIT mutable functions. That'd be dangerous. raise JitCompilationFailedException( "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY) body_bytecode, = yield [("CALL_ARGS", [self.jit_parse_bytecode, (body_id,)])] state = bytecode_to_tree.AnalysisState( self, body_id, user_root, body_param_dict, self.max_instructions) constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])] yield [("END_TRY", [])] del self.compilation_dependencies[body_id] # Write a prologue and prepend it to the generated function body. prologue_statements = [] # Create a LOCALS_NODE_NAME node, and connect it to the user root. prologue_statements.append( tree_ir.create_new_local_node( jit_runtime.LOCALS_NODE_NAME, tree_ir.LoadIndexInstruction( tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME), tree_ir.LiteralInstruction('user_root')), jit_runtime.LOCALS_EDGE_NAME)) for (key, val) in param_dict.items(): arg_ptr = tree_ir.create_new_local_node( body_param_dict[key], tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME)) prologue_statements.append(arg_ptr) prologue_statements.append( tree_ir.CreateDictionaryEdgeInstruction( tree_ir.LoadLocalInstruction(body_param_dict[key]), tree_ir.LiteralInstruction('value'), tree_ir.LoadLocalInstruction(val))) constructed_body = tree_ir.create_block( *(prologue_statements + [constructed_body])) # Optimize the function's body. constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])] # Shield temporaries from the GC. constructed_body = tree_ir.protect_temporaries_from_gc( constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME)) # Wrap the IR in a function definition, give it a unique name. constructed_function = tree_ir.DefineFunctionInstruction( function_name, parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME], constructed_body) # Convert the function definition to Python code, and compile it. exec(str(constructed_function), self.jit_globals) # Extract the compiled function from the JIT global state. compiled_function = self.jit_globals[function_name] if self.jit_success_log_function is not None: self.jit_success_log_function( "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id)) if self.jit_code_log_function is not None: self.jit_code_log_function(constructed_function) raise primitive_functions.PrimitiveFinished(compiled_function)