Browse Source

Implement calls to interpreted code from jitted code

jonathanvdc 8 years ago
parent
commit
a3d3e5f006
3 changed files with 173 additions and 12 deletions
  1. 62 11
      kernel/modelverse_jit/jit.py
  2. 76 0
      kernel/modelverse_jit/runtime.py
  3. 35 1
      kernel/modelverse_jit/tree_ir.py

+ 62 - 11
kernel/modelverse_jit/jit.py

@@ -1,9 +1,13 @@
 import modelverse_kernel.primitives as primitive_functions
 import modelverse_jit.tree_ir as tree_ir
+import modelverse_jit.runtime as jit_runtime
 
 KWARGS_PARAMETER_NAME = "kwargs"
 """The name of the kwargs parameter in jitted functions."""
 
+INTERPRET_FUNCTION_NAME = "__interpret_function"
+"""The name of the '__interpret_function' function, in the jitted function scope."""
+
 def get_parameter_names(compiled_function):
     """Gets the given compiled function's parameter names."""
     if hasattr(compiled_function, '__code__'):
@@ -42,7 +46,8 @@ class ModelverseJit(object):
         self.jitted_entry_points = {}
         self.jitted_parameters = {}
         self.jit_globals = {
-            'PrimitiveFinished' : primitive_functions.PrimitiveFinished
+            'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
+            INTERPRET_FUNCTION_NAME : jit_runtime.interpret_function
         }
         self.jit_count = 0
         self.max_instructions = max_instructions
@@ -825,11 +830,40 @@ class AnalysisState(object):
 
         raise primitive_functions.PrimitiveFinished(named_args)
 
-    def analyze_call(self, instruction_id):
-        """Tries to analyze the given 'call' instruction."""
-        func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
-                                          ("RD", [instruction_id, "params"])]
+    def analyze_indirect_call(self, func_id, first_arg_id):
+        """Analyzes a call to an unknown function."""
+
+        # First off, let's analyze the callee and the argument list.
+        try:
+            gen = self.analyze(func_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            func_val = ex.result
+
+        try:
+            gen = self.analyze_arguments(first_arg_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            named_args = ex.result
+
+        # Call the __interpret_function function to run the interpreter, like so:
+        #
+        # __interpret_function(function_id, { first_param_name : first_param_val, ... }, **kwargs)
+        #
+        dict_literal = tree_ir.DictionaryLiteralInstruction(
+            [(tree_ir.LiteralInstruction(key), val) for key, val in named_args])
+        raise primitive_functions.PrimitiveFinished(
+            tree_ir.JitCallInstruction(
+                tree_ir.LoadGlobalInstruction(INTERPRET_FUNCTION_NAME),
+                [('function_id', func_val), ('named_arguments', dict_literal)],
+                tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
 
+    def try_analyze_direct_call(self, func_id, first_param_id):
+        """Tries to analyze the given 'call' instruction as a direct call."""
         # Figure out what the 'func' instruction's type is.
         func_instruction_op, = yield [("RV", [func_id])]
         if func_instruction_op['value'] == 'access':
@@ -840,15 +874,12 @@ class AnalysisState(object):
                 resolved_var_id, = yield [("RD", [access_value_id, "var"])]
                 resolved_var_name, = yield [("RV", [resolved_var_id])]
 
-                # Try to look the name up as a global.
+                # Try to look up the name as a global.
                 _globals, = yield [("RD", [self.user_root, "globals"])]
                 global_var, = yield [("RD", [_globals, resolved_var_name])]
                 global_val, = yield [("RD", [global_var, "value"])]
 
-                if global_val is None:
-                    raise JitCompilationFailedException(
-                        "Cannot JIT function calls that target an unknown value.")
-                else:
+                if global_val is not None:
                     gen = self.analyze_direct_call(
                         global_val, resolved_var_name, first_param_id)
                     inp = None
@@ -856,8 +887,28 @@ class AnalysisState(object):
                         inp = yield gen.send(inp)
                     # PrimitiveFinished exception will bubble up from here.
 
-        raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
+        raise JitCompilationFailedException(
+            "Cannot JIT function calls that target an unknown value as direct calls.")
 
+    def analyze_call(self, instruction_id):
+        """Tries to analyze the given 'call' instruction."""
+        func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
+                                          ("RD", [instruction_id, "params"])]
+
+        try:
+            # Try to analyze the call as a direct call.
+            gen = self.try_analyze_direct_call(func_id, first_param_id)
+            inp = None
+            while 1:
+                inp = yield gen.send(inp)
+            # PrimitiveFinished exception will bubble up from here.
+        except JitCompilationFailedException:
+            # Looks like we'll have to compile it as an indirect call.
+            gen = self.analyze_indirect_call(func_id, first_param_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+            # PrimitiveFinished exception will bubble up from here.
 
     instruction_analyzers = {
         'if' : analyze_if,

+ 76 - 0
kernel/modelverse_jit/runtime.py

@@ -0,0 +1,76 @@
+import modelverse_kernel.primitives as primitive_functions
+
+def interpret_function(function_id, named_arguments, **kwargs):
+    """Makes the interpreter run the function with the given id with the specified
+       argument dictionary."""
+
+    user_root = kwargs['user_root']
+    user_frame, = yield [("RD", [user_root, "frame"])]
+    inst, = yield [("RD", [user_frame, "IP"])]
+
+    body_id, = yield [("RD", [function_id, "body"])]
+
+    # Create a new stack frame.
+    frame_link, new_phase, new_frame, new_evalstack, new_symbols, new_returnvalue = \
+                    yield [("RDE", [user_root, "frame"]),
+                           ("CNV", ["init"]),
+                           ("CN", []),
+                           ("CN", []),
+                           ("CN", []),
+                           ("CN", []),
+                          ]
+
+    _, _, _, _, _, _, _, _, _ = \
+                    yield [("CD", [user_root, "frame", new_frame]),
+                           ("CD", [new_frame, "evalstack", new_evalstack]),
+                           ("CD", [new_frame, "symbols", new_symbols]),
+                           ("CD", [new_frame, "returnvalue", new_returnvalue]),
+                           ("CD", [new_frame, "caller", inst]),
+                           ("CD", [new_frame, "phase", new_phase]),
+                           ("CD", [new_frame, "IP", body_id]),
+                           ("CD", [new_frame, "prev", user_frame]),
+                           ("DE", [frame_link]),
+                          ]
+
+    # Put the parameters in the new stack frame's symbol table.
+    kernel = kwargs['mvk']
+    try:
+        gen = kernel.jit.jit_parameters(body_id)
+        inp = None
+        while 1:
+            inp = yield gen.send(inp)
+    except primitive_functions.PrimitiveFinished as ex:
+        parameter_vars, parameter_names = ex.result
+        parameter_dict = dict(zip(parameter_names, parameter_vars))
+
+    for (key, value) in named_arguments.items():
+        param_var = parameter_dict[key]
+        variable, = yield [("CN", [])]
+        yield [("CD", [variable, "value", value])]
+        symbol_edge, = yield [("CE", [new_symbols, variable])]
+        yield [("CE", [symbol_edge, param_var])]
+
+    username = kwargs['username']
+    while 1:
+        try:
+            gen = kernel.execute_rule(username)
+            inp = None
+            while 1:
+                inp = yield gen.send(inp)
+        except StopIteration:
+            # An instruction has been completed. Check if we've already returned.
+            #
+            # TODO: the statement below performs O(n) state reads whenever an instruction
+            # finishes, where n is the number of 'interpret_function' stack frames.
+            # I don't *think* that this is problematic (at least not in the short term),
+            # but an O(1) solution would obviously be much better; that's the interpreter's
+            # complexity. Perhaps we can annotate the stack frame we create here with a marker
+            # that the kernel can pick up on? We could have the kernel throw an exception whenever
+            # it encounters said marker.
+            current_user_frame, = yield [("RD", [user_root, "frame"])]
+            if current_user_frame == user_frame:
+                # We're done here. Extract the return value and get out.
+                returnvalue, = yield [("RD", [user_frame, "returnvalue"])]
+                raise primitive_functions.PrimitiveFinished(returnvalue)
+            else:
+                yield None

+ 35 - 1
kernel/modelverse_jit/tree_ir.py

@@ -495,9 +495,43 @@ class LiteralInstruction(Instruction):
            result. The expression is returned as a string."""
         return repr(self.literal)
 
+class DictionaryLiteralInstruction(Instruction):
+    """Constructs a dictionary literal."""
+    def __init__(self, key_value_pairs):
+        Instruction.__init__(self)
+        self.key_value_pairs = key_value_pairs
+
+    def has_definition(self):
+        """Tells if this instruction requires a definition."""
+        return any(
+            [key.has_definition() or val.has_definition()
+             for key, val in self.key_value_pairs])
+
+    def simplify(self):
+        """Applies basic simplification to this instruction and its children."""
+        return DictionaryLiteralInstruction(
+            [(key.simplify(), val.simplify()) for key, val in self.key_value_pairs])
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+            The statement is appended immediately to the code generator."""
+        for key, val in self.key_value_pairs:
+            if key.has_definition():
+                key.generate_python_def(code_generator)
+            if val.has_definition():
+                val.generate_python_def(code_generator)
+
+    def generate_python_use(self, code_generator):
+        """Generates a Python expression that retrieves this instruction's
+           result. The expression is returned as a string."""
+        return '{ %s }' % ', '.join(
+            ['%s : %s' % (
+                key.generate_python_use(code_generator),
+                val.generate_python_use(code_generator))
+             for key, val in self.key_value_pairs])
+
 class StateInstruction(Instruction):
     """An instruction that accesses the modelverse state."""
-
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         raise NotImplementedError()