jit.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import math
  2. import keyword
  3. from collections import defaultdict
  4. import modelverse_kernel.primitives as primitive_functions
  5. class ModelverseJit(object):
  6. """A high-level interface to the modelverse JIT compiler."""
  7. def __init__(self):
  8. self.todo_entry_points = set()
  9. self.jitted_parameters = {}
  10. self.jit_globals = {
  11. 'PrimitiveFinished' : primitive_functions.PrimitiveFinished,
  12. }
  13. # jitted_entry_points maps body ids to values in jit_globals.
  14. self.jitted_entry_points = {}
  15. # global_functions maps global value names to body ids.
  16. self.global_functions = {}
  17. # global_functions_inv maps body ids to global value names.
  18. self.global_functions_inv = {}
  19. # jitted_function_aliases maps body ids to known aliases.
  20. self.jitted_function_aliases = defaultdict(set)
  21. self.jit_count = 0
  22. self.compilation_dependencies = {}
  23. self.cache = {}
  24. def mark_entry_point(self, body_id):
  25. """Marks the node with the given identifier as a function entry point."""
  26. if body_id not in self.jitted_entry_points:
  27. self.todo_entry_points.add(body_id)
  28. def is_jittable_entry_point(self, body_id):
  29. """Tells if the node with the given identifier is a function entry point that
  30. has not been marked as non-jittable. This only returns `True` if the JIT
  31. is enabled and the function entry point has been marked jittable, or if
  32. the function has already been compiled."""
  33. return ((body_id in self.todo_entry_points) or (body_id in self.jitted_entry_points))
  34. def generate_name(self, infix, suggested_name=None):
  35. """Generates a new name or picks the suggested name if it is still
  36. available."""
  37. if suggested_name is not None \
  38. and suggested_name not in self.jit_globals \
  39. and not keyword.iskeyword(suggested_name):
  40. self.jit_count += 1
  41. return suggested_name
  42. else:
  43. function_name = 'jit_%s%d' % (infix, self.jit_count)
  44. self.jit_count += 1
  45. return function_name
  46. def generate_function_name(self, body_id, suggested_name=None):
  47. """Generates a new function name or picks the suggested name if it is still
  48. available."""
  49. if suggested_name is None:
  50. suggested_name = self.get_global_name(body_id)
  51. return self.generate_name('func', suggested_name)
  52. def register_global(self, body_id, global_name):
  53. """Associates the given body id with the given global name."""
  54. self.global_functions[global_name] = body_id
  55. self.global_functions_inv[body_id] = global_name
  56. def get_global_name(self, body_id):
  57. """Gets the name of the global function with the given body id.
  58. Returns None if no known global exists with the given id."""
  59. if body_id in self.global_functions_inv:
  60. return self.global_functions_inv[body_id]
  61. else:
  62. return None
  63. def get_global_body_id(self, global_name):
  64. """Gets the body id of the global function with the given name.
  65. Returns None if no known global exists with the given name."""
  66. if global_name in self.global_functions:
  67. return self.global_functions[global_name]
  68. else:
  69. return None
  70. def register_compiled(self, body_id, compiled_function, function_name=None):
  71. """Registers a compiled entry point with the JIT."""
  72. # Get the function's name.
  73. actual_function_name = self.generate_function_name(body_id, function_name)
  74. # Map the body id to the given parameter list.
  75. self.jitted_entry_points[body_id] = actual_function_name
  76. self.jit_globals[actual_function_name] = compiled_function
  77. if function_name is not None:
  78. self.register_global(body_id, function_name)
  79. if body_id in self.todo_entry_points:
  80. self.todo_entry_points.remove(body_id)
  81. def __lookup_compiled_body_impl(self, body_id):
  82. """Looks up a compiled function by body id. Returns a matching function,
  83. or None if no function was found."""
  84. if body_id is not None and body_id in self.jitted_entry_points:
  85. return self.jit_globals[self.jitted_entry_points[body_id]]
  86. else:
  87. return None
  88. def __lookup_external_body_impl(self, global_name, body_id):
  89. """Looks up an external function by global name. Returns a matching function,
  90. or None if no function was found."""
  91. if global_name is not None and self.compiled_function_lookup is not None:
  92. result = self.compiled_function_lookup(global_name)
  93. if result is not None and body_id is not None:
  94. self.register_compiled(body_id, result, global_name)
  95. return result
  96. else:
  97. return None
  98. def lookup_compiled_body(self, body_id):
  99. """Looks up a compiled function by body id. Returns a matching function,
  100. or None if no function was found."""
  101. result = self.__lookup_compiled_body_impl(body_id)
  102. if result is not None:
  103. return result
  104. else:
  105. global_name = self.get_global_name(body_id)
  106. return self.__lookup_external_body_impl(global_name, body_id)
  107. def lookup_compiled_function(self, global_name):
  108. """Looks up a compiled function by global name. Returns a matching function,
  109. or None if no function was found."""
  110. body_id = self.get_global_body_id(global_name)
  111. result = self.__lookup_compiled_body_impl(body_id)
  112. if result is not None:
  113. return result
  114. else:
  115. return self.__lookup_external_body_impl(global_name, body_id)
  116. def jit_signature(self, body_id):
  117. """Acquires the signature for the given body id node, which consists of the
  118. parameter variables, parameter name and a flag that tells if the given function
  119. is mutable."""
  120. if body_id not in self.jitted_parameters:
  121. signature_id, = yield [("RRD", [body_id, "body"])]
  122. signature_id = signature_id[0]
  123. param_set_id, is_mutable = yield [
  124. ("RD", [signature_id, "params"]),
  125. ("RD", [signature_id, "mutable"])]
  126. if param_set_id is None:
  127. self.jitted_parameters[body_id] = ([], [], is_mutable)
  128. else:
  129. param_name_ids, = yield [("RDK", [param_set_id])]
  130. param_names = yield [("RV", [n]) for n in param_name_ids]
  131. #NOTE Patch up strange links...
  132. param_names = [i for i in param_names if i is not None]
  133. param_vars = yield [("RD", [param_set_id, k]) for k in param_names]
  134. #NOTE that variables might not be in the correct order, as we just read them out!
  135. lst = sorted([(name, var) for name, var in zip(param_names, param_vars)])
  136. param_vars = [i[1] for i in lst]
  137. param_names = [i[0] for i in lst]
  138. self.jitted_parameters[body_id] = (param_vars, param_names, is_mutable)
  139. raise primitive_functions.PrimitiveFinished(self.jitted_parameters[body_id])
  140. def check_jittable(self, body_id, suggested_name=None):
  141. """Checks if the function with the given body id is obviously non-jittable. If it's
  142. non-jittable, then a `JitCompilationFailedException` exception is thrown."""
  143. if body_id is None:
  144. raise ValueError('body_id cannot be None: ' + suggested_name)
  145. def jit_define_function(self, function_name, function_def):
  146. """Converts the given tree-IR function definition to Python code, defines it,
  147. and extracts the resulting function."""
  148. # The comment below makes pylint shut up about our (hopefully benign) use of exec here.
  149. # pylint: disable=I0011,W0122
  150. if self.jit_code_log_function is not None:
  151. self.jit_code_log_function(function_def)
  152. # Convert the function definition to Python code, and compile it.
  153. code_generator = tree_ir.PythonGenerator()
  154. function_def.generate_python_def(code_generator)
  155. source_map_name = self.get_source_map_name(function_name)
  156. if source_map_name is not None:
  157. self.jit_globals[source_map_name] = code_generator.source_map_builder.source_map
  158. exec(str(code_generator), self.jit_globals)
  159. # Extract the compiled function from the JIT global state.
  160. return self.jit_globals[function_name]
  161. def jit_delete_function(self, function_name):
  162. """Deletes the function with the given function name."""
  163. del self.jit_globals[function_name]
  164. import modelverse_kernel.primitives as primitive_functions
  165. class JitCompilationFailedException(Exception):
  166. """A type of exception that is raised when the jit fails to compile a function."""
  167. pass