runtime.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import modelverse_kernel.primitives as primitive_functions
  2. class JitCompilationFailedException(Exception):
  3. """A type of exception that is raised when the jit fails to compile a function."""
  4. pass
  5. MUTABLE_FUNCTION_KEY = "mutable"
  6. """A dictionary key for functions that are mutable."""
  7. FUNCTION_BODY_KEY = "body"
  8. """A dictionary key for function bodies."""
  9. def call_function(function_id, named_arguments, **kwargs):
  10. """Runs the function with the given id, passing it the specified argument dictionary."""
  11. task_root = kwargs['task_root']
  12. kernel = kwargs['mvk']
  13. body_id, is_mutable = yield [
  14. ("RD", [function_id, FUNCTION_BODY_KEY]),
  15. ("RD", [function_id, MUTABLE_FUNCTION_KEY])]
  16. # Try to jit the function here. We might be able to avoid building the stack
  17. # frame.
  18. def handle_jit_failed(_):
  19. """Interprets the function."""
  20. interpreter_args = {'function_id' : function_id, 'named_arguments' : named_arguments}
  21. interpreter_args.update(kwargs)
  22. yield [("TAIL_CALL_KWARGS", [interpret_function, interpreter_args])]
  23. if is_mutable is not None:
  24. kernel.jit.mark_no_jit(body_id)
  25. yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])]
  26. else:
  27. kernel.jit.mark_entry_point(body_id)
  28. yield [("TRY", [])]
  29. yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])]
  30. # Try to compile.
  31. compiled_func, = yield [("CALL_ARGS", [kernel.jit_compile, (task_root, body_id)])]
  32. yield [("END_TRY", [])]
  33. # Add the keyword arguments to the argument dictionary.
  34. named_arguments.update(kwargs)
  35. # Run the function.
  36. yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])]
  37. def interpret_function(function_id, named_arguments, **kwargs):
  38. """Makes the interpreter run the function with the given id for the specified
  39. argument dictionary."""
  40. task_root = kwargs['task_root']
  41. kernel = kwargs['mvk']
  42. task_frame, = yield [("RD", [task_root, "frame"])]
  43. inst, body_id = yield [("RD", [task_frame, "IP"]), ("RD", [function_id, FUNCTION_BODY_KEY])]
  44. kernel.jit.mark_entry_point(body_id)
  45. # Create a new stack frame.
  46. frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
  47. new_returnvalue, intrinsic_return = \
  48. yield [("RDE", [task_root, "frame"]),
  49. ("CNV", ["init"]),
  50. ("CN", []),
  51. ("CN", []),
  52. ("CN", []),
  53. ("CN", []),
  54. ("CN", [])
  55. ]
  56. _, _, _, _, _, _, _, _, _, _ = \
  57. yield [("CD", [task_root, "frame", new_frame]),
  58. ("CD", [new_frame, "evalstack", new_evalstack]),
  59. ("CD", [new_frame, "symbols", new_symbols]),
  60. ("CD", [new_frame, "returnvalue", new_returnvalue]),
  61. ("CD", [new_frame, "caller", inst]),
  62. ("CD", [new_frame, "phase", new_phase]),
  63. ("CD", [new_frame, "IP", body_id]),
  64. ("CD", [new_frame, "prev", task_frame]),
  65. ("CD", [
  66. new_frame,
  67. primitive_functions.EXCEPTION_RETURN_KEY,
  68. intrinsic_return]),
  69. ("DE", [frame_link])
  70. ]
  71. # Put the parameters in the new stack frame's symbol table.
  72. (parameter_vars, parameter_names, _), = yield [
  73. ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])]
  74. parameter_dict = dict(zip(parameter_names, parameter_vars))
  75. for (key, value) in named_arguments.items():
  76. param_var = parameter_dict[key]
  77. variable, = yield [("CN", [])]
  78. yield [("CD", [variable, "value", value])]
  79. symbol_edge, = yield [("CE", [new_symbols, variable])]
  80. yield [("CE", [symbol_edge, param_var])]
  81. taskname = kwargs['taskname']
  82. def exception_handler(ex):
  83. # print('Returning from interpreted function. Result: %s' % ex.result)
  84. raise primitive_functions.PrimitiveFinished(ex.result)
  85. # Create an exception handler to catch and translate InterpretedFunctionFinished.
  86. yield [("TRY", [])]
  87. yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, exception_handler])]
  88. while 1:
  89. result, = yield [("CALL_ARGS", [kernel.execute_rule, (taskname,)])]
  90. # An instruction has completed. Forward it.
  91. yield result
  92. def get_input(**parameters):
  93. """Retrieves input."""
  94. mvk = parameters["mvk"]
  95. task_root = parameters["task_root"]
  96. while 1:
  97. yield [("CALL_ARGS", [mvk.input_init, (task_root,)])]
  98. # Finished
  99. if mvk.success:
  100. # Got some input, so we can access it
  101. raise primitive_functions.PrimitiveFinished(mvk.input_value)
  102. else:
  103. # No input, so yield None but don't stop
  104. yield None