parser.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import xml.etree.ElementTree as ET
  2. from urllib.parse import unquote
  3. import base64
  4. import zlib
  5. import re
  6. IGNORE = ['id', 'label', 'placeholders', 'signal', 'class_name']
  7. """Properties to ignore when parsing."""
  8. class Node:
  9. def __init__(self, id, class_name, properties):
  10. self.id = id
  11. self.class_name = class_name
  12. self.properties = properties
  13. self._connections = {}
  14. self._inputs = set()
  15. self._outputs = set()
  16. self.children = []
  17. def __contains__(self, item):
  18. return item in self._inputs or item in self._outputs
  19. def __getitem__(self, item):
  20. return self.properties[item]
  21. def add_input(self, name):
  22. self._inputs.add(name)
  23. def add_output(self, name):
  24. self._outputs.add(name)
  25. def get_inputs(self):
  26. return list(self._inputs)
  27. def get_outputs(self):
  28. return list(self._outputs)
  29. def get_connections(self):
  30. return self._connections
  31. def add_connection(self, source, target):
  32. if source in self._connections:
  33. self._connections[source].append(target)
  34. else:
  35. self._connections[source] = [target]
  36. def get_properties_string(self, ignore=[]):
  37. res = ""
  38. for s in [f"{k}=({v if len(v) > 0 else 'None'})" for k, v in self.properties.items() if k not in IGNORE + ignore]:
  39. res += ", " + s
  40. return res
  41. def is_empty(self):
  42. return len(self.children) == 0
  43. class Parser:
  44. def __init__(self, filename,
  45. input_port_class_name="InputPortBlock",
  46. output_port_class_name="OutputPortBlock",
  47. ignore_empty_nodes=False):
  48. self.filename = filename
  49. self.input_class = input_port_class_name
  50. self.output_class = output_port_class_name
  51. self.ignore_empty_nodes = ignore_empty_nodes
  52. self.__signals = {}
  53. self.__imports = []
  54. self.__class_names = set()
  55. @staticmethod
  56. def decode_and_deflate(data):
  57. """Draw.io compresses each page as follows:
  58. First, all data is url-encoded
  59. Next, it is compressed/deflated
  60. Finally, it is encoded according to base64.
  61. To obtain the page data, we have to do the reverse.
  62. Returns:
  63. Uncompressed and decoded data as a string.
  64. """
  65. decoded_data = base64.b64decode(data)
  66. inflated = zlib.decompress(decoded_data, -15).decode('utf-8')
  67. url_decoded_data = unquote(inflated)
  68. return ET.fromstring(url_decoded_data)
  69. def parse(self):
  70. """Does the actual file parsing.
  71. If the file is compressed, we uncompress and work from there.
  72. If it wasn't compressed, we can work with the whole tree.
  73. Returns:
  74. A list of Node objects, representing the drawio file.
  75. """
  76. tree = ET.parse(self.filename)
  77. root = tree.getroot()
  78. compressed = len(root.findall(".//mxGraphModel")) == 0
  79. class_object_path = ".//object/mxCell/mxGeometry/mxRectangle/../../..[@class_name]"
  80. special_object_path = ".//object/mxCell/mxGeometry/../..[@role]"
  81. if compressed:
  82. # If compressed, first decode base64, then deflate, then url decode
  83. pages = root.findall(".//diagram")
  84. for page in pages: # Decoding happens pagewise
  85. nroot = self.decode_and_deflate(page.text)
  86. objects = nroot.findall(class_object_path)
  87. for obj in objects:
  88. res = self.create_node(nroot, obj.attrib)
  89. if res is not None:
  90. yield res
  91. special = nroot.findall(special_object_path)
  92. for obj in special:
  93. if obj.attrib["role"] == "import":
  94. module = obj.attrib["module"]
  95. if "objects" in obj.attrib:
  96. objects = obj.attrib["objects"]
  97. self.__imports.append(f"from {module} import {objects}")
  98. else:
  99. self.__imports.append(f"import {module}")
  100. else:
  101. objects = root.findall(class_object_path)
  102. for obj in objects:
  103. res = self.create_node(root, obj.attrib)
  104. if res is not None:
  105. yield res
  106. def get_imports(self):
  107. return self.__imports
  108. def create_node(self, root, attr):
  109. class_name = attr["class_name"]
  110. # detect duplicate class names
  111. if class_name in self.__class_names:
  112. raise ParseException(f"Class '{class_name}' already defined.")
  113. # detect spaces in class names
  114. if re.search(r"\s", class_name) is not None:
  115. raise ParseException(f"Invalid class '{class_name}': Class names may not contain spaces.")
  116. node = Node(attr["id"], class_name, attr)
  117. self.__class_names.add(class_name)
  118. # Find the children of the node
  119. _rect = root.findall(".//*[@parent='%s']" % node.id)[1]
  120. components = root.findall(".//object/mxCell[@parent='%s']/.." % _rect.attrib["id"])
  121. lookup = {}
  122. for com in components:
  123. att = com.attrib
  124. if att["class_name"] in [self.input_class, self.output_class]:
  125. # Create the ports
  126. name = att["name"]
  127. # Duplicate ports are allowed for clarity in the model.
  128. # They map onto the same port!
  129. if att["class_name"] == self.input_class:
  130. node.add_input(name)
  131. else:
  132. node.add_output(name)
  133. # The output's signal info needs to be captured
  134. if "signal" in att and att["signal"] != "":
  135. self.__signals.setdefault(node.class_name, {}).setdefault(att["signal"], []).append(name)
  136. else:
  137. # Normal Node
  138. child = Node(att["id"], att["class_name"], att)
  139. lookup[child.id] = child
  140. node.children.append(child)
  141. if self.ignore_empty_nodes and node.is_empty():
  142. return None
  143. edges = root.findall(".//*[@parent='%s'][@edge='1']" % _rect.attrib["id"])
  144. for edge in edges:
  145. att = edge.attrib
  146. source = root.find(".//*[@id='%s']" % att["source"])
  147. target = root.find(".//*[@id='%s']" % att["target"])
  148. # TODO: check for valid connection!
  149. if source.attrib["class_name"] == self.input_class:
  150. sblock = source.attrib["name"]
  151. spn = ""
  152. else:
  153. sblock = lookup[source[0].attrib["parent"]]
  154. spn = source.attrib["name"]
  155. if target.attrib["class_name"] == self.output_class:
  156. tblock = target.attrib["name"]
  157. tpn = ""
  158. else:
  159. tblock = lookup[target[0].attrib["parent"]]
  160. tpn = target.attrib["name"]
  161. # TODO: also allow attributes on edges?
  162. node.add_connection((sblock, spn), (tblock, tpn))
  163. return node
  164. class ParseException(Exception):
  165. """Semantic exceptions when parsing."""
  166. def __init__(self, message):
  167. super().__init__(message)
  168. def parse_environment(vars):
  169. if vars is None:
  170. return {}
  171. sets = vars.split(",")
  172. return {k.strip(): v.strip() for k, v in [x.split("=") for x in sets]}