import modelverse_kernel.primitives as primitive_functions class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass MUTABLE_FUNCTION_KEY = "mutable" """A dictionary key for functions that are mutable.""" FUNCTION_BODY_KEY = "body" """A dictionary key for function bodies.""" KWARGS_PARAMETER_NAME = "kwargs" """The name of the kwargs parameter in jitted functions.""" CALL_FUNCTION_NAME = "__call_function" """The name of the '__call_function' function, in the jitted function scope.""" GET_INPUT_FUNCTION_NAME = "__get_input" """The name of the '__get_input' function, in the jitted function scope.""" JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant_function" """The name of the jit_thunk_constant_function function in the JIT's global context.""" JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global" """The name of the jit_thunk_global function in the JIT's global context.""" JIT_REJIT_FUNCTION_NAME = "__jit_rejit" """The name of the rejit function in the JIT's global context.""" JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME = "__jit_compile_function_body_fast" """The name of the compile_function_body_fast function in the JIT's global context.""" UNREACHABLE_FUNCTION_NAME = "__unreachable" """The name of the unreachable function in the JIT's global context.""" LOCALS_NODE_NAME = "jit_locals" """The name of the node that is connected to all JIT locals in a given function call.""" LOCALS_EDGE_NAME = "jit_locals_edge" """The name of the edge that connects the LOCALS_NODE_NAME node to a user root.""" GLOBAL_NOT_FOUND_MESSAGE_FORMAT = "Not found as global: %s" """The format of the 'not found as global' message. Takes a single argument.""" BASELINE_JIT_ORIGIN_NAME = "baseline-jit" """The origin name for functions that were produced by the baseline JIT.""" FAST_JIT_ORIGIN_NAME = "fast-jit" """The origin name for functions that were produced by the fast JIT.""" def format_stack_frame(function_name, debug_info, origin='unknown'): """Formats a stack frame, which consists of a function name, debug information and an origin.""" if function_name is None: function_name = 'unknown function' if debug_info is None: debug_info = '[unknown location] ' return '%sin %s (%s)' % (debug_info, function_name, origin) def format_trace_message(debug_info, function_name, origin='unknown'): """Creates a formatted trace message.""" return 'TRACE: %s' % format_stack_frame(function_name, debug_info, origin) def call_function(function_id, named_arguments, **kwargs): """Runs the function with the given id, passing it the specified argument dictionary.""" task_root = kwargs['task_root'] kernel = kwargs['mvk'] body_id, is_mutable = yield [ ("RD", [function_id, FUNCTION_BODY_KEY]), ("RD", [function_id, MUTABLE_FUNCTION_KEY])] # Try to jit the function here. We might be able to avoid building the stack # frame. def handle_jit_failed(_): """Interprets the function.""" interpreter_args = {'body_id' : body_id, 'named_arguments' : named_arguments} interpreter_args.update(kwargs) yield [("TAIL_CALL_KWARGS", [interpret_function_body, interpreter_args])] if is_mutable is not None: kernel.jit.mark_no_jit(body_id) yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])] else: kernel.jit.mark_entry_point(body_id) yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])] # Try to compile. compiled_func, = yield [("CALL_ARGS", [kernel.jit_compile, (task_root, body_id)])] yield [("END_TRY", [])] # Add the keyword arguments to the argument dictionary. named_arguments.update(kwargs) # Run the function. yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])] def interpret_function(function_id, named_arguments, **kwargs): """Makes the interpreter run the function with the given id for the specified argument dictionary.""" body_id, = yield [("RD", [function_id, FUNCTION_BODY_KEY])] args = {'body_id' : body_id, named_arguments : named_arguments} args.update(kwargs) yield [("TAIL_CALL_KWARGS", [interpret_function_body, args])] def interpret_function_body(body_id, named_arguments, **kwargs): """Makes the interpreter run the function body with the given id for the specified argument dictionary.""" task_root = kwargs['task_root'] kernel = kwargs['mvk'] user_frame, = yield [("RD", [task_root, "frame"])] inst, = yield [("RD", [user_frame, "IP"])] kernel.jit.mark_entry_point(body_id) # Create a new stack frame. frame_link, new_phase, new_frame, new_evalstack, new_symbols, \ new_returnvalue, intrinsic_return = \ yield [("RDE", [task_root, "frame"]), ("CNV", ["init"]), ("CN", []), ("CN", []), ("CN", []), ("CN", []), ("CN", []) ] _, _, _, _, _, _, _, _, _, _ = \ yield [("CD", [task_root, "frame", new_frame]), ("CD", [new_frame, "evalstack", new_evalstack]), ("CD", [new_frame, "symbols", new_symbols]), ("CD", [new_frame, "returnvalue", new_returnvalue]), ("CD", [new_frame, "caller", inst]), ("CD", [new_frame, "phase", new_phase]), ("CD", [new_frame, "IP", body_id]), ("CD", [new_frame, "prev", user_frame]), ("CD", [ new_frame, primitive_functions.EXCEPTION_RETURN_KEY, intrinsic_return]), ("DE", [frame_link]) ] # Put the parameters in the new stack frame's symbol table. (parameter_vars, parameter_names, _), = yield [ ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])] parameter_dict = dict(zip(parameter_names, parameter_vars)) for (key, value) in named_arguments.items(): param_var = parameter_dict[key] variable, = yield [("CN", [])] yield [("CD", [variable, "value", value])] symbol_edge, = yield [("CE", [new_symbols, variable])] yield [("CE", [symbol_edge, param_var])] taskname = kwargs['taskname'] def exception_handler(ex): # print('Returning from interpreted function. Result: %s' % ex.result) raise primitive_functions.PrimitiveFinished(ex.result) # Create an exception handler to catch and translate InterpretedFunctionFinished. yield [("TRY", [])] yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, exception_handler])] while 1: result, = yield [("CALL_ARGS", [kernel.execute_rule, (taskname,)])] # An instruction has completed. Forward it. yield result class UnreachableCodeException(Exception): """The type of exception that is thrown when supposedly unreachable code is executed.""" pass def unreachable(): """Marks unreachable code.""" raise UnreachableCodeException('An unreachable statement was reached.') def get_input(**parameters): """Retrieves input.""" mvk = parameters["mvk"] task_root = parameters["task_root"] while 1: yield [("CALL_ARGS", [mvk.input_init, (task_root,)])] # Finished if mvk.success: # Got some input, so we can access it raise primitive_functions.PrimitiveFinished(mvk.input_value) else: # No input, so yield None but don't stop yield None