|
@@ -22,6 +22,7 @@ class ModelverseJit(object):
|
|
|
self.jit_count = 0
|
|
|
self.max_instructions = 30 if max_instructions is None else max_instructions
|
|
|
self.compiled_function_lookup = compiled_function_lookup
|
|
|
+ self.jit_intrinsics = {}
|
|
|
self.jit_enabled = True
|
|
|
|
|
|
def set_jit_enabled(self, is_enabled=True):
|
|
@@ -92,6 +93,19 @@ class ModelverseJit(object):
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
+ def register_intrinsic(self, name, apply_intrinsic):
|
|
|
+ """Registers the given intrisic with the JIT. This will make the JIT replace calls to
|
|
|
+ the function with the given entry point by an application of the specified function."""
|
|
|
+ self.jit_intrinsics[name] = apply_intrinsic
|
|
|
+
|
|
|
+ def register_binary_intrinsic(self, name, operator):
|
|
|
+ """Registers an intrinsic with the JIT that represents the given binary operation."""
|
|
|
+ self.register_intrinsic(name, lambda lhs, rhs: tree_ir.CreateNodeWithValueInstruction(
|
|
|
+ tree_ir.BinaryInstruction(
|
|
|
+ tree_ir.ReadValueInstruction(lhs),
|
|
|
+ operator,
|
|
|
+ tree_ir.ReadValueInstruction(rhs))))
|
|
|
+
|
|
|
def jit_parameters(self, body_id):
|
|
|
"""Acquires the parameter list for the given body id node."""
|
|
|
if body_id not in self.jitted_parameters:
|
|
@@ -680,21 +694,24 @@ class AnalysisState(object):
|
|
|
except primitive_functions.PrimitiveFinished as ex:
|
|
|
_, parameter_names = ex.result
|
|
|
|
|
|
- compiled_func = self.jit.lookup_compiled_function(callee_name)
|
|
|
- if compiled_func is None:
|
|
|
- # Compile the callee.
|
|
|
- try:
|
|
|
- gen = self.jit.jit_compile(self.user_root, body_id, callee_name)
|
|
|
- inp = None
|
|
|
- while True:
|
|
|
- inp = yield gen.send(inp)
|
|
|
- except primitive_functions.PrimitiveFinished as ex:
|
|
|
- pass
|
|
|
- else:
|
|
|
- self.jit.register_compiled(body_id, compiled_func, callee_name)
|
|
|
+ is_intrinsic = callee_name in self.jit.jit_intrinsics
|
|
|
|
|
|
- # Get the callee's name.
|
|
|
- compiled_func_name = self.jit.get_compiled_name(body_id)
|
|
|
+ if not is_intrinsic:
|
|
|
+ compiled_func = self.jit.lookup_compiled_function(callee_name)
|
|
|
+ if compiled_func is None:
|
|
|
+ # Compile the callee.
|
|
|
+ try:
|
|
|
+ gen = self.jit.jit_compile(self.user_root, body_id, callee_name)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ self.jit.register_compiled(body_id, compiled_func, callee_name)
|
|
|
+
|
|
|
+ # Get the callee's name.
|
|
|
+ compiled_func_name = self.jit.get_compiled_name(body_id)
|
|
|
|
|
|
# Analyze the argument dictionary.
|
|
|
try:
|
|
@@ -716,11 +733,15 @@ class AnalysisState(object):
|
|
|
"Cannot JIT-compile function call to '%s' with missing argument for "
|
|
|
"formal parameter '%s'." % (callee_name, param_name))
|
|
|
|
|
|
- raise primitive_functions.PrimitiveFinished(
|
|
|
- tree_ir.JitCallInstruction(
|
|
|
- tree_ir.LoadGlobalInstruction(compiled_func_name),
|
|
|
- arg_list,
|
|
|
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
|
|
|
+ if is_intrinsic:
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ self.jit.jit_intrinsics[callee_name](*arg_list))
|
|
|
+ else:
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ tree_ir.JitCallInstruction(
|
|
|
+ tree_ir.LoadGlobalInstruction(compiled_func_name),
|
|
|
+ arg_list,
|
|
|
+ tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
|
|
|
|
|
|
def analyze_argument_dict(self, first_argument_id):
|
|
|
"""Analyzes the parameter-to-argument mapping started by the specified first argument
|