123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565 |
- import modelverse_kernel.primitives as primitive_functions
- import modelverse_jit.tree_ir as tree_ir
- KWARGS_PARAMETER_NAME = "remainder"
- """The name of the kwargs parameter in jitted functions."""
- class JitCompilationFailedException(Exception):
- """A type of exception that is raised when the jit fails to compile a function."""
- pass
- class ModelverseJit(object):
- """A high-level interface to the modelverse JIT compiler."""
- def __init__(self, max_instructions=None):
- self.todo_entry_points = set()
- self.no_jit_entry_points = set()
- self.jitted_entry_points = {}
- self.jit_globals = {
- 'PrimitiveFinished' : primitive_functions.PrimitiveFinished
- }
- self.jit_count = 0
- self.max_instructions = 30 if max_instructions is None else max_instructions
- def mark_entry_point(self, body_id):
- """Marks the node with the given identifier as a function entry point."""
- if body_id not in self.no_jit_entry_points and body_id not in self.jitted_entry_points:
- self.todo_entry_points.add(body_id)
- def is_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point."""
- return body_id in self.todo_entry_points or \
- body_id in self.no_jit_entry_points or \
- body_id in self.jitted_entry_points
- def is_jittable_entry_point(self, body_id):
- """Tells if the node with the given identifier is a function entry point that
- has not been marked as non-jittable."""
- return body_id in self.todo_entry_points or \
- body_id in self.jitted_entry_points
- def mark_no_jit(self, body_id):
- """Informs the JIT that the node with the given identifier is a function entry
- point that must never be jitted."""
- self.no_jit_entry_points.add(body_id)
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def register_compiled(self, body_id, compiled):
- """Registers a compiled entry point with the JIT."""
- self.jitted_entry_points[body_id] = compiled
- if body_id in self.todo_entry_points:
- self.todo_entry_points.remove(body_id)
- def try_jit(self, body_id, parameter_list):
- """Tries to jit the function defined by the given entry point id and parameter list."""
- if body_id in self.jitted_entry_points:
- # We have already compiled this function.
- raise primitive_functions.PrimitiveFinished(self.jitted_entry_points[body_id])
- elif body_id in self.no_jit_entry_points:
- # We're not allowed to jit this function or have tried and failed before.
- raise JitCompilationFailedException(
- 'Cannot jit function at %d because it is marked non-jittable.' % body_id)
- try:
- gen = AnalysisState(self.max_instructions).analyze(body_id)
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- constructed_body = ex.result
- except JitCompilationFailedException as ex:
- self.mark_no_jit(body_id)
- raise JitCompilationFailedException(
- '%s (function at %d)' % (ex.message, body_id))
- # Wrap the IR in a function definition, give it a unique name.
- constructed_function = tree_ir.DefineFunctionInstruction(
- 'jit_func%d' % self.jit_count,
- parameter_list + ['**' + KWARGS_PARAMETER_NAME],
- constructed_body.simplify())
- self.jit_count += 1
- # Convert the function definition to Python code, and compile it.
- exec(str(constructed_function), self.jit_globals)
- # Extract the compiled function from the JIT global state.
- compiled_function = self.jit_globals[constructed_function.name]
- print(constructed_function)
- # Save the compiled function so we can reuse it later.
- self.jitted_entry_points[body_id] = compiled_function
- raise primitive_functions.PrimitiveFinished(compiled_function)
- class AnalysisState(object):
- """The state of a bytecode analysis call graph."""
- def __init__(self, max_instructions=None):
- self.analyzed_instructions = set()
- self.max_instructions = max_instructions
- def get_local_name(self, local_id):
- """Gets the name for a local with the given id."""
- return 'local%d' % local_id
- def retrieve_user_root(self):
- """Creates an instruction that stores the user_root variable
- in a local."""
- return tree_ir.StoreLocalInstruction(
- 'user_root',
- tree_ir.LoadIndexInstruction(
- tree_ir.LoadLocalInstruction(KWARGS_PARAMETER_NAME),
- tree_ir.LiteralInstruction('user_root')))
- def analyze(self, instruction_id):
- """Tries to build an intermediate representation from the instruction with the
- given id."""
- # Check the analyzed_instructions set for instruction_id to avoid
- # infinite loops.
- if instruction_id in self.analyzed_instructions:
- raise JitCompilationFailedException('Cannot jit non-tree instruction graph.')
- elif (self.max_instructions is not None and
- len(self.analyzed_instructions) > self.max_instructions):
- raise JitCompilationFailedException('Maximal number of instructions exceeded.')
- self.analyzed_instructions.add(instruction_id)
- instruction_val, = yield [("RV", [instruction_id])]
- instruction_val = instruction_val["value"]
- if instruction_val in self.instruction_analyzers:
- gen = self.instruction_analyzers[instruction_val](self, instruction_id)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except StopIteration:
- raise Exception(
- "Instruction analyzer (for '%s') finished without returning a value!" %
- (instruction_val))
- except primitive_functions.PrimitiveFinished as outer_e:
- # Check if the instruction has a 'next' instruction.
- next_instr, = yield [("RD", [instruction_id, "next"])]
- if next_instr is None:
- raise outer_e
- else:
- gen = self.analyze(next_instr)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as inner_e:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- outer_e.result,
- inner_e.result))
- else:
- raise JitCompilationFailedException(
- "Unknown instruction type: '%s'" % (instruction_val))
- def analyze_all(self, instruction_ids):
- """Tries to compile a list of IR trees from the given list of instruction ids."""
- results = []
- for inst in instruction_ids:
- gen = self.analyze(inst)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- results.append(ex.result)
- raise primitive_functions.PrimitiveFinished(results)
- def analyze_return(self, instruction_id):
- """Tries to analyze the given 'return' instruction."""
- retval_id, = yield [("RD", [instruction_id, 'value'])]
- if retval_id is None:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.ReturnInstruction(
- tree_ir.EmptyInstruction()))
- else:
- gen = self.analyze(retval_id)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- raise primitive_functions.PrimitiveFinished(
- tree_ir.ReturnInstruction(ex.result))
- def analyze_if(self, instruction_id):
- """Tries to analyze the given 'if' instruction."""
- cond, true, false = yield [
- ("RD", [instruction_id, "cond"]),
- ("RD", [instruction_id, "then"]),
- ("RD", [instruction_id, "else"])]
- gen = self.analyze_all(
- [cond, true]
- if false is None
- else [cond, true, false])
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- if false is None:
- cond_r, true_r = ex.result
- false_r = tree_ir.EmptyInstruction()
- else:
- cond_r, true_r, false_r = ex.result
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.ReadValueInstruction(cond_r),
- true_r,
- false_r))
- def analyze_while(self, instruction_id):
- """Tries to analyze the given 'while' instruction."""
- cond, body = yield [
- ("RD", [instruction_id, "cond"]),
- ("RD", [instruction_id, "body"])]
- gen = self.analyze_all([cond, body])
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- cond_r, body_r = ex.result
- raise primitive_functions.PrimitiveFinished(
- tree_ir.LoopInstruction(
- tree_ir.CompoundInstruction(
- tree_ir.SelectInstruction(
- tree_ir.ReadValueInstruction(cond_r),
- tree_ir.EmptyInstruction(),
- tree_ir.BreakInstruction()),
- body_r)))
- def analyze_constant(self, instruction_id):
- """Tries to analyze the given 'constant' (literal) instruction."""
- node_id, = yield [("RD", [instruction_id, "node"])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.LiteralInstruction(node_id))
- def analyze_output(self, instruction_id):
- """Tries to analyze the given 'output' instruction."""
- # The plan is to basically generate this tree:
- #
- # value = <some tree>
- # last_output, last_output_link, new_last_output = \
- # yield [("RD", [user_root, "last_output"]),
- # ("RDE", [user_root, "last_output"]),
- # ("CN", []),
- # ]
- # _, _, _, _ = \
- # yield [("CD", [last_output, "value", value]),
- # ("CD", [last_output, "next", new_last_output]),
- # ("CD", [user_root, "last_output", new_last_output]),
- # ("DE", [last_output_link])
- # ]
- # yield None
- value_id, = yield [("RD", [instruction_id, "value"])]
- gen = self.analyze(value_id)
- try:
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- value_local = tree_ir.StoreLocalInstruction('value', ex.result)
- store_user_root = self.retrieve_user_root()
- last_output = tree_ir.StoreLocalInstruction(
- 'last_output',
- tree_ir.ReadDictionaryValueInstruction(
- store_user_root.create_load(),
- tree_ir.LiteralInstruction('last_output')))
- last_output_link = tree_ir.StoreLocalInstruction(
- 'last_output_link',
- tree_ir.ReadDictionaryEdgeInstruction(
- store_user_root.create_load(),
- tree_ir.LiteralInstruction('last_output')))
- new_last_output = tree_ir.StoreLocalInstruction(
- 'new_last_output',
- tree_ir.CreateNodeInstruction())
- result = tree_ir.create_block(
- value_local,
- store_user_root,
- last_output,
- last_output_link,
- new_last_output,
- tree_ir.CreateDictionaryEdgeInstruction(
- last_output.create_load(),
- tree_ir.LiteralInstruction('value'),
- value_local.create_load()),
- tree_ir.CreateDictionaryEdgeInstruction(
- last_output.create_load(),
- tree_ir.LiteralInstruction('next'),
- new_last_output.create_load()),
- tree_ir.CreateDictionaryEdgeInstruction(
- store_user_root.create_load(),
- tree_ir.LiteralInstruction('last_output'),
- new_last_output.create_load()),
- tree_ir.DeleteEdgeInstruction(last_output_link.create_load()),
- tree_ir.NopInstruction())
- raise primitive_functions.PrimitiveFinished(result)
- def analyze_input(self, _):
- """Tries to analyze the given 'input' instruction."""
- # The plan is to generate this tree:
- #
- # value = None
- # while True:
- # if value is None:
- # yield None # nop
- # else:
- # break
- #
- # _input = yield [("RD", [user_root, "input"])]
- # value = yield [("RD", [_input, "value"])]
- #
- # _next = yield [("RD", [_input, "next"])]
- # yield [("CD", [user_root, "input", _next])]
- # yield [("DN", [_input])]
- user_root = self.retrieve_user_root()
- _input = tree_ir.StoreLocalInstruction(
- '_input',
- tree_ir.ReadDictionaryValueInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('input')))
- value = tree_ir.StoreLocalInstruction(
- 'value',
- tree_ir.ReadDictionaryValueInstruction(
- _input.create_load(),
- tree_ir.LiteralInstruction('value')))
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- tree_ir.create_block(
- user_root,
- value.create_store(tree_ir.LiteralInstruction(None)),
- tree_ir.LoopInstruction(
- tree_ir.create_block(
- tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- value.create_load(),
- 'is',
- tree_ir.LiteralInstruction(None)),
- tree_ir.NopInstruction(),
- tree_ir.BreakInstruction()),
- _input,
- value)),
- tree_ir.CreateDictionaryEdgeInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('input'),
- tree_ir.ReadDictionaryValueInstruction(
- _input.create_load(),
- tree_ir.LiteralInstruction('next'))),
- tree_ir.DeleteNodeInstruction(_input.create_load())),
- value.create_load()))
- def analyze_resolve(self, instruction_id):
- """Tries to analyze the given 'resolve' instruction."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- var_name, = yield [("RV", [var_id])]
- # To resolve a variable, we'll do something along the
- # lines of:
- #
- # if 'local_var' in locals():
- # tmp = local_var
- # else:
- # _globals, = yield [("RD", [user_root, "globals"])]
- # global_var, = yield [("RD", [_globals, var_name])]
- #
- # if global_var is None:
- # raise Exception("Runtime error: global '%s' not found" % (var_name))
- #
- # tmp = global_var
- user_root = self.retrieve_user_root()
- global_var = tree_ir.StoreLocalInstruction(
- 'global_var',
- tree_ir.ReadDictionaryValueInstruction(
- tree_ir.ReadDictionaryValueInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('globals')),
- tree_ir.LiteralInstruction(var_name)))
- err_block = tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- global_var.create_load(),
- 'is',
- tree_ir.LiteralInstruction(None)),
- tree_ir.RaiseInstruction(
- tree_ir.CallInstruction(
- tree_ir.LoadLocalInstruction('Exception'),
- [tree_ir.LiteralInstruction(
- "Runtime error: global '%s' not found" % var_name)
- ])),
- tree_ir.EmptyInstruction())
- name = self.get_local_name(var_id)
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.LocalExistsInstruction(name),
- tree_ir.LoadLocalInstruction(name),
- tree_ir.CompoundInstruction(
- tree_ir.create_block(
- user_root,
- global_var,
- err_block),
- global_var.create_load())))
- def analyze_declare(self, instruction_id):
- """Tries to analyze the given 'declare' function."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- name = self.get_local_name(var_id)
- # The following logic declares a local:
- #
- # if 'local_name' not in locals():
- # local_name, = yield [("CN", [])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.SelectInstruction(
- tree_ir.LocalExistsInstruction(name),
- tree_ir.EmptyInstruction(),
- tree_ir.StoreLocalInstruction(
- name,
- tree_ir.CreateNodeInstruction())))
- def analyze_global(self, instruction_id):
- """Tries to analyze the given 'global' (declaration) instruction."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- var_name, = yield [("RV", [var_id])]
- # To resolve a variable, we'll do something along the
- # lines of:
- #
- # _globals, = yield [("RD", [user_root, "globals"])]
- # global_var = yield [("RD", [_globals, var_name])]
- #
- # if global_var is None:
- # global_var, = yield [("CN", [])]
- # yield [("CD", [_globals, var_name, global_var])]
- #
- # tmp = global_var
- user_root = self.retrieve_user_root()
- _globals = tree_ir.StoreLocalInstruction(
- '_globals',
- tree_ir.ReadDictionaryValueInstruction(
- user_root.create_load(),
- tree_ir.LiteralInstruction('globals')))
- global_var = tree_ir.StoreLocalInstruction(
- 'global_var',
- tree_ir.ReadDictionaryValueInstruction(
- _globals.create_load(),
- tree_ir.LiteralInstruction(var_name)))
- raise primitive_functions.PrimitiveFinished(
- tree_ir.CompoundInstruction(
- tree_ir.create_block(
- user_root,
- _globals,
- global_var,
- tree_ir.SelectInstruction(
- tree_ir.BinaryInstruction(
- global_var.create_load(),
- 'is',
- tree_ir.LiteralInstruction(None)),
- tree_ir.create_block(
- global_var.create_store(
- tree_ir.CreateNodeInstruction()),
- tree_ir.CreateDictionaryEdgeInstruction(
- _globals.create_load(),
- tree_ir.LiteralInstruction(var_name),
- global_var.create_load())),
- tree_ir.EmptyInstruction())),
- global_var.create_load()))
- def analyze_assign(self, instruction_id):
- """Tries to analyze the given 'assign' instruction."""
- var_id, value_id = yield [("RD", [instruction_id, "var"]),
- ("RD", [instruction_id, "value"])]
- try:
- gen = self.analyze_all([var_id, value_id])
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- var_r, value_r = ex.result
- # Assignments work like this:
- #
- # value_link = yield [("RDE", [variable, "value"])]
- # _, _ = yield [("CD", [variable, "value", value]),
- # ("DE", [value_link])]
- variable = tree_ir.StoreLocalInstruction('variable', var_r)
- value = tree_ir.StoreLocalInstruction('value', value_r)
- value_link = tree_ir.StoreLocalInstruction(
- 'value_link',
- tree_ir.ReadDictionaryEdgeInstruction(
- variable.create_load(),
- tree_ir.LiteralInstruction('value')))
- raise primitive_functions.PrimitiveFinished(
- tree_ir.create_block(
- variable,
- value,
- value_link,
- tree_ir.CreateDictionaryEdgeInstruction(
- variable.create_load(),
- tree_ir.LiteralInstruction('value'),
- value.create_load()),
- tree_ir.DeleteEdgeInstruction(
- value_link.create_load())))
- def analyze_access(self, instruction_id):
- """Tries to analyze the given 'access' instruction."""
- var_id, = yield [("RD", [instruction_id, "var"])]
- try:
- gen = self.analyze(var_id)
- inp = None
- while True:
- inp = yield gen.send(inp)
- except primitive_functions.PrimitiveFinished as ex:
- var_r = ex.result
- # Accessing a variable is pretty easy. It really just boils
- # down to reading the value corresponding to the 'value' key
- # of the variable.
- #
- # value, = yield [("RD", [returnvalue, "value"])]
- raise primitive_functions.PrimitiveFinished(
- tree_ir.ReadDictionaryValueInstruction(
- var_r,
- tree_ir.LiteralInstruction('value')))
- instruction_analyzers = {
- 'if' : analyze_if,
- 'while' : analyze_while,
- 'return' : analyze_return,
- 'constant' : analyze_constant,
- 'resolve' : analyze_resolve,
- 'declare' : analyze_declare,
- 'global' : analyze_global,
- 'assign' : analyze_assign,
- 'access' : analyze_access,
- 'output' : analyze_output,
- 'input' : analyze_input
- }
|