Переглянути джерело

Define an adaptive JIT mode

jonathanvdc 8 роки тому
батько
коміт
0f17d07805

+ 77 - 0
kernel/modelverse_jit/bytecode_ir.py

@@ -6,6 +6,24 @@ class Instruction(object):
         self.next_instruction = None
         self.debug_information = None
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        raise NotImplementedError()
+
+    def get_reachable(self):
+        """Gets the set of all instructions that are reachable from the given instruction, including
+           this instruction."""
+        results = set()
+        stack = [self]
+        while len(stack) > 0:
+            instr = stack.pop()
+            results.add(instr)
+            for other in instr.get_directly_reachable():
+                if other not in results:
+                    stack.append(other)
+
+        return results
+
 class VariableNode(object):
     """Represents a variable node, which has an identifier and an optional name."""
     def __init__(self, node_id, name):
@@ -26,6 +44,13 @@ class SelectInstruction(Instruction):
         self.if_clause = if_clause
         self.else_clause = else_clause
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        if self.else_clause is None:
+            return (self.condition, self.if_clause)
+        else:
+            return (self.condition, self.if_clause, self.else_clause)
+
     constructor_parameters = (
         ('cond', Instruction),
         ('then', Instruction),
@@ -42,6 +67,10 @@ class WhileInstruction(Instruction):
         self.condition = condition
         self.body = body
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.condition, self.body)
+
     constructor_parameters = (
         ('cond', Instruction),
         ('body', Instruction))
@@ -56,6 +85,10 @@ class BreakInstruction(Instruction):
         Instruction.__init__(self)
         self.loop = loop
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.loop,)
+
     constructor_parameters = (('while', WhileInstruction),)
 
     def __repr__(self):
@@ -68,6 +101,10 @@ class ContinueInstruction(Instruction):
         Instruction.__init__(self)
         self.loop = loop
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.loop,)
+
     constructor_parameters = (('while', WhileInstruction),)
 
     def __repr__(self):
@@ -81,6 +118,10 @@ class ReturnInstruction(Instruction):
         Instruction.__init__(self)
         self.value = value
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.value,)
+
     constructor_parameters = (('value', Instruction),)
 
     def __repr__(self):
@@ -95,6 +136,10 @@ class CallInstruction(Instruction):
         self.target = target
         self.argument_list = argument_list
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.target,) + tuple((value for _, value in self.argument_list))
+
     def __repr__(self):
         return '@%r: CallInstruction(@%r, [%s])' % (
             id(self), id(self.target),
@@ -108,6 +153,10 @@ class ConstantInstruction(Instruction):
         self.constant_id = constant_id
         assert self.constant_id is not None
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('node', int),)
 
     def __repr__(self):
@@ -119,6 +168,10 @@ class InputInstruction(Instruction):
     def __init__(self):
         Instruction.__init__(self)
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = ()
 
     def __repr__(self):
@@ -131,6 +184,10 @@ class OutputInstruction(Instruction):
         Instruction.__init__(self)
         self.value = value
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.value,)
+
     constructor_parameters = (('value', Instruction),)
 
     def __repr__(self):
@@ -143,6 +200,10 @@ class DeclareInstruction(Instruction):
         Instruction.__init__(self)
         self.variable = variable
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('var', VariableNode),)
 
     def __repr__(self):
@@ -155,6 +216,10 @@ class GlobalInstruction(Instruction):
         Instruction.__init__(self)
         self.variable = variable
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('var', VariableNode),)
 
     def __repr__(self):
@@ -168,6 +233,10 @@ class ResolveInstruction(Instruction):
         Instruction.__init__(self)
         self.variable = variable
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return ()
+
     constructor_parameters = (('var', VariableNode),)
 
     def __repr__(self):
@@ -181,6 +250,10 @@ class AccessInstruction(Instruction):
         Instruction.__init__(self)
         self.pointer = pointer
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.pointer,)
+
     constructor_parameters = (('var', Instruction),)
 
     def __repr__(self):
@@ -195,6 +268,10 @@ class AssignInstruction(Instruction):
         self.pointer = pointer
         self.value = value
 
+    def get_directly_reachable(self):
+        """Gets all instructions that are directly reachable from this instruction."""
+        return (self.pointer, self.value)
+
     constructor_parameters = (
         ('var', Instruction),
         ('value', Instruction))

+ 136 - 9
kernel/modelverse_jit/jit.py

@@ -117,6 +117,8 @@ class ModelverseJit(object):
             jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
             jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant_function,
             jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global,
+            jit_runtime.JIT_REJIT_FUNCTION_NAME : self.jit_rejit,
+            jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME : compile_function_body_fast,
             jit_runtime.UNREACHABLE_FUNCTION_NAME : jit_runtime.unreachable
         }
         # jitted_entry_points maps body ids to values in jit_globals.
@@ -454,10 +456,6 @@ class ModelverseJit(object):
            non-jittable, then a `JitCompilationFailedException` exception is thrown."""
         if body_id is None:
             raise ValueError('body_id cannot be None')
-        elif body_id in self.jitted_entry_points:
-            # We have already compiled this function.
-            raise primitive_functions.PrimitiveFinished(
-                self.jit_globals[self.jitted_entry_points[body_id]])
         elif body_id in self.no_jit_entry_points:
             # We're not allowed to jit this function or have tried and failed before.
             raise JitCompilationFailedException(
@@ -471,9 +469,12 @@ class ModelverseJit(object):
                     '' if suggested_name is None else "'" + suggested_name + "'",
                     body_id))
 
-    def jit_recompile(self, task_root, body_id, function_name):
+    def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
         """Replaces the function with the given name by compiling the bytecode at the given
            body id."""
+        if compile_function_body is None:
+            compile_function_body = self.compile_function_body
+
         self.check_jittable(body_id, function_name)
 
         # Generate a name for the function we're about to analyze, and pretend that
@@ -510,7 +511,7 @@ class ModelverseJit(object):
                 "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
 
         constructed_function, = yield [
-            ("CALL_ARGS", [self.compile_function_body, (self, function_name, body_id, task_root)])]
+            ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
 
         yield [("END_TRY", [])]
         del self.compilation_dependencies[body_id]
@@ -532,6 +533,10 @@ class ModelverseJit(object):
         else:
             return None
 
+    def get_can_rejit_name(self, function_name):
+        """Gets the name of the given jitted function's can-rejit flag."""
+        return function_name + "_can_rejit"
+
     def jit_define_function(self, function_name, function_def):
         """Converts the given tree-IR function definition to Python code, defines it,
            and extracts the resulting function."""
@@ -558,11 +563,35 @@ class ModelverseJit(object):
 
     def jit_compile(self, task_root, body_id, suggested_name=None):
         """Tries to jit the function defined by the given entry point id and parameter list."""
-        # Generate a name for the function we're about to analyze, and pretend that
-        # it already exists. (we need to do this for recursive functions)
+        if body_id is None:
+            raise ValueError('body_id cannot be None')
+        elif body_id in self.jitted_entry_points:
+            # We have already compiled this function.
+            raise primitive_functions.PrimitiveFinished(
+                self.jit_globals[self.jitted_entry_points[body_id]])
+
+        # Generate a name for the function we're about to analyze, and 're-compile'
+        # it for the first time.
         function_name = self.generate_function_name(body_id, suggested_name)
         yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
 
+    def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
+        """Re-compiles the given function. If compilation fails, then the can-rejit
+           flag is set to false."""
+        old_jitted_func = self.jitted_entry_points[body_id]
+        def __handle_jit_failed(_):
+            self.jit_globals[self.get_can_rejit_name(function_name)] = False
+            self.jitted_entry_points[body_id] = old_jitted_func
+            self.no_jit_entry_points.remove(body_id)
+            raise primitive_functions.PrimitiveFinished(None)
+
+        yield [("TRY", [])]
+        yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
+        yield [
+            ("CALL_ARGS",
+             [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
+        yield [("END_TRY", [])]
+
     def jit_thunk(self, get_function_body, global_name=None):
         """Creates a thunk from the given IR tree that computes the function's body id.
            This thunk is a function that will invoke the function whose body id is retrieved.
@@ -680,7 +709,7 @@ class ModelverseJit(object):
                 tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
             global_name)
 
-def compile_function_body_baseline(jit, function_name, body_id, task_root):
+def compile_function_body_baseline(jit, function_name, body_id, task_root, header=None):
     """Have the baseline JIT compile the function with the given name and body id."""
     (parameter_ids, parameter_list, _), = yield [
         ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
@@ -692,6 +721,9 @@ def compile_function_body_baseline(jit, function_name, body_id, task_root):
         jit.max_instructions)
     constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
 
+    if header is not None:
+        constructed_body = tree_ir.create_block(header, constructed_body)
+
     # Optimize the function's body.
     constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
 
@@ -726,3 +758,98 @@ def compile_function_body_fast(jit, function_name, body_id, _):
         create_bare_function(
             function_name, parameter_list,
             constructed_body))
+
+def favor_large_functions(body_bytecode):
+    """Computes the initial temperature of a function based on the size of
+       its body bytecode. Larger functions are favored and the temperature
+       is incremented by one on every call."""
+    return (
+        len(body_bytecode.get_reachable()),
+        lambda old_value:
+        tree_ir.BinaryInstruction(
+            old_value,
+            '+',
+            tree_ir.LiteralInstruction(1)))
+
+ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 200
+"""The threshold temperature at which fast-jit will be used."""
+
+def compile_function_body_adaptive(
+        jit, function_name, body_id, task_root,
+        temperature_heuristic=favor_large_functions):
+    """Compile the function with the given name and body id. An execution engine is picked
+       automatically, and the function may be compiled again at a later time."""
+    # The general idea behind this compilation technique is to first use the baseline JIT
+    # to compile a function, and then switch to the fast JIT when we determine that doing
+    # so would be a good idea. We maintain a 'temperature' counter, which has an initial value
+    # and gets incremented every time the function is executed.
+
+    body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
+    initial_temperature, increment_temperature = temperature_heuristic(body_bytecode)
+
+    if initial_temperature >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
+        # Initial temperature exceeds the fast-jit threshold.
+        # Compile this thing with fast-jit right away.
+        yield [
+            ("TAIL_CALL_ARGS",
+             [compile_function_body_fast, (jit, function_name, body_id, task_root)])]
+
+    (_, parameter_list, _), = yield [
+        ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
+
+    temperature_counter_name = jit.import_value(
+        initial_temperature, function_name + "_temperature_counter")
+
+    can_rejit_name = jit.get_can_rejit_name(function_name)
+    jit.jit_globals[can_rejit_name] = True
+
+    # This tree represents the following logic:
+    #
+    # if can_rejit:
+    #     global temperature_counter
+    #     temperature_counter = increment_temperature(temperature_counter)
+    #     if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
+    #         yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
+    #         yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
+
+    header = tree_ir.SelectInstruction(
+        tree_ir.LoadGlobalInstruction(can_rejit_name),
+        tree_ir.create_block(
+            tree_ir.DeclareGlobalInstruction(temperature_counter_name),
+            tree_ir.IgnoreInstruction(
+                tree_ir.StoreGlobalInstruction(
+                    temperature_counter_name,
+                    increment_temperature(
+                        tree_ir.LoadGlobalInstruction(temperature_counter_name)))),
+            tree_ir.SelectInstruction(
+                tree_ir.BinaryInstruction(
+                    tree_ir.LoadGlobalInstruction(temperature_counter_name),
+                    '>=',
+                    tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
+                tree_ir.create_block(
+                    tree_ir.RunGeneratorFunctionInstruction(
+                        tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
+                        tree_ir.DictionaryLiteralInstruction([
+                            (tree_ir.LiteralInstruction('task_root'),
+                             bytecode_to_tree.load_task_root()),
+                            (tree_ir.LiteralInstruction('body_id'),
+                             tree_ir.LiteralInstruction(body_id)),
+                            (tree_ir.LiteralInstruction('function_name'),
+                             tree_ir.LiteralInstruction(function_name)),
+                            (tree_ir.LiteralInstruction('compile_function_body'),
+                             tree_ir.LoadGlobalInstruction(
+                                 jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
+                        result_type=tree_ir.NO_RESULT_TYPE),
+                    tree_ir.create_jit_call(
+                        tree_ir.LoadGlobalInstruction(function_name),
+                        [(name, tree_ir.LoadLocalInstruction(name)) for name in parameter_list],
+                        tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
+                        tree_ir.RunTailGeneratorFunctionInstruction)),
+                tree_ir.EmptyInstruction())),
+        tree_ir.EmptyInstruction())
+
+    # Compile with the baseline JIT, and insert the header.
+    yield [
+        ("TAIL_CALL_ARGS",
+         [compile_function_body_baseline,
+          (jit, function_name, body_id, task_root, header)])]

+ 6 - 0
kernel/modelverse_jit/runtime.py

@@ -25,6 +25,12 @@ JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant_function"
 JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
 """The name of the jit_thunk_global function in the JIT's global context."""
 
+JIT_REJIT_FUNCTION_NAME = "__jit_rejit"
+"""The name of the rejit function in the JIT's global context."""
+
+JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME = "__jit_compile_function_body_fast"
+"""The name of the compile_function_body_fast function in the JIT's global context."""
+
 UNREACHABLE_FUNCTION_NAME = "__unreachable"
 """The name of the unreachable function in the JIT's global context."""
 

+ 64 - 10
kernel/modelverse_jit/tree_ir.py

@@ -823,12 +823,18 @@ class DictionaryLiteralInstruction(Instruction):
 
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
-        return [val for _, val in self.key_value_pairs]
+        results = []
+        for key, val in self.key_value_pairs:
+            results.append(key)
+            results.append(val)
+        return results
 
     def create(self, new_children):
         """Creates a new instruction of this type from the given sequence of child instructions."""
-        keys = [k for k, _ in self.key_value_pairs]
-        return DictionaryLiteralInstruction(zip(keys, new_children))
+        new_kv_pairs = []
+        for i in xrange(len(self.key_value_pairs)):
+            new_kv_pairs.append((new_children[2 * i], new_children[2 * i + 1]))
+        return DictionaryLiteralInstruction(new_kv_pairs)
 
     def has_definition_impl(self):
         """Tells if this instruction requires a definition."""
@@ -981,14 +987,11 @@ class RunGeneratorFunctionInstruction(StateInstruction):
 
 class RunTailGeneratorFunctionInstruction(StateInstruction):
     """An instruction that runs a generator function."""
-    def __init__(self, function, argument_dict):
+    def __init__(self, function, argument_dict, result_type=NO_RESULT_TYPE):
         StateInstruction.__init__(self)
         self.function = function
         self.argument_dict = argument_dict
-
-    def get_result_type_impl(self):
-        """Gets the type of value produced by this instruction."""
-        return NO_RESULT_TYPE
+        self.result_type_cache = result_type
 
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
@@ -1234,6 +1237,55 @@ class LoadGlobalInstruction(VariableInstruction):
     def __repr__(self):
         return 'LoadGlobalInstruction(%r)' % self.name
 
+class StoreGlobalInstruction(VariableInstruction):
+    """An instruction that assigns a value to a global variable."""
+    def __init__(self, name, value):
+        VariableInstruction.__init__(self, name)
+        self.value = value
+
+    def get_children(self):
+        """Gets this instruction's sequence of child instructions."""
+        return [self.value]
+
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        val, = new_children
+        return StoreGlobalInstruction(self.name, val)
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+           The statement is appended immediately to the code generator."""
+        code_generator.append_move_definition(self, self.value)
+
+    def __repr__(self):
+        return 'StoreGlobalInstruction(%r, %r)' % (self.name, self.value)
+
+class DeclareGlobalInstruction(VariableInstruction):
+    """An instruction that declares a name as a global variable."""
+    def get_children(self):
+        """Gets this instruction's sequence of child instructions."""
+        return []
+
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
+
+    def has_result_temporary(self):
+        """Tells if this instruction stores its result in a temporary."""
+        return False
+
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        return self
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+           The statement is appended immediately to the code generator."""
+        code_generator.append_line('global %s' % self.name)
+
+    def __repr__(self):
+        return 'DeclareGlobalInstruction(%r)' % self.name
+
 class LoadIndexInstruction(Instruction):
     """An instruction that produces a value by indexing a specified expression with
        a given key."""
@@ -1602,7 +1654,9 @@ def create_block(*statements):
             statements[0],
             create_block(*statements[1:]))
 
-def create_jit_call(target, named_arguments, kwargs):
+def create_jit_call(
+        target, named_arguments, kwargs,
+        create_run_generator=RunGeneratorFunctionInstruction):
     """Creates a call that abides by the JIT's calling convention."""
     # A JIT call looks like this:
     #
@@ -1628,7 +1682,7 @@ def create_jit_call(target, named_arguments, kwargs):
             [kwargs]))
     return CompoundInstruction(
         create_block(*results),
-        RunGeneratorFunctionInstruction(
+        create_run_generator(
             target, arg_dict.create_load(), NODE_RESULT_TYPE))
 
 def evaluate_and_load(value):