|
@@ -4,6 +4,32 @@ import modelverse_jit.tree_ir as tree_ir
|
|
|
KWARGS_PARAMETER_NAME = "kwargs"
|
|
|
"""The name of the kwargs parameter in jitted functions."""
|
|
|
|
|
|
+def get_parameter_names(compiled_function):
|
|
|
+ """Gets the given compiled function's parameter names."""
|
|
|
+ if hasattr(compiled_function, '__code__'):
|
|
|
+ return compiled_function.__code__.co_varnames[
|
|
|
+ :compiled_function.__code__.co_argcount]
|
|
|
+ elif hasattr(compiled_function, '__init__'):
|
|
|
+ return get_parameter_names(compiled_function.__init__)[1:]
|
|
|
+ else:
|
|
|
+ raise ValueError("'compiled_function' must be a function or a type.")
|
|
|
+
|
|
|
+def apply_intrinsic(intrinsic_function, named_args):
|
|
|
+ """Applies the given intrinsic to the given sequence of named arguments."""
|
|
|
+ param_names = get_parameter_names(intrinsic_function)
|
|
|
+ if tuple(param_names) == tuple([n for n, _ in named_args]):
|
|
|
+ # Perfect match. Yay!
|
|
|
+ return intrinsic_function(**dict(named_args))
|
|
|
+ else:
|
|
|
+ # We'll have to store the arguments into locals to preserve
|
|
|
+ # the order of evaluation.
|
|
|
+ stored_args = [(name, tree_ir.StoreLocalInstruction(None, arg)) for name, arg in named_args]
|
|
|
+ arg_value_dict = dict([(name, arg.create_load()) for name, arg in stored_args])
|
|
|
+ store_instructions = [instruction for _, instruction in stored_args]
|
|
|
+ return tree_ir.CompoundInstruction(
|
|
|
+ tree_ir.create_block(*store_instructions),
|
|
|
+ intrinsic_function(**arg_value_dict))
|
|
|
+
|
|
|
class JitCompilationFailedException(Exception):
|
|
|
"""A type of exception that is raised when the jit fails to compile a function."""
|
|
|
pass
|
|
@@ -22,6 +48,7 @@ class ModelverseJit(object):
|
|
|
self.jit_count = 0
|
|
|
self.max_instructions = 30 if max_instructions is None else max_instructions
|
|
|
self.compiled_function_lookup = compiled_function_lookup
|
|
|
+ # jit_intrinsics is a function name -> intrinsic map.
|
|
|
self.jit_intrinsics = {}
|
|
|
self.compilation_dependencies = {}
|
|
|
self.jit_enabled = True
|
|
@@ -78,7 +105,9 @@ class ModelverseJit(object):
|
|
|
|
|
|
def register_compiled(self, body_id, compiled_function, function_name=None):
|
|
|
"""Registers a compiled entry point with the JIT."""
|
|
|
+ # Get the function's name.
|
|
|
function_name = self.generate_function_name(function_name)
|
|
|
+ # Map the body id to the given parameter list.
|
|
|
self.jitted_entry_points[body_id] = function_name
|
|
|
self.jit_globals[function_name] = compiled_function
|
|
|
if body_id in self.todo_entry_points:
|
|
@@ -94,25 +123,33 @@ class ModelverseJit(object):
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
- def register_intrinsic(self, name, apply_intrinsic):
|
|
|
+ def get_intrinsic(self, name):
|
|
|
+ """Tries to find an intrinsic version of the function with the
|
|
|
+ given name."""
|
|
|
+ if name in self.jit_intrinsics:
|
|
|
+ return self.jit_intrinsics[name]
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+ def register_intrinsic(self, name, intrinsic_function):
|
|
|
"""Registers the given intrisic with the JIT. This will make the JIT replace calls to
|
|
|
the function with the given entry point by an application of the specified function."""
|
|
|
- self.jit_intrinsics[name] = apply_intrinsic
|
|
|
+ self.jit_intrinsics[name] = intrinsic_function
|
|
|
|
|
|
def register_binary_intrinsic(self, name, operator):
|
|
|
"""Registers an intrinsic with the JIT that represents the given binary operation."""
|
|
|
- self.register_intrinsic(name, lambda lhs, rhs: tree_ir.CreateNodeWithValueInstruction(
|
|
|
+ self.register_intrinsic(name, lambda a, b: tree_ir.CreateNodeWithValueInstruction(
|
|
|
tree_ir.BinaryInstruction(
|
|
|
- tree_ir.ReadValueInstruction(lhs),
|
|
|
+ tree_ir.ReadValueInstruction(a),
|
|
|
operator,
|
|
|
- tree_ir.ReadValueInstruction(rhs))))
|
|
|
+ tree_ir.ReadValueInstruction(b))))
|
|
|
|
|
|
def register_unary_intrinsic(self, name, operator):
|
|
|
"""Registers an intrinsic with the JIT that represents the given unary operation."""
|
|
|
- self.register_intrinsic(name, lambda val: tree_ir.CreateNodeWithValueInstruction(
|
|
|
+ self.register_intrinsic(name, lambda a: tree_ir.CreateNodeWithValueInstruction(
|
|
|
tree_ir.UnaryInstruction(
|
|
|
operator,
|
|
|
- tree_ir.ReadValueInstruction(val))))
|
|
|
+ tree_ir.ReadValueInstruction(a))))
|
|
|
|
|
|
def jit_parameters(self, body_id):
|
|
|
"""Acquires the parameter list for the given body id node."""
|
|
@@ -708,18 +745,10 @@ class AnalysisState(object):
|
|
|
if body_id in self.jit.compilation_dependencies:
|
|
|
self.jit.compilation_dependencies[body_id].add(self.body_id)
|
|
|
|
|
|
- # Analyze the parameter list.
|
|
|
- try:
|
|
|
- gen = self.jit.jit_parameters(body_id)
|
|
|
- inp = None
|
|
|
- while True:
|
|
|
- inp = yield gen.send(inp)
|
|
|
- except primitive_functions.PrimitiveFinished as ex:
|
|
|
- _, parameter_names = ex.result
|
|
|
+ # Figure out if the function might be an intrinsic.
|
|
|
+ intrinsic = self.jit.get_intrinsic(callee_name)
|
|
|
|
|
|
- is_intrinsic = callee_name in self.jit.jit_intrinsics
|
|
|
-
|
|
|
- if not is_intrinsic:
|
|
|
+ if intrinsic is None:
|
|
|
compiled_func = self.jit.lookup_compiled_function(callee_name)
|
|
|
if compiled_func is None:
|
|
|
# Compile the callee.
|
|
@@ -738,39 +767,28 @@ class AnalysisState(object):
|
|
|
|
|
|
# Analyze the argument dictionary.
|
|
|
try:
|
|
|
- gen = self.analyze_argument_dict(first_parameter_id)
|
|
|
+ gen = self.analyze_arguments(first_parameter_id)
|
|
|
inp = None
|
|
|
while True:
|
|
|
inp = yield gen.send(inp)
|
|
|
except primitive_functions.PrimitiveFinished as ex:
|
|
|
- arg_dict = ex.result
|
|
|
-
|
|
|
- # Construct the argument list from the parameter list and
|
|
|
- # argument dictionary.
|
|
|
- arg_list = []
|
|
|
- for param_name in parameter_names:
|
|
|
- if param_name in arg_dict:
|
|
|
- arg_list.append(arg_dict[param_name])
|
|
|
- else:
|
|
|
- raise JitCompilationFailedException(
|
|
|
- "Cannot JIT-compile function call to '%s' with missing argument for "
|
|
|
- "formal parameter '%s'." % (callee_name, param_name))
|
|
|
+ named_args = ex.result
|
|
|
|
|
|
- if is_intrinsic:
|
|
|
+ if intrinsic is not None:
|
|
|
raise primitive_functions.PrimitiveFinished(
|
|
|
- self.jit.jit_intrinsics[callee_name](*arg_list))
|
|
|
+ apply_intrinsic(intrinsic, named_args))
|
|
|
else:
|
|
|
raise primitive_functions.PrimitiveFinished(
|
|
|
tree_ir.JitCallInstruction(
|
|
|
tree_ir.LoadGlobalInstruction(compiled_func_name),
|
|
|
- arg_list,
|
|
|
+ named_args,
|
|
|
tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
|
|
|
|
|
|
- def analyze_argument_dict(self, first_argument_id):
|
|
|
+ def analyze_arguments(self, first_argument_id):
|
|
|
"""Analyzes the parameter-to-argument mapping started by the specified first argument
|
|
|
node."""
|
|
|
next_param = first_argument_id
|
|
|
- argument_dict = {}
|
|
|
+ named_args = []
|
|
|
while next_param is not None:
|
|
|
param_name_id, = yield [("RD", [next_param, "name"])]
|
|
|
param_name, = yield [("RV", [param_name_id])]
|
|
@@ -781,11 +799,11 @@ class AnalysisState(object):
|
|
|
while True:
|
|
|
inp = yield gen.send(inp)
|
|
|
except primitive_functions.PrimitiveFinished as ex:
|
|
|
- argument_dict[param_name] = ex.result
|
|
|
+ named_args.append((param_name, ex.result))
|
|
|
|
|
|
next_param, = yield [("RD", [next_param, "next_param"])]
|
|
|
|
|
|
- raise primitive_functions.PrimitiveFinished(argument_dict)
|
|
|
+ raise primitive_functions.PrimitiveFinished(named_args)
|
|
|
|
|
|
def analyze_call(self, instruction_id):
|
|
|
"""Tries to analyze the given 'call' instruction."""
|