Browse Source

Insert JIT hooks in the kernel

jonathanvdc 8 years ago
parent
commit
c2517c0d60
3 changed files with 100 additions and 10 deletions
  1. 0 0
      kernel/modelverse_jit/__init__.py
  2. 51 0
      kernel/modelverse_jit/jit.py
  3. 49 10
      kernel/modelverse_kernel/main.py

+ 0 - 0
kernel/modelverse_jit/__init__.py


+ 51 - 0
kernel/modelverse_jit/jit.py

@@ -0,0 +1,51 @@
+import modelverse_kernel.primitives as primitive_functions
+import modelverse_jit.tree_ir as tree_ir
+
+class JitCompilationFailedException(Exception):
+    """A type of exception that is raised when the jit fails to compile a function."""
+    pass
+
+class ModelverseJit(object):
+    """A high-level interface to the modelverse JIT compiler."""
+
+    def __init__(self):
+        self.todo_entry_points = set()
+        self.no_jit_entry_points = set()
+        self.jitted_entry_points = {}
+
+    def mark_entry_point(self, body_id):
+        """Marks the node with the given identifier as a function entry point."""
+        if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
+            self.todo_entry_points.add(body_id)
+
+    def is_entry_point(self, body_id):
+        """Tells if the node with the given identifier is a function entry point."""
+        return body_id in self.todo_entry_points or \
+               body_id in self.no_jit_entry_points or \
+               body_id in self.jitted_entry_points
+
+    def is_jittable_entry_point(self, body_id):
+        """Tells if the node with the given identifier is a function entry point that
+           has not been marked as non-jittable."""
+        return body_id in self.todo_entry_points or \
+               body_id in self.jitted_entry_points
+
+    def mark_no_jit(self, body_id):
+        """Informs the JIT that the node with the given identifier is a function entry
+           point that must never be jitted."""
+        self.no_jit_entry_points.add(body_id)
+        if body_id in self.todo_entry_points:
+            self.todo_entry_points.remove(body_id)
+
+    def register_compiled(self, body_id, compiled):
+        """Registers a compiled entry point with the JIT."""
+        self.jitted_entry_points[body_id] = compiled
+        if body_id in self.todo_entry_points:
+            self.todo_entry_points.remove(body_id)
+
+    def try_jit(self, body_id, parameter_list):
+        """Tries to jit the function defined by the given entry point id and parameter list."""
+
+        print("Couldn't JIT: " + str(body_id))
+        self.mark_no_jit(body_id)
+        raise JitCompilationFailedException("Couln't JIT")

+ 49 - 10
kernel/modelverse_kernel/main.py

@@ -1,5 +1,6 @@
 import modelverse_kernel.primitives as primitive_functions
 import modelverse_kernel.compiled as compiled_functions
+import modelverse_jit.jit as jit
 import sys
 import time
 
@@ -18,6 +19,10 @@ class ModelverseKernel(object):
         self.generators = {}
         self.allow_compiled = True
         #self.allow_compiled = False
+
+        # `self.jit` is handles most JIT-related functionality.
+        # Set it to `None` to disable the JIT.
+        self.jit = jit.ModelverseJit()
         self.debug_info = "(no debug information found)"
 
     def execute_yields(self, username, operation, params, reply):
@@ -60,14 +65,14 @@ class ModelverseKernel(object):
         elif inst is None:
             raise Exception("Instruction pointer could not be found!")
         elif isinstance(phase_v, string_types):
-            if phase_v == "init" and inst in self.compiled:
+            if phase_v == "init" and (inst in self.compiled or \
+                self.jit.is_jittable_entry_point(inst)):
                 #print("%-30s(%s)" % ("COMPILED " + str(self.compiled[inst]), phase_v))
-                gen = self.execute_primitive(user_root, inst, username)
+                gen = self.execute_primitive_or_jit(user_root, inst, username)
             elif inst_v is None:
                 raise Exception("%s: error understanding command (%s, %s)" % (self.debug_info, inst_v, phase_v))
             else:
-                #print("%-30s(%s) -- %s" % (inst_v["value"], phase_v, username))
-                gen = getattr(self, "%s_%s" % (inst_v["value"], phase_v))(user_root)
+                gen = self.get_inst_phase_generator(inst_v, phase_v, user_root)
         elif inst_v is None:
             raise Exception("%s: error understanding command (%s, %s)" % (self.debug_info, inst_v, phase_v))
         elif inst_v["value"] == "call":
@@ -82,6 +87,21 @@ class ModelverseKernel(object):
                 inp = yield gen.send(inp)
         except StopIteration:
             pass
+        except jit.JitCompilationFailedException:
+            # Try again, but this time without the JIT.
+            gen = self.get_inst_phase_generator(inst_v, phase_v, user_root)
+            try:
+                inp = None
+                while 1:
+                    inp = yield gen.send(inp)
+            except StopIteration:
+                pass
+
+    def get_inst_phase_generator(self, inst_v, phase_v, user_root):
+        """Gets a generator for the given instruction in the given phase,
+           for the specified user root."""
+        #print("%-30s(%s) -- %s" % (inst_v["value"], phase_v, username))
+        return getattr(self, "%s_%s" % (inst_v["value"], phase_v))(user_root)
 
     ##########################
     ### Process primitives ###
@@ -97,8 +117,8 @@ class ModelverseKernel(object):
             self.primitives[bodies[i]] = getattr(primitive_functions, function_names[i])
         self.compiled.update(self.primitives)
 
-    def execute_primitive(self, user_root, inst, username):
-        # execute_primitive
+    def execute_primitive_or_jit(self, user_root, inst, username):
+        # execute_primitive_or_jit
         user_frame, =    yield [("RD", [user_root, "frame"])]
         symbols, =       yield [("RD", [user_frame, "symbols"])]
         all_links, =     yield [("RO", [symbols])]
@@ -119,7 +139,18 @@ class ModelverseKernel(object):
         try:
             # Forward the message we get to this generator
             # Sometimes it might not even be a generator, in which case this should already be in the except block (i.e., for the Read Root operation)
-            prim = self.compiled[inst](**parameters)
+            if inst in self.compiled:
+                prim = self.compiled[inst](**parameters)
+            else:
+                try:
+                    jit_gen = self.jit.try_jit(inst, dict_keys)
+                    inp = None
+                    while 1:
+                        inp = yield jit_gen.send(inp)
+                except primitive_functions.PrimitiveFinished as e:
+                    # Execution has ended with a returnvalue, so read it out from the exception being thrown
+                    prim = e.result
+
             inp = None
             while 1:
                 inp = yield prim.send(inp)
@@ -660,9 +691,12 @@ class ModelverseKernel(object):
 
         if param is None:
             returnvalue, =  yield [("RD", [user_frame, "returnvalue"])]
-            body, phase_link, frame_link, prev_phase, new_phase, new_frame, new_evalstack, new_symbols, new_returnvalue = \
-                            yield [("RD", [returnvalue, "body"]),
-                                   ("RDE", [user_frame, "phase"]),
+            body, =         yield [("RD", [returnvalue, "body"])]
+            if self.jit is not None:
+                self.jit.mark_entry_point(body)
+
+            phase_link, frame_link, prev_phase, new_phase, new_frame, new_evalstack, new_symbols, new_returnvalue = \
+                            yield [("RDE", [user_frame, "phase"]),
                                    ("RDE", [user_root, "frame"]),
                                    ("CNV", ["finish"]),
                                    ("CNV", ["init"]),
@@ -702,6 +736,11 @@ class ModelverseKernel(object):
                             yield [("RD", [signature, "params"]),
                                    ("RD", [inst, "last_param"]),
                                   ]
+
+            body, =         yield [("RD", [new_IP, "body"])]
+            if self.jit is not None:
+                self.jit.mark_entry_point(body)
+            
             name, =         yield [("RD", [last_param, "name"])]
             name_value, =   yield [("RV", [name])]
             returnvalue, formal_parameter, new_phase, variable = \