Selaa lähdekoodia

Add theoretical support for thunks to the JIT

jonathanvdc 8 vuotta sitten
vanhempi
commit
8669cc6a4a
2 muutettua tiedostoa jossa 159 lisäystä ja 16 poistoa
  1. 6 3
      kernel/modelverse_jit/bytecode_to_tree.py
  2. 153 13
      kernel/modelverse_jit/jit.py

+ 6 - 3
kernel/modelverse_jit/bytecode_to_tree.py

@@ -516,13 +516,16 @@ class AnalysisState(object):
         intrinsic = self.jit.get_intrinsic(callee_name)
 
         if intrinsic is None:
-            compiled_func = self.jit.lookup_compiled_function(callee_name)
+            if callee_name is not None:
+                self.jit.register_global(body_id, callee_name)
+                compiled_func = self.jit.lookup_compiled_function(callee_name)
+            else:
+                compiled_func = None
+
             if compiled_func is None:
                 # Compile the callee.
                 yield [
                     ("CALL_ARGS", [self.jit.jit_compile, (self.user_root, body_id, callee_name)])]
-            else:
-                self.jit.register_compiled(body_id, compiled_func, callee_name)
 
             # Get the callee's name.
             compiled_func_name = self.jit.get_compiled_name(body_id)

+ 153 - 13
kernel/modelverse_jit/jit.py

@@ -239,10 +239,12 @@ class ModelverseJit(object):
     def register_compiled(self, body_id, compiled_function, function_name=None):
         """Registers a compiled entry point with the JIT."""
         # Get the function's name.
-        function_name = self.generate_function_name(body_id, function_name)
+        actual_function_name = self.generate_function_name(body_id, function_name)
         # Map the body id to the given parameter list.
-        self.jitted_entry_points[body_id] = function_name
-        self.jit_globals[function_name] = compiled_function
+        self.jitted_entry_points[body_id] = actual_function_name
+        self.jit_globals[actual_function_name] = compiled_function
+        if function_name is not None:
+            self.register_global(body_id, function_name)
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
 
@@ -253,18 +255,46 @@ class ModelverseJit(object):
         self.jit_globals[actual_name] = value
         return actual_name
 
-    def lookup_compiled_function(self, name):
-        """Looks up a compiled function by name. Returns a matching function,
+    def __lookup_compiled_body_impl(self, body_id):
+        """Looks up a compiled function by body id. Returns a matching function,
            or None if no function was found."""
-        if name is None:
+        if body_id is not None and body_id in self.jitted_entry_points:
+            return self.jit_globals[self.jitted_entry_points[body_id]]
+        else:
             return None
-        elif name in self.jit_globals:
-            return self.jit_globals[name]
-        elif self.compiled_function_lookup is not None:
-            return self.compiled_function_lookup(name)
+
+    def __lookup_external_body_impl(self, global_name, body_id):
+        """Looks up an external function by global name. Returns a matching function,
+           or None if no function was found."""
+        if self.compiled_function_lookup is not None:
+            result = self.compiled_function_lookup(global_name)
+            if result is not None and body_id is not None:
+                self.register_compiled(body_id, result, global_name)
+
+            return result
         else:
             return None
 
+    def lookup_compiled_body(self, body_id):
+        """Looks up a compiled function by body id. Returns a matching function,
+           or None if no function was found."""
+        result = self.__lookup_compiled_body_impl(body_id)
+        if result is not None:
+            return result
+        else:
+            global_name = self.get_global_name(body_id)
+            return self.__lookup_external_body_impl(global_name, body_id)
+
+    def lookup_compiled_function(self, global_name):
+        """Looks up a compiled function by global name. Returns a matching function,
+           or None if no function was found."""
+        body_id = self.get_global_body_id(global_name)
+        result = self.__lookup_compiled_body_impl(body_id)
+        if result is not None:
+            return result
+        else:
+            return self.__lookup_external_body_impl(global_name, body_id)
+
     def get_intrinsic(self, name):
         """Tries to find an intrinsic version of the function with the
            given name."""
@@ -412,9 +442,6 @@ class ModelverseJit(object):
             self.jit_success_log_function(
                 "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
 
-        if self.jit_code_log_function is not None:
-            self.jit_code_log_function(constructed_function)
-
         raise primitive_functions.PrimitiveFinished(compiled_function)
 
     def jit_define_function(self, function_name, function_def):
@@ -425,6 +452,10 @@ class ModelverseJit(object):
 
         # Convert the function definition to Python code, and compile it.
         exec(str(function_def), self.jit_globals)
+
+        if self.jit_code_log_function is not None:
+            self.jit_code_log_function(function_def)
+
         # Extract the compiled function from the JIT global state.
         return self.jit_globals[function_name]
 
@@ -438,3 +469,112 @@ class ModelverseJit(object):
         # it already exists. (we need to do this for recursive functions)
         function_name = self.generate_function_name(body_id, suggested_name)
         yield [("TAIL_CALL_ARGS", [self.jit_recompile, (user_root, body_id, function_name)])]
+
+    def jit_thunk(self, get_function_body, global_name=None):
+        """Creates a thunk from the given IR tree that computes the function's body id.
+           This thunk is a function that will invoke the function whose body id is retrieved.
+           The thunk's name in the JIT's global context is returned."""
+        # The general idea is to first create a function that looks a bit like this:
+        #
+        # def jit_get_function_body(**kwargs):
+        #     raise primitive_functions.PrimitiveFinished(<get_function_body>)
+        #
+        get_function_body_name = self.generate_name('get_function_body')
+        get_function_body_func_def, = yield [
+            ("CALL_ARGS",
+             [create_function,
+              (get_function_body_name, [], {}, {}, tree_ir.ReturnInstruction(get_function_body))])]
+        get_function_body_func = self.jit_define_function(
+            get_function_body_name, get_function_body_func_def)
+
+        # Next, we want to create a thunk that invokes said function, and then replaces itself.
+        def __jit_thunk(**kwargs):
+            # Compute the body id, and delete the function that computes the body id; we won't
+            # be needing it anymore after this call.
+            body_id = yield [("CALL_KWARGS", [get_function_body_func, kwargs])]
+            self.jit_delete_function(get_function_body_name)
+
+            # Try to associate the global name with the body id, if that's at all possible.
+            if global_name is not None:
+                self.register_global(body_id, global_name)
+
+            compiled_function = self.lookup_compiled_body(body_id)
+            if compiled_function is not None:
+                # Replace this thunk by the compiled function.
+                self.jit_globals[get_function_body_name] = compiled_function
+            else:
+                def __handle_jit_exception(_):
+                    # Replace this thunk by a different thunk: one that calls the interpreter
+                    # directly, without checking if the function is jittable.
+                    (_, parameter_names, _), = yield [
+                        ("CALL_ARGS", [self.jit_signature, (body_id,)])]
+                    def __interpreter_thunk(**new_kwargs):
+                        named_arg_dict = {name : new_kwargs[name] for name in parameter_names}
+                        return jit_runtime.interpret_function_body(
+                            body_id, named_arg_dict, **new_kwargs)
+
+                    self.jit_globals[get_function_body_name] = __interpreter_thunk
+
+                yield [("TRY", [])]
+                yield [("CATCH", [JitCompilationFailedException, __handle_jit_exception])]
+                compiled_function, = yield [
+                    ("CALL_ARGS",
+                     [self.jit_recompile, (kwargs['user_root'], body_id, "jit_thunk")])]
+                yield [("END_TRY", [])]
+
+            # Call the compiled function.
+            yield [("TAIL_CALL_KWARGS", [compiled_function, kwargs])]
+
+        thunk_name = self.generate_name('thunk', global_name)
+        self.jit_globals[thunk_name] = __jit_thunk
+        raise primitive_functions.PrimitiveFinished(thunk_name)
+
+    def jit_thunk_constant(self, body_id):
+        """Creates a thunk from given body id.
+           This thunk is a function that will invoke the function whose body id is given.
+           The thunk's name in the JIT's global context is returned."""
+        # We might have compiled the function with the given body id already. In that case,
+        # we need not bother with constructing the thunk; we can return the compiled function
+        # right away.
+        if self.lookup_compiled_body(body_id) is not None:
+            raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
+
+        # Looks like we'll just have to build that thunk after all.
+        yield [
+            ("TAIL_CALL_ARGS",
+             [self.jit_thunk, (tree_ir.LiteralInstruction(body_id),)])]
+
+    def jit_thunk_global(self, global_name):
+        """Creates a thunk from given global name.
+           This thunk is a function that will invoke the function whose body id is given.
+           The thunk's name in the JIT's global context is returned."""
+        # We might have compiled the function with the given name already. In that case,
+        # we need not bother with constructing the thunk; we can return the compiled function
+        # right away.
+        body_id = self.get_global_body_id(global_name)
+        if body_id is not None and self.lookup_compiled_body(body_id) is not None:
+            raise primitive_functions.PrimitiveFinished(self.get_compiled_name(body_id))
+
+        # Looks like we'll just have to build that thunk after all.
+        # We want to look up the global function like so
+        #
+        #     _globals, = yield [("RD", [kwargs['user_root'], "globals"])]
+        #     global_var, = yield [("RD", [_globals, global_name])]
+        #     function_id, = yield [("RD", [global_var, "value"])]
+        #     body_id, = yield [("RD", [function_id, jit_runtime.FUNCTION_BODY_KEY])]
+        #
+        yield [
+            ("TAIL_CALL_ARGS",
+             [self.jit_thunk,
+              (tree_ir.ReadDictionaryValueInstruction(
+                  tree_ir.ReadDictionaryValueInstruction(
+                      tree_ir.ReadDictionaryValueInstruction(
+                          tree_ir.ReadDictionaryValueInstruction(
+                              tree_ir.LoadIndexInstruction(
+                                  tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
+                                  tree_ir.LiteralInstruction('user_root')),
+                              'globals'),
+                          global_name),
+                      'value'),
+                  tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
+               global_name)])]