| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import xml.etree.ElementTree as ET
- from urllib.parse import unquote
- import base64
- import zlib
- import re
- IGNORE = ['id', 'label', 'placeholders', 'signal', 'class_name']
- """Properties to ignore when parsing."""
- class Node:
- def __init__(self, id, class_name, properties):
- self.id = id
- self.class_name = class_name
- self.properties = properties
- self._connections = {}
- self._inputs = set()
- self._outputs = set()
- self.children = []
- def __contains__(self, item):
- return item in self._inputs or item in self._outputs
- def __getitem__(self, item):
- return self.properties[item]
- def add_input(self, name):
- self._inputs.add(name)
- def add_output(self, name):
- self._outputs.add(name)
- def get_inputs(self):
- return list(self._inputs)
- def get_outputs(self):
- return list(self._outputs)
- def get_connections(self):
- return self._connections
- def add_connection(self, source, target):
- if source in self._connections:
- self._connections[source].append(target)
- else:
- self._connections[source] = [target]
- def get_properties_string(self, ignore=[]):
- res = ""
- for s in [f"{k}=({v if len(v) > 0 else 'None'})" for k, v in self.properties.items() if k not in IGNORE + ignore]:
- res += ", " + s
- return res
- def is_empty(self):
- return len(self.children) == 0
- class Parser:
- def __init__(self, filename,
- input_port_class_name="InputPortBlock",
- output_port_class_name="OutputPortBlock",
- ignore_empty_nodes=False):
- self.filename = filename
- self.input_class = input_port_class_name
- self.output_class = output_port_class_name
- self.ignore_empty_nodes = ignore_empty_nodes
- self.__signals = {}
- self.__imports = []
- self.__class_names = set()
- @staticmethod
- def decode_and_deflate(data):
- """Draw.io compresses each page as follows:
- First, all data is url-encoded
- Next, it is compressed/deflated
- Finally, it is encoded according to base64.
- To obtain the page data, we have to do the reverse.
- Returns:
- Uncompressed and decoded data as a string.
- """
- decoded_data = base64.b64decode(data)
- inflated = zlib.decompress(decoded_data, -15).decode('utf-8')
- url_decoded_data = unquote(inflated)
- return ET.fromstring(url_decoded_data)
- def parse(self):
- """Does the actual file parsing.
- If the file is compressed, we uncompress and work from there.
- If it wasn't compressed, we can work with the whole tree.
- Returns:
- A list of Node objects, representing the drawio file.
- """
- tree = ET.parse(self.filename)
- root = tree.getroot()
- compressed = len(root.findall(".//mxGraphModel")) == 0
- class_object_path = ".//object/mxCell/mxGeometry/mxRectangle/../../..[@class_name]"
- special_object_path = ".//object/mxCell/mxGeometry/../..[@role]"
- if compressed:
- # If compressed, first decode base64, then deflate, then url decode
- pages = root.findall(".//diagram")
- for page in pages: # Decoding happens pagewise
- nroot = self.decode_and_deflate(page.text)
- objects = nroot.findall(class_object_path)
- for obj in objects:
- res = self.create_node(nroot, obj.attrib)
- if res is not None:
- yield res
- special = nroot.findall(special_object_path)
- for obj in special:
- if obj.attrib["role"] == "import":
- module = obj.attrib["module"]
- if "objects" in obj.attrib:
- objects = obj.attrib["objects"]
- self.__imports.append(f"from {module} import {objects}")
- else:
- self.__imports.append(f"import {module}")
- else:
- objects = root.findall(class_object_path)
- for obj in objects:
- res = self.create_node(root, obj.attrib)
- if res is not None:
- yield res
- def get_imports(self):
- return self.__imports
- def create_node(self, root, attr):
- class_name = attr["class_name"]
- # detect duplicate class names
- if class_name in self.__class_names:
- raise ParseException(f"Class '{class_name}' already defined.")
- # detect spaces in class names
- if re.search(r"\s", class_name) is not None:
- raise ParseException(f"Invalid class '{class_name}': Class names may not contain spaces.")
- node = Node(attr["id"], class_name, attr)
- self.__class_names.add(class_name)
- # Find the children of the node
- _rect = root.findall(".//*[@parent='%s']" % node.id)[1]
- components = root.findall(".//object/mxCell[@parent='%s']/.." % _rect.attrib["id"])
- lookup = {}
- for com in components:
- att = com.attrib
- if att["class_name"] in [self.input_class, self.output_class]:
- # Create the ports
- name = att["name"]
- # Duplicate ports are allowed for clarity in the model.
- # They map onto the same port!
- if att["class_name"] == self.input_class:
- node.add_input(name)
- else:
- node.add_output(name)
- # The output's signal info needs to be captured
- if "signal" in att and att["signal"] != "":
- self.__signals.setdefault(node.class_name, {}).setdefault(att["signal"], []).append(name)
- else:
- # Normal Node
- child = Node(att["id"], att["class_name"], att)
- lookup[child.id] = child
- node.children.append(child)
- if self.ignore_empty_nodes and node.is_empty():
- return None
- edges = root.findall(".//*[@parent='%s'][@edge='1']" % _rect.attrib["id"])
- for edge in edges:
- att = edge.attrib
- source = root.find(".//*[@id='%s']" % att["source"])
- target = root.find(".//*[@id='%s']" % att["target"])
- # TODO: check for valid connection!
- if source.attrib["class_name"] == self.input_class:
- sblock = source.attrib["name"]
- spn = ""
- else:
- sblock = lookup[source[0].attrib["parent"]]
- spn = source.attrib["name"]
- if target.attrib["class_name"] == self.output_class:
- tblock = target.attrib["name"]
- tpn = ""
- else:
- tblock = lookup[target[0].attrib["parent"]]
- tpn = target.attrib["name"]
- # TODO: also allow attributes on edges?
- node.add_connection((sblock, spn), (tblock, tpn))
- return node
- class ParseException(Exception):
- """Semantic exceptions when parsing."""
- def __init__(self, message):
- super().__init__(message)
- def parse_environment(vars):
- if vars is None:
- return {}
- sets = vars.split(",")
- return {k.strip(): v.strip() for k, v in [x.split("=") for x in sets]}
|