matcher.py 13 KB


  1. # This module contains a VF2-inspired graph matching algorithm
  2. # Author: Joeri Exelmans
  3. import itertools
  4. from util.timer import Timer
  5. # like finding the 'strongly connected componenets', but edges are navigable in any direction
  6. def find_connected_components(graph):
  7. next_component = 0
  8. vtx_to_component = {}
  9. component_to_vtxs = []
  10. for vtx in graph.vtxs:
  11. if vtx in vtx_to_component:
  12. continue
  13. vtx_to_component[vtx] = next_component
  14. vtxs = []
  15. component_to_vtxs.append(vtxs)
  16. add_recursively(vtx, vtxs, vtx_to_component, next_component)
  17. next_component += 1
  18. return (vtx_to_component, component_to_vtxs)
  19. def add_recursively(vtx, vtxs: list, d: dict, component: int, already_visited: set = set()):
  20. if vtx in already_visited:
  21. return
  22. already_visited.add(vtx)
  23. vtxs.append(vtx)
  24. d[vtx] = component
  25. for edge in vtx.outgoing:
  26. add_recursively(edge.tgt, vtxs, d, component, already_visited)
  27. for edge in vtx.incoming:
  28. add_recursively(edge.src, vtxs, d, component, already_visited)
  29. class Graph:
  30. def __init__(self):
  31. self.vtxs = []
  32. self.edges = []
  33. class Vertex:
  34. def __init__(self, value):
  35. self.incoming = []
  36. self.outgoing = []
  37. self.value = value
  38. def __repr__(self):
  39. return f"V({self.value})"
  40. class Edge:
  41. def __init__(self, src: Vertex, tgt: Vertex, label=None):
  42. self.src = src
  43. self.tgt = tgt
  44. self.label = label
  45. # Add ourselves to src/tgt vertices
  46. self.src.outgoing.append(self)
  47. self.tgt.incoming.append(self)
  48. def __repr__(self):
  49. if self.label != None:
  50. return f"({self.src}--{self.label}->{self.tgt})"
  51. else:
  52. return f"({self.src}->{self.tgt})"
  53. class MatcherState:
  54. def __init__(self):
  55. self.mapping_vtxs = {} # guest -> host
  56. self.mapping_edges = {} # guest -> host
  57. self.r_mapping_vtxs = {} # host -> guest
  58. self.r_mapping_edges = {} # host -> guest
  59. self.h_unmatched_vtxs = []
  60. self.g_unmatched_vtxs = []
  61. # boundary is the most recently added (to the mapping) pair of (guest -> host) vertices
  62. self.boundary = None
  63. @staticmethod
  64. def make_initial(host, guest):
  65. state = MatcherState()
  66. state.h_unmatched_vtxs = host.vtxs
  67. state.g_unmatched_vtxs = guest.vtxs
  68. return state
  69. # Grow the match set (creating a new copy)
  70. def grow_edge(self, host_edge, guest_edge):
  71. new_state = MatcherState()
  72. new_state.mapping_vtxs = self.mapping_vtxs
  73. new_state.mapping_edges = dict(self.mapping_edges)
  74. new_state.mapping_edges[guest_edge] = host_edge
  75. new_state.r_mapping_vtxs = self.r_mapping_vtxs
  76. new_state.r_mapping_edges = dict(self.r_mapping_edges)
  77. new_state.r_mapping_edges[host_edge] = guest_edge
  78. new_state.h_unmatched_vtxs = self.h_unmatched_vtxs
  79. new_state.g_unmatched_vtxs = self.g_unmatched_vtxs
  80. return new_state
  81. # Grow the match set (creating a new copy)
  82. def grow_vtx(self, host_vtx, guest_vtx):
  83. new_state = MatcherState()
  84. new_state.mapping_vtxs = dict(self.mapping_vtxs)
  85. new_state.mapping_vtxs[guest_vtx] = host_vtx
  86. new_state.mapping_edges = self.mapping_edges
  87. new_state.r_mapping_vtxs = dict(self.r_mapping_vtxs)
  88. new_state.r_mapping_vtxs[host_vtx] = guest_vtx
  89. new_state.r_mapping_edges = self.r_mapping_edges
  90. new_state.h_unmatched_vtxs = [h_vtx for h_vtx in self.h_unmatched_vtxs if h_vtx != host_vtx]
  91. new_state.g_unmatched_vtxs = [g_vtx for g_vtx in self.g_unmatched_vtxs if g_vtx != guest_vtx]
  92. new_state.boundary = (guest_vtx, host_vtx)
  93. return new_state
  94. def make_hashable(self):
  95. return frozenset(itertools.chain(
  96. ((gv,hv) for gv,hv in self.mapping_vtxs.items()),
  97. ((ge,he) for ge,he in self.mapping_edges.items()),
  98. ))
  99. def __repr__(self):
  100. # return self.make_hashable().__repr__()
  101. return "VTXS: "+self.mapping_vtxs.__repr__()+"\nEDGES: "+self.mapping_edges.__repr__()
  102. class MatcherVF2:
  103. # Guest is the pattern
  104. def __init__(self, host, guest, compare_fn):
  105. self.host = host
  106. self.guest = guest
  107. self.compare_fn = compare_fn
  108. # with Timer("find_connected_components - guest"):
  109. self.guest_vtx_to_component, self.guest_component_to_vtxs = find_connected_components(guest)
  110. # print("number of guest connected components:", len(self.guest_component_to_vtxs))
  111. def match(self):
  112. yield from self._match(
  113. state=MatcherState.make_initial(self.host, self.guest),
  114. already_visited=set())
  115. def _match(self, state, already_visited, indent=0):
  116. # input()
  117. def print_debug(*args):
  118. pass
  119. # print(" "*indent, *args) # uncomment to see a trace of the matching process
  120. print_debug("match")
  121. # Keep track of the states in the search space that we already visited
  122. hashable = state.make_hashable()
  123. if hashable in already_visited:
  124. print_debug(" SKIP - ALREADY VISITED")
  125. # print_debug(" ", hashable)
  126. return
  127. # print_debug(" ", [hash(a) for a in already_visited])
  128. # print_debug(" ADD STATE")
  129. # print_debug(" ", hash(hashable))
  130. already_visited.add(hashable)
  131. if len(state.mapping_vtxs) == len(self.guest.vtxs) and len(state.mapping_edges) == len(self.guest.edges):
  132. print_debug("GOT MATCH:")
  133. print_debug(" ", state.mapping_vtxs)
  134. print_debug(" ", state.mapping_edges)
  135. yield state
  136. return
  137. def read_edge(edge, direction):
  138. if direction == "outgoing":
  139. return edge.tgt
  140. elif direction == "incoming":
  141. return edge.src
  142. else:
  143. raise Exception("wtf!")
  144. def attempt_grow(direction, indent):
  145. for g_matched_vtx, h_matched_vtx in state.mapping_vtxs.items():
  146. print_debug('attempt_grow', direction)
  147. for g_candidate_edge in getattr(g_matched_vtx, direction):
  148. print_debug('g_candidate_edge:', g_candidate_edge)
  149. g_candidate_vtx = read_edge(g_candidate_edge, direction)
  150. # g_to_skip_vtxs.add(g_candidate_vtx)
  151. if g_candidate_edge in state.mapping_edges:
  152. print_debug(" skip, guest edge already matched")
  153. continue # skip already matched guest edge
  154. for h_candidate_edge in getattr(h_matched_vtx, direction):
  155. if g_candidate_edge.label != h_candidate_edge.label:
  156. print_debug(" labels differ")
  157. continue
  158. print_debug('h_candidate_edge:', h_candidate_edge)
  159. if h_candidate_edge in state.r_mapping_edges:
  160. print_debug(" skip, host edge already matched")
  161. continue # skip already matched host edge
  162. print_debug('grow edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge))
  163. new_state = state.grow_edge(h_candidate_edge, g_candidate_edge)
  164. h_candidate_vtx = read_edge(h_candidate_edge, direction)
  165. yield from attempt_match_vtxs(
  166. new_state,
  167. g_candidate_vtx,
  168. h_candidate_vtx,
  169. indent+1)
  170. print_debug('backtrack edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge))
  171. def attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent):
  172. print_debug('attempt_match_vtxs')
  173. if g_candidate_vtx in state.mapping_vtxs:
  174. if state.mapping_vtxs[g_candidate_vtx] != h_candidate_vtx:
  175. print_debug(" nope, guest already mapped (mismatch)")
  176. return # guest vtx is already mapped but doesn't match host vtx
  177. if h_candidate_vtx in state.r_mapping_vtxs:
  178. if state.r_mapping_vtxs[h_candidate_vtx] != g_candidate_vtx:
  179. print_debug(" nope, host already mapped (mismatch)")
  180. return # host vtx is already mapped but doesn't match guest vtx
  181. g_outdegree = len(g_candidate_vtx.outgoing)
  182. h_outdegree = len(h_candidate_vtx.outgoing)
  183. if g_outdegree > h_outdegree:
  184. print_debug(" nope, outdegree")
  185. return
  186. g_indegree = len(g_candidate_vtx.incoming)
  187. h_indegree = len(h_candidate_vtx.incoming)
  188. if g_indegree > h_indegree:
  189. print_debug(" nope, indegree")
  190. return
  191. if not self.compare_fn(g_candidate_vtx, h_candidate_vtx):
  192. print_debug(" nope, bad compare")
  193. return
  194. new_state = state.grow_vtx(
  195. h_candidate_vtx,
  196. g_candidate_vtx)
  197. print_debug('grow vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx))
  198. yield from self._match(new_state, already_visited, indent+1)
  199. print_debug('backtrack vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx))
  200. print_debug('preferred...')
  201. yield from attempt_grow('outgoing', indent+1)
  202. yield from attempt_grow('incoming', indent+1)
  203. print_debug('least preferred...')
  204. if state.boundary != None:
  205. g_boundary_vtx, _ = state.boundary
  206. guest_boundary_component = self.guest_vtx_to_component[g_boundary_vtx]
  207. # only try guest vertices that are in a different component (all vertices in the same component are already discovered via 'attempt_grow')
  208. guest_components_to_try = (c for i,c in enumerate(self.guest_component_to_vtxs) if i != guest_boundary_component)
  209. # for the host vertices however, we have to try them from all components, because different connected components of our pattern (=guest) could be mapped onto the same connected component in the host
  210. else:
  211. guest_components_to_try = self.guest_component_to_vtxs
  212. for g_candidate_vtxs in guest_components_to_try:
  213. for g_candidate_vtx in g_candidate_vtxs:
  214. if g_candidate_vtx in state.mapping_vtxs:
  215. print_debug("skip (already matched)", g_candidate_vtx)
  216. continue
  217. for h_candidate_vtx in state.h_unmatched_vtxs:
  218. yield from attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent+1)
  219. if indent == 0:
  220. print_debug('visited', len(already_visited), 'states total')
  221. # demo time...
  222. if __name__ == "__main__":
  223. host = Graph()
  224. host.vtxs = [Vertex(0), Vertex(1), Vertex(2), Vertex(3)]
  225. host.edges = [
  226. Edge(host.vtxs[0], host.vtxs[1]),
  227. Edge(host.vtxs[1], host.vtxs[2]),
  228. Edge(host.vtxs[2], host.vtxs[0]),
  229. Edge(host.vtxs[2], host.vtxs[3]),
  230. Edge(host.vtxs[3], host.vtxs[2]),
  231. ]
  232. guest = Graph()
  233. guest.vtxs = [
  234. Vertex('v != 3'), # cannot be matched with Vertex(3) - changing this to True, you get 2 morphisms instead of one
  235. Vertex('True')] # can be matched with any node
  236. guest.edges = [
  237. # Look for a simple loop:
  238. Edge(guest.vtxs[0], guest.vtxs[1]),
  239. # Edge(guest.vtxs[1], guest.vtxs[0]),
  240. ]
  241. m = MatcherVF2(host, guest, lambda g_vtx, h_vtx: eval(g_vtx.value, {}, {'v':h_vtx.value}))
  242. import time
  243. durations = 0
  244. iterations = 1
  245. print("Patience...")
  246. for n in range(iterations):
  247. time_start = time.perf_counter_ns()
  248. matches = [mm for mm in m.match()]
  249. time_end = time.perf_counter_ns()
  250. time_duration = time_end - time_start
  251. durations += time_duration
  252. print(f'{iterations} iterations, took {durations/1000000:.3f} ms, {durations/iterations/1000000:.3f} ms per iteration')
  253. print("found", len(matches), "matches")
  254. for mm in matches:
  255. print("match:")
  256. print(" ", mm.mapping_vtxs)
  257. print(" ", mm.mapping_edges)
  258. print("######################")
  259. host = Graph()
  260. host.vtxs = [
  261. Vertex('pony'), # 1
  262. Vertex('pony'), # 3
  263. Vertex('bear'),
  264. Vertex('bear'),
  265. ]
  266. host.edges = [
  267. # match:
  268. Edge(host.vtxs[0], host.vtxs[1]),
  269. Edge(host.vtxs[1], host.vtxs[0]),
  270. ]
  271. guest = Graph()
  272. guest.vtxs = [
  273. Vertex('pony'), # 0
  274. Vertex('pony'), # 1
  275. Vertex('bear')]
  276. guest.edges = [
  277. Edge(guest.vtxs[0], guest.vtxs[1]),
  278. Edge(guest.vtxs[1], guest.vtxs[0]),
  279. ]
  280. m = MatcherVF2(host, guest, lambda g_vtx, h_vtx: g_vtx.value == h_vtx.value)
  281. import time
  282. durations = 0
  283. iterations = 1
  284. print("Patience...")
  285. for n in range(iterations):
  286. time_start = time.perf_counter_ns()
  287. matches = [mm for mm in m.match()]
  288. time_end = time.perf_counter_ns()
  289. time_duration = time_end - time_start
  290. durations += time_duration
  291. print(f'{iterations} iterations, took {durations/1000000:.3f} ms, {durations/iterations/1000000:.3f} ms per iteration')
  292. print("found", len(matches), "matches")
  293. for mm in matches:
  294. print("match:")
  295. print(" ", mm.mapping_vtxs)
  296. print(" ", mm.mapping_edges)