|
@@ -117,6 +117,8 @@ class ModelverseJit(object):
|
|
|
jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
|
|
|
jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
|
|
|
jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
|
|
|
+ jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit,
|
|
|
+ jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast,
|
|
|
jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
|
|
|
}
|
|
|
# jitted_entry_points maps body ids to values in jit_globals.
|
|
@@ -454,10 +456,6 @@ class ModelverseJit(object):
|
|
|
non-jittable, then a `JitCompilationFailedException` exception is thrown."""
|
|
|
if body_id is None:
|
|
|
raise ValueError('body_id cannot be None')
|
|
|
- elif body_id in self.jitted_entry_points:
|
|
|
- # We have already compiled this function.
|
|
|
- raise primitive_functions.PrimitiveFinished(
|
|
|
- self.jit_globals[self.jitted_entry_points[body_id]])
|
|
|
elif body_id in self.no_jit_entry_points:
|
|
|
# We're not allowed to jit this function or have tried and failed before.
|
|
|
raise JitCompilationFailedException(
|
|
@@ -471,9 +469,12 @@ class ModelverseJit(object):
|
|
|
'' if suggested_name is None else "'" + suggested_name + "'",
|
|
|
body_id))
|
|
|
|
|
|
- def jit_recompile(self, task_root, body_id, function_name):
|
|
|
+ def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
|
|
|
"""Replaces the function with the given name by compiling the bytecode at the given
|
|
|
body id."""
|
|
|
+ if compile_function_body is None:
|
|
|
+ compile_function_body = self.compile_function_body
|
|
|
+
|
|
|
self.check_jittable(body_id, function_name)
|
|
|
|
|
|
# Generate a name for the function we're about to analyze, and pretend that
|
|
@@ -510,7 +511,7 @@ class ModelverseJit(object):
|
|
|
"Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
|
|
|
|
|
|
constructed_function, = yield [
|
|
|
- ("CALL_ARGS", [self.compile_function_body, (self, function_name, body_id, task_root)])]
|
|
|
+ ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
|
|
|
|
|
|
yield [("END_TRY", [])]
|
|
|
del self.compilation_dependencies[body_id]
|
|
@@ -532,6 +533,10 @@ class ModelverseJit(object):
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
+ def get_can_rejit_name(self, function_name):
|
|
|
+ """Gets the name of the given jitted function's can-rejit flag."""
|
|
|
+ return function_name + "_can_rejit"
|
|
|
+
|
|
|
def jit_define_function(self, function_name, function_def):
|
|
|
"""Converts the given tree-IR function definition to Python code, defines it,
|
|
|
and extracts the resulting function."""
|
|
@@ -558,11 +563,35 @@ class ModelverseJit(object):
|
|
|
|
|
|
def jit_compile(self, task_root, body_id, suggested_name=None):
|
|
|
"""Tries to jit the function defined by the given entry point id and parameter list."""
|
|
|
- # Generate a name for the function we're about to analyze, and pretend that
|
|
|
- # it already exists. (we need to do this for recursive functions)
|
|
|
+ if body_id is None:
|
|
|
+ raise ValueError('body_id cannot be None')
|
|
|
+ elif body_id in self.jitted_entry_points:
|
|
|
+ # We have already compiled this function.
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ self.jit_globals[self.jitted_entry_points[body_id]])
|
|
|
+
|
|
|
+ # Generate a name for the function we're about to analyze, and 're-compile'
|
|
|
+ # it for the first time.
|
|
|
function_name = self.generate_function_name(body_id, suggested_name)
|
|
|
yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
|
|
|
|
|
|
+ def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
|
|
|
+ """Re-compiles the given function. If compilation fails, then the can-rejit
|
|
|
+ flag is set to false."""
|
|
|
+ old_jitted_func = self.jitted_entry_points[body_id]
|
|
|
+ def __handle_jit_failed(_):
|
|
|
+ self.jit_globals[self.get_can_rejit_name(function_name)] = False
|
|
|
+ self.jitted_entry_points[body_id] = old_jitted_func
|
|
|
+ self.no_jit_entry_points.remove(body_id)
|
|
|
+ raise primitive_functions.PrimitiveFinished(None)
|
|
|
+
|
|
|
+ yield [("TRY", [])]
|
|
|
+ yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
|
|
|
+ yield [
|
|
|
+ ("CALL_ARGS",
|
|
|
+ [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
|
|
|
+ yield [("END_TRY", [])]
|
|
|
+
|
|
|
def jit_thunk(self, get_function_body, global_name=None):
|
|
|
"""Creates a thunk from the given IR tree that computes the function's body id.
|
|
|
This thunk is a function that will invoke the function whose body id is retrieved.
|
|
@@ -680,7 +709,7 @@ class ModelverseJit(object):
|
|
|
tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
|
|
|
global_name)
|
|
|
|
|
|
-def compile_function_body_baseline(jit, function_name, body_id, task_root):
|
|
|
+def compile_function_body_baseline(jit, function_name, body_id, task_root, header=None):
|
|
|
"""Have the baseline JIT compile the function with the given name and body id."""
|
|
|
(parameter_ids, parameter_list, _), = yield [
|
|
|
("CALL_ARGS", [jit.jit_signature, (body_id,)])]
|
|
@@ -692,6 +721,9 @@ def compile_function_body_baseline(jit, function_name, body_id, task_root):
|
|
|
jit.max_instructions)
|
|
|
constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
|
|
|
|
|
|
+ if header is not None:
|
|
|
+ constructed_body = tree_ir.create_block(header, constructed_body)
|
|
|
+
|
|
|
# Optimize the function's body.
|
|
|
constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
|
|
|
|
|
@@ -726,3 +758,98 @@ def compile_function_body_fast(jit, function_name, body_id, _):
|
|
|
create_bare_function(
|
|
|
function_name, parameter_list,
|
|
|
constructed_body))
|
|
|
+
|
|
|
+def favor_large_functions(body_bytecode):
|
|
|
+ """Computes the initial temperature of a function based on the size of
|
|
|
+ its body bytecode. Larger functions are favored and the temperature
|
|
|
+ is incremented by one on every call."""
|
|
|
+ return (
|
|
|
+ len(body_bytecode.get_reachable()),
|
|
|
+ lambda old_value:
|
|
|
+ tree_ir.BinaryInstruction(
|
|
|
+ old_value,
|
|
|
+ '+',
|
|
|
+ tree_ir.LiteralInstruction(1)))
|
|
|
+
|
|
|
+ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 200
|
|
|
+"""The threshold temperature at which fast-jit will be used."""
|
|
|
+
|
|
|
+def compile_function_body_adaptive(
|
|
|
+ jit, function_name, body_id, task_root,
|
|
|
+ temperature_heuristic=favor_large_functions):
|
|
|
+ """Compile the function with the given name and body id. An execution engine is picked
|
|
|
+ automatically, and the function may be compiled again at a later time."""
|
|
|
+ # The general idea behind this compilation technique is to first use the baseline JIT
|
|
|
+ # to compile a function, and then switch to the fast JIT when we determine that doing
|
|
|
+ # so would be a good idea. We maintain a 'temperature' counter, which has an initial value
|
|
|
+ # and gets incremented every time the function is executed.
|
|
|
+
|
|
|
+ body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
|
|
|
+ initial_temperature, increment_temperature = temperature_heuristic(body_bytecode)
|
|
|
+
|
|
|
+ if initial_temperature >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
|
|
|
+ # Initial temperature exceeds the fast-jit threshold.
|
|
|
+ # Compile this thing with fast-jit right away.
|
|
|
+ yield [
|
|
|
+ ("TAIL_CALL_ARGS",
|
|
|
+ [compile_function_body_fast, (jit, function_name, body_id, task_root)])]
|
|
|
+
|
|
|
+ (_, parameter_list, _), = yield [
|
|
|
+ ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
|
|
|
+
|
|
|
+ temperature_counter_name = jit.import_value(
|
|
|
+ initial_temperature, function_name + "_temperature_counter")
|
|
|
+
|
|
|
+ can_rejit_name = jit.get_can_rejit_name(function_name)
|
|
|
+ jit.jit_globals[can_rejit_name] = True
|
|
|
+
|
|
|
+ # This tree represents the following logic:
|
|
|
+ #
|
|
|
+ # if can_rejit:
|
|
|
+ # global temperature_counter
|
|
|
+ # temperature_counter = increment_temperature(temperature_counter)
|
|
|
+ # if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
|
|
|
+ # yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
|
|
|
+ # yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
|
|
|
+
|
|
|
+ header = tree_ir.SelectInstruction(
|
|
|
+ tree_ir.LoadGlobalInstruction(can_rejit_name),
|
|
|
+ tree_ir.create_block(
|
|
|
+ tree_ir.DeclareGlobalInstruction(temperature_counter_name),
|
|
|
+ tree_ir.IgnoreInstruction(
|
|
|
+ tree_ir.StoreGlobalInstruction(
|
|
|
+ temperature_counter_name,
|
|
|
+ increment_temperature(
|
|
|
+ tree_ir.LoadGlobalInstruction(temperature_counter_name)))),
|
|
|
+ tree_ir.SelectInstruction(
|
|
|
+ tree_ir.BinaryInstruction(
|
|
|
+ tree_ir.LoadGlobalInstruction(temperature_counter_name),
|
|
|
+ '>=',
|
|
|
+ tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
|
|
|
+ tree_ir.create_block(
|
|
|
+ tree_ir.RunGeneratorFunctionInstruction(
|
|
|
+ tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
|
|
|
+ tree_ir.DictionaryLiteralInstruction([
|
|
|
+ (tree_ir.LiteralInstruction('task_root'),
|
|
|
+ bytecode_to_tree.load_task_root()),
|
|
|
+ (tree_ir.LiteralInstruction('body_id'),
|
|
|
+ tree_ir.LiteralInstruction(body_id)),
|
|
|
+ (tree_ir.LiteralInstruction('function_name'),
|
|
|
+ tree_ir.LiteralInstruction(function_name)),
|
|
|
+ (tree_ir.LiteralInstruction('compile_function_body'),
|
|
|
+ tree_ir.LoadGlobalInstruction(
|
|
|
+ jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
|
|
|
+ result_type=tree_ir.NO_RESULT_TYPE),
|
|
|
+ tree_ir.create_jit_call(
|
|
|
+ tree_ir.LoadGlobalInstruction(function_name),
|
|
|
+ [(name, tree_ir.LoadLocalInstruction(name)) for name in parameter_list],
|
|
|
+ tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
|
|
|
+ tree_ir.RunTailGeneratorFunctionInstruction)),
|
|
|
+ tree_ir.EmptyInstruction())),
|
|
|
+ tree_ir.EmptyInstruction())
|
|
|
+
|
|
|
+ # Compile with the baseline JIT, and insert the header.
|
|
|
+ yield [
|
|
|
+ ("TAIL_CALL_ARGS",
|
|
|
+ [compile_function_body_baseline,
|
|
|
+ (jit, function_name, body_id, task_root, header)])]
|