rule.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import pprint
  2. from typing import Generator, Callable
  3. from uuid import UUID
  4. import functools
  5. from api.od import ODAPI
  6. from concrete_syntax.common import indent
  7. from transformation.matcher import match_od
  8. from transformation.rewriter import rewrite
  9. from transformation.cloner import clone_od
  10. from util.timer import Timer
  11. class Rule:
  12. def __init__(self, nacs: list[UUID], lhs: UUID, rhs: UUID):
  13. self.nacs = nacs
  14. self.lhs = lhs
  15. self.rhs = rhs
  16. PP = pprint.PrettyPrinter(depth=4)
  17. class _NAC_MATCHED(Exception):
  18. pass
  19. # Helper for executing NAC/LHS/RHS-type rules
  20. class RuleMatcherRewriter:
  21. def __init__(self, state, mm: UUID, mm_ramified: UUID):
  22. self.state = state
  23. self.mm = mm
  24. self.mm_ramified = mm_ramified
  25. # Generates matches.
  26. # Every match is a dictionary with entries LHS_element_name -> model_element_name
  27. def match_rule(self, m: UUID, lhs: UUID, nacs: list[UUID], rule_name: str) -> Generator[dict, None, None]:
  28. lhs_matcher = match_od(self.state,
  29. host_m=m,
  30. host_mm=self.mm,
  31. pattern_m=lhs,
  32. pattern_mm=self.mm_ramified)
  33. try:
  34. # First we iterate over LHS-matches:
  35. # for i, lhs_match in enumerate(lhs_matcher):
  36. x=0
  37. while True:
  38. try:
  39. with Timer(f"MATCH LHS {rule_name}"):
  40. lhs_match = lhs_matcher.__next__()
  41. x += 1
  42. nac_matched = False
  43. try:
  44. for i_nac, nac in enumerate(nacs):
  45. # For every LHS-match, we see if there is a NAC-match:
  46. nac_matcher = match_od(self.state,
  47. host_m=m,
  48. host_mm=self.mm,
  49. pattern_m=nac,
  50. pattern_mm=self.mm_ramified,
  51. pivot=lhs_match) # try to "grow" LHS-match with NAC-match
  52. try:
  53. # for nac_match in nac_matcher:
  54. while True:
  55. try:
  56. with Timer(f"MATCH NAC{i_nac} {rule_name}"):
  57. nac_match = nac_matcher.__next__()
  58. raise _NAC_MATCHED()
  59. except StopIteration:
  60. break # no more nac-matches
  61. # The NAC has at least one match
  62. # (there could be more, but we know enough, so let's not waste CPU/MEM resources and proceed to next LHS match)
  63. nac_matched = True
  64. break
  65. except Exception as e:
  66. # The exception may originate from eval'ed condition-code in LHS or NAC
  67. # Decorate exception with some context, to help with debugging
  68. e.add_note(f"while matching NAC of '{rule_name}'")
  69. raise
  70. except _NAC_MATCHED:
  71. continue # continue with next LHS-match
  72. # There were no NAC matches -> yield LHS-match!
  73. yield lhs_match
  74. except StopIteration:
  75. break # no more lhs-matches
  76. except Exception as e:
  77. # The exception may originate from eval'ed condition-code in LHS or NAC
  78. # Decorate exception with some context, to help with debugging
  79. e.add_note(f"while matching LHS of '{rule_name}'")
  80. raise
  81. def exec_rule(self, m: UUID, lhs: UUID, rhs: UUID, lhs_match: dict, rule_name: str):
  82. cloned_m = clone_od(self.state, m, self.mm)
  83. try:
  84. rhs_match = rewrite(self.state,
  85. lhs_m=lhs,
  86. rhs_m=rhs,
  87. pattern_mm=self.mm_ramified,
  88. lhs_match=lhs_match,
  89. host_m=cloned_m,
  90. host_mm=self.mm)
  91. except Exception as e:
  92. # Make exceptions raised in eval'ed code easier to trace:
  93. e.add_note(f"while executing RHS of '{rule_name}'")
  94. raise
  95. return (cloned_m, rhs_match)
  96. # Generator that yields actions in the format expected by 'Simulator' class
  97. class ActionGenerator:
  98. def __init__(self, matcher_rewriter: RuleMatcherRewriter, rule_dict: dict[str, Rule]):
  99. self.matcher_rewriter = matcher_rewriter
  100. self.rule_dict = rule_dict
  101. def __call__(self, od: ODAPI):
  102. at_least_one_match = False
  103. for rule_name, rule in self.rule_dict.items():
  104. match_iterator = self.matcher_rewriter.match_rule(od.m, rule.lhs, rule.nacs, rule_name)
  105. x = 0
  106. while True:
  107. try:
  108. # if True:
  109. with Timer(f"MATCH RULE {rule_name}"):
  110. lhs_match = match_iterator.__next__()
  111. x += 1
  112. # We got a match!
  113. def do_action(od, rule, lhs_match, rule_name):
  114. with Timer(f"EXEC RHS {rule_name}"):
  115. new_m, rhs_match = self.matcher_rewriter.exec_rule(od.m, rule.lhs, rule.rhs, lhs_match, rule_name)
  116. msgs = [f"executed rule '{rule_name}'\n" + indent(PP.pformat(rhs_match), 6)]
  117. return (ODAPI(od.state, new_m, od.mm), msgs)
  118. yield (
  119. rule_name + '\n' + indent(PP.pformat(lhs_match), 6), # description of action
  120. functools.partial(do_action, od, rule, lhs_match, rule_name) # the action itself (as a callback)
  121. )
  122. at_least_one_match = True
  123. except StopIteration:
  124. break
  125. return at_least_one_match
  126. # Given a list of actions (in high -> low priority), will always yield the highest priority enabled actions.
  127. class PriorityActionGenerator:
  128. def __init__(self, matcher_rewriter: RuleMatcherRewriter, rule_dicts: list[dict[str, Rule]]):
  129. self.generators = [ActionGenerator(matcher_rewriter, rule_dict) for rule_dict in rule_dicts]
  130. def __call__(self, od: ODAPI):
  131. for generator in self.generators:
  132. at_least_one_match = yield from generator(od)
  133. if at_least_one_match:
  134. return True
  135. return False
  136. # class ForAllGenerator:
  137. # def __init__(self, matcher_rewriter: RuleMatcherRewriter, rule_dict: dict[str, Rule]):
  138. # self.matcher_rewriter = matcher_rewriter
  139. # self.rule_dict = rule_dict
  140. # def __call__(self, od: ODAPI):
  141. # matches = []
  142. # for rule_name, rule in self.rule_dict.items():
  143. # for lhs_match in self.matcher_rewriter.match_rule(od.m, rule.lhs, rule.nacs, rule_name):
  144. # matches.append((rule_name, rule, lhs_match))
  145. # def do_action(matches):
  146. # pass
  147. # if len(matches) > 0:
  148. # yield (
  149. # [rule_name for rule_name, _, _ in matches]
  150. # )
  151. # return True
  152. # return False