Browse Source

Detect 'mutable' functions in the JIT

jonathanvdc 8 years ago
parent
commit
85bea2fbbf
3 changed files with 37 additions and 14 deletions
  1. 17 8
      kernel/modelverse_jit/jit.py
  2. 19 5
      kernel/modelverse_jit/runtime.py
  3. 1 1
      kernel/modelverse_kernel/main.py

+ 17 - 8
kernel/modelverse_jit/jit.py

@@ -262,19 +262,23 @@ class ModelverseJit(object):
                 tree_ir.LoadGlobalInstruction(target_type.__name__),
                 [tree_ir.ReadValueInstruction(a)])))
 
-    def jit_parameters(self, body_id):
-        """Acquires the parameter list for the given body id node."""
+    def jit_signature(self, body_id):
+        """Acquires the signature for the given body id node, which consists of the
+           parameter variables, parameter name and a flag that tells if the given function
+           is mutable."""
         if body_id not in self.jitted_parameters:
-            signature_id, = yield [("RRD", [body_id, "body"])]
+            signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
             signature_id = signature_id[0]
-            param_set_id, = yield [("RD", [signature_id, "params"])]
+            param_set_id, is_mutable = yield [
+                ("RD", [signature_id, "params"]),
+                ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
             if param_set_id is None:
-                self.jitted_parameters[body_id] = ([], [])
+                self.jitted_parameters[body_id] = ([], [], is_mutable)
             else:
                 param_name_ids, = yield [("RDK", [param_set_id])]
                 param_names = yield [("RV", [n]) for n in param_name_ids]
                 param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
-                self.jitted_parameters[body_id] = (param_vars, param_names)
+                self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
 
         raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
 
@@ -307,7 +311,8 @@ class ModelverseJit(object):
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = None
 
-        (parameter_ids, parameter_list), = yield [("CALL_ARGS", [self.jit_parameters, (body_id,)])]
+        (parameter_ids, parameter_list, is_mutable), = yield [
+            ("CALL_ARGS", [self.jit_signature, (body_id,)])]
 
         param_dict = dict(zip(parameter_ids, parameter_list))
         body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
@@ -331,6 +336,10 @@ class ModelverseJit(object):
         # Try to analyze the function's body.
         yield [("TRY", [])]
         yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
+        if is_mutable:
+            # We can't just JIT mutable functions. That'd be dangerous.
+            raise JitCompilationFailedException(
+                "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
         state = AnalysisState(
             self, body_id, user_root, body_param_dict,
             self.max_instructions)
@@ -880,7 +889,7 @@ class AnalysisState(object):
         """Tries to analyze a direct 'call' instruction."""
         self.register_function_var(callee_id)
 
-        body_id, = yield [("RD", [callee_id, "body"])]
+        body_id, = yield [("RD", [callee_id, jit_runtime.FUNCTION_BODY_KEY])]
 
         # Make this function dependent on the callee.
         if body_id in self.jit.compilation_dependencies:

+ 19 - 5
kernel/modelverse_jit/runtime.py

@@ -4,20 +4,34 @@ class JitCompilationFailedException(Exception):
     """A type of exception that is raised when the jit fails to compile a function."""
     pass
 
+MUTABLE_FUNCTION_KEY = "mutable"
+"""A dictionary key for functions that are mutable."""
+
+FUNCTION_BODY_KEY = "body"
+"""A dictionary key for function bodies."""
+
 def call_function(function_id, named_arguments, **kwargs):
     """Runs the function with the given id, passing it the specified argument dictionary."""
     user_root = kwargs['user_root']
     kernel = kwargs['mvk']
-    body_id, = yield [("RD", [function_id, "body"])]
-    kernel.jit.mark_entry_point(body_id)
+    body_id, is_mutable = yield [
+        ("RD", [function_id, FUNCTION_BODY_KEY]),
+        ("RD", [function_id, MUTABLE_FUNCTION_KEY])]
 
     # Try to jit the function here. We might be able to avoid building the stack
     # frame.
     def handle_jit_failed(_):
+        """Interprets the function."""
         interpreter_args = {'function_id' : function_id, 'named_arguments' : named_arguments}
         interpreter_args.update(kwargs)
         yield [("TAIL_CALL_KWARGS", [interpret_function, interpreter_args])]
 
+    if is_mutable is not None:
+        kernel.jit.mark_no_jit(body_id)
+        yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])]
+    else:
+        kernel.jit.mark_entry_point(body_id)
+
     yield [("TRY", [])]
     yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])]
     # Try to compile.
@@ -34,7 +48,7 @@ def interpret_function(function_id, named_arguments, **kwargs):
     user_root = kwargs['user_root']
     kernel = kwargs['mvk']
     user_frame, = yield [("RD", [user_root, "frame"])]
-    inst, body_id = yield [("RD", [user_frame, "IP"]), ("RD", [function_id, "body"])]
+    inst, body_id = yield [("RD", [user_frame, "IP"]), ("RD", [function_id, FUNCTION_BODY_KEY])]
     kernel.jit.mark_entry_point(body_id)
 
     # Create a new stack frame.
@@ -66,8 +80,8 @@ def interpret_function(function_id, named_arguments, **kwargs):
                           ]
 
     # Put the parameters in the new stack frame's symbol table.
-    (parameter_vars, parameter_names), = yield [
-        ("CALL_ARGS", [kernel.jit.jit_parameters, (body_id,)])]
+    (parameter_vars, parameter_names, _), = yield [
+        ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])]
     parameter_dict = dict(zip(parameter_names, parameter_vars))
 
     for (key, value) in named_arguments.items():

+ 1 - 1
kernel/modelverse_kernel/main.py

@@ -65,7 +65,7 @@ class ModelverseKernel(object):
         # To make the JIT print JIT successes and errors to the command-line,
         # uncomment the line below:
         #
-        #     self.jit.set_jit_success_log()
+        self.jit.set_jit_success_log()
         #
         # If you want, you can use a custom logging function:
         #