Browse Source

Create a bytecode graph interpreter

jonathanvdc 8 years ago
parent
commit
7fa49badec

+ 317 - 0
kernel/modelverse_jit/bytecode_interpreter.py

@@ -0,0 +1,317 @@
+"""Interprets parsed bytecode graphs."""
+
+import modelverse_jit.bytecode_ir as bytecode_ir
+import modelverse_jit.runtime as jit_runtime
+import modelverse_kernel.primitives as primitive_functions
+
+class BreakException(Exception):
+    """A type of exception that is used to interpret 'break' instructions:
+       the 'break' instructions throw a BreakException, which is then handled
+       by the appropriate 'while' instruction."""
+    def __init__(self, loop):
+        Exception.__init__(self)
+        self.loop = loop
+
+class ContinueException(Exception):
+    """A type of exception that is used to interpret 'continue' instructions:
+       the 'continue' instructions throw a ContinueException, which is then handled
+       by the appropriate 'while' instruction."""
+    def __init__(self, loop):
+        Exception.__init__(self)
+        self.loop = loop
+
+class InterpreterState(object):
+    """The state of the bytecode interpreter."""
+    def __init__(self, gc_root_node, keyword_arg_dict, nop_period=20):
+        self.gc_root_node = gc_root_node
+        self.nop_period = nop_period
+        self.keyword_arg_dict = keyword_arg_dict
+        self.current_result = None
+        self.nop_phase = 0
+        self.local_vars = {}
+
+    def import_local(self, node_id, value):
+        """Imports the given value as a local in this interpreter state."""
+        local_node, = yield [("CN", [])]
+        yield [
+            ("CE", [self.gc_root_node, local_node]),
+            ("CD", [local_node, "value", value])]
+        self.local_vars[node_id] = local_node
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def schedule_nop(self):
+        """Increments the nop-phase. If a nop should be performed, then True is returned.
+           Otherwise, False."""
+        self.nop_phase += 1
+        if self.nop_phase == self.nop_period:
+            self.nop_phase = 0
+            return True
+        else:
+            return False
+
+    def update_result(self, new_result):
+        """Sets the current result to the given value, if it is not None."""
+        if new_result is not None:
+            self.current_result = new_result
+
+    def get_task_root(self):
+        """Gets the task root node id."""
+        return self.keyword_arg_dict['task_root']
+
+    def get_kernel(self):
+        """Gets the Modelverse kernel instance."""
+        return self.keyword_arg_dict['mvk']
+
+    def interpret(self, instruction):
+        """Interprets the given instruction and returns the current result."""
+        instruction_type = type(instruction)
+        if instruction_type in InterpreterState.INTERPRETERS:
+            # Interpret the instruction.
+            yield [("CALL_ARGS",
+                    [InterpreterState.INTERPRETERS[instruction_type], (self, instruction)])]
+
+            # Maybe perform a nop.
+            if self.schedule_nop():
+                yield None
+
+            # Interpret the next instruction.
+            next_instruction = instruction.next_instruction
+            if next_instruction is not None:
+                yield [("TAIL_CALL_ARGS", [self.interpret, (next_instruction,)])]
+            else:
+                raise primitive_functions.PrimitiveFinished(self.current_result)
+        else:
+            raise jit_runtime.JitCompilationFailedException(
+                'Unknown bytecode instruction: %r' % instruction)
+
+    def interpret_select(self, instruction):
+        """Interprets the given 'select' instruction."""
+        cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
+        cond_val, = yield [("RV", [cond_node])]
+        if cond_val:
+            yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.if_clause,)])]
+        elif instruction.else_clause is not None:
+            yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.else_clause,)])]
+        else:
+            raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_while(self, instruction):
+        """Interprets the given 'while' instruction."""
+        def __handle_break(exception):
+            if exception.loop == instruction:
+                # End the loop.
+                raise primitive_functions.PrimitiveFinished(None)
+            else:
+                # Propagate the exception to the next 'while' loop.
+                raise exception
+
+        def __handle_continue(exception):
+            if exception.loop == instruction:
+                # Restart the loop.
+                yield [("TAIL_CALL_ARGS", [self.interpret, (instruction,)])]
+            else:
+                # Propagate the exception to the next 'while' loop.
+                raise exception
+
+        yield [("TRY", [])]
+        yield [("CATCH", [BreakException, __handle_break])]
+        yield [("CATCH", [ContinueException, __handle_continue])]
+        while 1:
+            cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
+            cond_val, = yield [("RV", [cond_node])]
+            if cond_val:
+                yield [("CALL_ARGS", [self.interpret, (instruction.body,)])]
+            else:
+                break
+        yield [("END_TRY", [])]
+
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_break(self, instruction):
+        """Interprets the given 'break' instruction."""
+        raise BreakException(instruction.loop)
+
+    def interpret_continue(self, instruction):
+        """Interprets the given 'continue' instruction."""
+        raise ContinueException(instruction.loop)
+
+    def interpret_return(self, instruction):
+        """Interprets the given 'return' instruction."""
+        if instruction.value is None:
+            raise primitive_functions.InterpretedFunctionFinished(None)
+        else:
+            return_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
+            raise primitive_functions.InterpretedFunctionFinished(return_node)
+
+    def interpret_call(self, instruction):
+        """Interprets the given 'call' instruction."""
+        target, = yield [("CALL_ARGS", [self.interpret, (instruction.target,)])]
+        named_args = {}
+        for name, arg_instruction in instruction.argument_list:
+            arg, = yield [("CALL_ARGS", [self.interpret, (arg_instruction,)])]
+            named_args[name] = arg
+
+        kwargs = {'function_id': target, 'named_arguments': named_args}
+        kwargs.update(self.keyword_arg_dict)
+        result, = yield [("CALL_KWARGS", [jit_runtime.call_function, kwargs])]
+        if result is not None:
+            yield [("CE", [self.gc_root_node, result])]
+            self.update_result(result)
+
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_constant(self, instruction):
+        """Interprets the given 'constant' instruction."""
+        self.update_result(instruction.constant_id)
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_input(self, instruction):
+        """Interprets the given 'input' instruction."""
+        result, = yield [("CALL_KWARGS", [jit_runtime.get_input, self.keyword_arg_dict])]
+        self.update_result(result)
+        yield [("CE", [self.gc_root_node, result])]
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_output(self, instruction):
+        """Interprets the given 'output' instruction."""
+        output_value, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
+        task_root = self.get_task_root()
+        last_output, last_output_link, new_last_output = yield [
+            ("RD", [task_root, "last_output"]),
+            ("RDE", [task_root, "last_output"]),
+            ("CN", [])
+        ]
+        yield [
+            ("CD", [last_output, "value", output_value]),
+            ("CD", [last_output, "next", new_last_output]),
+            ("CD", [task_root, "last_output", new_last_output]),
+            ("DE", [last_output_link])
+        ]
+        yield None
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_declare(self, instruction):
+        """Interprets a 'declare' (local) instruction."""
+        node_id = instruction.variable.node_id
+        if node_id in self.local_vars:
+            self.update_result(self.local_vars[node_id])
+            raise primitive_functions.PrimitiveFinished(None)
+        else:
+            local_node, = yield [("CN", [])]
+            yield [("CE", [self.gc_root_node, local_node])]
+            self.update_result(local_node)
+            self.local_vars[node_id] = local_node
+            raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_global(self, instruction):
+        """Interprets a (declare) 'global' instruction."""
+        var_name = instruction.variable.name
+        task_root = self.get_task_root()
+        _globals, = yield [("RD", [task_root, "globals"])]
+        global_var, = yield [("RD", [_globals, var_name])]
+
+        if global_var is None:
+            global_var, = yield [("CN", [])]
+            yield [("CD", [_globals, var_name, global_var])]
+
+        self.update_result(global_var)
+        yield [("CE", [self.gc_root_node, global_var])]
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_resolve(self, instruction):
+        """Interprets a 'resolve' instruction."""
+        node_id = instruction.variable.node_id
+        if node_id in self.local_vars:
+            self.update_result(self.local_vars[node_id])
+            raise primitive_functions.PrimitiveFinished(None)
+        else:
+            task_root = self.get_task_root()
+            var_name = instruction.variable.name
+            _globals, = yield [("RD", [task_root, "globals"])]
+            global_var, = yield [("RD", [_globals, var_name])]
+
+            if global_var is None:
+                raise Exception(jit_runtime.GLOBAL_NOT_FOUND_MESSAGE_FORMAT % var_name)
+
+            mvk = self.get_kernel()
+            if mvk.suggest_function_names and mvk.jit.get_global_body_id(var_name) is None:
+                global_val, = yield [("RD", [global_var, "value"])]
+                if global_val is not None:
+                    func_body, = yield [("RD", [global_val, "body"])]
+                    if func_body is not None:
+                        mvk.jit.register_global(func_body, var_name)
+
+            self.update_result(global_var)
+            yield [("CE", [self.gc_root_node, global_var])]
+            raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_access(self, instruction):
+        """Interprets an 'access' instruction."""
+        pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
+        value_node, = yield [("RD", [pointer_node, "value"])]
+        self.update_result(value_node)
+        yield [("CE", [self.gc_root_node, value_node])]
+        raise primitive_functions.PrimitiveFinished(None)
+
+    def interpret_assign(self, instruction):
+        """Interprets an 'assign' instruction."""
+        pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
+        value_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
+        value_link, = yield [("RDE", [pointer_node, "value"])]
+        yield [
+            ("CD", [pointer_node, "value", value_node]),
+            ("DE", [value_link])]
+        raise primitive_functions.PrimitiveFinished(None)
+
+    INTERPRETERS = {
+        bytecode_ir.SelectInstruction: interpret_select,
+        bytecode_ir.WhileInstruction: interpret_while,
+        bytecode_ir.BreakInstruction: interpret_break,
+        bytecode_ir.ContinueInstruction: interpret_continue,
+        bytecode_ir.ReturnInstruction: interpret_return,
+        bytecode_ir.CallInstruction: interpret_call,
+        bytecode_ir.ConstantInstruction: interpret_constant,
+        bytecode_ir.InputInstruction: interpret_input,
+        bytecode_ir.OutputInstruction: interpret_output,
+        bytecode_ir.DeclareInstruction: interpret_declare,
+        bytecode_ir.GlobalInstruction: interpret_global,
+        bytecode_ir.ResolveInstruction: interpret_resolve,
+        bytecode_ir.AccessInstruction: interpret_access,
+        bytecode_ir.AssignInstruction: interpret_assign
+    }
+
+def interpret_bytecode_function(function_name, body_bytecode, local_arguments, keyword_arguments):
+    """Interprets the bytecode function with the given name, body, named arguments and
+       keyword arguments."""
+    yield [("DEBUG_INFO", [function_name, None, jit_runtime.BYTECODE_INTERPRETER_ORIGIN_NAME])]
+    task_root = keyword_arguments['task_root']
+    gc_root_node, = yield [("CN", [])]
+    gc_root_edge, = yield [("CE", [task_root, gc_root_node])]
+    interpreter = InterpreterState(gc_root_node, keyword_arguments)
+    for param_id, arg_node in local_arguments.items():
+        yield [("CALL_ARGS", [interpreter.import_local, (param_id, arg_node)])]
+
+    def __handle_return(exception):
+        yield [("DE", [gc_root_edge])]
+        raise primitive_functions.PrimitiveFinished(exception.result)
+
+    def __handle_break(_):
+        raise jit_runtime.UnreachableCodeException(
+            "Function '%s' tries to break out of a loop that is not currently executing." %
+            function_name)
+
+    def __handle_continue(_):
+        raise jit_runtime.UnreachableCodeException(
+            "Function '%s' tries to continue a loop that is not currently executing." %
+            function_name)
+
+    # Perform a nop before interpreting the function.
+    yield None
+
+    yield [("TRY", [])]
+    yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, __handle_return])]
+    yield [("CATCH", [BreakException, __handle_break])]
+    yield [("CATCH", [ContinueException, __handle_continue])]
+    yield [("CALL_ARGS", [interpreter.interpret, (body_bytecode,)])]
+    yield [("END_TRY", [])]
+    raise jit_runtime.UnreachableCodeException("Function '%s' failed to return." % function_name)

+ 25 - 1
kernel/modelverse_jit/jit.py

@@ -5,6 +5,7 @@ import modelverse_jit.bytecode_parser as bytecode_parser
 import modelverse_jit.bytecode_to_tree as bytecode_to_tree
 import modelverse_jit.bytecode_to_cfg as bytecode_to_cfg
 import modelverse_jit.bytecode_ir as bytecode_ir
+import modelverse_jit.bytecode_interpreter as bytecode_interpreter
 import modelverse_jit.cfg_optimization as cfg_optimization
 import modelverse_jit.cfg_to_tree as cfg_to_tree
 import modelverse_jit.cfg_ir as cfg_ir
@@ -566,10 +567,13 @@ class ModelverseJit(object):
         if body_id is None:
             raise ValueError('body_id cannot be None')
         elif body_id in self.jitted_entry_points:
-            # We have already compiled this function.
             raise primitive_functions.PrimitiveFinished(
                 self.jit_globals[self.jitted_entry_points[body_id]])
 
+        compiled_func = self.lookup_compiled_body(body_id)
+        if compiled_func is not None:
+            raise primitive_functions.PrimitiveFinished(compiled_func)
+
         # Generate a name for the function we're about to analyze, and 're-compile'
         # it for the first time.
         function_name = self.generate_function_name(body_id, suggested_name)
@@ -714,6 +718,26 @@ class ModelverseJit(object):
                 tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
             global_name)
 
+def compile_function_body_interpret(jit, function_name, body_id, task_root):
+    """Create a function that invokes the interpreter on the given function."""
+    (parameter_ids, parameter_list, _), = yield [
+        ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
+    param_dict = dict(zip(parameter_ids, parameter_list))
+    body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
+    def __interpret_function(**kwargs):
+        local_args = {}
+        inner_kwargs = dict(kwargs)
+        for param_id, name in param_dict.items():
+            local_args[param_id] = inner_kwargs[name]
+            del inner_kwargs[name]
+
+        yield [("TAIL_CALL_ARGS",
+                [bytecode_interpreter.interpret_bytecode_function,
+                 (function_name, body_bytecode, local_args, inner_kwargs)])]
+
+    jit.jit_globals[function_name] = __interpret_function
+    raise primitive_functions.PrimitiveFinished(__interpret_function)
+
 def compile_function_body_baseline(jit, function_name, body_id, task_root, header=None):
     """Have the baseline JIT compile the function with the given name and body id."""
     (parameter_ids, parameter_list, _), = yield [

+ 3 - 0
kernel/modelverse_jit/runtime.py

@@ -43,6 +43,9 @@ LOCALS_EDGE_NAME = "jit_locals_edge"
 GLOBAL_NOT_FOUND_MESSAGE_FORMAT = "Not found as global: %s"
 """The format of the 'not found as global' message. Takes a single argument."""
 
+BYTECODE_INTERPRETER_ORIGIN_NAME = "bytecode-interpreter"
+"""The origin name for functions that were produced by the bytecode interpreter."""
+
 BASELINE_JIT_ORIGIN_NAME = "baseline-jit"
 """The origin name for functions that were produced by the baseline JIT."""