|
@@ -239,10 +239,12 @@ class ModelverseJit(object):
|
|
|
def register_compiled(self, body_id, compiled_function, function_name=None):
|
|
|
"""Registers a compiled entry point with the JIT."""
|
|
|
# Get the function's name.
|
|
|
- function_name = self.generate_function_name(body_id, function_name)
|
|
|
+ actual_function_name = self.generate_function_name(body_id, function_name)
|
|
|
# Map the body id to the given parameter list.
|
|
|
- self.jitted_entry_points[body_id] = function_name
|
|
|
- self.jit_globals[function_name] = compiled_function
|
|
|
+ self.jitted_entry_points[body_id] = actual_function_name
|
|
|
+ self.jit_globals[actual_function_name] = compiled_function
|
|
|
+ if function_name is not None:
|
|
|
+ self.register_global(body_id, function_name)
|
|
|
if body_id in self.todo_entry_points:
|
|
|
self.todo_entry_points.remove(body_id)
|
|
|
|
|
@@ -253,18 +255,46 @@ class ModelverseJit(object):
|
|
|
self.jit_globals[actual_name] = value
|
|
|
return actual_name
|
|
|
|
|
|
- def lookup_compiled_function(self, name):
|
|
|
- """Looks up a compiled function by name. Returns a matching function,
|
|
|
+ def __lookup_compiled_body_impl(self, body_id):
|
|
|
+ """Looks up a compiled function by body id. Returns a matching function,
|
|
|
or None if no function was found."""
|
|
|
- if name is None:
|
|
|
+ if body_id is not None and body_id in self.jitted_entry_points:
|
|
|
+ return self.jit_globals[self.jitted_entry_points[body_id]]
|
|
|
+ else:
|
|
|
return None
|
|
|
- elif name in self.jit_globals:
|
|
|
- return self.jit_globals[name]
|
|
|
- elif self.compiled_function_lookup is not None:
|
|
|
- return self.compiled_function_lookup(name)
|
|
|
+
|
|
|
+ def __lookup_external_body_impl(self, global_name, body_id):
|
|
|
+ """Looks up an external function by global name. Returns a matching function,
|
|
|
+ or None if no function was found."""
|
|
|
+ if self.compiled_function_lookup is not None:
|
|
|
+ result = self.compiled_function_lookup(global_name)
|
|
|
+ if result is not None and body_id is not None:
|
|
|
+ self.register_compiled(body_id, result, global_name)
|
|
|
+
|
|
|
+ return result
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
+ def lookup_compiled_body(self, body_id):
|
|
|
+ """Looks up a compiled function by body id. Returns a matching function,
|
|
|
+ or None if no function was found."""
|
|
|
+ result = self.__lookup_compiled_body_impl(body_id)
|
|
|
+ if result is not None:
|
|
|
+ return result
|
|
|
+ else:
|
|
|
+ global_name = self.get_global_name(body_id)
|
|
|
+ return self.__lookup_external_body_impl(global_name, body_id)
|
|
|
+
|
|
|
+ def lookup_compiled_function(self, global_name):
|
|
|
+ """Looks up a compiled function by global name. Returns a matching function,
|
|
|
+ or None if no function was found."""
|
|
|
+ body_id = self.get_global_body_id(global_name)
|
|
|
+ result = self.__lookup_compiled_body_impl(body_id)
|
|
|
+ if result is not None:
|
|
|
+ return result
|
|
|
+ else:
|
|
|
+ return self.__lookup_external_body_impl(global_name, body_id)
|
|
|
+
|
|
|
def get_intrinsic(self, name):
|
|
|
"""Tries to find an intrinsic version of the function with the
|
|
|
given name."""
|
|
@@ -412,9 +442,6 @@ class ModelverseJit(object):
|
|
|
self.jit_success_log_function(
|
|
|
"JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
|
|
|
|
|
|
- if self.jit_code_log_function is not None:
|
|
|
- self.jit_code_log_function(constructed_function)
|
|
|
-
|
|
|
raise primitive_functions.PrimitiveFinished(compiled_function)
|
|
|
|
|
|
def jit_define_function(self, function_name, function_def):
|
|
@@ -425,6 +452,10 @@ class ModelverseJit(object):
|
|
|
|
|
|
# Convert the function definition to Python code, and compile it.
|
|
|
exec(str(function_def), self.jit_globals)
|
|
|
+
|
|
|
+ if self.jit_code_log_function is not None:
|
|
|
+ self.jit_code_log_function(function_def)
|
|
|
+
|
|
|
# Extract the compiled function from the JIT global state.
|
|
|
return self.jit_globals[function_name]
|
|
|
|
|
@@ -438,3 +469,112 @@ class ModelverseJit(object):
|
|
|
# it already exists. (we need to do this for recursive functions)
|
|
|
function_name = self.generate_function_name(body_id, suggested_name)
|
|
|
yield [("TAIL_CALL_ARGS", [self.jit_recompile, (user_root, body_id, function_name)])]
|
|
|
+
|
|
|
+ def jit_thunk(self, get_function_body, global_name=None):
|
|
|
+ """Creates a thunk from the given IR tree that computes the function's body id.
|
|
|
+ This thunk is a function that will invoke the function whose body id is retrieved.
|
|
|
+ The thunk's name in the JIT's global context is returned."""
|
|
|
+ # The general idea is to first create a function that looks a bit like this:
|
|
|
+ #
|
|
|
+ # def jit_get_function_body(**kwargs):
|
|
|
+ # raise primitive_functions.PrimitiveFinished(<get_function_body>)
|
|
|
+ #
|
|
|
+ get_function_body_name = self.generate_name('get_function_body')
|
|
|
+ get_function_body_func_def, = yield [
|
|
|
+ ("CALL_ARGS",
|
|
|
+ [create_function,
|
|
|
+ (get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))])]
|
|
|
+ get_function_body_func = self.jit_define_function(
|
|
|
+ get_function_body_name, get_function_body_func_def)
|
|
|
+
|
|
|
+ # Next, we want to create a thunk that invokes said function, and then replaces itself.
|
|
|
+ def __jit_thunk(**kwargs):
|
|
|
+ # Compute the body id, and delete the function that computes the body id; we won't
|
|
|
+ # be needing it anymore after this call.
|
|
|
+ body_id = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
|
|
|
+ self.jit_delete_function(get_function_body_name)
|
|
|
+
|
|
|
+ # Try to associate the global name with the body id, if that's at all possible.
|
|
|
+ if global_name is not None:
|
|
|
+ self.register_global(body_id, global_name)
|
|
|
+
|
|
|
+ compiled_function = self.lookup_compiled_body(body_id)
|
|
|
+ if compiled_function is not None:
|
|
|
+ # Replace this thunk by the compiled function.
|
|
|
+ self.jit_globals[get_function_body_name] = compiled_function
|
|
|
+ else:
|
|
|
+ def __handle_jit_exception(_):
|
|
|
+ # Replace this thunk by a different thunk: one that calls the interpreter
|
|
|
+ # directly, without checking if the function is jittable.
|
|
|
+ (_, parameter_names, _), = yield [
|
|
|
+ ("CALL_ARGS", [self.jit_signature, (body_id,)])]
|
|
|
+ def __interpreter_thunk(**new_kwargs):
|
|
|
+ named_arg_dict = {name : new_kwargs[name] for name in parameter_names}
|
|
|
+ return jit_runtime.interpret_function_body(
|
|
|
+ body_id, named_arg_dict, **new_kwargs)
|
|
|
+
|
|
|
+ self.jit_globals[get_function_body_name] = __interpreter_thunk
|
|
|
+
|
|
|
+ yield [("TRY", [])]
|
|
|
+ yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
|
|
|
+ compiled_function, = yield [
|
|
|
+ ("CALL_ARGS",
|
|
|
+ [self.jit_recompile, (kwargs['user_root'], body_id, "jit_thunk")])]
|
|
|
+ yield [("END_TRY", [])]
|
|
|
+
|
|
|
+ # Call the compiled function.
|
|
|
+ yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
|
|
|
+
|
|
|
+ thunk_name = self.generate_name('thunk', global_name)
|
|
|
+ self.jit_globals[thunk_name] = __jit_thunk
|
|
|
+ raise primitive_functions.PrimitiveFinished(thunk_name)
|
|
|
+
|
|
|
+ def jit_thunk_constant(self, body_id):
|
|
|
+ """Creates a thunk from given body id.
|
|
|
+ This thunk is a function that will invoke the function whose body id is given.
|
|
|
+ The thunk's name in the JIT's global context is returned."""
|
|
|
+ # We might have compiled the function with the given body id already. In that case,
|
|
|
+ # we need not bother with constructing the thunk; we can return the compiled function
|
|
|
+ # right away.
|
|
|
+ if self.lookup_compiled_body(body_id) is not None:
|
|
|
+ raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
|
|
|
+
|
|
|
+ # Looks like we'll just have to build that thunk after all.
|
|
|
+ yield [
|
|
|
+ ("TAIL_CALL_ARGS",
|
|
|
+ [self.jit_thunk, (tree_ir.LiteralInstruction(body_id),)])]
|
|
|
+
|
|
|
+ def jit_thunk_global(self, global_name):
|
|
|
+ """Creates a thunk from given global name.
|
|
|
+ This thunk is a function that will invoke the function whose body id is given.
|
|
|
+ The thunk's name in the JIT's global context is returned."""
|
|
|
+ # We might have compiled the function with the given name already. In that case,
|
|
|
+ # we need not bother with constructing the thunk; we can return the compiled function
|
|
|
+ # right away.
|
|
|
+ body_id = self.get_global_body_id(global_name)
|
|
|
+ if body_id is not None and self.lookup_compiled_body(body_id) is not None:
|
|
|
+ raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
|
|
|
+
|
|
|
+ # Looks like we'll just have to build that thunk after all.
|
|
|
+ # We want to look up the global function like so
|
|
|
+ #
|
|
|
+ # _globals, = yield [("RD", [kwargs['user_root'], "globals"])]
|
|
|
+ # global_var, = yield [("RD", [_globals, global_name])]
|
|
|
+ # function_id, = yield [("RD", [global_var, "value"])]
|
|
|
+ # body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
|
|
|
+ #
|
|
|
+ yield [
|
|
|
+ ("TAIL_CALL_ARGS",
|
|
|
+ [self.jit_thunk,
|
|
|
+ (tree_ir.ReadDictionaryValueInstruction(
|
|
|
+ tree_ir.ReadDictionaryValueInstruction(
|
|
|
+ tree_ir.ReadDictionaryValueInstruction(
|
|
|
+ tree_ir.ReadDictionaryValueInstruction(
|
|
|
+ tree_ir.LoadIndexInstruction(
|
|
|
+ tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
|
|
|
+ tree_ir.LiteralInstruction('user_root')),
|
|
|
+ 'globals'),
|
|
|
+ global_name),
|
|
|
+ 'value'),
|
|
|
+ tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
|
|
|
+ global_name)])]
|