Browse Source

Add an adaptive JIT heuristic that favors small functions

jonathanvdc 8 years ago
parent
commit
8c89761326
2 changed files with 41 additions and 2 deletions
  1. 5 1
      hybrid_server/classes/mvkcontroller.xml
  2. 36 1
      kernel/modelverse_jit/jit.py

+ 5 - 1
hybrid_server/classes/mvkcontroller.xml

@@ -40,8 +40,12 @@
                     self.mvk.jit.set_function_body_compiler(jit.compile_function_body_fast)
                     self.mvk.jit.set_function_body_compiler(jit.compile_function_body_fast)
                 elif opt == 'baseline-jit':
                 elif opt == 'baseline-jit':
                     self.mvk.jit.set_function_body_compiler(jit.compile_function_body_baseline)
                     self.mvk.jit.set_function_body_compiler(jit.compile_function_body_baseline)
-                elif opt == 'adaptive-jit':
+                elif opt == 'adaptive-jit' or opt == 'adaptive-jit-favor-large-functions':
                     self.mvk.jit.set_function_body_compiler(jit.compile_function_body_adaptive)
                     self.mvk.jit.set_function_body_compiler(jit.compile_function_body_adaptive)
+                elif opt == 'adaptive-jit-favor-small-functions':
+                    self.mvk.jit.set_function_body_compiler(
+                        lambda *args: jit.compile_function_body_adaptive(
+                            *args, temperature_heuristic=jit.favor_small_functions))
                 else:
                 else:
                     print("warning: unknown kernel option '%s'." % opt)
                     print("warning: unknown kernel option '%s'." % opt)
 
 

+ 36 - 1
kernel/modelverse_jit/jit.py

@@ -1,4 +1,5 @@
 import keyword
 import keyword
+from collections import defaultdict
 import modelverse_kernel.primitives as primitive_functions
 import modelverse_kernel.primitives as primitive_functions
 import modelverse_jit.bytecode_parser as bytecode_parser
 import modelverse_jit.bytecode_parser as bytecode_parser
 import modelverse_jit.bytecode_to_tree as bytecode_to_tree
 import modelverse_jit.bytecode_to_tree as bytecode_to_tree
@@ -129,6 +130,8 @@ class ModelverseJit(object):
         self.global_functions_inv = {}
         self.global_functions_inv = {}
         # bytecode_graphs maps body ids to their parsed bytecode graphs.
         # bytecode_graphs maps body ids to their parsed bytecode graphs.
         self.bytecode_graphs = {}
         self.bytecode_graphs = {}
+        # jitted_function_aliases maps body ids to known aliases.
+        self.jitted_function_aliases = defaultdict(set)
         self.jit_count = 0
         self.jit_count = 0
         self.max_instructions = max_instructions
         self.max_instructions = max_instructions
         self.compiled_function_lookup = compiled_function_lookup
         self.compiled_function_lookup = compiled_function_lookup
@@ -587,11 +590,15 @@ class ModelverseJit(object):
 
 
         yield [("TRY", [])]
         yield [("TRY", [])]
         yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
         yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
-        yield [
+        jitted_function, = yield [
             ("CALL_ARGS",
             ("CALL_ARGS",
              [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
              [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
         yield [("END_TRY", [])]
         yield [("END_TRY", [])]
 
 
+        # Update all aliases.
+        for function_alias in self.jitted_function_aliases[body_id]:
+            self.jit_globals[function_alias] = jitted_function
+
     def jit_thunk(self, get_function_body, global_name=None):
     def jit_thunk(self, get_function_body, global_name=None):
         """Creates a thunk from the given IR tree that computes the function's body id.
         """Creates a thunk from the given IR tree that computes the function's body id.
            This thunk is a function that will invoke the function whose body id is retrieved.
            This thunk is a function that will invoke the function whose body id is retrieved.
@@ -623,6 +630,7 @@ class ModelverseJit(object):
             if compiled_function is not None:
             if compiled_function is not None:
                 # Replace this thunk by the compiled function.
                 # Replace this thunk by the compiled function.
                 self.jit_globals[thunk_name] = compiled_function
                 self.jit_globals[thunk_name] = compiled_function
+                self.jitted_function_aliases[body_id].add(thunk_name)
             else:
             else:
                 def __handle_jit_exception(_):
                 def __handle_jit_exception(_):
                     # Replace this thunk by a different thunk: one that calls the interpreter
                     # Replace this thunk by a different thunk: one that calls the interpreter
@@ -763,6 +771,16 @@ def favor_large_functions(body_bytecode):
     """Computes the initial temperature of a function based on the size of
     """Computes the initial temperature of a function based on the size of
        its body bytecode. Larger functions are favored and the temperature
        its body bytecode. Larger functions are favored and the temperature
        is incremented by one on every call."""
        is incremented by one on every call."""
+    # The rationale for this heuristic is that it does some damage control:
+    # we can afford to decide (wrongly) not to fast-jit a small function,
+    # because we can just fast-jit that function later on. Since the function
+    # is so small, it will (hopefully) not be able to deal us a heavy blow in
+    # terms of performance.
+    #
+    # If we decide not to fast-jit a large function however, we might end up
+    # in a situation where said function runs for a long time before we
+    # realize that we really should have jitted it. And that's exactly what
+    # this heuristic tries to avoid.
     return (
     return (
         len(body_bytecode.get_reachable()),
         len(body_bytecode.get_reachable()),
         lambda old_value:
         lambda old_value:
@@ -771,6 +789,23 @@ def favor_large_functions(body_bytecode):
             '+',
             '+',
             tree_ir.LiteralInstruction(1)))
             tree_ir.LiteralInstruction(1)))
 
 
+def favor_small_functions(body_bytecode):
+    """Computes the initial temperature of a function based on the size of
+       its body bytecode. Smaller functions are favored and the temperature
+       is incremented by one on every call."""
+    # The rationale for this heuristic is that small functions are easy to
+    # fast-jit, because they probably won't trigger the non-linear complexity
+    # of fast-jit's algorithms. So it might be cheaper to fast-jit small
+    # functions and get a performance boost from that than to fast-jit large
+    # functions.
+    return (
+        ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD - len(body_bytecode.get_reachable()),
+        lambda old_value:
+        tree_ir.BinaryInstruction(
+            old_value,
+            '+',
+            tree_ir.LiteralInstruction(1)))
+
 ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 200
 ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD = 200
 """The threshold temperature at which fast-jit will be used."""
 """The threshold temperature at which fast-jit will be used."""