Browse Source

Optimize the 'interpret_function' function in the JIT runtime

jonathanvdc 8 years ago
parent
commit
e8c3971bad

+ 12 - 19
kernel/modelverse_jit/runtime.py

@@ -11,16 +11,18 @@ def interpret_function(function_id, named_arguments, **kwargs):
     body_id, = yield [("RD", [function_id, "body"])]
 
     # Create a new stack frame.
-    frame_link, new_phase, new_frame, new_evalstack, new_symbols, new_returnvalue = \
+    frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
+        new_returnvalue, intrinsic_return = \
                     yield [("RDE", [user_root, "frame"]),
                            ("CNV", ["init"]),
                            ("CN", []),
                            ("CN", []),
                            ("CN", []),
                            ("CN", []),
+                           ("CN", [])
                           ]
 
-    _, _, _, _, _, _, _, _, _ = \
+    _, _, _, _, _, _, _, _, _, _ = \
                     yield [("CD", [user_root, "frame", new_frame]),
                            ("CD", [new_frame, "evalstack", new_evalstack]),
                            ("CD", [new_frame, "symbols", new_symbols]),
@@ -29,7 +31,11 @@ def interpret_function(function_id, named_arguments, **kwargs):
                            ("CD", [new_frame, "phase", new_phase]),
                            ("CD", [new_frame, "IP", body_id]),
                            ("CD", [new_frame, "prev", user_frame]),
-                           ("DE", [frame_link]),
+                           ("CD", [
+                               new_frame,
+                               primitive_functions.PRIMITIVE_RETURN_KEY,
+                               intrinsic_return]),
+                           ("DE", [frame_link])
                           ]
 
     # Put the parameters in the new stack frame's symbol table.
@@ -58,19 +64,6 @@ def interpret_function(function_id, named_arguments, **kwargs):
             while 1:
                 inp = yield gen.send(inp)
         except StopIteration:
-            # An instruction has been completed. Check if we've already returned.
-            #
-            # TODO: the statement below performs O(n) state reads whenever an instruction
-            # finishes, where n is the number of 'interpret_function' stack frames.
-            # I don't *think* that this is problematic (at least not in the short term),
-            # but an O(1) solution would obviously be much better; that's the interpreter's
-            # complexity. Perhaps we can annotate the stack frame we create here with a marker
-            # that the kernel can pick up on? We could have the kernel throw an exception whenever
-            # it encounters said marker.
-            current_user_frame, = yield [("RD", [user_root, "frame"])]
-            if current_user_frame == user_frame:
-                # We're done here. Extract the return value and get out.
-                returnvalue, = yield [("RD", [user_frame, "returnvalue"])]
-                raise primitive_functions.PrimitiveFinished(returnvalue)
-            else:
-                yield None
+            # An instruction has been completed. Forward it.
+            yield None
+        # Let primitive_functions.PrimitiveFinished bubble up.

+ 28 - 11
kernel/modelverse_kernel/main.py

@@ -582,6 +582,13 @@ class ModelverseKernel(object):
             _, _ =          yield [("CD", [user_root, "frame", prev_frame]),
                                    ("DN", [user_frame]),
                                   ]
+
+            # If the callee's frame is marked with the '__primitive_return' key, then
+            # we need to throw an exception instead of just finishing here. This design
+            # gives us O(1) state reads per jit-interpreter transition.
+            exception_return, = yield [("RD", [user_frame, primitive_functions.PRIMITIVE_RETURN_KEY])]
+            if exception_return is not None:
+                raise primitive_functions.PrimitiveFinished(None)
         else:
             evalstack, evalstack_link, ip_link, new_evalstack, evalstack_phase = \
                             yield [("RD", [user_frame, "evalstack"]),
@@ -601,17 +608,27 @@ class ModelverseKernel(object):
                                   ]
 
     def return_eval(self, user_root):
-        user_frame, =       yield [("RD", [user_root, "frame"])]
-        prev_frame, =       yield [("RD", [user_frame, "prev"])]
-        returnvalue, old_returnvalue_link = \
-                            yield [("RD", [user_frame, "returnvalue"]),
-                                   ("RDE", [prev_frame, "returnvalue"]),
-                                  ]
-        _, _, _, _ =        yield [("CD", [user_root, "frame", prev_frame]),
-                                   ("CD", [prev_frame, "returnvalue", returnvalue]),
-                                   ("DE", [old_returnvalue_link]),
-                                   ("DN", [user_frame]),
-                                  ]
+        user_frame, = yield [("RD", [user_root, "frame"])]
+        prev_frame, exception_return, returnvalue = yield [
+            ("RD", [user_frame, "prev"]),
+            ("RD", [user_frame, primitive_functions.PRIMITIVE_RETURN_KEY]),
+            ("RD", [user_frame, "returnvalue"])]
+
+        # If the callee's frame is marked with the '__primitive_return' key, then
+        # we need to throw an exception instead of just finishing here. This design
+        # gives us O(1) state reads per jit-interpreter transition.
+        if exception_return is not None:
+            yield [
+                ("CD", [user_root, "frame", prev_frame]),
+                ("DN", [user_frame])]
+            raise primitive_functions.PrimitiveFinished(returnvalue)
+        else:
+            old_returnvalue_link, = yield [("RDE", [prev_frame, "returnvalue"])]
+            yield [
+                ("CD", [user_root, "frame", prev_frame]),
+                ("CD", [prev_frame, "returnvalue", returnvalue]),
+                ("DE", [old_returnvalue_link]),
+                ("DN", [user_frame])]
 
     def constant_init(self, user_root):
         user_frame, =       yield [("RD", [user_root, "frame"])]

+ 4 - 0
kernel/modelverse_kernel/primitives.py

@@ -2,6 +2,10 @@
 class PrimitiveFinished(Exception):
     def __init__(self, value):
         self.result = value
+
+PRIMITIVE_RETURN_KEY = "__primitive_return"
+"""A dictionary key for functions which request that the kernel throw a PrimitiveFinished
+   exception with the return value instead of injecting the return value in the caller's frame."""
     
 def integer_subtraction(a, b, **remainder):
     a_value, b_value =  yield [("RV", [a]), ("RV", [b])]