run_tests.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os
  2. import sys
  3. import io
  4. import pprint
  5. import tempfile
  6. import unittest
  7. import xml.etree.ElementTree as ET
  8. from drawio2py import parser, abstract_syntax, generator, shapelib, util
  9. DATADIR = os.path.join(os.path.dirname(__file__), "data")
  10. class DummyOutput:
  11. def write(self, text: str):
  12. pass
  13. # The bare minimum that we can consider 'a test':
  14. # We parse XML, write it out again, then parse it again and write it out again.
  15. # Finally, we check if both serializations (the first and second one) are bitwise equal (they should).
  16. # To verify if the generated drawio file is really the "same" as the original, we currently manually open the file in drawio.
  17. def run_test(filename):
  18. # Parse (1st time):
  19. asyntax = parser.Parser.parse(os.path.join(DATADIR,filename))
  20. # Generate .drawio (1st time):
  21. csyntax = io.BytesIO()
  22. generator.generate(asyntax, csyntax)
  23. csyntax.seek(0)
  24. # Parse (2nd time):
  25. asyntax2 = parser.Parser.parse(csyntax)
  26. # Generate .drawio (2nd time):
  27. csyntax2 = io.BytesIO()
  28. generator.generate(asyntax2, csyntax2)
  29. csyntax2.seek(0)
  30. if (csyntax.getvalue() != csyntax2.getvalue()):
  31. # print("csyntax:", csyntax.getvalue())
  32. # print("csyntax2:", csyntax2.getvalue())
  33. raise Exception("Files differ after round-trip!")
  34. # Compares two XML trees.
  35. # From: https://stackoverflow.com/a/24349916
  36. def elements_equal(e1, e2, depth=0):
  37. print(" "*depth+e1.tag)
  38. if e1.tag != e2.tag:
  39. print("tags differ")
  40. return False
  41. if e1.tail != e2.tail:
  42. print("tail differs")
  43. return False
  44. if e1.attrib != e2.attrib:
  45. print("attributes differ")
  46. pprint.pprint(e1.attrib)
  47. pprint.pprint(e2.attrib)
  48. return False
  49. if len(e1) != len(e2):
  50. print("number of children differs")
  51. return False
  52. return all(elements_equal(c1, c2, depth+1) for c1, c2 in zip(e1, e2))
  53. # Currently unused.
  54. # Our generated XML always differs from the parsed XML.
  55. # That's because we skip certain attributes.
  56. def assert_roundtrip_equal(filename):
  57. expected_xml = ET.parse(os.path.join(DATADIR,filename)).getroot()
  58. asyntax = parser.Parser.parse_xml_root(expected_xml)
  59. actual_xml = generator.generate_mxfile(asyntax)
  60. if not elements_equal(expected_xml, actual_xml):
  61. raise Exception("Generated XML tree differs from parsed XML tree.")
  62. def parse_shapelib(filename):
  63. return shapelib.parse_library(os.path.join(DATADIR,filename))
  64. class Tests(unittest.TestCase):
  65. def test_1(self):
  66. run_test("test.drawio")
  67. def test_2(self):
  68. run_test("overview.drawio") # we eat our own dog food :)
  69. def test_3(self):
  70. run_test("TrivialPM.drawio")
  71. def test_label_offset(self):
  72. run_test("labelOffset.drawio")
  73. def test_edge_label(self):
  74. run_test("edgeLabel.drawio")
  75. # asyntax = parser.Parser.parse(os.path.join(DATADIR, "edgeLabel.drawio"))
  76. # with open(os.path.join(DATADIR, "edgeLabel-1.drawio"), 'wb') as f:
  77. # generator.generate(asyntax, f)
  78. def test_shapelib(self):
  79. common_lib = parse_shapelib("shapelibs/common.xml")
  80. pm_lib = parse_shapelib("shapelibs/pm.xml")
  81. id_gen = util.DrawioIDGenerator()
  82. page = util.generate_empty_page(id_gen, "MyFancyPage")
  83. cloner = shapelib.ShapeCloner(id_gen)
  84. initial = cloner.clone_vertex(pm_lib["(PM) Initial"], page.root, 100, 100)
  85. final = cloner.clone_vertex(pm_lib["(PM) Final"], page.root, 300, 300)
  86. cloner.clone_edge(common_lib["Control Flow"], page.root, initial, final)
  87. def parse_shapelib1(self):
  88. parse_shapelib("shapelibs/ftg.xml")
  89. def parse_shapelib2(self):
  90. parse_shapelib("shapelibs/pt.xml")
  91. def parse_shapelib3(self):
  92. parse_shapelib("shapelibs/ss.xml")