瀏覽代碼

Make JIT capable of compiling and running a leaf function

Specifically, the integration/test_main.py works with JIT enabled now!
jonathanvdc 8 年之前
父節點
當前提交
28acadeb06
共有 3 個文件被更改,包括 70 次插入12 次删除
  1. 36 10
      kernel/modelverse_jit/jit.py
  2. 32 0
      kernel/modelverse_jit/tree_ir.py
  3. 2 2
      kernel/modelverse_kernel/main.py

+ 36 - 10
kernel/modelverse_jit/jit.py

@@ -15,6 +15,8 @@ class ModelverseJit(object):
         self.todo_entry_points = set()
         self.no_jit_entry_points = set()
         self.jitted_entry_points = {}
+        self.jit_globals = {}
+        self.jit_count = 0
 
     def mark_entry_point(self, body_id):
         """Marks the node with the given identifier as a function entry point."""
@@ -48,20 +50,42 @@ class ModelverseJit(object):
 
     def try_jit(self, body_id, parameter_list):
         """Tries to jit the function defined by the given entry point id and parameter list."""
+        if body_id in self.jitted_entry_points:
+            # We have already compiled this function.
+            raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
+        elif body_id in self.no_jit_entry_points:
+            # We're not allowed to jit this function or have tried and failed before.
+            raise JitCompilationFailedException(
+                'Cannot jit function at %d because it is marked non-jittable.' % body_id)
+
         gen = AnalysisState().analyze(body_id)
         try:
             inp = None
             while True:
                 inp = yield gen.send(inp)
-        except primitive_functions.PrimitiveFinished as e:
-            constructed_ir = e.result
-        except JitCompilationFailedException:
+        except primitive_functions.PrimitiveFinished as ex:
+            constructed_body = ex.result
+        except JitCompilationFailedException as ex:
             self.mark_no_jit(body_id)
-            raise
-
-        print(constructed_ir)
-        self.mark_no_jit(body_id)
-        raise JitCompilationFailedException("Can't jit function body at " + str(body_id))
+            raise JitCompilationFailedException(
+                '%s (function at %d)' % (ex.message, body_id))
+
+        # Wrap the IR in a function definition, give it a unique name.
+        constructed_function = tree_ir.DefineFunctionInstruction(
+            'jit_func%d' % self.jit_count,
+            parameter_list + ['**' + KWARGS_PARAMETER_NAME],
+            constructed_body)
+        self.jit_count += 1
+        # Convert the function definition to Python code, and compile it.
+        exec(str(constructed_function), self.jit_globals)
+        # Extract the compiled function from the JIT global state.
+        compiled_function = self.jit_globals[constructed_function.name]
+
+        print(constructed_function)
+        
+        # Save the compiled function so we can reuse it later.
+        self.jitted_entry_points[body_id] = compiled_function
+        raise primitive_functions.PrimitiveFinished(compiled_function)
 
 class AnalysisState(object):
     """The state of a bytecode analysis call graph."""
@@ -75,7 +99,7 @@ class AnalysisState(object):
         # Check the analyzed_instructions set for instruction_id to avoid
         # infinite loops.
         if instruction_id in self.analyzed_instructions:
-            raise JitCompilationFailedException('Cannon jit non-tree instruction graph.')
+            raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
 
         self.analyzed_instructions.add(instruction_id)
         instruction_val, = yield [("RV", [instruction_id])]
@@ -206,6 +230,7 @@ class AnalysisState(object):
         #                        ("CD", [user_root, "last_output", new_last_output]),
         #                        ("DE", [last_output_link])
         #                       ]
+        # yield 'nop'
 
         value_id, = yield [("RD", [instruction_id, "value"])]
         gen = self.analyze(value_id)
@@ -256,7 +281,8 @@ class AnalysisState(object):
                 store_user_root.create_load(),
                 tree_ir.LiteralInstruction('last_output'),
                 new_last_output.create_load()),
-            tree_ir.DeleteEdgeInstruction(last_output_link.create_load()))
+            tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
+            tree_ir.NopInstruction())
 
         raise primitive_functions.PrimitiveFinished(result)
 

+ 32 - 0
kernel/modelverse_jit/tree_ir.py

@@ -1,4 +1,8 @@
 
+NOP_LITERAL = None
+"""A literal that results in a nop during which execution may be interrupted
+   when yielded."""
+
 class Instruction(object):
     """A base class for instructions. An instruction is essentially a syntax
        node that must first be defined, and can only then be used."""
@@ -328,6 +332,22 @@ class LoadLocalInstruction(LocalInstruction):
         """Tells if this instruction requires a definition."""
         return False
 
+class DefineFunctionInstruction(LocalInstruction):
+    """An instruction that defines a function."""
+    def __init__(self, name, parameter_list, body):
+        LocalInstruction.__init__(self, name)
+        self.parameter_list = parameter_list
+        self.body = body
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+           The statement is appended immediately to the code generator."""
+
+        code_generator.append_line('def %s(%s):' % (self.name, ', '.join(self.parameter_list)))
+        code_generator.increase_indentation()
+        self.body.generate_python_def(code_generator)
+        code_generator.decrease_indentation()
+
 class LoadIndexInstruction(Instruction):
     """An instruction that produces a value by indexing a specified expression with
        a given key."""
@@ -353,6 +373,18 @@ class LoadIndexInstruction(Instruction):
             self.indexed.generate_python_use(code_generator),
             self.key.generate_python_use(code_generator))
 
+class NopInstruction(Instruction):
+    """A nop instruction, which allows for the kernel's thread of execution to be interrupted."""
+
+    def has_result(self):
+        """Tells if this instruction computes a result."""
+        return False
+
+    def generate_python_def(self, code_generator):
+        """Generates a Python statement that executes this instruction.
+           The statement is appended immediately to the code generator."""
+        code_generator.append_line('yield %s' % repr(NOP_LITERAL))
+
 class ReadValueInstruction(StateInstruction):
     """An instruction that reads a value from a node."""
 

+ 2 - 2
kernel/modelverse_kernel/main.py

@@ -20,7 +20,7 @@ class ModelverseKernel(object):
         self.allow_compiled = True
         #self.allow_compiled = False
 
-        # `self.jit` is handles most JIT-related functionality.
+        # `self.jit` handles most JIT-related functionality.
         # Set it to `None` to disable the JIT.
         self.jit = jit.ModelverseJit()
         self.debug_info = "(no debug information found)"
@@ -150,7 +150,7 @@ class ModelverseKernel(object):
                         inp = yield jit_gen.send(inp)
                 except primitive_functions.PrimitiveFinished as e:
                     # Execution has ended with a returnvalue, so read it out from the exception being thrown
-                    prim = e.result
+                    prim = e.result(**parameters)
 
             inp = None
             while 1: