runtime.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import modelverse_kernel.primitives as primitive_functions
  2. class JitCompilationFailedException(Exception):
  3. """A type of exception that is raised when the jit fails to compile a function."""
  4. pass
  5. MUTABLE_FUNCTION_KEY = "mutable"
  6. """A dictionary key for functions that are mutable."""
  7. FUNCTION_BODY_KEY = "body"
  8. """A dictionary key for function bodies."""
  9. KWARGS_PARAMETER_NAME = "kwargs"
  10. """The name of the kwargs parameter in jitted functions."""
  11. CALL_FUNCTION_NAME = "__call_function"
  12. """The name of the '__call_function' function, in the jitted function scope."""
  13. GET_INPUT_FUNCTION_NAME = "__get_input"
  14. """The name of the '__get_input' function, in the jitted function scope."""
  15. LOCALS_NODE_NAME = "jit_locals"
  16. """The name of the node that is connected to all JIT locals in a given function call."""
  17. LOCALS_EDGE_NAME = "jit_locals_edge"
  18. """The name of the edge that connects the LOCALS_NODE_NAME node to a user root."""
  19. def call_function(function_id, named_arguments, **kwargs):
  20. """Runs the function with the given id, passing it the specified argument dictionary."""
  21. user_root = kwargs['user_root']
  22. kernel = kwargs['mvk']
  23. body_id, is_mutable = yield [
  24. ("RD", [function_id, FUNCTION_BODY_KEY]),
  25. ("RD", [function_id, MUTABLE_FUNCTION_KEY])]
  26. # Try to jit the function here. We might be able to avoid building the stack
  27. # frame.
  28. def handle_jit_failed(_):
  29. """Interprets the function."""
  30. interpreter_args = {'body_id' : body_id, 'named_arguments' : named_arguments}
  31. interpreter_args.update(kwargs)
  32. yield [("TAIL_CALL_KWARGS", [interpret_function_body, interpreter_args])]
  33. if is_mutable is not None:
  34. kernel.jit.mark_no_jit(body_id)
  35. yield [("TAIL_CALL_ARGS", [handle_jit_failed, ()])]
  36. else:
  37. kernel.jit.mark_entry_point(body_id)
  38. yield [("TRY", [])]
  39. yield [("CATCH", [JitCompilationFailedException, handle_jit_failed])]
  40. # Try to compile.
  41. compiled_func, = yield [("CALL_ARGS", [kernel.jit_compile, (user_root, body_id)])]
  42. yield [("END_TRY", [])]
  43. # Add the keyword arguments to the argument dictionary.
  44. named_arguments.update(kwargs)
  45. # Run the function.
  46. yield [("TAIL_CALL_KWARGS", [compiled_func, named_arguments])]
  47. def interpret_function(function_id, named_arguments, **kwargs):
  48. """Makes the interpreter run the function with the given id for the specified
  49. argument dictionary."""
  50. body_id, = yield [("RD", [function_id, FUNCTION_BODY_KEY])]
  51. args = {'body_id' : body_id, named_arguments : named_arguments}
  52. args.update(kwargs)
  53. yield [("TAIL_CALL_KWARGS", [interpret_function_body, args])]
  54. def interpret_function_body(body_id, named_arguments, **kwargs):
  55. """Makes the interpreter run the function body with the given id for the specified
  56. argument dictionary."""
  57. user_root = kwargs['user_root']
  58. kernel = kwargs['mvk']
  59. user_frame, = yield [("RD", [user_root, "frame"])]
  60. inst, = yield [("RD", [user_frame, "IP"])]
  61. kernel.jit.mark_entry_point(body_id)
  62. # Create a new stack frame.
  63. frame_link, new_phase, new_frame, new_evalstack, new_symbols, \
  64. new_returnvalue, intrinsic_return = \
  65. yield [("RDE", [user_root, "frame"]),
  66. ("CNV", ["init"]),
  67. ("CN", []),
  68. ("CN", []),
  69. ("CN", []),
  70. ("CN", []),
  71. ("CN", [])
  72. ]
  73. _, _, _, _, _, _, _, _, _, _ = \
  74. yield [("CD", [user_root, "frame", new_frame]),
  75. ("CD", [new_frame, "evalstack", new_evalstack]),
  76. ("CD", [new_frame, "symbols", new_symbols]),
  77. ("CD", [new_frame, "returnvalue", new_returnvalue]),
  78. ("CD", [new_frame, "caller", inst]),
  79. ("CD", [new_frame, "phase", new_phase]),
  80. ("CD", [new_frame, "IP", body_id]),
  81. ("CD", [new_frame, "prev", user_frame]),
  82. ("CD", [
  83. new_frame,
  84. primitive_functions.EXCEPTION_RETURN_KEY,
  85. intrinsic_return]),
  86. ("DE", [frame_link])
  87. ]
  88. # Put the parameters in the new stack frame's symbol table.
  89. (parameter_vars, parameter_names, _), = yield [
  90. ("CALL_ARGS", [kernel.jit.jit_signature, (body_id,)])]
  91. parameter_dict = dict(zip(parameter_names, parameter_vars))
  92. for (key, value) in named_arguments.items():
  93. param_var = parameter_dict[key]
  94. variable, = yield [("CN", [])]
  95. yield [("CD", [variable, "value", value])]
  96. symbol_edge, = yield [("CE", [new_symbols, variable])]
  97. yield [("CE", [symbol_edge, param_var])]
  98. username = kwargs['username']
  99. def exception_handler(ex):
  100. # print('Returning from interpreted function. Result: %s' % ex.result)
  101. raise primitive_functions.PrimitiveFinished(ex.result)
  102. # Create an exception handler to catch and translate InterpretedFunctionFinished.
  103. yield [("TRY", [])]
  104. yield [("CATCH", [primitive_functions.InterpretedFunctionFinished, exception_handler])]
  105. while 1:
  106. result, = yield [("CALL_ARGS", [kernel.execute_rule, (username,)])]
  107. # An instruction has completed. Forward it.
  108. yield result
  109. def get_input(**parameters):
  110. """Retrieves input."""
  111. mvk = parameters["mvk"]
  112. user_root = parameters["user_root"]
  113. while 1:
  114. yield [("CALL_ARGS", [mvk.input_init, (user_root,)])]
  115. # Finished
  116. if mvk.success:
  117. # Got some input, so we can access it
  118. raise primitive_functions.PrimitiveFinished(mvk.input_value)
  119. else:
  120. # No input, so yield None but don't stop
  121. yield None