Sfoglia il codice sorgente

Fold constant reads in the JIT

jonathanvdc 8 anni fa
parent
commit
043ea99f29
2 ha cambiato i file con 67 aggiunte e 3 eliminazioni
  1. 55 3
      kernel/modelverse_jit/jit.py
  2. 12 0
      kernel/modelverse_jit/tree_ir.py

+ 55 - 3
kernel/modelverse_jit/jit.py

@@ -34,6 +34,50 @@ def apply_intrinsic(intrinsic_function, named_args):
             tree_ir.create_block(*store_instructions),
             tree_ir.create_block(*store_instructions),
             intrinsic_function(**arg_value_dict))
             intrinsic_function(**arg_value_dict))
 
 
+def map_and_simplify_generator(function, instruction):
+    """Applies the given mapping function to every instruction in the tree
+       that has the given instruction as root, and simplifies it on-the-fly.
+
+       This is at least as powerful as first mapping and then simplifying, as
+       maps and simplifications are interspersed.
+
+       This function assumes that function creates a generator that returns by
+       raising a primitive_functions.PrimitiveFinished."""
+
+    # First handle the children by mapping on them and then simplifying them.
+    new_children = []
+    for inst in instruction.get_children():
+        try:
+            gen = map_and_simplify_generator(function, inst)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            new_children.append(ex.result)
+
+    # Then apply the function to the top-level node.
+    try:
+        gen = function(instruction.create(new_children))
+        inp = None
+        while True:
+            inp = yield gen.send(inp)
+    except primitive_functions.PrimitiveFinished as ex:
+        # Finally, simplify the transformed top-level node.
+        raise primitive_functions.PrimitiveFinished(ex.result.simplify_node())
+
+def expand_constant_read(instruction):
+    """Tries to replace a read of a constant node by a literal."""
+    if isinstance(instruction, tree_ir.ReadValueInstruction) and \
+        isinstance(instruction.node_id, tree_ir.LiteralInstruction):
+        val, = yield [("RV", [instruction.node_id.literal])]
+        raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
+    else:
+        raise primitive_functions.PrimitiveFinished(instruction)
+
+def optimize_tree_ir(instruction):
+    """Optimizes an IR tree."""
+    return map_and_simplify_generator(expand_constant_read, instruction)
+
 class JitCompilationFailedException(Exception):
 class JitCompilationFailedException(Exception):
     """A type of exception that is raised when the jit fails to compile a function."""
     """A type of exception that is raised when the jit fails to compile a function."""
     pass
     pass
@@ -252,11 +296,19 @@ class ModelverseJit(object):
         constructed_body = tree_ir.create_block(
         constructed_body = tree_ir.create_block(
             *(prologue_statements + [constructed_body]))
             *(prologue_statements + [constructed_body]))
 
 
+        try:
+            gen = optimize_tree_ir(constructed_body)
+            inp = None
+            while True:
+                inp = yield gen.send(inp)
+        except primitive_functions.PrimitiveFinished as ex:
+            constructed_body = ex.result
+
         # Wrap the IR in a function definition, give it a unique name.
         # Wrap the IR in a function definition, give it a unique name.
         constructed_function = tree_ir.DefineFunctionInstruction(
         constructed_function = tree_ir.DefineFunctionInstruction(
             function_name,
             function_name,
             parameter_list + ['**' + KWARGS_PARAMETER_NAME],
             parameter_list + ['**' + KWARGS_PARAMETER_NAME],
-            constructed_body.simplify())
+            constructed_body)
         # Convert the function definition to Python code, and compile it.
         # Convert the function definition to Python code, and compile it.
         exec(str(constructed_function), self.jit_globals)
         exec(str(constructed_function), self.jit_globals)
         # Extract the compiled function from the JIT global state.
         # Extract the compiled function from the JIT global state.
@@ -370,8 +422,8 @@ class AnalysisState(object):
         """Tries to compile a list of IR trees from the given list of instruction ids."""
         """Tries to compile a list of IR trees from the given list of instruction ids."""
         results = []
         results = []
         for inst in instruction_ids:
         for inst in instruction_ids:
-            gen = self.analyze(inst)
             try:
             try:
+                gen = self.analyze(inst)
                 inp = None
                 inp = None
                 while True:
                 while True:
                     inp = yield gen.send(inp)
                     inp = yield gen.send(inp)
@@ -388,8 +440,8 @@ class AnalysisState(object):
                 tree_ir.ReturnInstruction(
                 tree_ir.ReturnInstruction(
                     tree_ir.EmptyInstruction()))
                     tree_ir.EmptyInstruction()))
         else:
         else:
-            gen = self.analyze(retval_id)
             try:
             try:
+                gen = self.analyze(retval_id)
                 inp = None
                 inp = None
                 while True:
                 while True:
                     inp = yield gen.send(inp)
                     inp = yield gen.send(inp)

+ 12 - 0
kernel/modelverse_jit/tree_ir.py

@@ -1129,6 +1129,18 @@ def with_debug_info_trace(instruction, debug_info):
                 LiteralInstruction('TRACE: %s(JIT)' % debug_info)),
                 LiteralInstruction('TRACE: %s(JIT)' % debug_info)),
             instruction)
             instruction)
 
 
+def map_and_simplify(function, instruction):
+    """Applies the given mapping function to every instruction in the tree
+       that has the given instruction as root, and simplifies it on-the-fly.
+
+       This is at least as powerful as first mapping and then simplifying, as
+       maps and simplifications are interspersed."""
+    # Let's just agree to disagree on map vs list comprehensions, pylint.
+    # pylint: disable=I0011,W0141
+    return function(
+        instruction.create(
+            map(map_and_simplify, instruction.get_children()))).simplify_node()
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     example_tree = SelectInstruction(
     example_tree = SelectInstruction(
         LiteralInstruction(True),
         LiteralInstruction(True),