Browse Source

Switch to call-by-name in JIT

jonathanvdc 8 years ago
parent
commit
5bf4011e5d
3 changed files with 116 additions and 83 deletions
  1. 24 24
      kernel/modelverse_jit/intrinsics.py
  2. 56 38
      kernel/modelverse_jit/jit.py
  3. 36 21
      kernel/modelverse_jit/tree_ir.py

+ 24 - 24
kernel/modelverse_jit/intrinsics.py

@@ -36,78 +36,78 @@ UNARY_INTRINSICS = {
 MISC_INTRINSICS = {
     # Reference equality
     'element_eq' :
-        lambda lhs, rhs:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
-            tree_ir.BinaryInstruction(lhs, '==', rhs)),
+            tree_ir.BinaryInstruction(a, '==', b)),
     'element_neq' :
-        lambda lhs, rhs:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
-            tree_ir.BinaryInstruction(lhs, '!=', rhs)),
+            tree_ir.BinaryInstruction(a, '!=', b)),
 
     # Strings
     'string_get' :
-        lambda str_node, index_node:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.LoadIndexInstruction(
-                tree_ir.ReadValueInstruction(str_node),
-                tree_ir.ReadValueInstruction(index_node))),
+                tree_ir.ReadValueInstruction(a),
+                tree_ir.ReadValueInstruction(b))),
     'string_len' :
-        lambda str_node:
+        lambda a:
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.CallInstruction(
                 tree_ir.LoadGlobalInstruction('len'),
-                [tree_ir.ReadValueInstruction(str_node)])),
+                [tree_ir.ReadValueInstruction(a)])),
     'string_join' :
-        lambda lhs, rhs:
+        lambda a, b:
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.BinaryInstruction(
                 tree_ir.CallInstruction(
                     tree_ir.LoadGlobalInstruction('str'),
-                    [tree_ir.ReadValueInstruction(lhs)]),
+                    [tree_ir.ReadValueInstruction(a)]),
                 '+',
                 tree_ir.CallInstruction(
                     tree_ir.LoadGlobalInstruction('str'),
-                    [tree_ir.ReadValueInstruction(rhs)]))),
+                    [tree_ir.ReadValueInstruction(b)]))),
 
     # State creation
     'create_node' : tree_ir.CreateNodeInstruction,
     'create_edge' : tree_ir.CreateEdgeInstruction,
     'create_value' :
-        lambda val:
+        lambda a:
         tree_ir.CreateNodeWithValueInstruction(
-            tree_ir.ReadValueInstruction(val)),
+            tree_ir.ReadValueInstruction(a)),
 
     # State reads
     'read_edge_src' :
-        lambda e:
+        lambda a:
         tree_ir.LoadIndexInstruction(
-            tree_ir.ReadEdgeInstruction(e),
+            tree_ir.ReadEdgeInstruction(a),
             tree_ir.LiteralInstruction(0)),
     'read_edge_dst' :
-        lambda e:
+        lambda a:
         tree_ir.LoadIndexInstruction(
-            tree_ir.ReadEdgeInstruction(e),
+            tree_ir.ReadEdgeInstruction(a),
             tree_ir.LiteralInstruction(1)),
     'is_edge' :
-        lambda e:
+        lambda a:
         tree_ir.CreateNodeWithValueInstruction(
             tree_ir.BinaryInstruction(
                 tree_ir.LoadIndexInstruction(
-                    tree_ir.ReadEdgeInstruction(e),
+                    tree_ir.ReadEdgeInstruction(a),
                     tree_ir.LiteralInstruction(0)),
                 'is not',
                 tree_ir.LiteralInstruction(None))),
 
     # Dictionary operations
     'dict_read' :
-        lambda dict_node, key:
+        lambda a, b:
         tree_ir.ReadDictionaryValueInstruction(
-            dict_node, tree_ir.ReadValueInstruction(key)),
+            a, tree_ir.ReadValueInstruction(b)),
 
     'dict_read_edge' :
-        lambda dict_node, key:
+        lambda a, b:
         tree_ir.ReadDictionaryEdgeInstruction(
-            dict_node, tree_ir.ReadValueInstruction(key))
+            a, tree_ir.ReadValueInstruction(b))
 }
 
 def register_intrinsics(target_jit):

+ 56 - 38
kernel/modelverse_jit/jit.py

@@ -4,6 +4,32 @@ import modelverse_jit.tree_ir as tree_ir
 KWARGS_PARAMETER_NAME = "kwargs"
 """The name of the kwargs parameter in jitted functions."""
 
+def get_parameter_names(compiled_function):
+    """Gets the given compiled function's parameter names."""
+    if hasattr(compiled_function, '__code__'):
+        return compiled_function.__code__.co_varnames[
+            :compiled_function.__code__.co_argcount]
+    elif hasattr(compiled_function, '__init__'):
+        return get_parameter_names(compiled_function.__init__)[1:]
+    else:
+        raise ValueError("'compiled_function' must be a function or a type.")
+
+def apply_intrinsic(intrinsic_function, named_args):
+    """Applies the given intrinsic to the given sequence of named arguments."""
+    param_names = get_parameter_names(intrinsic_function)
+    if tuple(param_names) == tuple([n for n, _ in named_args]):
+        # Perfect match. Yay!
+        return intrinsic_function(**dict(named_args))
+    else:
+        # We'll have to store the arguments into locals to preserve
+        # the order of evaluation.
+        stored_args = [(name, tree_ir.StoreLocalInstruction(None, arg)) for name, arg in named_args]
+        arg_value_dict = dict([(name, arg.create_load()) for name, arg in stored_args])
+        store_instructions = [instruction for _, instruction in stored_args]
+        return tree_ir.CompoundInstruction(
+            tree_ir.create_block(*store_instructions),
+            intrinsic_function(**arg_value_dict))
+
 class JitCompilationFailedException(Exception):
     """A type of exception that is raised when the jit fails to compile a function."""
     pass
@@ -22,6 +48,7 @@ class ModelverseJit(object):
         self.jit_count = 0
         self.max_instructions = 30 if max_instructions is None else max_instructions
         self.compiled_function_lookup = compiled_function_lookup
+        # jit_intrinsics is a function name -> intrinsic map.
         self.jit_intrinsics = {}
         self.compilation_dependencies = {}
         self.jit_enabled = True
@@ -78,7 +105,9 @@ class ModelverseJit(object):
 
     def register_compiled(self, body_id, compiled_function, function_name=None):
         """Registers a compiled entry point with the JIT."""
+        # Get the function's name.
         function_name = self.generate_function_name(function_name)
+        # Map the body id to the given parameter list.
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = compiled_function
         if body_id in self.todo_entry_points:
@@ -94,25 +123,33 @@ class ModelverseJit(object):
         else:
             return None
 
-    def register_intrinsic(self, name, apply_intrinsic):
+    def get_intrinsic(self, name):
+        """Tries to find an intrinsic version of the function with the
+           given name."""
+        if name in self.jit_intrinsics:
+            return self.jit_intrinsics[name]
+        else:
+            return None
+
+    def register_intrinsic(self, name, intrinsic_function):
         """Registers the given intrisic with the JIT. This will make the JIT replace calls to
            the function with the given entry point by an application of the specified function."""
-        self.jit_intrinsics[name] = apply_intrinsic
+        self.jit_intrinsics[name] = intrinsic_function
 
     def register_binary_intrinsic(self, name, operator):
         """Registers an intrinsic with the JIT that represents the given binary operation."""
-        self.register_intrinsic(name, lambda lhs, rhs: tree_ir.CreateNodeWithValueInstruction(
+        self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction(
             tree_ir.BinaryInstruction(
-                tree_ir.ReadValueInstruction(lhs),
+                tree_ir.ReadValueInstruction(a),
                 operator,
-                tree_ir.ReadValueInstruction(rhs))))
+                tree_ir.ReadValueInstruction(b))))
 
     def register_unary_intrinsic(self, name, operator):
         """Registers an intrinsic with the JIT that represents the given unary operation."""
-        self.register_intrinsic(name, lambda val: tree_ir.CreateNodeWithValueInstruction(
+        self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
             tree_ir.UnaryInstruction(
                 operator,
-                tree_ir.ReadValueInstruction(val))))
+                tree_ir.ReadValueInstruction(a))))
 
     def jit_parameters(self, body_id):
         """Acquires the parameter list for the given body id node."""
@@ -708,18 +745,10 @@ class AnalysisState(object):
         if body_id in self.jit.compilation_dependencies:
             self.jit.compilation_dependencies[body_id].add(self.body_id)
 
-        # Analyze the parameter list.
-        try:
-            gen = self.jit.jit_parameters(body_id)
-            inp = None
-            while True:
-                inp = yield gen.send(inp)
-        except primitive_functions.PrimitiveFinished as ex:
-            _, parameter_names = ex.result
+        # Figure out if the function might be an intrinsic.
+        intrinsic = self.jit.get_intrinsic(callee_name)
 
-        is_intrinsic = callee_name in self.jit.jit_intrinsics
-
-        if not is_intrinsic:
+        if intrinsic is None:
             compiled_func = self.jit.lookup_compiled_function(callee_name)
             if compiled_func is None:
                 # Compile the callee.
@@ -738,39 +767,28 @@ class AnalysisState(object):
 
         # Analyze the argument dictionary.
         try:
-            gen = self.analyze_argument_dict(first_parameter_id)
+            gen = self.analyze_arguments(first_parameter_id)
             inp = None
             while True:
                 inp = yield gen.send(inp)
         except primitive_functions.PrimitiveFinished as ex:
-            arg_dict = ex.result
-
-        # Construct the argument list from the parameter list and
-        # argument dictionary.
-        arg_list = []
-        for param_name in parameter_names:
-            if param_name in arg_dict:
-                arg_list.append(arg_dict[param_name])
-            else:
-                raise JitCompilationFailedException(
-                    "Cannot JIT-compile function call to '%s' with missing argument for "
-                    "formal parameter '%s'." % (callee_name, param_name))
+            named_args = ex.result
 
-        if is_intrinsic:
+        if intrinsic is not None:
             raise primitive_functions.PrimitiveFinished(
-                self.jit.jit_intrinsics[callee_name](*arg_list))
+                apply_intrinsic(intrinsic, named_args))
         else:
             raise primitive_functions.PrimitiveFinished(
                 tree_ir.JitCallInstruction(
                     tree_ir.LoadGlobalInstruction(compiled_func_name),
-                    arg_list,
+                    named_args,
                     tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
 
-    def analyze_argument_dict(self, first_argument_id):
+    def analyze_arguments(self, first_argument_id):
         """Analyzes the parameter-to-argument mapping started by the specified first argument
            node."""
         next_param = first_argument_id
-        argument_dict = {}
+        named_args = []
         while next_param is not None:
             param_name_id, = yield [("RD", [next_param, "name"])]
             param_name, = yield [("RV", [param_name_id])]
@@ -781,11 +799,11 @@ class AnalysisState(object):
                 while True:
                     inp = yield gen.send(inp)
             except primitive_functions.PrimitiveFinished as ex:
-                argument_dict[param_name] = ex.result
+                named_args.append((param_name, ex.result))
 
             next_param, = yield [("RD", [next_param, "next_param"])]
 
-        raise primitive_functions.PrimitiveFinished(argument_dict)
+        raise primitive_functions.PrimitiveFinished(named_args)
 
     def analyze_call(self, instruction_id):
         """Tries to analyze the given 'call' instruction."""

+ 36 - 21
kernel/modelverse_jit/tree_ir.py

@@ -44,7 +44,7 @@ class Instruction(object):
         """Tells if this instruction requires a definition."""
         return True
 
-    def get_result_name_override(self):
+    def get_result_name_override(self, code_generator):
         """Gets a value that overrides the code generator's result name for this
            instruction if it is not None."""
         return None
@@ -117,7 +117,7 @@ class PythonGenerator(object):
     def get_result_name(self, instruction, advised_name=None):
         """Gets the name of the given instruction's result variable."""
         if instruction not in self.result_value_dict:
-            override_name = instruction.get_result_name_override()
+            override_name = instruction.get_result_name_override(self)
             if override_name is not None:
                 self.result_value_dict[instruction] = override_name
             elif advised_name is not None:
@@ -291,17 +291,17 @@ class CallInstruction(Instruction):
 
 class JitCallInstruction(Instruction):
     """An instruction that calls a jitted function."""
-    def __init__(self, target, argument_list, kwarg):
+    def __init__(self, target, named_args, kwarg):
         Instruction.__init__(self)
         self.target = target
-        self.argument_list = argument_list
+        self.named_args = named_args
         self.kwarg = kwarg
 
     def simplify(self):
         """Applies basic simplification to this instruction and its children."""
         return JitCallInstruction(
             self.target.simplify(),
-            [arg.simplify() for arg in self.argument_list],
+            [(param_name, arg.simplify()) for param_name, arg in self.named_args],
             self.kwarg.simplify())
 
     def generate_python_def(self, code_generator):
@@ -309,21 +309,26 @@ class JitCallInstruction(Instruction):
         if self.target.has_definition():
             self.target.generate_python_def(code_generator)
 
-        for arg in self.argument_list:
+        arg_list = []
+        for param_name, arg in self.named_args:
             if arg.has_definition():
                 arg.generate_python_def(code_generator)
 
+            arg_list.append(
+                '%s=%s' % (param_name, arg.generate_python_use(code_generator)))
+
         if self.kwarg.has_definition():
             self.kwarg.generate_python_def(code_generator)
 
+        arg_list.append(
+            '**%s' % self.kwarg.generate_python_use(code_generator))
+
         code_generator.append_line('try:')
         code_generator.increase_indentation()
         code_generator.append_line(
             'gen = %s(%s) ' % (
                 self.target.generate_python_use(code_generator),
-                ', '.join(
-                    [arg.generate_python_use(code_generator) for arg in self.argument_list] +
-                    ['**' + self.kwarg.generate_python_use(code_generator)])))
+                ', '.join(arg_list)))
         code_generator.append_line('inp = None')
         code_generator.append_line('while 1:')
         code_generator.increase_indentation()
@@ -505,21 +510,30 @@ class StateInstruction(Instruction):
 
         code_generator.append_state_definition(self, self.get_opcode(), args)
 
-class VariableInstruction(Instruction):
-    """A base class for instructions that access variables."""
+class VariableName(object):
+    """A data structure that unifies names across instructions that access the
+       same variable."""
     def __init__(self, name):
-        Instruction.__init__(self)
         self.name = name
 
-    def get_result_name_override(self):
+    def get_result_name_override(self, _):
         """Gets a value that overrides the code generator's result name for this
            instruction if it is not None."""
         return self.name
 
-    def generate_python_use(self, code_generator):
-        """Generates a Python expression that retrieves this instruction's
-           result. The expression is returned as a string."""
-        return self.name
+class VariableInstruction(Instruction):
+    """A base class for instructions that access variables."""
+    def __init__(self, name):
+        Instruction.__init__(self)
+        if isinstance(name, str) or isinstance(name, unicode) or name is None:
+            self.name = VariableName(name)
+        else:
+            self.name = name
+
+    def get_result_name_override(self, code_generator):
+        """Gets a value that overrides the code generator's result name for this
+           instruction if it is not None."""
+        return code_generator.get_result_name(self.name)
 
 class LocalInstruction(VariableInstruction):
     """A base class for instructions that access local variables."""
@@ -553,10 +567,10 @@ class LoadLocalInstruction(LocalInstruction):
         """Tells if this instruction requires a definition."""
         return False
 
-class DefineFunctionInstruction(LocalInstruction):
+class DefineFunctionInstruction(VariableInstruction):
     """An instruction that defines a function."""
     def __init__(self, name, parameter_list, body):
-        LocalInstruction.__init__(self, name)
+        VariableInstruction.__init__(self, name)
         self.parameter_list = parameter_list
         self.body = body
 
@@ -564,7 +578,8 @@ class DefineFunctionInstruction(LocalInstruction):
         """Generates a Python statement that executes this instruction.
            The statement is appended immediately to the code generator."""
 
-        code_generator.append_line('def %s(%s):' % (self.name, ', '.join(self.parameter_list)))
+        code_generator.append_line('def %s(%s):' % (
+            code_generator.get_result_name(self), ', '.join(self.parameter_list)))
         code_generator.increase_indentation()
         self.body.generate_python_def(code_generator)
         code_generator.decrease_indentation()
@@ -579,7 +594,7 @@ class LocalExistsInstruction(LocalInstruction):
     def generate_python_use(self, code_generator):
         """Generates a Python expression that retrieves this instruction's
            result. The expression is returned as a string."""
-        return "'%s' in locals()" % self.name
+        return "'%s' in locals()" % self.get_result_name_override(code_generator)
 
 class LoadGlobalInstruction(VariableInstruction):
     """An instruction that loads a value from a global variable."""