Преглед на файлове

Implement 'call' instructions in the JIT

jonathanvdc преди 8 години
родител
ревизия
7bdca769aa
променени са 3 файла, в които са добавени 278 реда и са изтрити 27 реда
  1. 217 19
      kernel/modelverse_jit/jit.py
  2. 60 7
      kernel/modelverse_jit/tree_ir.py
  3. 1 1
      kernel/modelverse_kernel/main.py

+ 217 - 19
kernel/modelverse_jit/jit.py

@@ -15,6 +15,7 @@ class ModelverseJit(object):
         self.todo_entry_points = set()
         self.no_jit_entry_points = set()
         self.jitted_entry_points = {}
+        self.jitted_parameters = {}
         self.jit_globals = {
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished
         }
@@ -43,7 +44,16 @@ class ModelverseJit(object):
            is enabled and the function entry point has been marked jittable, or if
            the function has already been compiled."""
         return ((self.jit_enabled and body_id in self.todo_entry_points) or
-                body_id in self.jitted_entry_points)
+                self.has_compiled(body_id))
+
+    def has_compiled(self, body_id):
+        """Tests if the function belonging to the given body node has been compiled yet."""
+        return body_id in self.jitted_entry_points
+
+    def get_compiled_name(self, body_id):
+        """Gets the name of the compiled version of the given body node in the JIT
+           global state."""
+        return self.jitted_entry_points[body_id]
 
     def mark_no_jit(self, body_id):
         """Informs the JIT that the node with the given identifier is a function entry
@@ -52,18 +62,39 @@ class ModelverseJit(object):
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
 
+    def generate_function_name(self):
+        """Generates a new function name,"""
+        function_name = 'jit_func%d' % self.jit_count
+        self.jit_count += 1
+        return function_name
+
     def register_compiled(self, body_id, compiled_function, function_name=None):
         """Registers a compiled entry point with the JIT."""
         if function_name is None:
-            function_name = 'jit_func%d' % self.jit_count
-            self.jit_count += 1
+            function_name = self.generate_function_name()
 
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = compiled_function
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
 
-    def jit_compile(self, body_id, parameter_list):
+    def jit_parameters(self, body_id):
+        """Acquires the parameter list for the given body id node."""
+        if body_id not in self.jitted_parameters:
+            signature_id, = yield [("RRD", [body_id, "body"])]
+            signature_id = signature_id[0]
+            param_set_id, = yield [("RD", [signature_id, "params"])]
+            if param_set_id is None:
+                self.jitted_parameters[body_id] = ([], [])
+            else:
+                param_name_ids, = yield [("RDK", [param_set_id])]
+                param_names = yield [("RV", [n]) for n in param_name_ids]
+                param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
+                self.jitted_parameters[body_id] = (param_vars, param_names)
+
+        raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
+
+    def jit_compile(self, user_root, body_id):
         """Tries to jit the function defined by the given entry point id and parameter list."""
         # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
         # pylint: disable=I0011,W0122
@@ -77,8 +108,25 @@ class ModelverseJit(object):
             raise JitCompilationFailedException(
                 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
 
+        # Generate a name for the function we're about to analyze, and pretend that
+        # it already exists. (we need to do this for recursive functions)
+        function_name = self.generate_function_name()
+        self.jitted_entry_points[body_id] = function_name
+        self.jit_globals[function_name] = None
+
+        try:
+            gen = self.jit_parameters(body_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            parameter_ids, parameter_list = ex.result
+
+        param_dict = dict(zip(parameter_ids, parameter_list))
+        body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
         try:
-            gen = AnalysisState(self.max_instructions).analyze(body_id)
+            gen = AnalysisState(
+                self, user_root, body_param_dict, self.max_instructions).analyze(body_id)
             inp = None
             while True:
                 inp = yield gen.send(inp)
@@ -86,35 +134,69 @@ class ModelverseJit(object):
             constructed_body = ex.result
         except JitCompilationFailedException as ex:
             self.mark_no_jit(body_id)
+            del self.jitted_entry_points[body_id]
             raise JitCompilationFailedException(
                 '%s (function at %d)' % (ex.message, body_id))
 
+        # Write a prologue and prepend it to the generated function body.
+        prologue_statements = []
+        for (key, val) in param_dict.items():
+            arg_ptr = tree_ir.StoreLocalInstruction(
+                body_param_dict[key],
+                tree_ir.CreateNodeInstruction())
+            prologue_statements.append(arg_ptr)
+            prologue_statements.append(
+                tree_ir.CreateDictionaryEdgeInstruction(
+                    arg_ptr.create_load(),
+                    tree_ir.LiteralInstruction('value'),
+                    tree_ir.LoadLocalInstruction(val)))
+
+        constructed_body = tree_ir.create_block(
+            *(prologue_statements + [constructed_body]))
+
         # Wrap the IR in a function definition, give it a unique name.
         constructed_function = tree_ir.DefineFunctionInstruction(
-            'jit_func%d' % self.jit_count,
+            function_name,
             parameter_list + ['**' + KWARGS_PARAMETER_NAME],
             constructed_body.simplify())
-        self.jit_count += 1
         # Convert the function definition to Python code, and compile it.
         exec(str(constructed_function), self.jit_globals)
         # Extract the compiled function from the JIT global state.
-        compiled_function = self.jit_globals[constructed_function.name]
-        # Save the compiled function so we can reuse it later.
-        self.jitted_entry_points[body_id] = constructed_function.name
+        compiled_function = self.jit_globals[function_name]
 
-        print(constructed_function)
+        # print(constructed_function)
         raise primitive_functions.PrimitiveFinished(compiled_function)
 
 class AnalysisState(object):
     """The state of a bytecode analysis call graph."""
-
-    def __init__(self, max_instructions=None):
+    def __init__(self, jit, user_root, local_mapping, max_instructions=None):
         self.analyzed_instructions = set()
+        self.function_vars = set()
+        self.local_vars = set()
         self.max_instructions = max_instructions
+        self.user_root = user_root
+        self.jit = jit
+        self.local_mapping = local_mapping
 
     def get_local_name(self, local_id):
         """Gets the name for a local with the given id."""
-        return 'local%d' % local_id
+        if local_id not in self.local_mapping:
+            self.local_mapping[local_id] = 'local%d' % local_id
+        return self.local_mapping[local_id]
+
+    def register_local_var(self, local_id):
+        """Registers the given variable node id as a local."""
+        if local_id in self.function_vars:
+            raise JitCompilationFailedException(
+                "Local is used as target of function call.")
+        self.local_vars.add(local_id)
+
+    def register_function_var(self, local_id):
+        """Registers the given variable node id as a function."""
+        if local_id in self.local_vars:
+            raise JitCompilationFailedException(
+                "Local is used as target of function call.")
+        self.function_vars.add(local_id)
 
     def retrieve_user_root(self):
         """Creates an instruction that stores the user_root variable
@@ -134,7 +216,7 @@ class AnalysisState(object):
             raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
         elif (self.max_instructions is not None and
               len(self.analyzed_instructions) > self.max_instructions):
-            raise JitCompilationFailedException('Maximal number of instructions exceeded.')
+            raise JitCompilationFailedException('Maximum number of instructions exceeded.')
 
         self.analyzed_instructions.add(instruction_id)
         instruction_val, = yield [("RV", [instruction_id])]
@@ -397,6 +479,12 @@ class AnalysisState(object):
         #
         #         tmp = global_var
 
+        name = self.get_local_name(var_id)
+
+        if var_name is None:
+            raise primitive_functions.PrimitiveFinished(
+                tree_ir.LoadLocalInstruction(name))
+
         user_root = self.retrieve_user_root()
         global_var = tree_ir.StoreLocalInstruction(
             'global_var',
@@ -413,13 +501,12 @@ class AnalysisState(object):
                 tree_ir.LiteralInstruction(None)),
             tree_ir.RaiseInstruction(
                 tree_ir.CallInstruction(
-                    tree_ir.LoadLocalInstruction('Exception'),
+                    tree_ir.LoadGlobalInstruction('Exception'),
                     [tree_ir.LiteralInstruction(
                         "Runtime error: global '%s' not found" % var_name)
                     ])),
             tree_ir.EmptyInstruction())
 
-        name = self.get_local_name(var_id)
         raise primitive_functions.PrimitiveFinished(
             tree_ir.SelectInstruction(
                 tree_ir.LocalExistsInstruction(name),
@@ -435,6 +522,8 @@ class AnalysisState(object):
         """Tries to analyze the given 'declare' function."""
         var_id, = yield [("RD", [instruction_id, "var"])]
 
+        self.register_local_var(var_id)
+
         name = self.get_local_name(var_id)
 
         # The following logic declares a local:
@@ -556,13 +645,121 @@ class AnalysisState(object):
         # down to reading the value corresponding to the 'value' key
         # of the variable.
         #
-        #     value, =  yield [("RD", [returnvalue, "value"])]
+        #     value, = yield [("RD", [returnvalue, "value"])]
 
         raise primitive_functions.PrimitiveFinished(
             tree_ir.ReadDictionaryValueInstruction(
                 var_r,
                 tree_ir.LiteralInstruction('value')))
 
+    def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
+        """Tries to analyze a direct 'call' instruction."""
+
+        self.register_function_var(callee_id)
+
+        body_id, = yield [("RD", [callee_id, "body"])]
+        # Analyze the parameter list.
+        try:
+            gen = self.jit.jit_parameters(body_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            _, parameter_names = ex.result
+
+        # Compile the callee.
+        try:
+            gen = self.jit.jit_compile(self.user_root, body_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            pass
+
+        # Get the callee's name.
+        compiled_func_name = self.jit.get_compiled_name(body_id)
+
+        # Analyze the argument dictionary.
+        try:
+            gen = self.analyze_argument_dict(first_parameter_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            arg_dict = ex.result
+
+        # Construct the argument list from the parameter list and
+        # argument dictionary.
+        arg_list = []
+        for param_name in parameter_names:
+            if param_name in arg_dict:
+                arg_list.append(arg_dict[param_name])
+            else:
+                raise JitCompilationFailedException(
+                    "Cannot JIT-compile function call to '%s' with missing argument for "
+                    "formal parameter '%s'." % (callee_name, param_name))
+
+        raise primitive_functions.PrimitiveFinished(
+            tree_ir.JitCallInstruction(
+                tree_ir.LoadGlobalInstruction(compiled_func_name),
+                arg_list,
+                tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
+
+    def analyze_argument_dict(self, first_argument_id):
+        """Analyzes the parameter-to-argument mapping started by the specified first argument
+           node."""
+        next_param = first_argument_id
+        argument_dict = {}
+        while next_param is not None:
+            param_name_id, = yield [("RD", [next_param, "name"])]
+            param_name, = yield [("RV", [param_name_id])]
+            param_val_id, = yield [("RD", [next_param, "value"])]
+            try:
+                gen = self.analyze(param_val_id)
+                inp = None
+                while True:
+                    inp = yield gen.send(inp)
+            except primitive_functions.PrimitiveFinished as ex:
+                argument_dict[param_name] = ex.result
+
+            next_param, = yield [("RD", [next_param, "next_param"])]
+
+        raise primitive_functions.PrimitiveFinished(argument_dict)
+
+    def analyze_call(self, instruction_id):
+        """Tries to analyze the given 'call' instruction."""
+        func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
+                                          ("RD", [instruction_id, "params"])]
+
+        # Figure out what the 'func' instruction's type is.
+        func_instruction_op, = yield [("RV", [func_id])]
+        if func_instruction_op['value'] == 'access':
+            # Calls to 'access(resolve(var))' instructions are translated to direct calls.
+            access_value_id, = yield [("RD", [func_id, "var"])]
+            access_value_op, = yield [("RV", [access_value_id])]
+            if access_value_op['value'] == 'resolve':
+                resolved_var_id, = yield [("RD", [access_value_id, "var"])]
+                resolved_var_name, = yield [("RV", [resolved_var_id])]
+
+                # Try to look the name up as a global.
+                _globals, = yield [("RD", [self.user_root, "globals"])]
+                global_var, = yield [("RD", [_globals, resolved_var_name])]
+                global_val, = yield [("RD", [global_var, "value"])]
+
+                if global_val is None:
+                    raise JitCompilationFailedException(
+                        "Cannot JIT function calls that target an unknown value.")
+                else:
+                    gen = self.analyze_direct_call(
+                        global_val, resolved_var_name, first_param_id)
+                    inp = None
+                    while True:
+                        inp = yield gen.send(inp)
+                    # PrimitiveFinished exception will bubble up from here.
+
+        raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
+
+
     instruction_analyzers = {
         'if' : analyze_if,
         'while' : analyze_while,
@@ -574,6 +771,7 @@ class AnalysisState(object):
         'assign' : analyze_assign,
         'access' : analyze_access,
         'output' : analyze_output,
-        'input' : analyze_input
+        'input' : analyze_input,
+        'call' : analyze_call
     }
 

+ 60 - 7
kernel/modelverse_jit/tree_ir.py

@@ -289,6 +289,51 @@ class CallInstruction(Instruction):
                 self.target.generate_python_use(code_generator),
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
 
+class JitCallInstruction(Instruction):
+    """An instruction that calls a jitted function."""
+    def __init__(self, target, argument_list, kwarg):
+        Instruction.__init__(self)
+        self.target = target
+        self.argument_list = argument_list
+        self.kwarg = kwarg
+
+    def simplify(self):
+        """Applies basic simplification to this instruction and its children."""
+        return JitCallInstruction(
+            self.target.simplify(),
+            [arg.simplify() for arg in self.argument_list],
+            self.kwarg.simplify())
+
+    def generate_python_def(self, code_generator):
+        """Generates Python code for this instruction."""
+        if self.target.has_definition():
+            self.target.generate_python_def(code_generator)
+
+        for arg in self.argument_list:
+            if arg.has_definition():
+                arg.generate_python_def(code_generator)
+
+        if self.kwarg.has_definition():
+            self.kwarg.generate_python_def(code_generator)
+
+        code_generator.append_line('try:')
+        code_generator.increase_indentation()
+        code_generator.append_line(
+            'gen = %s(%s, **%s) ' % (
+                self.target.generate_python_use(code_generator),
+                ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list]),
+                self.kwarg.generate_python_use(code_generator)))
+        code_generator.append_line('inp = None')
+        code_generator.append_line('while 1:')
+        code_generator.increase_indentation()
+        code_generator.append_line('inp = yield gen.send(inp)')
+        code_generator.decrease_indentation()
+        code_generator.decrease_indentation()
+        code_generator.append_line('except PrimitiveFinished as ex:')
+        code_generator.increase_indentation()
+        code_generator.append_line('%s = ex.result' % code_generator.get_result_name(self))
+        code_generator.decrease_indentation()
+
 class BinaryInstruction(Instruction):
     """An instruction that performs a binary operation."""
     def __init__(self, lhs, operator, rhs):
@@ -428,8 +473,8 @@ class StateInstruction(Instruction):
 
         code_generator.append_state_definition(self, self.get_opcode(), args)
 
-class LocalInstruction(Instruction):
-    """A base class for instructions that access local variables."""
+class VariableInstruction(Instruction):
+    """A base class for instructions that access variables."""
     def __init__(self, name):
         Instruction.__init__(self)
         self.name = name
@@ -439,6 +484,13 @@ class LocalInstruction(Instruction):
            instruction if it is not None."""
         return self.name
 
+    def generate_python_use(self, code_generator):
+        """Generates a Python expression that retrieves this instruction's
+           result. The expression is returned as a string."""
+        return self.name
+
+class LocalInstruction(VariableInstruction):
+    """A base class for instructions that access local variables."""
     def create_load(self):
         """Creates an instruction that loads the variable referenced by this instruction."""
         return LoadLocalInstruction(self.name)
@@ -448,11 +500,6 @@ class LocalInstruction(Instruction):
            by this instruction."""
         return StoreLocalInstruction(self.name, value)
 
-    def generate_python_use(self, code_generator):
-        """Generates a Python expression that retrieves this instruction's
-           result. The expression is returned as a string."""
-        return self.name
-
 class StoreLocalInstruction(LocalInstruction):
     """An instruction that stores a value in a local variable."""
     def __init__(self, name, value):
@@ -502,6 +549,12 @@ class LocalExistsInstruction(LocalInstruction):
            result. The expression is returned as a string."""
         return "'%s' in locals()" % self.name
 
+class LoadGlobalInstruction(VariableInstruction):
+    """An instruction that loads a value from a global variable."""
+    def has_definition(self):
+        """Tells if this instruction requires a definition."""
+        return False
+
 class LoadIndexInstruction(Instruction):
     """An instruction that produces a value by indexing a specified expression with
        a given key."""

+ 1 - 1
kernel/modelverse_kernel/main.py

@@ -138,7 +138,7 @@ class ModelverseKernel(object):
 
         # Have the JIT compile the function.
         try:
-            jit_gen = self.jit.jit_compile(inst, dict_keys)
+            jit_gen = self.jit.jit_compile(user_root, inst)
             inp = None
             while 1:
                 inp = yield jit_gen.send(inp)