|
@@ -8,104 +8,9 @@ import modelverse_jit.runtime as jit_runtime
|
|
|
# in this module.
|
|
|
JitCompilationFailedException = jit_runtime.JitCompilationFailedException
|
|
|
|
|
|
-def map_and_simplify_generator(function, instruction):
|
|
|
- """Applies the given mapping function to every instruction in the tree
|
|
|
- that has the given instruction as root, and simplifies it on-the-fly.
|
|
|
-
|
|
|
- This is at least as powerful as first mapping and then simplifying, as
|
|
|
- maps and simplifications are interspersed.
|
|
|
-
|
|
|
- This function assumes that function creates a generator that returns by
|
|
|
- raising a primitive_functions.PrimitiveFinished."""
|
|
|
-
|
|
|
- # First handle the children by mapping on them and then simplifying them.
|
|
|
- new_children = []
|
|
|
- for inst in instruction.get_children():
|
|
|
- new_inst, = yield [("CALL_ARGS", [map_and_simplify_generator, (function, inst)])]
|
|
|
- new_children.append(new_inst)
|
|
|
-
|
|
|
- # Then apply the function to the top-level node.
|
|
|
- transformed, = yield [("CALL_ARGS", [function, (instruction.create(new_children),)])]
|
|
|
- # Finally, simplify the transformed top-level node.
|
|
|
- raise primitive_functions.PrimitiveFinished(transformed.simplify_node())
|
|
|
-
|
|
|
-def expand_constant_read(instruction):
|
|
|
- """Tries to replace a read of a constant node by a literal."""
|
|
|
- if isinstance(instruction, tree_ir.ReadValueInstruction) and \
|
|
|
- isinstance(instruction.node_id, tree_ir.LiteralInstruction):
|
|
|
- val, = yield [("RV", [instruction.node_id.literal])]
|
|
|
- raise primitive_functions.PrimitiveFinished(tree_ir.LiteralInstruction(val))
|
|
|
- else:
|
|
|
- raise primitive_functions.PrimitiveFinished(instruction)
|
|
|
-
|
|
|
-def optimize_tree_ir(instruction):
|
|
|
- """Optimizes an IR tree."""
|
|
|
- return map_and_simplify_generator(expand_constant_read, instruction)
|
|
|
-
|
|
|
-def create_bare_function(function_name, parameter_list, function_body):
|
|
|
- """Creates a function definition from the given function name, parameter list
|
|
|
- and function body. No prolog is included."""
|
|
|
- # Wrap the IR in a function definition, give it a unique name.
|
|
|
- return tree_ir.DefineFunctionInstruction(
|
|
|
- function_name,
|
|
|
- parameter_list + ['**' + jit_runtime.KWARGS_PARAMETER_NAME],
|
|
|
- function_body)
|
|
|
-
|
|
|
-def create_function(
|
|
|
- function_name, parameter_list, param_dict,
|
|
|
- body_param_dict, function_body, source_map_name=None,
|
|
|
- compatible_temporary_protects=False):
|
|
|
- """Creates a function from the given function name, parameter list,
|
|
|
- variable-to-parameter name map, variable-to-local name map and
|
|
|
- function body. An optional source map can be included, too."""
|
|
|
- # Write a prologue and prepend it to the generated function body.
|
|
|
- prolog_statements = []
|
|
|
- # If the source map is not None, then we should generate a "DEBUG_INFO"
|
|
|
- # request.
|
|
|
- if source_map_name is not None:
|
|
|
- prolog_statements.append(
|
|
|
- tree_ir.RegisterDebugInfoInstruction(
|
|
|
- tree_ir.LiteralInstruction(function_name),
|
|
|
- tree_ir.LoadGlobalInstruction(source_map_name),
|
|
|
- tree_ir.LiteralInstruction(jit_runtime.BASELINE_JIT_ORIGIN_NAME)))
|
|
|
-
|
|
|
- # Create a LOCALS_NODE_NAME node, and connect it to the user root.
|
|
|
- prolog_statements.append(
|
|
|
- tree_ir.create_new_local_node(
|
|
|
- jit_runtime.LOCALS_NODE_NAME,
|
|
|
- tree_ir.LoadIndexInstruction(
|
|
|
- tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
|
|
|
- tree_ir.LiteralInstruction('task_root')),
|
|
|
- jit_runtime.LOCALS_EDGE_NAME))
|
|
|
- for (key, val) in list(param_dict.items()):
|
|
|
- arg_ptr = tree_ir.create_new_local_node(
|
|
|
- body_param_dict[key],
|
|
|
- tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME))
|
|
|
- prolog_statements.append(arg_ptr)
|
|
|
- prolog_statements.append(
|
|
|
- tree_ir.CreateDictionaryEdgeInstruction(
|
|
|
- tree_ir.LoadLocalInstruction(body_param_dict[key]),
|
|
|
- tree_ir.LiteralInstruction('value'),
|
|
|
- tree_ir.LoadLocalInstruction(val)))
|
|
|
-
|
|
|
- constructed_body = tree_ir.create_block(
|
|
|
- *(prolog_statements + [function_body]))
|
|
|
-
|
|
|
- # Shield temporaries from the GC.
|
|
|
- constructed_body = tree_ir.protect_temporaries_from_gc(
|
|
|
- constructed_body,
|
|
|
- tree_ir.LoadLocalInstruction(jit_runtime.LOCALS_NODE_NAME),
|
|
|
- compatible_temporary_protects)
|
|
|
-
|
|
|
- return create_bare_function(function_name, parameter_list, constructed_body)
|
|
|
-
|
|
|
-def print_value(val):
|
|
|
- """A thin wrapper around 'print'."""
|
|
|
- print(val)
|
|
|
-
|
|
|
class ModelverseJit(object):
|
|
|
"""A high-level interface to the modelverse JIT compiler."""
|
|
|
- def __init__(self, max_instructions=None, compiled_function_lookup=None):
|
|
|
+ def __init__(self):
|
|
|
self.todo_entry_points = set()
|
|
|
self.no_jit_entry_points = set()
|
|
|
self.jitted_parameters = {}
|
|
@@ -118,22 +23,11 @@ class ModelverseJit(object):
|
|
|
self.global_functions = {}
|
|
|
# global_functions_inv maps body ids to global value names.
|
|
|
self.global_functions_inv = {}
|
|
|
- # bytecode_graphs maps body ids to their parsed bytecode graphs.
|
|
|
- self.bytecode_graphs = {}
|
|
|
# jitted_function_aliases maps body ids to known aliases.
|
|
|
self.jitted_function_aliases = defaultdict(set)
|
|
|
self.jit_count = 0
|
|
|
- self.max_instructions = max_instructions
|
|
|
- self.compiled_function_lookup = compiled_function_lookup
|
|
|
self.compilation_dependencies = {}
|
|
|
self.jit_enabled = True
|
|
|
- self.direct_calls_allowed = True
|
|
|
- self.tracing_enabled = False
|
|
|
- self.source_maps_enabled = True
|
|
|
- self.input_function_enabled = False
|
|
|
- self.nop_insertion_enabled = True
|
|
|
- self.jit_success_log_function = None
|
|
|
- self.jit_code_log_function = None
|
|
|
|
|
|
def set_jit_enabled(self, is_enabled=True):
|
|
|
"""Enables or disables the JIT."""
|
|
@@ -166,16 +60,6 @@ class ModelverseJit(object):
|
|
|
the currently running code."""
|
|
|
self.nop_insertion_enabled = is_enabled
|
|
|
|
|
|
- def set_jit_success_log(self, log_function=print_value):
|
|
|
- """Configures this JIT instance with a function that prints output to a log.
|
|
|
- Success and failure messages for specific functions are then sent to said log."""
|
|
|
- self.jit_success_log_function = log_function
|
|
|
-
|
|
|
- def set_jit_code_log(self, log_function=print_value):
|
|
|
- """Configures this JIT instance with a function that prints output to a log.
|
|
|
- Function definitions of jitted functions are then sent to said log."""
|
|
|
- self.jit_code_log_function = log_function
|
|
|
-
|
|
|
def set_function_body_compiler(self, compile_function_body):
|
|
|
"""Sets the function that the JIT uses to compile function bodies."""
|
|
|
self.compile_function_body = compile_function_body
|
|
@@ -348,16 +232,6 @@ class ModelverseJit(object):
|
|
|
|
|
|
raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
|
|
|
|
|
|
- def jit_parse_bytecode(self, body_id):
|
|
|
- """Parses the given function body as a bytecode graph."""
|
|
|
- if body_id in self.bytecode_graphs:
|
|
|
- raise primitive_functions.PrimitiveFinished(self.bytecode_graphs[body_id])
|
|
|
-
|
|
|
- parser = bytecode_parser.BytecodeParser()
|
|
|
- result, = yield [("CALL_ARGS", [parser.parse_instruction, (body_id,)])]
|
|
|
- self.bytecode_graphs[body_id] = result
|
|
|
- raise primitive_functions.PrimitiveFinished(result)
|
|
|
-
|
|
|
def check_jittable(self, body_id, suggested_name=None):
|
|
|
"""Checks if the function with the given body id is obviously non-jittable. If it's
|
|
|
non-jittable, then a `JitCompilationFailedException` exception is thrown."""
|
|
@@ -376,73 +250,6 @@ class ModelverseJit(object):
|
|
|
'' if suggested_name is None else "'" + suggested_name + "'",
|
|
|
body_id))
|
|
|
|
|
|
- def jit_recompile(self, task_root, body_id, function_name, compile_function_body=None):
|
|
|
- """Replaces the function with the given name by compiling the bytecode at the given
|
|
|
- body id."""
|
|
|
- if compile_function_body is None:
|
|
|
- compile_function_body = self.compile_function_body
|
|
|
-
|
|
|
- self.check_jittable(body_id, function_name)
|
|
|
-
|
|
|
- # Generate a name for the function we're about to analyze, and pretend that
|
|
|
- # it already exists. (we need to do this for recursive functions)
|
|
|
- self.jitted_entry_points[body_id] = function_name
|
|
|
- self.jit_globals[function_name] = None
|
|
|
-
|
|
|
- (_, _, is_mutable), = yield [
|
|
|
- ("CALL_ARGS", [self.jit_signature, (body_id,)])]
|
|
|
-
|
|
|
- dependencies = set([body_id])
|
|
|
- self.compilation_dependencies[body_id] = dependencies
|
|
|
-
|
|
|
- def handle_jit_exception(exception):
|
|
|
- # If analysis fails, then a JitCompilationFailedException will be thrown.
|
|
|
- print("EXCEPTION with mutable")
|
|
|
- del self.compilation_dependencies[body_id]
|
|
|
- for dep in dependencies:
|
|
|
- self.mark_no_jit(dep)
|
|
|
- if dep in self.jitted_entry_points:
|
|
|
- del self.jitted_entry_points[dep]
|
|
|
-
|
|
|
- failure_message = "%s (function '%s' at %d)" % (
|
|
|
- str(exception), function_name, body_id)
|
|
|
- if self.jit_success_log_function is not None:
|
|
|
- self.jit_success_log_function('JIT compilation failed: %s' % failure_message)
|
|
|
- raise JitCompilationFailedException(failure_message)
|
|
|
-
|
|
|
- # Try to analyze the function's body.
|
|
|
- yield [("TRY", [])]
|
|
|
- yield [("CATCH", [JitCompilationFailedException, handle_jit_exception])]
|
|
|
- if is_mutable:
|
|
|
- # We can't just JIT mutable functions. That'd be dangerous.
|
|
|
- raise JitCompilationFailedException(
|
|
|
- "Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
|
|
|
-
|
|
|
- compiled_function, = yield [
|
|
|
- ("CALL_ARGS", [compile_function_body, (self, function_name, body_id, task_root)])]
|
|
|
-
|
|
|
- yield [("END_TRY", [])]
|
|
|
- del self.compilation_dependencies[body_id]
|
|
|
-
|
|
|
- if self.jit_success_log_function is not None:
|
|
|
- assert self.jitted_entry_points[body_id] == function_name
|
|
|
- self.jit_success_log_function(
|
|
|
- "JIT compilation successful: (function '%s' at %d)" % (function_name, body_id))
|
|
|
-
|
|
|
- raise primitive_functions.PrimitiveFinished(compiled_function)
|
|
|
-
|
|
|
- def get_source_map_name(self, function_name):
|
|
|
- """Gets the name of the given jitted function's source map. None is returned if source maps
|
|
|
- are disabled."""
|
|
|
- if self.source_maps_enabled:
|
|
|
- return function_name + "_source_map"
|
|
|
- else:
|
|
|
- return None
|
|
|
-
|
|
|
- def get_can_rejit_name(self, function_name):
|
|
|
- """Gets the name of the given jitted function's can-rejit flag."""
|
|
|
- return function_name + "_can_rejit"
|
|
|
-
|
|
|
def jit_define_function(self, function_name, function_def):
|
|
|
"""Converts the given tree-IR function definition to Python code, defines it,
|
|
|
and extracts the resulting function."""
|
|
@@ -465,47 +272,3 @@ class ModelverseJit(object):
|
|
|
def jit_delete_function(self, function_name):
|
|
|
"""Deletes the function with the given function name."""
|
|
|
del self.jit_globals[function_name]
|
|
|
-
|
|
|
- def jit_compile(self, task_root, body_id, suggested_name=None):
|
|
|
- """Tries to jit the function defined by the given entry point id and parameter list."""
|
|
|
- if body_id is None:
|
|
|
- raise ValueError('body_id cannot be None: ' + str(suggested_name))
|
|
|
- elif body_id in self.jitted_entry_points:
|
|
|
- raise primitive_functions.PrimitiveFinished(
|
|
|
- self.jit_globals[self.jitted_entry_points[body_id]])
|
|
|
-
|
|
|
- compiled_func = self.lookup_compiled_body(body_id)
|
|
|
- if compiled_func is not None:
|
|
|
- raise primitive_functions.PrimitiveFinished(compiled_func)
|
|
|
-
|
|
|
- # Generate a name for the function we're about to analyze, and 're-compile'
|
|
|
- # it for the first time.
|
|
|
- function_name = self.generate_function_name(body_id, suggested_name)
|
|
|
- yield [("TAIL_CALL_ARGS", [self.jit_recompile, (task_root, body_id, function_name)])]
|
|
|
-
|
|
|
- def jit_rejit(self, task_root, body_id, function_name, compile_function_body=None):
|
|
|
- """Re-compiles the given function. If compilation fails, then the can-rejit
|
|
|
- flag is set to false."""
|
|
|
- old_jitted_func = self.jitted_entry_points[body_id]
|
|
|
- def __handle_jit_failed(_):
|
|
|
- self.jit_globals[self.get_can_rejit_name(function_name)] = False
|
|
|
- self.jitted_entry_points[body_id] = old_jitted_func
|
|
|
- self.no_jit_entry_points.remove(body_id)
|
|
|
- raise primitive_functions.PrimitiveFinished(None)
|
|
|
-
|
|
|
- yield [("TRY", [])]
|
|
|
- yield [("CATCH", [jit_runtime.JitCompilationFailedException, __handle_jit_failed])]
|
|
|
- jitted_function, = yield [
|
|
|
- ("CALL_ARGS",
|
|
|
- [self.jit_recompile, (task_root, body_id, function_name, compile_function_body)])]
|
|
|
- yield [("END_TRY", [])]
|
|
|
-
|
|
|
- # Update all aliases.
|
|
|
- for function_alias in self.jitted_function_aliases[body_id]:
|
|
|
- self.jit_globals[function_alias] = jitted_function
|
|
|
-
|
|
|
- def new_compile(self, body_id):
|
|
|
- print("Compiling body ID " + str(body_id))
|
|
|
- raise JitCompilationFailedException("Function was marked '%s'." % jit_runtime.MUTABLE_FUNCTION_KEY)
|
|
|
-
|
|
|
- #raise primitive_functions.PrimitiveFinished("pass")
|