Browse Source

Shield JIT temporaries from the garbage collector

jonathanvdc 8 years ago
parent
commit
c50d53ad80
3 changed files with 265 additions and 55 deletions
  1. 4 0
      kernel/modelverse_jit/jit.py
  2. 260 54
      kernel/modelverse_jit/tree_ir.py
  3. 1 1
      kernel/modelverse_kernel/main.py

+ 4 - 0
kernel/modelverse_jit/jit.py

@@ -346,6 +346,10 @@ class ModelverseJit(object):
         # Optimize the function's body.
         # Optimize the function's body.
         constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
         constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
 
 
+        # Shield temporaries from the GC.
+        constructed_body = tree_ir.protect_temporaries_from_gc(
+            constructed_body, tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
+
         # Wrap the IR in a function definition, give it a unique name.
         # Wrap the IR in a function definition, give it a unique name.
         constructed_function = tree_ir.DefineFunctionInstruction(
         constructed_function = tree_ir.DefineFunctionInstruction(
             function_name,
             function_name,

+ 260 - 54
kernel/modelverse_jit/tree_ir.py

@@ -25,23 +25,69 @@
 # `None` marker value to the server, without terminating the current
 # `None` marker value to the server, without terminating the current
 # generator.
 # generator.
 
 
+# Let's just agree to disagree on map vs list comprehensions, pylint.
+# pylint: disable=I0011,W0141
+
 NOP_LITERAL = None
 NOP_LITERAL = None
 """A literal that results in a nop during which execution may be interrupted
 """A literal that results in a nop during which execution may be interrupted
    when yielded."""
    when yielded."""
 
 
+UNKNOWN_RESULT_TYPE = 'unknown'
+"""The result type for instructions that produce either primitive values or
+   references to nodes."""
+
+PRIMITIVE_RESULT_TYPE = 'primitive'
+"""The result type for instructions that produce primitive values."""
+
+NODE_RESULT_TYPE = 'node'
+"""The result type for instructions that produce references to nodes."""
+
+NO_RESULT_TYPE = 'nothing'
+"""The result type for instructions that no result."""
+
+def result_type_intersection(first_type, second_type):
+    """Computes the intersection of the given result types."""
+    if first_type == second_type:
+        return first_type
+    elif ((first_type == PRIMITIVE_RESULT_TYPE and second_type == UNKNOWN_RESULT_TYPE)
+          or (first_type == UNKNOWN_RESULT_TYPE and second_type == PRIMITIVE_RESULT_TYPE)):
+        return PRIMITIVE_RESULT_TYPE
+    elif ((first_type == NODE_RESULT_TYPE and second_type == UNKNOWN_RESULT_TYPE)
+          or (first_type == UNKNOWN_RESULT_TYPE and second_type == NODE_RESULT_TYPE)):
+        return NODE_RESULT_TYPE
+    else:
+        return NO_RESULT_TYPE
+
+def result_type_union(first_type, second_type):
+    """Computes the union of the given result types."""
+    if first_type == second_type:
+        return first_type
+    elif ((first_type == PRIMITIVE_RESULT_TYPE and second_type == NO_RESULT_TYPE)
+          or (first_type == NO_RESULT_TYPE and second_type == PRIMITIVE_RESULT_TYPE)):
+        return PRIMITIVE_RESULT_TYPE
+    elif ((first_type == NODE_RESULT_TYPE and second_type == NO_RESULT_TYPE)
+          or (first_type == NO_RESULT_TYPE and second_type == NODE_RESULT_TYPE)):
+        return NODE_RESULT_TYPE
+    else:
+        return UNKNOWN_RESULT_TYPE
+
 class Instruction(object):
 class Instruction(object):
     """A base class for instructions. An instruction is essentially a syntax
     """A base class for instructions. An instruction is essentially a syntax
        node that must first be defined, and can only then be used."""
        node that must first be defined, and can only then be used."""
 
 
     def __init__(self):
     def __init__(self):
-        self.has_result_cache = None
+        self.result_type_cache = None
         self.has_definition_cache = None
         self.has_definition_cache = None
 
 
+    def get_result_type(self):
+        """Gets this instruction's result type."""
+        if self.result_type_cache is None:
+            self.result_type_cache = self.get_result_type_impl()
+        return self.result_type_cache
+
     def has_result(self):
     def has_result(self):
         """Tells if this instruction computes a result."""
         """Tells if this instruction computes a result."""
-        if self.has_result_cache is None:
-            self.has_result_cache = self.has_result_impl()
-        return self.has_result_cache
+        return self.get_result_type() != NO_RESULT_TYPE
 
 
     def has_definition(self):
     def has_definition(self):
         """Tells if this instruction requires a definition."""
         """Tells if this instruction requires a definition."""
@@ -49,9 +95,9 @@ class Instruction(object):
             self.has_definition_cache = self.has_definition_impl()
             self.has_definition_cache = self.has_definition_impl()
         return self.has_definition_cache
         return self.has_definition_cache
 
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return True
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        return PRIMITIVE_RESULT_TYPE
 
 
     def has_definition_impl(self):
     def has_definition_impl(self):
         """Tells if this instruction requires a definition."""
         """Tells if this instruction requires a definition."""
@@ -219,9 +265,9 @@ class PythonGenerator(object):
 class VoidInstruction(Instruction):
 class VoidInstruction(Instruction):
     """A base class for instructions that do not return a value."""
     """A base class for instructions that do not return a value."""
 
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return False
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        return NO_RESULT_TYPE
 
 
     def get_children(self):
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         """Gets this instruction's sequence of child instructions."""
@@ -247,9 +293,11 @@ class SelectInstruction(Instruction):
         self.if_clause = if_clause
         self.if_clause = if_clause
         self.else_clause = else_clause
         self.else_clause = else_clause
 
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return self.if_clause.has_result() or self.else_clause.has_result()
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        return result_type_intersection(
+            self.if_clause.get_result_type(),
+            self.else_clause.get_result_type())
 
 
     def simplify_node(self):
     def simplify_node(self):
         """Applies basic simplification to this instruction only."""
         """Applies basic simplification to this instruction only."""
@@ -390,16 +438,12 @@ class CallInstruction(Instruction):
                 self.target.generate_python_use(code_generator),
                 self.target.generate_python_use(code_generator),
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
                 ', '.join([arg.generate_python_use(code_generator) for arg in self.argument_list])))
 
 
-class PrintInstruction(Instruction):
+class PrintInstruction(VoidInstruction):
     """An instruction that prints a value."""
     """An instruction that prints a value."""
     def __init__(self, argument):
     def __init__(self, argument):
-        Instruction.__init__(self)
+        VoidInstruction.__init__(self)
         self.argument = argument
         self.argument = argument
 
 
-    def has_result_impl(self):
-        """Tells if this instruction has a result."""
-        return False
-
     def get_children(self):
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         """Gets this instruction's sequence of child instructions."""
         return [self.argument]
         return [self.argument]
@@ -555,9 +599,12 @@ class CompoundInstruction(Instruction):
         self.first = first
         self.first = first
         self.second = second
         self.second = second
 
 
-    def has_result_impl(self):
-        """Tells if this instruction has a result."""
-        return self.second.has_result() or self.first.has_result()
+    def get_result_type_impl(self):
+        """Gets this instruction's result type."""
+        if self.second.has_result():
+            return self.second.get_result_type()
+        else:
+            return self.first.get_result_type()
 
 
     def get_children(self):
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         """Gets this instruction's sequence of child instructions."""
@@ -673,6 +720,10 @@ class DictionaryLiteralInstruction(Instruction):
 
 
 class StateInstruction(Instruction):
 class StateInstruction(Instruction):
     """An instruction that accesses the modelverse state."""
     """An instruction that accesses the modelverse state."""
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NODE_RESULT_TYPE
+
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
         raise NotImplementedError()
         raise NotImplementedError()
@@ -702,10 +753,11 @@ class StateInstruction(Instruction):
 
 
 class RunGeneratorFunctionInstruction(StateInstruction):
 class RunGeneratorFunctionInstruction(StateInstruction):
     """An instruction that runs a generator function."""
     """An instruction that runs a generator function."""
-    def __init__(self, function, argument_dict):
+    def __init__(self, function, argument_dict, result_type=PRIMITIVE_RESULT_TYPE):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.function = function
         self.function = function
         self.argument_dict = argument_dict
         self.argument_dict = argument_dict
+        self.result_type_cache = result_type
 
 
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
@@ -715,6 +767,11 @@ class RunGeneratorFunctionInstruction(StateInstruction):
         """Gets this state instruction's argument list."""
         """Gets this state instruction's argument list."""
         return [self.function, self.argument_dict]
         return [self.function, self.argument_dict]
 
 
+    def create(self, new_children):
+        """Creates a new instruction of this type from the given sequence of child instructions."""
+        func, arg_dict = new_children
+        return RunGeneratorFunctionInstruction(func, arg_dict, self.get_result_type())
+
 class RunTailGeneratorFunctionInstruction(StateInstruction):
 class RunTailGeneratorFunctionInstruction(StateInstruction):
     """An instruction that runs a generator function."""
     """An instruction that runs a generator function."""
     def __init__(self, function, argument_dict):
     def __init__(self, function, argument_dict):
@@ -722,6 +779,10 @@ class RunTailGeneratorFunctionInstruction(StateInstruction):
         self.function = function
         self.function = function
         self.argument_dict = argument_dict
         self.argument_dict = argument_dict
 
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
+
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
         return "TAIL_CALL_KWARGS"
         return "TAIL_CALL_KWARGS"
@@ -940,10 +1001,10 @@ class LoadMemberInstruction(Instruction):
             self.container.generate_python_use(code_generator),
             self.container.generate_python_use(code_generator),
             self.member_name)
             self.member_name)
 
 
-class StoreMemberInstruction(Instruction):
+class StoreMemberInstruction(VoidInstruction):
     """An instruction that stores a value in a container member."""
     """An instruction that stores a value in a container member."""
     def __init__(self, container, member_name, value):
     def __init__(self, container, member_name, value):
-        Instruction.__init__(self)
+        VoidInstruction.__init__(self)
         self.container = container
         self.container = container
         self.member_name = member_name
         self.member_name = member_name
         self.value = value
         self.value = value
@@ -952,10 +1013,6 @@ class StoreMemberInstruction(Instruction):
         """Tells if this instruction requires a definition."""
         """Tells if this instruction requires a definition."""
         return True
         return True
 
 
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return False
-
     def get_children(self):
     def get_children(self):
         """Gets this instruction's sequence of child instructions."""
         """Gets this instruction's sequence of child instructions."""
         return [self.container, self.value]
         return [self.container, self.value]
@@ -975,20 +1032,8 @@ class StoreMemberInstruction(Instruction):
             self.member_name,
             self.member_name,
             self.value.generate_python_use(code_generator)))
             self.value.generate_python_use(code_generator)))
 
 
-class NopInstruction(Instruction):
+class NopInstruction(VoidInstruction):
     """A nop instruction, which allows for the kernel's thread of execution to be interrupted."""
     """A nop instruction, which allows for the kernel's thread of execution to be interrupted."""
-    def has_result_impl(self):
-        """Tells if this instruction computes a result."""
-        return False
-
-    def get_children(self):
-        """Gets this instruction's sequence of child instructions."""
-        return []
-
-    def create(self, new_children):
-        """Creates a new instruction of this type from the given sequence of child instructions."""
-        return self
-
     def generate_python_def(self, code_generator):
     def generate_python_def(self, code_generator):
         """Generates a Python statement that executes this instruction.
         """Generates a Python statement that executes this instruction.
            The statement is appended immediately to the code generator."""
            The statement is appended immediately to the code generator."""
@@ -1000,6 +1045,10 @@ class ReadValueInstruction(StateInstruction):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.node_id = node_id
         self.node_id = node_id
 
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return PRIMITIVE_RESULT_TYPE
+
     def simplify_node(self):
     def simplify_node(self):
         """Applies basic simplification to this instruction only."""
         """Applies basic simplification to this instruction only."""
         if isinstance(self.node_id, CreateNodeWithValueInstruction):
         if isinstance(self.node_id, CreateNodeWithValueInstruction):
@@ -1065,6 +1114,10 @@ class ReadOutgoingEdgesInstruction(StateInstruction):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.node_id = node_id
         self.node_id = node_id
 
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return PRIMITIVE_RESULT_TYPE
+
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
         return "RO"
         return "RO"
@@ -1079,6 +1132,10 @@ class ReadIncomingEdgesInstruction(StateInstruction):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.node_id = node_id
         self.node_id = node_id
 
 
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return PRIMITIVE_RESULT_TYPE
+
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
         return "RI"
         return "RI"
@@ -1089,7 +1146,6 @@ class ReadIncomingEdgesInstruction(StateInstruction):
 
 
 class CreateNodeInstruction(StateInstruction):
 class CreateNodeInstruction(StateInstruction):
     """An instruction that creates an empty node."""
     """An instruction that creates an empty node."""
-
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
         return "CN"
         return "CN"
@@ -1149,9 +1205,9 @@ class DeleteNodeInstruction(StateInstruction):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.node_id = node_id
         self.node_id = node_id
 
 
-    def has_result(self):
-        """Tells if this instruction computes a result."""
-        return False
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
 
 
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
@@ -1167,9 +1223,9 @@ class DeleteEdgeInstruction(StateInstruction):
         StateInstruction.__init__(self)
         StateInstruction.__init__(self)
         self.edge_id = edge_id
         self.edge_id = edge_id
 
 
-    def has_result(self):
-        """Tells if this instruction computes a result."""
-        return False
+    def get_result_type_impl(self):
+        """Gets the type of value produced by this instruction."""
+        return NO_RESULT_TYPE
 
 
     def get_opcode(self):
     def get_opcode(self):
         """Gets the opcode for this state instruction."""
         """Gets the opcode for this state instruction."""
@@ -1218,7 +1274,7 @@ def create_jit_call(target, named_arguments, kwargs):
     return CompoundInstruction(
     return CompoundInstruction(
         create_block(*results),
         create_block(*results),
         RunGeneratorFunctionInstruction(
         RunGeneratorFunctionInstruction(
-            target, arg_dict.create_load()))
+            target, arg_dict.create_load(), NODE_RESULT_TYPE))
 
 
 def create_new_local_node(local_variable, connected_node, edge_variable=None):
 def create_new_local_node(local_variable, connected_node, edge_variable=None):
     """Creates a local node that is the backing storage for a local variable.
     """Creates a local node that is the backing storage for a local variable.
@@ -1226,7 +1282,7 @@ def create_new_local_node(local_variable, connected_node, edge_variable=None):
        as dead by the GC. The newly created node is stored in the given
        as dead by the GC. The newly created node is stored in the given
        local variable. The edge's id can also optionally be stored in a variable."""
        local variable. The edge's id can also optionally be stored in a variable."""
     local_store = StoreLocalInstruction(local_variable, CreateNodeInstruction())
     local_store = StoreLocalInstruction(local_variable, CreateNodeInstruction())
-    create_edge = CreateEdgeInstruction(local_store.create_load(), connected_node)
+    create_edge = CreateEdgeInstruction(connected_node, local_store.create_load())
     if edge_variable is not None:
     if edge_variable is not None:
         create_edge = StoreLocalInstruction(edge_variable, create_edge)
         create_edge = StoreLocalInstruction(edge_variable, create_edge)
 
 
@@ -1247,17 +1303,167 @@ def with_debug_info_trace(instruction, debug_info, function_name):
                 LiteralInstruction('TRACE: %s(%s, JIT)' % (debug_info, function_name))),
                 LiteralInstruction('TRACE: %s(%s, JIT)' % (debug_info, function_name))),
             instruction)
             instruction)
 
 
+def map_instruction_tree_top_down(function, instruction):
+    """Applies the given mapping function to every instruction in the tree
+       that has the given instruction as root. The map is applied in a top-down
+       fashion."""
+    mapped_instruction = function(instruction)
+    return mapped_instruction.create(
+        [map_instruction_tree_top_down(function, child)
+         for child in mapped_instruction.get_children()])
+
 def map_and_simplify(function, instruction):
 def map_and_simplify(function, instruction):
     """Applies the given mapping function to every instruction in the tree
     """Applies the given mapping function to every instruction in the tree
        that has the given instruction as root, and simplifies it on-the-fly.
        that has the given instruction as root, and simplifies it on-the-fly.
+       The map is applied in a bottom-up fashion.
 
 
        This is at least as powerful as first mapping and then simplifying, as
        This is at least as powerful as first mapping and then simplifying, as
        maps and simplifications are interspersed."""
        maps and simplifications are interspersed."""
-    # Let's just agree to disagree on map vs list comprehensions, pylint.
-    # pylint: disable=I0011,W0141
     return function(
     return function(
         instruction.create(
         instruction.create(
-            map(map_and_simplify, instruction.get_children()))).simplify_node()
+            [map_and_simplify(function, child)
+             for child in instruction.get_children()])).simplify_node()
+
+def iterate_as_stack(instruction, stack_iterator):
+    """Iterates over the given instruction and its children in the order in which temporaries are
+    'pushed' on and 'popped' from a virtual evaluation stack."""
+    if isinstance(instruction, SelectInstruction):
+        iterate_as_stack(instruction.condition, stack_iterator)
+        stack_iterator.pop()
+        if_iterator, else_iterator = stack_iterator.branch(), stack_iterator.branch()
+        iterate_as_stack(instruction.if_clause, if_iterator)
+        iterate_as_stack(instruction.else_clause, else_iterator)
+        stack_iterator.merge(if_iterator, else_iterator)
+    elif isinstance(instruction, CompoundInstruction):
+        iterate_as_stack(instruction.first, stack_iterator)
+        if instruction.second.has_result():
+            stack_iterator.pop()
+        iterate_as_stack(instruction.second, stack_iterator)
+        if not instruction.second.has_result():
+            stack_iterator.pop()
+    else:
+        children = instruction.get_children()
+        for child in children:
+            # Push all children onto the stack.
+            iterate_as_stack(child, stack_iterator)
+        for child in children:
+            # Pop all children from the stack.
+            stack_iterator.pop()
+        # Push the instruction.
+        stack_iterator.push(instruction)
+
+class StackIterator(object):
+    """A base class for stack iterators."""
+    def __init__(self, stack=None):
+        self.stack = [] if stack is None else stack
+
+    def pop(self):
+        """Pops an instruction from the stack."""
+        self.stack.pop()
+
+    def push(self, instruction):
+        """Pushes an instruction onto the stack."""
+        self.stack.append(set([instruction]))
+
+    def branch(self):
+        """Creates a copy of this stack iterator."""
+        return self.create(self.copy_stack())
+
+    def merge(self, *branches):
+        """Sets this stack iterator's stack to the union of the given branches."""
+        self.__init__(self.merge_stacks(*branches))
+
+    def copy_stack(self):
+        """Creates a copy of this stack iterator's stack."""
+        return [set(vals) for vals in self.stack]
+
+    def create(self, new_stack):
+        """Creates a stack iterator from the given stack"""
+        return type(self)(new_stack)
+
+    def merge_stacks(self, *branches):
+        """Computes the union of the stacks of the given branches."""
+        results = None
+        for branch in branches:
+            if results is None:
+                results = branch.copy_stack()
+            else:
+                assert len(branch.stack) == len(results)
+                results = [set.union(*t) for t in zip(branch.stack, results)]
+
+        return results
+
+def protect_temporaries_from_gc(instruction, connected_node):
+    """Protects temporaries from the garbage collector by connecting them to the given node."""
+    # # The reasoning behind this function
+    #
+    # A nop instruction (`yield None`) may trigger the garbage collector, which will delete
+    # unreachable ("dead") vertices and edges. Now take into account that a bare temporary node
+    # is actually unreachable from the root node. The consequence is that temporary nodes
+    # may be garbage-collected if a nop instruction is executed while they are on the evaluation
+    # "stack." This is _never_ what we want.
+    #
+    # To counter this, we can connect temporary nodes to a node that is reachable from the root.
+    # However, we only want to create edges between edges and a known reachable node if we really
+    # have to, because creating edges incurs some overhead.
+    #
+    # We will create an edge between a temporary and the known reachable node if and only if the
+    # temporary is on the "stack" when either a nop or a call instruction is executed.
+
+    class GCStackIterator(StackIterator):
+        """A stack iterator that detects which instructions might be at risk of getting garbage
+           collected."""
+        def __init__(self, stack=None, gc_temporaries=None):
+            StackIterator.__init__(self, stack)
+            self.gc_temporaries = set() if gc_temporaries is None else gc_temporaries
+
+        def push(self, instruction):
+            """Pushes an instruction onto the stack."""
+            if isinstance(instruction, (
+                    NopInstruction,
+                    RunGeneratorFunctionInstruction,
+                    RunTailGeneratorFunctionInstruction)):
+                # All values currently on the stack are at risk. Mark them as such.
+                for instruction_set in self.stack:
+                    self.gc_temporaries.update(instruction_set)
+
+            # Proceed.
+            StackIterator.push(self, instruction)
+
+        def merge(self, *branches):
+            """Sets this stack iterator's stack to the union of the given branches."""
+            self.__init__(
+                self.merge_stacks(*branches),
+                set.union(*[br.gc_temporaries for br in branches]))
+
+        def create(self, new_stack):
+            """Creates a stack iterator from the given stack"""
+            return GCStackIterator(new_stack, self.gc_temporaries)
+
+    # Find out which instructions are at risk.
+    gc_iterator = GCStackIterator()
+    iterate_as_stack(instruction, gc_iterator)
+    # These temporaries need to be protected from the GC.
+    gc_temporaries = gc_iterator.gc_temporaries
+
+    def protect_result(instruction):
+        """Protects the given instruction's (temporary) result."""
+        if instruction in gc_temporaries and instruction.get_result_type() == NODE_RESULT_TYPE:
+            gc_temporaries.remove(instruction)
+            store_instr = StoreLocalInstruction(None, instruction)
+            return CompoundInstruction(
+                store_instr,
+                CompoundInstruction(
+                    SelectInstruction(
+                        BinaryInstruction(
+                            store_instr.create_load(), 'is not', LiteralInstruction(None)),
+                        CreateEdgeInstruction(connected_node, store_instr.create_load()),
+                        EmptyInstruction()),
+                    store_instr.create_load()))
+        else:
+            return instruction
+
+    return map_instruction_tree_top_down(protect_result, instruction)
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     example_tree = SelectInstruction(
     example_tree = SelectInstruction(

+ 1 - 1
kernel/modelverse_kernel/main.py

@@ -60,7 +60,7 @@ class ModelverseKernel(object):
         # To make the JIT compile 'input' instructions as calls to
         # To make the JIT compile 'input' instructions as calls to
         # modelverse_jit.runtime.get_input, uncomment the line below:
         # modelverse_jit.runtime.get_input, uncomment the line below:
         #
         #
-        #     self.jit.use_input_function()
+        self.jit.use_input_function()
         #
         #
 
 
         self.debug_info = defaultdict(list)
         self.debug_info = defaultdict(list)