CBDDraw.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """
  2. Useful drawing function to easily draw a CBD model in Graphviz.
  3. """
  4. from CBD.Core import CBD
  5. from CBD.lib.std import *
  6. def gvDraw(cbd, filename, rankdir="LR", colors=None):
  7. """
  8. Outputs a :class:`CBD` as a `GraphViz <https://graphviz.org/>`_ script to filename.
  9. For instance, drawing the CBD given in the :doc:`examples/EvenNumberGen` example, the following figure
  10. can be obtained:
  11. .. figure:: _figures/EvenNumberGV.svg
  12. Note:
  13. The resulting Graphviz file might look "clunky" and messy when rendering with
  14. the standard :code:`dot` engine. The :code:`neato`, :code:`twopi` and :code:`circo`
  15. engines might provide a cleaner and more readable result.
  16. Args:
  17. cbd (CBD): The :class:`CBD` to draw.
  18. filename (str): The name of the dot-file.
  19. rankdir (str): The GraphViz rankdir of the plot. Must be either :code:`TB`
  20. or :code:`LR`.
  21. colors (dict): An optional dictionary of :code:`blockname -> color`.
  22. """
  23. # f = sys.stdout
  24. f = open(filename, "w")
  25. write = lambda s: f.write(s)
  26. write("""// CBD model of the {n} block
  27. // Created with CBD.converters.CBDDraw
  28. digraph model {{
  29. splines=ortho;
  30. label=<<B>{n} ({t})</B>>;
  31. labelloc=\"t\";
  32. fontsize=20;
  33. rankdir=\"{rd}\";
  34. """.format(n=cbd.getPath(), t=cbd.getBlockType(), rd=rankdir))
  35. if colors is None:
  36. colors = {}
  37. def writeBlock(block):
  38. """
  39. Writes a block to graphviz.
  40. Args:
  41. block: The block to write.
  42. """
  43. if isinstance(block, ConstantBlock):
  44. label = " {}\\n({})\\n{}".format(block.getBlockType(), block.getBlockName(), block.getValue())
  45. elif isinstance(block, GenericBlock):
  46. label = " {}\\n({})\\n{}".format(block.getBlockType(), block.getBlockName(), block.getBlockOperator())
  47. elif isinstance(block, ClampBlock) and block._use_const:
  48. label = " {}\\n({})\\n[{}, {}]".format(block.getBlockType(), block.getBlockName(), block.min, block.max)
  49. else:
  50. label = block.getBlockType() + "\\n(" + block.getBlockName() + ")"
  51. shape = "box"
  52. if isinstance(block, CBD):
  53. shape = "Msquare"
  54. elif isinstance(block, ConstantBlock):
  55. shape = "ellipse"
  56. col = ""
  57. if block.getBlockName() in colors:
  58. col = ", color=\"{0}\", fontcolor=\"{0}\"".format(colors[block.getBlockName()])
  59. write(" {b} [label=\"{lbl}\", shape={shape}{col}];\n".format(b=nodeName(block),
  60. lbl=label,
  61. shape=shape,
  62. col=col))
  63. def nodeName(block):
  64. return "node_%d" % id(block)
  65. for port in cbd.getInputPorts():
  66. s = "%s_%s" % (nodeName(cbd), port.name)
  67. write(" {s} [shape=point, width=0.01, height=0.01];\n".format(s=s))
  68. i = "inter_%d_%s" % (id(port.block), port.name)
  69. write(" {i} [shape=point, width=0.01, height=0.01];\n".format(i=i))
  70. write(" {b} -> {i} [taillabel=\"{inp}\", arrowhead=\"none\", arrowtail=\"inv\", dir=both];\n".format(i=i, b=s, inp=port.name))
  71. for block in cbd.getBlocks():
  72. writeBlock(block)
  73. for in_port in block.getInputPorts():
  74. other = in_port.getIncoming().source
  75. op = other.name
  76. i = "inter_%d_%s" % (id(other.block), op)
  77. write(" {i} -> {b} [headlabel=\"{inp}\", arrowhead=\"normal\", arrowtail=\"none\", dir=both];\n".format(i=i, b=nodeName(block), inp=in_port.name))
  78. for op in block.getOutputPortNames():
  79. i = "inter_%d_%s" % (id(block), op)
  80. # if i not in conn: continue
  81. write(" {i} [shape=point, width=0.01, height=0.01];\n".format(i=i))
  82. write(" {a} -> {i} [taillabel=\"{out}\", arrowtail=\"oinv\", arrowhead=\"none\", dir=both];\n"\
  83. .format(i=i, a=nodeName(block), out=op))
  84. for port in cbd.getOutputPorts():
  85. other = port.getIncoming().source
  86. i = "inter_%d_%s" % (id(other.block), other.name)
  87. t = "%s_%s" % (nodeName(cbd), port.name)
  88. write(" {b} [shape=point, width=0.01, height=0.01];\n".format(b=t))
  89. write(" {i} -> {b} [headlabel=\"{inp}\", arrowhead=\"onormal\", arrowtail=\"none\", dir=both];\n".format(i=i, b=t, inp=port.name))
  90. write("\n}")
  91. f.close()