|
@@ -15,6 +15,8 @@ class ModelverseJit(object):
|
|
|
self.todo_entry_points = set()
|
|
|
self.no_jit_entry_points = set()
|
|
|
self.jitted_entry_points = {}
|
|
|
+ self.jit_globals = {}
|
|
|
+ self.jit_count = 0
|
|
|
|
|
|
def mark_entry_point(self, body_id):
|
|
|
"""Marks the node with the given identifier as a function entry point."""
|
|
@@ -48,20 +50,42 @@ class ModelverseJit(object):
|
|
|
|
|
|
def try_jit(self, body_id, parameter_list):
|
|
|
"""Tries to jit the function defined by the given entry point id and parameter list."""
|
|
|
+ if body_id in self.jitted_entry_points:
|
|
|
+ # We have already compiled this function.
|
|
|
+ raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
|
|
|
+ elif body_id in self.no_jit_entry_points:
|
|
|
+ # We're not allowed to jit this function or have tried and failed before.
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
|
|
|
+
|
|
|
gen = AnalysisState().analyze(body_id)
|
|
|
try:
|
|
|
inp = None
|
|
|
while True:
|
|
|
inp = yield gen.send(inp)
|
|
|
- except primitive_functions.PrimitiveFinished as e:
|
|
|
- constructed_ir = e.result
|
|
|
- except JitCompilationFailedException:
|
|
|
+ except primitive_functions.PrimitiveFinished as ex:
|
|
|
+ constructed_body = ex.result
|
|
|
+ except JitCompilationFailedException as ex:
|
|
|
self.mark_no_jit(body_id)
|
|
|
- raise
|
|
|
-
|
|
|
- print(constructed_ir)
|
|
|
- self.mark_no_jit(body_id)
|
|
|
- raise JitCompilationFailedException("Can't jit function body at " + str(body_id))
|
|
|
+ raise JitCompilationFailedException(
|
|
|
+ '%s (function at %d)' % (ex.message, body_id))
|
|
|
+
|
|
|
+ # Wrap the IR in a function definition, give it a unique name.
|
|
|
+ constructed_function = tree_ir.DefineFunctionInstruction(
|
|
|
+ 'jit_func%d' % self.jit_count,
|
|
|
+ parameter_list + ['**' + KWARGS_PARAMETER_NAME],
|
|
|
+ constructed_body)
|
|
|
+ self.jit_count += 1
|
|
|
+ # Convert the function definition to Python code, and compile it.
|
|
|
+ exec(str(constructed_function), self.jit_globals)
|
|
|
+ # Extract the compiled function from the JIT global state.
|
|
|
+ compiled_function = self.jit_globals[constructed_function.name]
|
|
|
+
|
|
|
+ print(constructed_function)
|
|
|
+
|
|
|
+ # Save the compiled function so we can reuse it later.
|
|
|
+ self.jitted_entry_points[body_id] = compiled_function
|
|
|
+ raise primitive_functions.PrimitiveFinished(compiled_function)
|
|
|
|
|
|
class AnalysisState(object):
|
|
|
"""The state of a bytecode analysis call graph."""
|
|
@@ -75,7 +99,7 @@ class AnalysisState(object):
|
|
|
# Check the analyzed_instructions set for instruction_id to avoid
|
|
|
# infinite loops.
|
|
|
if instruction_id in self.analyzed_instructions:
|
|
|
- raise JitCompilationFailedException('Cannon jit non-tree instruction graph.')
|
|
|
+ raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
|
|
|
|
|
|
self.analyzed_instructions.add(instruction_id)
|
|
|
instruction_val, = yield [("RV", [instruction_id])]
|
|
@@ -206,6 +230,7 @@ class AnalysisState(object):
|
|
|
# ("CD", [user_root, "last_output", new_last_output]),
|
|
|
# ("DE", [last_output_link])
|
|
|
# ]
|
|
|
+ # yield 'nop'
|
|
|
|
|
|
value_id, = yield [("RD", [instruction_id, "value"])]
|
|
|
gen = self.analyze(value_id)
|
|
@@ -256,7 +281,8 @@ class AnalysisState(object):
|
|
|
store_user_root.create_load(),
|
|
|
tree_ir.LiteralInstruction('last_output'),
|
|
|
new_last_output.create_load()),
|
|
|
- tree_ir.DeleteEdgeInstruction(last_output_link.create_load()))
|
|
|
+ tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
|
|
|
+ tree_ir.NopInstruction())
|
|
|
|
|
|
raise primitive_functions.PrimitiveFinished(result)
|
|
|
|