plot_results.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import csv
  2. import os
  3. import sys
  4. import matplotlib.pyplot as plt
  5. from matplotlib.backends.backend_pdf import PdfPages
  6. import argparse
  7. def read_data(filepath):
  8. results = {}
  9. with open(filepath, 'r') as csvfile:
  10. reader = csv.DictReader(csvfile, delimiter=';')
  11. for col in reader.fieldnames:
  12. if col != "":
  13. results[col] = []
  14. for row in reader:
  15. for col in results:
  16. results[col].append(float(row[col]))
  17. return results
  18. def selectTime(raw_data, maxTime):
  19. indexToTrim = 0
  20. maxIndex = len(raw_data["time"])
  21. while indexToTrim<maxIndex and raw_data["time"][indexToTrim] < maxTime:
  22. indexToTrim+=1
  23. if indexToTrim<maxIndex:
  24. for col in raw_data.keys():
  25. del raw_data[col][indexToTrim:]
  26. #print("Data trimmed up to index ", indexToTrim)
  27. return raw_data
  28. def get_all_csvs(results_folder):
  29. results = []
  30. directory = os.fsencode(results_folder)
  31. for root, dirs, files in os.walk(directory):
  32. for file in files:
  33. filename = os.fsdecode(file)
  34. if filename.endswith(".csv"):
  35. filepath = os.fsdecode(os.path.join(root,file))
  36. results.append((filename, filepath))
  37. return results
  38. def plot_results(results_folder, analyticalSolutionDir):
  39. for (filename, filepath) in get_all_csvs(results_folder):
  40. # Read all data onto a dictionary
  41. print("Reading data from file",filename,"... ", end="", flush=True)
  42. raw_data = read_data(filepath)
  43. print("Done.")
  44. if analyticalSolutionDir!="":
  45. analytical_sol_file = os.path.join(analyticalSolutionDir, filename)
  46. if os.path.exists(analytical_sol_file):
  47. print("Reading analytical solution of file ",filename,"... ", end="", flush=True)
  48. analytical_data_raw = read_data(analytical_sol_file)
  49. print("Done.")
  50. print("Columns read: ", analytical_data_raw.keys())
  51. print("Trimming data to time ",raw_data["time"][-1],"... ", end="", flush=True)
  52. solution_raw = selectTime(analytical_data_raw, raw_data["time"][-1])
  53. print("Done.")
  54. else:
  55. print("No analytical solution for file ",filename,": " + analytical_sol_file)
  56. solution_raw = None
  57. else:
  58. solution_raw = None
  59. # Create a group of plots per filenames
  60. num_trajectories = len(raw_data.keys()) - 1 # Exclude time
  61. print("Creating ",num_trajectories," plots... ", end="", flush=True)
  62. plot_num = 1
  63. plot_file_name = filepath.replace(".csv",".pdf")
  64. pp = PdfPages(plot_file_name)
  65. for col in raw_data.keys():
  66. if col != "time":
  67. print("Plotting col " + col + "...")
  68. plt.subplot(num_trajectories, 1, plot_num)
  69. plt.xlabel("time")
  70. plt.plot(raw_data["time"], raw_data[col], '-', label=col)
  71. if solution_raw != None:
  72. if col in solution_raw:
  73. plt.plot(solution_raw["time"], solution_raw[col], '-', label=col+"_anl")
  74. plt.legend()
  75. plot_num += 1
  76. pp.savefig()
  77. pp.close()
  78. plt.clf()
  79. print("Done.")
  80. def smallest_identifier_string(strs):
  81. # Taken from https://medium.com/@d_dchris/10-methods-to-solve-the-longest-common-prefix-problem-using-python-leetcode-14-a87bb3eb0f3a
  82. longest_pre = ""
  83. if not strs:
  84. return longest_pre
  85. shortest_str = min(strs, key=len)
  86. for i in range(len(shortest_str)):
  87. if all([x.startswith(shortest_str[:i+1]) for x in strs]):
  88. longest_pre = shortest_str[:i+1]
  89. else:
  90. break
  91. return longest_pre
  92. def plot_merge_results(results_folder):
  93. # Get all files (and their path) ending in csv.
  94. all_csv_files = get_all_csvs(results_folder)
  95. # Group them by filename. Eg., all fmu1.csv files should be together.
  96. csv_files_grouped = {}
  97. for (filename, filepath) in all_csv_files:
  98. if not filename in csv_files_grouped:
  99. csv_files_grouped[filename] = []
  100. csv_files_grouped[filename].append(filepath)
  101. print("Files grouped.")
  102. # Create plot for each group
  103. for (filename, files) in csv_files_grouped.items():
  104. plot_filepath = files[0]
  105. plot_file_name = plot_filepath.replace(".csv",".pdf")
  106. pp = PdfPages(plot_file_name)
  107. group_id = smallest_identifier_string(files)
  108. for filepath in files:
  109. file_identifier = filepath[len(group_id):]
  110. raw_data = read_data(filepath)
  111. print("Read data from file",filename,".")
  112. num_trajectories = len(raw_data.keys()) - 1 # Exclude time
  113. # Assumes that each csv column is in the same order on all csvs. This is a reasonable assumption...
  114. plot_num = 1
  115. for col in raw_data.keys():
  116. if col != "time":
  117. print("Plotting col " + col + "...")
  118. plt.subplot(num_trajectories, 1, plot_num)
  119. plt.xlabel("time")
  120. plt.plot(raw_data["time"], raw_data[col], '-', label=file_identifier+"."+col)
  121. plt.legend()
  122. plot_num += 1
  123. pp.savefig()
  124. pp.close()
  125. plt.clf()
  126. print("Done.")
  127. if __name__ == '__main__':
  128. parser = argparse.ArgumentParser(description='HintCO plotting utility')
  129. parser.add_argument('resultsdir', help='directory where csv files of the co-simulation are stored.')
  130. parser.add_argument('--solutionsdir', help='directory where csv files of the analytical solution are stored.')
  131. parser.add_argument('--merge', action="store_true", help='Instead of creating a plot per folder, merge all plots into a single folder. This requires a fixed folder structure')
  132. args = parser.parse_args()
  133. results_folder = args.resultsdir
  134. analyticalSolutionDir = args.solutionsdir if args.solutionsdir else ""
  135. merge = args.merge
  136. print("results=",results_folder)
  137. print("analyticalSolutionDir=",analyticalSolutionDir)
  138. print("merge=",merge)
  139. if merge:
  140. plot_merge_results(results_folder)
  141. else:
  142. plot_results(results_folder, analyticalSolutionDir)