vf2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # This module contains a VF2-inspired graph matching algorithm
  2. # It defines its own Graph type, and can be used standalone (no dependencies on the rest of muMLE framework)
  3. # Author: Joeri Exelmans
  4. import itertools
  5. from util.timer import Timer
  6. # like finding the 'strongly connected componenets', but edges are navigable in any direction
  7. def find_connected_components(graph):
  8. next_component = 0
  9. vtx_to_component = {}
  10. component_to_vtxs = []
  11. for vtx in graph.vtxs:
  12. if vtx in vtx_to_component:
  13. continue
  14. vtx_to_component[vtx] = next_component
  15. vtxs = []
  16. component_to_vtxs.append(vtxs)
  17. add_recursively(vtx, vtxs, vtx_to_component, next_component)
  18. next_component += 1
  19. return (vtx_to_component, component_to_vtxs)
  20. def add_recursively(vtx, vtxs: list, d: dict, component: int, already_visited: set = set()):
  21. if vtx in already_visited:
  22. return
  23. already_visited.add(vtx)
  24. vtxs.append(vtx)
  25. d[vtx] = component
  26. for edge in vtx.outgoing:
  27. add_recursively(edge.tgt, vtxs, d, component, already_visited)
  28. for edge in vtx.incoming:
  29. add_recursively(edge.src, vtxs, d, component, already_visited)
  30. class Graph:
  31. def __init__(self):
  32. self.vtxs = []
  33. self.edges = []
  34. class Vertex:
  35. def __init__(self, value):
  36. self.incoming = []
  37. self.outgoing = []
  38. self.value = value
  39. def __repr__(self):
  40. return f"V({self.value})"
  41. class Edge:
  42. def __init__(self, src: Vertex, tgt: Vertex, label=None):
  43. self.src = src
  44. self.tgt = tgt
  45. self.label = label
  46. # Add ourselves to src/tgt vertices
  47. self.src.outgoing.append(self)
  48. self.tgt.incoming.append(self)
  49. def __repr__(self):
  50. if self.label != None:
  51. return f"({self.src}--{self.label}->{self.tgt})"
  52. else:
  53. return f"({self.src}->{self.tgt})"
  54. class MatcherState:
  55. def __init__(self):
  56. self.mapping_vtxs = {} # guest -> host
  57. self.mapping_edges = {} # guest -> host
  58. self.r_mapping_vtxs = {} # host -> guest
  59. self.r_mapping_edges = {} # host -> guest
  60. self.h_unmatched_vtxs = []
  61. self.g_unmatched_vtxs = []
  62. # boundary is the most recently added (to the mapping) pair of (guest -> host) vertices
  63. self.boundary = None
  64. @staticmethod
  65. def make_initial(host, guest, pivot):
  66. state = MatcherState()
  67. state.h_unmatched_vtxs = [vtx for vtx in host.vtxs if vtx not in pivot.values()]
  68. state.g_unmatched_vtxs = [vtx for vtx in guest.vtxs if vtx not in pivot.keys()]
  69. state.mapping_vtxs = pivot
  70. state.r_mapping_vtxs = { v: k for k,v in state.mapping_vtxs.items() }
  71. return state
  72. # Grow the match set (creating a new copy)
  73. def grow_edge(self, host_edge, guest_edge):
  74. new_state = MatcherState()
  75. new_state.mapping_vtxs = self.mapping_vtxs
  76. new_state.mapping_edges = dict(self.mapping_edges)
  77. new_state.mapping_edges[guest_edge] = host_edge
  78. new_state.r_mapping_vtxs = self.r_mapping_vtxs
  79. new_state.r_mapping_edges = dict(self.r_mapping_edges)
  80. new_state.r_mapping_edges[host_edge] = guest_edge
  81. new_state.h_unmatched_vtxs = self.h_unmatched_vtxs
  82. new_state.g_unmatched_vtxs = self.g_unmatched_vtxs
  83. return new_state
  84. # Grow the match set (creating a new copy)
  85. def grow_vtx(self, host_vtx, guest_vtx):
  86. new_state = MatcherState()
  87. new_state.mapping_vtxs = dict(self.mapping_vtxs)
  88. new_state.mapping_vtxs[guest_vtx] = host_vtx
  89. new_state.mapping_edges = self.mapping_edges
  90. new_state.r_mapping_vtxs = dict(self.r_mapping_vtxs)
  91. new_state.r_mapping_vtxs[host_vtx] = guest_vtx
  92. new_state.r_mapping_edges = self.r_mapping_edges
  93. new_state.h_unmatched_vtxs = [h_vtx for h_vtx in self.h_unmatched_vtxs if h_vtx != host_vtx]
  94. new_state.g_unmatched_vtxs = [g_vtx for g_vtx in self.g_unmatched_vtxs if g_vtx != guest_vtx]
  95. new_state.boundary = (guest_vtx, host_vtx)
  96. return new_state
  97. def make_hashable(self):
  98. return frozenset(itertools.chain(
  99. ((gv,hv) for gv,hv in self.mapping_vtxs.items()),
  100. ((ge,he) for ge,he in self.mapping_edges.items()),
  101. ))
  102. def __repr__(self):
  103. # return self.make_hashable().__repr__()
  104. return "VTXS: "+self.mapping_vtxs.__repr__()+"\nEDGES: "+self.mapping_edges.__repr__()
  105. class MatcherVF2:
  106. # Guest is the pattern
  107. def __init__(self, host, guest, compare_fn):
  108. self.host = host
  109. self.guest = guest
  110. self.compare_fn = compare_fn
  111. # with Timer("find_connected_components - guest"):
  112. self.guest_vtx_to_component, self.guest_component_to_vtxs = find_connected_components(guest)
  113. # print("number of guest connected components:", len(self.guest_component_to_vtxs))
  114. def match(self, pivot={}):
  115. yield from self._match(
  116. state=MatcherState.make_initial(self.host, self.guest, pivot),
  117. already_visited=set())
  118. def _match(self, state, already_visited, indent=0):
  119. # input()
  120. def print_debug(*args):
  121. pass
  122. # print(" "*indent, *args) # uncomment to see a trace of the matching process
  123. print_debug("match")
  124. # Keep track of the states in the search space that we already visited
  125. hashable = state.make_hashable()
  126. if hashable in already_visited:
  127. print_debug(" SKIP - ALREADY VISITED")
  128. # print_debug(" ", hashable)
  129. return
  130. # print_debug(" ", [hash(a) for a in already_visited])
  131. # print_debug(" ADD STATE")
  132. # print_debug(" ", hash(hashable))
  133. already_visited.add(hashable)
  134. if len(state.mapping_vtxs) == len(self.guest.vtxs) and len(state.mapping_edges) == len(self.guest.edges):
  135. print_debug("GOT MATCH:")
  136. print_debug(" ", state.mapping_vtxs)
  137. print_debug(" ", state.mapping_edges)
  138. yield state
  139. return
  140. def read_edge(edge, direction):
  141. if direction == "outgoing":
  142. return edge.tgt
  143. elif direction == "incoming":
  144. return edge.src
  145. else:
  146. raise Exception("wtf!")
  147. def attempt_grow(direction, indent):
  148. for g_matched_vtx, h_matched_vtx in state.mapping_vtxs.items():
  149. print_debug('attempt_grow', direction)
  150. for g_candidate_edge in getattr(g_matched_vtx, direction):
  151. print_debug('g_candidate_edge:', g_candidate_edge)
  152. g_candidate_vtx = read_edge(g_candidate_edge, direction)
  153. # g_to_skip_vtxs.add(g_candidate_vtx)
  154. if g_candidate_edge in state.mapping_edges:
  155. print_debug(" skip, guest edge already matched")
  156. continue # skip already matched guest edge
  157. for h_candidate_edge in getattr(h_matched_vtx, direction):
  158. if g_candidate_edge.label != h_candidate_edge.label:
  159. print_debug(" labels differ")
  160. continue
  161. print_debug('h_candidate_edge:', h_candidate_edge)
  162. if h_candidate_edge in state.r_mapping_edges:
  163. print_debug(" skip, host edge already matched")
  164. continue # skip already matched host edge
  165. print_debug('grow edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge))
  166. new_state = state.grow_edge(h_candidate_edge, g_candidate_edge)
  167. h_candidate_vtx = read_edge(h_candidate_edge, direction)
  168. yield from attempt_match_vtxs(
  169. new_state,
  170. g_candidate_vtx,
  171. h_candidate_vtx,
  172. indent+1)
  173. print_debug('backtrack edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge))
  174. def attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent):
  175. print_debug('attempt_match_vtxs')
  176. if g_candidate_vtx in state.mapping_vtxs:
  177. if state.mapping_vtxs[g_candidate_vtx] != h_candidate_vtx:
  178. print_debug(" nope, guest already mapped (mismatch)")
  179. return # guest vtx is already mapped but doesn't match host vtx
  180. if h_candidate_vtx in state.r_mapping_vtxs:
  181. if state.r_mapping_vtxs[h_candidate_vtx] != g_candidate_vtx:
  182. print_debug(" nope, host already mapped (mismatch)")
  183. return # host vtx is already mapped but doesn't match guest vtx
  184. g_outdegree = len(g_candidate_vtx.outgoing)
  185. h_outdegree = len(h_candidate_vtx.outgoing)
  186. if g_outdegree > h_outdegree:
  187. print_debug(" nope, outdegree")
  188. return
  189. g_indegree = len(g_candidate_vtx.incoming)
  190. h_indegree = len(h_candidate_vtx.incoming)
  191. if g_indegree > h_indegree:
  192. print_debug(" nope, indegree")
  193. return
  194. if not self.compare_fn(g_candidate_vtx, h_candidate_vtx):
  195. print_debug(" nope, bad compare")
  196. return
  197. new_state = state.grow_vtx(
  198. h_candidate_vtx,
  199. g_candidate_vtx)
  200. print_debug('grow vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx))
  201. yield from self._match(new_state, already_visited, indent+1)
  202. print_debug('backtrack vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx))
  203. print_debug('preferred...')
  204. yield from attempt_grow('outgoing', indent+1)
  205. yield from attempt_grow('incoming', indent+1)
  206. print_debug('least preferred...')
  207. if state.boundary != None:
  208. g_boundary_vtx, _ = state.boundary
  209. guest_boundary_component = self.guest_vtx_to_component[g_boundary_vtx]
  210. # only try guest vertices that are in a different component (all vertices in the same component are already discovered via 'attempt_grow')
  211. guest_components_to_try = (c for i,c in enumerate(self.guest_component_to_vtxs) if i != guest_boundary_component)
  212. # for the host vertices however, we have to try them from all components, because different connected components of our pattern (=guest) could be mapped onto the same connected component in the host
  213. else:
  214. guest_components_to_try = self.guest_component_to_vtxs
  215. for g_candidate_vtxs in guest_components_to_try:
  216. for g_candidate_vtx in g_candidate_vtxs:
  217. if g_candidate_vtx in state.mapping_vtxs:
  218. print_debug("skip (already matched)", g_candidate_vtx)
  219. continue
  220. for h_candidate_vtx in state.h_unmatched_vtxs:
  221. yield from attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent+1)
  222. if indent == 0:
  223. print_debug('visited', len(already_visited), 'states total')
  224. # demo time...
  225. if __name__ == "__main__":
  226. host = Graph()
  227. host.vtxs = [Vertex(0), Vertex(1), Vertex(2), Vertex(3)]
  228. host.edges = [
  229. Edge(host.vtxs[0], host.vtxs[1]),
  230. Edge(host.vtxs[1], host.vtxs[2]),
  231. Edge(host.vtxs[2], host.vtxs[0]),
  232. Edge(host.vtxs[2], host.vtxs[3]),
  233. Edge(host.vtxs[3], host.vtxs[2]),
  234. ]
  235. guest = Graph()
  236. guest.vtxs = [
  237. Vertex('v != 3'), # cannot be matched with Vertex(3) - changing this to True, you get 2 morphisms instead of one
  238. Vertex('True')] # can be matched with any node
  239. guest.edges = [
  240. # Look for a simple loop:
  241. Edge(guest.vtxs[0], guest.vtxs[1]),
  242. # Edge(guest.vtxs[1], guest.vtxs[0]),
  243. ]
  244. m = MatcherVF2(host, guest, lambda g_vtx, h_vtx: eval(g_vtx.value, {}, {'v':h_vtx.value}))
  245. import time
  246. durations = 0
  247. iterations = 1
  248. print("Patience...")
  249. for n in range(iterations):
  250. time_start = time.perf_counter_ns()
  251. matches = [mm for mm in m.match()]
  252. time_end = time.perf_counter_ns()
  253. time_duration = time_end - time_start
  254. durations += time_duration
  255. print(f'{iterations} iterations, took {durations/1000000:.3f} ms, {durations/iterations/1000000:.3f} ms per iteration')
  256. print("found", len(matches), "matches")
  257. for mm in matches:
  258. print("match:")
  259. print(" ", mm.mapping_vtxs)
  260. print(" ", mm.mapping_edges)
  261. print("######################")
  262. host = Graph()
  263. host.vtxs = [
  264. Vertex('pony'), # 1
  265. Vertex('pony'), # 3
  266. Vertex('bear'),
  267. Vertex('bear'),
  268. ]
  269. host.edges = [
  270. # match:
  271. Edge(host.vtxs[0], host.vtxs[1]),
  272. Edge(host.vtxs[1], host.vtxs[0]),
  273. ]
  274. guest = Graph()
  275. guest.vtxs = [
  276. Vertex('pony'), # 0
  277. Vertex('pony'), # 1
  278. Vertex('bear')]
  279. guest.edges = [
  280. Edge(guest.vtxs[0], guest.vtxs[1]),
  281. Edge(guest.vtxs[1], guest.vtxs[0]),
  282. ]
  283. m = MatcherVF2(host, guest, lambda g_vtx, h_vtx: g_vtx.value == h_vtx.value)
  284. import time
  285. durations = 0
  286. iterations = 1
  287. print("Patience...")
  288. for n in range(iterations):
  289. time_start = time.perf_counter_ns()
  290. matches = [mm for mm in m.match()]
  291. time_end = time.perf_counter_ns()
  292. time_duration = time_end - time_start
  293. durations += time_duration
  294. print(f'{iterations} iterations, took {durations/1000000:.3f} ms, {durations/iterations/1000000:.3f} ms per iteration')
  295. print("found", len(matches), "matches")
  296. for mm in matches:
  297. print("match:")
  298. print(" ", mm.mapping_vtxs)
  299. print(" ", mm.mapping_edges)