Переглянути джерело

Merge branch 'jit' of msdl.uantwerpen.be:jonathanvdc/modelverse into yentl

Yentl Van Tendeloo 8 роки тому
батько
коміт
b9ed72ac69

+ 76 - 13
kernel/modelverse_jit/jit.py

@@ -82,6 +82,10 @@ def optimize_tree_ir(instruction):
     """Optimizes an IR tree."""
     return map_and_simplify_generator(expand_constant_read, instruction)
 
+def print_value(val):
+    """A thin wrapper around 'print'."""
+    print(val)
+
 class ModelverseJit(object):
     """A high-level interface to the modelverse JIT compiler."""
     def __init__(self, max_instructions=None, compiled_function_lookup=None):
@@ -105,6 +109,8 @@ class ModelverseJit(object):
         self.tracing_enabled = False
         self.input_function_enabled = False
         self.nop_insertion_enabled = True
+        self.jit_success_log_function = None
+        self.jit_code_log_function = None
 
     def set_jit_enabled(self, is_enabled=True):
         """Enables or disables the JIT."""
@@ -129,6 +135,16 @@ class ModelverseJit(object):
            Modelverse server an opportunity to interrupt the currently running code."""
         self.nop_insertion_enabled = is_enabled
 
+    def set_jit_success_log(self, log_function=print_value):
+        """Configures this JIT instance with a function that prints output to a log.
+           Success and failure messages for specific functions are then sent to said log."""
+        self.jit_success_log_function = log_function
+
+    def set_jit_code_log(self, log_function=print_value):
+        """Configures this JIT instance with a function that prints output to a log.
+           Function definitions of jitted functions are then sent to said log."""
+        self.jit_code_log_function = log_function
+
     def mark_entry_point(self, body_id):
         """Marks the node with the given identifier as a function entry point."""
         if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
@@ -246,19 +262,23 @@ class ModelverseJit(object):
                 tree_ir.LoadGlobalInstruction(target_type.__name__),
                 [tree_ir.ReadValueInstruction(a)])))
 
-    def jit_parameters(self, body_id):
-        """Acquires the parameter list for the given body id node."""
+    def jit_signature(self, body_id):
+        """Acquires the signature for the given body id node, which consists of the
+           parameter variables, parameter name and a flag that tells if the given function
+           is mutable."""
         if body_id not in self.jitted_parameters:
-            signature_id, = yield [("RRD", [body_id, "body"])]
+            signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
             signature_id = signature_id[0]
-            param_set_id, = yield [("RD", [signature_id, "params"])]
+            param_set_id, is_mutable = yield [
+                ("RD", [signature_id, "params"]),
+                ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
             if param_set_id is None:
-                self.jitted_parameters[body_id] = ([], [])
+                self.jitted_parameters[body_id] = ([], [], is_mutable)
             else:
                 param_name_ids, = yield [("RDK", [param_set_id])]
                 param_names = yield [("RV", [n]) for n in param_name_ids]
                 param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
-                self.jitted_parameters[body_id] = (param_vars, param_names)
+                self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
 
         raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
 
@@ -291,7 +311,8 @@ class ModelverseJit(object):
         self.jitted_entry_points[body_id] = function_name
         self.jit_globals[function_name] = None
 
-        (parameter_ids, parameter_list), = yield [("CALL_ARGS", [self.jit_parameters, (body_id,)])]
+        (parameter_ids, parameter_list, is_mutable), = yield [
+            ("CALL_ARGS", [self.jit_signature, (body_id,)])]
 
         param_dict = dict(zip(parameter_ids, parameter_list))
         body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
@@ -306,12 +327,19 @@ class ModelverseJit(object):
                 if dep in self.jitted_entry_points:
                     del self.jitted_entry_points[dep]
 
-            raise JitCompilationFailedException(
-                "%s (function '%s' at %d)" % (exception.message, function_name, body_id))
+            failure_message = "%s (function '%s' at %d)" % (
+                exception.message, function_name, body_id)
+            if self.jit_success_log_function is not None:
+                self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
+            raise JitCompilationFailedException(failure_message)
 
         # Try to analyze the function's body.
         yield [("TRY", [])]
         yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
+        if is_mutable:
+            # We can't just JIT mutable functions. That'd be dangerous.
+            raise JitCompilationFailedException(
+                "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
         state = AnalysisState(
             self, body_id, user_root, body_param_dict,
             self.max_instructions)
@@ -360,7 +388,13 @@ class ModelverseJit(object):
         # Extract the compiled function from the JIT global state.
         compiled_function = self.jit_globals[function_name]
 
-        # print(constructed_function)
+        if self.jit_success_log_function is not None:
+            self.jit_success_log_function(
+                "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
+
+        if self.jit_code_log_function is not None:
+            self.jit_code_log_function(constructed_function)
+
         raise primitive_functions.PrimitiveFinished(compiled_function)
 
 class AnalysisState(object):
@@ -375,6 +409,7 @@ class AnalysisState(object):
         self.jit = jit
         self.local_mapping = local_mapping
         self.function_name = jit.jitted_entry_points[body_id]
+        self.enclosing_loop_instruction = None
 
     def get_local_name(self, local_id):
         """Gets the name for a local with the given id."""
@@ -510,7 +545,15 @@ class AnalysisState(object):
             ("RD", [instruction_id, "cond"]),
             ("RD", [instruction_id, "body"])]
 
-        (cond_r, body_r), = yield [("CALL_ARGS", [self.analyze_all, ([cond, body],)])]
+        # Analyze the condition.
+        cond_r, = yield [("CALL_ARGS", [self.analyze, (cond,)])]
+        # Store the old enclosing loop on the stack, and make this loop the
+        # new enclosing loop.
+        old_loop_instruction = self.enclosing_loop_instruction
+        self.enclosing_loop_instruction = instruction_id
+        body_r, = yield [("CALL_ARGS", [self.analyze, (body,)])]
+        # Restore hte old enclosing loop.
+        self.enclosing_loop_instruction = old_loop_instruction
         if self.jit.nop_insertion_enabled:
             create_loop_body = lambda check, body: tree_ir.create_block(
                 check,
@@ -855,7 +898,7 @@ class AnalysisState(object):
         """Tries to analyze a direct 'call' instruction."""
         self.register_function_var(callee_id)
 
-        body_id, = yield [("RD", [callee_id, "body"])]
+        body_id, = yield [("RD", [callee_id, jit_runtime.FUNCTION_BODY_KEY])]
 
         # Make this function dependent on the callee.
         if body_id in self.jit.compilation_dependencies:
@@ -982,6 +1025,24 @@ class AnalysisState(object):
         yield [("END_TRY", [])]
         raise primitive_functions.PrimitiveFinished(result)
 
+    def analyze_break(self, instruction_id):
+        """Tries to analyze the given 'break' instruction."""
+        target_instruction_id, = yield [("RD", [instruction_id, "while"])]
+        if target_instruction_id == self.enclosing_loop_instruction:
+            raise primitive_functions.PrimitiveFinished(tree_ir.BreakInstruction())
+        else:
+            raise JitCompilationFailedException(
+                "Multilevel 'break' is not supported by the baseline JIT.")
+
+    def analyze_continue(self, instruction_id):
+        """Tries to analyze the given 'continue' instruction."""
+        target_instruction_id, = yield [("RD", [instruction_id, "while"])]
+        if target_instruction_id == self.enclosing_loop_instruction:
+            raise primitive_functions.PrimitiveFinished(tree_ir.ContinueInstruction())
+        else:
+            raise JitCompilationFailedException(
+                "Multilevel 'continue' is not supported by the baseline JIT.")
+
     instruction_analyzers = {
         'if' : analyze_if,
         'while' : analyze_while,
@@ -994,6 +1055,8 @@ class AnalysisState(object):
         'access' : analyze_access,
         'output' : analyze_output,
         'input' : analyze_input,
-        'call' : analyze_call
+        'call' : analyze_call,
+        'break' : analyze_break,
+        'continue' : analyze_continue
     }
 

+ 19 - 5
kernel/modelverse_jit/runtime.py

@@ -4,20 +4,34 @@ class JitCompilationFailedException(Exception):
     """A type of exception that is raised when the jit fails to compile a function."""
     pass
 
+MUTABLE_FUNCTION_KEY = "mutable"
+"""A dictionary key for functions that are mutable."""
+
+FUNCTION_BODY_KEY = "body"
+"""A dictionary key for function bodies."""
+
 def call_function(function_id, named_arguments, **kwargs):
     """Runs the function with the given id, passing it the specified argument dictionary."""
     user_root = kwargs['user_root']
     kernel = kwargs['mvk']
-    body_id, = yield [("RD", [function_id, "body"])]
-    kernel.jit.mark_entry_point(body_id)
+    body_id, is_mutable = yield [
+        ("RD", [function_id, FUNCTION_BODY_KEY]),
+        ("RD", [function_id, MUTABLE_FUNCTION_KEY])]
 
     # Try to jit the function here. We might be able to avoid building the stack
     # frame.
     def handle_jit_failed(_):
+        """Interprets the function."""
         interpreter_args = {'function_id' : function_id, 'named_arguments' : named_arguments}
         interpreter_args.update(kwargs)
         yield [("TAIL_CALL_KWARGS", [interpret_function, interpreter_args])]
 
+    if is_mutable is not None:
+        kernel.jit.mark_no_jit(body_id)
+        yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])]
+    else:
+        kernel.jit.mark_entry_point(body_id)
+
     yield [("TRY", [])]
     yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])]
     # Try to compile.
@@ -34,7 +48,7 @@ def interpret_function(function_id, named_arguments, **kwargs):
     user_root = kwargs['user_root']
     kernel = kwargs['mvk']
     user_frame, = yield [("RD", [user_root, "frame"])]
-    inst, body_id = yield [("RD", [user_frame, "IP"]), ("RD", [function_id, "body"])]
+    inst, body_id = yield [("RD", [user_frame, "IP"]), ("RD", [function_id, FUNCTION_BODY_KEY])]
     kernel.jit.mark_entry_point(body_id)
 
     # Create a new stack frame.
@@ -66,8 +80,8 @@ def interpret_function(function_id, named_arguments, **kwargs):
                           ]
 
     # Put the parameters in the new stack frame's symbol table.
-    (parameter_vars, parameter_names), = yield [
-        ("CALL_ARGS", [kernel.jit.jit_parameters, (body_id,)])]
+    (parameter_vars, parameter_names, _), = yield [
+        ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])]
     parameter_dict = dict(zip(parameter_names, parameter_vars))
 
     for (key, value) in named_arguments.items():

+ 21 - 4
kernel/modelverse_kernel/main.py

@@ -52,15 +52,33 @@ class ModelverseKernel(object):
         #
         #     self.jit.allow_direct_calls(False)
         #
+        # To make the JIT compile 'input' instructions as calls to
+        # modelverse_jit.runtime.get_input, uncomment the line below:
+        #
+        #     self.jit.use_input_function()
+        #
         # To enable tracing in the JIT (for debugging purposes), uncomment
         # the line below:
         #
         #     self.jit.enable_tracing()
         #
-        # To make the JIT compile 'input' instructions as calls to
-        # modelverse_jit.runtime.get_input, uncomment the line below:
+        # To make the JIT print JIT successes and errors to the command-line,
+        # uncomment the line below:
         #
-        #     self.jit.use_input_function()
+        #     self.jit.set_jit_success_log()
+        #
+        # If you want, you can use a custom logging function:
+        #
+        #     self.jit.set_jit_success_log(logging_function)
+        #
+        # To make the JIT print jitted code to the command-line, uncomment the
+        # line below:
+        #
+        #     self.jit.set_jit_code_log()
+        #
+        # If you want, you can use a custom logging function:
+        #
+        #     self.jit.set_jit_code_log(logging_function)
         #
 
         self.debug_info = defaultdict(list)
@@ -126,7 +144,6 @@ class ModelverseKernel(object):
 
             def handle_jit_failed(exception):
                 # Try again, but this time without the JIT.
-                # print(exception.message)
                 gen = self.get_inst_phase_generator(inst_v, self.phase_v, user_root)
                 yield [("TAIL_CALL", [gen])]
 

+ 1 - 16
performance/code/test_harness.alc

@@ -3,26 +3,11 @@ include "jit.alh"
 
 Void function test_main()
 
-Void function call_function(function_name : String):
-	// Resolve the specified function, and execute it.
-	Element func
-	func = resolve(function_name)
-	func()
-	return!
-
 Void function main():
-	String config
 	Integer start_time
 	Integer end_time
-	config = input()
-	// if (config == "interpreter"):
-		// set_jit_enabled(False)
-
 	start_time = time()
-	// HACK: use `call_function` to hide what would otherwise be a direct call to `test_main`
-	// from the JIT. This prevents the JIT from compiling `test_main` _before_ `config` has
-	// been analyzed.
-	call_function("test_main")
+	test_main()
 	end_time = time()
 	output(end_time - start_time)
 	

+ 1 - 1
performance/utils.py

@@ -244,7 +244,7 @@ def run_perf_test(files, parameters, optimization_level, n_iterations=1):
     for _ in xrange(n_iterations):
         result += float(
             run_file_single_output(
-                files, [optimization_level] + parameters + [0], 'CO',
+                files, parameters + [0], 'CO',
                 optimization_level)) / float(n_iterations)
     return result