瀏覽代碼

Shield JIT temporaries from the garbage collector

jonathanvdc 8 年之前
父節點
當前提交
de10bb6587
共有 3 個文件被更改,包括 265 次插入55 次删除
  1. 4 0
      kernel/modelverse_jit/jit.py
  2. 260 54
      kernel/modelverse_jit/tree_ir.py
  3. 1 1
      kernel/modelverse_kernel/main.py

+ 4 - 0
kernel/modelverse_jit/jit.py

@@ -346,6 +346,10 @@ class ModelverseJit(object):
         # Optimize the function's body.
         constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
 
+        # Shield temporaries from the GC.
+        constructed_body = tree_ir.protect_temporaries_from_gc(
+            constructed_body, tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
+
         # Wrap the IR in a function definition, give it a unique name.
         constructed_function = tree_ir.DefineFunctionInstruction(
             function_name,

+ 260 - 54
kernel/modelverse_jit/tree_ir.py

@@ -25,23 +25,69 @@
 # `None` marker value to the server, without terminating the current
 # generator.
 
+# Let's just agree to disagree on map vs list comprehensions, pylint.
+# pylint: disable=I0011,W0141
+
 NOP_LITERAL = None
 """A literal that results in a nop during which execution may be interrupted
    when yielded."""
 
+UNKNOWN_RESULT_TYPE = 'unknown'
+"""The result type for instructions that produce either primitive values or
+   references to nodes."""
+
+PRIMITIVE_RESULT_TYPE = 'primitive'
+"""The result type for instructions that produce primitive values."""
+
+NODE_RESULT_TYPE = 'node'
+"""The result type for instructions that produce references to nodes."""
+
+NO_RESULT_TYPE = 'nothing'
+"""The result type for instructions that no result."""
+
+def result_type_intersection(first_type, second_type):
+    """Computes the intersection of the given result types."""
+    if first_type == second_type:
+        return first_type
+    elif ((first_type == PRIMITIVE_RESULT_TYPE and second_type == UNKNOWN_RESULT_TYPE)
+          or (first_type == UNKNOWN_RESULT_TYPE and second_type == PRIMITIVE_RESULT_TYPE)):
+        return PRIMITIVE_RESULT_TYPE
+    elif ((first_type == NODE_RESULT_TYPE and second_type == UNKNOWN_RESULT_TYPE)
+          or (first_type == UNKNOWN_RESULT_TYPE and second_type == NODE_RESULT_TYPE)):
+        return NODE_RESULT_TYPE
+    else:
+        return NO_RESULT_TYPE
+
+def result_type_union(first_type, second_type):
+    """Computes the union of the given result types."""
+    if first_type == second_type:
+        return first_type
+    elif ((first_type == PRIMITIVE_RESULT_TYPE and second_type == NO_RESULT_TYPE)
+          or (first_type == NO_RESULT_TYPE and second_type == PRIMITIVE_RESULT_TYPE)):
+        return PRIMITIVE_RESULT_TYPE
+    elif ((first_type == NODE_RESULT_TYPE and second_type == NO_RESULT_TYPE)
+          or (first_type == NO_RESULT_TYPE and second_type == NODE_RESULT_TYPE)):
+        return NODE_RESULT_TYPE
+    else:
+        return UNKNOWN_RESULT_TYPE
+
 class Instruction(object):
     """A base class for instructions. An instruction is essentially a syntax
        node that must first be defined, and can only then be used."""
 
     def __init__(self):
-        self.has_result_cache = None
+        self.result_type_cache = None
         self.has_definition_cache = None
 
+    def get_result_type(self):
+        """Gets this instruction's result type."""
+        if self.result_type_cache is None:
+            self.result_type_cache = self.get_result_type_impl()
+        return self.result_type_cache
+
     def has_result(self):
         """Tells if this instruction computes a result."""
-        if self.has_result_cache is None:
-            self.has_result_cache = self.has_result_impl()
-        return self.has_result_cache
+        return self.get_result_type() != NO_RESULT_TYPE
 
     def has_definition(self):
         """Tells if this instruction requires a definition."""
@@ -49,9 +95,9 @@ class Instruction(object):
             self.has_definition_cache = self.has_definition_impl()
         return self.has_definition_cache
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return True
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        return PRIMITIVE_RESULT_TYPE
 
     def has_definition_impl(self):
         """Tells if this instruction requires a definition."""
@@ -219,9 +265,9 @@ class PythonGenerator(object):
 class VoidInstruction(Instruction):
     """A base class for instructions that do not return a value."""
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return False
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        return NO_RESULT_TYPE
 
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
@@ -247,9 +293,11 @@ class SelectInstruction(Instruction):
         self.if_clause = if_clause
         self.else_clause = else_clause
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return self.if_clause.has_result() or self.else_clause.has_result()
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        return result_type_intersection(
+            self.if_clause.get_result_type(),
+            self.else_clause.get_result_type())
 
     def simplify_node(self):
         """Applies basic simplification to this instruction only."""
@@ -390,16 +438,12 @@ class CallInstruction(Instruction):
                 self.target.generate_python_use(code_generator),
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
 
-class PrintInstruction(Instruction):
+class PrintInstruction(VoidInstruction):
     """An instruction that prints a value."""
     def __init__(self, argument):
-        Instruction.__init__(self)
+        VoidInstruction.__init__(self)
         self.argument = argument
 
-    def has_result_impl(self):
-        """Tells if this instruction has a result."""
-        return False
-
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         return [self.argument]
@@ -555,9 +599,12 @@ class CompoundInstruction(Instruction):
         self.first = first
         self.second = second
 
-    def has_result_impl(self):
-        """Tells if this instruction has a result."""
-        return self.second.has_result() or self.first.has_result()
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        if self.second.has_result():
+            return self.second.get_result_type()
+        else:
+            return self.first.get_result_type()
 
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
@@ -673,6 +720,10 @@ class DictionaryLiteralInstruction(Instruction):
 
 class StateInstruction(Instruction):
     """An instruction that accesses the modelverse state."""
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NODE_RESULT_TYPE
+
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         raise NotImplementedError()
@@ -702,10 +753,11 @@ class StateInstruction(Instruction):
 
 class RunGeneratorFunctionInstruction(StateInstruction):
     """An instruction that runs a generator function."""
-    def __init__(self, function, argument_dict):
+    def __init__(self, function, argument_dict, result_type=PRIMITIVE_RESULT_TYPE):
         StateInstruction.__init__(self)
         self.function = function
         self.argument_dict = argument_dict
+        self.result_type_cache = result_type
 
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
@@ -715,6 +767,11 @@ class RunGeneratorFunctionInstruction(StateInstruction):
         """Gets this state instruction's argument list."""
         return [self.function, self.argument_dict]
 
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        func, arg_dict = new_children
+        return RunGeneratorFunctionInstruction(func, arg_dict, self.get_result_type())
+
 class RunTailGeneratorFunctionInstruction(StateInstruction):
     """An instruction that runs a generator function."""
     def __init__(self, function, argument_dict):
@@ -722,6 +779,10 @@ class RunTailGeneratorFunctionInstruction(StateInstruction):
         self.function = function
         self.argument_dict = argument_dict
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
+
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         return "TAIL_CALL_KWARGS"
@@ -940,10 +1001,10 @@ class LoadMemberInstruction(Instruction):
             self.container.generate_python_use(code_generator),
             self.member_name)
 
-class StoreMemberInstruction(Instruction):
+class StoreMemberInstruction(VoidInstruction):
     """An instruction that stores a value in a container member."""
     def __init__(self, container, member_name, value):
-        Instruction.__init__(self)
+        VoidInstruction.__init__(self)
         self.container = container
         self.member_name = member_name
         self.value = value
@@ -952,10 +1013,6 @@ class StoreMemberInstruction(Instruction):
         """Tells if this instruction requires a definition."""
         return True
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return False
-
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         return [self.container, self.value]
@@ -975,20 +1032,8 @@ class StoreMemberInstruction(Instruction):
             self.member_name,
             self.value.generate_python_use(code_generator)))
 
-class NopInstruction(Instruction):
+class NopInstruction(VoidInstruction):
     """A nop instruction, which allows for the kernel's thread of execution to be interrupted."""
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return False
-
-    def get_children(self):
-        """Gets this instruction's sequence of child instructions."""
-        return []
-
-    def create(self, new_children):
-        """Creates a new instruction of this type from the given sequence of child instructions."""
-        return self
-
     def generate_python_def(self, code_generator):
         """Generates a Python statement that executes this instruction.
            The statement is appended immediately to the code generator."""
@@ -1000,6 +1045,10 @@ class ReadValueInstruction(StateInstruction):
         StateInstruction.__init__(self)
         self.node_id = node_id
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return PRIMITIVE_RESULT_TYPE
+
     def simplify_node(self):
         """Applies basic simplification to this instruction only."""
         if isinstance(self.node_id, CreateNodeWithValueInstruction):
@@ -1065,6 +1114,10 @@ class ReadOutgoingEdgesInstruction(StateInstruction):
         StateInstruction.__init__(self)
         self.node_id = node_id
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return PRIMITIVE_RESULT_TYPE
+
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         return "RO"
@@ -1079,6 +1132,10 @@ class ReadIncomingEdgesInstruction(StateInstruction):
         StateInstruction.__init__(self)
         self.node_id = node_id
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return PRIMITIVE_RESULT_TYPE
+
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         return "RI"
@@ -1089,7 +1146,6 @@ class ReadIncomingEdgesInstruction(StateInstruction):
 
 class CreateNodeInstruction(StateInstruction):
     """An instruction that creates an empty node."""
-
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         return "CN"
@@ -1149,9 +1205,9 @@ class DeleteNodeInstruction(StateInstruction):
         StateInstruction.__init__(self)
         self.node_id = node_id
 
-    def has_result(self):
-        """Tells if this instruction computes a result."""
-        return False
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
 
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
@@ -1167,9 +1223,9 @@ class DeleteEdgeInstruction(StateInstruction):
         StateInstruction.__init__(self)
         self.edge_id = edge_id
 
-    def has_result(self):
-        """Tells if this instruction computes a result."""
-        return False
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
 
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
@@ -1218,7 +1274,7 @@ def create_jit_call(target, named_arguments, kwargs):
     return CompoundInstruction(
         create_block(*results),
         RunGeneratorFunctionInstruction(
-            target, arg_dict.create_load()))
+            target, arg_dict.create_load(), NODE_RESULT_TYPE))
 
 def create_new_local_node(local_variable, connected_node, edge_variable=None):
     """Creates a local node that is the backing storage for a local variable.
@@ -1226,7 +1282,7 @@ def create_new_local_node(local_variable, connected_node, edge_variable=None):
        as dead by the GC. The newly created node is stored in the given
        local variable. The edge's id can also optionally be stored in a variable."""
     local_store = StoreLocalInstruction(local_variable, CreateNodeInstruction())
-    create_edge = CreateEdgeInstruction(local_store.create_load(), connected_node)
+    create_edge = CreateEdgeInstruction(connected_node, local_store.create_load())
     if edge_variable is not None:
         create_edge = StoreLocalInstruction(edge_variable, create_edge)
 
@@ -1247,17 +1303,167 @@ def with_debug_info_trace(instruction, debug_info, function_name):
                 LiteralInstruction('TRACE: %s(%s, JIT)' % (debug_info, function_name))),
             instruction)
 
+def map_instruction_tree_top_down(function, instruction):
+    """Applies the given mapping function to every instruction in the tree
+       that has the given instruction as root. The map is applied in a top-down
+       fashion."""
+    mapped_instruction = function(instruction)
+    return mapped_instruction.create(
+        [map_instruction_tree_top_down(function, child)
+         for child in mapped_instruction.get_children()])
+
 def map_and_simplify(function, instruction):
     """Applies the given mapping function to every instruction in the tree
        that has the given instruction as root, and simplifies it on-the-fly.
+       The map is applied in a bottom-up fashion.
 
        This is at least as powerful as first mapping and then simplifying, as
        maps and simplifications are interspersed."""
-    # Let's just agree to disagree on map vs list comprehensions, pylint.
-    # pylint: disable=I0011,W0141
     return function(
         instruction.create(
-            map(map_and_simplify, instruction.get_children()))).simplify_node()
+            [map_and_simplify(function, child)
+             for child in instruction.get_children()])).simplify_node()
+
+def iterate_as_stack(instruction, stack_iterator):
+    """Iterates over the given instruction and its children in the order in which temporaries are
+    'pushed' on and 'popped' from a virtual evaluation stack."""
+    if isinstance(instruction, SelectInstruction):
+        iterate_as_stack(instruction.condition, stack_iterator)
+        stack_iterator.pop()
+        if_iterator, else_iterator = stack_iterator.branch(), stack_iterator.branch()
+        iterate_as_stack(instruction.if_clause, if_iterator)
+        iterate_as_stack(instruction.else_clause, else_iterator)
+        stack_iterator.merge(if_iterator, else_iterator)
+    elif isinstance(instruction, CompoundInstruction):
+        iterate_as_stack(instruction.first, stack_iterator)
+        if instruction.second.has_result():
+            stack_iterator.pop()
+        iterate_as_stack(instruction.second, stack_iterator)
+        if not instruction.second.has_result():
+            stack_iterator.pop()
+    else:
+        children = instruction.get_children()
+        for child in children:
+            # Push all children onto the stack.
+            iterate_as_stack(child, stack_iterator)
+        for child in children:
+            # Pop all children from the stack.
+            stack_iterator.pop()
+        # Push the instruction.
+        stack_iterator.push(instruction)
+
+class StackIterator(object):
+    """A base class for stack iterators."""
+    def __init__(self, stack=None):
+        self.stack = [] if stack is None else stack
+
+    def pop(self):
+        """Pops an instruction from the stack."""
+        self.stack.pop()
+
+    def push(self, instruction):
+        """Pushes an instruction onto the stack."""
+        self.stack.append(set([instruction]))
+
+    def branch(self):
+        """Creates a copy of this stack iterator."""
+        return self.create(self.copy_stack())
+
+    def merge(self, *branches):
+        """Sets this stack iterator's stack to the union of the given branches."""
+        self.__init__(self.merge_stacks(*branches))
+
+    def copy_stack(self):
+        """Creates a copy of this stack iterator's stack."""
+        return [set(vals) for vals in self.stack]
+
+    def create(self, new_stack):
+        """Creates a stack iterator from the given stack"""
+        return type(self)(new_stack)
+
+    def merge_stacks(self, *branches):
+        """Computes the union of the stacks of the given branches."""
+        results = None
+        for branch in branches:
+            if results is None:
+                results = branch.copy_stack()
+            else:
+                assert len(branch.stack) == len(results)
+                results = [set.union(*t) for t in zip(branch.stack, results)]
+
+        return results
+
+def protect_temporaries_from_gc(instruction, connected_node):
+    """Protects temporaries from the garbage collector by connecting them to the given node."""
+    # # The reasoning behind this function
+    #
+    # A nop instruction (`yield None`) may trigger the garbage collector, which will delete
+    # unreachable ("dead") vertices and edges. Now take into account that a bare temporary node
+    # is actually unreachable from the root node. The consequence is that temporary nodes
+    # may be garbage-collected if a nop instruction is executed while they are on the evaluation
+    # "stack." This is _never_ what we want.
+    #
+    # To counter this, we can connect temporary nodes to a node that is reachable from the root.
+    # However, we only want to create edges between edges and a known reachable node if we really
+    # have to, because creating edges incurs some overhead.
+    #
+    # We will create an edge between a temporary and the known reachable node if and only if the
+    # temporary is on the "stack" when either a nop or a call instruction is executed.
+
+    class GCStackIterator(StackIterator):
+        """A stack iterator that detects which instructions might be at risk of getting garbage
+           collected."""
+        def __init__(self, stack=None, gc_temporaries=None):
+            StackIterator.__init__(self, stack)
+            self.gc_temporaries = set() if gc_temporaries is None else gc_temporaries
+
+        def push(self, instruction):
+            """Pushes an instruction onto the stack."""
+            if isinstance(instruction, (
+                    NopInstruction,
+                    RunGeneratorFunctionInstruction,
+                    RunTailGeneratorFunctionInstruction)):
+                # All values currently on the stack are at risk. Mark them as such.
+                for instruction_set in self.stack:
+                    self.gc_temporaries.update(instruction_set)
+
+            # Proceed.
+            StackIterator.push(self, instruction)
+
+        def merge(self, *branches):
+            """Sets this stack iterator's stack to the union of the given branches."""
+            self.__init__(
+                self.merge_stacks(*branches),
+                set.union(*[br.gc_temporaries for br in branches]))
+
+        def create(self, new_stack):
+            """Creates a stack iterator from the given stack"""
+            return GCStackIterator(new_stack, self.gc_temporaries)
+
+    # Find out which instructions are at risk.
+    gc_iterator = GCStackIterator()
+    iterate_as_stack(instruction, gc_iterator)
+    # These temporaries need to be protected from the GC.
+    gc_temporaries = gc_iterator.gc_temporaries
+
+    def protect_result(instruction):
+        """Protects the given instruction's (temporary) result."""
+        if instruction in gc_temporaries and instruction.get_result_type() == NODE_RESULT_TYPE:
+            gc_temporaries.remove(instruction)
+            store_instr = StoreLocalInstruction(None, instruction)
+            return CompoundInstruction(
+                store_instr,
+                CompoundInstruction(
+                    SelectInstruction(
+                        BinaryInstruction(
+                            store_instr.create_load(), 'is not', LiteralInstruction(None)),
+                        CreateEdgeInstruction(connected_node, store_instr.create_load()),
+                        EmptyInstruction()),
+                    store_instr.create_load()))
+        else:
+            return instruction
+
+    return map_instruction_tree_top_down(protect_result, instruction)
 
 if __name__ == "__main__":
     example_tree = SelectInstruction(

+ 1 - 1
kernel/modelverse_kernel/main.py

@@ -60,7 +60,7 @@ class ModelverseKernel(object):
         # To make the JIT compile 'input' instructions as calls to
         # modelverse_jit.runtime.get_input, uncomment the line below:
         #
-        #     self.jit.use_input_function()
+        self.jit.use_input_function()
         #
 
         self.debug_info = defaultdict(list)