Преглед на файлове

Add module for parsing drawio shape libraries, and instantiating their shapes.

Joeri Exelmans преди 2 години
родител
ревизия
1a0d79d140

+ 3 - 1
drawio2py/abstract_syntax.py

@@ -29,7 +29,9 @@ class Point(Element):
 	y: Decimal
 
 @dataclass(eq=False)
-class VertexGeometry(Point):
+class VertexGeometry(Element):
+	x: Optional[Decimal]
+	y: Optional[Decimal]
 	width: Decimal
 	height: Decimal
 

+ 7 - 5
drawio2py/generator.py

@@ -29,11 +29,13 @@ def generate(drawio: DrawIOFile, file_object):
 		root = ET.SubElement(mxgm, "root")
 
 		def write_cell(cell):
-			attrs = {
-				"id": "" if cell.properties else str(cell.id),
-				"value": cell.value,
-				"style": generate_style(cell.style),
-			}
+			attrs = {}
+			if not cell.properties:
+				attrs["id"] = cell.id
+			if cell.value:
+				attrs["value"] = cell.value
+			if cell.style:
+				attrs["style"] = generate_style(cell.style)
 			if cell.parent:
 				attrs["parent"] = cell.parent.id
 			if isinstance(cell, Edge):

+ 4 - 4
drawio2py/parser.py

@@ -50,17 +50,17 @@ class Parser:
 		# Dangling Edges
 		source = groot.find(".//mxPoint[@as='sourcePoint']")
 		if source is not None:
-			source = Point(source.attrib["x"], source.attrib["y"])
+			source = Point(source.get("x", None), source.get("y", None))
 		target = groot.find(".//mxPoint[@as='targetPoint']")
 		if target is not None:
-			target = Point(target.attrib["x"], target.attrib["y"])
+			target = Point(target.get("x", None), target.get("y", None))
 
 		return EdgeGeometry(pts, source, target)
 
 	@staticmethod
 	def parse_cell_geometry(groot):
 		if groot is None: return None
-		return VertexGeometry(groot.attrib["x"], groot.attrib["y"],
+		return VertexGeometry(groot.get("x", None), groot.attrib.get("y", None),
 		                    groot.attrib["width"], groot.attrib["height"])
 
 	@staticmethod
@@ -144,7 +144,7 @@ class Parser:
 			root = root)
 
 	@staticmethod
-	def parse_mxgraphmodel(mxgm):
+	def parse_mxgraphmodel(mxgm) -> Cell:
 		cell_dict = {} # mapping from ID to cell
 		Parser.parse_components(mxgm[0], cell_dict)
 

+ 76 - 0
drawio2py/shapelib.py

@@ -0,0 +1,76 @@
+# Some functions for dealing with Draw.io shape libraries (that are shown in the left pane)
+
+from typing import List, Dict, Tuple
+from decimal import Decimal
+import json
+import secrets
+import xml.etree.ElementTree as ET
+import copy
+
+from drawio2py.abstract_syntax import *
+from drawio2py.parser import Parser
+from drawio2py import util
+
+# Generates (more or less globally unique) Cell IDs in a format similar to how drawio does it:
+class DrawioIDGenerator:
+    def __init__(self):
+        self.next_id = 0
+        self.prefix = secrets.token_urlsafe(20) # a random token of similar complexity to Drawio's cell IDs.
+
+    def gen(self) -> str:
+        id = self.next_id
+        self.next_id += 1
+        return self.prefix + "-" + str(id)
+
+def parse_library(path) -> Dict[str, Cell]:
+    # A library is at the highest level an XML tree, with only one node: <mxlibrary>
+    tree = ET.parse(path)
+    # The "text" in this node is a JSON array:
+    elements = json.loads(tree.getroot().text)
+    library = {}
+    for el in elements:
+        # Every element in the array has a "title" (shown to the user when hovering the shape) and "xml", which is actually Base64-encoded compressed URL-encoded XML (lol) (of an <mxgraphmodel>), which is the same format as what we encounter in a .drawio file:
+        mxgm = Parser.decode_and_deflate(el["xml"])
+        # We create a dictionary mapping title to <mxgraphmodel>:
+        shape_root = Parser.parse_mxgraphmodel(mxgm)
+
+        # Some assertions on what we expect from shape libraries:
+        if len(shape_root.children) != 1:
+            raise Exception("Library shape '" + el["title"] + "': The root does not contain one layer, but " + len(shape_root.children))
+        [shape_layer] = shape_root.children
+        if len(shape_layer.children) != 1:
+            raise Exception("Library shape '" + el["title"] + "': Layer does not contain one cell, but " + len(shape_layer.children))
+        [shape_cell] = shape_layer.children
+        library[el["title"]] = shape_cell
+    return library
+
+class ShapeCloner:
+    def __init__(self, id_gen: DrawioIDGenerator):
+        self.id_gen = id_gen
+
+    def clone_cell(self, template: Cell, parent: Cell) -> Cell:
+        cell = copy.deepcopy(template)
+        cell.id = self.id_gen.gen()
+        cell.parent = parent
+        parent.children.append(cell)
+        return cell
+
+    def clone_vertex(self, template: Cell, parent: Cell, x: Decimal, y: Decimal) -> Cell:
+        cell = self.clone_cell(template, parent)
+        if type(cell) != Vertex:
+            raise Exception("Expected template to contain Vertex, instead the type was " + str(type(cell)))
+        cell.geometry.x = x
+        cell.geometry.y = y
+        return cell
+
+    def clone_edge(self, template: Cell, parent: Cell, source: Cell, target: Cell):
+        if util.find_lca(source, target) != parent:
+            raise Exception("Drawio invariant violation: The parent of an edge must always be the LCA of the source and target of that edge.")
+        cell = self.clone_cell(template, parent)
+        if type(cell) != Edge:
+            raise Exception("Expected template to contain Edge, instead the type was " + str(type(cell)))
+        cell.source = source
+        cell.target = target
+        cell.geometry.source_point = None
+        cell.geometry.target_point = None
+        return cell

+ 16 - 0
drawio2py/util.py

@@ -0,0 +1,16 @@
+from typing import Optional
+from drawio2py.abstract_syntax import Cell
+
+def is_descendant(ancestor: Cell, descendant: Cell) -> bool:
+    for child in ancestor.children:
+        if child == descendant or is_descendant(child, descendant):
+            return True
+    return False
+
+def find_lca(cell1: Cell, cell2: Cell) -> Optional[Cell]:
+    lca = cell1.parent
+    while lca != None:
+        if is_descendant(lca, cell2):
+            return lca
+        lca = lca.parent
+    return None

+ 6 - 0
test/data/shapelibs/README.md

@@ -0,0 +1,6 @@
+These drawio libraries are meant only as input data for the tests.
+
+They were cloned from the "official" libraries for FTG+PM, but no effort will be made to keep them up-to-date with the "official" libraries.
+
+The official libraries can be found here:
+https://msdl.uantwerpen.be/git/jexelmans/drawio/src/master/src/main/webapp/myPlugins/drawiolibs

Файловите разлики са ограничени, защото са твърде много
+ 10 - 0
test/data/shapelibs/common.xml


Файловите разлики са ограничени, защото са твърде много
+ 10 - 0
test/data/shapelibs/ftg.xml


Файловите разлики са ограничени, защото са твърде много
+ 9 - 0
test/data/shapelibs/pm.xml


Файловите разлики са ограничени, защото са твърде много
+ 6 - 0
test/data/shapelibs/pt.xml


Файловите разлики са ограничени, защото са твърде много
+ 5 - 0
test/data/shapelibs/ss.xml


+ 32 - 2
test/run_tests.py

@@ -1,11 +1,12 @@
 import os
 import sys
 import io
-from drawio2py import parser, abstract_syntax, generator
 import pprint
 import tempfile
 import unittest
 
+from drawio2py import parser, abstract_syntax, generator, shapelib
+
 DATADIR = os.path.join(os.path.dirname(__file__), "data")
 
 class DummyOutput:
@@ -41,7 +42,9 @@ def run_test(filename):
         # print(csyntax.getvalue())
         # print(csyntax2.getvalue())
         raise Exception("Files differ after round-trip!")
-    print(filename, "OK")
+
+def parse_shapelib(filename):
+    return shapelib.parse_library(os.path.join(DATADIR,filename))
 
 class Tests(unittest.TestCase):
     def test_1(self):
@@ -52,3 +55,30 @@ class Tests(unittest.TestCase):
 
     def test_3(self):
         run_test("TrivialPM.drawio")
+
+    def test_shapelib(self):
+        common_lib = parse_shapelib("shapelibs/common.xml")
+        pm_lib = parse_shapelib("shapelibs/pm.xml")
+
+        root = abstract_syntax.Cell(
+            id="0",
+            value="",
+            parent=None,
+            children=[],
+            properties={},
+            style=None,
+            attributes={},
+        )
+        cloner = shapelib.ShapeCloner(id_gen=shapelib.DrawioIDGenerator())
+        initial = cloner.clone_vertex(pm_lib["(PM) Initial"], root, 100, 100)
+        final = cloner.clone_vertex(pm_lib["(PM) Final"], root, 300, 300)
+        cloner.clone_edge(common_lib["Control Flow"], root, initial, final)
+
+    def parse_shapelib1(self):
+        parse_shapelib("shapelibs/ftg.xml")
+
+    def parse_shapelib2(self):
+        parse_shapelib("shapelibs/pt.xml")
+
+    def parse_shapelib3(self):
+        parse_shapelib("shapelibs/ss.xml")