123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- """Interprets parsed bytecode graphs."""
- import modelverse_jit.bytecode_ir as bytecode_ir
- import modelverse_jit.runtime as jit_runtime
- import modelverse_jit.source_map as source_map
- import modelverse_kernel.primitives as primitive_functions
- class BreakException(Exception):
- """A type of exception that is used to interpret 'break' instructions:
- the 'break' instructions throw a BreakException, which is then handled
- by the appropriate 'while' instruction."""
- def __init__(self, loop):
- Exception.__init__(self)
- self.loop = loop
- class ContinueException(Exception):
- """A type of exception that is used to interpret 'continue' instructions:
- the 'continue' instructions throw a ContinueException, which is then handled
- by the appropriate 'while' instruction."""
- def __init__(self, loop):
- Exception.__init__(self)
- self.loop = loop
- class InterpreterState(object):
- """The state of the bytecode interpreter."""
- def __init__(self, gc_root_node, keyword_arg_dict, src_map, nop_period=10):
- self.gc_root_node = gc_root_node
- self.nop_period = nop_period
- self.keyword_arg_dict = keyword_arg_dict
- self.src_map = src_map
- self.current_result = None
- self.nop_phase = 0
- self.local_vars = {}
- def import_local(self, node_id, value):
- """Imports the given value as a local in this interpreter state."""
- local_node, = yield [("CN", [])]
- yield [
- ("CE", [self.gc_root_node, local_node]),
- ("CD", [local_node, "value", value])]
- self.local_vars[node_id] = local_node
- raise primitive_functions.PrimitiveFinished(None)
- def schedule_nop(self):
- """Increments the nop-phase. If a nop should be performed, then True is returned.
- Otherwise, False."""
- self.nop_phase += 1
- if self.nop_phase == self.nop_period:
- self.nop_phase = 0
- return True
- else:
- return False
- def update_result(self, new_result):
- """Sets the current result to the given value, if it is not None."""
- if new_result is not None:
- self.current_result = new_result
- def get_task_root(self):
- """Gets the task root node id."""
- return self.keyword_arg_dict['task_root']
- def get_kernel(self):
- """Gets the Modelverse kernel instance."""
- return self.keyword_arg_dict['mvk']
- def interpret(self, instruction):
- """Interprets the given instruction and returns the current result."""
- instruction_type = type(instruction)
- if instruction_type in InterpreterState.INTERPRETERS:
- old_debug_info = self.src_map.debug_information
- debug_info = instruction.debug_information
- if debug_info is not None:
- self.src_map.debug_information = debug_info
- # Interpret the instruction.
- yield [("CALL_ARGS",
- [InterpreterState.INTERPRETERS[instruction_type], (self, instruction)])]
- self.src_map.debug_information = old_debug_info
- # Maybe perform a nop.
- if self.schedule_nop():
- yield None
- # Interpret the next instruction.
- next_instruction = instruction.next_instruction
- if next_instruction is not None:
- yield [("TAIL_CALL_ARGS", [self.interpret, (next_instruction,)])]
- else:
- raise primitive_functions.PrimitiveFinished(self.current_result)
- else:
- raise jit_runtime.JitCompilationFailedException(
- 'Unknown bytecode instruction: %r' % instruction)
- def interpret_select(self, instruction):
- """Interprets the given 'select' instruction."""
- cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
- cond_val, = yield [("RV", [cond_node])]
- if cond_val:
- yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.if_clause,)])]
- elif instruction.else_clause is not None:
- yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.else_clause,)])]
- else:
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_while(self, instruction):
- """Interprets the given 'while' instruction."""
- def __handle_break(exception):
- if exception.loop == instruction:
- # End the loop.
- raise primitive_functions.PrimitiveFinished(None)
- else:
- # Propagate the exception to the next 'while' loop.
- raise exception
- def __handle_continue(exception):
- if exception.loop == instruction:
- # Restart the loop.
- yield [("TAIL_CALL_ARGS", [self.interpret_while, (instruction,)])]
- else:
- # Propagate the exception to the next 'while' loop.
- raise exception
- yield [("TRY", [])]
- yield [("CATCH", [BreakException, __handle_break])]
- yield [("CATCH", [ContinueException, __handle_continue])]
- while 1:
- cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])]
- cond_val, = yield [("RV", [cond_node])]
- if cond_val:
- yield [("CALL_ARGS", [self.interpret, (instruction.body,)])]
- else:
- break
- yield [("END_TRY", [])]
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_break(self, instruction):
- """Interprets the given 'break' instruction."""
- raise BreakException(instruction.loop)
- def interpret_continue(self, instruction):
- """Interprets the given 'continue' instruction."""
- raise ContinueException(instruction.loop)
- def interpret_return(self, instruction):
- """Interprets the given 'return' instruction."""
- if instruction.value is None:
- raise primitive_functions.InterpretedFunctionFinished(None)
- else:
- return_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
- raise primitive_functions.InterpretedFunctionFinished(return_node)
- def interpret_call(self, instruction):
- """Interprets the given 'call' instruction."""
- target, = yield [("CALL_ARGS", [self.interpret, (instruction.target,)])]
- named_args = {}
- for name, arg_instruction in instruction.argument_list:
- arg, = yield [("CALL_ARGS", [self.interpret, (arg_instruction,)])]
- named_args[name] = arg
- kwargs = {'function_id': target, 'named_arguments': named_args}
- kwargs.update(self.keyword_arg_dict)
- result, = yield [("CALL_KWARGS", [jit_runtime.call_function, kwargs])]
- if result is not None:
- yield [("CE", [self.gc_root_node, result])]
- self.update_result(result)
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_constant(self, instruction):
- """Interprets the given 'constant' instruction."""
- self.update_result(instruction.constant_id)
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_input(self, instruction):
- """Interprets the given 'input' instruction."""
- result, = yield [("CALL_KWARGS", [jit_runtime.get_input, self.keyword_arg_dict])]
- self.update_result(result)
- yield [("CE", [self.gc_root_node, result])]
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_output(self, instruction):
- """Interprets the given 'output' instruction."""
- output_value, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
- task_root = self.get_task_root()
- last_output, last_output_link, new_last_output = yield [
- ("RD", [task_root, "last_output"]),
- ("RDE", [task_root, "last_output"]),
- ("CN", [])
- ]
- yield [
- ("CD", [last_output, "value", output_value]),
- ("CD", [last_output, "next", new_last_output]),
- ("CD", [task_root, "last_output", new_last_output]),
- ("DE", [last_output_link])
- ]
- yield None
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_declare(self, instruction):
- """Interprets a 'declare' (local) instruction."""
- node_id = instruction.variable.node_id
- if node_id in self.local_vars:
- raise primitive_functions.PrimitiveFinished(None)
- else:
- local_node, = yield [("CN", [])]
- yield [("CE", [self.gc_root_node, local_node])]
- self.local_vars[node_id] = local_node
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_global(self, instruction):
- """Interprets a (declare) 'global' instruction."""
- var_name = instruction.variable.name
- task_root = self.get_task_root()
- _globals, = yield [("RD", [task_root, "globals"])]
- global_var, = yield [("RD", [_globals, var_name])]
- if global_var is None:
- global_var, = yield [("CN", [])]
- yield [("CD", [_globals, var_name, global_var])]
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_resolve(self, instruction):
- """Interprets a 'resolve' instruction."""
- node_id = instruction.variable.node_id
- if node_id in self.local_vars:
- self.update_result(self.local_vars[node_id])
- raise primitive_functions.PrimitiveFinished(None)
- else:
- task_root = self.get_task_root()
- var_name = instruction.variable.name
- _globals, = yield [("RD", [task_root, "globals"])]
- global_var, = yield [("RD", [_globals, var_name])]
- if global_var is None:
- raise Exception(jit_runtime.GLOBAL_NOT_FOUND_MESSAGE_FORMAT % var_name)
- mvk = self.get_kernel()
- if mvk.suggest_function_names and mvk.jit.get_global_body_id(var_name) is None:
- global_val, = yield [("RD", [global_var, "value"])]
- if global_val is not None:
- func_body, = yield [("RD", [global_val, "body"])]
- if func_body is not None:
- mvk.jit.register_global(func_body, var_name)
- self.update_result(global_var)
- yield [("CE", [self.gc_root_node, global_var])]
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_access(self, instruction):
- """Interprets an 'access' instruction."""
- pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
- value_node, = yield [("RD", [pointer_node, "value"])]
- self.update_result(value_node)
- yield [("CE", [self.gc_root_node, value_node])]
- raise primitive_functions.PrimitiveFinished(None)
- def interpret_assign(self, instruction):
- """Interprets an 'assign' instruction."""
- pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])]
- value_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])]
- value_link, = yield [("RDE", [pointer_node, "value"])]
- yield [
- ("CD", [pointer_node, "value", value_node]),
- ("DE", [value_link])]
- raise primitive_functions.PrimitiveFinished(None)
- INTERPRETERS = {
- bytecode_ir.SelectInstruction: interpret_select,
- bytecode_ir.WhileInstruction: interpret_while,
- bytecode_ir.BreakInstruction: interpret_break,
- bytecode_ir.ContinueInstruction: interpret_continue,
- bytecode_ir.ReturnInstruction: interpret_return,
- bytecode_ir.CallInstruction: interpret_call,
- bytecode_ir.ConstantInstruction: interpret_constant,
- bytecode_ir.InputInstruction: interpret_input,
- bytecode_ir.OutputInstruction: interpret_output,
- bytecode_ir.DeclareInstruction: interpret_declare,
- bytecode_ir.GlobalInstruction: interpret_global,
- bytecode_ir.ResolveInstruction: interpret_resolve,
- bytecode_ir.AccessInstruction: interpret_access,
- bytecode_ir.AssignInstruction: interpret_assign
- }
- def interpret_bytecode_function(function_name, body_bytecode, local_arguments, keyword_arguments):
- """Interprets the bytecode function with the given name, body, named arguments and
- keyword arguments."""
- src_map = source_map.ManualSourceMap()
- yield [("DEBUG_INFO", [function_name, src_map, jit_runtime.BYTECODE_INTERPRETER_ORIGIN_NAME])]
- task_root = keyword_arguments['task_root']
- gc_root_node, = yield [("CN", [])]
- gc_root_edge, = yield [("CE", [task_root, gc_root_node])]
- interpreter = InterpreterState(gc_root_node, keyword_arguments, src_map)
- for param_id, arg_node in local_arguments.items():
- yield [("CALL_ARGS", [interpreter.import_local, (param_id, arg_node)])]
- def __handle_return(exception):
- yield [("DE", [gc_root_edge])]
- raise primitive_functions.PrimitiveFinished(exception.result)
- def __handle_break(_):
- raise jit_runtime.UnreachableCodeException(
- "Function '%s' tries to break out of a loop that is not currently executing." %
- function_name)
- def __handle_continue(_):
- raise jit_runtime.UnreachableCodeException(
- "Function '%s' tries to continue a loop that is not currently executing." %
- function_name)
- # Perform a nop before interpreting the function.
- yield None
- yield [("TRY", [])]
- yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, __handle_return])]
- yield [("CATCH", [BreakException, __handle_break])]
- yield [("CATCH", [ContinueException, __handle_continue])]
- yield [("CALL_ARGS", [interpreter.interpret, (body_bytecode,)])]
- yield [("END_TRY", [])]
- raise jit_runtime.UnreachableCodeException("Function '%s' failed to return." % function_name)
|