compiled.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from modelverse_kernel.primitives import PrimitiveFinished
  2. import modelverse_jit.runtime as jit_runtime
  3. def reverseKeyLookupMulti(a, b, **remainder):
  4. edges, = yield [("RO", [a])]
  5. b_val, = yield [("RV", [b])]
  6. expanded_edges = yield [("RE", [i]) for i in edges]
  7. values = yield [("RV", [i[1]]) for i in expanded_edges]
  8. result, = yield [("CN", [])]
  9. for i, edge in enumerate(values):
  10. if b_val == edge:
  11. # Found our edge: edges[i]
  12. outgoing, = yield [("RO", [edges[i]])]
  13. value, = yield [("RE", [outgoing[0]])]
  14. yield [("CE", [result, value[1]])]
  15. raise PrimitiveFinished(result)
  16. """
  17. edges_out, edges_in, result = yield [("RO", [a]), ("RI", [b]), ("CN", [])]
  18. python_result = set()
  19. options = set(edges_out) & set(edges_in)
  20. if options:
  21. all_destination_edges = yield [("RO", [i]) for i in options]
  22. for e in all_destination_edges:
  23. st, = yield [("RE", [edge]) for edge in e]
  24. s, t = st
  25. python_result.add(t)
  26. yield [("CE", [result, i]) for i in python_result]
  27. raise PrimitiveFinished(result)
  28. """
  29. def reverseKeyLookup(a, b, **remainder):
  30. edges_out, edges_in = yield [("RO", [a]), ("RI", [b])]
  31. options = set(edges_out) & set(edges_in)
  32. if options:
  33. # Select one option randomly
  34. edge = options.pop()
  35. out_edges, = yield [("RO", [edge])]
  36. # Select one option randomly
  37. out_edge = out_edges.pop()
  38. e, = yield [("RE", [out_edge])]
  39. s, t = e
  40. raise PrimitiveFinished(t)
  41. else:
  42. result, = yield [("CNV", ["(unknown: %s)" % b])]
  43. raise PrimitiveFinished(result)
  44. def set_copy(a, **remainder):
  45. b, = yield [("CN", [])]
  46. links, = yield [("RO", [a])]
  47. exp_links = yield [("RE", [i]) for i in links]
  48. _ = yield [("CE", [b, i[1]]) for i in exp_links]
  49. raise PrimitiveFinished(b)
  50. def check_symbols(a, b, c, **remainder):
  51. symbols = {}
  52. function_name, = yield [("RV", [b])]
  53. symbols[function_name] = False
  54. object_links, = yield [("RO", [c])]
  55. set_elements = yield [("RE", [i]) for i in object_links]
  56. set_elements = [i[1] for i in set_elements]
  57. set_values = yield [("RV", [i]) for i in set_elements]
  58. set_elements = yield [("RD", [a, i]) for i in set_values]
  59. symbols_set = yield [("RD", [i, "symbols"]) for i in set_elements]
  60. all_keys = yield [("RDK", [i]) for i in symbols_set]
  61. for i, s in zip(all_keys, symbols_set):
  62. # For each object we have found
  63. keys = yield [("RV", [j]) for j in i]
  64. values = yield [("RD", [s, j]) for j in keys]
  65. values = yield [("RV", [j]) for j in values]
  66. for key, value in zip(keys, values):
  67. k = key
  68. v = value
  69. if v and symbols.get(k, False):
  70. result, = yield [("CNV", ["ERROR: multiple definition of symbol " + str(key)])]
  71. raise PrimitiveFinished(result)
  72. elif v and not symbols.get(k, False):
  73. symbols[k] = True
  74. elif not v and k not in symbols:
  75. symbols[k] = False
  76. for i, j in symbols.items():
  77. if i == "input" or i == "output":
  78. continue
  79. if not j:
  80. result, = yield [("CNV", ["ERROR: undefined symbol " + str(i)])]
  81. raise PrimitiveFinished(result)
  82. result, = yield [("CNV", ["OK"])]
  83. raise PrimitiveFinished(result)
  84. def construct_const(**remainder):
  85. v, = yield [("CNV", [{"value": "constant"}])]
  86. # Get input: keep trying until we get something
  87. inp, = yield [("CALL_KWARGS", [jit_runtime.get_input, remainder])]
  88. yield [("CD", [v, "node", inp])]
  89. raise PrimitiveFinished(v)
  90. def instantiated_name(a, b, **remainder):
  91. name_value, = yield [("RV", [b])]
  92. if name_value == "":
  93. b, = yield [("CNV", ["__" + str(a)])]
  94. raise PrimitiveFinished(b)
  95. def set_merge(a, b, **remainder):
  96. outputs, = yield [("RO", [b])]
  97. values = yield [("RE", [i]) for i in outputs]
  98. yield [("CE", [a, i[1]]) for i in values]
  99. raise PrimitiveFinished(a)
  100. def has_value(a, **remainder):
  101. v, = yield [("RV", [a])]
  102. if v is None:
  103. result, = yield [("CNV", [False])]
  104. else:
  105. result, = yield [("CNV", [True])]
  106. raise PrimitiveFinished(result)
  107. def make_reverse_dictionary(a, **remainder):
  108. reverse, = yield [("CN", [])]
  109. key_nodes, = yield [("RDK", [a])]
  110. values = yield [("RDN", [a, i]) for i in key_nodes]
  111. yield [("CD", [reverse, str(v), k]) for k, v in zip(key_nodes, values)]
  112. raise PrimitiveFinished(reverse)
  113. def dict_eq(a, b, **remainder):
  114. key_nodes, = yield [("RDK", [a])]
  115. key_values = yield [("RV", [i]) for i in key_nodes]
  116. values = yield [("RD", [a, i]) for i in key_values]
  117. values = yield [("RV", [i]) for i in values]
  118. a_dict = dict(zip(key_values, values))
  119. key_nodes, = yield [("RDK", [b])]
  120. key_values = yield [("RV", [i]) for i in key_nodes]
  121. values = yield [("RD", [b, i]) for i in key_values]
  122. values = yield [("RV", [i]) for i in values]
  123. b_dict = dict(zip(key_values, values))
  124. result, = yield [("CNV", [a_dict == b_dict])]
  125. raise PrimitiveFinished(result)
  126. def string_substr(a, b, c, **remainder):
  127. a_val, b_val, c_val = yield [("RV", [a]),
  128. ("RV", [b]),
  129. ("RV", [c])]
  130. try:
  131. new_value = a_val[b_val:c_val]
  132. except:
  133. new_value = ""
  134. result, = yield [("CNV", [new_value])]
  135. raise PrimitiveFinished(result)