Browse Source

Switch to call-by-name in JIT

jonathanvdc 8 years ago
parent
commit
5bf4011e5d
3 changed files with 116 additions and 83 deletions
  1. 24 24
      kernel/modelverse_jit/intrinsics.py
  2. 56 38
      kernel/modelverse_jit/jit.py
  3. 36 21
      kernel/modelverse_jit/tree_ir.py

+ 24 - 24
kernel/modelverse_jit/intrinsics.py

@@ -36,78 +36,78 @@ UNARY_INTRINSICS = {
 MISC_INTRINSICS = {
 MISC_INTRINSICS = {
     # Reference equality
     # Reference equality
     'element_eq' :
     'element_eq' :
-        lambda lhs, rhs:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
-            tree_ir.BinaryInstruction(lhs, '==', rhs)),
+            tree_ir.BinaryInstruction(a, '==', b)),
     'element_neq' :
     'element_neq' :
-        lambda lhs, rhs:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
-            tree_ir.BinaryInstruction(lhs, '!=', rhs)),
+            tree_ir.BinaryInstruction(a, '!=', b)),
 
 
     # Strings
     # Strings
     'string_get' :
     'string_get' :
-        lambda str_node, index_node:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.LoadIndexInstruction(
             tree_ir.LoadIndexInstruction(
-                tree_ir.ReadValueInstruction(str_node),
-                tree_ir.ReadValueInstruction(index_node))),
+                tree_ir.ReadValueInstruction(a),
+                tree_ir.ReadValueInstruction(b))),
     'string_len' :
     'string_len' :
-        lambda str_node:
+        lambda a:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.CallInstruction(
             tree_ir.CallInstruction(
                 tree_ir.LoadGlobalInstruction('len'),
                 tree_ir.LoadGlobalInstruction('len'),
-                [tree_ir.ReadValueInstruction(str_node)])),
+                [tree_ir.ReadValueInstruction(a)])),
     'string_join' :
     'string_join' :
-        lambda lhs, rhs:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.BinaryInstruction(
             tree_ir.BinaryInstruction(
                 tree_ir.CallInstruction(
                 tree_ir.CallInstruction(
                     tree_ir.LoadGlobalInstruction('str'),
                     tree_ir.LoadGlobalInstruction('str'),
-                    [tree_ir.ReadValueInstruction(lhs)]),
+                    [tree_ir.ReadValueInstruction(a)]),
                 '+',
                 '+',
                 tree_ir.CallInstruction(
                 tree_ir.CallInstruction(
                     tree_ir.LoadGlobalInstruction('str'),
                     tree_ir.LoadGlobalInstruction('str'),
-                    [tree_ir.ReadValueInstruction(rhs)]))),
+                    [tree_ir.ReadValueInstruction(b)]))),
 
 
     # State creation
     # State creation
     'create_node' : tree_ir.CreateNodeInstruction,
     'create_node' : tree_ir.CreateNodeInstruction,
     'create_edge' : tree_ir.CreateEdgeInstruction,
     'create_edge' : tree_ir.CreateEdgeInstruction,
     'create_value' :
     'create_value' :
-        lambda val:
+        lambda a:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
-            tree_ir.ReadValueInstruction(val)),
+            tree_ir.ReadValueInstruction(a)),
 
 
     # State reads
     # State reads
     'read_edge_src' :
     'read_edge_src' :
-        lambda e:
+        lambda a:
         tree_ir.LoadIndexInstruction(
         tree_ir.LoadIndexInstruction(
-            tree_ir.ReadEdgeInstruction(e),
+            tree_ir.ReadEdgeInstruction(a),
             tree_ir.LiteralInstruction(0)),
             tree_ir.LiteralInstruction(0)),
     'read_edge_dst' :
     'read_edge_dst' :
-        lambda e:
+        lambda a:
         tree_ir.LoadIndexInstruction(
         tree_ir.LoadIndexInstruction(
-            tree_ir.ReadEdgeInstruction(e),
+            tree_ir.ReadEdgeInstruction(a),
             tree_ir.LiteralInstruction(1)),
             tree_ir.LiteralInstruction(1)),
     'is_edge' :
     'is_edge' :
-        lambda e:
+        lambda a:
         tree_ir.CreateNodeWithValueInstruction(
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.BinaryInstruction(
             tree_ir.BinaryInstruction(
                 tree_ir.LoadIndexInstruction(
                 tree_ir.LoadIndexInstruction(
-                    tree_ir.ReadEdgeInstruction(e),
+                    tree_ir.ReadEdgeInstruction(a),
                     tree_ir.LiteralInstruction(0)),
                     tree_ir.LiteralInstruction(0)),
                 'is not',
                 'is not',
                 tree_ir.LiteralInstruction(None))),
                 tree_ir.LiteralInstruction(None))),
 
 
     # Dictionary operations
     # Dictionary operations
     'dict_read' :
     'dict_read' :
-        lambda dict_node, key:
+        lambda a, b:
         tree_ir.ReadDictionaryValueInstruction(
         tree_ir.ReadDictionaryValueInstruction(
-            dict_node, tree_ir.ReadValueInstruction(key)),
+            a, tree_ir.ReadValueInstruction(b)),
 
 
     'dict_read_edge' :
     'dict_read_edge' :
-        lambda dict_node, key:
+        lambda a, b:
         tree_ir.ReadDictionaryEdgeInstruction(
         tree_ir.ReadDictionaryEdgeInstruction(
-            dict_node, tree_ir.ReadValueInstruction(key))
+            a, tree_ir.ReadValueInstruction(b))
 }
 }
 
 
 def register_intrinsics(target_jit):
 def register_intrinsics(target_jit):

+ 56 - 38
kernel/modelverse_jit/jit.py

@@ -4,6 +4,32 @@ import modelverse_jit.tree_ir as tree_ir
 KWARGS_PARAMETER_NAME = "kwargs"
 KWARGS_PARAMETER_NAME = "kwargs"
 """The name of the kwargs parameter in jitted functions."""
 """The name of the kwargs parameter in jitted functions."""
 
 
+def get_parameter_names(compiled_function):
+    """Gets the given compiled function's parameter names."""
+    if hasattr(compiled_function, '__code__'):
+        return compiled_function.__code__.co_varnames[
+            :compiled_function.__code__.co_argcount]
+    elif hasattr(compiled_function, '__init__'):
+        return get_parameter_names(compiled_function.__init__)[1:]
+    else:
+        raise ValueError("'compiled_function' must be a function or a type.")
+
+def apply_intrinsic(intrinsic_function, named_args):
+    """Applies the given intrinsic to the given sequence of named arguments."""
+    param_names = get_parameter_names(intrinsic_function)
+    if tuple(param_names) == tuple([n for n, _ in named_args]):
+        # Perfect match. Yay!
+        return intrinsic_function(**dict(named_args))
+    else:
+        # We'll have to store the arguments into locals to preserve
+        # the order of evaluation.
+        stored_args = [(name, tree_ir.StoreLocalInstruction(None, arg)) for name, arg in named_args]
+        arg_value_dict = dict([(name, arg.create_load()) for name, arg in stored_args])
+        store_instructions = [instruction for _, instruction in stored_args]
+        return tree_ir.CompoundInstruction(
+            tree_ir.create_block(*store_instructions),
+            intrinsic_function(**arg_value_dict))
+
 class JitCompilationFailedException(Exception):
 class JitCompilationFailedException(Exception):
     """A type of exception that is raised when the jit fails to compile a function."""
     """A type of exception that is raised when the jit fails to compile a function."""
     pass
     pass
@@ -22,6 +48,7 @@ class ModelverseJit(object):
         self.jit_count = 0
         self.jit_count = 0
         self.max_instructions = 30 if max_instructions is None else max_instructions
         self.max_instructions = 30 if max_instructions is None else max_instructions
         self.compiled_function_lookup = compiled_function_lookup
         self.compiled_function_lookup = compiled_function_lookup
+        # jit_intrinsics is a function name -> intrinsic map.
         self.jit_intrinsics = {}
         self.jit_intrinsics = {}
         self.compilation_dependencies = {}
         self.compilation_dependencies = {}
         self.jit_enabled = True
         self.jit_enabled = True
@@ -78,7 +105,9 @@ class ModelverseJit(object):
 
 
     def register_compiled(self, body_id, compiled_function, function_name=None):
     def register_compiled(self, body_id, compiled_function, function_name=None):
         """Registers a compiled entry point with the JIT."""
         """Registers a compiled entry point with the JIT."""
+        # Get the function's name.
         function_name = self.generate_function_name(function_name)
         function_name = self.generate_function_name(function_name)
+        # Map the body id to the given parameter list.
         self.jitted_entry_points[body_id] = function_name
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = compiled_function
         self.jit_globals[function_name] = compiled_function
         if body_id in self.todo_entry_points:
         if body_id in self.todo_entry_points:
@@ -94,25 +123,33 @@ class ModelverseJit(object):
         else:
         else:
             return None
             return None
 
 
-    def register_intrinsic(self, name, apply_intrinsic):
+    def get_intrinsic(self, name):
+        """Tries to find an intrinsic version of the function with the
+           given name."""
+        if name in self.jit_intrinsics:
+            return self.jit_intrinsics[name]
+        else:
+            return None
+
+    def register_intrinsic(self, name, intrinsic_function):
         """Registers the given intrisic with the JIT. This will make the JIT replace calls to
         """Registers the given intrisic with the JIT. This will make the JIT replace calls to
            the function with the given entry point by an application of the specified function."""
            the function with the given entry point by an application of the specified function."""
-        self.jit_intrinsics[name] = apply_intrinsic
+        self.jit_intrinsics[name] = intrinsic_function
 
 
     def register_binary_intrinsic(self, name, operator):
     def register_binary_intrinsic(self, name, operator):
         """Registers an intrinsic with the JIT that represents the given binary operation."""
         """Registers an intrinsic with the JIT that represents the given binary operation."""
-        self.register_intrinsic(name, lambda lhs, rhs: tree_ir.CreateNodeWithValueInstruction(
+        self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction(
             tree_ir.BinaryInstruction(
             tree_ir.BinaryInstruction(
-                tree_ir.ReadValueInstruction(lhs),
+                tree_ir.ReadValueInstruction(a),
                 operator,
                 operator,
-                tree_ir.ReadValueInstruction(rhs))))
+                tree_ir.ReadValueInstruction(b))))
 
 
     def register_unary_intrinsic(self, name, operator):
     def register_unary_intrinsic(self, name, operator):
         """Registers an intrinsic with the JIT that represents the given unary operation."""
         """Registers an intrinsic with the JIT that represents the given unary operation."""
-        self.register_intrinsic(name, lambda val: tree_ir.CreateNodeWithValueInstruction(
+        self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
             tree_ir.UnaryInstruction(
             tree_ir.UnaryInstruction(
                 operator,
                 operator,
-                tree_ir.ReadValueInstruction(val))))
+                tree_ir.ReadValueInstruction(a))))
 
 
     def jit_parameters(self, body_id):
     def jit_parameters(self, body_id):
         """Acquires the parameter list for the given body id node."""
         """Acquires the parameter list for the given body id node."""
@@ -708,18 +745,10 @@ class AnalysisState(object):
         if body_id in self.jit.compilation_dependencies:
         if body_id in self.jit.compilation_dependencies:
             self.jit.compilation_dependencies[body_id].add(self.body_id)
             self.jit.compilation_dependencies[body_id].add(self.body_id)
 
 
-        # Analyze the parameter list.
-        try:
-            gen = self.jit.jit_parameters(body_id)
-            inp = None
-            while True:
-                inp = yield gen.send(inp)
-        except primitive_functions.PrimitiveFinished as ex:
-            _, parameter_names = ex.result
+        # Figure out if the function might be an intrinsic.
+        intrinsic = self.jit.get_intrinsic(callee_name)
 
 
-        is_intrinsic = callee_name in self.jit.jit_intrinsics
-
-        if not is_intrinsic:
+        if intrinsic is None:
             compiled_func = self.jit.lookup_compiled_function(callee_name)
             compiled_func = self.jit.lookup_compiled_function(callee_name)
             if compiled_func is None:
             if compiled_func is None:
                 # Compile the callee.
                 # Compile the callee.
@@ -738,39 +767,28 @@ class AnalysisState(object):
 
 
         # Analyze the argument dictionary.
         # Analyze the argument dictionary.
         try:
         try:
-            gen = self.analyze_argument_dict(first_parameter_id)
+            gen = self.analyze_arguments(first_parameter_id)
             inp = None
             inp = None
             while True:
             while True:
                 inp = yield gen.send(inp)
                 inp = yield gen.send(inp)
         except primitive_functions.PrimitiveFinished as ex:
         except primitive_functions.PrimitiveFinished as ex:
-            arg_dict = ex.result
-
-        # Construct the argument list from the parameter list and
-        # argument dictionary.
-        arg_list = []
-        for param_name in parameter_names:
-            if param_name in arg_dict:
-                arg_list.append(arg_dict[param_name])
-            else:
-                raise JitCompilationFailedException(
-                    "Cannot JIT-compile function call to '%s' with missing argument for "
-                    "formal parameter '%s'." % (callee_name, param_name))
+            named_args = ex.result
 
 
-        if is_intrinsic:
+        if intrinsic is not None:
             raise primitive_functions.PrimitiveFinished(
             raise primitive_functions.PrimitiveFinished(
-                self.jit.jit_intrinsics[callee_name](*arg_list))
+                apply_intrinsic(intrinsic, named_args))
         else:
         else:
             raise primitive_functions.PrimitiveFinished(
             raise primitive_functions.PrimitiveFinished(
                 tree_ir.JitCallInstruction(
                 tree_ir.JitCallInstruction(
                     tree_ir.LoadGlobalInstruction(compiled_func_name),
                     tree_ir.LoadGlobalInstruction(compiled_func_name),
-                    arg_list,
+                    named_args,
                     tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
                     tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
 
 
-    def analyze_argument_dict(self, first_argument_id):
+    def analyze_arguments(self, first_argument_id):
         """Analyzes the parameter-to-argument mapping started by the specified first argument
         """Analyzes the parameter-to-argument mapping started by the specified first argument
            node."""
            node."""
         next_param = first_argument_id
         next_param = first_argument_id
-        argument_dict = {}
+        named_args = []
         while next_param is not None:
         while next_param is not None:
             param_name_id, = yield [("RD", [next_param, "name"])]
             param_name_id, = yield [("RD", [next_param, "name"])]
             param_name, = yield [("RV", [param_name_id])]
             param_name, = yield [("RV", [param_name_id])]
@@ -781,11 +799,11 @@ class AnalysisState(object):
                 while True:
                 while True:
                     inp = yield gen.send(inp)
                     inp = yield gen.send(inp)
             except primitive_functions.PrimitiveFinished as ex:
             except primitive_functions.PrimitiveFinished as ex:
-                argument_dict[param_name] = ex.result
+                named_args.append((param_name, ex.result))
 
 
             next_param, = yield [("RD", [next_param, "next_param"])]
             next_param, = yield [("RD", [next_param, "next_param"])]
 
 
-        raise primitive_functions.PrimitiveFinished(argument_dict)
+        raise primitive_functions.PrimitiveFinished(named_args)
 
 
     def analyze_call(self, instruction_id):
     def analyze_call(self, instruction_id):
         """Tries to analyze the given 'call' instruction."""
         """Tries to analyze the given 'call' instruction."""

+ 36 - 21
kernel/modelverse_jit/tree_ir.py

@@ -44,7 +44,7 @@ class Instruction(object):
         """Tells if this instruction requires a definition."""
         """Tells if this instruction requires a definition."""
         return True
         return True
 
 
-    def get_result_name_override(self):
+    def get_result_name_override(self, code_generator):
         """Gets a value that overrides the code generator's result name for this
         """Gets a value that overrides the code generator's result name for this
            instruction if it is not None."""
            instruction if it is not None."""
         return None
         return None
@@ -117,7 +117,7 @@ class PythonGenerator(object):
     def get_result_name(self, instruction, advised_name=None):
     def get_result_name(self, instruction, advised_name=None):
         """Gets the name of the given instruction's result variable."""
         """Gets the name of the given instruction's result variable."""
         if instruction not in self.result_value_dict:
         if instruction not in self.result_value_dict:
-            override_name = instruction.get_result_name_override()
+            override_name = instruction.get_result_name_override(self)
             if override_name is not None:
             if override_name is not None:
                 self.result_value_dict[instruction] = override_name
                 self.result_value_dict[instruction] = override_name
             elif advised_name is not None:
             elif advised_name is not None:
@@ -291,17 +291,17 @@ class CallInstruction(Instruction):
 
 
 class JitCallInstruction(Instruction):
 class JitCallInstruction(Instruction):
     """An instruction that calls a jitted function."""
     """An instruction that calls a jitted function."""
-    def __init__(self, target, argument_list, kwarg):
+    def __init__(self, target, named_args, kwarg):
         Instruction.__init__(self)
         Instruction.__init__(self)
         self.target = target
         self.target = target
-        self.argument_list = argument_list
+        self.named_args = named_args
         self.kwarg = kwarg
         self.kwarg = kwarg
 
 
     def simplify(self):
     def simplify(self):
         """Applies basic simplification to this instruction and its children."""
         """Applies basic simplification to this instruction and its children."""
         return JitCallInstruction(
         return JitCallInstruction(
             self.target.simplify(),
             self.target.simplify(),
-            [arg.simplify() for arg in self.argument_list],
+            [(param_name, arg.simplify()) for param_name, arg in self.named_args],
             self.kwarg.simplify())
             self.kwarg.simplify())
 
 
     def generate_python_def(self, code_generator):
     def generate_python_def(self, code_generator):
@@ -309,21 +309,26 @@ class JitCallInstruction(Instruction):
         if self.target.has_definition():
         if self.target.has_definition():
             self.target.generate_python_def(code_generator)
             self.target.generate_python_def(code_generator)
 
 
-        for arg in self.argument_list:
+        arg_list = []
+        for param_name, arg in self.named_args:
             if arg.has_definition():
             if arg.has_definition():
                 arg.generate_python_def(code_generator)
                 arg.generate_python_def(code_generator)
 
 
+            arg_list.append(
+                '%s=%s' % (param_name, arg.generate_python_use(code_generator)))
+
         if self.kwarg.has_definition():
         if self.kwarg.has_definition():
             self.kwarg.generate_python_def(code_generator)
             self.kwarg.generate_python_def(code_generator)
 
 
+        arg_list.append(
+            '**%s' % self.kwarg.generate_python_use(code_generator))
+
         code_generator.append_line('try:')
         code_generator.append_line('try:')
         code_generator.increase_indentation()
         code_generator.increase_indentation()
         code_generator.append_line(
         code_generator.append_line(
             'gen = %s(%s) ' % (
             'gen = %s(%s) ' % (
                 self.target.generate_python_use(code_generator),
                 self.target.generate_python_use(code_generator),
-                ', '.join(
-                    [arg.generate_python_use(code_generator) for arg in self.argument_list] +
-                    ['**' + self.kwarg.generate_python_use(code_generator)])))
+                ', '.join(arg_list)))
         code_generator.append_line('inp = None')
         code_generator.append_line('inp = None')
         code_generator.append_line('while 1:')
         code_generator.append_line('while 1:')
         code_generator.increase_indentation()
         code_generator.increase_indentation()
@@ -505,21 +510,30 @@ class StateInstruction(Instruction):
 
 
         code_generator.append_state_definition(self, self.get_opcode(), args)
         code_generator.append_state_definition(self, self.get_opcode(), args)
 
 
-class VariableInstruction(Instruction):
-    """A base class for instructions that access variables."""
+class VariableName(object):
+    """A data structure that unifies names across instructions that access the
+       same variable."""
     def __init__(self, name):
     def __init__(self, name):
-        Instruction.__init__(self)
         self.name = name
         self.name = name
 
 
-    def get_result_name_override(self):
+    def get_result_name_override(self, _):
         """Gets a value that overrides the code generator's result name for this
         """Gets a value that overrides the code generator's result name for this
            instruction if it is not None."""
            instruction if it is not None."""
         return self.name
         return self.name
 
 
-    def generate_python_use(self, code_generator):
-        """Generates a Python expression that retrieves this instruction's
-           result. The expression is returned as a string."""
-        return self.name
+class VariableInstruction(Instruction):
+    """A base class for instructions that access variables."""
+    def __init__(self, name):
+        Instruction.__init__(self)
+        if isinstance(name, str) or isinstance(name, unicode) or name is None:
+            self.name = VariableName(name)
+        else:
+            self.name = name
+
+    def get_result_name_override(self, code_generator):
+        """Gets a value that overrides the code generator's result name for this
+           instruction if it is not None."""
+        return code_generator.get_result_name(self.name)
 
 
 class LocalInstruction(VariableInstruction):
 class LocalInstruction(VariableInstruction):
     """A base class for instructions that access local variables."""
     """A base class for instructions that access local variables."""
@@ -553,10 +567,10 @@ class LoadLocalInstruction(LocalInstruction):
         """Tells if this instruction requires a definition."""
         """Tells if this instruction requires a definition."""
         return False
         return False
 
 
-class DefineFunctionInstruction(LocalInstruction):
+class DefineFunctionInstruction(VariableInstruction):
     """An instruction that defines a function."""
     """An instruction that defines a function."""
     def __init__(self, name, parameter_list, body):
     def __init__(self, name, parameter_list, body):
-        LocalInstruction.__init__(self, name)
+        VariableInstruction.__init__(self, name)
         self.parameter_list = parameter_list
         self.parameter_list = parameter_list
         self.body = body
         self.body = body
 
 
@@ -564,7 +578,8 @@ class DefineFunctionInstruction(LocalInstruction):
         """Generates a Python statement that executes this instruction.
         """Generates a Python statement that executes this instruction.
            The statement is appended immediately to the code generator."""
            The statement is appended immediately to the code generator."""
 
 
-        code_generator.append_line('def %s(%s):' % (self.name, ', '.join(self.parameter_list)))
+        code_generator.append_line('def %s(%s):' % (
+            code_generator.get_result_name(self), ', '.join(self.parameter_list)))
         code_generator.increase_indentation()
         code_generator.increase_indentation()
         self.body.generate_python_def(code_generator)
         self.body.generate_python_def(code_generator)
         code_generator.decrease_indentation()
         code_generator.decrease_indentation()
@@ -579,7 +594,7 @@ class LocalExistsInstruction(LocalInstruction):
     def generate_python_use(self, code_generator):
     def generate_python_use(self, code_generator):
         """Generates a Python expression that retrieves this instruction's
         """Generates a Python expression that retrieves this instruction's
            result. The expression is returned as a string."""
            result. The expression is returned as a string."""
-        return "'%s' in locals()" % self.name
+        return "'%s' in locals()" % self.get_result_name_override(code_generator)
 
 
 class LoadGlobalInstruction(VariableInstruction):
 class LoadGlobalInstruction(VariableInstruction):
     """An instruction that loads a value from a global variable."""
     """An instruction that loads a value from a global variable."""