Explorar o código

Use thunks in the JIT

jonathanvdc %!s(int64=9) %!d(string=hai) anos
pai
achega
18a7dc7f61

+ 27 - 9
kernel/modelverse_jit/bytecode_to_tree.py

@@ -588,15 +588,33 @@ class AnalysisState(object):
                 self.register_function_var(target.pointer.variable.node_id)
                 self.register_function_var(target.pointer.variable.node_id)
                 resolved_var_name = target.pointer.variable.name
                 resolved_var_name = target.pointer.variable.name
 
 
-                # Try to look up the name as a global.
-                _globals, = yield [("RD", [self.user_root, "globals"])]
-                global_var, = yield [("RD", [_globals, resolved_var_name])]
-                global_val, = yield [("RD", [global_var, "value"])]
-
-                if global_val is not None:
-                    result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
-                        global_val, resolved_var_name, argument_list)])]
-                    raise primitive_functions.PrimitiveFinished(result)
+                if self.jit.thunks_enabled:
+                    # Analyze the argument dictionary.
+                    named_args, = yield [("CALL_ARGS", [self.analyze_arguments, (argument_list,)])]
+
+                    # Try to resolve the callee as an intrinsic.
+                    intrinsic = self.jit.get_intrinsic(resolved_var_name)
+                    if intrinsic is not None:
+                        raise primitive_functions.PrimitiveFinished(
+                            apply_intrinsic(intrinsic, named_args))
+
+                    # Otherwise, build a thunk.
+                    thunk_name = self.jit.jit_thunk_global(target.pointer.variable.name)
+                    raise primitive_functions.PrimitiveFinished(
+                        tree_ir.create_jit_call(
+                            tree_ir.LoadGlobalInstruction(thunk_name),
+                            named_args,
+                            tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME)))
+                else:
+                    # Try to look up the name as a global.
+                    _globals, = yield [("RD", [self.user_root, "globals"])]
+                    global_var, = yield [("RD", [_globals, resolved_var_name])]
+                    global_val, = yield [("RD", [global_var, "value"])]
+
+                    if global_val is not None:
+                        result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
+                            global_val, resolved_var_name, argument_list)])]
+                        raise primitive_functions.PrimitiveFinished(result)
         elif isinstance(target, bytecode_ir.ConstantInstruction):
         elif isinstance(target, bytecode_ir.ConstantInstruction):
             # 'const(func_id)' instructions are also translated to direct calls.
             # 'const(func_id)' instructions are also translated to direct calls.
             result, = yield [("CALL_ARGS", [self.analyze_direct_call, (
             result, = yield [("CALL_ARGS", [self.analyze_direct_call, (

+ 52 - 48
kernel/modelverse_jit/jit.py

@@ -74,9 +74,6 @@ def create_function(
     constructed_body = tree_ir.create_block(
     constructed_body = tree_ir.create_block(
         *(prologue_statements + [function_body]))
         *(prologue_statements + [function_body]))
 
 
-    # Optimize the function's body.
-    constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
-
     # Shield temporaries from the GC.
     # Shield temporaries from the GC.
     constructed_body = tree_ir.protect_temporaries_from_gc(
     constructed_body = tree_ir.protect_temporaries_from_gc(
         constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
         constructed_body, tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
@@ -87,7 +84,7 @@ def create_function(
         parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
         parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
         constructed_body)
         constructed_body)
 
 
-    raise primitive_functions.PrimitiveFinished(constructed_function)
+    return constructed_function
 
 
 def print_value(val):
 def print_value(val):
     """A thin wrapper around 'print'."""
     """A thin wrapper around 'print'."""
@@ -102,7 +99,9 @@ class ModelverseJit(object):
         self.jit_globals = {
         self.jit_globals = {
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
             jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
             jit_runtime.CALL_FUNCTION_NAME : jit_runtime.call_function,
-            jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input
+            jit_runtime.GET_INPUT_FUNCTION_NAME : jit_runtime.get_input,
+            jit_runtime.JIT_THUNK_CONSTANT_FUNCTION_NAME : self.jit_thunk_constant,
+            jit_runtime.JIT_THUNK_GLOBAL_FUNCTION_NAME : self.jit_thunk_global
         }
         }
         # jitted_entry_points maps body ids to values in jit_globals.
         # jitted_entry_points maps body ids to values in jit_globals.
         self.jitted_entry_points = {}
         self.jitted_entry_points = {}
@@ -123,6 +122,7 @@ class ModelverseJit(object):
         self.tracing_enabled = False
         self.tracing_enabled = False
         self.input_function_enabled = False
         self.input_function_enabled = False
         self.nop_insertion_enabled = True
         self.nop_insertion_enabled = True
+        self.thunks_enabled = True
         self.jit_success_log_function = None
         self.jit_success_log_function = None
         self.jit_code_log_function = None
         self.jit_code_log_function = None
 
 
@@ -143,12 +143,21 @@ class ModelverseJit(object):
         self.tracing_enabled = is_enabled
         self.tracing_enabled = is_enabled
 
 
     def enable_nop_insertion(self, is_enabled=True):
     def enable_nop_insertion(self, is_enabled=True):
-        """Enables or disables nop insertion for jitted code. The JIT will insert nops at loop
-           back-edges. Inserting nops sacrifices performance to keep the jitted code from
-           blocking the thread of execution by consuming all resources; nops give the
-           Modelverse server an opportunity to interrupt the currently running code."""
+        """Enables or disables nop insertion for jitted code. If enabled, the JIT will
+           insert nops at loop back-edges. Inserting nops sacrifices performance to
+           keep the jitted code from blocking the thread of execution and consuming
+           all resources; nops give the Modelverse server an opportunity to interrupt
+           the currently running code."""
         self.nop_insertion_enabled = is_enabled
         self.nop_insertion_enabled = is_enabled
 
 
+    def enable_thunks(self, is_enabled=True):
+        """Enables or disables thunks for jitted code. Thunks delay the compilation of
+           functions until they are actually used. Thunks generally reduce start-up
+           time.
+
+           Thunks are enabled by default."""
+        self.thunks_enabled = is_enabled
+
     def set_jit_success_log(self, log_function=print_value):
     def set_jit_success_log(self, log_function=print_value):
         """Configures this JIT instance with a function that prints output to a log.
         """Configures this JIT instance with a function that prints output to a log.
            Success and failure messages for specific functions are then sent to said log."""
            Success and failure messages for specific functions are then sent to said log."""
@@ -429,11 +438,12 @@ class ModelverseJit(object):
         yield [("END_TRY", [])]
         yield [("END_TRY", [])]
         del self.compilation_dependencies[body_id]
         del self.compilation_dependencies[body_id]
 
 
+        # Optimize the function's body.
+        constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
+
         # Wrap the tree IR in a function definition.
         # Wrap the tree IR in a function definition.
-        constructed_function, = yield [
-            ("CALL_ARGS",
-             [create_function,
-              (function_name, parameter_list, param_dict, body_param_dict, constructed_body)])]
+        constructed_function = create_function(
+            function_name, parameter_list, param_dict, body_param_dict, constructed_body)
 
 
         # Convert the function definition to Python code, and compile it.
         # Convert the function definition to Python code, and compile it.
         compiled_function = self.jit_define_function(function_name, constructed_function)
         compiled_function = self.jit_define_function(function_name, constructed_function)
@@ -480,18 +490,17 @@ class ModelverseJit(object):
         #     raise primitive_functions.PrimitiveFinished(<get_function_body>)
         #     raise primitive_functions.PrimitiveFinished(<get_function_body>)
         #
         #
         get_function_body_name = self.generate_name('get_function_body')
         get_function_body_name = self.generate_name('get_function_body')
-        get_function_body_func_def, = yield [
-            ("CALL_ARGS",
-             [create_function,
-              (get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))])]
+        get_function_body_func_def = create_function(
+            get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))
         get_function_body_func = self.jit_define_function(
         get_function_body_func = self.jit_define_function(
             get_function_body_name, get_function_body_func_def)
             get_function_body_name, get_function_body_func_def)
 
 
         # Next, we want to create a thunk that invokes said function, and then replaces itself.
         # Next, we want to create a thunk that invokes said function, and then replaces itself.
+        thunk_name = self.generate_name('thunk', global_name)
         def __jit_thunk(**kwargs):
         def __jit_thunk(**kwargs):
             # Compute the body id, and delete the function that computes the body id; we won't
             # Compute the body id, and delete the function that computes the body id; we won't
             # be needing it anymore after this call.
             # be needing it anymore after this call.
-            body_id = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
+            body_id, = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
             self.jit_delete_function(get_function_body_name)
             self.jit_delete_function(get_function_body_name)
 
 
             # Try to associate the global name with the body id, if that's at all possible.
             # Try to associate the global name with the body id, if that's at all possible.
@@ -501,7 +510,7 @@ class ModelverseJit(object):
             compiled_function = self.lookup_compiled_body(body_id)
             compiled_function = self.lookup_compiled_body(body_id)
             if compiled_function is not None:
             if compiled_function is not None:
                 # Replace this thunk by the compiled function.
                 # Replace this thunk by the compiled function.
-                self.jit_globals[get_function_body_name] = compiled_function
+                self.jit_globals[thunk_name] = compiled_function
             else:
             else:
                 def __handle_jit_exception(_):
                 def __handle_jit_exception(_):
                     # Replace this thunk by a different thunk: one that calls the interpreter
                     # Replace this thunk by a different thunk: one that calls the interpreter
@@ -513,36 +522,33 @@ class ModelverseJit(object):
                         return jit_runtime.interpret_function_body(
                         return jit_runtime.interpret_function_body(
                             body_id, named_arg_dict, **new_kwargs)
                             body_id, named_arg_dict, **new_kwargs)
 
 
-                    self.jit_globals[get_function_body_name] = __interpreter_thunk
+                    self.jit_globals[thunk_name] = __interpreter_thunk
 
 
                 yield [("TRY", [])]
                 yield [("TRY", [])]
                 yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
                 yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
                 compiled_function, = yield [
                 compiled_function, = yield [
                     ("CALL_ARGS",
                     ("CALL_ARGS",
-                     [self.jit_recompile, (kwargs['user_root'], body_id, "jit_thunk")])]
+                     [self.jit_recompile, (kwargs['user_root'], body_id, thunk_name)])]
                 yield [("END_TRY", [])]
                 yield [("END_TRY", [])]
 
 
             # Call the compiled function.
             # Call the compiled function.
             yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
             yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
 
 
-        thunk_name = self.generate_name('thunk', global_name)
         self.jit_globals[thunk_name] = __jit_thunk
         self.jit_globals[thunk_name] = __jit_thunk
-        raise primitive_functions.PrimitiveFinished(thunk_name)
+        return thunk_name
 
 
     def jit_thunk_constant(self, body_id):
     def jit_thunk_constant(self, body_id):
         """Creates a thunk from given body id.
         """Creates a thunk from given body id.
            This thunk is a function that will invoke the function whose body id is given.
            This thunk is a function that will invoke the function whose body id is given.
            The thunk's name in the JIT's global context is returned."""
            The thunk's name in the JIT's global context is returned."""
-        # We might have compiled the function with the given body id already. In that case,
-        # we need not bother with constructing the thunk; we can return the compiled function
-        # right away.
         if self.lookup_compiled_body(body_id) is not None:
         if self.lookup_compiled_body(body_id) is not None:
-            raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
-
-        # Looks like we'll just have to build that thunk after all.
-        yield [
-            ("TAIL_CALL_ARGS",
-             [self.jit_thunk, (tree_ir.LiteralInstruction(body_id),)])]
+            # We might have compiled the function with the given body id already. In that case,
+            # we need not bother with constructing the thunk; we can return the compiled function
+            # right away.
+            return self.get_compiled_name(body_id)
+        else:
+            # Looks like we'll just have to build that thunk after all.
+            return self.jit_thunk(tree_ir.LiteralInstruction(body_id))
 
 
     def jit_thunk_global(self, global_name):
     def jit_thunk_global(self, global_name):
         """Creates a thunk from given global name.
         """Creates a thunk from given global name.
@@ -553,7 +559,7 @@ class ModelverseJit(object):
         # right away.
         # right away.
         body_id = self.get_global_body_id(global_name)
         body_id = self.get_global_body_id(global_name)
         if body_id is not None and self.lookup_compiled_body(body_id) is not None:
         if body_id is not None and self.lookup_compiled_body(body_id) is not None:
-            raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
+            return self.get_compiled_name(body_id)
 
 
         # Looks like we'll just have to build that thunk after all.
         # Looks like we'll just have to build that thunk after all.
         # We want to look up the global function like so
         # We want to look up the global function like so
@@ -563,18 +569,16 @@ class ModelverseJit(object):
         #     function_id, = yield [("RD", [global_var, "value"])]
         #     function_id, = yield [("RD", [global_var, "value"])]
         #     body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
         #     body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
         #
         #
-        yield [
-            ("TAIL_CALL_ARGS",
-             [self.jit_thunk,
-              (tree_ir.ReadDictionaryValueInstruction(
-                  tree_ir.ReadDictionaryValueInstruction(
-                      tree_ir.ReadDictionaryValueInstruction(
-                          tree_ir.ReadDictionaryValueInstruction(
-                              tree_ir.LoadIndexInstruction(
-                                  tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
-                                  tree_ir.LiteralInstruction('user_root')),
-                              'globals'),
-                          global_name),
-                      'value'),
-                  tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
-               global_name)])]
+        return self.jit_thunk(
+            tree_ir.ReadDictionaryValueInstruction(
+                tree_ir.ReadDictionaryValueInstruction(
+                    tree_ir.ReadDictionaryValueInstruction(
+                        tree_ir.ReadDictionaryValueInstruction(
+                            tree_ir.LoadIndexInstruction(
+                                tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
+                                tree_ir.LiteralInstruction('user_root')),
+                            tree_ir.LiteralInstruction('globals')),
+                        tree_ir.LiteralInstruction(global_name)),
+                    tree_ir.LiteralInstruction('value')),
+                tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
+            global_name)

+ 6 - 0
kernel/modelverse_jit/runtime.py

@@ -19,6 +19,12 @@ CALL_FUNCTION_NAME = "__call_function"
 GET_INPUT_FUNCTION_NAME = "__get_input"
 GET_INPUT_FUNCTION_NAME = "__get_input"
 """The name of the '__get_input' function, in the jitted function scope."""
 """The name of the '__get_input' function, in the jitted function scope."""
 
 
+JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant"
+"""The name of the jit_thunk_constant function in the JIT's global context."""
+
+JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
+"""The name of the jit_thunk_global function in the JIT's global context."""
+
 LOCALS_NODE_NAME = "jit_locals"
 LOCALS_NODE_NAME = "jit_locals"
 """The name of the node that is connected to all JIT locals in a given function call."""
 """The name of the node that is connected to all JIT locals in a given function call."""
 
 

+ 4 - 0
kernel/modelverse_kernel/main.py

@@ -51,6 +51,10 @@ class ModelverseKernel(object):
         #
         #
         #     self.jit.allow_direct_calls(False)
         #     self.jit.allow_direct_calls(False)
         #
         #
+        # To disable thunks in the JIT, uncomment the line below:
+        #
+        #     self.jit.enable_thunks(False)
+        #
         # To make the JIT compile 'input' instructions as calls to
         # To make the JIT compile 'input' instructions as calls to
         # modelverse_jit.runtime.get_input, uncomment the line below:
         # modelverse_jit.runtime.get_input, uncomment the line below:
         #
         #