|
@@ -82,6 +82,10 @@ def optimize_tree_ir(instruction):
|
|
|
"""Optimizes an IR tree."""
|
|
|
return map_and_simplify_generator(expand_constant_read, instruction)
|
|
|
|
|
|
+def print_value(val):
|
|
|
+ """A thin wrapper around 'print'."""
|
|
|
+ print(val)
|
|
|
+
|
|
|
class ModelverseJit(object):
|
|
|
"""A high-level interface to the modelverse JIT compiler."""
|
|
|
def __init__(self, max_instructions=None, compiled_function_lookup=None):
|
|
@@ -105,6 +109,8 @@ class ModelverseJit(object):
|
|
|
self.tracing_enabled = False
|
|
|
self.input_function_enabled = False
|
|
|
self.nop_insertion_enabled = True
|
|
|
+ self.jit_success_log_function = None
|
|
|
+ self.jit_code_log_function = None
|
|
|
|
|
|
def set_jit_enabled(self, is_enabled=True):
|
|
|
"""Enables or disables the JIT."""
|
|
@@ -129,6 +135,16 @@ class ModelverseJit(object):
|
|
|
Modelverse server an opportunity to interrupt the currently running code."""
|
|
|
self.nop_insertion_enabled = is_enabled
|
|
|
|
|
|
+ def set_jit_success_log(self, log_function=print_value):
|
|
|
+ """Configures this JIT instance with a function that prints output to a log.
|
|
|
+ Success and failure messages for specific functions are then sent to said log."""
|
|
|
+ self.jit_success_log_function = log_function
|
|
|
+
|
|
|
+ def set_jit_code_log(self, log_function=print_value):
|
|
|
+ """Configures this JIT instance with a function that prints output to a log.
|
|
|
+ Function definitions of jitted functions are then sent to said log."""
|
|
|
+ self.jit_code_log_function = log_function
|
|
|
+
|
|
|
def mark_entry_point(self, body_id):
|
|
|
"""Marks the node with the given identifier as a function entry point."""
|
|
|
if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
|
|
@@ -246,19 +262,23 @@ class ModelverseJit(object):
|
|
|
tree_ir.LoadGlobalInstruction(target_type.__name__),
|
|
|
[tree_ir.ReadValueInstruction(a)])))
|
|
|
|
|
|
- def jit_parameters(self, body_id):
|
|
|
- """Acquires the parameter list for the given body id node."""
|
|
|
+ def jit_signature(self, body_id):
|
|
|
+ """Acquires the signature for the given body id node, which consists of the
|
|
|
+ parameter variables, parameter name and a flag that tells if the given function
|
|
|
+ is mutable."""
|
|
|
if body_id not in self.jitted_parameters:
|
|
|
- signature_id, = yield [("RRD", [body_id, "body"])]
|
|
|
+ signature_id, = yield [("RRD", [body_id, jit_runtime.FUNCTION_BODY_KEY])]
|
|
|
signature_id = signature_id[0]
|
|
|
- param_set_id, = yield [("RD", [signature_id, "params"])]
|
|
|
+ param_set_id, is_mutable = yield [
|
|
|
+ ("RD", [signature_id, "params"]),
|
|
|
+ ("RD", [signature_id, jit_runtime.MUTABLE_FUNCTION_KEY])]
|
|
|
if param_set_id is None:
|
|
|
- self.jitted_parameters[body_id] = ([], [])
|
|
|
+ self.jitted_parameters[body_id] = ([], [], is_mutable)
|
|
|
else:
|
|
|
param_name_ids, = yield [("RDK", [param_set_id])]
|
|
|
param_names = yield [("RV", [n]) for n in param_name_ids]
|
|
|
param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
|
|
|
- self.jitted_parameters[body_id] = (param_vars, param_names)
|
|
|
+ self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
|
|
|
|
|
|
raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
|
|
|
|
|
@@ -291,7 +311,8 @@ class ModelverseJit(object):
|
|
|
self.jitted_entry_points[body_id] = function_name
|
|
|
self.jit_globals[function_name] = None
|
|
|
|
|
|
- (parameter_ids, parameter_list), = yield [("CALL_ARGS", [self.jit_parameters, (body_id,)])]
|
|
|
+ (parameter_ids, parameter_list, is_mutable), = yield [
|
|
|
+ ("CALL_ARGS", [self.jit_signature, (body_id,)])]
|
|
|
|
|
|
param_dict = dict(zip(parameter_ids, parameter_list))
|
|
|
body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
|
|
@@ -306,12 +327,19 @@ class ModelverseJit(object):
|
|
|
if dep in self.jitted_entry_points:
|
|
|
del self.jitted_entry_points[dep]
|
|
|
|
|
|
- raise JitCompilationFailedException(
|
|
|
- "%s (function '%s' at %d)" % (exception.message, function_name, body_id))
|
|
|
+ failure_message = "%s (function '%s' at %d)" % (
|
|
|
+ exception.message, function_name, body_id)
|
|
|
+ if self.jit_success_log_function is not None:
|
|
|
+ self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
|
|
|
+ raise JitCompilationFailedException(failure_message)
|
|
|
|
|
|
# Try to analyze the function's body.
|
|
|
yield [("TRY", [])]
|
|
|
yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
|
|
|
+ if is_mutable:
|
|
|
+ # We can't just JIT mutable functions. That'd be dangerous.
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
|
|
|
state = AnalysisState(
|
|
|
self, body_id, user_root, body_param_dict,
|
|
|
self.max_instructions)
|
|
@@ -360,7 +388,13 @@ class ModelverseJit(object):
|
|
|
# Extract the compiled function from the JIT global state.
|
|
|
compiled_function = self.jit_globals[function_name]
|
|
|
|
|
|
- # print(constructed_function)
|
|
|
+ if self.jit_success_log_function is not None:
|
|
|
+ self.jit_success_log_function(
|
|
|
+ "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
|
|
|
+
|
|
|
+ if self.jit_code_log_function is not None:
|
|
|
+ self.jit_code_log_function(constructed_function)
|
|
|
+
|
|
|
raise primitive_functions.PrimitiveFinished(compiled_function)
|
|
|
|
|
|
class AnalysisState(object):
|
|
@@ -375,6 +409,7 @@ class AnalysisState(object):
|
|
|
self.jit = jit
|
|
|
self.local_mapping = local_mapping
|
|
|
self.function_name = jit.jitted_entry_points[body_id]
|
|
|
+ self.enclosing_loop_instruction = None
|
|
|
|
|
|
def get_local_name(self, local_id):
|
|
|
"""Gets the name for a local with the given id."""
|
|
@@ -510,7 +545,15 @@ class AnalysisState(object):
|
|
|
("RD", [instruction_id, "cond"]),
|
|
|
("RD", [instruction_id, "body"])]
|
|
|
|
|
|
- (cond_r, body_r), = yield [("CALL_ARGS", [self.analyze_all, ([cond, body],)])]
|
|
|
+ # Analyze the condition.
|
|
|
+ cond_r, = yield [("CALL_ARGS", [self.analyze, (cond,)])]
|
|
|
+ # Store the old enclosing loop on the stack, and make this loop the
|
|
|
+ # new enclosing loop.
|
|
|
+ old_loop_instruction = self.enclosing_loop_instruction
|
|
|
+ self.enclosing_loop_instruction = instruction_id
|
|
|
+ body_r, = yield [("CALL_ARGS", [self.analyze, (body,)])]
|
|
|
+ # Restore hte old enclosing loop.
|
|
|
+ self.enclosing_loop_instruction = old_loop_instruction
|
|
|
if self.jit.nop_insertion_enabled:
|
|
|
create_loop_body = lambda check, body: tree_ir.create_block(
|
|
|
check,
|
|
@@ -855,7 +898,7 @@ class AnalysisState(object):
|
|
|
"""Tries to analyze a direct 'call' instruction."""
|
|
|
self.register_function_var(callee_id)
|
|
|
|
|
|
- body_id, = yield [("RD", [callee_id, "body"])]
|
|
|
+ body_id, = yield [("RD", [callee_id, jit_runtime.FUNCTION_BODY_KEY])]
|
|
|
|
|
|
# Make this function dependent on the callee.
|
|
|
if body_id in self.jit.compilation_dependencies:
|
|
@@ -982,6 +1025,24 @@ class AnalysisState(object):
|
|
|
yield [("END_TRY", [])]
|
|
|
raise primitive_functions.PrimitiveFinished(result)
|
|
|
|
|
|
+ def analyze_break(self, instruction_id):
|
|
|
+ """Tries to analyze the given 'break' instruction."""
|
|
|
+ target_instruction_id, = yield [("RD", [instruction_id, "while"])]
|
|
|
+ if target_instruction_id == self.enclosing_loop_instruction:
|
|
|
+ raise primitive_functions.PrimitiveFinished(tree_ir.BreakInstruction())
|
|
|
+ else:
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Multilevel 'break' is not supported by the baseline JIT.")
|
|
|
+
|
|
|
+ def analyze_continue(self, instruction_id):
|
|
|
+ """Tries to analyze the given 'continue' instruction."""
|
|
|
+ target_instruction_id, = yield [("RD", [instruction_id, "while"])]
|
|
|
+ if target_instruction_id == self.enclosing_loop_instruction:
|
|
|
+ raise primitive_functions.PrimitiveFinished(tree_ir.ContinueInstruction())
|
|
|
+ else:
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Multilevel 'continue' is not supported by the baseline JIT.")
|
|
|
+
|
|
|
instruction_analyzers = {
|
|
|
'if' : analyze_if,
|
|
|
'while' : analyze_while,
|
|
@@ -994,6 +1055,8 @@ class AnalysisState(object):
|
|
|
'access' : analyze_access,
|
|
|
'output' : analyze_output,
|
|
|
'input' : analyze_input,
|
|
|
- 'call' : analyze_call
|
|
|
+ 'call' : analyze_call,
|
|
|
+ 'break' : analyze_break,
|
|
|
+ 'continue' : analyze_continue
|
|
|
}
|
|
|
|