Просмотр исходного кода

Change how the JIT stores jitted functions

jonathanvdc 8 лет назад
Родитель
Сommit
361b26aa10
2 измененных файлов с 13 добавлено и 8 удалено
  1. 12 7
      kernel/modelverse_jit/jit.py
  2. 1 1
      kernel/modelverse_kernel/main.py

+ 12 - 7
kernel/modelverse_jit/jit.py

@@ -45,20 +45,26 @@ class ModelverseJit(object):
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
 
-    def register_compiled(self, body_id, compiled):
+    def register_compiled(self, body_id, compiled_function, function_name=None):
         """Registers a compiled entry point with the JIT."""
-        self.jitted_entry_points[body_id] = compiled
+        if function_name is None:
+            function_name = 'jit_func%d' % self.jit_count
+            self.jit_count += 1
+
+        self.jitted_entry_points[body_id] = function_name
+        self.jit_globals[function_name] = compiled_function
         if body_id in self.todo_entry_points:
             self.todo_entry_points.remove(body_id)
 
-    def try_jit(self, body_id, parameter_list):
+    def jit_compile(self, body_id, parameter_list):
         """Tries to jit the function defined by the given entry point id and parameter list."""
         # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
         # pylint: disable=I0011,W0122
 
         if body_id in self.jitted_entry_points:
             # We have already compiled this function.
-            raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
+            raise primitive_functions.PrimitiveFinished(
+                self.jit_globals[self.jitted_entry_points[body_id]])
         elif body_id in self.no_jit_entry_points:
             # We're not allowed to jit this function or have tried and failed before.
             raise JitCompilationFailedException(
@@ -86,11 +92,10 @@ class ModelverseJit(object):
         exec(str(constructed_function), self.jit_globals)
         # Extract the compiled function from the JIT global state.
         compiled_function = self.jit_globals[constructed_function.name]
+        # Save the compiled function so we can reuse it later.
+        self.jitted_entry_points[body_id] = constructed_function.name
 
         print(constructed_function)
-
-        # Save the compiled function so we can reuse it later.
-        self.jitted_entry_points[body_id] = compiled_function
         raise primitive_functions.PrimitiveFinished(compiled_function)
 
 class AnalysisState(object):

+ 1 - 1
kernel/modelverse_kernel/main.py

@@ -144,7 +144,7 @@ class ModelverseKernel(object):
                 prim = self.compiled[inst](**parameters)
             else:
                 try:
-                    jit_gen = self.jit.try_jit(inst, dict_keys)
+                    jit_gen = self.jit.jit_compile(inst, dict_keys)
                     inp = None
                     while 1:
                         inp = yield jit_gen.send(inp)