Explorar o código

Increased performance a bit by changing MvS implementation slightly

Yentl Van Tendeloo %!s(int64=8) %!d(string=hai) anos
pai
achega
b9a7e2cb5b
Modificáronse 1 ficheiros con 127 adicións e 116 borrados
  1. 127 116
      state/modelverse_state/main.py

+ 127 - 116
state/modelverse_state/main.py

@@ -26,8 +26,8 @@ class ModelverseState(object):
     def __init__(self, bootfile = None):
         self.free_id = 0
         self.edges = {}
-        self.outgoing = defaultdict(set)
-        self.incoming = defaultdict(set)
+        self.outgoing = {}
+        self.incoming = {}
         self.values = {}
         self.nodes = set()
         self.GC = True
@@ -61,8 +61,8 @@ class ModelverseState(object):
                 self.root, self.free_id, self.nodes, self.edges, self.values = pickle.load(open(picklefile, 'rb'))
                 for name in self.edges:
                     source, destination = self.edges[name]
-                    self.outgoing[source].add(name)
-                    self.incoming[destination].add(name)
+                    self.outgoing.setdefault(source, set()).add(name)
+                    self.incoming.setdefault(destination, set()).add(name)
                 return self.root
             else:
                 raise Exception("Invalid pickle")
@@ -137,8 +137,8 @@ class ModelverseState(object):
         elif target not in self.edges and target not in self.nodes:
             return (None, status.FAIL_CE_TARGET)
         else:
-            self.outgoing[source].add(self.free_id)
-            self.incoming[target].add(self.free_id)
+            self.outgoing.setdefault(source, set()).add(self.free_id)
+            self.incoming.setdefault(target, set()).add(self.free_id)
             self.edges[self.free_id] = (source, target)
             self.free_id += 1
             return (self.free_id - 1, status.SUCCESS)
@@ -177,43 +177,36 @@ class ModelverseState(object):
         return (None, status.SUCCESS)
 
     def read_value(self, node):
-        if node not in self.nodes:
+        if node in self.values:
+            return (self.values[node], status.SUCCESS)
+        elif node not in self.nodes:
             return (None, status.FAIL_RV_UNKNOWN)
-        v = self.values.get(node, None)
-        if v is None:
-            return (None, status.FAIL_RV_NO_VALUE)
         else:
-            return (v, status.SUCCESS)
+            return (None, status.FAIL_RV_NO_VALUE)
 
     def read_outgoing(self, elem):
         if elem in self.edges or elem in self.nodes:
-            return (list(self.outgoing[elem]), status.SUCCESS)
+            if elem in self.outgoing:
+                return (list(self.outgoing[elem]), status.SUCCESS)
+            else:
+                return ([], status.SUCCESS)
         else:
             return (None, status.FAIL_RO_UNKNOWN)
 
     def read_incoming(self, elem):
         if elem in self.edges or elem in self.nodes:
-            return (list(self.incoming[elem]), status.SUCCESS)
+            if elem in self.incoming:
+                return (list(self.incoming[elem]), status.SUCCESS)
+            else:
+                return ([], status.SUCCESS)
         else:
             return (None, status.FAIL_RI_UNKNOWN)
 
     def read_edge(self, edge):
-        v = self.edges.get(edge, None)
-        if v is None:
-            return ([None, None], status.FAIL_RE_UNKNOWN)
+        if edge in self.edges:
+            return (list(self.edges[edge]), status.SUCCESS)
         else:
-            s, t = v
-            return ([s, t], status.SUCCESS)
-
-    def read_dict_old(self, node, value):
-        e, s = self.read_dict_edge(node, value)
-        if s != status.SUCCESS:
-            return (None, {status.FAIL_RDICTE_UNKNOWN: status.FAIL_RDICT_UNKNOWN,
-                           status.FAIL_RDICTE_UNCERTAIN: status.FAIL_RDICT_UNCERTAIN,
-                           status.FAIL_RDICTE_OOB: status.FAIL_RDICT_OOB,
-                           status.FAIL_RDICTE_NOT_FOUND: status.FAIL_RDICT_NOT_FOUND,
-                           status.FAIL_RDICTE_AMBIGUOUS: status.FAIL_RDICT_AMBIGUOUS}[s])
-        return (self.edges[e][1], status.SUCCESS)
+            return ([None, None], status.FAIL_RE_UNKNOWN)
 
     def read_dict(self, node, value):
         try:
@@ -229,35 +222,37 @@ class ModelverseState(object):
             # Didn't exist
             pass
 
+        # Get all outgoing links
+        if node in self.outgoing:
+            for e1 in self.outgoing[node]:
+                # For each link, we read the links that might link to a data value
+                if e1 in self.outgoing:
+                    for e2 in self.outgoing[e1]:
+                        # Now read out the target of the link
+                        target = self.edges[e2][1]
+                        # And access its value
+                        if target in self.values and self.values[target] == value:
+                            # Found a match
+                            # Now get the target of the original link
+                            self.cache.setdefault(node, {})[value] = e1
+                            return (self.edges[e1][1], status.SUCCESS)
+
         if node not in self.nodes and node not in self.edges:
             return (None, status.FAIL_RDICT_UNKNOWN)
-        if not self.is_valid_datavalue(value):
+        elif not self.is_valid_datavalue(value):
             return (None, status.FAIL_RDICT_OOB)
-            
-        # Get all outgoing links
-        for e1 in self.outgoing.get(node, set()):
-            data_links = self.outgoing.get(e1, set())
-            # For each link, we read the links that might link to a data value
-            for e2 in data_links:
-                # Now read out the target of the link
-                target = self.edges[e2][1]
-                # And access its value
-                v = self.values.get(target, None)
-                if v == value:
-                    # Found a match
-                    # Now get the target of the original link
-                    self.cache.setdefault(node, {})[value] = e1
-                    return (self.edges[e1][1], status.SUCCESS)
-        return (None, status.FAIL_RDICT_NOT_FOUND)
+        else:
+            return (None, status.FAIL_RDICT_NOT_FOUND)
 
     def read_dict_keys(self, node):
         if node not in self.nodes and node not in self.edges:
             return (None, status.FAIL_RDICTKEYS_UNKNOWN)
         result = []
-        for e1 in self.outgoing.get(node, set()):
-            data_links = self.outgoing.get(e1, set())
-            for e2 in data_links:
-                result.append(self.edges[e2][1])
+        if node in self.outgoing:
+            for e1 in self.outgoing[node]:
+                if e1 in self.outgoing:
+                    for e2 in self.outgoing[e1]:
+                        result.append(self.edges[e2][1])
         return (result, status.SUCCESS)
 
     def read_dict_edge(self, node, value):
@@ -274,31 +269,30 @@ class ModelverseState(object):
             # Didn't exist
             pass
 
-        if node not in self.nodes and node not in self.edges:
-            return (None, status.FAIL_RDICTE_UNKNOWN)
-        if not self.is_valid_datavalue(value):
-            return (None, status.FAIL_RDICTE_OOB)
-            
         # Get all outgoing links
         found = None
-        for e1 in self.outgoing.get(node, set()):
-            data_links = self.outgoing.get(e1, set())
-            # For each link, we read the links that might link to a data value
-            for e2 in data_links:
-                # Now read out the target of the link
-                target = self.edges[e2][1]
-                # And access its value
-                v = self.values.get(target, None)
-                if v == value:
-                    # Found a match
-                    # Now get the target of the original link
-                    if found is not None:
-                        print("Duplicate key on value: %s : %s (%s <-> %s)!" % (v, type(v), found, e1))
-                        return (None, status.FAIL_RDICTE_AMBIGUOUS)
-                    found = e1
-                    self.cache.setdefault(node, {})[value] = e1
+        if node in self.outgoing:
+            for e1 in self.outgoing[node]:
+                # For each link, we read the links that might link to a data value
+                if e1 in self.outgoing:
+                    for e2 in self.outgoing[e1]:
+                        # Now read out the target of the link
+                        target = self.edges[e2][1]
+                        # And access its value
+                        if target in self.values and self.values[target] == value:
+                            # Found a match
+                            # Now get the target of the original link
+                            if found is not None:
+                                print("Duplicate key on value: %s : %s (%s <-> %s)!" % (self.values[target], type(v), found, e1))
+                                return (None, status.FAIL_RDICTE_AMBIGUOUS)
+                            found = e1
+                            self.cache.setdefault(node, {})[value] = e1
         if found is not None:
             return (found, status.SUCCESS)
+        elif node not in self.nodes and node not in self.edges:
+            return (None, status.FAIL_RDICTE_UNKNOWN)
+        elif not self.is_valid_datavalue(value):
+            return (None, status.FAIL_RDICTE_OOB)
         else:
             return (None, status.FAIL_RDICTE_NOT_FOUND)
 
@@ -318,20 +312,19 @@ class ModelverseState(object):
 
         # Get all outgoing links
         found = None
-        for e1 in self.outgoing.get(node, set()):
-            data_links = self.outgoing.get(e1, set())
-            # For each link, we read the links that might link to a data value
-            for e2 in data_links:
-                # Now read out the target of the link
-                target = self.edges[e2][1]
-                # And access its value
-                if target == value_node:
-                    # Found a match
-                    # Now get the target of the original link
-                    if found is not None:
-                        print("Duplicate key on node: %s (%s <-> %s)!" % (value_node, found, e1))
-                        return (None, status.FAIL_RDICTNE_AMBIGUOUS)
-                    found = e1
+        if node in self.outgoing:
+            for e1 in self.outgoing[node]:
+                # For each link, we read the links that might link to a data value
+                if e1 in self.outgoing:
+                    for e2 in self.outgoing[e1]:
+                        # And access its value
+                        if self.edges[e2][1] == value_node:
+                            # Found a match
+                            # Now get the target of the original link
+                            if found is not None:
+                                print("Duplicate key on node: %s (%s <-> %s)!" % (value_node, found, e1))
+                                return (None, status.FAIL_RDICTNE_AMBIGUOUS)
+                            found = e1
         if found is not None:
             return (found, status.SUCCESS)
         else:
@@ -344,20 +337,20 @@ class ModelverseState(object):
             return (None, status.FAIL_RRDICT_OOB)
         # Get all outgoing links
         matches = []
-        for e1 in self.incoming.get(node, set()):
-            data_links = self.outgoing.get(e1, set())
-            # For each link, we read the links that might link to a data value
-            for e2 in data_links:
-                # Now read out the target of the link
-                target = self.edges[e2][1]
-                # And access its value
-                v = self.values.get(target, None)
-                if v == value:
-                    # Found a match
-                    if len(data_links) > 1:
-                        return (None, status.FAIL_RRDICT_UNCERTAIN)
-                    else:
-                        matches.append(e1)
+        if node in self.incoming:
+            for e1 in self.incoming[node]:
+                # For each link, we read the links that might link to a data value
+                if e1 in self.outgoing:
+                    for e2 in self.outgoing[e1]:
+                        # Now read out the target of the link
+                        target = self.edges[e2][1]
+                        # And access its value
+                        if target in self.values and self.values[target] == value:
+                            # Found a match
+                            if len(self.outgoing[e1]) > 1:
+                                return (None, status.FAIL_RRDICT_UNCERTAIN)
+                            else:
+                                matches.append(e1)
         if len(matches) == 0:
             return (None, status.FAIL_RRDICT_NOT_FOUND)
         else:
@@ -379,15 +372,23 @@ class ModelverseState(object):
             del self.values[node]
 
         s = set()
-        for e in self.outgoing[node]:
-            s.add(e)
-        for e in self.incoming[node]:
-            s.add(e)
+        if node in self.outgoing:
+            for e in self.outgoing[node]:
+                s.add(e)
+            del self.outgoing[node]
+        if node in self.incoming:
+            for e in self.incoming[node]:
+                s.add(e)
+            del self.incoming[node]
+
         for e in s:
             self.delete_edge(e)
 
-        del self.outgoing[node]
-        del self.incoming[node]
+        if node in self.outgoing:
+            del self.outgoing[node]
+        if node in self.incoming:
+            del self.incoming[node]
+
         return (None, status.SUCCESS)
 
     def delete_edge(self, edge):
@@ -395,22 +396,30 @@ class ModelverseState(object):
             return (None, status.FAIL_DE_UNKNOWN)
 
         s, t = self.edges[edge]
-        self.incoming[t].remove(edge)
-        self.outgoing[s].remove(edge)
+        if t in self.incoming:
+            self.incoming[t].remove(edge)
+        if s in self.outgoing:
+            self.outgoing[s].remove(edge)
 
         del self.edges[edge]
 
         s = set()
-        for e in self.outgoing[edge]:
-            s.add(e)
-        for e in self.incoming[edge]:
-            s.add(e)
+        if edge in self.outgoing:
+            for e in self.outgoing[edge]:
+                s.add(e)
+        if edge in self.incoming:
+            for e in self.incoming[edge]:
+                s.add(e)
+
         for e in s:
             self.delete_edge(e)
-        del self.outgoing[edge]
-        del self.incoming[edge]
 
-        if (self.GC) and (not self.incoming[t]) and (t not in self.edges):
+        if edge in self.outgoing:
+            del self.outgoing[edge]
+        if edge in self.incoming:
+            del self.incoming[edge]
+
+        if (self.GC) and (t in self.incoming and not self.incoming[t]) and (t not in self.edges):
             # Remove this node as well
             # Edges aren't deleted like this, as they might have a reachable target and source!
             # If they haven't, they will be removed because the source was removed.
@@ -421,7 +430,7 @@ class ModelverseState(object):
     def garbage_collect(self):  
         while self.to_delete:
             t = self.to_delete.pop()
-            if not self.incoming[t]:
+            if t in self.incoming and not self.incoming[t]:
                 self.delete_node(t)
 
     def purge(self):
@@ -437,8 +446,10 @@ class ModelverseState(object):
                 values.remove(elem)
                 if elem in self.edges:
                     visit_list.extend(self.edges[elem])
-                visit_list.extend(self.outgoing[elem])
-                visit_list.extend(self.incoming[elem])
+                if elem in self.outgoing:
+                    visit_list.extend(self.outgoing[elem])
+                if elem in self.incoming:
+                    visit_list.extend(self.incoming[elem])
 
         # All remaining elements are to be purged
         if len(values) > 0: