Browse Source

Implement 'call' instructions in the JIT

jonathanvdc 8 years ago
parent
commit
7fcb4983ce
3 changed files with 278 additions and 27 deletions
  1. 217 19
      kernel/modelverse_jit/jit.py
  2. 60 7
      kernel/modelverse_jit/tree_ir.py
  3. 1 1
      kernel/modelverse_kernel/main.py

+ 217 - 19
kernel/modelverse_jit/jit.py

@@ -15,6 +15,7 @@ class ModelverseJit(object):
         self.todo_entry_points = set()
         self.todo_entry_points = set()
         self.no_jit_entry_points = set()
         self.no_jit_entry_points = set()
         self.jitted_entry_points = {}
         self.jitted_entry_points = {}
+        self.jitted_parameters = {}
         self.jit_globals = {
         self.jit_globals = {
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished
         }
         }
@@ -43,7 +44,16 @@ class ModelverseJit(object):
            is enabled and the function entry point has been marked jittable, or if
            is enabled and the function entry point has been marked jittable, or if
            the function has already been compiled."""
            the function has already been compiled."""
         return ((self.jit_enabled and body_id in self.todo_entry_points) or
         return ((self.jit_enabled and body_id in self.todo_entry_points) or
-                body_id in self.jitted_entry_points)
+                self.has_compiled(body_id))
+
+    def has_compiled(self, body_id):
+        """Tests if the function belonging to the given body node has been compiled yet."""
+        return body_id in self.jitted_entry_points
+
+    def get_compiled_name(self, body_id):
+        """Gets the name of the compiled version of the given body node in the JIT
+           global state."""
+        return self.jitted_entry_points[body_id]
 
 
     def mark_no_jit(self, body_id):
     def mark_no_jit(self, body_id):
         """Informs the JIT that the node with the given identifier is a function entry
         """Informs the JIT that the node with the given identifier is a function entry
@@ -52,18 +62,39 @@ class ModelverseJit(object):
         if body_id in self.todo_entry_points:
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
             self.todo_entry_points.remove(body_id)
 
 
+    def generate_function_name(self):
+        """Generates a new function name,"""
+        function_name = 'jit_func%d' % self.jit_count
+        self.jit_count += 1
+        return function_name
+
     def register_compiled(self, body_id, compiled_function, function_name=None):
     def register_compiled(self, body_id, compiled_function, function_name=None):
         """Registers a compiled entry point with the JIT."""
         """Registers a compiled entry point with the JIT."""
         if function_name is None:
         if function_name is None:
-            function_name = 'jit_func%d' % self.jit_count
-            self.jit_count += 1
+            function_name = self.generate_function_name()
 
 
         self.jitted_entry_points[body_id] = function_name
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = compiled_function
         self.jit_globals[function_name] = compiled_function
         if body_id in self.todo_entry_points:
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
             self.todo_entry_points.remove(body_id)
 
 
-    def jit_compile(self, body_id, parameter_list):
+    def jit_parameters(self, body_id):
+        """Acquires the parameter list for the given body id node."""
+        if body_id not in self.jitted_parameters:
+            signature_id, = yield [("RRD", [body_id, "body"])]
+            signature_id = signature_id[0]
+            param_set_id, = yield [("RD", [signature_id, "params"])]
+            if param_set_id is None:
+                self.jitted_parameters[body_id] = ([], [])
+            else:
+                param_name_ids, = yield [("RDK", [param_set_id])]
+                param_names = yield [("RV", [n]) for n in param_name_ids]
+                param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
+                self.jitted_parameters[body_id] = (param_vars, param_names)
+
+        raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
+
+    def jit_compile(self, user_root, body_id):
         """Tries to jit the function defined by the given entry point id and parameter list."""
         """Tries to jit the function defined by the given entry point id and parameter list."""
         # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
         # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
         # pylint: disable=I0011,W0122
         # pylint: disable=I0011,W0122
@@ -77,8 +108,25 @@ class ModelverseJit(object):
             raise JitCompilationFailedException(
             raise JitCompilationFailedException(
                 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
                 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
 
 
+        # Generate a name for the function we're about to analyze, and pretend that
+        # it already exists. (we need to do this for recursive functions)
+        function_name = self.generate_function_name()
+        self.jitted_entry_points[body_id] = function_name
+        self.jit_globals[function_name] = None
+
+        try:
+            gen = self.jit_parameters(body_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            parameter_ids, parameter_list = ex.result
+
+        param_dict = dict(zip(parameter_ids, parameter_list))
+        body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
         try:
         try:
-            gen = AnalysisState(self.max_instructions).analyze(body_id)
+            gen = AnalysisState(
+                self, user_root, body_param_dict, self.max_instructions).analyze(body_id)
             inp = None
             inp = None
             while True:
             while True:
                 inp = yield gen.send(inp)
                 inp = yield gen.send(inp)
@@ -86,35 +134,69 @@ class ModelverseJit(object):
             constructed_body = ex.result
             constructed_body = ex.result
         except JitCompilationFailedException as ex:
         except JitCompilationFailedException as ex:
             self.mark_no_jit(body_id)
             self.mark_no_jit(body_id)
+            del self.jitted_entry_points[body_id]
             raise JitCompilationFailedException(
             raise JitCompilationFailedException(
                 '%s (function at %d)' % (ex.message, body_id))
                 '%s (function at %d)' % (ex.message, body_id))
 
 
+        # Write a prologue and prepend it to the generated function body.
+        prologue_statements = []
+        for (key, val) in param_dict.items():
+            arg_ptr = tree_ir.StoreLocalInstruction(
+                body_param_dict[key],
+                tree_ir.CreateNodeInstruction())
+            prologue_statements.append(arg_ptr)
+            prologue_statements.append(
+                tree_ir.CreateDictionaryEdgeInstruction(
+                    arg_ptr.create_load(),
+                    tree_ir.LiteralInstruction('value'),
+                    tree_ir.LoadLocalInstruction(val)))
+
+        constructed_body = tree_ir.create_block(
+            *(prologue_statements + [constructed_body]))
+
         # Wrap the IR in a function definition, give it a unique name.
         # Wrap the IR in a function definition, give it a unique name.
         constructed_function = tree_ir.DefineFunctionInstruction(
         constructed_function = tree_ir.DefineFunctionInstruction(
-            'jit_func%d' % self.jit_count,
+            function_name,
             parameter_list + ['**' + KWARGS_PARAMETER_NAME],
             parameter_list + ['**' + KWARGS_PARAMETER_NAME],
             constructed_body.simplify())
             constructed_body.simplify())
-        self.jit_count += 1
         # Convert the function definition to Python code, and compile it.
         # Convert the function definition to Python code, and compile it.
         exec(str(constructed_function), self.jit_globals)
         exec(str(constructed_function), self.jit_globals)
         # Extract the compiled function from the JIT global state.
         # Extract the compiled function from the JIT global state.
-        compiled_function = self.jit_globals[constructed_function.name]
-        # Save the compiled function so we can reuse it later.
-        self.jitted_entry_points[body_id] = constructed_function.name
+        compiled_function = self.jit_globals[function_name]
 
 
-        print(constructed_function)
+        # print(constructed_function)
         raise primitive_functions.PrimitiveFinished(compiled_function)
         raise primitive_functions.PrimitiveFinished(compiled_function)
 
 
 class AnalysisState(object):
 class AnalysisState(object):
     """The state of a bytecode analysis call graph."""
     """The state of a bytecode analysis call graph."""
-
-    def __init__(self, max_instructions=None):
+    def __init__(self, jit, user_root, local_mapping, max_instructions=None):
         self.analyzed_instructions = set()
         self.analyzed_instructions = set()
+        self.function_vars = set()
+        self.local_vars = set()
         self.max_instructions = max_instructions
         self.max_instructions = max_instructions
+        self.user_root = user_root
+        self.jit = jit
+        self.local_mapping = local_mapping
 
 
     def get_local_name(self, local_id):
     def get_local_name(self, local_id):
         """Gets the name for a local with the given id."""
         """Gets the name for a local with the given id."""
-        return 'local%d' % local_id
+        if local_id not in self.local_mapping:
+            self.local_mapping[local_id] = 'local%d' % local_id
+        return self.local_mapping[local_id]
+
+    def register_local_var(self, local_id):
+        """Registers the given variable node id as a local."""
+        if local_id in self.function_vars:
+            raise JitCompilationFailedException(
+                "Local is used as target of function call.")
+        self.local_vars.add(local_id)
+
+    def register_function_var(self, local_id):
+        """Registers the given variable node id as a function."""
+        if local_id in self.local_vars:
+            raise JitCompilationFailedException(
+                "Local is used as target of function call.")
+        self.function_vars.add(local_id)
 
 
     def retrieve_user_root(self):
     def retrieve_user_root(self):
         """Creates an instruction that stores the user_root variable
         """Creates an instruction that stores the user_root variable
@@ -134,7 +216,7 @@ class AnalysisState(object):
             raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
             raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
         elif (self.max_instructions is not None and
         elif (self.max_instructions is not None and
               len(self.analyzed_instructions) > self.max_instructions):
               len(self.analyzed_instructions) > self.max_instructions):
-            raise JitCompilationFailedException('Maximal number of instructions exceeded.')
+            raise JitCompilationFailedException('Maximum number of instructions exceeded.')
 
 
         self.analyzed_instructions.add(instruction_id)
         self.analyzed_instructions.add(instruction_id)
         instruction_val, = yield [("RV", [instruction_id])]
         instruction_val, = yield [("RV", [instruction_id])]
@@ -397,6 +479,12 @@ class AnalysisState(object):
         #
         #
         #         tmp = global_var
         #         tmp = global_var
 
 
+        name = self.get_local_name(var_id)
+
+        if var_name is None:
+            raise primitive_functions.PrimitiveFinished(
+                tree_ir.LoadLocalInstruction(name))
+
         user_root = self.retrieve_user_root()
         user_root = self.retrieve_user_root()
         global_var = tree_ir.StoreLocalInstruction(
         global_var = tree_ir.StoreLocalInstruction(
             'global_var',
             'global_var',
@@ -413,13 +501,12 @@ class AnalysisState(object):
                 tree_ir.LiteralInstruction(None)),
                 tree_ir.LiteralInstruction(None)),
             tree_ir.RaiseInstruction(
             tree_ir.RaiseInstruction(
                 tree_ir.CallInstruction(
                 tree_ir.CallInstruction(
-                    tree_ir.LoadLocalInstruction('Exception'),
+                    tree_ir.LoadGlobalInstruction('Exception'),
                     [tree_ir.LiteralInstruction(
                     [tree_ir.LiteralInstruction(
                         "Runtime error: global '%s' not found" % var_name)
                         "Runtime error: global '%s' not found" % var_name)
                     ])),
                     ])),
             tree_ir.EmptyInstruction())
             tree_ir.EmptyInstruction())
 
 
-        name = self.get_local_name(var_id)
         raise primitive_functions.PrimitiveFinished(
         raise primitive_functions.PrimitiveFinished(
             tree_ir.SelectInstruction(
             tree_ir.SelectInstruction(
                 tree_ir.LocalExistsInstruction(name),
                 tree_ir.LocalExistsInstruction(name),
@@ -435,6 +522,8 @@ class AnalysisState(object):
         """Tries to analyze the given 'declare' function."""
         """Tries to analyze the given 'declare' function."""
         var_id, = yield [("RD", [instruction_id, "var"])]
         var_id, = yield [("RD", [instruction_id, "var"])]
 
 
+        self.register_local_var(var_id)
+
         name = self.get_local_name(var_id)
         name = self.get_local_name(var_id)
 
 
         # The following logic declares a local:
         # The following logic declares a local:
@@ -556,13 +645,121 @@ class AnalysisState(object):
         # down to reading the value corresponding to the 'value' key
         # down to reading the value corresponding to the 'value' key
         # of the variable.
         # of the variable.
         #
         #
-        #     value, =  yield [("RD", [returnvalue, "value"])]
+        #     value, = yield [("RD", [returnvalue, "value"])]
 
 
         raise primitive_functions.PrimitiveFinished(
         raise primitive_functions.PrimitiveFinished(
             tree_ir.ReadDictionaryValueInstruction(
             tree_ir.ReadDictionaryValueInstruction(
                 var_r,
                 var_r,
                 tree_ir.LiteralInstruction('value')))
                 tree_ir.LiteralInstruction('value')))
 
 
+    def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
+        """Tries to analyze a direct 'call' instruction."""
+
+        self.register_function_var(callee_id)
+
+        body_id, = yield [("RD", [callee_id, "body"])]
+        # Analyze the parameter list.
+        try:
+            gen = self.jit.jit_parameters(body_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            _, parameter_names = ex.result
+
+        # Compile the callee.
+        try:
+            gen = self.jit.jit_compile(self.user_root, body_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            pass
+
+        # Get the callee's name.
+        compiled_func_name = self.jit.get_compiled_name(body_id)
+
+        # Analyze the argument dictionary.
+        try:
+            gen = self.analyze_argument_dict(first_parameter_id)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            arg_dict = ex.result
+
+        # Construct the argument list from the parameter list and
+        # argument dictionary.
+        arg_list = []
+        for param_name in parameter_names:
+            if param_name in arg_dict:
+                arg_list.append(arg_dict[param_name])
+            else:
+                raise JitCompilationFailedException(
+                    "Cannot JIT-compile function call to '%s' with missing argument for "
+                    "formal parameter '%s'." % (callee_name, param_name))
+
+        raise primitive_functions.PrimitiveFinished(
+            tree_ir.JitCallInstruction(
+                tree_ir.LoadGlobalInstruction(compiled_func_name),
+                arg_list,
+                tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
+
+    def analyze_argument_dict(self, first_argument_id):
+        """Analyzes the parameter-to-argument mapping started by the specified first argument
+           node."""
+        next_param = first_argument_id
+        argument_dict = {}
+        while next_param is not None:
+            param_name_id, = yield [("RD", [next_param, "name"])]
+            param_name, = yield [("RV", [param_name_id])]
+            param_val_id, = yield [("RD", [next_param, "value"])]
+            try:
+                gen = self.analyze(param_val_id)
+                inp = None
+                while True:
+                    inp = yield gen.send(inp)
+            except primitive_functions.PrimitiveFinished as ex:
+                argument_dict[param_name] = ex.result
+
+            next_param, = yield [("RD", [next_param, "next_param"])]
+
+        raise primitive_functions.PrimitiveFinished(argument_dict)
+
+    def analyze_call(self, instruction_id):
+        """Tries to analyze the given 'call' instruction."""
+        func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
+                                          ("RD", [instruction_id, "params"])]
+
+        # Figure out what the 'func' instruction's type is.
+        func_instruction_op, = yield [("RV", [func_id])]
+        if func_instruction_op['value'] == 'access':
+            # Calls to 'access(resolve(var))' instructions are translated to direct calls.
+            access_value_id, = yield [("RD", [func_id, "var"])]
+            access_value_op, = yield [("RV", [access_value_id])]
+            if access_value_op['value'] == 'resolve':
+                resolved_var_id, = yield [("RD", [access_value_id, "var"])]
+                resolved_var_name, = yield [("RV", [resolved_var_id])]
+
+                # Try to look the name up as a global.
+                _globals, = yield [("RD", [self.user_root, "globals"])]
+                global_var, = yield [("RD", [_globals, resolved_var_name])]
+                global_val, = yield [("RD", [global_var, "value"])]
+
+                if global_val is None:
+                    raise JitCompilationFailedException(
+                        "Cannot JIT function calls that target an unknown value.")
+                else:
+                    gen = self.analyze_direct_call(
+                        global_val, resolved_var_name, first_param_id)
+                    inp = None
+                    while True:
+                        inp = yield gen.send(inp)
+                    # PrimitiveFinished exception will bubble up from here.
+
+        raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
+
+
     instruction_analyzers = {
     instruction_analyzers = {
         'if' : analyze_if,
         'if' : analyze_if,
         'while' : analyze_while,
         'while' : analyze_while,
@@ -574,6 +771,7 @@ class AnalysisState(object):
         'assign' : analyze_assign,
         'assign' : analyze_assign,
         'access' : analyze_access,
         'access' : analyze_access,
         'output' : analyze_output,
         'output' : analyze_output,
-        'input' : analyze_input
+        'input' : analyze_input,
+        'call' : analyze_call
     }
     }
 
 

+ 60 - 7
kernel/modelverse_jit/tree_ir.py

@@ -289,6 +289,51 @@ class CallInstruction(Instruction):
                 self.target.generate_python_use(code_generator),
                 self.target.generate_python_use(code_generator),
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
 
 
+class JitCallInstruction(Instruction):
+    """An instruction that calls a jitted function."""
+    def __init__(self, target, argument_list, kwarg):
+        Instruction.__init__(self)
+        self.target = target
+        self.argument_list = argument_list
+        self.kwarg = kwarg
+
+    def simplify(self):
+        """Applies basic simplification to this instruction and its children."""
+        return JitCallInstruction(
+            self.target.simplify(),
+            [arg.simplify() for arg in self.argument_list],
+            self.kwarg.simplify())
+
+    def generate_python_def(self, code_generator):
+        """Generates Python code for this instruction."""
+        if self.target.has_definition():
+            self.target.generate_python_def(code_generator)
+
+        for arg in self.argument_list:
+            if arg.has_definition():
+                arg.generate_python_def(code_generator)
+
+        if self.kwarg.has_definition():
+            self.kwarg.generate_python_def(code_generator)
+
+        code_generator.append_line('try:')
+        code_generator.increase_indentation()
+        code_generator.append_line(
+            'gen = %s(%s, **%s) ' % (
+                self.target.generate_python_use(code_generator),
+                ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list]),
+                self.kwarg.generate_python_use(code_generator)))
+        code_generator.append_line('inp = None')
+        code_generator.append_line('while 1:')
+        code_generator.increase_indentation()
+        code_generator.append_line('inp = yield gen.send(inp)')
+        code_generator.decrease_indentation()
+        code_generator.decrease_indentation()
+        code_generator.append_line('except PrimitiveFinished as ex:')
+        code_generator.increase_indentation()
+        code_generator.append_line('%s = ex.result' % code_generator.get_result_name(self))
+        code_generator.decrease_indentation()
+
 class BinaryInstruction(Instruction):
 class BinaryInstruction(Instruction):
     """An instruction that performs a binary operation."""
     """An instruction that performs a binary operation."""
     def __init__(self, lhs, operator, rhs):
     def __init__(self, lhs, operator, rhs):
@@ -428,8 +473,8 @@ class StateInstruction(Instruction):
 
 
         code_generator.append_state_definition(self, self.get_opcode(), args)
         code_generator.append_state_definition(self, self.get_opcode(), args)
 
 
-class LocalInstruction(Instruction):
-    """A base class for instructions that access local variables."""
+class VariableInstruction(Instruction):
+    """A base class for instructions that access variables."""
     def __init__(self, name):
     def __init__(self, name):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.name = name
         self.name = name
@@ -439,6 +484,13 @@ class LocalInstruction(Instruction):
            instruction if it is not None."""
            instruction if it is not None."""
         return self.name
         return self.name
 
 
+    def generate_python_use(self, code_generator):
+        """Generates a Python expression that retrieves this instruction's
+           result. The expression is returned as a string."""
+        return self.name
+
+class LocalInstruction(VariableInstruction):
+    """A base class for instructions that access local variables."""
     def create_load(self):
     def create_load(self):
         """Creates an instruction that loads the variable referenced by this instruction."""
         """Creates an instruction that loads the variable referenced by this instruction."""
         return LoadLocalInstruction(self.name)
         return LoadLocalInstruction(self.name)
@@ -448,11 +500,6 @@ class LocalInstruction(Instruction):
            by this instruction."""
            by this instruction."""
         return StoreLocalInstruction(self.name, value)
         return StoreLocalInstruction(self.name, value)
 
 
-    def generate_python_use(self, code_generator):
-        """Generates a Python expression that retrieves this instruction's
-           result. The expression is returned as a string."""
-        return self.name
-
 class StoreLocalInstruction(LocalInstruction):
 class StoreLocalInstruction(LocalInstruction):
     """An instruction that stores a value in a local variable."""
     """An instruction that stores a value in a local variable."""
     def __init__(self, name, value):
     def __init__(self, name, value):
@@ -502,6 +549,12 @@ class LocalExistsInstruction(LocalInstruction):
            result. The expression is returned as a string."""
            result. The expression is returned as a string."""
         return "'%s' in locals()" % self.name
         return "'%s' in locals()" % self.name
 
 
+class LoadGlobalInstruction(VariableInstruction):
+    """An instruction that loads a value from a global variable."""
+    def has_definition(self):
+        """Tells if this instruction requires a definition."""
+        return False
+
 class LoadIndexInstruction(Instruction):
 class LoadIndexInstruction(Instruction):
     """An instruction that produces a value by indexing a specified expression with
     """An instruction that produces a value by indexing a specified expression with
        a given key."""
        a given key."""

+ 1 - 1
kernel/modelverse_kernel/main.py

@@ -138,7 +138,7 @@ class ModelverseKernel(object):
 
 
         # Have the JIT compile the function.
         # Have the JIT compile the function.
         try:
         try:
-            jit_gen = self.jit.jit_compile(inst, dict_keys)
+            jit_gen = self.jit.jit_compile(user_root, inst)
             inp = None
             inp = None
             while 1:
             while 1:
                 inp = yield jit_gen.send(inp)
                 inp = yield jit_gen.send(inp)