Browse Source

Ensure that the JIT's locals don't get GC'ed

jonathanvdc 8 years ago
parent
commit
b03a765eda
3 changed files with 46 additions and 11 deletions
  1. 29 7
      kernel/modelverse_jit/jit.py
  2. 5 4
      kernel/modelverse_jit/runtime.py
  3. 12 0
      kernel/modelverse_jit/tree_ir.py

+ 29 - 7
kernel/modelverse_jit/jit.py

@@ -16,6 +16,12 @@ CALL_FUNCTION_NAME = "__call_function"
 GET_INPUT_FUNCTION_NAME = "__get_input"
 GET_INPUT_FUNCTION_NAME = "__get_input"
 """The name of the '__get_input' function, in the jitted function scope."""
 """The name of the '__get_input' function, in the jitted function scope."""
 
 
+LOCALS_NODE_NAME = "jit_locals"
+"""The name of the node that is connected to all JIT locals in a given function call."""
+
+LOCALS_EDGE_NAME = "jit_locals_edge"
+"""The name of the edge that connects the LOCALS_NODE_NAME node to a user root."""
+
 def get_parameter_names(compiled_function):
 def get_parameter_names(compiled_function):
     """Gets the given compiled function's parameter names."""
     """Gets the given compiled function's parameter names."""
     if hasattr(compiled_function, '__code__'):
     if hasattr(compiled_function, '__code__'):
@@ -315,14 +321,22 @@ class ModelverseJit(object):
 
 
         # Write a prologue and prepend it to the generated function body.
         # Write a prologue and prepend it to the generated function body.
         prologue_statements = []
         prologue_statements = []
+        # Create a LOCALS_NODE_NAME node, and connect it to the user root.
+        prologue_statements.append(
+            tree_ir.create_new_local_node(
+                LOCALS_NODE_NAME,
+                tree_ir.LoadIndexInstruction(
+                    tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
+                    tree_ir.LiteralInstruction('user_root')),
+                LOCALS_EDGE_NAME))
         for (key, val) in param_dict.items():
         for (key, val) in param_dict.items():
-            arg_ptr = tree_ir.StoreLocalInstruction(
+            arg_ptr = tree_ir.create_new_local_node(
                 body_param_dict[key],
                 body_param_dict[key],
-                tree_ir.CreateNodeInstruction())
+                tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))
             prologue_statements.append(arg_ptr)
             prologue_statements.append(arg_ptr)
             prologue_statements.append(
             prologue_statements.append(
                 tree_ir.CreateDictionaryEdgeInstruction(
                 tree_ir.CreateDictionaryEdgeInstruction(
-                    arg_ptr.create_load(),
+                    tree_ir.LoadLocalInstruction(body_param_dict[key]),
                     tree_ir.LiteralInstruction('value'),
                     tree_ir.LiteralInstruction('value'),
                     tree_ir.LoadLocalInstruction(val)))
                     tree_ir.LoadLocalInstruction(val)))
 
 
@@ -446,14 +460,21 @@ class AnalysisState(object):
     def analyze_return(self, instruction_id):
     def analyze_return(self, instruction_id):
         """Tries to analyze the given 'return' instruction."""
         """Tries to analyze the given 'return' instruction."""
         retval_id, = yield [("RD", [instruction_id, 'value'])]
         retval_id, = yield [("RD", [instruction_id, 'value'])]
+        def create_return(return_value):
+            return tree_ir.ReturnInstruction(
+                tree_ir.CompoundInstruction(
+                    return_value,
+                    tree_ir.DeleteEdgeInstruction(
+                        tree_ir.LoadLocalInstruction(LOCALS_EDGE_NAME))))
+
         if retval_id is None:
         if retval_id is None:
             raise primitive_functions.PrimitiveFinished(
             raise primitive_functions.PrimitiveFinished(
-                tree_ir.ReturnInstruction(
+                create_return(
                     tree_ir.EmptyInstruction()))
                     tree_ir.EmptyInstruction()))
         else:
         else:
             retval, = yield [("CALL_ARGS", [self.analyze, (retval_id,)])]
             retval, = yield [("CALL_ARGS", [self.analyze, (retval_id,)])]
             raise primitive_functions.PrimitiveFinished(
             raise primitive_functions.PrimitiveFinished(
-                tree_ir.ReturnInstruction(retval))
+                create_return(retval))
 
 
     def analyze_if(self, instruction_id):
     def analyze_if(self, instruction_id):
         """Tries to analyze the given 'if' instruction."""
         """Tries to analyze the given 'if' instruction."""
@@ -712,14 +733,15 @@ class AnalysisState(object):
         #
         #
         #     if 'local_name' not in locals():
         #     if 'local_name' not in locals():
         #         local_name, = yield [("CN", [])]
         #         local_name, = yield [("CN", [])]
+        #         yield [("CE", [LOCALS_NODE_NAME, local_name])]
 
 
         raise primitive_functions.PrimitiveFinished(
         raise primitive_functions.PrimitiveFinished(
             tree_ir.SelectInstruction(
             tree_ir.SelectInstruction(
                 tree_ir.LocalExistsInstruction(name),
                 tree_ir.LocalExistsInstruction(name),
                 tree_ir.EmptyInstruction(),
                 tree_ir.EmptyInstruction(),
-                tree_ir.StoreLocalInstruction(
+                tree_ir.create_new_local_node(
                     name,
                     name,
-                    tree_ir.CreateNodeInstruction())))
+                    tree_ir.LoadLocalInstruction(LOCALS_NODE_NAME))))
 
 
     def analyze_global(self, instruction_id):
     def analyze_global(self, instruction_id):
         """Tries to analyze the given 'global' (declaration) instruction."""
         """Tries to analyze the given 'global' (declaration) instruction."""

+ 5 - 4
kernel/modelverse_jit/runtime.py

@@ -14,7 +14,7 @@ def call_function(function_id, named_arguments, **kwargs):
     # Try to jit the function here. We might be able to avoid building the stack
     # Try to jit the function here. We might be able to avoid building the stack
     # frame.
     # frame.
     def handle_jit_failed(_):
     def handle_jit_failed(_):
-        interpreter_args = {'body_id' : body_id, 'named_arguments' : named_arguments}
+        interpreter_args = {'function_id' : function_id, 'named_arguments' : named_arguments}
         interpreter_args.update(kwargs)
         interpreter_args.update(kwargs)
         yield [("TAIL_CALL_KWARGS", [interpret_function, interpreter_args])]
         yield [("TAIL_CALL_KWARGS", [interpret_function, interpreter_args])]
 
 
@@ -28,13 +28,14 @@ def call_function(function_id, named_arguments, **kwargs):
     # Run the function.
     # Run the function.
     yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])]
     yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])]
 
 
-def interpret_function(body_id, named_arguments, **kwargs):
-    """Makes the interpreter run the function with the given id with the specified
+def interpret_function(function_id, named_arguments, **kwargs):
+    """Makes the interpreter run the function with the given id for the specified
        argument dictionary."""
        argument dictionary."""
     user_root = kwargs['user_root']
     user_root = kwargs['user_root']
     kernel = kwargs['mvk']
     kernel = kwargs['mvk']
     user_frame, = yield [("RD", [user_root, "frame"])]
     user_frame, = yield [("RD", [user_root, "frame"])]
-    inst, = yield [("RD", [user_frame, "IP"])]
+    inst, body_id = yield [("RD", [user_frame, "IP"]), ("RD", [function_id, "body"])]
+    kernel.jit.mark_entry_point(body_id)
 
 
     # Create a new stack frame.
     # Create a new stack frame.
     frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
     frame_link, new_phase, new_frame, new_evalstack, new_symbols, \

+ 12 - 0
kernel/modelverse_jit/tree_ir.py

@@ -1220,6 +1220,18 @@ def create_jit_call(target, named_arguments, kwargs):
         RunGeneratorFunctionInstruction(
         RunGeneratorFunctionInstruction(
             target, arg_dict.create_load()))
             target, arg_dict.create_load()))
 
 
+def create_new_local_node(local_variable, connected_node, edge_variable=None):
+    """Creates a local node that is the backing storage for a local variable.
+       This node is connected to a given node to make sure it's not perceived
+       as dead by the GC. The newly created node is stored in the given
+       local variable. The edge's id can also optionally be stored in a variable."""
+    local_store = StoreLocalInstruction(local_variable, CreateNodeInstruction())
+    create_edge = CreateEdgeInstruction(local_store.create_load(), connected_node)
+    if edge_variable is not None:
+        create_edge = StoreLocalInstruction(edge_variable, create_edge)
+
+    return create_block(local_store, create_edge)
+
 def with_debug_info_trace(instruction, debug_info, function_name):
 def with_debug_info_trace(instruction, debug_info, function_name):
     """Prepends the given instruction with a tracing instruction that prints
     """Prepends the given instruction with a tracing instruction that prints
        the given debug information and function name."""
        the given debug information and function name."""