"""Interprets parsed bytecode graphs.""" import modelverse_jit.bytecode_ir as bytecode_ir import modelverse_jit.runtime as jit_runtime import modelverse_kernel.primitives as primitive_functions class BreakException(Exception): """A type of exception that is used to interpret 'break' instructions: the 'break' instructions throw a BreakException, which is then handled by the appropriate 'while' instruction.""" def __init__(self, loop): Exception.__init__(self) self.loop = loop class ContinueException(Exception): """A type of exception that is used to interpret 'continue' instructions: the 'continue' instructions throw a ContinueException, which is then handled by the appropriate 'while' instruction.""" def __init__(self, loop): Exception.__init__(self) self.loop = loop class InterpreterState(object): """The state of the bytecode interpreter.""" def __init__(self, gc_root_node, keyword_arg_dict, nop_period=20): self.gc_root_node = gc_root_node self.nop_period = nop_period self.keyword_arg_dict = keyword_arg_dict self.current_result = None self.nop_phase = 0 self.local_vars = {} def import_local(self, node_id, value): """Imports the given value as a local in this interpreter state.""" local_node, = yield [("CN", [])] yield [ ("CE", [self.gc_root_node, local_node]), ("CD", [local_node, "value", value])] self.local_vars[node_id] = local_node raise primitive_functions.PrimitiveFinished(None) def schedule_nop(self): """Increments the nop-phase. If a nop should be performed, then True is returned. Otherwise, False.""" self.nop_phase += 1 if self.nop_phase == self.nop_period: self.nop_phase = 0 return True else: return False def update_result(self, new_result): """Sets the current result to the given value, if it is not None.""" if new_result is not None: self.current_result = new_result def get_task_root(self): """Gets the task root node id.""" return self.keyword_arg_dict['task_root'] def get_kernel(self): """Gets the Modelverse kernel instance.""" return self.keyword_arg_dict['mvk'] def interpret(self, instruction): """Interprets the given instruction and returns the current result.""" instruction_type = type(instruction) if instruction_type in InterpreterState.INTERPRETERS: # Interpret the instruction. yield [("CALL_ARGS", [InterpreterState.INTERPRETERS[instruction_type], (self, instruction)])] # Maybe perform a nop. if self.schedule_nop(): yield None # Interpret the next instruction. next_instruction = instruction.next_instruction if next_instruction is not None: yield [("TAIL_CALL_ARGS", [self.interpret, (next_instruction,)])] else: raise primitive_functions.PrimitiveFinished(self.current_result) else: raise jit_runtime.JitCompilationFailedException( 'Unknown bytecode instruction: %r' % instruction) def interpret_select(self, instruction): """Interprets the given 'select' instruction.""" cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])] cond_val, = yield [("RV", [cond_node])] if cond_val: yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.if_clause,)])] elif instruction.else_clause is not None: yield [("TAIL_CALL_ARGS", [self.interpret, (instruction.else_clause,)])] else: raise primitive_functions.PrimitiveFinished(None) def interpret_while(self, instruction): """Interprets the given 'while' instruction.""" def __handle_break(exception): if exception.loop == instruction: # End the loop. raise primitive_functions.PrimitiveFinished(None) else: # Propagate the exception to the next 'while' loop. raise exception def __handle_continue(exception): if exception.loop == instruction: # Restart the loop. yield [("TAIL_CALL_ARGS", [self.interpret, (instruction,)])] else: # Propagate the exception to the next 'while' loop. raise exception yield [("TRY", [])] yield [("CATCH", [BreakException, __handle_break])] yield [("CATCH", [ContinueException, __handle_continue])] while 1: cond_node, = yield [("CALL_ARGS", [self.interpret, (instruction.condition,)])] cond_val, = yield [("RV", [cond_node])] if cond_val: yield [("CALL_ARGS", [self.interpret, (instruction.body,)])] else: break yield [("END_TRY", [])] raise primitive_functions.PrimitiveFinished(None) def interpret_break(self, instruction): """Interprets the given 'break' instruction.""" raise BreakException(instruction.loop) def interpret_continue(self, instruction): """Interprets the given 'continue' instruction.""" raise ContinueException(instruction.loop) def interpret_return(self, instruction): """Interprets the given 'return' instruction.""" if instruction.value is None: raise primitive_functions.InterpretedFunctionFinished(None) else: return_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])] raise primitive_functions.InterpretedFunctionFinished(return_node) def interpret_call(self, instruction): """Interprets the given 'call' instruction.""" target, = yield [("CALL_ARGS", [self.interpret, (instruction.target,)])] named_args = {} for name, arg_instruction in instruction.argument_list: arg, = yield [("CALL_ARGS", [self.interpret, (arg_instruction,)])] named_args[name] = arg kwargs = {'function_id': target, 'named_arguments': named_args} kwargs.update(self.keyword_arg_dict) result, = yield [("CALL_KWARGS", [jit_runtime.call_function, kwargs])] if result is not None: yield [("CE", [self.gc_root_node, result])] self.update_result(result) raise primitive_functions.PrimitiveFinished(None) def interpret_constant(self, instruction): """Interprets the given 'constant' instruction.""" self.update_result(instruction.constant_id) raise primitive_functions.PrimitiveFinished(None) def interpret_input(self, instruction): """Interprets the given 'input' instruction.""" result, = yield [("CALL_KWARGS", [jit_runtime.get_input, self.keyword_arg_dict])] self.update_result(result) yield [("CE", [self.gc_root_node, result])] raise primitive_functions.PrimitiveFinished(None) def interpret_output(self, instruction): """Interprets the given 'output' instruction.""" output_value, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])] task_root = self.get_task_root() last_output, last_output_link, new_last_output = yield [ ("RD", [task_root, "last_output"]), ("RDE", [task_root, "last_output"]), ("CN", []) ] yield [ ("CD", [last_output, "value", output_value]), ("CD", [last_output, "next", new_last_output]), ("CD", [task_root, "last_output", new_last_output]), ("DE", [last_output_link]) ] yield None raise primitive_functions.PrimitiveFinished(None) def interpret_declare(self, instruction): """Interprets a 'declare' (local) instruction.""" node_id = instruction.variable.node_id if node_id in self.local_vars: self.update_result(self.local_vars[node_id]) raise primitive_functions.PrimitiveFinished(None) else: local_node, = yield [("CN", [])] yield [("CE", [self.gc_root_node, local_node])] self.update_result(local_node) self.local_vars[node_id] = local_node raise primitive_functions.PrimitiveFinished(None) def interpret_global(self, instruction): """Interprets a (declare) 'global' instruction.""" var_name = instruction.variable.name task_root = self.get_task_root() _globals, = yield [("RD", [task_root, "globals"])] global_var, = yield [("RD", [_globals, var_name])] if global_var is None: global_var, = yield [("CN", [])] yield [("CD", [_globals, var_name, global_var])] self.update_result(global_var) yield [("CE", [self.gc_root_node, global_var])] raise primitive_functions.PrimitiveFinished(None) def interpret_resolve(self, instruction): """Interprets a 'resolve' instruction.""" node_id = instruction.variable.node_id if node_id in self.local_vars: self.update_result(self.local_vars[node_id]) raise primitive_functions.PrimitiveFinished(None) else: task_root = self.get_task_root() var_name = instruction.variable.name _globals, = yield [("RD", [task_root, "globals"])] global_var, = yield [("RD", [_globals, var_name])] if global_var is None: raise Exception(jit_runtime.GLOBAL_NOT_FOUND_MESSAGE_FORMAT % var_name) mvk = self.get_kernel() if mvk.suggest_function_names and mvk.jit.get_global_body_id(var_name) is None: global_val, = yield [("RD", [global_var, "value"])] if global_val is not None: func_body, = yield [("RD", [global_val, "body"])] if func_body is not None: mvk.jit.register_global(func_body, var_name) self.update_result(global_var) yield [("CE", [self.gc_root_node, global_var])] raise primitive_functions.PrimitiveFinished(None) def interpret_access(self, instruction): """Interprets an 'access' instruction.""" pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])] value_node, = yield [("RD", [pointer_node, "value"])] self.update_result(value_node) yield [("CE", [self.gc_root_node, value_node])] raise primitive_functions.PrimitiveFinished(None) def interpret_assign(self, instruction): """Interprets an 'assign' instruction.""" pointer_node, = yield [("CALL_ARGS", [self.interpret, (instruction.pointer,)])] value_node, = yield [("CALL_ARGS", [self.interpret, (instruction.value,)])] value_link, = yield [("RDE", [pointer_node, "value"])] yield [ ("CD", [pointer_node, "value", value_node]), ("DE", [value_link])] raise primitive_functions.PrimitiveFinished(None) INTERPRETERS = { bytecode_ir.SelectInstruction: interpret_select, bytecode_ir.WhileInstruction: interpret_while, bytecode_ir.BreakInstruction: interpret_break, bytecode_ir.ContinueInstruction: interpret_continue, bytecode_ir.ReturnInstruction: interpret_return, bytecode_ir.CallInstruction: interpret_call, bytecode_ir.ConstantInstruction: interpret_constant, bytecode_ir.InputInstruction: interpret_input, bytecode_ir.OutputInstruction: interpret_output, bytecode_ir.DeclareInstruction: interpret_declare, bytecode_ir.GlobalInstruction: interpret_global, bytecode_ir.ResolveInstruction: interpret_resolve, bytecode_ir.AccessInstruction: interpret_access, bytecode_ir.AssignInstruction: interpret_assign } def interpret_bytecode_function(function_name, body_bytecode, local_arguments, keyword_arguments): """Interprets the bytecode function with the given name, body, named arguments and keyword arguments.""" yield [("DEBUG_INFO", [function_name, None, jit_runtime.BYTECODE_INTERPRETER_ORIGIN_NAME])] task_root = keyword_arguments['task_root'] gc_root_node, = yield [("CN", [])] gc_root_edge, = yield [("CE", [task_root, gc_root_node])] interpreter = InterpreterState(gc_root_node, keyword_arguments) for param_id, arg_node in local_arguments.items(): yield [("CALL_ARGS", [interpreter.import_local, (param_id, arg_node)])] def __handle_return(exception): yield [("DE", [gc_root_edge])] raise primitive_functions.PrimitiveFinished(exception.result) def __handle_break(_): raise jit_runtime.UnreachableCodeException( "Function '%s' tries to break out of a loop that is not currently executing." % function_name) def __handle_continue(_): raise jit_runtime.UnreachableCodeException( "Function '%s' tries to continue a loop that is not currently executing." % function_name) # Perform a nop before interpreting the function. yield None yield [("TRY", [])] yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, __handle_return])] yield [("CATCH", [BreakException, __handle_break])] yield [("CATCH", [ContinueException, __handle_continue])] yield [("CALL_ARGS", [interpreter.interpret, (body_bytecode,)])] yield [("END_TRY", [])] raise jit_runtime.UnreachableCodeException("Function '%s' failed to return." % function_name)