123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- import modelverse_kernel.primitives as primitive_functions
- import modelverse_jit.tree_ir as tree_ir
- class JitCompilationFailedException(Exception):
- """A type of exception that is raised when the jit fails to compile a function."""
- pass
- class ModelverseJit(object):
- """A high-level interface to the modelverse JIT compiler."""
- def __init__(self):
- self.todo_entry_points = set()
- self.no_jit_entry_points = set()
- self.jitted_entry_points = {}
- def mark_entry_point(self, body_id):
- """Marks the node with the given identifier as a function entry point."""
- if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
- self.todo_entry_points.add(body_id)
- def is_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point."""
- return body_id in self.todo_entry_points or \
- body_id in self.no_jit_entry_points or \
- body_id in self.jitted_entry_points
- def is_jittable_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point that
- has not been marked as non-jittable."""
- return body_id in self.todo_entry_points or \
- body_id in self.jitted_entry_points
- def mark_no_jit(self, body_id):
- """Informs the JIT that the node with the given identifier is a function entry
- point that must never be jitted."""
- self.no_jit_entry_points.add(body_id)
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def register_compiled(self, body_id, compiled):
- """Registers a compiled entry point with the JIT."""
- self.jitted_entry_points[body_id] = compiled
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def try_jit(self, body_id, parameter_list):
- """Tries to jit the function defined by the given entry point id and parameter list."""
- gen = AnalysisState().analyze(body_id)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as e:
- pass
- self.mark_no_jit(body_id)
- raise JitCompilationFailedException("Can't JIT function body at " + str(body_id))
- class AnalysisState(object):
- """The state of a bytecode analysis call graph."""
- def __init__(self):
- self.analyzed_instructions = set()
- def analyze(self, instruction_id):
- """Tries to build an intermediate representation from the instruction with the
- given id."""
- # Add the instruction id to the analyzed_instructions set to avoid
- # infinite loops.
- self.analyzed_instructions.add(instruction_id)
- instruction_val, = yield [("RV", [instruction_id])]
- instruction_val = instruction_val["value"]
- if instruction_val in self.instruction_analyzers:
- gen = self.instruction_analyzers[instruction_val](self, instruction_id)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except StopIteration:
- raise Exception(
- "Instruction analyzer (for '%s') finished without returning a value!" %
- (instruction_val))
- except primitive_functions.PrimitiveFinished as outer_e:
- # Check if the instruction has a 'next' instruction.
- next_instr, = yield [("RD", [instruction_id, "next"])]
- if next_instr is None:
- raise outer_e
- else:
- gen = self.analyze(next_instr)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as inner_e:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- outer_e.result,
- inner_e.result))
- else:
- raise JitCompilationFailedException(
- "Unknown instruction type: '%s'" % (instruction_val))
- def analyze_all(self, instruction_ids):
- """Tries to compile a list of IR trees from the given list of instruction ids."""
- results = []
- for inst in instruction_ids:
- gen = self.analyze(inst)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as e:
- results.append(e.result)
- raise primitive_functions.PrimitiveFinished(results)
- def analyze_return(self, instruction_id):
- """Tries to analyze the given 'return' instruction."""
- retval_id, = yield [("RD", [instruction_id, 'value'])]
- if retval_id is None:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.ReturnInstruction(
- tree_ir.EmptyInstruction()))
- else:
- gen = self.analyze(retval_id)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as e:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.ReturnInstruction(e.result))
- def analyze_if(self, instruction_id):
- """Tries to analyze the given 'if' instruction."""
- cond, true, false = yield [
- ("RD", [instruction_id, "cond"]),
- ("RD", [instruction_id, "then"]),
- ("RD", [instruction_id, "else"])]
- gen = self.analyze_all([cond, true, false])
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as e:
- cond_r, true_r, false_r = e.result
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.ReadValueInstruction(cond_r),
- true_r,
- false_r))
- def analyze_constant(self, instruction_id):
- """Tries to analyze the given 'constant' (literal) instruction."""
- node_id, = yield [("RD", [instruction_id, "node"])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.LiteralInstruction(node_id))
- instruction_analyzers = {
- 'if' : analyze_if,
- 'return' : analyze_return,
- 'constant' : analyze_constant
- }
|