rewrite.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import functools
  2. from typing import List, Callable, Generator
  3. from api.od import ODAPI
  4. from .exec_node import ExecNode
  5. from .data_node import DataNode
  6. from ..RuleExecuter import RuleExecuter
  7. class Rewrite(ExecNode, DataNode):
  8. def __init__(self, label: str) -> None:
  9. ExecNode.__init__(self, out_connections=1)
  10. DataNode.__init__(self)
  11. self.label = label
  12. self.rule = None
  13. self.rule_executer : RuleExecuter
  14. def init_rule(self, rule, rule_executer):
  15. self.rule = rule
  16. self.rule_executer= rule_executer
  17. def execute(self, od: ODAPI) -> Generator | None:
  18. yield "ghello", functools.partial(self.rewrite, od)
  19. def rewrite(self, od):
  20. print("rewrite" + self.label)
  21. pivot = {}
  22. if self.data_in is not None:
  23. pivot = self.get_input_data()[0]
  24. self.store_data(self.rule_executer.rewrite_rule(od.m, self.rule, pivot=pivot), 1)
  25. return ODAPI(od.state, od.m, od.mm),[f"rewrite {self.label}\n\tpivot: {pivot}\n\t{"success" if self.data_out.success else "failure"}\n"]
  26. def generate_dot(self, nodes: List[str], edges: List[str], visited: set[int]) -> None:
  27. if self.id in visited:
  28. return
  29. nodes.append(f"{self.id}[label=R_{self.label.split("/")[-1]}]")
  30. ExecNode.generate_dot(self, nodes, edges, visited)
  31. DataNode.generate_dot(self, nodes, edges, visited)