compiled.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from modelverse_kernel.primitives import PrimitiveFinished
  2. import modelverse_jit.runtime as jit_runtime
  3. def reverseKeyLookupMulti(a, b, **remainder):
  4. edges, = yield [("RO", [a])]
  5. b_val, = yield [("RV", [b])]
  6. expanded_edges = yield [("RE", [i]) for i in edges]
  7. values = yield [("RV", [i[1]]) for i in expanded_edges]
  8. result, = yield [("CN", [])]
  9. for i, edge in enumerate(values):
  10. if b_val == edge:
  11. # Found our edge: edges[i]
  12. outgoing, = yield [("RO", [edges[i]])]
  13. value, = yield [("RE", [outgoing[0]])]
  14. yield [("CE", [result, value[1]])]
  15. raise PrimitiveFinished(result)
  16. def reverseKeyLookup(a, b, **remainder):
  17. # Guess, which might work
  18. guess, = yield [("RD", [a, "__%s" % b])]
  19. if guess == b:
  20. result, = yield [("CNV", ["__%s" % b])]
  21. raise PrimitiveFinished(result)
  22. # We failed: do the slow way
  23. edges, = yield [("RO", [a])]
  24. expanded_edges = yield [("RE", [i]) for i in edges]
  25. for i, edge in enumerate(expanded_edges):
  26. if b == edge[1]:
  27. # Found our edge: edges[i]
  28. outgoing, = yield [("RO", [edges[i]])]
  29. result, = yield [("RE", [outgoing[0]])]
  30. raise PrimitiveFinished(result[1])
  31. result, = yield [("CNV", ["(unknown: %s)" % b])]
  32. raise PrimitiveFinished(result)
  33. def set_copy(a, **remainder):
  34. b, = yield [("CN", [])]
  35. links, = yield [("RO", [a])]
  36. exp_links = yield [("RE", [i]) for i in links]
  37. _ = yield [("CE", [b, i[1]]) for i in exp_links]
  38. raise PrimitiveFinished(b)
  39. def check_symbols(a, b, c, **remainder):
  40. symbols = {}
  41. function_name, = yield [("RV", [b])]
  42. symbols[function_name] = False
  43. object_links, = yield [("RO", [c])]
  44. set_elements = yield [("RE", [i]) for i in object_links]
  45. set_elements = [i[1] for i in set_elements]
  46. set_values = yield [("RV", [i]) for i in set_elements]
  47. set_elements = yield [("RD", [a, i]) for i in set_values]
  48. symbols_set = yield [("RD", [i, "symbols"]) for i in set_elements]
  49. all_keys = yield [("RDK", [i]) for i in symbols_set]
  50. for i, s in zip(all_keys, symbols_set):
  51. # For each object we have found
  52. keys = yield [("RV", [j]) for j in i]
  53. values = yield [("RD", [s, j]) for j in keys]
  54. values = yield [("RV", [j]) for j in values]
  55. for key, value in zip(keys, values):
  56. k = key
  57. v = value
  58. if v and symbols.get(k, False):
  59. result, = yield [("CNV", ["ERROR: multiple definition of symbol " + str(key)])]
  60. raise PrimitiveFinished(result)
  61. elif v and not symbols.get(k, False):
  62. symbols[k] = True
  63. elif not v and k not in symbols:
  64. symbols[k] = False
  65. for i, j in symbols.items():
  66. if i == "input" or i == "output":
  67. continue
  68. if not j:
  69. result, = yield [("CNV", ["ERROR: undefined symbol " + str(i)])]
  70. raise PrimitiveFinished(result)
  71. result, = yield [("CNV", ["OK"])]
  72. raise PrimitiveFinished(result)
  73. def construct_const(**remainder):
  74. v, = yield [("CNV", [{"value": "constant"}])]
  75. # Get input: keep trying until we get something
  76. inp, = yield [("CALL_KWARGS", [jit_runtime.get_input, remainder])]
  77. yield [("CD", [v, "node", inp])]
  78. raise PrimitiveFinished(v)
  79. def instantiated_name(a, b, **remainder):
  80. name_value, = yield [("RV", [b])]
  81. if name_value == "":
  82. b, = yield [("CNV", ["__" + str(a)])]
  83. raise PrimitiveFinished(b)
  84. def set_merge(a, b, **remainder):
  85. outputs, = yield [("RO", [b])]
  86. values = yield [("RE", [i]) for i in outputs]
  87. yield [("CE", [a, i[1]]) for i in values]
  88. raise PrimitiveFinished(a)
  89. def has_value(a, **remainder):
  90. v, = yield [("RV", [a])]
  91. if v is None:
  92. result, = yield [("CNV", [False])]
  93. else:
  94. result, = yield [("CNV", [True])]
  95. raise PrimitiveFinished(result)
  96. def make_reverse_dictionary(a, **remainder):
  97. reverse, = yield [("CN", [])]
  98. key_nodes, = yield [("RDK", [a])]
  99. values = yield [("RDN", [a, i]) for i in key_nodes]
  100. yield [("CD", [reverse, str(v), k]) for k, v in zip(key_nodes, values)]
  101. raise PrimitiveFinished(reverse)
  102. def dict_eq(a, b, **remainder):
  103. key_nodes, = yield [("RDK", [a])]
  104. key_values = yield [("RV", [i]) for i in key_nodes]
  105. values = yield [("RD", [a, i]) for i in key_values]
  106. values = yield [("RV", [i]) for i in values]
  107. a_dict = dict(zip(key_values, values))
  108. key_nodes, = yield [("RDK", [b])]
  109. key_values = yield [("RV", [i]) for i in key_nodes]
  110. values = yield [("RD", [b, i]) for i in key_values]
  111. values = yield [("RV", [i]) for i in values]
  112. b_dict = dict(zip(key_values, values))
  113. result, = yield [("CNV", [a_dict == b_dict])]
  114. raise PrimitiveFinished(result)