瀏覽代碼

Move parts of compile_function_body_adaptive into a separate class

jonathanvdc 8 年之前
父節點
當前提交
721b3d18d3
共有 1 個文件被更改,包括 72 次插入55 次删除
  1. 72 55
      kernel/modelverse_jit/jit.py

+ 72 - 55
kernel/modelverse_jit/jit.py

@@ -848,6 +848,76 @@ def favor_loops(body_bytecode):
 
     return temperature, 1
 
+class AdaptiveJitState(object):
+    """Shared state for adaptive JIT compilation."""
+    def __init__(
+            self, temperature_counter_name,
+            temperature_increment, can_rejit_name):
+        self.temperature_counter_name = temperature_counter_name
+        self.temperature_increment = temperature_increment
+        self.can_rejit_name = can_rejit_name
+
+    def compile_baseline(
+            self, jit, function_name, body_id, task_root):
+        """Compiles the given function with the baseline JIT, and inserts logic that controls
+           the temperature counter."""
+        (_, parameter_list, _), = yield [
+            ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
+
+        # This tree represents the following logic:
+        #
+        # if can_rejit:
+        #     global temperature_counter
+        #     temperature_counter = temperature_counter + temperature_increment
+        #     if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
+        #         yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
+        #         yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
+
+        header = tree_ir.SelectInstruction(
+            tree_ir.LoadGlobalInstruction(self.can_rejit_name),
+            tree_ir.create_block(
+                tree_ir.DeclareGlobalInstruction(self.temperature_counter_name),
+                tree_ir.IgnoreInstruction(
+                    tree_ir.StoreGlobalInstruction(
+                        self.temperature_counter_name,
+                        tree_ir.BinaryInstruction(
+                            tree_ir.LoadGlobalInstruction(self.temperature_counter_name),
+                            '+',
+                            tree_ir.LiteralInstruction(self.temperature_increment)))),
+                tree_ir.SelectInstruction(
+                    tree_ir.BinaryInstruction(
+                        tree_ir.LoadGlobalInstruction(self.temperature_counter_name),
+                        '>=',
+                        tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
+                    tree_ir.create_block(
+                        tree_ir.RunGeneratorFunctionInstruction(
+                            tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
+                            tree_ir.DictionaryLiteralInstruction([
+                                (tree_ir.LiteralInstruction('task_root'),
+                                 bytecode_to_tree.load_task_root()),
+                                (tree_ir.LiteralInstruction('body_id'),
+                                 tree_ir.LiteralInstruction(body_id)),
+                                (tree_ir.LiteralInstruction('function_name'),
+                                 tree_ir.LiteralInstruction(function_name)),
+                                (tree_ir.LiteralInstruction('compile_function_body'),
+                                 tree_ir.LoadGlobalInstruction(
+                                     jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
+                            result_type=tree_ir.NO_RESULT_TYPE),
+                        bytecode_to_tree.create_return(
+                            tree_ir.create_jit_call(
+                                tree_ir.LoadGlobalInstruction(function_name),
+                                [(name, tree_ir.LoadLocalInstruction(name))
+                                 for name in parameter_list],
+                                tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))),
+                    tree_ir.EmptyInstruction())),
+            tree_ir.EmptyInstruction())
+
+        # Compile with the baseline JIT, and insert the header.
+        yield [
+            ("TAIL_CALL_ARGS",
+             [compile_function_body_baseline,
+              (jit, function_name, body_id, task_root, header)])]
+
 def compile_function_body_adaptive(
         jit, function_name, body_id, task_root,
         temperature_heuristic=favor_loops):
@@ -871,64 +941,11 @@ def compile_function_body_adaptive(
             ("TAIL_CALL_ARGS",
              [compile_function_body_fast, (jit, function_name, body_id, task_root)])]
 
-    (_, parameter_list, _), = yield [
-        ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
-
     temperature_counter_name = jit.import_value(
         initial_temperature, function_name + "_temperature_counter")
 
     can_rejit_name = jit.get_can_rejit_name(function_name)
     jit.jit_globals[can_rejit_name] = True
 
-    # This tree represents the following logic:
-    #
-    # if can_rejit:
-    #     global temperature_counter
-    #     temperature_counter = temperature_counter + temperature_increment
-    #     if temperature_counter >= ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD:
-    #         yield [("CALL_KWARGS", [jit_runtime.JIT_REJIT_FUNCTION_NAME, {...}])]
-    #         yield [("TAIL_CALL_KWARGS", [function_name, {...}])]
-
-    header = tree_ir.SelectInstruction(
-        tree_ir.LoadGlobalInstruction(can_rejit_name),
-        tree_ir.create_block(
-            tree_ir.DeclareGlobalInstruction(temperature_counter_name),
-            tree_ir.IgnoreInstruction(
-                tree_ir.StoreGlobalInstruction(
-                    temperature_counter_name,
-                    tree_ir.BinaryInstruction(
-                        tree_ir.LoadGlobalInstruction(temperature_counter_name),
-                        '+',
-                        tree_ir.LiteralInstruction(temperature_increment)))),
-            tree_ir.SelectInstruction(
-                tree_ir.BinaryInstruction(
-                    tree_ir.LoadGlobalInstruction(temperature_counter_name),
-                    '>=',
-                    tree_ir.LiteralInstruction(ADAPTIVE_FAST_JIT_TEMPERATURE_THRESHOLD)),
-                tree_ir.create_block(
-                    tree_ir.RunGeneratorFunctionInstruction(
-                        tree_ir.LoadGlobalInstruction(jit_runtime.JIT_REJIT_FUNCTION_NAME),
-                        tree_ir.DictionaryLiteralInstruction([
-                            (tree_ir.LiteralInstruction('task_root'),
-                             bytecode_to_tree.load_task_root()),
-                            (tree_ir.LiteralInstruction('body_id'),
-                             tree_ir.LiteralInstruction(body_id)),
-                            (tree_ir.LiteralInstruction('function_name'),
-                             tree_ir.LiteralInstruction(function_name)),
-                            (tree_ir.LiteralInstruction('compile_function_body'),
-                             tree_ir.LoadGlobalInstruction(
-                                 jit_runtime.JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME))]),
-                        result_type=tree_ir.NO_RESULT_TYPE),
-                    bytecode_to_tree.create_return(
-                        tree_ir.create_jit_call(
-                            tree_ir.LoadGlobalInstruction(function_name),
-                            [(name, tree_ir.LoadLocalInstruction(name)) for name in parameter_list],
-                            tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))),
-                tree_ir.EmptyInstruction())),
-        tree_ir.EmptyInstruction())
-
-    # Compile with the baseline JIT, and insert the header.
-    yield [
-        ("TAIL_CALL_ARGS",
-         [compile_function_body_baseline,
-          (jit, function_name, body_id, task_root, header)])]
+    state = AdaptiveJitState(temperature_counter_name, temperature_increment, can_rejit_name)
+    yield [("TAIL_CALL_ARGS", [state.compile_baseline, (jit, function_name, body_id, task_root)])]