Browse Source

Rename interpret_function to call_function, add straight-to-JIT optimization

jonathanvdc 8 years ago
parent
commit
1d067be8c7
3 changed files with 49 additions and 18 deletions
  1. 10 10
      kernel/modelverse_jit/jit.py
  2. 28 1
      kernel/modelverse_jit/runtime.py
  3. 11 7
      kernel/modelverse_kernel/main.py

+ 10 - 10
kernel/modelverse_jit/jit.py

@@ -3,11 +3,15 @@ import modelverse_jit.tree_ir as tree_ir
 import modelverse_jit.runtime as jit_runtime
 import keyword
 
+# Import JitCompilationFailedException because it used to be defined
+# in this module.
+JitCompilationFailedException = jit_runtime.JitCompilationFailedException
+
 KWARGS_PARAMETER_NAME = "kwargs"
 """The name of the kwargs parameter in jitted functions."""
 
-INTERPRET_FUNCTION_NAME = "__interpret_function"
-"""The name of the '__interpret_function' function, in the jitted function scope."""
+CALL_FUNCTION_NAME = "__call_function"
+"""The name of the '__call_function' function, in the jitted function scope."""
 
 def get_parameter_names(compiled_function):
     """Gets the given compiled function's parameter names."""
@@ -79,10 +83,6 @@ def optimize_tree_ir(instruction):
     """Optimizes an IR tree."""
     return map_and_simplify_generator(expand_constant_read, instruction)
 
-class JitCompilationFailedException(Exception):
-    """A type of exception that is raised when the jit fails to compile a function."""
-    pass
-
 class ModelverseJit(object):
     """A high-level interface to the modelverse JIT compiler."""
     def __init__(self, max_instructions=None, compiled_function_lookup=None):
@@ -92,7 +92,7 @@ class ModelverseJit(object):
         self.jitted_parameters = {}
         self.jit_globals = {
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
-            INTERPRET_FUNCTION_NAME : jit_runtime.interpret_function
+            CALL_FUNCTION_NAME : jit_runtime.call_function
         }
         self.jit_count = 0
         self.max_instructions = max_instructions
@@ -938,15 +938,15 @@ class AnalysisState(object):
         except primitive_functions.PrimitiveFinished as ex:
             named_args = ex.result
 
-        # Call the __interpret_function function to run the interpreter, like so:
+        # Call the __call_function function to run the interpreter, like so:
         #
-        # __interpret_function(function_id, { first_param_name : first_param_val, ... }, **kwargs)
+        # __call_function(function_id, { first_param_name : first_param_val, ... }, **kwargs)
         #
         dict_literal = tree_ir.DictionaryLiteralInstruction(
             [(tree_ir.LiteralInstruction(key), val) for key, val in named_args])
         raise primitive_functions.PrimitiveFinished(
             tree_ir.JitCallInstruction(
-                tree_ir.LoadGlobalInstruction(INTERPRET_FUNCTION_NAME),
+                tree_ir.LoadGlobalInstruction(CALL_FUNCTION_NAME),
                 [('function_id', func_val), ('named_arguments', dict_literal)],
                 tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
 

+ 28 - 1
kernel/modelverse_jit/runtime.py

@@ -1,6 +1,10 @@
 import modelverse_kernel.primitives as primitive_functions
 
-def interpret_function(function_id, named_arguments, **kwargs):
+class JitCompilationFailedException(Exception):
+    """A type of exception that is raised when the jit fails to compile a function."""
+    pass
+
+def call_function(function_id, named_arguments, **kwargs):
     """Makes the interpreter run the function with the given id with the specified
        argument dictionary."""
 
@@ -12,6 +16,29 @@ def interpret_function(function_id, named_arguments, **kwargs):
     kernel = kwargs['mvk']
     kernel.jit.mark_entry_point(body_id)
 
+    # Try to jit the function here. We might be able to avoid building the stack
+    # frame.
+    try:
+        # Try to compile.
+        jit_gen = kernel.jit_compile(user_root, body_id)
+        inp = None
+        while 1:
+            inp = yield jit_gen.send(inp)
+    except primitive_functions.PrimitiveFinished as ex:
+        compiled_func = ex.result
+        # Add the keyword arguments to the argument dictionary.
+        named_arguments.update(kwargs)
+        # Run the function.
+        gen = compiled_func(**named_arguments)
+        inp = None
+        while 1:
+            inp = yield gen.send(inp)
+        # Let primitive_functions.PrimitiveFinished bubble up.
+    except JitCompilationFailedException:
+        # That's quite alright. Just build a stack frame and hand the function to
+        # the interpreter.
+        pass
+
     # Create a new stack frame.
     frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
         new_returnvalue, intrinsic_return = \

+ 11 - 7
kernel/modelverse_kernel/main.py

@@ -147,6 +147,16 @@ class ModelverseKernel(object):
                 getattr(primitive_functions, function_names[i]), 
                 function_names[i])
 
+    def jit_compile(self, user_root, inst):
+        # Try to retrieve the suggested name.
+        if self.suggested_function_names is not None and inst in self.suggested_function_names:
+            suggested_name = self.suggested_function_names[inst]
+        else:
+            suggested_name = None
+
+        # Have the JIT compile the function.
+        return self.jit.jit_compile(user_root, inst, suggested_name)
+
     def execute_jit(self, user_root, inst, username):
         # execute_jit
         user_frame, =    yield [("RD", [user_root, "frame"])]
@@ -165,15 +175,9 @@ class ModelverseKernel(object):
         parameters["username"] = username
         parameters["mvk"] = self
 
-        # Try to retrieve the suggested name.
-        if self.suggested_function_names is not None and inst in self.suggested_function_names:
-            suggested_name = self.suggested_function_names[inst]
-        else:
-            suggested_name = None
-
         # Have the JIT compile the function.
         try:
-            jit_gen = self.jit.jit_compile(user_root, inst, suggested_name)
+            jit_gen = self.jit_compile(user_root, inst)
             inp = None
             while 1:
                 inp = yield jit_gen.send(inp)