Przeglądaj źródła

Store compiled and primitive functions in the JIT

jonathanvdc 8 lat temu
rodzic
commit
e6eb7192de
2 zmienionych plików z 25 dodań i 29 usunięć
  1. 4 4
      kernel/modelverse_jit/jit.py
  2. 21 25
      kernel/modelverse_kernel/main.py

+ 4 - 4
kernel/modelverse_jit/jit.py

@@ -40,10 +40,10 @@ class ModelverseJit(object):
     def is_jittable_entry_point(self, body_id):
         """Tells if the node with the given identifier is a function entry point that
            has not been marked as non-jittable. This only returns `True` if the JIT
-           is enabled"""
-        return self.jit_enabled and (
-            body_id in self.todo_entry_points or
-            body_id in self.jitted_entry_points)
+           is enabled and the function entry point has been marked jittable, or if
+           the function has already been compiled."""
+        return ((self.jit_enabled and body_id in self.todo_entry_points) or
+                body_id in self.jitted_entry_points)
 
     def mark_no_jit(self, body_id):
         """Informs the JIT that the node with the given identifier is a function entry

+ 21 - 25
kernel/modelverse_kernel/main.py

@@ -12,8 +12,6 @@ else:
 class ModelverseKernel(object):
     def __init__(self, root):
         self.root = root
-        self.primitives = {}
-        self.compiled = {}
         self.returnvalue = None
         self.success = True
         self.generators = {}
@@ -66,9 +64,9 @@ class ModelverseKernel(object):
         elif inst is None:
             raise Exception("Instruction pointer could not be found!")
         elif isinstance(phase_v, string_types):
-            if phase_v == "init" and (inst in self.compiled or self.jit.is_jittable_entry_point(inst)):
-                #print("%-30s(%s)" % ("COMPILED " + str(self.compiled[inst]), phase_v))
-                gen = self.execute_primitive_or_jit(user_root, inst, username)
+            if phase_v == "init" and self.jit.is_jittable_entry_point(inst):
+                #print("%-30s(%s)" % ("COMPILED " + str(self.jit.jitted_entry_points[inst]), phase_v))
+                gen = self.execute_jit(user_root, inst, username)
             elif inst_v is None:
                 raise Exception("%s: error understanding command (%s, %s)" % (self.debug_info, inst_v, phase_v))
             else:
@@ -115,11 +113,13 @@ class ModelverseKernel(object):
         signatures  =    yield [("RDN", [primitives, f]) for f in keys]
         bodies =         yield [("RD", [f, "body"]) for f in signatures]
         for i in range(len(keys)):
-            self.primitives[bodies[i]] = getattr(primitive_functions, function_names[i])
-        self.compiled.update(self.primitives)
+            self.jit.register_compiled(
+                bodies[i], 
+                getattr(primitive_functions, function_names[i]), 
+                function_names[i])
 
-    def execute_primitive_or_jit(self, user_root, inst, username):
-        # execute_primitive_or_jit
+    def execute_jit(self, user_root, inst, username):
+        # execute_jit
         user_frame, =    yield [("RD", [user_root, "frame"])]
         symbols, =       yield [("RD", [user_frame, "symbols"])]
         all_links, =     yield [("RO", [symbols])]
@@ -136,22 +136,18 @@ class ModelverseKernel(object):
         parameters["username"] = username
         parameters["mvk"] = self
 
-        # prim is a generator itself!
+        # Have the JIT compile the function.
         try:
-            # Forward the message we get to this generator
-            # Sometimes it might not even be a generator, in which case this should already be in the except block (i.e., for the Read Root operation)
-            if inst in self.compiled:
-                prim = self.compiled[inst](**parameters)
-            else:
-                try:
-                    jit_gen = self.jit.jit_compile(inst, dict_keys)
-                    inp = None
-                    while 1:
-                        inp = yield jit_gen.send(inp)
-                except primitive_functions.PrimitiveFinished as e:
-                    # Execution has ended with a returnvalue, so read it out from the exception being thrown
-                    prim = e.result(**parameters)
+            jit_gen = self.jit.jit_compile(inst, dict_keys)
+            inp = None
+            while 1:
+                inp = yield jit_gen.send(inp)
+        except primitive_functions.PrimitiveFinished as e:
+            compiled_func = e.result
 
+        # Run the compiled function.
+        try:
+            prim = compiled_func(**parameters)
             inp = None
             while 1:
                 inp = yield prim.send(inp)
@@ -163,7 +159,7 @@ class ModelverseKernel(object):
             result = e.result
 
             #if result is None:
-            #    raise Exception("Primitive raised exception: value of None for operation %s with parameters %s" % (self.compiled[inst], str(parameters)))
+            #    raise Exception("Primitive raised exception: value of None for operation %s with parameters %s" % (compiled_func, str(parameters)))
 
         # Clean up the current stack, as if a return happened
         old_frame, =    yield [("RD", [user_frame, "prev"])]
@@ -463,7 +459,7 @@ class ModelverseKernel(object):
                     # For this, we read out the body of the resolved data
                     compiler_val, =  yield [("RD", [variable, "value"])]
                     compiler_body, = yield [("RD", [compiler_val, "body"])]
-                    self.compiled[compiler_body] = compiled_function
+                    self.jit.register_compiled(compiler_body, compiled_function, var_name)
 
         else:
             phase_link, returnvalue_link, new_phase = \