123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- import modelverse_kernel.primitives as primitive_functions
- class JitCompilationFailedException(Exception):
- """A type of exception that is raised when the jit fails to compile a function."""
- pass
- MUTABLE_FUNCTION_KEY = "mutable"
- """A dictionary key for functions that are mutable."""
- FUNCTION_BODY_KEY = "body"
- """A dictionary key for function bodies."""
- KWARGS_PARAMETER_NAME = "kwargs"
- """The name of the kwargs parameter in jitted functions."""
- CALL_FUNCTION_NAME = "__call_function"
- """The name of the '__call_function' function, in the jitted function scope."""
- GET_INPUT_FUNCTION_NAME = "__get_input"
- """The name of the '__get_input' function, in the jitted function scope."""
- JIT_THUNK_CONSTANT_FUNCTION_NAME = "__jit_thunk_constant_function"
- """The name of the jit_thunk_constant_function function in the JIT's global context."""
- JIT_THUNK_GLOBAL_FUNCTION_NAME = "__jit_thunk_global"
- """The name of the jit_thunk_global function in the JIT's global context."""
- JIT_REJIT_FUNCTION_NAME = "__jit_rejit"
- """The name of the rejit function in the JIT's global context."""
- JIT_COMPILE_FUNCTION_BODY_FAST_FUNCTION_NAME = "__jit_compile_function_body_fast"
- """The name of the compile_function_body_fast function in the JIT's global context."""
- UNREACHABLE_FUNCTION_NAME = "__unreachable"
- """The name of the unreachable function in the JIT's global context."""
- LOCALS_NODE_NAME = "jit_locals"
- """The name of the node that is connected to all JIT locals in a given function call."""
- LOCALS_EDGE_NAME = "jit_locals_edge"
- """The name of the edge that connects the LOCALS_NODE_NAME node to a user root."""
- GLOBAL_NOT_FOUND_MESSAGE_FORMAT = "Not found as global: %s"
- """The format of the 'not found as global' message. Takes a single argument."""
- BASELINE_JIT_ORIGIN_NAME = "baseline-jit"
- """The origin name for functions that were produced by the baseline JIT."""
- FAST_JIT_ORIGIN_NAME = "fast-jit"
- """The origin name for functions that were produced by the fast JIT."""
- def format_stack_frame(function_name, debug_info, origin='unknown'):
- """Formats a stack frame, which consists of a function name, debug
- information and an origin."""
- if function_name is None:
- function_name = 'unknown function'
- if debug_info is None:
- debug_info = '[unknown location] '
- return '%sin %s (%s)' % (debug_info, function_name, origin)
- def format_trace_message(debug_info, function_name, origin='unknown'):
- """Creates a formatted trace message."""
- return 'TRACE: %s' % format_stack_frame(function_name, debug_info, origin)
- def call_function(function_id, named_arguments, **kwargs):
- """Runs the function with the given id, passing it the specified argument dictionary."""
- task_root = kwargs['task_root']
- kernel = kwargs['mvk']
- body_id, is_mutable = yield [
- ("RD", [function_id, FUNCTION_BODY_KEY]),
- ("RD", [function_id, MUTABLE_FUNCTION_KEY])]
- # Try to jit the function here. We might be able to avoid building the stack
- # frame.
- def handle_jit_failed(_):
- """Interprets the function."""
- interpreter_args = {'body_id' : body_id, 'named_arguments' : named_arguments}
- interpreter_args.update(kwargs)
- yield [("TAIL_CALL_KWARGS", [interpret_function_body, interpreter_args])]
- if is_mutable is not None:
- kernel.jit.mark_no_jit(body_id)
- yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])]
- else:
- kernel.jit.mark_entry_point(body_id)
- yield [("TRY", [])]
- yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])]
- # Try to compile.
- compiled_func, = yield [("CALL_ARGS", [kernel.jit_compile, (task_root, body_id)])]
- yield [("END_TRY", [])]
- # Add the keyword arguments to the argument dictionary.
- named_arguments.update(kwargs)
- # Run the function.
- yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])]
- def interpret_function(function_id, named_arguments, **kwargs):
- """Makes the interpreter run the function with the given id for the specified
- argument dictionary."""
- body_id, = yield [("RD", [function_id, FUNCTION_BODY_KEY])]
- args = {'body_id' : body_id, named_arguments : named_arguments}
- args.update(kwargs)
- yield [("TAIL_CALL_KWARGS", [interpret_function_body, args])]
- def interpret_function_body(body_id, named_arguments, **kwargs):
- """Makes the interpreter run the function body with the given id for the specified
- argument dictionary."""
- task_root = kwargs['task_root']
- kernel = kwargs['mvk']
- user_frame, = yield [("RD", [task_root, "frame"])]
- inst, = yield [("RD", [user_frame, "IP"])]
- kernel.jit.mark_entry_point(body_id)
- # Create a new stack frame.
- frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
- new_returnvalue, intrinsic_return = \
- yield [("RDE", [task_root, "frame"]),
- ("CNV", ["init"]),
- ("CN", []),
- ("CN", []),
- ("CN", []),
- ("CN", []),
- ("CN", [])
- ]
- _, _, _, _, _, _, _, _, _, _ = \
- yield [("CD", [task_root, "frame", new_frame]),
- ("CD", [new_frame, "evalstack", new_evalstack]),
- ("CD", [new_frame, "symbols", new_symbols]),
- ("CD", [new_frame, "returnvalue", new_returnvalue]),
- ("CD", [new_frame, "caller", inst]),
- ("CD", [new_frame, "phase", new_phase]),
- ("CD", [new_frame, "IP", body_id]),
- ("CD", [new_frame, "prev", user_frame]),
- ("CD", [
- new_frame,
- primitive_functions.EXCEPTION_RETURN_KEY,
- intrinsic_return]),
- ("DE", [frame_link])
- ]
- # Put the parameters in the new stack frame's symbol table.
- (parameter_vars, parameter_names, _), = yield [
- ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])]
- parameter_dict = dict(zip(parameter_names, parameter_vars))
- for (key, value) in named_arguments.items():
- param_var = parameter_dict[key]
- variable, = yield [("CN", [])]
- yield [("CD", [variable, "value", value])]
- symbol_edge, = yield [("CE", [new_symbols, variable])]
- yield [("CE", [symbol_edge, param_var])]
- taskname = kwargs['taskname']
- def exception_handler(ex):
- # print('Returning from interpreted function. Result: %s' % ex.result)
- raise primitive_functions.PrimitiveFinished(ex.result)
- # Create an exception handler to catch and translate InterpretedFunctionFinished.
- yield [("TRY", [])]
- yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, exception_handler])]
- while 1:
- result, = yield [("CALL_ARGS", [kernel.execute_rule, (taskname,)])]
- # An instruction has completed. Forward it.
- yield result
- class UnreachableCodeException(Exception):
- """The type of exception that is thrown when supposedly unreachable code is executed."""
- pass
- def unreachable():
- """Marks unreachable code."""
- raise UnreachableCodeException('An unreachable statement was reached.')
- def get_input(**parameters):
- """Retrieves input."""
- mvk = parameters["mvk"]
- task_root = parameters["task_root"]
- while 1:
- yield [("CALL_ARGS", [mvk.input_init, (task_root,)])]
- # Finished
- if mvk.success:
- # Got some input, so we can access it
- raise primitive_functions.PrimitiveFinished(mvk.input_value)
- else:
- # No input, so yield None but don't stop
- yield None
|