import modelverse_kernel.primitives as primitive_functions class JitCompilationFailedException(Exception): """A type of exception that is raised when the jit fails to compile a function.""" pass MUTABLE_FUNCTION_KEY = "mutable" """A dictionary key for functions that are mutable.""" FUNCTION_BODY_KEY = "body" """A dictionary key for function bodies.""" def call_function(function_id, named_arguments, **kwargs): """Runs the function with the given id, passing it the specified argument dictionary.""" task_root = kwargs['task_root'] kernel = kwargs['mvk'] body_id, is_mutable = yield [ ("RD", [function_id, FUNCTION_BODY_KEY]), ("RD", [function_id, MUTABLE_FUNCTION_KEY])] # Try to jit the function here. We might be able to avoid building the stack # frame. def handle_jit_failed(_): """Interprets the function.""" interpreter_args = {'function_id' : function_id, 'named_arguments' : named_arguments} interpreter_args.update(kwargs) yield [("TAIL_CALL_KWARGS", [interpret_function, interpreter_args])] if is_mutable is not None: kernel.jit.mark_no_jit(body_id) yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])] else: kernel.jit.mark_entry_point(body_id) yield [("TRY", [])] yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])] # Try to compile. compiled_func, = yield [("CALL_ARGS", [kernel.jit_compile, (task_root, body_id)])] yield [("END_TRY", [])] # Add the keyword arguments to the argument dictionary. named_arguments.update(kwargs) # Run the function. yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])] def interpret_function(function_id, named_arguments, **kwargs): """Makes the interpreter run the function with the given id for the specified argument dictionary.""" task_root = kwargs['task_root'] kernel = kwargs['mvk'] task_frame, = yield [("RD", [task_root, "frame"])] inst, body_id = yield [("RD", [task_frame, "IP"]), ("RD", [function_id, FUNCTION_BODY_KEY])] kernel.jit.mark_entry_point(body_id) # Create a new stack frame. frame_link, new_phase, new_frame, new_evalstack, new_symbols, \ new_returnvalue, intrinsic_return = \ yield [("RDE", [task_root, "frame"]), ("CNV", ["init"]), ("CN", []), ("CN", []), ("CN", []), ("CN", []), ("CN", []) ] _, _, _, _, _, _, _, _, _, _ = \ yield [("CD", [task_root, "frame", new_frame]), ("CD", [new_frame, "evalstack", new_evalstack]), ("CD", [new_frame, "symbols", new_symbols]), ("CD", [new_frame, "returnvalue", new_returnvalue]), ("CD", [new_frame, "caller", inst]), ("CD", [new_frame, "phase", new_phase]), ("CD", [new_frame, "IP", body_id]), ("CD", [new_frame, "prev", task_frame]), ("CD", [ new_frame, primitive_functions.EXCEPTION_RETURN_KEY, intrinsic_return]), ("DE", [frame_link]) ] # Put the parameters in the new stack frame's symbol table. (parameter_vars, parameter_names, _), = yield [ ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])] parameter_dict = dict(zip(parameter_names, parameter_vars)) for (key, value) in named_arguments.items(): param_var = parameter_dict[key] variable, = yield [("CN", [])] yield [("CD", [variable, "value", value])] symbol_edge, = yield [("CE", [new_symbols, variable])] yield [("CE", [symbol_edge, param_var])] taskname = kwargs['taskname'] def exception_handler(ex): # print('Returning from interpreted function. Result: %s' % ex.result) raise primitive_functions.PrimitiveFinished(ex.result) # Create an exception handler to catch and translate InterpretedFunctionFinished. yield [("TRY", [])] yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, exception_handler])] while 1: result, = yield [("CALL_ARGS", [kernel.execute_rule, (taskname,)])] # An instruction has completed. Forward it. yield result def get_input(**parameters): """Retrieves input.""" mvk = parameters["mvk"] task_root = parameters["task_root"] while 1: yield [("CALL_ARGS", [mvk.input_init, (task_root,)])] # Finished if mvk.success: # Got some input, so we can access it raise primitive_functions.PrimitiveFinished(mvk.input_value) else: # No input, so yield None but don't stop yield None