Browse Source

Add a switch to compile 'input' instructions as calls

jonathanvdc 8 years ago
parent
commit
774e604290

+ 24 - 7
kernel/modelverse_jit/jit.py

@@ -13,6 +13,9 @@ KWARGS_PARAMETER_NAME = "kwargs"
 CALL_FUNCTION_NAME = "__call_function"
 """The name of the '__call_function' function, in the jitted function scope."""
 
+GET_INPUT_FUNCTION_NAME = "__get_input"
+"""The name of the '__get_input' function, in the jitted function scope."""
+
 def get_parameter_names(compiled_function):
     """Gets the given compiled function's parameter names."""
     if hasattr(compiled_function, '__code__'):
@@ -82,7 +85,8 @@ class ModelverseJit(object):
         self.jitted_parameters = {}
         self.jit_globals = {
             'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
-            CALL_FUNCTION_NAME : jit_runtime.call_function
+            CALL_FUNCTION_NAME : jit_runtime.call_function,
+            GET_INPUT_FUNCTION_NAME : jit_runtime.get_input
         }
         self.jit_count = 0
         self.max_instructions = max_instructions
@@ -93,6 +97,7 @@ class ModelverseJit(object):
         self.jit_enabled = True
         self.direct_calls_allowed = True
         self.tracing_enabled = False
+        self.input_function_enabled = False
 
     def set_jit_enabled(self, is_enabled=True):
         """Enables or disables the JIT."""
@@ -102,6 +107,10 @@ class ModelverseJit(object):
         """Allows or disallows direct calls from jitted to jitted code."""
         self.direct_calls_allowed = is_allowed
 
+    def use_input_function(self, is_enabled=True):
+        """Configures the JIT to compile 'input' instructions as function calls."""
+        self.input_function_enabled = is_enabled
+
     def enable_tracing(self, is_enabled=True):
         """Enables or disables tracing for jitted code."""
         self.tracing_enabled = is_enabled
@@ -236,6 +245,12 @@ class ModelverseJit(object):
                 'Cannot jit function %s at %d because it is marked non-jittable.' % (
                     '' if suggested_name is None else "'" + suggested_name + "'",
                     body_id))
+        elif not self.jit_enabled:
+            # We're not allowed to jit anything.
+            raise JitCompilationFailedException(
+                'Cannot jit function %s at %d because the JIT has been disabled.' % (
+                    '' if suggested_name is None else "'" + suggested_name + "'",
+                    body_id))
 
         # Generate a name for the function we're about to analyze, and pretend that
         # it already exists. (we need to do this for recursive functions)
@@ -376,6 +391,8 @@ class AnalysisState(object):
             # Analyze the instruction itself.
             outer_result, = yield [
                 ("CALL_ARGS", [self.instruction_analyzers[instruction_val], (self, instruction_id)])]
+            if self.jit.tracing_enabled:
+                outer_result = tree_ir.with_debug_info_trace(outer_result, debug_info, self.function_name)
             # Check if the instruction has a 'next' instruction.
             next_instr, = yield [("RD", [instruction_id, "next"])]
             if next_instr is None:
@@ -523,12 +540,12 @@ class AnalysisState(object):
         """Tries to analyze the given 'input' instruction."""
 
         # Possible alternative to the explicit syntax tree:
-        #
-        # raise primitive_functions.PrimitiveFinished(
-        #     tree_ir.create_jit_call(
-        #         tree_ir.LoadGlobalInstruction('__get_input'),
-        #         [],
-        #         tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
+        if self.jit.input_function_enabled:
+            raise primitive_functions.PrimitiveFinished(
+                tree_ir.create_jit_call(
+                    tree_ir.LoadGlobalInstruction(GET_INPUT_FUNCTION_NAME),
+                    [],
+                    tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
 
         # The plan is to generate this tree:
         #

+ 14 - 0
kernel/modelverse_jit/runtime.py

@@ -89,3 +89,17 @@ def call_function(function_id, named_arguments, **kwargs):
         result, = yield [("CALL_ARGS", [kernel.execute_rule, (username,)])]
         # An instruction has completed. Forward it.
         yield result
+
+def get_input(**parameters):
+    """Retrieves input."""
+    mvk = parameters["mvk"]
+    user_root = parameters["user_root"]
+    while 1:
+        yield [("CALL_ARGS", [mvk.input_init, (user_root,)])]
+        # Finished
+        if mvk.success:
+            # Got some input, so we can access it
+            raise primitive_functions.PrimitiveFinished(mvk.input_value)
+        else:
+            # No input, so yield None but don't stop
+            yield None

+ 1 - 1
kernel/modelverse_jit/tree_ir.py

@@ -1188,7 +1188,7 @@ def with_debug_info_trace(instruction, debug_info, function_name):
         return instruction
     else:
         if debug_info is None:
-            debug_info = 'unknown location'
+            debug_info = 'unknown location '
         if function_name is None:
             function_name = 'unknown function'
         return create_block(

+ 10 - 2
kernel/modelverse_kernel/main.py

@@ -44,15 +44,24 @@ class ModelverseKernel(object):
         jit_intrinsics.register_intrinsics(self.jit)
 
         # To disable the JIT, uncomment the line below:
+        #
         #     self.jit.set_jit_enabled(False)
         #
         # To disable direct calls in the JIT, uncomment the line below:
+        #
         #     self.jit.allow_direct_calls(False)
         #
         # To enable tracing in the JIT (for debugging purposes), uncomment
         # the line below:
         #
         #     self.jit.enable_tracing()
+        #
+        # To make the JIT compile 'input' instructions as calls to
+        # modelverse_jit.runtime.get_input, uncomment the line below:
+        #
+        #     self.jit.use_input_function()
+        #
+
         self.debug_info = defaultdict(list)
 
     def execute_yields(self, username, operation, params, reply):
@@ -114,8 +123,7 @@ class ModelverseKernel(object):
             # Try again, but this time without the JIT.
             # print(exception.message)
             gen = self.get_inst_phase_generator(inst_v, self.phase_v, user_root)
-            yield [("CALL", [gen])]
-            raise primitive_functions.PrimitiveFinished(None)
+            yield [("TAIL_CALL", [gen])]
 
         yield [("TRY", [])]
         yield [("CATCH", [jit.JitCompilationFailedException, handle_jit_failed])]