|
@@ -111,13 +111,6 @@ class ModelverseJit(object):
|
|
|
self.jitted_parameters = {}
|
|
|
self.jit_globals = {
|
|
|
'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
|
|
|
- jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
|
|
|
- jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
|
|
|
- jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
|
|
|
- jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
|
|
|
- jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit,
|
|
|
- jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast,
|
|
|
- jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
|
|
|
}
|
|
|
# jitted_entry_points maps body ids to values in jit_globals.
|
|
|
self.jitted_entry_points = {}
|
|
@@ -139,7 +132,6 @@ class ModelverseJit(object):
|
|
|
self.source_maps_enabled = True
|
|
|
self.input_function_enabled = False
|
|
|
self.nop_insertion_enabled = True
|
|
|
- self.thunks_enabled = True
|
|
|
self.jit_success_log_function = None
|
|
|
self.jit_code_log_function = None
|
|
|
self.compile_function_body = compile_function_body_baseline
|
|
@@ -175,14 +167,6 @@ class ModelverseJit(object):
|
|
|
the currently running code."""
|
|
|
self.nop_insertion_enabled = is_enabled
|
|
|
|
|
|
- def enable_thunks(self, is_enabled=True):
|
|
|
- """Enables or disables thunks for jitted code. Thunks delay the compilation of
|
|
|
- functions until they are actually used. Thunks generally reduce start-up
|
|
|
- time.
|
|
|
-
|
|
|
- Thunks are enabled by default."""
|
|
|
- self.thunks_enabled = is_enabled
|
|
|
-
|
|
|
def set_jit_success_log(self, log_function=print_value):
|
|
|
"""Configures this JIT instance with a function that prints output to a log.
|
|
|
Success and failure messages for specific functions are then sent to said log."""
|
|
@@ -521,124 +505,6 @@ class ModelverseJit(object):
|
|
|
for function_alias in self.jitted_function_aliases[body_id]:
|
|
|
self.jit_globals[function_alias] = jitted_function
|
|
|
|
|
|
- def jit_thunk(self, get_function_body, global_name=None):
|
|
|
- """Creates a thunk from the given IR tree that computes the function's body id.
|
|
|
- This thunk is a function that will invoke the function whose body id is retrieved.
|
|
|
- The thunk's name in the JIT's global context is returned."""
|
|
|
- # The general idea is to first create a function that looks a bit like this:
|
|
|
- #
|
|
|
- # def jit_get_function_body(**kwargs):
|
|
|
- # raise primitive_functions.PrimitiveFinished(<get_function_body>)
|
|
|
- #
|
|
|
- get_function_body_name = self.generate_name('get_function_body')
|
|
|
- get_function_body_func_def = create_function(
|
|
|
- get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))
|
|
|
- get_function_body_func = self.jit_define_function(
|
|
|
- get_function_body_name, get_function_body_func_def)
|
|
|
-
|
|
|
- # Next, we want to create a thunk that invokes said function, and then replaces itself.
|
|
|
- thunk_name = self.generate_name('thunk', global_name)
|
|
|
- def __jit_thunk(**kwargs):
|
|
|
- # Compute the body id, and delete the function that computes the body id; we won't
|
|
|
- # be needing it anymore after this call.
|
|
|
- body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
|
|
|
- self.jit_delete_function(get_function_body_name)
|
|
|
-
|
|
|
- # Try to associate the global name with the body id, if that's at all possible.
|
|
|
- if global_name is not None:
|
|
|
- self.register_global(body_id, global_name)
|
|
|
-
|
|
|
- compiled_function = self.lookup_compiled_body(body_id)
|
|
|
- if compiled_function is not None:
|
|
|
- # Replace this thunk by the compiled function.
|
|
|
- self.jit_globals[thunk_name] = compiled_function
|
|
|
- self.jitted_function_aliases[body_id].add(thunk_name)
|
|
|
- else:
|
|
|
- def __handle_jit_exception(_):
|
|
|
- # Replace this thunk by a different thunk: one that calls the interpreter
|
|
|
- # directly, without checking if the function is jittable.
|
|
|
- (_, parameter_names, _), = yield [
|
|
|
- ("CALL_ARGS", [self.jit_signature, (body_id,)])]
|
|
|
- def __interpreter_thunk(**new_kwargs):
|
|
|
- named_arg_dict = {name : new_kwargs[name] for name in parameter_names}
|
|
|
- return jit_runtime.interpret_function_body(
|
|
|
- body_id, named_arg_dict, **new_kwargs)
|
|
|
-
|
|
|
- self.jit_globals[thunk_name] = __interpreter_thunk
|
|
|
-
|
|
|
- yield [("TRY", [])]
|
|
|
- yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
|
|
|
- compiled_function, = yield [
|
|
|
- ("CALL_ARGS",
|
|
|
- [self.jit_recompile, (kwargs['task_root'], body_id, thunk_name)])]
|
|
|
- yield [("END_TRY", [])]
|
|
|
-
|
|
|
- # Call the compiled function.
|
|
|
- yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
|
|
|
-
|
|
|
- self.jit_globals[thunk_name] = __jit_thunk
|
|
|
- return thunk_name
|
|
|
-
|
|
|
- def jit_thunk_constant_body(self, body_id):
|
|
|
- """Creates a thunk from the given body id.
|
|
|
- This thunk is a function that will invoke the function whose body id is given.
|
|
|
- The thunk's name in the JIT's global context is returned."""
|
|
|
- self.lookup_compiled_body(body_id)
|
|
|
- compiled_name = self.get_compiled_name(body_id)
|
|
|
- if compiled_name is not None:
|
|
|
- # We might have compiled the function with the given body id already. In that case,
|
|
|
- # we need not bother with constructing the thunk; we can return the compiled function
|
|
|
- # right away.
|
|
|
- return compiled_name
|
|
|
- else:
|
|
|
- # Looks like we'll just have to build that thunk after all.
|
|
|
- return self.jit_thunk(tree_ir.LiteralInstruction(body_id))
|
|
|
-
|
|
|
- def jit_thunk_constant_function(self, body_id):
|
|
|
- """Creates a thunk from the given function id.
|
|
|
- This thunk is a function that will invoke the function whose function id is given.
|
|
|
- The thunk's name in the JIT's global context is returned."""
|
|
|
- return self.jit_thunk(
|
|
|
- tree_ir.ReadDictionaryValueInstruction(
|
|
|
- tree_ir.LiteralInstruction(body_id),
|
|
|
- tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)))
|
|
|
-
|
|
|
- def jit_thunk_global(self, global_name):
|
|
|
- """Creates a thunk from given global name.
|
|
|
- This thunk is a function that will invoke the function whose body id is given.
|
|
|
- The thunk's name in the JIT's global context is returned."""
|
|
|
- # We might have compiled the function with the given name already. In that case,
|
|
|
- # we need not bother with constructing the thunk; we can return the compiled function
|
|
|
- # right away.
|
|
|
- body_id = self.get_global_body_id(global_name)
|
|
|
- if body_id is not None:
|
|
|
- self.lookup_compiled_body(body_id)
|
|
|
- compiled_name = self.get_compiled_name(body_id)
|
|
|
- if compiled_name is not None:
|
|
|
- return compiled_name
|
|
|
-
|
|
|
- # Looks like we'll just have to build that thunk after all.
|
|
|
- # We want to look up the global function like so
|
|
|
- #
|
|
|
- # _globals, = yield [("RD", [kwargs['task_root'], "globals"])]
|
|
|
- # global_var, = yield [("RD", [_globals, global_name])]
|
|
|
- # function_id, = yield [("RD", [global_var, "value"])]
|
|
|
- # body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
|
|
|
- #
|
|
|
- return self.jit_thunk(
|
|
|
- tree_ir.ReadDictionaryValueInstruction(
|
|
|
- tree_ir.ReadDictionaryValueInstruction(
|
|
|
- tree_ir.ReadDictionaryValueInstruction(
|
|
|
- tree_ir.ReadDictionaryValueInstruction(
|
|
|
- tree_ir.LoadIndexInstruction(
|
|
|
- tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
|
|
|
- tree_ir.LiteralInstruction('task_root')),
|
|
|
- tree_ir.LiteralInstruction('globals')),
|
|
|
- tree_ir.LiteralInstruction(global_name)),
|
|
|
- tree_ir.LiteralInstruction('value')),
|
|
|
- tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
|
|
|
- global_name)
|
|
|
-
|
|
|
def new_compile(self, body_id):
|
|
|
print("Compiling body ID " + str(body_id))
|
|
|
raise JitCompilationFailedException("Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
|