Browse Source

Configure JIT codegen to use the 'RUN' instruction

jonathanvdc 8 years ago
parent
commit
1273d956f3

+ 7 - 6
kernel/modelverse_jit/jit.py

@@ -234,8 +234,9 @@ class ModelverseJit(object):
         """Tries to jit the function defined by the given entry point id and parameter list."""
         # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
         # pylint: disable=I0011,W0122
-
-        if body_id in self.jitted_entry_points:
+        if body_id is None:
+            raise ValueError('body_id cannot be None')
+        elif body_id in self.jitted_entry_points:
             # We have already compiled this function.
             raise primitive_functions.PrimitiveFinished(
                 self.jit_globals[self.jitted_entry_points[body_id]])
@@ -317,7 +318,7 @@ class ModelverseJit(object):
         # Extract the compiled function from the JIT global state.
         compiled_function = self.jit_globals[function_name]
 
-        # print(constructed_function)
+        print(constructed_function)
         raise primitive_functions.PrimitiveFinished(compiled_function)
 
 class AnalysisState(object):
@@ -584,7 +585,7 @@ class AnalysisState(object):
         # Possible alternative to the explicit syntax tree:
         #
         # raise primitive_functions.PrimitiveFinished(
-        #     tree_ir.JitCallInstruction(
+        #     tree_ir.create_jit_call(
         #         tree_ir.LoadGlobalInstruction('__get_input'),
         #         [],
         #         tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
@@ -892,7 +893,7 @@ class AnalysisState(object):
                 apply_intrinsic(intrinsic, named_args))
         else:
             raise primitive_functions.PrimitiveFinished(
-                tree_ir.JitCallInstruction(
+                tree_ir.create_jit_call(
                     tree_ir.LoadGlobalInstruction(compiled_func_name),
                     named_args,
                     tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
@@ -945,7 +946,7 @@ class AnalysisState(object):
         dict_literal = tree_ir.DictionaryLiteralInstruction(
             [(tree_ir.LiteralInstruction(key), val) for key, val in named_args])
         raise primitive_functions.PrimitiveFinished(
-            tree_ir.JitCallInstruction(
+            tree_ir.create_jit_call(
                 tree_ir.LoadGlobalInstruction(CALL_FUNCTION_NAME),
                 [('function_id', func_val), ('named_arguments', dict_literal)],
                 tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))

+ 3 - 3
kernel/modelverse_jit/runtime.py

@@ -20,11 +20,11 @@ def call_function(function_id, named_arguments, **kwargs):
     # frame.
     try:
         # Try to compile.
-        compiled_func, = yield [("RUN", kernel.jit_compile(user_root, body_id))]
+        compiled_func, = yield [("RUN", [kernel.jit_compile(user_root, body_id)])]
         # Add the keyword arguments to the argument dictionary.
         named_arguments.update(kwargs)
         # Run the function.
-        result, = yield [("RUN", compiled_func, named_arguments)]
+        result, = yield [("RUN", [compiled_func, named_arguments])]
         # Return.
         raise primitive_functions.PrimitiveFinished(result)
     except JitCompilationFailedException:
@@ -80,7 +80,7 @@ def call_function(function_id, named_arguments, **kwargs):
     username = kwargs['username']
     try:
         while 1:
-            result, = yield [("RUN", kernel.execute_rule(username))]
+            result, = yield [("RUN", [kernel.execute_rule(username)])]
             yield result
     except primitive_functions.InterpretedFunctionFinished as ex:
         raise primitive_functions.PrimitiveFinished(ex.result)

+ 70 - 69
kernel/modelverse_jit/tree_ir.py

@@ -192,7 +192,7 @@ class PythonGenerator(object):
             ', '.join([arg_i.generate_python_use(self) for arg_i in args]))
         self.state_definitions.append((result_name, request_tuple))
         self.state_definition_names.add(result_name)
-        if not self.combine_state_definitions:
+        if not self.combine_state_definitions or opcode == "RUN":
             self.flush_state_definitions()
 
     def flush_state_definitions(self):
@@ -366,63 +366,6 @@ class CallInstruction(Instruction):
                 self.target.generate_python_use(code_generator),
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
 
-class JitCallInstruction(Instruction):
-    """An instruction that calls a jitted function."""
-    def __init__(self, target, named_args, kwarg):
-        Instruction.__init__(self)
-        self.target = target
-        self.named_args = named_args
-        self.kwarg = kwarg
-
-    def get_children(self):
-        """Gets this instruction's sequence of child instructions."""
-        return [self.target] + [arg for _, arg in self.named_args] + [self.kwarg]
-
-    def create(self, new_children):
-        """Creates a new instruction of this type from the given sequence of child instructions."""
-        param_names = [name for name, _ in self.named_args]
-        return JitCallInstruction(
-            new_children[0], zip(param_names, new_children[1:-1]), new_children[-1])
-
-    def generate_python_def(self, code_generator):
-        """Generates Python code for this instruction."""
-        if self.target.has_definition():
-            self.target.generate_python_def(code_generator)
-
-        arg_list = []
-        for param_name, arg in self.named_args:
-            if arg.has_definition():
-                arg.generate_python_def(code_generator)
-
-            arg_list.append(
-                '%s=%s' % (param_name, arg.generate_python_use(code_generator)))
-
-        if self.kwarg.has_definition():
-            self.kwarg.generate_python_def(code_generator)
-
-        arg_list.append(
-            '**%s' % self.kwarg.generate_python_use(code_generator))
-
-        own_name = code_generator.get_result_name(self)
-        code_generator.append_line('try:')
-        code_generator.increase_indentation()
-        code_generator.append_line(
-            '%s_gen = %s(%s)' % (
-                own_name,
-                self.target.generate_python_use(code_generator),
-                ', '.join(arg_list)))
-        code_generator.append_line('%s_inp = None' % own_name)
-        code_generator.append_line('while 1:')
-        code_generator.increase_indentation()
-        code_generator.append_line(
-            '%s_inp = yield %s_gen.send(%s_inp)' % (own_name, own_name, own_name))
-        code_generator.decrease_indentation()
-        code_generator.decrease_indentation()
-        code_generator.append_line('except PrimitiveFinished as %s_ex:' % own_name)
-        code_generator.increase_indentation()
-        code_generator.append_line('%s = %s_ex.result' % (own_name, own_name))
-        code_generator.decrease_indentation()
-
 class PrintInstruction(Instruction):
     """An instruction that prints a value."""
     def __init__(self, argument):
@@ -661,7 +604,7 @@ class DictionaryLiteralInstruction(Instruction):
         keys = [k for k, _ in self.key_value_pairs]
         return DictionaryLiteralInstruction(zip(keys, new_children))
 
-    def has_definition(self):
+    def has_definition_impl(self):
         """Tells if this instruction requires a definition."""
         return any(
             [key.has_definition() or val.has_definition()
@@ -672,23 +615,37 @@ class DictionaryLiteralInstruction(Instruction):
         return DictionaryLiteralInstruction(
             [(key.simplify(), val.simplify()) for key, val in self.key_value_pairs])
 
+    def generate_dictionary_expr(self, code_generator):
+        """Generates an expression that creates this dictionary."""
+        return '{ %s }' % ', '.join(
+            ['%s : %s' % (
+                key.generate_python_use(code_generator),
+                val.generate_python_use(code_generator))
+             for key, val in self.key_value_pairs])
+
     def generate_python_def(self, code_generator):
         """Generates a Python statement that executes this instruction.
             The statement is appended immediately to the code generator."""
-        for key, val in self.key_value_pairs:
-            if key.has_definition():
-                key.generate_python_def(code_generator)
-            if val.has_definition():
-                val.generate_python_def(code_generator)
+        if self.has_definition():
+            for key, val in self.key_value_pairs:
+                if key.has_definition():
+                    key.generate_python_def(code_generator)
+                if val.has_definition():
+                    val.generate_python_def(code_generator)
+
+            code_generator.append_line('%s = %s' % (
+                code_generator.get_result_name(self),
+                self.generate_dictionary_expr(code_generator)))
+        else:
+            code_generator.append_line('pass')
 
     def generate_python_use(self, code_generator):
         """Generates a Python expression that retrieves this instruction's
            result. The expression is returned as a string."""
-        return '{ %s }' % ', '.join(
-            ['%s : %s' % (
-                key.generate_python_use(code_generator),
-                val.generate_python_use(code_generator))
-             for key, val in self.key_value_pairs])
+        if self.has_definition():
+            return code_generator.get_result_name(self)
+        else:
+            return self.generate_dictionary_expr(code_generator)
 
 class StateInstruction(Instruction):
     """An instruction that accesses the modelverse state."""
@@ -719,6 +676,21 @@ class StateInstruction(Instruction):
 
         code_generator.append_state_definition(self, self.get_opcode(), args)
 
+class RunGeneratorFunctionInstruction(StateInstruction):
+    """An instruction that runs a generator function."""
+    def __init__(self, function, argument_dict):
+        StateInstruction.__init__(self)
+        self.function = function
+        self.argument_dict = argument_dict
+
+    def get_opcode(self):
+        """Gets the opcode for this state instruction."""
+        return "RUN"
+
+    def get_arguments(self):
+        """Gets this state instruction's argument list."""
+        return [self.function, self.argument_dict]
+
 class VariableName(object):
     """A data structure that unifies names across instructions that access the
        same variable."""
@@ -1180,6 +1152,35 @@ def create_block(*statements):
             statements[0],
             create_block(*statements[1:]))
 
+def create_jit_call(target, named_arguments, kwargs):
+    """Creates a call that abides by the JIT's calling convention."""
+    # A JIT call looks like this:
+    #
+    # target = ...
+    # arg_dict = { ... }
+    # arg_dict.update(kwargs)
+    # result, = yield [("RUN", [target, arg_dict])]
+
+    results = []
+    if target.has_definition():
+        target_tmp = StoreLocalInstruction(None, target)
+        results.append(target_tmp)
+        target = target_tmp.create_load()
+
+    arg_dict = StoreLocalInstruction(
+        None,
+        DictionaryLiteralInstruction(
+            [(LiteralInstruction(key), val) for key, val in named_arguments]))
+    results.append(arg_dict)
+    results.append(
+        CallInstruction(
+            LoadMemberInstruction(arg_dict.create_load(), 'update'),
+            [kwargs]))
+    return CompoundInstruction(
+        create_block(*results),
+        RunGeneratorFunctionInstruction(
+            target, arg_dict.create_load()))
+
 def with_debug_info_trace(instruction, debug_info, function_name):
     """Prepends the given instruction with a tracing instruction that prints
        the given debug information and function name."""

+ 25 - 13
kernel/modelverse_kernel/main.py

@@ -40,19 +40,31 @@ def pop_requests(requests, opStack):
         if len(requests) == 0:
             set_finished_requests_flag(opStack)
 
-        request_length = len(first_request)
-        if request_length == 2:
-            # Format: ("RUN", gen)
-            _, gen = first_request
+        _, request_args = first_request
+        if len(request_args) == 1:
+            # Format: ("RUN", [gen])
+            gen, = request_args
             push_generator(gen, opStack)
         else:
-            # Format: ("RUN", func, kwargs)
+            # Format: ("RUN", [func, kwargs])
             # This format is useful because it also works for functions that
             # throw an exception but never yield.
-            _, func, kwargs = first_request
-            push_generator(func(**kwargs), opStack)
-
-        raise RunRequest()
+            func, kwargs = request_args
+            # We need to be extra careful here, because func(**kwargs) might
+            # not be a generator at all: it might simply be a method that
+            # raises an exception. To cope with this we need to push a dummy
+            # entry onto the stack if a StopIteration or PrimtiveFinished
+            # exception is thrown. The logic in execute_yields will then pop
+            # that dummy entry.
+            try:
+                push_generator(func(**kwargs), opStack)
+                raise RunRequest()
+            except StopIteration:
+                push_generator(None, opStack)
+                raise
+            except primitive_functions.PrimitiveFinished as ex:
+                push_generator(None, opStack)
+                raise
     else:
         # The state handles all other requests.
         if len(requests) == 0:
@@ -222,12 +234,12 @@ class ModelverseKernel(object):
                 raise Exception("%s: error understanding command (%s, %s)" % (self.debug_info[username], inst_v, self.phase_v))
 
         try:
-            yield [("RUN", gen)]
+            yield [("RUN", [gen])]
         except jit.JitCompilationFailedException as e:
             # Try again, but this time without the JIT.
             # print(e.message)
             gen = self.get_inst_phase_generator(inst_v, self.phase_v, user_root)
-            yield [("RUN", gen)]
+            yield [("RUN", [gen])]
 
     def get_inst_phase_generator(self, inst_v, phase_v, user_root):
         """Gets a generator for the given instruction in the given phase,
@@ -280,9 +292,9 @@ class ModelverseKernel(object):
         parameters["mvk"] = self
 
         # Have the JIT compile the function.
-        compiled_func, = yield [("RUN", self.jit_compile(user_root, inst))]
+        compiled_func, = yield [("RUN", [self.jit_compile(user_root, inst)])]
         # Run the compiled function.
-        results = yield [("RUN", compiled_func, parameters)]
+        results = yield [("RUN", [compiled_func, parameters])]
         if results is None:
             raise Exception(
                 "%s: primitive finished without returning a value!" % (self.debug_info[username]))