Browse Source

Allow the kernel to switch between baseline and complex JIT

jonathanvdc 8 years ago
parent
commit
7deb185691
2 changed files with 57 additions and 34 deletions
  1. 53 34
      kernel/modelverse_jit/jit.py
  2. 4 0
      kernel/modelverse_kernel/main.py

+ 53 - 34
kernel/modelverse_jit/jit.py

@@ -133,6 +133,7 @@ class ModelverseJit(object):
         self.thunks_enabled = True
         self.jit_success_log_function = None
         self.jit_code_log_function = None
+        self.compile_function_body = compile_function_body_baseline
 
     def set_jit_enabled(self, is_enabled=True):
         """Enables or disables the JIT."""
@@ -176,6 +177,10 @@ class ModelverseJit(object):
            Function definitions of jitted functions are then sent to said log."""
         self.jit_code_log_function = log_function
 
+    def set_function_body_compiler(self, compile_function_body):
+        """Sets the function that the JIT uses to compile function bodies."""
+        self.compile_function_body = compile_function_body
+
     def mark_entry_point(self, body_id):
         """Marks the node with the given identifier as a function entry point."""
         if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
@@ -433,11 +438,9 @@ class ModelverseJit(object):
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = None
 
-        (parameter_ids, parameter_list, is_mutable), = yield [
+        (_, _, is_mutable), = yield [
             ("CALL_ARGS", [self.jit_signature, (body_id,)])]
 
-        param_dict = dict(zip(parameter_ids, parameter_list))
-        body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
         dependencies = set([body_id])
         self.compilation_dependencies[body_id] = dependencies
 
@@ -462,41 +465,13 @@ class ModelverseJit(object):
             # We can't just JIT mutable functions. That'd be dangerous.
             raise JitCompilationFailedException(
                 "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
-        body_bytecode, = yield [("CALL_ARGS", [self.jit_parse_bytecode, (body_id,)])]
-        state = bytecode_to_tree.AnalysisState(
-            self, body_id, task_root, body_param_dict,
-            self.max_instructions)
-        constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
-        if self.jit_code_log_function is not None:
-            bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict)
-            bytecode_analyzer.analyze(body_bytecode)
-            yield [
-                ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, self)])]
-            self.jit_code_log_function(
-                "CFG for function '%s' at '%d':\n%s" % (
-                    function_name, body_id,
-                    '\n'.join(
-                        map(
-                            str,
-                            cfg_optimization.get_all_reachable_blocks(
-                                bytecode_analyzer.entry_point)))))
-            cfg_func = create_bare_function(
-                function_name, parameter_list,
-                cfg_to_tree.lower_flow_graph(bytecode_analyzer.entry_point, self))
-            self.jit_code_log_function(
-                "Lowered CFG for function '%s' at '%d':\n%s" % (
-                    function_name, body_id, cfg_func))
+
+        constructed_function, = yield [
+            ("CALL_ARGS", [self.compile_function_body, (self, function_name, body_id, task_root)])]
 
         yield [("END_TRY", [])]
         del self.compilation_dependencies[body_id]
 
-        # Optimize the function's body.
-        constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
-
-        # Wrap the tree IR in a function definition.
-        constructed_function = create_function(
-            function_name, parameter_list, param_dict, body_param_dict, constructed_body)
-
         # Convert the function definition to Python code, and compile it.
         compiled_function = self.jit_define_function(function_name, constructed_function)
 
@@ -639,3 +614,47 @@ class ModelverseJit(object):
                     tree_ir.LiteralInstruction('value')),
                 tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
             global_name)
+
+def compile_function_body_baseline(jit, function_name, body_id, task_root):
+    """Have the baseline JIT compile the function with the given name and body id."""
+    (parameter_ids, parameter_list, _), = yield [
+        ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
+    param_dict = dict(zip(parameter_ids, parameter_list))
+    body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
+    body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
+    state = bytecode_to_tree.AnalysisState(
+        jit, body_id, task_root, body_param_dict,
+        jit.max_instructions)
+    constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
+
+    # Optimize the function's body.
+    constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
+
+    # Wrap the tree IR in a function definition.
+    raise primitive_functions.PrimitiveFinished(
+        create_function(
+            function_name, parameter_list, param_dict, body_param_dict, constructed_body))
+
+def compile_function_body_fast(jit, function_name, body_id, _):
+    """Have the fast JIT compile the function with the given name and body id."""
+    (parameter_ids, parameter_list, _), = yield [
+        ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
+    param_dict = dict(zip(parameter_ids, parameter_list))
+    body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
+    bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict)
+    bytecode_analyzer.analyze(body_bytecode)
+    yield [
+        ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])]
+    if jit.jit_code_log_function is not None:
+        jit.jit_code_log_function(
+            "CFG for function '%s' at '%d':\n%s" % (
+                function_name, body_id,
+                '\n'.join(
+                    map(
+                        str,
+                        cfg_optimization.get_all_reachable_blocks(
+                            bytecode_analyzer.entry_point)))))
+    raise primitive_functions.PrimitiveFinished(
+        create_bare_function(
+            function_name, parameter_list,
+            cfg_to_tree.lower_flow_graph(bytecode_analyzer.entry_point, jit)))

+ 4 - 0
kernel/modelverse_kernel/main.py

@@ -48,6 +48,10 @@ class ModelverseKernel(object):
         #
         #     self.jit.set_jit_enabled(False)
         #
+        # To always use the fast JIT, uncomment the line below:
+        #
+        #     self.jit.set_function_body_compiler(jit.compile_function_body_fast)
+        #
         # To disable direct calls in the JIT, uncomment the line below:
         #
         #     self.jit.allow_direct_calls(False)