compiled.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from modelverse_kernel.primitives import PrimitiveFinished
  2. import modelverse_jit.runtime as jit_runtime
  3. def reverseKeyLookupMulti(a, b, **remainder):
  4. edges, = yield [("RO", [a])]
  5. b_val, = yield [("RV", [b])]
  6. expanded_edges = yield [("RE", [i]) for i in edges]
  7. values = yield [("RV", [i[1]]) for i in expanded_edges]
  8. result, = yield [("CN", [])]
  9. for i, edge in enumerate(values):
  10. if b_val == edge:
  11. # Found our edge: edges[i]
  12. outgoing, = yield [("RO", [edges[i]])]
  13. value, = yield [("RE", [outgoing[0]])]
  14. yield [("CE", [result, value[1]])]
  15. raise PrimitiveFinished(result)
  16. def reverseKeyLookup(a, b, **remainder):
  17. edges, = yield [("RO", [a])]
  18. expanded_edges = yield [("RE", [i]) for i in edges]
  19. for i, edge in enumerate(expanded_edges):
  20. if b == edge[1]:
  21. # Found our edge: edges[i]
  22. outgoing, = yield [("RO", [edges[i]])]
  23. result, = yield [("RE", [outgoing[0]])]
  24. raise PrimitiveFinished(result[1])
  25. result, = yield [("CNV", ["(unknown: %s)" % b])]
  26. raise PrimitiveFinished(result)
  27. def set_copy(a, **remainder):
  28. b, = yield [("CN", [])]
  29. links, = yield [("RO", [a])]
  30. exp_links = yield [("RE", [i]) for i in links]
  31. _ = yield [("CE", [b, i[1]]) for i in exp_links]
  32. raise PrimitiveFinished(b)
  33. def check_symbols(a, b, c, **remainder):
  34. symbols = {}
  35. function_name, = yield [("RV", [b])]
  36. symbols[function_name] = False
  37. object_links, = yield [("RO", [c])]
  38. set_elements = yield [("RE", [i]) for i in object_links]
  39. set_elements = [i[1] for i in set_elements]
  40. set_values = yield [("RV", [i]) for i in set_elements]
  41. set_elements = yield [("RD", [a, i]) for i in set_values]
  42. symbols_set = yield [("RD", [i, "symbols"]) for i in set_elements]
  43. all_keys = yield [("RDK", [i]) for i in symbols_set]
  44. for i, s in zip(all_keys, symbols_set):
  45. # For each object we have found
  46. keys = yield [("RV", [j]) for j in i]
  47. values = yield [("RD", [s, j]) for j in keys]
  48. values = yield [("RV", [j]) for j in values]
  49. for key, value in zip(keys, values):
  50. k = key
  51. v = value
  52. if v and symbols.get(k, False):
  53. result, = yield [("CNV", ["ERROR: multiple definition of symbol " + str(key)])]
  54. raise PrimitiveFinished(result)
  55. elif v and not symbols.get(k, False):
  56. symbols[k] = True
  57. elif not v and k not in symbols:
  58. symbols[k] = False
  59. for i, j in symbols.items():
  60. if i == "input" or i == "output":
  61. continue
  62. if not j:
  63. result, = yield [("CNV", ["ERROR: undefined symbol " + str(i)])]
  64. raise PrimitiveFinished(result)
  65. result, = yield [("CNV", ["OK"])]
  66. raise PrimitiveFinished(result)
  67. def construct_const(**remainder):
  68. v, = yield [("CNV", [{"value": "constant"}])]
  69. # Get input: keep trying until we get something
  70. inp, = yield [("CALL_KWARGS", [jit_runtime.get_input, remainder])]
  71. yield [("CD", [v, "node", inp])]
  72. raise PrimitiveFinished(v)
  73. def instantiated_name(a, b, **remainder):
  74. name_value, = yield [("RV", [b])]
  75. if name_value == "":
  76. b, = yield [("CNV", ["__" + str(a)])]
  77. raise PrimitiveFinished(b)
  78. def set_merge(a, b, **remainder):
  79. outputs, = yield [("RO", [b])]
  80. values = yield [("RE", [i]) for i in outputs]
  81. yield [("CE", [a, i[1]]) for i in values]
  82. raise PrimitiveFinished(a)