Преглед изворни кода

Optimize the dict iteration idiom in fast-JIT

jonathanvdc пре 8 година
родитељ
комит
7e7be8ad44

+ 88 - 0
kernel/modelverse_jit/cfg_data_structures.py

@@ -0,0 +1,88 @@
+"""Defines optimizations that replace Modelverse data structures by Python data structures."""
+
+import modelverse_jit.cfg_ir as cfg_ir
+
+SET_DEF_REWRITE_RULES = {
+    'dict_keys':
+    lambda def_def:
+    def_def.redefine(
+        cfg_ir.DirectFunctionCall(
+            cfg_ir.REVERSE_LIST_MACRO_NAME,
+            [('seq',
+              def_def.insert_before(
+                  cfg_ir.DirectFunctionCall(
+                      cfg_ir.READ_DICT_KEYS_MACRO_NAME,
+                      cfg_ir.get_def_value(def_def).argument_list,
+                      cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
+                      has_side_effects=False)))],
+            cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
+            has_side_effects=False))
+}
+
+def __redefine_as_list_len(use_def, def_def):
+    use_def.redefine(
+        cfg_ir.CreateNode(
+            use_def.insert_before(
+                cfg_ir.DirectFunctionCall(
+                    'len', [('seq', def_def)],
+                    cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION,
+                    has_side_effects=False))))
+
+SET_USE_REWRITE_RULES = {
+    ('set_pop', 'a'):
+    lambda use_def, def_def:
+    use_def.redefine(
+        cfg_ir.DirectFunctionCall(
+            'pop', [('self', def_def)],
+            cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION)),
+    ('list_len', 'a'): __redefine_as_list_len,
+    ('read_nr_out', 'a'): __redefine_as_list_len
+}
+
+def get_call_def_rewriter(definition, def_rewrite_rules):
+    """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
+    if cfg_ir.is_call(definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
+        call = cfg_ir.get_def_value(definition)
+        if call.target_name in def_rewrite_rules:
+            return def_rewrite_rules[call.target_name]
+
+    return None
+
+def get_call_use_rewriter(use_definition, def_definition, use_rewrite_rules):
+    """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
+    if cfg_ir.is_call(use_definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
+        call = cfg_ir.get_def_value(use_definition)
+        for arg_name, arg_def in call.argument_list:
+            if arg_def == def_definition:
+                key = (call.target_name, arg_name)
+                if key in use_rewrite_rules:
+                    return use_rewrite_rules[key]
+
+    return None
+
+def apply_rewrite_rules(
+        entry_point,
+        get_def_rewriter,
+        get_use_rewriter):
+    """Applies the given definition and use rewrite rules to all definitions and uses where a
+       rewrite rule can be found for both the definitions and the uses."""
+    # pylint: disable=I0011,W0108
+    cfg_ir.match_and_rewrite(
+        entry_point,
+        lambda def_def:
+        get_def_rewriter(def_def) is not None,
+        lambda use_def, def_def:
+        get_use_rewriter(use_def, def_def) is not None,
+        lambda def_def:
+        get_def_rewriter(def_def)(def_def),
+        lambda use_def, def_def:
+        get_use_rewriter(use_def, def_def)(use_def, def_def))
+
+def optimize_data_structures(entry_point):
+    """Optimizes data structures in the graph defined by the given entry point."""
+    apply_rewrite_rules(
+        entry_point,
+        lambda def_def:
+        get_call_def_rewriter(def_def, SET_DEF_REWRITE_RULES),
+        lambda use_def, def_def:
+        get_call_use_rewriter(use_def, def_def, SET_USE_REWRITE_RULES))

+ 22 - 0
kernel/modelverse_jit/cfg_ir.py

@@ -440,6 +440,10 @@ SIMPLE_POSITIONAL_CALLING_CONVENTION = 'simple-positional'
 """The calling convention for functions that use 'return' statements to return.
    Arguments are matched to parameters based on position."""
 
+SELF_POSITIONAL_CALLING_CONVENTION = 'self-positional'
+"""A calling convention that is identical to SIMPLE_POSITIONAL_CALLING_CONVENTION, except
+   for the fact that the first argument is used as the 'self' parameter."""
+
 JIT_CALLING_CONVENTION = 'jit'
 """The calling convention for jitted functions."""
 
@@ -449,6 +453,12 @@ MACRO_POSITIONAL_CALLING_CONVENTION = 'macro-positional'
 PRINT_MACRO_NAME = 'print'
 """The name of the 'print' macro."""
 
+READ_DICT_KEYS_MACRO_NAME = 'read_dict_keys'
+"""The name of the macro that reads all keys from a dictionary."""
+
+REVERSE_LIST_MACRO_NAME = 'reverse_list'
+"""The name of the list reversal macro."""
+
 class DirectFunctionCall(Value):
     """A value that is the result of a direct function call."""
     def __init__(
@@ -849,6 +859,18 @@ def get_literal_def_value(def_or_value):
     """Gets the value of the given literal value or definition with an underlying literal."""
     return apply_to_value(get_literal_value, def_or_value)
 
+def is_call(def_or_value, target_name=None, calling_convention=None):
+    """Tells if the given definition or value is a direct call.
+       The callee's name must match the given name, or the specified name must be None
+       The call must have the given calling convention, or the calling convention must be None."""
+    value = get_def_value(def_or_value)
+    if isinstance(value, DirectFunctionCall):
+        return ((target_name is None or value.target_name == target_name)
+                and (calling_convention is None
+                     or calling_convention == value.calling_convention))
+    else:
+        return False
+
 def get_all_predecessor_blocks(entry_point):
     """Creates a mapping of blocks to their direct predecessors for every block in the control-flow
        graph defined by the given entry point."""

+ 2 - 0
kernel/modelverse_jit/cfg_optimization.py

@@ -4,6 +4,7 @@ from collections import defaultdict
 import modelverse_jit.cfg_ir as cfg_ir
 import modelverse_jit.cfg_dominators as cfg_dominators
 import modelverse_jit.cfg_ssa_construction as cfg_ssa_construction
+import modelverse_jit.cfg_data_structures as cfg_data_structures
 import modelverse_kernel.primitives as primitive_functions
 
 def is_empty_block(block):
@@ -363,6 +364,7 @@ def optimize(entry_point, jit):
     eliminate_trivial_phis(entry_point)
     entry_point = cfg_ssa_construction.construct_ssa_form(entry_point)
     optimize_calls(entry_point, jit)
+    cfg_data_structures.optimize_data_structures(entry_point)
     yield [("CALL_ARGS", [inline_constants, (entry_point,)])]
     optimize_reads(entry_point)
     simplify_values(entry_point)

+ 14 - 1
kernel/modelverse_jit/cfg_to_tree.py

@@ -838,6 +838,14 @@ class LoweringState(object):
             tree_ir.LoadGlobalInstruction(value.target_name),
             [self.use_definition(arg) for _, arg in value.argument_list])
 
+    def lower_self_positional_call(self, value):
+        """Lowers a direct call that uses the 'self-positional' calling convention."""
+        all_args = [self.use_definition(arg) for _, arg in value.argument_list]
+        target = all_args[0]
+        arg_list = all_args[1:]
+        return tree_ir.CallInstruction(
+            tree_ir.LoadMemberInstruction(target, value.target_name), arg_list)
+
     def lower_jit_call(self, value):
         """Lowers a direct call that uses the 'jit' calling convention."""
         arg_list = [(name, self.use_definition(arg))
@@ -954,12 +962,17 @@ class LoweringState(object):
 
     call_lowerings = {
         cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION : lower_simple_positional_call,
+        cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION : lower_self_positional_call,
         cfg_ir.JIT_CALLING_CONVENTION : lower_jit_call,
         cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION : lower_macro_call
     }
 
     macro_lowerings = {
-        cfg_ir.PRINT_MACRO_NAME: tree_ir.PrintInstruction
+        cfg_ir.PRINT_MACRO_NAME: tree_ir.PrintInstruction,
+        cfg_ir.READ_DICT_KEYS_MACRO_NAME: tree_ir.ReadDictionaryKeysInstruction,
+        cfg_ir.REVERSE_LIST_MACRO_NAME:
+        lambda seq:
+        tree_ir.ListSliceInstruction(seq, None, None, tree_ir.LiteralInstruction(-1))
     }
 
 def lower_flow_graph(entry_point, jit):

+ 71 - 0
kernel/modelverse_jit/tree_ir.py

@@ -813,6 +813,63 @@ class DictionaryLiteralInstruction(Instruction):
     def __repr__(self):
         return "DictionaryLiteralInstruction(%r)" % self.key_value_pairs
 
+class ListSliceInstruction(Instruction):
+    """Slices a list."""
+    def __init__(self, seq, start, end, step):
+        Instruction.__init__(self)
+        self.seq = seq
+        self.start = start
+        self.end = end
+        self.step = step
+
+    def get_children(self):
+        """Gets this instruction's sequence of child instructions."""
+        all_items = (self.seq, self.start, self.end, self.step)
+        return [item for item in all_items if item is not None]
+
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        # pylint: disable=I0011,E1120
+        args = []
+        i = 0
+        for old_item in (self.seq, self.start, self.end, self.step):
+            if old_item is None:
+                args.append(None)
+            else:
+                args.append(new_children[i])
+                i += 1
+
+        assert len(new_children) == i
+        assert len(args) == 4
+        return ListSliceInstruction(*args)
+
+    def has_definition_impl(self):
+        """Tells if this instruction requires a definition."""
+        return any([item.has_definition() for item in self.get_children()])
+
+    def has_result_temporary(self):
+        """Tells if this instruction stores its result in a temporary."""
+        return False
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+            The statement is appended immediately to the code generator."""
+        for item in self.get_children():
+            if item.has_definition():
+                item.generate_python_def(code_generator)
+
+    def generate_python_use(self, code_generator):
+        """Generates a Python expression that retrieves this instruction's
+           result. The expression is returned as a string."""
+        return '%s[%s:%s:%s]' % (
+            self.seq.generate_python_use(code_generator),
+            '' if self.start is None else self.start.generate_python_use(code_generator),
+            '' if self.end is None else self.end.generate_python_use(code_generator),
+            '' if self.step is None else self.step.generate_python_use(code_generator))
+
+    def __repr__(self):
+        return "ListSliceInstruction(%r, %r, %r, %r)" % (self.seq, self.start, self.end, self.step)
+
 class StateInstruction(Instruction):
     """An instruction that accesses the modelverse state."""
     def get_result_type_impl(self):
@@ -1290,6 +1347,20 @@ class ReadDictionaryEdgeInstruction(StateInstruction):
         """Gets this state instruction's argument list."""
         return [self.node_id, self.key]
 
+class ReadDictionaryKeysInstruction(StateInstruction):
+    """An instruction that reads all keys from a dictionary."""
+    def __init__(self, node_id):
+        StateInstruction.__init__(self)
+        self.node_id = node_id
+
+    def get_opcode(self):
+        """Gets the opcode for this state instruction."""
+        return "RDK"
+
+    def get_arguments(self):
+        """Gets this state instruction's argument list."""
+        return [self.node_id]
+
 class ReadEdgeInstruction(StateInstruction):
     """An instruction that reads an edge."""
     def __init__(self, node_id):