|
@@ -15,6 +15,7 @@ class ModelverseJit(object):
|
|
|
self.todo_entry_points = set()
|
|
|
self.no_jit_entry_points = set()
|
|
|
self.jitted_entry_points = {}
|
|
|
+ self.jitted_parameters = {}
|
|
|
self.jit_globals = {
|
|
|
'PrimitiveFinished' : primitive_functions.PrimitiveFinished
|
|
|
}
|
|
@@ -43,7 +44,16 @@ class ModelverseJit(object):
|
|
|
is enabled and the function entry point has been marked jittable, or if
|
|
|
the function has already been compiled."""
|
|
|
return ((self.jit_enabled and body_id in self.todo_entry_points) or
|
|
|
- body_id in self.jitted_entry_points)
|
|
|
+ self.has_compiled(body_id))
|
|
|
+
|
|
|
+ def has_compiled(self, body_id):
|
|
|
+ """Tests if the function belonging to the given body node has been compiled yet."""
|
|
|
+ return body_id in self.jitted_entry_points
|
|
|
+
|
|
|
+ def get_compiled_name(self, body_id):
|
|
|
+ """Gets the name of the compiled version of the given body node in the JIT
|
|
|
+ global state."""
|
|
|
+ return self.jitted_entry_points[body_id]
|
|
|
|
|
|
def mark_no_jit(self, body_id):
|
|
|
"""Informs the JIT that the node with the given identifier is a function entry
|
|
@@ -52,18 +62,39 @@ class ModelverseJit(object):
|
|
|
if body_id in self.todo_entry_points:
|
|
|
self.todo_entry_points.remove(body_id)
|
|
|
|
|
|
+ def generate_function_name(self):
|
|
|
+ """Generates a new function name,"""
|
|
|
+ function_name = 'jit_func%d' % self.jit_count
|
|
|
+ self.jit_count += 1
|
|
|
+ return function_name
|
|
|
+
|
|
|
def register_compiled(self, body_id, compiled_function, function_name=None):
|
|
|
"""Registers a compiled entry point with the JIT."""
|
|
|
if function_name is None:
|
|
|
- function_name = 'jit_func%d' % self.jit_count
|
|
|
- self.jit_count += 1
|
|
|
+ function_name = self.generate_function_name()
|
|
|
|
|
|
self.jitted_entry_points[body_id] = function_name
|
|
|
self.jit_globals[function_name] = compiled_function
|
|
|
if body_id in self.todo_entry_points:
|
|
|
self.todo_entry_points.remove(body_id)
|
|
|
|
|
|
- def jit_compile(self, body_id, parameter_list):
|
|
|
+ def jit_parameters(self, body_id):
|
|
|
+ """Acquires the parameter list for the given body id node."""
|
|
|
+ if body_id not in self.jitted_parameters:
|
|
|
+ signature_id, = yield [("RRD", [body_id, "body"])]
|
|
|
+ signature_id = signature_id[0]
|
|
|
+ param_set_id, = yield [("RD", [signature_id, "params"])]
|
|
|
+ if param_set_id is None:
|
|
|
+ self.jitted_parameters[body_id] = ([], [])
|
|
|
+ else:
|
|
|
+ param_name_ids, = yield [("RDK", [param_set_id])]
|
|
|
+ param_names = yield [("RV", [n]) for n in param_name_ids]
|
|
|
+ param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
|
|
|
+ self.jitted_parameters[body_id] = (param_vars, param_names)
|
|
|
+
|
|
|
+ raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
|
|
|
+
|
|
|
+ def jit_compile(self, user_root, body_id):
|
|
|
"""Tries to jit the function defined by the given entry point id and parameter list."""
|
|
|
# The comment below makes pylint shut up about our (hopefully benign) use of exec here.
|
|
|
# pylint: disable=I0011,W0122
|
|
@@ -77,8 +108,25 @@ class ModelverseJit(object):
|
|
|
raise JitCompilationFailedException(
|
|
|
'Cannot jit function at %d because it is marked non-jittable.' % body_id)
|
|
|
|
|
|
+ # Generate a name for the function we're about to analyze, and pretend that
|
|
|
+ # it already exists. (we need to do this for recursive functions)
|
|
|
+ function_name = self.generate_function_name()
|
|
|
+ self.jitted_entry_points[body_id] = function_name
|
|
|
+ self.jit_globals[function_name] = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ gen = self.jit_parameters(body_id)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ parameter_ids, parameter_list = ex.result
|
|
|
+
|
|
|
+ param_dict = dict(zip(parameter_ids, parameter_list))
|
|
|
+ body_param_dict = dict(zip(parameter_ids, [p + "_ptr" for p in parameter_list]))
|
|
|
try:
|
|
|
- gen = AnalysisState(self.max_instructions).analyze(body_id)
|
|
|
+ gen = AnalysisState(
|
|
|
+ self, user_root, body_param_dict, self.max_instructions).analyze(body_id)
|
|
|
inp = None
|
|
|
while True:
|
|
|
inp = yield gen.send(inp)
|
|
@@ -86,35 +134,69 @@ class ModelverseJit(object):
|
|
|
constructed_body = ex.result
|
|
|
except JitCompilationFailedException as ex:
|
|
|
self.mark_no_jit(body_id)
|
|
|
+ del self.jitted_entry_points[body_id]
|
|
|
raise JitCompilationFailedException(
|
|
|
'%s (function at %d)' % (ex.message, body_id))
|
|
|
|
|
|
+ # Write a prologue and prepend it to the generated function body.
|
|
|
+ prologue_statements = []
|
|
|
+ for (key, val) in param_dict.items():
|
|
|
+ arg_ptr = tree_ir.StoreLocalInstruction(
|
|
|
+ body_param_dict[key],
|
|
|
+ tree_ir.CreateNodeInstruction())
|
|
|
+ prologue_statements.append(arg_ptr)
|
|
|
+ prologue_statements.append(
|
|
|
+ tree_ir.CreateDictionaryEdgeInstruction(
|
|
|
+ arg_ptr.create_load(),
|
|
|
+ tree_ir.LiteralInstruction('value'),
|
|
|
+ tree_ir.LoadLocalInstruction(val)))
|
|
|
+
|
|
|
+ constructed_body = tree_ir.create_block(
|
|
|
+ *(prologue_statements + [constructed_body]))
|
|
|
+
|
|
|
# Wrap the IR in a function definition, give it a unique name.
|
|
|
constructed_function = tree_ir.DefineFunctionInstruction(
|
|
|
- 'jit_func%d' % self.jit_count,
|
|
|
+ function_name,
|
|
|
parameter_list + ['**' + KWARGS_PARAMETER_NAME],
|
|
|
constructed_body.simplify())
|
|
|
- self.jit_count += 1
|
|
|
# Convert the function definition to Python code, and compile it.
|
|
|
exec(str(constructed_function), self.jit_globals)
|
|
|
# Extract the compiled function from the JIT global state.
|
|
|
- compiled_function = self.jit_globals[constructed_function.name]
|
|
|
- # Save the compiled function so we can reuse it later.
|
|
|
- self.jitted_entry_points[body_id] = constructed_function.name
|
|
|
+ compiled_function = self.jit_globals[function_name]
|
|
|
|
|
|
- print(constructed_function)
|
|
|
+ # print(constructed_function)
|
|
|
raise primitive_functions.PrimitiveFinished(compiled_function)
|
|
|
|
|
|
class AnalysisState(object):
|
|
|
"""The state of a bytecode analysis call graph."""
|
|
|
-
|
|
|
- def __init__(self, max_instructions=None):
|
|
|
+ def __init__(self, jit, user_root, local_mapping, max_instructions=None):
|
|
|
self.analyzed_instructions = set()
|
|
|
+ self.function_vars = set()
|
|
|
+ self.local_vars = set()
|
|
|
self.max_instructions = max_instructions
|
|
|
+ self.user_root = user_root
|
|
|
+ self.jit = jit
|
|
|
+ self.local_mapping = local_mapping
|
|
|
|
|
|
def get_local_name(self, local_id):
|
|
|
"""Gets the name for a local with the given id."""
|
|
|
- return 'local%d' % local_id
|
|
|
+ if local_id not in self.local_mapping:
|
|
|
+ self.local_mapping[local_id] = 'local%d' % local_id
|
|
|
+ return self.local_mapping[local_id]
|
|
|
+
|
|
|
+ def register_local_var(self, local_id):
|
|
|
+ """Registers the given variable node id as a local."""
|
|
|
+ if local_id in self.function_vars:
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Local is used as target of function call.")
|
|
|
+ self.local_vars.add(local_id)
|
|
|
+
|
|
|
+ def register_function_var(self, local_id):
|
|
|
+ """Registers the given variable node id as a function."""
|
|
|
+ if local_id in self.local_vars:
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Local is used as target of function call.")
|
|
|
+ self.function_vars.add(local_id)
|
|
|
|
|
|
def retrieve_user_root(self):
|
|
|
"""Creates an instruction that stores the user_root variable
|
|
@@ -134,7 +216,7 @@ class AnalysisState(object):
|
|
|
raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
|
|
|
elif (self.max_instructions is not None and
|
|
|
len(self.analyzed_instructions) > self.max_instructions):
|
|
|
- raise JitCompilationFailedException('Maximal number of instructions exceeded.')
|
|
|
+ raise JitCompilationFailedException('Maximum number of instructions exceeded.')
|
|
|
|
|
|
self.analyzed_instructions.add(instruction_id)
|
|
|
instruction_val, = yield [("RV", [instruction_id])]
|
|
@@ -397,6 +479,12 @@ class AnalysisState(object):
|
|
|
#
|
|
|
# tmp = global_var
|
|
|
|
|
|
+ name = self.get_local_name(var_id)
|
|
|
+
|
|
|
+ if var_name is None:
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ tree_ir.LoadLocalInstruction(name))
|
|
|
+
|
|
|
user_root = self.retrieve_user_root()
|
|
|
global_var = tree_ir.StoreLocalInstruction(
|
|
|
'global_var',
|
|
@@ -413,13 +501,12 @@ class AnalysisState(object):
|
|
|
tree_ir.LiteralInstruction(None)),
|
|
|
tree_ir.RaiseInstruction(
|
|
|
tree_ir.CallInstruction(
|
|
|
- tree_ir.LoadLocalInstruction('Exception'),
|
|
|
+ tree_ir.LoadGlobalInstruction('Exception'),
|
|
|
[tree_ir.LiteralInstruction(
|
|
|
"Runtime error: global '%s' not found" % var_name)
|
|
|
])),
|
|
|
tree_ir.EmptyInstruction())
|
|
|
|
|
|
- name = self.get_local_name(var_id)
|
|
|
raise primitive_functions.PrimitiveFinished(
|
|
|
tree_ir.SelectInstruction(
|
|
|
tree_ir.LocalExistsInstruction(name),
|
|
@@ -435,6 +522,8 @@ class AnalysisState(object):
|
|
|
"""Tries to analyze the given 'declare' function."""
|
|
|
var_id, = yield [("RD", [instruction_id, "var"])]
|
|
|
|
|
|
+ self.register_local_var(var_id)
|
|
|
+
|
|
|
name = self.get_local_name(var_id)
|
|
|
|
|
|
# The following logic declares a local:
|
|
@@ -556,13 +645,121 @@ class AnalysisState(object):
|
|
|
# down to reading the value corresponding to the 'value' key
|
|
|
# of the variable.
|
|
|
#
|
|
|
- # value, = yield [("RD", [returnvalue, "value"])]
|
|
|
+ # value, = yield [("RD", [returnvalue, "value"])]
|
|
|
|
|
|
raise primitive_functions.PrimitiveFinished(
|
|
|
tree_ir.ReadDictionaryValueInstruction(
|
|
|
var_r,
|
|
|
tree_ir.LiteralInstruction('value')))
|
|
|
|
|
|
+ def analyze_direct_call(self, callee_id, callee_name, first_parameter_id):
|
|
|
+ """Tries to analyze a direct 'call' instruction."""
|
|
|
+
|
|
|
+ self.register_function_var(callee_id)
|
|
|
+
|
|
|
+ body_id, = yield [("RD", [callee_id, "body"])]
|
|
|
+ # Analyze the parameter list.
|
|
|
+ try:
|
|
|
+ gen = self.jit.jit_parameters(body_id)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ _, parameter_names = ex.result
|
|
|
+
|
|
|
+ # Compile the callee.
|
|
|
+ try:
|
|
|
+ gen = self.jit.jit_compile(self.user_root, body_id)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Get the callee's name.
|
|
|
+ compiled_func_name = self.jit.get_compiled_name(body_id)
|
|
|
+
|
|
|
+ # Analyze the argument dictionary.
|
|
|
+ try:
|
|
|
+ gen = self.analyze_argument_dict(first_parameter_id)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ arg_dict = ex.result
|
|
|
+
|
|
|
+ # Construct the argument list from the parameter list and
|
|
|
+ # argument dictionary.
|
|
|
+ arg_list = []
|
|
|
+ for param_name in parameter_names:
|
|
|
+ if param_name in arg_dict:
|
|
|
+ arg_list.append(arg_dict[param_name])
|
|
|
+ else:
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Cannot JIT-compile function call to '%s' with missing argument for "
|
|
|
+ "formal parameter '%s'." % (callee_name, param_name))
|
|
|
+
|
|
|
+ raise primitive_functions.PrimitiveFinished(
|
|
|
+ tree_ir.JitCallInstruction(
|
|
|
+ tree_ir.LoadGlobalInstruction(compiled_func_name),
|
|
|
+ arg_list,
|
|
|
+ tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME)))
|
|
|
+
|
|
|
+ def analyze_argument_dict(self, first_argument_id):
|
|
|
+ """Analyzes the parameter-to-argument mapping started by the specified first argument
|
|
|
+ node."""
|
|
|
+ next_param = first_argument_id
|
|
|
+ argument_dict = {}
|
|
|
+ while next_param is not None:
|
|
|
+ param_name_id, = yield [("RD", [next_param, "name"])]
|
|
|
+ param_name, = yield [("RV", [param_name_id])]
|
|
|
+ param_val_id, = yield [("RD", [next_param, "value"])]
|
|
|
+ try:
|
|
|
+ gen = self.analyze(param_val_id)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ argument_dict[param_name] = ex.result
|
|
|
+
|
|
|
+ next_param, = yield [("RD", [next_param, "next_param"])]
|
|
|
+
|
|
|
+ raise primitive_functions.PrimitiveFinished(argument_dict)
|
|
|
+
|
|
|
+ def analyze_call(self, instruction_id):
|
|
|
+ """Tries to analyze the given 'call' instruction."""
|
|
|
+ func_id, first_param_id, = yield [("RD", [instruction_id, "func"]),
|
|
|
+ ("RD", [instruction_id, "params"])]
|
|
|
+
|
|
|
+ # Figure out what the 'func' instruction's type is.
|
|
|
+ func_instruction_op, = yield [("RV", [func_id])]
|
|
|
+ if func_instruction_op['value'] == 'access':
|
|
|
+ # Calls to 'access(resolve(var))' instructions are translated to direct calls.
|
|
|
+ access_value_id, = yield [("RD", [func_id, "var"])]
|
|
|
+ access_value_op, = yield [("RV", [access_value_id])]
|
|
|
+ if access_value_op['value'] == 'resolve':
|
|
|
+ resolved_var_id, = yield [("RD", [access_value_id, "var"])]
|
|
|
+ resolved_var_name, = yield [("RV", [resolved_var_id])]
|
|
|
+
|
|
|
+ # Try to look the name up as a global.
|
|
|
+ _globals, = yield [("RD", [self.user_root, "globals"])]
|
|
|
+ global_var, = yield [("RD", [_globals, resolved_var_name])]
|
|
|
+ global_val, = yield [("RD", [global_var, "value"])]
|
|
|
+
|
|
|
+ if global_val is None:
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ "Cannot JIT function calls that target an unknown value.")
|
|
|
+ else:
|
|
|
+ gen = self.analyze_direct_call(
|
|
|
+ global_val, resolved_var_name, first_param_id)
|
|
|
+ inp = None
|
|
|
+ while True:
|
|
|
+ inp = yield gen.send(inp)
|
|
|
+ # PrimitiveFinished exception will bubble up from here.
|
|
|
+
|
|
|
+ raise JitCompilationFailedException("Cannot JIT indirect function calls yet.")
|
|
|
+
|
|
|
+
|
|
|
instruction_analyzers = {
|
|
|
'if' : analyze_if,
|
|
|
'while' : analyze_while,
|
|
@@ -574,6 +771,7 @@ class AnalysisState(object):
|
|
|
'assign' : analyze_assign,
|
|
|
'access' : analyze_access,
|
|
|
'output' : analyze_output,
|
|
|
- 'input' : analyze_input
|
|
|
+ 'input' : analyze_input,
|
|
|
+ 'call' : analyze_call
|
|
|
}
|
|
|
|