Sfoglia il codice sorgente

Optimize the dict iteration idiom in fast-JIT

jonathanvdc 8 anni fa
parent
commit
7e7be8ad44

+ 88 - 0
kernel/modelverse_jit/cfg_data_structures.py

@@ -0,0 +1,88 @@
+"""Defines optimizations that replace Modelverse data structures by Python data structures."""
+
+import modelverse_jit.cfg_ir as cfg_ir
+
+SET_DEF_REWRITE_RULES = {
+    'dict_keys':
+    lambda def_def:
+    def_def.redefine(
+        cfg_ir.DirectFunctionCall(
+            cfg_ir.REVERSE_LIST_MACRO_NAME,
+            [('seq',
+              def_def.insert_before(
+                  cfg_ir.DirectFunctionCall(
+                      cfg_ir.READ_DICT_KEYS_MACRO_NAME,
+                      cfg_ir.get_def_value(def_def).argument_list,
+                      cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
+                      has_side_effects=False)))],
+            cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION,
+            has_side_effects=False))
+}
+
+def __redefine_as_list_len(use_def, def_def):
+    use_def.redefine(
+        cfg_ir.CreateNode(
+            use_def.insert_before(
+                cfg_ir.DirectFunctionCall(
+                    'len', [('seq', def_def)],
+                    cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION,
+                    has_side_effects=False))))
+
+SET_USE_REWRITE_RULES = {
+    ('set_pop', 'a'):
+    lambda use_def, def_def:
+    use_def.redefine(
+        cfg_ir.DirectFunctionCall(
+            'pop', [('self', def_def)],
+            cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION)),
+    ('list_len', 'a'): __redefine_as_list_len,
+    ('read_nr_out', 'a'): __redefine_as_list_len
+}
+
+def get_call_def_rewriter(definition, def_rewrite_rules):
+    """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
+    if cfg_ir.is_call(definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
+        call = cfg_ir.get_def_value(definition)
+        if call.target_name in def_rewrite_rules:
+            return def_rewrite_rules[call.target_name]
+
+    return None
+
+def get_call_use_rewriter(use_definition, def_definition, use_rewrite_rules):
+    """Gets an appropriate rewrite rule from the given dictionary of call rewrite rules."""
+    if cfg_ir.is_call(use_definition, calling_convention=cfg_ir.JIT_CALLING_CONVENTION):
+        call = cfg_ir.get_def_value(use_definition)
+        for arg_name, arg_def in call.argument_list:
+            if arg_def == def_definition:
+                key = (call.target_name, arg_name)
+                if key in use_rewrite_rules:
+                    return use_rewrite_rules[key]
+
+    return None
+
+def apply_rewrite_rules(
+        entry_point,
+        get_def_rewriter,
+        get_use_rewriter):
+    """Applies the given definition and use rewrite rules to all definitions and uses where a
+       rewrite rule can be found for both the definitions and the uses."""
+    # pylint: disable=I0011,W0108
+    cfg_ir.match_and_rewrite(
+        entry_point,
+        lambda def_def:
+        get_def_rewriter(def_def) is not None,
+        lambda use_def, def_def:
+        get_use_rewriter(use_def, def_def) is not None,
+        lambda def_def:
+        get_def_rewriter(def_def)(def_def),
+        lambda use_def, def_def:
+        get_use_rewriter(use_def, def_def)(use_def, def_def))
+
+def optimize_data_structures(entry_point):
+    """Optimizes data structures in the graph defined by the given entry point."""
+    apply_rewrite_rules(
+        entry_point,
+        lambda def_def:
+        get_call_def_rewriter(def_def, SET_DEF_REWRITE_RULES),
+        lambda use_def, def_def:
+        get_call_use_rewriter(use_def, def_def, SET_USE_REWRITE_RULES))

+ 22 - 0
kernel/modelverse_jit/cfg_ir.py

@@ -440,6 +440,10 @@ SIMPLE_POSITIONAL_CALLING_CONVENTION = 'simple-positional'
 """The calling convention for functions that use 'return' statements to return.
 """The calling convention for functions that use 'return' statements to return.
    Arguments are matched to parameters based on position."""
    Arguments are matched to parameters based on position."""
 
 
+SELF_POSITIONAL_CALLING_CONVENTION = 'self-positional'
+"""A calling convention that is identical to SIMPLE_POSITIONAL_CALLING_CONVENTION, except
+   for the fact that the first argument is used as the 'self' parameter."""
+
 JIT_CALLING_CONVENTION = 'jit'
 JIT_CALLING_CONVENTION = 'jit'
 """The calling convention for jitted functions."""
 """The calling convention for jitted functions."""
 
 
@@ -449,6 +453,12 @@ MACRO_POSITIONAL_CALLING_CONVENTION = 'macro-positional'
 PRINT_MACRO_NAME = 'print'
 PRINT_MACRO_NAME = 'print'
 """The name of the 'print' macro."""
 """The name of the 'print' macro."""
 
 
+READ_DICT_KEYS_MACRO_NAME = 'read_dict_keys'
+"""The name of the macro that reads all keys from a dictionary."""
+
+REVERSE_LIST_MACRO_NAME = 'reverse_list'
+"""The name of the list reversal macro."""
+
 class DirectFunctionCall(Value):
 class DirectFunctionCall(Value):
     """A value that is the result of a direct function call."""
     """A value that is the result of a direct function call."""
     def __init__(
     def __init__(
@@ -849,6 +859,18 @@ def get_literal_def_value(def_or_value):
     """Gets the value of the given literal value or definition with an underlying literal."""
     """Gets the value of the given literal value or definition with an underlying literal."""
     return apply_to_value(get_literal_value, def_or_value)
     return apply_to_value(get_literal_value, def_or_value)
 
 
+def is_call(def_or_value, target_name=None, calling_convention=None):
+    """Tells if the given definition or value is a direct call.
+       The callee's name must match the given name, or the specified name must be None
+       The call must have the given calling convention, or the calling convention must be None."""
+    value = get_def_value(def_or_value)
+    if isinstance(value, DirectFunctionCall):
+        return ((target_name is None or value.target_name == target_name)
+                and (calling_convention is None
+                     or calling_convention == value.calling_convention))
+    else:
+        return False
+
 def get_all_predecessor_blocks(entry_point):
 def get_all_predecessor_blocks(entry_point):
     """Creates a mapping of blocks to their direct predecessors for every block in the control-flow
     """Creates a mapping of blocks to their direct predecessors for every block in the control-flow
        graph defined by the given entry point."""
        graph defined by the given entry point."""

+ 2 - 0
kernel/modelverse_jit/cfg_optimization.py

@@ -4,6 +4,7 @@ from collections import defaultdict
 import modelverse_jit.cfg_ir as cfg_ir
 import modelverse_jit.cfg_ir as cfg_ir
 import modelverse_jit.cfg_dominators as cfg_dominators
 import modelverse_jit.cfg_dominators as cfg_dominators
 import modelverse_jit.cfg_ssa_construction as cfg_ssa_construction
 import modelverse_jit.cfg_ssa_construction as cfg_ssa_construction
+import modelverse_jit.cfg_data_structures as cfg_data_structures
 import modelverse_kernel.primitives as primitive_functions
 import modelverse_kernel.primitives as primitive_functions
 
 
 def is_empty_block(block):
 def is_empty_block(block):
@@ -363,6 +364,7 @@ def optimize(entry_point, jit):
     eliminate_trivial_phis(entry_point)
     eliminate_trivial_phis(entry_point)
     entry_point = cfg_ssa_construction.construct_ssa_form(entry_point)
     entry_point = cfg_ssa_construction.construct_ssa_form(entry_point)
     optimize_calls(entry_point, jit)
     optimize_calls(entry_point, jit)
+    cfg_data_structures.optimize_data_structures(entry_point)
     yield [("CALL_ARGS", [inline_constants, (entry_point,)])]
     yield [("CALL_ARGS", [inline_constants, (entry_point,)])]
     optimize_reads(entry_point)
     optimize_reads(entry_point)
     simplify_values(entry_point)
     simplify_values(entry_point)

+ 14 - 1
kernel/modelverse_jit/cfg_to_tree.py

@@ -838,6 +838,14 @@ class LoweringState(object):
             tree_ir.LoadGlobalInstruction(value.target_name),
             tree_ir.LoadGlobalInstruction(value.target_name),
             [self.use_definition(arg) for _, arg in value.argument_list])
             [self.use_definition(arg) for _, arg in value.argument_list])
 
 
+    def lower_self_positional_call(self, value):
+        """Lowers a direct call that uses the 'self-positional' calling convention."""
+        all_args = [self.use_definition(arg) for _, arg in value.argument_list]
+        target = all_args[0]
+        arg_list = all_args[1:]
+        return tree_ir.CallInstruction(
+            tree_ir.LoadMemberInstruction(target, value.target_name), arg_list)
+
     def lower_jit_call(self, value):
     def lower_jit_call(self, value):
         """Lowers a direct call that uses the 'jit' calling convention."""
         """Lowers a direct call that uses the 'jit' calling convention."""
         arg_list = [(name, self.use_definition(arg))
         arg_list = [(name, self.use_definition(arg))
@@ -954,12 +962,17 @@ class LoweringState(object):
 
 
     call_lowerings = {
     call_lowerings = {
         cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION : lower_simple_positional_call,
         cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION : lower_simple_positional_call,
+        cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION : lower_self_positional_call,
         cfg_ir.JIT_CALLING_CONVENTION : lower_jit_call,
         cfg_ir.JIT_CALLING_CONVENTION : lower_jit_call,
         cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION : lower_macro_call
         cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION : lower_macro_call
     }
     }
 
 
     macro_lowerings = {
     macro_lowerings = {
-        cfg_ir.PRINT_MACRO_NAME: tree_ir.PrintInstruction
+        cfg_ir.PRINT_MACRO_NAME: tree_ir.PrintInstruction,
+        cfg_ir.READ_DICT_KEYS_MACRO_NAME: tree_ir.ReadDictionaryKeysInstruction,
+        cfg_ir.REVERSE_LIST_MACRO_NAME:
+        lambda seq:
+        tree_ir.ListSliceInstruction(seq, None, None, tree_ir.LiteralInstruction(-1))
     }
     }
 
 
 def lower_flow_graph(entry_point, jit):
 def lower_flow_graph(entry_point, jit):

+ 71 - 0
kernel/modelverse_jit/tree_ir.py

@@ -813,6 +813,63 @@ class DictionaryLiteralInstruction(Instruction):
     def __repr__(self):
     def __repr__(self):
         return "DictionaryLiteralInstruction(%r)" % self.key_value_pairs
         return "DictionaryLiteralInstruction(%r)" % self.key_value_pairs
 
 
+class ListSliceInstruction(Instruction):
+    """Slices a list."""
+    def __init__(self, seq, start, end, step):
+        Instruction.__init__(self)
+        self.seq = seq
+        self.start = start
+        self.end = end
+        self.step = step
+
+    def get_children(self):
+        """Gets this instruction's sequence of child instructions."""
+        all_items = (self.seq, self.start, self.end, self.step)
+        return [item for item in all_items if item is not None]
+
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        # pylint: disable=I0011,E1120
+        args = []
+        i = 0
+        for old_item in (self.seq, self.start, self.end, self.step):
+            if old_item is None:
+                args.append(None)
+            else:
+                args.append(new_children[i])
+                i += 1
+
+        assert len(new_children) == i
+        assert len(args) == 4
+        return ListSliceInstruction(*args)
+
+    def has_definition_impl(self):
+        """Tells if this instruction requires a definition."""
+        return any([item.has_definition() for item in self.get_children()])
+
+    def has_result_temporary(self):
+        """Tells if this instruction stores its result in a temporary."""
+        return False
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+            The statement is appended immediately to the code generator."""
+        for item in self.get_children():
+            if item.has_definition():
+                item.generate_python_def(code_generator)
+
+    def generate_python_use(self, code_generator):
+        """Generates a Python expression that retrieves this instruction's
+           result. The expression is returned as a string."""
+        return '%s[%s:%s:%s]' % (
+            self.seq.generate_python_use(code_generator),
+            '' if self.start is None else self.start.generate_python_use(code_generator),
+            '' if self.end is None else self.end.generate_python_use(code_generator),
+            '' if self.step is None else self.step.generate_python_use(code_generator))
+
+    def __repr__(self):
+        return "ListSliceInstruction(%r, %r, %r, %r)" % (self.seq, self.start, self.end, self.step)
+
 class StateInstruction(Instruction):
 class StateInstruction(Instruction):
     """An instruction that accesses the modelverse state."""
     """An instruction that accesses the modelverse state."""
     def get_result_type_impl(self):
     def get_result_type_impl(self):
@@ -1290,6 +1347,20 @@ class ReadDictionaryEdgeInstruction(StateInstruction):
         """Gets this state instruction's argument list."""
         """Gets this state instruction's argument list."""
         return [self.node_id, self.key]
         return [self.node_id, self.key]
 
 
+class ReadDictionaryKeysInstruction(StateInstruction):
+    """An instruction that reads all keys from a dictionary."""
+    def __init__(self, node_id):
+        StateInstruction.__init__(self)
+        self.node_id = node_id
+
+    def get_opcode(self):
+        """Gets the opcode for this state instruction."""
+        return "RDK"
+
+    def get_arguments(self):
+        """Gets this state instruction's argument list."""
+        return [self.node_id]
+
 class ReadEdgeInstruction(StateInstruction):
 class ReadEdgeInstruction(StateInstruction):
     """An instruction that reads an edge."""
     """An instruction that reads an edge."""
     def __init__(self, node_id):
     def __init__(self, node_id):