rule.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from concrete_syntax.textual_od.renderer import render_od
  2. import pprint
  3. from typing import Generator, Callable
  4. from uuid import UUID
  5. import functools
  6. from api.od import ODAPI
  7. from concrete_syntax.common import indent
  8. from transformation.matcher import match_od
  9. from transformation.rewriter import rewrite
  10. from transformation.cloner import clone_od
  11. from util.timer import Timer
  12. class Rule:
  13. def __init__(self, nacs: list[UUID], lhs: UUID, rhs: UUID):
  14. self.nacs = nacs
  15. self.lhs = lhs
  16. self.rhs = rhs
  17. PP = pprint.PrettyPrinter(depth=4)
  18. class _NAC_MATCHED(Exception):
  19. pass
  20. # Helper for executing NAC/LHS/RHS-type rules
  21. class RuleMatcherRewriter:
  22. def __init__(self, state, mm: UUID, mm_ramified: UUID, eval_context={}):
  23. self.state = state
  24. self.mm = mm
  25. self.mm_ramified = mm_ramified
  26. self.eval_context = eval_context
  27. # Generates matches.
  28. # Every match is a dictionary with entries LHS_element_name -> model_element_name
  29. def match_rule(self, m: UUID, lhs: UUID, nacs: list[UUID], rule_name: str) -> Generator[dict, None, None]:
  30. lhs_matcher = match_od(self.state,
  31. host_m=m,
  32. host_mm=self.mm,
  33. pattern_m=lhs,
  34. pattern_mm=self.mm_ramified,
  35. eval_context=self.eval_context,
  36. )
  37. try:
  38. # First we iterate over LHS-matches:
  39. # for i, lhs_match in enumerate(lhs_matcher):
  40. x=0
  41. while True:
  42. try:
  43. with Timer(f"MATCH LHS {rule_name}"):
  44. lhs_match = lhs_matcher.__next__()
  45. x += 1
  46. # Uncomment to see matches attempted - may give insight into why your rule is not matching
  47. # print(" lhs_match:", lhs_match)
  48. nac_matched = False
  49. with Timer(f"MATCH NACs {rule_name}"):
  50. try:
  51. for i_nac, nac in enumerate(nacs):
  52. # For every LHS-match, we see if there is a NAC-match:
  53. nac_matcher = match_od(self.state,
  54. host_m=m,
  55. host_mm=self.mm,
  56. pattern_m=nac,
  57. pattern_mm=self.mm_ramified,
  58. pivot=lhs_match, # try to "grow" LHS-match with NAC-match
  59. eval_context=self.eval_context,
  60. )
  61. try:
  62. # for nac_match in nac_matcher:
  63. while True:
  64. try:
  65. with Timer(f"MATCH NAC{i_nac} {rule_name}"):
  66. nac_match = nac_matcher.__next__()
  67. # The NAC has at least one match
  68. # (there could be more, but we know enough, so let's not waste CPU/MEM resources and proceed to next LHS match)
  69. raise _NAC_MATCHED()
  70. except StopIteration:
  71. break # no more nac-matches
  72. except Exception as e:
  73. # The exception may originate from eval'ed condition-code in LHS or NAC
  74. # Decorate exception with some context, to help with debugging
  75. e.add_note(f"while matching NAC of '{rule_name}'")
  76. raise
  77. except _NAC_MATCHED:
  78. continue # continue with next LHS-match
  79. # There were no NAC matches -> yield LHS-match!
  80. yield lhs_match
  81. except StopIteration:
  82. break # no more lhs-matches
  83. except Exception as e:
  84. # The exception may originate from eval'ed condition-code in LHS or NAC
  85. # Decorate exception with some context, to help with debugging
  86. e.add_note(f"while matching LHS of '{rule_name}'")
  87. raise
  88. def exec_rule(self, m: UUID, lhs: UUID, rhs: UUID, lhs_match: dict, rule_name: str, in_place=False):
  89. if in_place:
  90. # dangerous
  91. cloned_m = m
  92. else:
  93. cloned_m = clone_od(self.state, m, self.mm)
  94. # print('before clone:')
  95. # print(render_od(self.state, m, self.mm))
  96. # print('after clone:')
  97. # print(render_od(self.state, cloned_m, self.mm))
  98. try:
  99. rhs_match = rewrite(self.state,
  100. lhs_m=lhs,
  101. rhs_m=rhs,
  102. pattern_mm=self.mm_ramified,
  103. lhs_match=lhs_match,
  104. host_m=cloned_m,
  105. host_mm=self.mm,
  106. eval_context=self.eval_context,
  107. )
  108. except Exception as e:
  109. # Make exceptions raised in eval'ed code easier to trace:
  110. e.add_note(f"while executing RHS of '{rule_name}'")
  111. raise
  112. return (cloned_m, rhs_match)
  113. # This is often what you want: find a match, and execute the rule
  114. def exec_on_first_match(self, host: UUID, rule: Rule, rule_name: str, in_place=False):
  115. for lhs_match in self.match_rule(host, rule.lhs, rule.nacs, rule_name):
  116. (rewritten_host, rhs_match) = self.exec_rule(host, rule.lhs, rule.rhs, lhs_match, rule_name, in_place)
  117. return rewritten_host, lhs_match, rhs_match
  118. # Generator that yields actions in the format expected by 'Simulator' class
  119. class ActionGenerator:
  120. def __init__(self, matcher_rewriter: RuleMatcherRewriter, rule_dict: dict[str, Rule]):
  121. self.matcher_rewriter = matcher_rewriter
  122. self.rule_dict = rule_dict
  123. def __call__(self, od: ODAPI):
  124. at_least_one_match = False
  125. for rule_name, rule in self.rule_dict.items():
  126. match_iterator = self.matcher_rewriter.match_rule(od.m, rule.lhs, rule.nacs, rule_name)
  127. x = 0
  128. while True:
  129. try:
  130. # if True:
  131. with Timer(f"MATCH RULE {rule_name}"):
  132. lhs_match = match_iterator.__next__()
  133. x += 1
  134. # We got a match!
  135. def do_action(od, rule, lhs_match, rule_name):
  136. with Timer(f"EXEC RHS {rule_name}"):
  137. new_m, rhs_match = self.matcher_rewriter.exec_rule(od.m, rule.lhs, rule.rhs, lhs_match, rule_name)
  138. msgs = [f"executed rule '{rule_name}'\n" + indent(PP.pformat(rhs_match), 6)]
  139. return (ODAPI(od.state, new_m, od.mm), msgs)
  140. yield (
  141. rule_name + '\n' + indent(PP.pformat(lhs_match), 6), # description of action
  142. functools.partial(do_action, od, rule, lhs_match, rule_name) # the action itself (as a callback)
  143. )
  144. at_least_one_match = True
  145. except StopIteration:
  146. break
  147. return at_least_one_match
  148. # Given a list of actions (in high -> low priority), will always yield the highest priority enabled actions.
  149. class PriorityActionGenerator:
  150. def __init__(self, matcher_rewriter: RuleMatcherRewriter, rule_dicts: list[dict[str, Rule]]):
  151. self.generators = [ActionGenerator(matcher_rewriter, rule_dict) for rule_dict in rule_dicts]
  152. def __call__(self, od: ODAPI):
  153. for generator in self.generators:
  154. at_least_one_match = yield from generator(od)
  155. if at_least_one_match:
  156. return True
  157. return False
  158. # class ForAllGenerator:
  159. # def __init__(self, matcher_rewriter: RuleMatcherRewriter, rule_dict: dict[str, Rule]):
  160. # self.matcher_rewriter = matcher_rewriter
  161. # self.rule_dict = rule_dict
  162. # def __call__(self, od: ODAPI):
  163. # matches = []
  164. # for rule_name, rule in self.rule_dict.items():
  165. # for lhs_match in self.matcher_rewriter.match_rule(od.m, rule.lhs, rule.nacs, rule_name):
  166. # matches.append((rule_name, rule, lhs_match))
  167. # def do_action(matches):
  168. # pass
  169. # if len(matches) > 0:
  170. # yield (
  171. # [rule_name for rule_name, _, _ in matches]
  172. # )
  173. # return True
  174. # return False