Переглянути джерело

Make Drawio-XML generator more composable

Joeri Exelmans 2 роки тому
батько
коміт
3b6abf7614
3 змінених файлів з 99 додано та 67 видалено
  1. 72 67
      drawio2py/generator.py
  2. 25 0
      drawio2py/util.py
  3. 2 0
      test/run_tests.py

+ 72 - 67
drawio2py/generator.py

@@ -7,12 +7,30 @@ from drawio2py.abstract_syntax import *
 
 
 def generate(drawio: DrawIOFile, file_object):
+	graph = generate_mxfile(drawio)
+	et = ET.ElementTree(graph)
+	et.write(file_object)
+
+def generate_mxfile(drawio: DrawIOFile) -> ET.Element:
 	graph = ET.Element("mxfile", {
 		"compressed": "false", # always write uncompressed files
 		"version": drawio.version,
 		"type": "device",
 		"pages": str(len(drawio.pages))
 	})
+	for page in drawio.pages:
+		dia = generate_diagram(page)
+		graph.append(dia)
+	return graph
+
+
+def generate_diagram(page: Page) -> ET.Element:
+	dia = ET.Element("diagram", {
+		"id": str(page.id),
+		"name": page.name
+	})
+	mxgm = ET.SubElement(dia, "mxGraphModel", page.attributes)
+	root = ET.SubElement(mxgm, "root")
 
 	def generate_style(style: Style) -> str:
 		res = ""
@@ -20,77 +38,64 @@ def generate(drawio: DrawIOFile, file_object):
 			res += "%s=%s;" % (str(k), str(v))
 		return res
 
-	for page in drawio.pages:
-		dia = ET.SubElement(graph, "diagram", {
-			"id": str(page.id),
-			"name": page.name
-		})
-		mxgm = ET.Element("mxGraphModel", page.attributes)
-		root = ET.SubElement(mxgm, "root")
-
-		def write_cell(cell):
-			attrs = {}
-			if not cell.properties:
-				attrs["id"] = cell.id
-			if cell.value:
-				attrs["value"] = cell.value
-			if cell.style:
-				attrs["style"] = generate_style(cell.style)
-			if cell.parent:
-				attrs["parent"] = cell.parent.id
-			if isinstance(cell, Edge):
-				attrs["edge"] = "1"
-				if cell.source:
-					attrs["source"] = cell.source.id
-				if cell.target:
-					attrs["target"] = cell.target.id
-					
-			if cell.properties:
-				# Wrap in <object> if there are properties
-				properties_with_id = cell.properties.copy()
-				properties_with_id['id'] = cell.id  # <object> gets the ID, not the <mxCell> (WTF drawio!)
-				par = ET.SubElement(root, "object", properties_with_id)
-			else:
-				par = root
-
-			# Create the actual <mxCell>
-			c = ET.SubElement(par, "mxCell", attrs, **cell.attributes)
-
-			# Geometry
-			if isinstance(cell, Vertex):
-				ET.SubElement(c, "mxGeometry", {
-					"x": str(cell.geometry.x),
-					"y": str(cell.geometry.y),
-					"width": str(cell.geometry.width),
-					"height": str(cell.geometry.height),
-					"as": "geometry"
-				})
-			elif isinstance(cell, Edge):
-				g = ET.SubElement(c, "mxGeometry", {
-					"relative": "1",
-					"as": "geometry"
-				})
-				if len(cell.geometry.points) > 0:
-					a = ET.SubElement(g, "Array", {"as": "points"})
-					for p in cell.geometry.points:
-						ET.SubElement(a, "mxPoint", {"x": str(p.x), "y": str(p.y)})
-				if cell.geometry.source_point is not None:
-					sp = cell.geometry.source_point
-					ET.SubElement(g, "mxPoint", {"x": str(sp.x), "y": str(sp.y), "as": "sourcePoint"})
-				if cell.geometry.target_point is not None:
-					tp = cell.geometry.target_point
-					ET.SubElement(g, "mxPoint", {"x": str(tp.x), "y": str(tp.y), "as": "targetPoint"})
+	def write_cell(cell):
+		attrs = {}
+		if not cell.properties:
+			attrs["id"] = cell.id
+		if cell.value:
+			attrs["value"] = cell.value
+		if cell.style:
+			attrs["style"] = generate_style(cell.style)
+		if cell.parent:
+			attrs["parent"] = cell.parent.id
+		if isinstance(cell, Edge):
+			attrs["edge"] = "1"
+			if cell.source:
+				attrs["source"] = cell.source.id
+			if cell.target:
+				attrs["target"] = cell.target.id
+				
+		if cell.properties:
+			# Wrap in <object> if there are properties
+			properties_with_id = cell.properties.copy()
+			properties_with_id['id'] = cell.id  # <object> gets the ID, not the <mxCell> (WTF drawio!)
+			par = ET.SubElement(root, "object", properties_with_id)
+		else:
+			par = root
 
-			for child in cell.children:
-				write_cell(child)
+		# Create the actual <mxCell>
+		c = ET.SubElement(par, "mxCell", attrs, **cell.attributes)
 
-		write_cell(page.root)
+		# Geometry
+		if isinstance(cell, Vertex):
+			ET.SubElement(c, "mxGeometry", {
+				"x": str(cell.geometry.x),
+				"y": str(cell.geometry.y),
+				"width": str(cell.geometry.width),
+				"height": str(cell.geometry.height),
+				"as": "geometry"
+			})
+		elif isinstance(cell, Edge):
+			g = ET.SubElement(c, "mxGeometry", {
+				"relative": "1",
+				"as": "geometry"
+			})
+			if len(cell.geometry.points) > 0:
+				a = ET.SubElement(g, "Array", {"as": "points"})
+				for p in cell.geometry.points:
+					ET.SubElement(a, "mxPoint", {"x": str(p.x), "y": str(p.y)})
+			if cell.geometry.source_point is not None:
+				sp = cell.geometry.source_point
+				ET.SubElement(g, "mxPoint", {"x": str(sp.x), "y": str(sp.y), "as": "sourcePoint"})
+			if cell.geometry.target_point is not None:
+				tp = cell.geometry.target_point
+				ET.SubElement(g, "mxPoint", {"x": str(tp.x), "y": str(tp.y), "as": "targetPoint"})
 
-		dia.append(mxgm)
-
-	et = ET.ElementTree(graph)
-	et.write(file_object)
+		for child in cell.children:
+			write_cell(child)
 
+	write_cell(page.root)
+	return dia
 
 if __name__ == '__main__':
 	from drawio2py.parser import Parser

+ 25 - 0
drawio2py/util.py

@@ -14,3 +14,28 @@ def find_lca(cell1: Cell, cell2: Cell) -> Optional[Cell]:
             return lca
         lca = lca.parent
     return None
+
+def generate_empty_root_and_layer() -> Cell:
+    """
+    In typical draw.io fashion, generates two Cells, the root (id="0") and one empty layer (id="1")
+    """
+    root = dio_as.Cell(
+        id="0",
+        value="",
+        parent=None,
+        children=[],
+        properties={},
+        style=None,
+        attributes={},
+    )
+    layer = dio_as.Cell(
+        id="1",
+        value="",
+        parent=root,
+        children=[],
+        properties={},
+        style=None,
+        attributes={},
+    )
+    root.children.append(layer)
+    return root

+ 2 - 0
test/run_tests.py

@@ -31,6 +31,8 @@ def run_test(filename):
     csyntax2.seek(0)
 
     if (csyntax.getvalue() != csyntax2.getvalue()):
+        # print("csyntax:", csyntax.getvalue())
+        # print("csyntax2:", csyntax2.getvalue())
         raise Exception("Files differ after round-trip!")
 
 def parse_shapelib(filename):