Browse Source

Use thunks in the JIT

jonathanvdc 8 years ago
parent
commit
8cc038d626

+ 27 - 9
kernel/modelverse_jit/bytecode_to_tree.py

@@ -588,15 +588,33 @@ class AnalysisState(object):
                 self.register_function_var(target.pointer.variable.node_id)
                 resolved_var_name = target.pointer.variable.name
 
-                # Try to look up the name as a global.
-                _globals, = yield [("RD", [self.user_root, "globals"])]
-                global_var, = yield [("RD", [_globals, resolved_var_name])]
-                global_val, = yield [("RD", [global_var, "value"])]
-
-                if global_val is not None:
-                    result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
-                        global_val, resolved_var_name, argument_list)])]
-                    raise primitive_functions.PrimitiveFinished(result)
+                if self.jit.thunks_enabled:
+                    # Analyze the argument dictionary.
+                    named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (argument_list,)])]
+
+                    # Try to resolve the callee as an intrinsic.
+                    intrinsic = self.jit.get_intrinsic(resolved_var_name)
+                    if intrinsic is not None:
+                        raise primitive_functions.PrimitiveFinished(
+                            apply_intrinsic(intrinsic, named_args))
+
+                    # Otherwise, build a thunk.
+                    thunk_name = self.jit.jit_thunk_global(target.pointer.variable.name)
+                    raise primitive_functions.PrimitiveFinished(
+                        tree_ir.create_jit_call(
+                            tree_ir.LoadGlobalInstruction(thunk_name),
+                            named_args,
+                            tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))
+                else:
+                    # Try to look up the name as a global.
+                    _globals, = yield [("RD", [self.user_root, "globals"])]
+                    global_var, = yield [("RD", [_globals, resolved_var_name])]
+                    global_val, = yield [("RD", [global_var, "value"])]
+
+                    if global_val is not None:
+                        result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
+                            global_val, resolved_var_name, argument_list)])]
+                        raise primitive_functions.PrimitiveFinished(result)
         elif isinstance(target, bytecode_ir.ConstantInstruction):
             # 'const(func_id)' instructions are also translated to direct calls.
             result, = yield [("CALL_ARGS", [self.analyze_direct_call, (

+ 52 - 48
kernel/modelverse_jit/jit.py

@@ -74,9 +74,6 @@ def create_function(
     constructed_body = tree_ir.create_block(
         *(prologue_statements + [function_body]))
 
-    # Optimize the function's body.
-    constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
-
     # Shield temporaries from the GC.
     constructed_body = tree_ir.protect_temporaries_from_gc(
         constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
@@ -87,7 +84,7 @@ def create_function(
         parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
         constructed_body)
 
-    raise primitive_functions.PrimitiveFinished(constructed_function)
+    return constructed_function
 
 def print_value(val):
     """A thin wrapper around 'print'."""
@@ -102,7 +99,9 @@ class ModelverseJit(object):
         self.jit_globals = {
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
             jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
-            jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input
+            jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
+            jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant,
+            jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global
         }
         # jitted_entry_points maps body ids to values in jit_globals.
         self.jitted_entry_points = {}
@@ -123,6 +122,7 @@ class ModelverseJit(object):
         self.tracing_enabled = False
         self.input_function_enabled = False
         self.nop_insertion_enabled = True
+        self.thunks_enabled = True
         self.jit_success_log_function = None
         self.jit_code_log_function = None
 
@@ -143,12 +143,21 @@ class ModelverseJit(object):
         self.tracing_enabled = is_enabled
 
     def enable_nop_insertion(self, is_enabled=True):
-        """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop
-           back-edges. Inserting nops sacrifices performance to keep the jitted code from
-           blocking the thread of execution by consuming all resources; nops give the
-           Modelverse server an opportunity to interrupt the currently running code."""
+        """Enables or disables nop insertion for jitted code. If enabled, the JIT will
+           insert nops at loop back-edges. Inserting nops sacrifices performance to
+           keep the jitted code from blocking the thread of execution and consuming
+           all resources; nops give the Modelverse server an opportunity to interrupt
+           the currently running code."""
         self.nop_insertion_enabled = is_enabled
 
+    def enable_thunks(self, is_enabled=True):
+        """Enables or disables thunks for jitted code. Thunks delay the compilation of
+           functions until they are actually used. Thunks generally reduce start-up
+           time.
+
+           Thunks are enabled by default."""
+        self.thunks_enabled = is_enabled
+
     def set_jit_success_log(self, log_function=print_value):
         """Configures this JIT instance with a function that prints output to a log.
            Success and failure messages for specific functions are then sent to said log."""
@@ -429,11 +438,12 @@ class ModelverseJit(object):
         yield [("END_TRY", [])]
         del self.compilation_dependencies[body_id]
 
+        # Optimize the function's body.
+        constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
+
         # Wrap the tree IR in a function definition.
-        constructed_function, = yield [
-            ("CALL_ARGS",
-             [create_function,
-              (function_name, parameter_list, param_dict, body_param_dict, constructed_body)])]
+        constructed_function = create_function(
+            function_name, parameter_list, param_dict, body_param_dict, constructed_body)
 
         # Convert the function definition to Python code, and compile it.
         compiled_function = self.jit_define_function(function_name, constructed_function)
@@ -480,18 +490,17 @@ class ModelverseJit(object):
         #     raise primitive_functions.PrimitiveFinished(<get_function_body>)
         #
         get_function_body_name = self.generate_name('get_function_body')
-        get_function_body_func_def, = yield [
-            ("CALL_ARGS",
-             [create_function,
-              (get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))])]
+        get_function_body_func_def = create_function(
+            get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))
         get_function_body_func = self.jit_define_function(
             get_function_body_name, get_function_body_func_def)
 
         # Next, we want to create a thunk that invokes said function, and then replaces itself.
+        thunk_name = self.generate_name('thunk', global_name)
         def __jit_thunk(**kwargs):
             # Compute the body id, and delete the function that computes the body id; we won't
             # be needing it anymore after this call.
-            body_id = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
+            body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
             self.jit_delete_function(get_function_body_name)
 
             # Try to associate the global name with the body id, if that's at all possible.
@@ -501,7 +510,7 @@ class ModelverseJit(object):
             compiled_function = self.lookup_compiled_body(body_id)
             if compiled_function is not None:
                 # Replace this thunk by the compiled function.
-                self.jit_globals[get_function_body_name] = compiled_function
+                self.jit_globals[thunk_name] = compiled_function
             else:
                 def __handle_jit_exception(_):
                     # Replace this thunk by a different thunk: one that calls the interpreter
@@ -513,36 +522,33 @@ class ModelverseJit(object):
                         return jit_runtime.interpret_function_body(
                             body_id, named_arg_dict, **new_kwargs)
 
-                    self.jit_globals[get_function_body_name] = __interpreter_thunk
+                    self.jit_globals[thunk_name] = __interpreter_thunk
 
                 yield [("TRY", [])]
                 yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
                 compiled_function, = yield [
                     ("CALL_ARGS",
-                     [self.jit_recompile, (kwargs['user_root'], body_id, "jit_thunk")])]
+                     [self.jit_recompile, (kwargs['user_root'], body_id, thunk_name)])]
                 yield [("END_TRY", [])]
 
             # Call the compiled function.
             yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
 
-        thunk_name = self.generate_name('thunk', global_name)
         self.jit_globals[thunk_name] = __jit_thunk
-        raise primitive_functions.PrimitiveFinished(thunk_name)
+        return thunk_name
 
     def jit_thunk_constant(self, body_id):
         """Creates a thunk from given body id.
            This thunk is a function that will invoke the function whose body id is given.
            The thunk's name in the JIT's global context is returned."""
-        # We might have compiled the function with the given body id already. In that case,
-        # we need not bother with constructing the thunk; we can return the compiled function
-        # right away.
         if self.lookup_compiled_body(body_id) is not None:
-            raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
-
-        # Looks like we'll just have to build that thunk after all.
-        yield [
-            ("TAIL_CALL_ARGS",
-             [self.jit_thunk, (tree_ir.LiteralInstruction(body_id),)])]
+            # We might have compiled the function with the given body id already. In that case,
+            # we need not bother with constructing the thunk; we can return the compiled function
+            # right away.
+            return self.get_compiled_name(body_id)
+        else:
+            # Looks like we'll just have to build that thunk after all.
+            return self.jit_thunk(tree_ir.LiteralInstruction(body_id))
 
     def jit_thunk_global(self, global_name):
         """Creates a thunk from given global name.
@@ -553,7 +559,7 @@ class ModelverseJit(object):
         # right away.
         body_id = self.get_global_body_id(global_name)
         if body_id is not None and self.lookup_compiled_body(body_id) is not None:
-            raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
+            return self.get_compiled_name(body_id)
 
         # Looks like we'll just have to build that thunk after all.
         # We want to look up the global function like so
@@ -563,18 +569,16 @@ class ModelverseJit(object):
         #     function_id, = yield [("RD", [global_var, "value"])]
         #     body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
         #
-        yield [
-            ("TAIL_CALL_ARGS",
-             [self.jit_thunk,
-              (tree_ir.ReadDictionaryValueInstruction(
-                  tree_ir.ReadDictionaryValueInstruction(
-                      tree_ir.ReadDictionaryValueInstruction(
-                          tree_ir.ReadDictionaryValueInstruction(
-                              tree_ir.LoadIndexInstruction(
-                                  tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
-                                  tree_ir.LiteralInstruction('user_root')),
-                              'globals'),
-                          global_name),
-                      'value'),
-                  tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
-               global_name)])]
+        return self.jit_thunk(
+            tree_ir.ReadDictionaryValueInstruction(
+                tree_ir.ReadDictionaryValueInstruction(
+                    tree_ir.ReadDictionaryValueInstruction(
+                        tree_ir.ReadDictionaryValueInstruction(
+                            tree_ir.LoadIndexInstruction(
+                                tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
+                                tree_ir.LiteralInstruction('user_root')),
+                            tree_ir.LiteralInstruction('globals')),
+                        tree_ir.LiteralInstruction(global_name)),
+                    tree_ir.LiteralInstruction('value')),
+                tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
+            global_name)

+ 6 - 0
kernel/modelverse_jit/runtime.py

@@ -19,6 +19,12 @@ CALL_FUNCTION_NAME = "__call_function"
 GET_INPUT_FUNCTION_NAME = "__get_input"
 """The name of the '__get_input' function, in the jitted function scope."""
 
+JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant"
+"""The name of the jit_thunk_constant function in the JIT's global context."""
+
+JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
+"""The name of the jit_thunk_global function in the JIT's global context."""
+
 LOCALS_NODE_NAME = "jit_locals"
 """The name of the node that is connected to all JIT locals in a given function call."""
 

+ 4 - 0
kernel/modelverse_kernel/main.py

@@ -51,6 +51,10 @@ class ModelverseKernel(object):
         #
         #     self.jit.allow_direct_calls(False)
         #
+        # To disable thunks in the JIT, uncomment the line below:
+        #
+        #     self.jit.enable_thunks(False)
+        #
         # To make the JIT compile 'input' instructions as calls to
         # modelverse_jit.runtime.get_input, uncomment the line below:
         #