소스 검색

Make the JIT compile calls to certain functions as intrinsics

jonathanvdc 8 년 전
부모
커밋
6b753fa24e
3개의 변경된 파일84개의 추가작업 그리고 19개의 파일을 삭제
  1. 40 19
      kernel/modelverse_jit/jit.py
  2. 18 0
      kernel/modelverse_jit/tree_ir.py
  3. 26 0
      kernel/modelverse_kernel/main.py

+ 40 - 19
kernel/modelverse_jit/jit.py

@@ -22,6 +22,7 @@ class ModelverseJit(object):
         self.jit_count = 0
         self.max_instructions = 30 if max_instructions is None else max_instructions
         self.compiled_function_lookup = compiled_function_lookup
+        self.jit_intrinsics = {}
         self.jit_enabled = True
 
     def set_jit_enabled(self, is_enabled=True):
@@ -92,6 +93,19 @@ class ModelverseJit(object):
         else:
             return None
 
+    def register_intrinsic(self, name, apply_intrinsic):
+        """Registers the given intrisic with the JIT. This will make the JIT replace calls to
+           the function with the given entry point by an application of the specified function."""
+        self.jit_intrinsics[name] = apply_intrinsic
+
+    def register_binary_intrinsic(self, name, operator):
+        """Registers an intrinsic with the JIT that represents the given binary operation."""
+        self.register_intrinsic(name, lambda lhs, rhs: tree_ir.CreateNodeWithValueInstruction(
+            tree_ir.BinaryInstruction(
+                tree_ir.ReadValueInstruction(lhs),
+                operator,
+                tree_ir.ReadValueInstruction(rhs))))
+
     def jit_parameters(self, body_id):
         """Acquires the parameter list for the given body id node."""
         if body_id not in self.jitted_parameters:
@@ -680,21 +694,24 @@ class AnalysisState(object):
         except primitive_functions.PrimitiveFinished as ex:
             _, parameter_names = ex.result
 
-        compiled_func = self.jit.lookup_compiled_function(callee_name)
-        if compiled_func is None:
-            # Compile the callee.
-            try:
-                gen = self.jit.jit_compile(self.user_root, body_id, callee_name)
-                inp = None
-                while True:
-                    inp = yield gen.send(inp)
-            except primitive_functions.PrimitiveFinished as ex:
-                pass
-        else:
-            self.jit.register_compiled(body_id, compiled_func, callee_name)
+        is_intrinsic = callee_name in self.jit.jit_intrinsics
 
-        # Get the callee's name.
-        compiled_func_name = self.jit.get_compiled_name(body_id)
+        if not is_intrinsic:
+            compiled_func = self.jit.lookup_compiled_function(callee_name)
+            if compiled_func is None:
+                # Compile the callee.
+                try:
+                    gen = self.jit.jit_compile(self.user_root, body_id, callee_name)
+                    inp = None
+                    while True:
+                        inp = yield gen.send(inp)
+                except primitive_functions.PrimitiveFinished as ex:
+                    pass
+            else:
+                self.jit.register_compiled(body_id, compiled_func, callee_name)
+
+            # Get the callee's name.
+            compiled_func_name = self.jit.get_compiled_name(body_id)
 
         # Analyze the argument dictionary.
         try:
@@ -716,11 +733,15 @@ class AnalysisState(object):
                     "Cannot JIT-compile function call to '%s' with missing argument for "
                     "formal parameter '%s'." % (callee_name, param_name))
 
-        raise primitive_functions.PrimitiveFinished(
-            tree_ir.JitCallInstruction(
-                tree_ir.LoadGlobalInstruction(compiled_func_name),
-                arg_list,
-                tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
+        if is_intrinsic:
+            raise primitive_functions.PrimitiveFinished(
+                self.jit.jit_intrinsics[callee_name](*arg_list))
+        else:
+            raise primitive_functions.PrimitiveFinished(
+                tree_ir.JitCallInstruction(
+                    tree_ir.LoadGlobalInstruction(compiled_func_name),
+                    arg_list,
+                    tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
 
     def analyze_argument_dict(self, first_argument_id):
         """Analyzes the parameter-to-argument mapping started by the specified first argument

+ 18 - 0
kernel/modelverse_jit/tree_ir.py

@@ -666,6 +666,24 @@ class CreateNodeInstruction(StateInstruction):
         """Gets this state instruction's argument list."""
         return []
 
+class CreateNodeWithValueInstruction(StateInstruction):
+    """An instruction that creates a node with a given value."""
+    def __init__(self, value):
+        StateInstruction.__init__(self)
+        self.value = value
+
+    def simplify(self):
+        """Applies basic simplification to this instruction and its children."""
+        return CreateNodeWithValueInstruction(self.value.simplify())
+
+    def get_opcode(self):
+        """Gets the opcode for this state instruction."""
+        return "CNV"
+
+    def get_arguments(self):
+        """Gets this state instruction's argument list."""
+        return [self.value]
+
 class CreateDictionaryEdgeInstruction(StateInstruction):
     """An instruction that creates a dictionary edge."""
 

+ 26 - 0
kernel/modelverse_kernel/main.py

@@ -24,10 +24,36 @@ class ModelverseKernel(object):
             self.jit.compiled_function_lookup = lambda func_name: \
                 getattr(compiled_functions, func_name, None)
 
+        self.register_intrinsics()
+
         # To disable the JIT, uncomment the line below:
         #     self.jit.set_jit_enabled(False)
         self.debug_info = "(no debug information found)"
 
+    def register_intrinsics(self):
+        """Registers intrinsics with the JIT."""
+        self.jit.register_binary_intrinsic('value_eq', '==')
+        self.jit.register_binary_intrinsic('value_neq', '!=')
+
+        self.jit.register_binary_intrinsic('bool_and', 'and')
+        self.jit.register_binary_intrinsic('bool_or', 'or')
+
+        self.jit.register_binary_intrinsic('integer_addition', '+')
+        self.jit.register_binary_intrinsic('integer_multiplication', '*')
+        self.jit.register_binary_intrinsic('integer_division', '/')
+        self.jit.register_binary_intrinsic('integer_gt', '>')
+        self.jit.register_binary_intrinsic('integer_gte', '>=')
+        self.jit.register_binary_intrinsic('integer_lt', '<')
+        self.jit.register_binary_intrinsic('integer_lte', '<=')
+
+        self.jit.register_binary_intrinsic('float_addition', '+')
+        self.jit.register_binary_intrinsic('float_multiplication', '*')
+        self.jit.register_binary_intrinsic('float_division', '/')
+        self.jit.register_binary_intrinsic('float_gt', '>')
+        self.jit.register_binary_intrinsic('float_gte', '>=')
+        self.jit.register_binary_intrinsic('float_lt', '<')
+        self.jit.register_binary_intrinsic('float_lte', '<=')
+
     def execute_yields(self, username, operation, params, reply):
         try:
             self.success = True