import math import keyword from collections import defaultdict import modelverse_kernel.primitives as primitive_functions class ModelverseJit(object): """A high-level interface to the modelverse JIT compiler.""" def __init__(self): self.todo_entry_points = set() self.jitted_parameters = {} self.jit_globals = {} # jitted_entry_points maps body ids to values in jit_globals. self.jitted_entry_points = {} # global_functions maps global value names to body ids. self.global_functions = {} # global_functions_inv maps body ids to global value names. self.global_functions_inv = {} # jitted_function_aliases maps body ids to known aliases. self.jitted_function_aliases = defaultdict(set) self.jit_count = 0 self.compilation_dependencies = {} self.cache = {} def mark_entry_point(self, body_id): """Marks the node with the given identifier as a function entry point.""" if body_id not in self.jitted_entry_points: self.todo_entry_points.add(body_id) def is_jittable_entry_point(self, body_id): """Tells if the node with the given identifier is a function entry point that has not been marked as non-jittable. This only returns `True` if the JIT is enabled and the function entry point has been marked jittable, or if the function has already been compiled.""" return ((body_id in self.todo_entry_points) or (body_id in self.jitted_entry_points)) def generate_name(self, infix, suggested_name=None): """Generates a new name or picks the suggested name if it is still available.""" if suggested_name is not None \ and suggested_name not in self.jit_globals \ and not keyword.iskeyword(suggested_name): self.jit_count += 1 return suggested_name else: function_name = 'jit_%s%d' % (infix, self.jit_count) self.jit_count += 1 return function_name def generate_function_name(self, body_id, suggested_name=None): """Generates a new function name or picks the suggested name if it is still available.""" if suggested_name is None: suggested_name = self.get_global_name(body_id) return self.generate_name('func', suggested_name) def register_global(self, body_id, global_name): """Associates the given body id with the given global name.""" self.global_functions[global_name] = body_id self.global_functions_inv[body_id] = global_name def get_global_name(self, body_id): """Gets the name of the global function with the given body id. Returns None if no known global exists with the given id.""" if body_id in self.global_functions_inv: return self.global_functions_inv[body_id] else: return None def get_global_body_id(self, global_name): """Gets the body id of the global function with the given name. Returns None if no known global exists with the given name.""" if global_name in self.global_functions: return self.global_functions[global_name] else: return None def register_compiled(self, body_id, compiled_function, function_name=None): """Registers a compiled entry point with the JIT.""" # Get the function's name. actual_function_name = self.generate_function_name(body_id, function_name) # Map the body id to the given parameter list. self.jitted_entry_points[body_id] = actual_function_name self.jit_globals[actual_function_name] = compiled_function if function_name is not None: self.register_global(body_id, function_name) if body_id in self.todo_entry_points: self.todo_entry_points.remove(body_id) def __lookup_compiled_body_impl(self, body_id): """Looks up a compiled function by body id. Returns a matching function, or None if no function was found.""" if body_id is not None and body_id in self.jitted_entry_points: return self.jit_globals[self.jitted_entry_points[body_id]] else: return None def __lookup_external_body_impl(self, global_name, body_id): """Looks up an external function by global name. Returns a matching function, or None if no function was found.""" if global_name is not None and self.compiled_function_lookup is not None: result = self.compiled_function_lookup(global_name) if result is not None and body_id is not None: self.register_compiled(body_id, result, global_name) return result else: return None def lookup_compiled_body(self, body_id): """Looks up a compiled function by body id. Returns a matching function, or None if no function was found.""" result = self.__lookup_compiled_body_impl(body_id) if result is not None: return result else: global_name = self.get_global_name(body_id) return self.__lookup_external_body_impl(global_name, body_id) def lookup_compiled_function(self, global_name): """Looks up a compiled function by global name. Returns a matching function, or None if no function was found.""" body_id = self.get_global_body_id(global_name) result = self.__lookup_compiled_body_impl(body_id) if result is not None: return result else: return self.__lookup_external_body_impl(global_name, body_id) def jit_signature(self, body_id): """Acquires the signature for the given body id node, which consists of the parameter variables, parameter name and a flag that tells if the given function is mutable.""" if body_id not in self.jitted_parameters: signature_id, = yield [("RRD", [body_id, "body"])] signature_id = signature_id[0] param_set_id, is_mutable = yield [ ("RD", [signature_id, "params"]), ("RD", [signature_id, "mutable"])] if param_set_id is None: self.jitted_parameters[body_id] = ([], [], is_mutable) else: param_name_ids, = yield [("RDK", [param_set_id])] param_names = yield [("RV", [n]) for n in param_name_ids] #NOTE Patch up strange links... param_names = [i for i in param_names if i is not None] param_vars = yield [("RD", [param_set_id, k]) for k in param_names] #NOTE that variables might not be in the correct order, as we just read them out! lst = sorted([(name, var) for name, var in zip(param_names, param_vars)]) param_vars = [i[1] for i in lst] param_names = [i[0] for i in lst] self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable) yield [("RETURN", [self.jitted_parameters[body_id]])] def check_jittable(self, body_id, suggested_name=None): """Checks if the function with the given body id is obviously non-jittable. If it's non-jittable, then a `JitCompilationFailedException` exception is thrown.""" if body_id is None: raise ValueError('body_id cannot be None: ' + suggested_name) def jit_define_function(self, function_name, function_def): """Converts the given tree-IR function definition to Python code, defines it, and extracts the resulting function.""" # The comment below makes pylint shut up about our (hopefully benign) use of exec here. # pylint: disable=I0011,W0122 if self.jit_code_log_function is not None: self.jit_code_log_function(function_def) # Convert the function definition to Python code, and compile it. code_generator = tree_ir.PythonGenerator() function_def.generate_python_def(code_generator) source_map_name = self.get_source_map_name(function_name) if source_map_name is not None: self.jit_globals[source_map_name] = code_generator.source_map_builder.source_map exec(str(code_generator), self.jit_globals) # Extract the compiled function from the JIT global state. return self.jit_globals[function_name] def jit_delete_function(self, function_name): """Deletes the function with the given function name.""" del self.jit_globals[function_name] import modelverse_kernel.primitives as primitive_functions class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass