Browse Source

Strategically insert GC protects

jonathanvdc 8 years ago
parent
commit
77eee2f563

+ 93 - 13
kernel/modelverse_jit/cfg_ir.py

@@ -49,23 +49,31 @@ class BasicBlock(object):
         self.renumber_definitions()
         self.renumber_definitions()
         return result
         return result
 
 
-    def insert_definition_before(self, anchor, value):
-        """Inserts the second definition or value before the first definition."""
-        index = None
+    def __get_def_index_for_insert(self, anchor):
         for i, definition in enumerate(self.definitions):
         for i, definition in enumerate(self.definitions):
             if definition.definition_index == anchor.definition_index:
             if definition.definition_index == anchor.definition_index:
-                index = i
+                return i
 
 
-        if index is None:
-            raise ValueError(
-                'Cannot insert a definition because the anchor '
-                'is not defined in this block.')
+        raise ValueError(
+            'Cannot insert a definition because the anchor '
+            'is not defined in this block.')
 
 
+    def insert_definition_before(self, anchor, value):
+        """Inserts the second definition or value before the first definition."""
+        index = self.__get_def_index_for_insert(anchor)
         result = self.create_definition(value)
         result = self.create_definition(value)
         self.definitions.insert(index, result)
         self.definitions.insert(index, result)
         self.renumber_definitions()
         self.renumber_definitions()
         return result
         return result
 
 
+    def insert_definition_after(self, anchor, value):
+        """Inserts the second definition or value after the first definition."""
+        index = self.__get_def_index_for_insert(anchor)
+        result = self.create_definition(value)
+        self.definitions.insert(index + 1, result)
+        self.renumber_definitions()
+        return result
+
     def append_definition(self, value):
     def append_definition(self, value):
         """Defines the given value in this basic block."""
         """Defines the given value in this basic block."""
         result = self.create_definition(value)
         result = self.create_definition(value)
@@ -158,6 +166,10 @@ class Definition(object):
         """Inserts the given value or definition before this definition."""
         """Inserts the given value or definition before this definition."""
         return self.block.insert_definition_before(self, value)
         return self.block.insert_definition_before(self, value)
 
 
+    def insert_after(self, value):
+        """Inserts the given value or definition after this definition."""
+        return self.block.insert_definition_after(self, value)
+
     def ref_str(self):
     def ref_str(self):
         """Gets a string that represents a reference to this definition."""
         """Gets a string that represents a reference to this definition."""
         return '$%d' % self.index
         return '$%d' % self.index
@@ -464,6 +476,9 @@ JIT_CALLING_CONVENTION = 'jit'
 MACRO_POSITIONAL_CALLING_CONVENTION = 'macro-positional'
 MACRO_POSITIONAL_CALLING_CONVENTION = 'macro-positional'
 """The calling convention for well-known functions that are expanded as macros during codegen."""
 """The calling convention for well-known functions that are expanded as macros during codegen."""
 
 
+MACRO_IO_CALLING_CONVENTION = 'macro-io'
+"""The calling convention 'input' and 'output'."""
+
 PRINT_MACRO_NAME = 'print'
 PRINT_MACRO_NAME = 'print'
 """The name of the 'print' macro."""
 """The name of the 'print' macro."""
 
 
@@ -479,13 +494,22 @@ READ_DICT_KEYS_MACRO_NAME = 'read_dict_keys'
 REVERSE_LIST_MACRO_NAME = 'reverse_list'
 REVERSE_LIST_MACRO_NAME = 'reverse_list'
 """The name of the list reversal macro."""
 """The name of the list reversal macro."""
 
 
+GC_PROTECT_MACRO_NAME = 'gc_protect'
+"""The name of the macro that unconditionally protects its first argument from the GC by
+   drawing an edge between it and the second argument."""
+
+MAYBE_GC_PROTECT_MACRO_NAME = 'maybe_gc_protect'
+"""The name of the macro that protects its first argument from the GC by drawing an edge between
+   it and the second argument, but only if that first argument is not None."""
+
 class DirectFunctionCall(Value):
 class DirectFunctionCall(Value):
     """A value that is the result of a direct function call."""
     """A value that is the result of a direct function call."""
     def __init__(
     def __init__(
             self, target_name, argument_list,
             self, target_name, argument_list,
             calling_convention=JIT_CALLING_CONVENTION,
             calling_convention=JIT_CALLING_CONVENTION,
             has_value=True,
             has_value=True,
-            has_side_effects=True):
+            has_side_effects=True,
+            has_bidirectional_dependencies=False):
         Value.__init__(self)
         Value.__init__(self)
         self.target_name = target_name
         self.target_name = target_name
         assert all([isinstance(val, Definition) for _, val in argument_list])
         assert all([isinstance(val, Definition) for _, val in argument_list])
@@ -493,6 +517,7 @@ class DirectFunctionCall(Value):
         self.calling_convention = calling_convention
         self.calling_convention = calling_convention
         self.has_value_val = has_value
         self.has_value_val = has_value
         self.has_side_effects_val = has_side_effects
         self.has_side_effects_val = has_side_effects
+        self.has_bidirectional_deps_val = has_bidirectional_dependencies
 
 
     def has_side_effects(self):
     def has_side_effects(self):
         """Tells if this instruction has side-effects."""
         """Tells if this instruction has side-effects."""
@@ -502,13 +527,18 @@ class DirectFunctionCall(Value):
         """Tells if this value produces a result that is not None."""
         """Tells if this value produces a result that is not None."""
         return self.has_value_val
         return self.has_value_val
 
 
+    def has_bidirectional_dependencies(self):
+        """Tells if this value has bidirectional dependencies."""
+        return self.has_bidirectional_deps_val
+
     def create(self, new_dependencies):
     def create(self, new_dependencies):
         """Creates an instruction of this type from the given set of dependencies."""
         """Creates an instruction of this type from the given set of dependencies."""
         return DirectFunctionCall(
         return DirectFunctionCall(
             self.target_name,
             self.target_name,
             [(name, new_val)
             [(name, new_val)
              for new_val, (name, _) in zip(new_dependencies, self.argument_list)],
              for new_val, (name, _) in zip(new_dependencies, self.argument_list)],
-            self.calling_convention, self.has_value_val, self.has_side_effects_val)
+            self.calling_convention, self.has_value_val, self.has_side_effects_val,
+            self.has_bidirectional_deps_val)
 
 
     def get_dependencies(self):
     def get_dependencies(self):
         """Gets all definitions and instructions on which this instruction depends."""
         """Gets all definitions and instructions on which this instruction depends."""
@@ -519,7 +549,9 @@ class DirectFunctionCall(Value):
            The result is a formatted string that consists of a calling convention,
            The result is a formatted string that consists of a calling convention,
            and optionally information that pertains to whether the function returns
            and optionally information that pertains to whether the function returns
            a value and has side-effects."""
            a value and has side-effects."""
-        if self.has_side_effects() and self.has_value():
+        if (self.has_side_effects()
+                and self.has_value()
+                and not self.has_bidirectional_dependencies()):
             return repr(self.calling_convention)
             return repr(self.calling_convention)
 
 
         contents = [repr(self.calling_convention)]
         contents = [repr(self.calling_convention)]
@@ -527,6 +559,8 @@ class DirectFunctionCall(Value):
             contents.append('pure')
             contents.append('pure')
         if not self.has_value():
         if not self.has_value():
             contents.append('void')
             contents.append('void')
+        if self.has_bidirectional_dependencies():
+            contents.append('two-way-dependencies')
 
 
         return '(%s)' % ', '.join(contents)
         return '(%s)' % ', '.join(contents)
 
 
@@ -827,14 +861,40 @@ def create_output(argument):
     """Creates a value that outputs the specified argument."""
     """Creates a value that outputs the specified argument."""
     return DirectFunctionCall(
     return DirectFunctionCall(
         OUTPUT_MACRO_NAME, [('argument', argument)],
         OUTPUT_MACRO_NAME, [('argument', argument)],
-        calling_convention=MACRO_POSITIONAL_CALLING_CONVENTION,
+        calling_convention=MACRO_IO_CALLING_CONVENTION,
         has_value=False)
         has_value=False)
 
 
 def create_input():
 def create_input():
     """Creates a value that pops a value from the input queue."""
     """Creates a value that pops a value from the input queue."""
     return DirectFunctionCall(
     return DirectFunctionCall(
         INPUT_MACRO_NAME, [],
         INPUT_MACRO_NAME, [],
-        calling_convention=MACRO_POSITIONAL_CALLING_CONVENTION)
+        calling_convention=MACRO_IO_CALLING_CONVENTION)
+
+def create_gc_protect(protected_value, root):
+    """Creates a value that protects the first from the GC by drawing an
+       edge between it and the given root."""
+    return DirectFunctionCall(
+        GC_PROTECT_MACRO_NAME, [
+            ('protected_value', protected_value),
+            ('root', root)
+        ],
+        calling_convention=MACRO_POSITIONAL_CALLING_CONVENTION,
+        has_value=False, has_side_effects=False,
+        has_bidirectional_dependencies=True)
+
+def create_conditional_gc_protect(protected_value, root):
+    """Creates a value that protects the first from the GC by drawing an
+       edge between it and the given root, but only if the protected value
+       is not None."""
+    return DirectFunctionCall(
+        MAYBE_GC_PROTECT_MACRO_NAME, [
+            ('condition', protected_value),
+            ('protected_value', protected_value),
+            ('root', root)
+        ],
+        calling_convention=MACRO_POSITIONAL_CALLING_CONVENTION,
+        has_value=False, has_side_effects=False,
+        has_bidirectional_dependencies=True)
 
 
 def get_def_value(def_or_value):
 def get_def_value(def_or_value):
     """Returns the given value, or the underlying value of the given definition, whichever is
     """Returns the given value, or the underlying value of the given definition, whichever is
@@ -953,6 +1013,26 @@ def get_trivial_phi_value(parameter_def, values):
 
 
     return result
     return result
 
 
+def find_all_def_uses(entry_point):
+    """Finds all uses of all definitions in the given entry point.
+       A (definition to list of users map, definition to defining block map)
+       tuple is returned."""
+    all_blocks = list(get_all_blocks(entry_point))
+
+    # Find all definition users for each definition.
+    def_users = defaultdict(list)
+    def_blocks = {}
+    for block in all_blocks:
+        for parameter_def in block.parameters:
+            def_blocks[parameter_def] = block
+        for definition in block.definitions + [block.flow]:
+            def_blocks[definition] = block
+            for dependency in definition.get_all_dependencies():
+                if not isinstance(dependency, Branch):
+                    def_users[dependency].append(definition)
+
+    return def_users, def_blocks
+
 def match_and_rewrite(entry_point, match_def, match_use, rewrite_def, rewrite_use):
 def match_and_rewrite(entry_point, match_def, match_use, rewrite_def, rewrite_use):
     """Matches and rewrites chains of definitions and uses in the graph defined by
     """Matches and rewrites chains of definitions and uses in the graph defined by
        the given entry point."""
        the given entry point."""

+ 130 - 0
kernel/modelverse_jit/cfg_optimization.py

@@ -364,6 +364,133 @@ def optimize_reads(entry_point):
             cfg_ir.Read(def_def.insert_before(def_def.value))),
             cfg_ir.Read(def_def.insert_before(def_def.value))),
         lambda use_def, def_def: use_def.redefine(def_def))
         lambda use_def, def_def: use_def.redefine(def_def))
 
 
+def protect_from_gc(entry_point):
+    """Protects locals in the control-flow graph defined by the given
+       entry point from the GC."""
+    root_node = entry_point.prepend_definition(cfg_ir.AllocateRootNode())
+    def protect_def_from_gc(definition):
+        """Protects the given definition from the GC."""
+        definition.insert_after(cfg_ir.create_gc_protect(definition, root_node))
+
+    def maybe_protect_def_from_gc(definition):
+        """Protects the given definition from the GC, if its result is not None."""
+        definition.insert_after(cfg_ir.create_conditional_gc_protect(definition, root_node))
+
+    for block in cfg_ir.get_all_blocks(entry_point):
+        for definition in block.definitions:
+            def_value = cfg_ir.get_def_value(definition)
+            if isinstance(def_value, cfg_ir.CreateNode):
+                protect_def_from_gc(definition)
+            elif (isinstance(def_value, cfg_ir.IndirectFunctionCall)
+                  or (isinstance(def_value, cfg_ir.DirectFunctionCall)
+                      and (def_value.calling_convention == cfg_ir.JIT_CALLING_CONVENTION
+                           or def_value.calling_convention == cfg_ir.MACRO_IO_CALLING_CONVENTION)
+                      and def_value.has_value())):
+                maybe_protect_def_from_gc(definition)
+
+        if isinstance(block.flow, (cfg_ir.ReturnFlow, cfg_ir.ThrowFlow)):
+            block.append_definition(cfg_ir.DeallocateRootNode(root_node))
+
+def elide_gc_protects(entry_point):
+    """Tries to elide GC protection values."""
+    # We don't need to protect a value from the GC if it is used for the
+    # last time _before_ the GC has an opportunity to kick in. To simplify
+    # things, we'll do a quick block-based analysis.
+    def __may_cause_gc(definition):
+        def_value = cfg_ir.get_def_value(definition)
+        if isinstance(def_value, cfg_ir.IndirectFunctionCall):
+            return True
+        elif (isinstance(def_value, cfg_ir.DirectFunctionCall)
+              and (def_value.calling_convention == cfg_ir.JIT_CALLING_CONVENTION
+                   or def_value.calling_convention == cfg_ir.MACRO_IO_CALLING_CONVENTION)):
+            return True
+        else:
+            return False
+
+    def __get_protected_def(def_or_value):
+        value = cfg_ir.get_def_value(def_or_value)
+        if cfg_ir.is_call(
+                value, target_name=cfg_ir.GC_PROTECT_MACRO_NAME,
+                calling_convention=cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION):
+            _, protected_def = value.argument_list[0]
+            return protected_def
+        elif cfg_ir.is_call(
+                value, target_name=cfg_ir.MAYBE_GC_PROTECT_MACRO_NAME,
+                calling_convention=cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION):
+            _, protected_def = value.argument_list[1]
+            return protected_def
+        else:
+            return None
+
+    def_blocks = {}
+    def __register_def_or_use(definition, block):
+        if definition in def_blocks and def_blocks[definition] != block:
+            # Definition seems to be used across basic blocks.
+            ineligible_defs.add(definition)
+
+        def_blocks[definition] = block
+
+    ineligible_defs = set()
+    def_protections = defaultdict(list)
+    for block in cfg_ir.get_all_blocks(entry_point):
+        no_gc_defs = set()
+        block_defs = set()
+        first_gc = {}
+        last_def_uses = {}
+        for i, definition in enumerate(block.definitions):
+            if isinstance(definition.value, cfg_ir.Definition):
+                # Handling definitions of definitions is complicated and they should already have
+                # been expanded at this point. Just mark them as ineligible.
+                ineligible_defs.add(definition)
+                ineligible_defs.add(definition.value)
+                continue
+
+            protected_def = __get_protected_def(definition)
+            if protected_def is not None:
+                # We just ran into a gc_protect/maybe_gc_protect.
+                def_protections[protected_def].append(definition)
+                continue
+
+            block_defs.add(definition)
+            __register_def_or_use(definition, block)
+
+            for dependency in definition.get_all_dependencies():
+                __register_def_or_use(dependency, block)
+                last_def_uses[dependency] = i
+
+            if __may_cause_gc(definition):
+                for gc_def in no_gc_defs:
+                    first_gc[gc_def] = i
+                no_gc_defs = set()
+
+            no_gc_defs.add(definition)
+
+        # Mark all branch arguments as ineligible.
+        for branch in block.flow.branches():
+            ineligible_defs.update(branch.arguments)
+
+        for dependency in block.flow.get_dependencies():
+            last_def_uses[dependency] = None
+
+        for definition in block_defs:
+            if definition in ineligible_defs:
+                # Definition was already ineligible.
+                continue
+
+            # Mark `definition` as ineligible if there is a GC definition in the range of
+            # definitions (definition, last_def_uses[definition]].
+            if definition in first_gc:
+                if definition in last_def_uses:
+                    last_use = last_def_uses[definition]
+                    if last_use is None or first_gc[definition] <= last_use:
+                        ineligible_defs.add(definition)
+
+    # Elide all GC protections for definitions which are not in the `ineligible_defs` set.
+    for protected, protections in def_protections.items():
+        if protected not in ineligible_defs:
+            for protect_def in protections:
+                protect_def.redefine(cfg_ir.Literal(None))
+
 def optimize(entry_point, jit):
 def optimize(entry_point, jit):
     """Optimizes the control-flow graph defined by the given entry point.
     """Optimizes the control-flow graph defined by the given entry point.
        A potentially altered entry point is returned."""
        A potentially altered entry point is returned."""
@@ -382,4 +509,7 @@ def optimize(entry_point, jit):
     expand_indirect_definitions(entry_point)
     expand_indirect_definitions(entry_point)
     eliminate_unused_definitions(entry_point)
     eliminate_unused_definitions(entry_point)
     merge_blocks(entry_point)
     merge_blocks(entry_point)
+    protect_from_gc(entry_point)
+    elide_gc_protects(entry_point)
+    eliminate_unused_definitions(entry_point)
     raise primitive_functions.PrimitiveFinished(entry_point)
     raise primitive_functions.PrimitiveFinished(entry_point)

+ 38 - 25
kernel/modelverse_jit/cfg_to_tree.py

@@ -293,19 +293,7 @@ def find_inlinable_definitions(entry_point):
     """Computes a set of definitions which are eligible for inlining, i.e., they
     """Computes a set of definitions which are eligible for inlining, i.e., they
        are used only once and are consumed in the same basic block in which they
        are used only once and are consumed in the same basic block in which they
        are defined."""
        are defined."""
-    all_blocks = list(cfg_ir.get_all_blocks(entry_point))
-
-    # Find all definition users for each definition.
-    def_users = defaultdict(set)
-    def_blocks = {}
-    for block in all_blocks:
-        for parameter_def in block.parameters:
-            def_blocks[parameter_def] = block
-        for definition in block.definitions + [block.flow]:
-            def_blocks[definition] = block
-            for dependency in definition.get_all_dependencies():
-                if not isinstance(dependency, cfg_ir.Branch):
-                    def_users[dependency].add(definition)
+    def_users, def_blocks = cfg_ir.find_all_def_uses(entry_point)
 
 
     # Find all definitions which are eligible for inlining.
     # Find all definitions which are eligible for inlining.
     eligible_defs = set()
     eligible_defs = set()
@@ -317,6 +305,17 @@ def find_inlinable_definitions(entry_point):
 
 
     return eligible_defs
     return eligible_defs
 
 
+def lower_gc_protect(protected_value, root):
+    """Lowers a GC_PROTECT_MACRO_NAME macro call."""
+    return tree_ir.IgnoreInstruction(tree_ir.CreateEdgeInstruction(root, protected_value))
+
+def lower_maybe_gc_protect(condition, protected_value, root):
+    """Lowers a MAYBE_GC_PROTECT_MACRO_NAME macro call."""
+    return tree_ir.SelectInstruction(
+        tree_ir.BinaryInstruction(condition, 'is not', tree_ir.LiteralInstruction(None)),
+        lower_gc_protect(protected_value, root),
+        tree_ir.EmptyInstruction())
+
 class SimpleDefinitionScheduler(object):
 class SimpleDefinitionScheduler(object):
     """Schedules definitions within a basic block in the order they occur in."""
     """Schedules definitions within a basic block in the order they occur in."""
     def __init__(self, lowering_state, definitions, flow):
     def __init__(self, lowering_state, definitions, flow):
@@ -649,15 +648,6 @@ class LoweringState(object):
         self.root_edge_names = {}
         self.root_edge_names = {}
         self.inlinable_definitions = inlinable_definitions
         self.inlinable_definitions = inlinable_definitions
         self.scheduler = None
         self.scheduler = None
-        self.macro_lowerings = {
-            cfg_ir.PRINT_MACRO_NAME: tree_ir.PrintInstruction,
-            cfg_ir.OUTPUT_MACRO_NAME: bytecode_to_tree.create_output,
-            cfg_ir.INPUT_MACRO_NAME: lambda: bytecode_to_tree.create_input(self.jit.use_input_function),
-            cfg_ir.READ_DICT_KEYS_MACRO_NAME: tree_ir.ReadDictionaryKeysInstruction,
-            cfg_ir.REVERSE_LIST_MACRO_NAME:
-            lambda seq:
-            tree_ir.ListSliceInstruction(seq, None, None, tree_ir.LiteralInstruction(-1))
-        }
 
 
     def __get_root_edge_name(self, root_node):
     def __get_root_edge_name(self, root_node):
         """Gets the name of the given root edge's variable."""
         """Gets the name of the given root edge's variable."""
@@ -868,13 +858,25 @@ class LoweringState(object):
     def lower_macro_call(self, value):
     def lower_macro_call(self, value):
         """Expands a macro call."""
         """Expands a macro call."""
         arg_list = [self.use_definition(arg) for _, arg in value.argument_list]
         arg_list = [self.use_definition(arg) for _, arg in value.argument_list]
-        if value.target_name in self.macro_lowerings:
-            return self.macro_lowerings[value.target_name](*arg_list)
+        if value.target_name in LoweringState.macro_lowerings:
+            return LoweringState.macro_lowerings[value.target_name](*arg_list)
         else:
         else:
             raise jit_runtime.JitCompilationFailedException(
             raise jit_runtime.JitCompilationFailedException(
                 "Unknown macro: '%s' in instruction '%s'" %
                 "Unknown macro: '%s' in instruction '%s'" %
                 (value.target_name, value))
                 (value.target_name, value))
 
 
+    def lower_io_call(self, value):
+        """Expands an IO call."""
+        arg_list = [self.use_definition(arg) for _, arg in value.argument_list]
+        if value.target_name == cfg_ir.INPUT_MACRO_NAME:
+            return bytecode_to_tree.create_input(self.jit.use_input_function)
+        elif value.target_name == cfg_ir.OUTPUT_MACRO_NAME:
+            return bytecode_to_tree.create_output(*arg_list)
+        else:
+            raise jit_runtime.JitCompilationFailedException(
+                "Unknown IO macro: '%s' in instruction '%s'" %
+                (value.target_name, value))
+
     def lower_jump(self, flow):
     def lower_jump(self, flow):
         """Lowers the given 'jump' flow instruction to a tree."""
         """Lowers the given 'jump' flow instruction to a tree."""
         return self.lower_branch(flow.branch)
         return self.lower_branch(flow.branch)
@@ -969,7 +971,18 @@ class LoweringState(object):
         cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION : lower_simple_positional_call,
         cfg_ir.SIMPLE_POSITIONAL_CALLING_CONVENTION : lower_simple_positional_call,
         cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION : lower_self_positional_call,
         cfg_ir.SELF_POSITIONAL_CALLING_CONVENTION : lower_self_positional_call,
         cfg_ir.JIT_CALLING_CONVENTION : lower_jit_call,
         cfg_ir.JIT_CALLING_CONVENTION : lower_jit_call,
-        cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION : lower_macro_call
+        cfg_ir.MACRO_POSITIONAL_CALLING_CONVENTION : lower_macro_call,
+        cfg_ir.MACRO_IO_CALLING_CONVENTION : lower_io_call
+    }
+
+    macro_lowerings = {
+        cfg_ir.PRINT_MACRO_NAME: tree_ir.PrintInstruction,
+        cfg_ir.READ_DICT_KEYS_MACRO_NAME: tree_ir.ReadDictionaryKeysInstruction,
+        cfg_ir.REVERSE_LIST_MACRO_NAME:
+        lambda seq:
+        tree_ir.ListSliceInstruction(seq, None, None, tree_ir.LiteralInstruction(-1)),
+        cfg_ir.GC_PROTECT_MACRO_NAME: lower_gc_protect,
+        cfg_ir.MAYBE_GC_PROTECT_MACRO_NAME: lower_maybe_gc_protect
     }
     }
 
 
 def lower_flow_graph(entry_point, jit):
 def lower_flow_graph(entry_point, jit):