vf2.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # This module contains a VF2-inspired graph matching algorithm
  2. # It defines its own Graph type, and can be used standalone (no dependencies on the rest of muMLE framework)
  3. # Author: Joeri Exelmans
  4. import itertools
  5. from util.timer import Timer, counted
  6. # like finding the 'strongly connected componenets', but edges are navigable in any direction
  7. def find_connected_components(graph):
  8. next_component = 0
  9. vtx_to_component = {}
  10. component_to_vtxs = []
  11. for vtx in graph.vtxs:
  12. if vtx in vtx_to_component:
  13. continue
  14. vtx_to_component[vtx] = next_component
  15. vtxs = []
  16. component_to_vtxs.append(vtxs)
  17. add_recursively(vtx, vtxs, vtx_to_component, next_component)
  18. next_component += 1
  19. return (vtx_to_component, component_to_vtxs)
  20. def add_recursively(vtx, vtxs: list, d: dict, component: int, already_visited: set = set()):
  21. if vtx in already_visited:
  22. return
  23. already_visited.add(vtx)
  24. vtxs.append(vtx)
  25. d[vtx] = component
  26. for edge in vtx.outgoing:
  27. add_recursively(edge.tgt, vtxs, d, component, already_visited)
  28. for edge in vtx.incoming:
  29. add_recursively(edge.src, vtxs, d, component, already_visited)
  30. class Graph:
  31. def __init__(self):
  32. self.vtxs = []
  33. self.edges = []
  34. class Vertex:
  35. def __init__(self, value):
  36. self.incoming = []
  37. self.outgoing = []
  38. self.value = value
  39. def __repr__(self):
  40. return f"V({self.value})"
  41. class Edge:
  42. def __init__(self, src: Vertex, tgt: Vertex, label=None):
  43. self.src = src
  44. self.tgt = tgt
  45. self.label = label
  46. # Add ourselves to src/tgt vertices
  47. self.src.outgoing.append(self)
  48. self.tgt.incoming.append(self)
  49. def __repr__(self):
  50. if self.label != None:
  51. return f"({self.src}--{self.label}->{self.tgt})"
  52. else:
  53. return f"({self.src}->{self.tgt})"
  54. class MatcherState:
  55. def __init__(self):
  56. self.mapping_vtxs = {} # guest -> host
  57. self.mapping_edges = {} # guest -> host
  58. self.r_mapping_vtxs = {} # host -> guest
  59. self.r_mapping_edges = {} # host -> guest
  60. self.h_unmatched_vtxs = []
  61. self.g_unmatched_vtxs = []
  62. # boundary is the most recently added (to the mapping) pair of (guest -> host) vertices
  63. self.boundary = None
  64. @staticmethod
  65. def make_initial(host, guest, pivot):
  66. state = MatcherState()
  67. state.h_unmatched_vtxs = [vtx for vtx in host.vtxs if vtx not in pivot.values()]
  68. state.g_unmatched_vtxs = [vtx for vtx in guest.vtxs if vtx not in pivot.keys()]
  69. # if guest_to_host_candidates != None:
  70. # state.g_unmatched_vtxs.sort(
  71. # # performance thingy:
  72. # # try to match guest vtxs with few candidates first (fail early!):
  73. # key=lambda guest_vtx: guest_to_host_candidates.get(guest_vtx, 0))
  74. state.mapping_vtxs = pivot
  75. state.r_mapping_vtxs = { v: k for k,v in state.mapping_vtxs.items() }
  76. return state
  77. # Grow the match set (creating a new copy)
  78. def grow_edge(self, host_edge, guest_edge):
  79. new_state = MatcherState()
  80. new_state.mapping_vtxs = self.mapping_vtxs
  81. new_state.mapping_edges = dict(self.mapping_edges)
  82. new_state.mapping_edges[guest_edge] = host_edge
  83. new_state.r_mapping_vtxs = self.r_mapping_vtxs
  84. new_state.r_mapping_edges = dict(self.r_mapping_edges)
  85. new_state.r_mapping_edges[host_edge] = guest_edge
  86. new_state.h_unmatched_vtxs = self.h_unmatched_vtxs
  87. new_state.g_unmatched_vtxs = self.g_unmatched_vtxs
  88. return new_state
  89. # Grow the match set (creating a new copy)
  90. def grow_vtx(self, host_vtx, guest_vtx):
  91. new_state = MatcherState()
  92. new_state.mapping_vtxs = dict(self.mapping_vtxs)
  93. new_state.mapping_vtxs[guest_vtx] = host_vtx
  94. new_state.mapping_edges = self.mapping_edges
  95. new_state.r_mapping_vtxs = dict(self.r_mapping_vtxs)
  96. new_state.r_mapping_vtxs[host_vtx] = guest_vtx
  97. new_state.r_mapping_edges = self.r_mapping_edges
  98. new_state.h_unmatched_vtxs = [h_vtx for h_vtx in self.h_unmatched_vtxs if h_vtx != host_vtx]
  99. new_state.g_unmatched_vtxs = [g_vtx for g_vtx in self.g_unmatched_vtxs if g_vtx != guest_vtx]
  100. new_state.boundary = (guest_vtx, host_vtx)
  101. return new_state
  102. def make_hashable(self):
  103. return frozenset(itertools.chain(
  104. ((gv,hv) for gv,hv in self.mapping_vtxs.items()),
  105. ((ge,he) for ge,he in self.mapping_edges.items()),
  106. ))
  107. def __repr__(self):
  108. # return self.make_hashable().__repr__()
  109. return "VTXS: "+self.mapping_vtxs.__repr__()+"\nEDGES: "+self.mapping_edges.__repr__()
  110. class MatcherVF2:
  111. # Guest is the pattern
  112. def __init__(self, host, guest, compare_fn, guest_to_host_candidates=None):
  113. self.host = host
  114. self.guest = guest
  115. self.compare_fn = compare_fn
  116. # map guest vertex to number of candidate vertices in host graph:
  117. if guest_to_host_candidates != None:
  118. self.guest_to_host_candidates = guest_to_host_candidates
  119. else:
  120. # atttempt to match every guest vertex with every host vertex (slow!)
  121. self.guest_to_host_candidates = { g_vtx : len(host.vtxs) for g_vtx in guest.vtxs }
  122. # with Timer("find_connected_components - guest"):
  123. self.guest_vtx_to_component, self.guest_component_to_vtxs = find_connected_components(guest)
  124. for component in self.guest_component_to_vtxs:
  125. pass
  126. # sort vertices in component such that the vertices of the rarest type (with the fewest element) occurs first
  127. component.sort(key=lambda guest_vtx: guest_to_host_candidates[guest_vtx])
  128. if len(self.guest_component_to_vtxs) > 1:
  129. print("warning: pattern has multiple components:", len(self.guest_component_to_vtxs))
  130. def match(self, pivot={}):
  131. yield from self._match(
  132. state=MatcherState.make_initial(self.host, self.guest, pivot),
  133. already_visited=set())
  134. # @counted
  135. def _match(self, state, already_visited, indent=0):
  136. # input()
  137. num_matches = 0
  138. def print_debug(*args):
  139. pass
  140. # print(" "*indent, *args) # uncomment to see a trace of the matching process
  141. print_debug("match")
  142. # Keep track of the states in the search space that we already visited
  143. hashable = state.make_hashable()
  144. if hashable in already_visited:
  145. print_debug(" SKIP - ALREADY VISITED")
  146. # print_debug(" ", hashable)
  147. return 0
  148. # print_debug(" ", [hash(a) for a in already_visited])
  149. # print_debug(" ADD STATE")
  150. # print_debug(" ", hash(hashable))
  151. already_visited.add(hashable)
  152. if len(state.mapping_vtxs) == len(self.guest.vtxs) and len(state.mapping_edges) == len(self.guest.edges):
  153. print_debug("GOT MATCH:")
  154. print_debug(" ", state.mapping_vtxs)
  155. print_debug(" ", state.mapping_edges)
  156. yield state
  157. return 1
  158. def read_edge(edge, direction):
  159. if direction == "outgoing":
  160. return edge.tgt
  161. elif direction == "incoming":
  162. return edge.src
  163. else:
  164. raise Exception("wtf!")
  165. def attempt_grow(direction, indent):
  166. num_matches = 0
  167. for g_matched_vtx, h_matched_vtx in state.mapping_vtxs.items():
  168. print_debug('attempt_grow', direction)
  169. for g_candidate_edge in getattr(g_matched_vtx, direction):
  170. print_debug('g_candidate_edge:', g_candidate_edge)
  171. g_candidate_vtx = read_edge(g_candidate_edge, direction)
  172. # g_to_skip_vtxs.add(g_candidate_vtx)
  173. if g_candidate_edge in state.mapping_edges:
  174. print_debug(" skip, guest edge already matched")
  175. continue # skip already matched guest edge
  176. for h_candidate_edge in getattr(h_matched_vtx, direction):
  177. if g_candidate_edge.label != h_candidate_edge.label:
  178. print_debug(" labels differ")
  179. continue
  180. print_debug('h_candidate_edge:', h_candidate_edge)
  181. if h_candidate_edge in state.r_mapping_edges:
  182. print_debug(" skip, host edge already matched")
  183. continue # skip already matched host edge
  184. print_debug('grow edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge))
  185. new_state = state.grow_edge(h_candidate_edge, g_candidate_edge)
  186. h_candidate_vtx = read_edge(h_candidate_edge, direction)
  187. num_matches += yield from attempt_match_vtxs(
  188. new_state,
  189. g_candidate_vtx,
  190. h_candidate_vtx,
  191. indent+1)
  192. print_debug('backtrack edge', g_candidate_edge, ':', h_candidate_edge, id(g_candidate_edge), id(h_candidate_edge))
  193. return num_matches
  194. def attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent):
  195. print_debug('attempt_match_vtxs')
  196. if g_candidate_vtx in state.mapping_vtxs:
  197. if state.mapping_vtxs[g_candidate_vtx] != h_candidate_vtx:
  198. print_debug(" nope, guest already mapped (mismatch)")
  199. return 0 # guest vtx is already mapped but doesn't match host vtx
  200. if h_candidate_vtx in state.r_mapping_vtxs:
  201. if state.r_mapping_vtxs[h_candidate_vtx] != g_candidate_vtx:
  202. print_debug(" nope, host already mapped (mismatch)")
  203. return 0 # host vtx is already mapped but doesn't match guest vtx
  204. g_outdegree = len(g_candidate_vtx.outgoing)
  205. h_outdegree = len(h_candidate_vtx.outgoing)
  206. if g_outdegree > h_outdegree:
  207. print_debug(" nope, outdegree")
  208. return 0
  209. g_indegree = len(g_candidate_vtx.incoming)
  210. h_indegree = len(h_candidate_vtx.incoming)
  211. if g_indegree > h_indegree:
  212. print_debug(" nope, indegree")
  213. return 0
  214. if not self.compare_fn(g_candidate_vtx, h_candidate_vtx):
  215. print_debug(" nope, bad compare")
  216. return 0
  217. new_state = state.grow_vtx(
  218. h_candidate_vtx,
  219. g_candidate_vtx)
  220. print_debug('grow vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx))
  221. num_matches = yield from self._match(new_state, already_visited, indent+1)
  222. print_debug('backtrack vtx', g_candidate_vtx, ':', h_candidate_vtx, id(g_candidate_vtx), id(h_candidate_vtx))
  223. return num_matches
  224. print_debug('preferred...')
  225. num_matches += yield from attempt_grow('outgoing', indent+1)
  226. num_matches += yield from attempt_grow('incoming', indent+1)
  227. if num_matches == 0:
  228. print_debug('least preferred...')
  229. if state.boundary != None:
  230. g_boundary_vtx, _ = state.boundary
  231. guest_boundary_component = self.guest_vtx_to_component[g_boundary_vtx]
  232. # only try guest vertices that are in a different component (all vertices in the same component are already discovered via 'attempt_grow')
  233. guest_components_to_try = (c for i,c in enumerate(self.guest_component_to_vtxs) if i != guest_boundary_component)
  234. # for the host vertices however, we have to try them from all components, because different connected components of our pattern (=guest) could be mapped onto the same connected component in the host
  235. else:
  236. guest_components_to_try = self.guest_component_to_vtxs
  237. for g_component in guest_components_to_try:
  238. # we only need to pick ONE vertex from the component
  239. # in the future, this can be optimized further by picking the vertex of the type with the fewest instances
  240. g_candidate_vtx = g_component[0]
  241. g_vtx_matches = 0
  242. g_vtx_max = self.guest_to_host_candidates[g_candidate_vtx]
  243. # print(' guest vtx has', g_vtx_max, ' host candidates')
  244. if g_candidate_vtx in state.mapping_vtxs:
  245. print_debug("skip (already matched)", g_candidate_vtx)
  246. continue
  247. for h_candidate_vtx in state.h_unmatched_vtxs:
  248. N = yield from attempt_match_vtxs(state, g_candidate_vtx, h_candidate_vtx, indent+1)
  249. g_vtx_matches += N > 0
  250. num_matches += N
  251. if g_vtx_matches == g_vtx_max:
  252. print("EARLY STOP")
  253. break # found all matches
  254. return num_matches
  255. # demo time...
  256. if __name__ == "__main__":
  257. host = Graph()
  258. host.vtxs = [Vertex(0), Vertex(1), Vertex(2), Vertex(3)]
  259. host.edges = [
  260. Edge(host.vtxs[0], host.vtxs[1]),
  261. Edge(host.vtxs[1], host.vtxs[2]),
  262. Edge(host.vtxs[2], host.vtxs[0]),
  263. Edge(host.vtxs[2], host.vtxs[3]),
  264. Edge(host.vtxs[3], host.vtxs[2]),
  265. ]
  266. guest = Graph()
  267. guest.vtxs = [
  268. Vertex('v != 3'), # cannot be matched with Vertex(3) - changing this to True, you get 2 morphisms instead of one
  269. Vertex('True')] # can be matched with any node
  270. guest.edges = [
  271. # Look for a simple loop:
  272. Edge(guest.vtxs[0], guest.vtxs[1]),
  273. # Edge(guest.vtxs[1], guest.vtxs[0]),
  274. ]
  275. m = MatcherVF2(host, guest, lambda g_vtx, h_vtx: eval(g_vtx.value, {}, {'v':h_vtx.value}))
  276. import time
  277. durations = 0
  278. iterations = 1
  279. print("Patience...")
  280. for n in range(iterations):
  281. time_start = time.perf_counter_ns()
  282. matches = [mm for mm in m.match()]
  283. time_end = time.perf_counter_ns()
  284. time_duration = time_end - time_start
  285. durations += time_duration
  286. print(f'{iterations} iterations, took {durations/1000000:.3f} ms, {durations/iterations/1000000:.3f} ms per iteration')
  287. print("found", len(matches), "matches")
  288. for mm in matches:
  289. print("match:")
  290. print(" ", mm.mapping_vtxs)
  291. print(" ", mm.mapping_edges)
  292. print("######################")
  293. host = Graph()
  294. host.vtxs = [
  295. Vertex('pony'), # 1
  296. Vertex('pony'), # 3
  297. Vertex('bear'),
  298. Vertex('bear'),
  299. ]
  300. host.edges = [
  301. # match:
  302. Edge(host.vtxs[0], host.vtxs[1]),
  303. Edge(host.vtxs[1], host.vtxs[0]),
  304. ]
  305. guest = Graph()
  306. guest.vtxs = [
  307. Vertex('pony'), # 0
  308. Vertex('pony'), # 1
  309. Vertex('bear')]
  310. guest.edges = [
  311. Edge(guest.vtxs[0], guest.vtxs[1]),
  312. Edge(guest.vtxs[1], guest.vtxs[0]),
  313. ]
  314. m = MatcherVF2(host, guest, lambda g_vtx, h_vtx: g_vtx.value == h_vtx.value)
  315. import time
  316. durations = 0
  317. iterations = 1
  318. print("Patience...")
  319. for n in range(iterations):
  320. time_start = time.perf_counter_ns()
  321. matches = [mm for mm in m.match()]
  322. time_end = time.perf_counter_ns()
  323. time_duration = time_end - time_start
  324. durations += time_duration
  325. print(f'{iterations} iterations, took {durations/1000000:.3f} ms, {durations/iterations/1000000:.3f} ms per iteration')
  326. print("found", len(matches), "matches")
  327. for mm in matches:
  328. print("match:")
  329. print(" ", mm.mapping_vtxs)
  330. print(" ", mm.mapping_edges)