import xml.etree.ElementTree as ET from urllib.parse import unquote import base64 import zlib import re IGNORE = ['id', 'label', 'placeholders', 'signal', 'class_name'] """Properties to ignore when parsing.""" class Node: def __init__(self, id, class_name, properties): self.id = id self.class_name = class_name self.properties = properties self._connections = {} self._inputs = set() self._outputs = set() self.children = [] def __contains__(self, item): return item in self._inputs or item in self._outputs def __getitem__(self, item): return self.properties[item] def add_input(self, name): self._inputs.add(name) def add_output(self, name): self._outputs.add(name) def get_inputs(self): return list(self._inputs) def get_outputs(self): return list(self._outputs) def get_connections(self): return self._connections def add_connection(self, source, target): if source in self._connections: self._connections[source].append(target) else: self._connections[source] = [target] def get_properties_string(self, ignore=[]): res = "" for s in [f"{k}=({v if len(v) > 0 else 'None'})" for k, v in self.properties.items() if k not in IGNORE + ignore]: res += ", " + s return res def is_empty(self): return len(self.children) == 0 class Parser: def __init__(self, filename, input_port_class_name="InputPortBlock", output_port_class_name="OutputPortBlock", ignore_empty_nodes=False): self.filename = filename self.input_class = input_port_class_name self.output_class = output_port_class_name self.ignore_empty_nodes = ignore_empty_nodes self.__signals = {} self.__imports = [] self.__class_names = set() @staticmethod def decode_and_deflate(data): """Draw.io compresses each page as follows: First, all data is url-encoded Next, it is compressed/deflated Finally, it is encoded according to base64. To obtain the page data, we have to do the reverse. Returns: Uncompressed and decoded data as a string. """ decoded_data = base64.b64decode(data) inflated = zlib.decompress(decoded_data, -15).decode('utf-8') url_decoded_data = unquote(inflated) return ET.fromstring(url_decoded_data) def parse(self): """Does the actual file parsing. If the file is compressed, we uncompress and work from there. If it wasn't compressed, we can work with the whole tree. Returns: A list of Node objects, representing the drawio file. """ tree = ET.parse(self.filename) root = tree.getroot() compressed = len(root.findall(".//mxGraphModel")) == 0 class_object_path = ".//object/mxCell/mxGeometry/mxRectangle/../../..[@class_name]" special_object_path = ".//object/mxCell/mxGeometry/../..[@role]" if compressed: # If compressed, first decode base64, then deflate, then url decode pages = root.findall(".//diagram") for page in pages: # Decoding happens pagewise nroot = self.decode_and_deflate(page.text) objects = nroot.findall(class_object_path) for obj in objects: res = self.create_node(nroot, obj.attrib) if res is not None: yield res special = nroot.findall(special_object_path) for obj in special: if obj.attrib["role"] == "import": module = obj.attrib["module"] if "objects" in obj.attrib: objects = obj.attrib["objects"] self.__imports.append(f"from {module} import {objects}") else: self.__imports.append(f"import {module}") else: objects = root.findall(class_object_path) for obj in objects: res = self.create_node(root, obj.attrib) if res is not None: yield res def get_imports(self): return self.__imports def create_node(self, root, attr): class_name = attr["class_name"] # detect duplicate class names if class_name in self.__class_names: raise ParseException(f"Class '{class_name}' already defined.") # detect spaces in class names if re.search(r"\s", class_name) is not None: raise ParseException(f"Invalid class '{class_name}': Class names may not contain spaces.") node = Node(attr["id"], class_name, attr) self.__class_names.add(class_name) # Find the children of the node _rect = root.findall(".//*[@parent='%s']" % node.id)[1] components = root.findall(".//object/mxCell[@parent='%s']/.." % _rect.attrib["id"]) lookup = {} for com in components: att = com.attrib if att["class_name"] in [self.input_class, self.output_class]: # Create the ports name = att["name"] # Duplicate ports are allowed for clarity in the model. # They map onto the same port! if att["class_name"] == self.input_class: node.add_input(name) else: node.add_output(name) # The output's signal info needs to be captured if "signal" in att and att["signal"] != "": self.__signals.setdefault(node.class_name, {}).setdefault(att["signal"], []).append(name) else: # Normal Node child = Node(att["id"], att["class_name"], att) lookup[child.id] = child node.children.append(child) if self.ignore_empty_nodes and node.is_empty(): return None edges = root.findall(".//*[@parent='%s'][@edge='1']" % _rect.attrib["id"]) for edge in edges: att = edge.attrib source = root.find(".//*[@id='%s']" % att["source"]) target = root.find(".//*[@id='%s']" % att["target"]) # TODO: check for valid connection! if source.attrib["class_name"] == self.input_class: sblock = source.attrib["name"] spn = "" else: sblock = lookup[source[0].attrib["parent"]] spn = source.attrib["name"] if target.attrib["class_name"] == self.output_class: tblock = target.attrib["name"] tpn = "" else: tblock = lookup[target[0].attrib["parent"]] tpn = target.attrib["name"] # TODO: also allow attributes on edges? node.add_connection((sblock, spn), (tblock, tpn)) return node class ParseException(Exception): """Semantic exceptions when parsing.""" def __init__(self, message): super().__init__(message) def parse_environment(vars): if vars is None: return {} sets = vars.split(",") return {k.strip(): v.strip() for k, v in [x.split("=") for x in sets]}