intrinsics.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import time
  2. import modelverse_jit.jit as jit
  3. import modelverse_jit.tree_ir as tree_ir
  4. import modelverse_jit.runtime as jit_runtime
  5. BINARY_INTRINSICS = {
  6. 'value_eq' : '==',
  7. 'value_neq' : '!=',
  8. 'bool_and' : 'and',
  9. 'bool_or' : 'or',
  10. 'integer_addition' : '+',
  11. 'integer_subtraction' : '-',
  12. 'integer_multiplication' : '*',
  13. 'integer_division' : '/',
  14. 'integer_gt' : '>',
  15. 'integer_gte' : '>=',
  16. 'integer_lt' : '<',
  17. 'integer_lte' : '<=',
  18. 'float_addition' : '+',
  19. 'float_subtraction' : '-',
  20. 'float_multiplication' : '*',
  21. 'float_division' : '/',
  22. 'float_gt' : '>',
  23. 'float_gte' : '>=',
  24. 'float_lt' : '<',
  25. 'float_lte' : '<='
  26. }
  27. UNARY_INTRINSICS = {
  28. 'bool_not' : 'not',
  29. 'integer_neg' : '-',
  30. 'float_neg' : '-'
  31. }
  32. CAST_INTRINSICS = {
  33. 'cast_i2f' : float,
  34. 'cast_i2s' : str,
  35. 'cast_i2b' : bool,
  36. 'cast_f2i' : int,
  37. 'cast_f2s' : str,
  38. 'cast_f2b' : bool,
  39. 'cast_s2i' : int,
  40. 'cast_s2f' : float,
  41. 'cast_s2b' : bool,
  42. 'cast_b2i' : int,
  43. 'cast_b2f' : float,
  44. 'cast_b2s' : str
  45. }
  46. def create_get_length(expression):
  47. """Creates an expression that evaluates the given expression, and then
  48. computes the length of its result."""
  49. return tree_ir.CallInstruction(
  50. tree_ir.LoadGlobalInstruction('len'),
  51. [expression])
  52. # Don't compain about the variable names, pylint. It's important that we
  53. # get them right.
  54. # pylint: disable=I0011,C0103
  55. def __set_add(a, b):
  56. tmp = tree_ir.StoreLocalInstruction(None, a)
  57. return tree_ir.create_block(
  58. tmp,
  59. tree_ir.CreateEdgeInstruction(tmp.create_load(), b),
  60. tmp.create_load())
  61. def __dict_add(a, b, c):
  62. a_tmp = tree_ir.StoreLocalInstruction(None, a)
  63. b_tmp = tree_ir.StoreLocalInstruction(None, b)
  64. return tree_ir.create_block(
  65. a_tmp,
  66. b_tmp,
  67. tree_ir.CreateEdgeInstruction(
  68. tree_ir.CreateEdgeInstruction(a_tmp.create_load(), c),
  69. b_tmp.create_load()),
  70. a_tmp.create_load())
  71. def __list_read(a, b):
  72. # The statements in this function generate the following code:
  73. #
  74. # a_tmp = a # To make sure a is evaluated before b.
  75. # b_value, = yield [("RV", [b])]
  76. # result, = yield [("RD", [a_tmp, b_value])]
  77. # if result is None:
  78. # raise Exception("List read out of bounds: %s" % b_value)
  79. # result
  80. a_tmp = tree_ir.StoreLocalInstruction(None, a)
  81. b_val = tree_ir.StoreLocalInstruction(
  82. None,
  83. tree_ir.ReadValueInstruction(b))
  84. result = tree_ir.StoreLocalInstruction(
  85. None,
  86. tree_ir.ReadDictionaryValueInstruction(
  87. a_tmp.create_load(), b_val.create_load()))
  88. return tree_ir.create_block(
  89. a_tmp,
  90. b_val,
  91. result,
  92. tree_ir.SelectInstruction(
  93. tree_ir.BinaryInstruction(
  94. result.create_load(),
  95. 'is',
  96. tree_ir.LiteralInstruction(None)),
  97. tree_ir.RaiseInstruction(
  98. tree_ir.CallInstruction(
  99. tree_ir.LoadGlobalInstruction('Exception'),
  100. [tree_ir.BinaryInstruction(
  101. tree_ir.LiteralInstruction('List read out of bounds: %s'),
  102. '%',
  103. b_val.create_load())])),
  104. tree_ir.NopInstruction()),
  105. result.create_load())
  106. def __list_append(a, b):
  107. # We want to generate code that is more or less equivalent to:
  108. #
  109. # a_tmp = a
  110. # b_tmp = b
  111. # a_outgoing, = yield [("RO", [a_tmp])]
  112. # _ = yield [("CD", [a_tmp, len(a_outgoing), b_tmp])]
  113. # a
  114. a_tmp = tree_ir.StoreLocalInstruction(None, a)
  115. b_tmp = tree_ir.StoreLocalInstruction(None, b)
  116. return tree_ir.create_block(
  117. a_tmp,
  118. tree_ir.CreateDictionaryEdgeInstruction(
  119. a_tmp.create_load(),
  120. create_get_length(
  121. tree_ir.ReadOutgoingEdgesInstruction(
  122. a_tmp.create_load())),
  123. b_tmp),
  124. a_tmp.create_load())
  125. def __log(a):
  126. # Original definition:
  127. #
  128. # def log(a, **remainder):
  129. # a_value, = yield [("RV", [a])]
  130. # print("== LOG == " + str(a_value))
  131. # raise PrimitiveFinished(a)
  132. a_tmp = tree_ir.StoreLocalInstruction(None, a)
  133. return tree_ir.CompoundInstruction(
  134. tree_ir.create_block(
  135. a_tmp,
  136. tree_ir.PrintInstruction(
  137. tree_ir.BinaryInstruction(
  138. tree_ir.LiteralInstruction("== LOG == "),
  139. '+',
  140. tree_ir.CallInstruction(
  141. tree_ir.LoadGlobalInstruction('str'),
  142. [tree_ir.ReadValueInstruction(a_tmp.create_load())])))),
  143. a_tmp.create_load())
  144. MISC_INTRINSICS = {
  145. # Reference equality
  146. 'element_eq' :
  147. lambda a, b:
  148. tree_ir.CreateNodeWithValueInstruction(
  149. tree_ir.BinaryInstruction(a, '==', b)),
  150. 'element_neq' :
  151. lambda a, b:
  152. tree_ir.CreateNodeWithValueInstruction(
  153. tree_ir.BinaryInstruction(a, '!=', b)),
  154. # Strings
  155. 'string_get' :
  156. lambda a, b:
  157. tree_ir.CreateNodeWithValueInstruction(
  158. tree_ir.LoadIndexInstruction(
  159. tree_ir.ReadValueInstruction(a),
  160. tree_ir.ReadValueInstruction(b))),
  161. 'string_len' :
  162. lambda a:
  163. tree_ir.CreateNodeWithValueInstruction(
  164. tree_ir.CallInstruction(
  165. tree_ir.LoadGlobalInstruction('len'),
  166. [tree_ir.ReadValueInstruction(a)])),
  167. 'string_join' :
  168. lambda a, b:
  169. tree_ir.CreateNodeWithValueInstruction(
  170. tree_ir.BinaryInstruction(
  171. tree_ir.CallInstruction(
  172. tree_ir.LoadGlobalInstruction('str'),
  173. [tree_ir.ReadValueInstruction(a)]),
  174. '+',
  175. tree_ir.CallInstruction(
  176. tree_ir.LoadGlobalInstruction('str'),
  177. [tree_ir.ReadValueInstruction(b)]))),
  178. 'string_startswith' :
  179. lambda a, b:
  180. tree_ir.CreateNodeWithValueInstruction(
  181. tree_ir.CallInstruction(
  182. tree_ir.LoadMemberInstruction(
  183. tree_ir.ReadValueInstruction(a),
  184. 'startswith'),
  185. [tree_ir.ReadValueInstruction(b)])),
  186. # State creation
  187. 'create_node' : tree_ir.CreateNodeInstruction,
  188. 'create_edge' :
  189. # Lambda is totally necessary here, pylint.
  190. # You totally dropped the ball on this one.
  191. # pylint: disable=I0011,W0108
  192. lambda a, b:
  193. tree_ir.CreateEdgeInstruction(a, b),
  194. 'create_value' :
  195. lambda a:
  196. tree_ir.CreateNodeWithValueInstruction(
  197. tree_ir.ReadValueInstruction(a)),
  198. # State reads
  199. 'read_edge_src' :
  200. lambda a:
  201. tree_ir.LoadIndexInstruction(
  202. tree_ir.ReadEdgeInstruction(a),
  203. tree_ir.LiteralInstruction(0)),
  204. 'read_edge_dst' :
  205. lambda a:
  206. tree_ir.LoadIndexInstruction(
  207. tree_ir.ReadEdgeInstruction(a),
  208. tree_ir.LiteralInstruction(1)),
  209. 'is_edge' :
  210. lambda a:
  211. tree_ir.CreateNodeWithValueInstruction(
  212. tree_ir.BinaryInstruction(
  213. tree_ir.LoadIndexInstruction(
  214. tree_ir.ReadEdgeInstruction(a),
  215. tree_ir.LiteralInstruction(0)),
  216. 'is not',
  217. tree_ir.LiteralInstruction(None))),
  218. # read_root
  219. 'read_root' :
  220. lambda:
  221. tree_ir.LoadIndexInstruction(
  222. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
  223. tree_ir.LiteralInstruction('root')),
  224. # read_userroot
  225. 'read_userroot' :
  226. lambda:
  227. tree_ir.LoadIndexInstruction(
  228. tree_ir.LoadLocalInstruction(jit_runtime.KWARGS_PARAMETER_NAME),
  229. tree_ir.LiteralInstruction('user_root')),
  230. # Dictionary operations
  231. 'dict_read' :
  232. lambda a, b:
  233. tree_ir.ReadDictionaryValueInstruction(
  234. a, tree_ir.ReadValueInstruction(b)),
  235. 'dict_read_edge' :
  236. lambda a, b:
  237. tree_ir.ReadDictionaryEdgeInstruction(
  238. a, tree_ir.ReadValueInstruction(b)),
  239. 'dict_add' : __dict_add,
  240. # Set operations
  241. 'set_add' : __set_add,
  242. # List operations
  243. 'list_len' :
  244. lambda a:
  245. tree_ir.CreateNodeWithValueInstruction(
  246. create_get_length(tree_ir.ReadOutgoingEdgesInstruction(a))),
  247. 'list_read' : __list_read,
  248. 'list_append' : __list_append,
  249. # log
  250. 'log' : __log
  251. }
  252. def register_time_intrinsic(target_jit):
  253. """Registers the time() intrinsic with the given JIT."""
  254. import_name = target_jit.import_value(time.time, 'time')
  255. target_jit.register_intrinsic(
  256. 'time',
  257. lambda: tree_ir.CreateNodeWithValueInstruction(
  258. tree_ir.CallInstruction(
  259. tree_ir.LoadGlobalInstruction(import_name),
  260. [])))
  261. def register_intrinsics(target_jit):
  262. """Registers all intrinsics in the module with the given JIT."""
  263. for (key, value) in BINARY_INTRINSICS.items():
  264. target_jit.register_binary_intrinsic(key, value)
  265. for (key, value) in UNARY_INTRINSICS.items():
  266. target_jit.register_unary_intrinsic(key, value)
  267. for (key, value) in CAST_INTRINSICS.items():
  268. target_jit.register_cast_intrinsic(key, value)
  269. for (key, value) in MISC_INTRINSICS.items():
  270. target_jit.register_intrinsic(key, value)
  271. register_time_intrinsic(target_jit)