|
@@ -133,6 +133,7 @@ class ModelverseJit(object):
|
|
|
self.thunks_enabled = True
|
|
|
self.jit_success_log_function = None
|
|
|
self.jit_code_log_function = None
|
|
|
+ self.compile_function_body = compile_function_body_baseline
|
|
|
|
|
|
def set_jit_enabled(self, is_enabled=True):
|
|
|
"""Enables or disables the JIT."""
|
|
@@ -176,6 +177,10 @@ class ModelverseJit(object):
|
|
|
Function definitions of jitted functions are then sent to said log."""
|
|
|
self.jit_code_log_function = log_function
|
|
|
|
|
|
+ def set_function_body_compiler(self, compile_function_body):
|
|
|
+ """Sets the function that the JIT uses to compile function bodies."""
|
|
|
+ self.compile_function_body = compile_function_body
|
|
|
+
|
|
|
def mark_entry_point(self, body_id):
|
|
|
"""Marks the node with the given identifier as a function entry point."""
|
|
|
if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
|
|
@@ -433,11 +438,9 @@ class ModelverseJit(object):
|
|
|
self.jitted_entry_points[body_id] = function_name
|
|
|
self.jit_globals[function_name] = None
|
|
|
|
|
|
- (parameter_ids, parameter_list, is_mutable), = yield [
|
|
|
+ (_, _, is_mutable), = yield [
|
|
|
("CALL_ARGS", [self.jit_signature, (body_id,)])]
|
|
|
|
|
|
- param_dict = dict(zip(parameter_ids, parameter_list))
|
|
|
- body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
|
|
|
dependencies = set([body_id])
|
|
|
self.compilation_dependencies[body_id] = dependencies
|
|
|
|
|
@@ -462,41 +465,13 @@ class ModelverseJit(object):
|
|
|
# We can't just JIT mutable functions. That'd be dangerous.
|
|
|
raise JitCompilationFailedException(
|
|
|
"Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
|
|
|
- body_bytecode, = yield [("CALL_ARGS", [self.jit_parse_bytecode, (body_id,)])]
|
|
|
- state = bytecode_to_tree.AnalysisState(
|
|
|
- self, body_id, task_root, body_param_dict,
|
|
|
- self.max_instructions)
|
|
|
- constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
|
|
|
- if self.jit_code_log_function is not None:
|
|
|
- bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict)
|
|
|
- bytecode_analyzer.analyze(body_bytecode)
|
|
|
- yield [
|
|
|
- ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, self)])]
|
|
|
- self.jit_code_log_function(
|
|
|
- "CFG for function '%s' at '%d':\n%s" % (
|
|
|
- function_name, body_id,
|
|
|
- '\n'.join(
|
|
|
- map(
|
|
|
- str,
|
|
|
- cfg_optimization.get_all_reachable_blocks(
|
|
|
- bytecode_analyzer.entry_point)))))
|
|
|
- cfg_func = create_bare_function(
|
|
|
- function_name, parameter_list,
|
|
|
- cfg_to_tree.lower_flow_graph(bytecode_analyzer.entry_point, self))
|
|
|
- self.jit_code_log_function(
|
|
|
- "Lowered CFG for function '%s' at '%d':\n%s" % (
|
|
|
- function_name, body_id, cfg_func))
|
|
|
+
|
|
|
+ constructed_function, = yield [
|
|
|
+ ("CALL_ARGS", [self.compile_function_body, (self, function_name, body_id, task_root)])]
|
|
|
|
|
|
yield [("END_TRY", [])]
|
|
|
del self.compilation_dependencies[body_id]
|
|
|
|
|
|
- # Optimize the function's body.
|
|
|
- constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
|
|
|
-
|
|
|
- # Wrap the tree IR in a function definition.
|
|
|
- constructed_function = create_function(
|
|
|
- function_name, parameter_list, param_dict, body_param_dict, constructed_body)
|
|
|
-
|
|
|
# Convert the function definition to Python code, and compile it.
|
|
|
compiled_function = self.jit_define_function(function_name, constructed_function)
|
|
|
|
|
@@ -639,3 +614,47 @@ class ModelverseJit(object):
|
|
|
tree_ir.LiteralInstruction('value')),
|
|
|
tree_ir.LiteralInstruction(jit_runtime.FUNCTION_BODY_KEY)),
|
|
|
global_name)
|
|
|
+
|
|
|
+def compile_function_body_baseline(jit, function_name, body_id, task_root):
|
|
|
+ """Have the baseline JIT compile the function with the given name and body id."""
|
|
|
+ (parameter_ids, parameter_list, _), = yield [
|
|
|
+ ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
|
|
|
+ param_dict = dict(zip(parameter_ids, parameter_list))
|
|
|
+ body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
|
|
|
+ body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
|
|
|
+ state = bytecode_to_tree.AnalysisState(
|
|
|
+ jit, body_id, task_root, body_param_dict,
|
|
|
+ jit.max_instructions)
|
|
|
+ constructed_body, = yield [("CALL_ARGS", [state.analyze, (body_bytecode,)])]
|
|
|
+
|
|
|
+ # Optimize the function's body.
|
|
|
+ constructed_body, = yield [("CALL_ARGS", [optimize_tree_ir, (constructed_body,)])]
|
|
|
+
|
|
|
+ # Wrap the tree IR in a function definition.
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ create_function(
|
|
|
+ function_name, parameter_list, param_dict, body_param_dict, constructed_body))
|
|
|
+
|
|
|
+def compile_function_body_fast(jit, function_name, body_id, _):
|
|
|
+ """Have the fast JIT compile the function with the given name and body id."""
|
|
|
+ (parameter_ids, parameter_list, _), = yield [
|
|
|
+ ("CALL_ARGS", [jit.jit_signature, (body_id,)])]
|
|
|
+ param_dict = dict(zip(parameter_ids, parameter_list))
|
|
|
+ body_bytecode, = yield [("CALL_ARGS", [jit.jit_parse_bytecode, (body_id,)])]
|
|
|
+ bytecode_analyzer = bytecode_to_cfg.AnalysisState(param_dict)
|
|
|
+ bytecode_analyzer.analyze(body_bytecode)
|
|
|
+ yield [
|
|
|
+ ("CALL_ARGS", [cfg_optimization.optimize, (bytecode_analyzer.entry_point, jit)])]
|
|
|
+ if jit.jit_code_log_function is not None:
|
|
|
+ jit.jit_code_log_function(
|
|
|
+ "CFG for function '%s' at '%d':\n%s" % (
|
|
|
+ function_name, body_id,
|
|
|
+ '\n'.join(
|
|
|
+ map(
|
|
|
+ str,
|
|
|
+ cfg_optimization.get_all_reachable_blocks(
|
|
|
+ bytecode_analyzer.entry_point)))))
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ create_bare_function(
|
|
|
+ function_name, parameter_list,
|
|
|
+ cfg_to_tree.lower_flow_graph(bytecode_analyzer.entry_point, jit)))
|