소스 검색

mpl plottingmanager/handlers now working

rparedis 5 년 전
부모
커밋
4ad7647132
1개의 변경된 파일236개의 추가작업 그리고 85개의 파일을 삭제
  1. 236 85
      src/CBD/lib/interface/plottinghandler.py

+ 236 - 85
src/CBD/lib/interface/plottinghandler.py

@@ -1,89 +1,240 @@
+from enum import Enum
+
 import matplotlib.animation as animation
-import matplotlib.pyplot as plt
-from matplotlib.lines import Line2D
-from matplotlib.container import BarContainer
-from matplotlib.collections import _CollectionWithSizes
-
-# TODO: also allow bokeh, seaborn, ggplot... instead of only mpl
-
-class LivePlotHandler:
-	def __init__(self, framework):
-		# TODO: check framework validity: is it installed?
-		self.framework = framework
-		self.__plots = []
-		self.__open = False
-
-	def listen(self):
-		"""
-		Open the plots and start listening to updates.
-		"""
-		self.__open = True
-		self.__setup()
-
-	def close(self):
-		"""
-		Stop listening to updates and close the plots.
-		Will be called upon close of the figure(s)/plot(s).
-		"""
-		self.__teardown()
-		self.__open = False
-
-	def is_open(self):
-		"""
-		If the plot is not showing anymore, the simulation
-		might need to be terminated as well.
-		"""
-		return self.__open
-
-	def __setup(self):
-		# TODO: connect figure/plot close event to close call
-		# TODO: ion and show figure-dependent?
-		plt.ion()
-		plt.show()
-
-	def __teardown(self):
-		plt.ioff()
-
-	def registerPlot(self, block, elem, interval=100):
-		anim = animation.FuncAnimation(elem.axes, lambda i, el=elem, b=block: self.update(i, el, b),
-		                               fargs=(elem, block), interval=interval)
-		self.__plots.append([block, elem, anim])
-
-	# TODO
-	def unregisterBlock(self, block):
-		pass
-
-	def refresh(self):
-		"""
-		Some frameworks require a refresh of the plot.
-		This function provides that functionality.
-		"""
-		# TODO: figure-dependent?
-		plt.draw()
-		plt.pause(0.01)
-
-	def update(self, _, elem, block):
-		if isinstance(elem, Line2D):
-			elem.set_data(*block.data_xy)
-		elif isinstance(elem, _CollectionWithSizes):
-			elem.set_offsets(block.data)
-		elif isinstance(elem, BarContainer):    # TODO: horizontal bar?
-			for rect, h in zip(elem.patches, block.data_xy[1]):
-				rect.set_height(h)
-		# TODO: stem, eventplot, pie, stackplot, broken barh, vlines, hlines, fill
-
-	# def update_xlimit(self, min, max):
-	# 	pass
-	#
-	# def update_ylimit(self, min, max):
-	# 	pass
+from bokeh.plotting import curdoc
+# TODO: Bokeh (see TODOs), GGplot, Seaborn
+# TODO: Jupyter
 
 
-if __name__ == '__main__':
-	lph = LivePlotHandler("matplotlib")
-	lph.registerPlot(block, line)
-	# register update function to the framework's update method
-	lph.listen()
-	# run simulation => call window refresh (lph.refresh()) function
-	lph.close()
+class Backend(Enum):
+	MPL        = 1
+	MATPLOTLIB = 1
+	BOKEH      = 2
+	GGPLOT     = 3
+
+# Note: for Bokeh, a server must be started via 'bokeh serve'
+# Note: Seaborn is built on top of matplotlib
+
+class PlotHandler:
+	def __init__(self, object, figure, kind, interval=100, backend=Backend.MPL):
+		self.object = object
+		self.kind = kind
+		self.figure = figure
+
+		self.kind._backend = backend
+		self.elm = self.kind.create(figure)
+		self.__opened = True
+
+		# obtaining information
+		self.__get_data = lambda obj: obj.data_xy
+		self.__events = {
+			"update_event": [],
+			"close_event": []
+		}
+
+		# backend info:
+		if backend == Backend.MPL:
+			self.__ani = animation.FuncAnimation(figure[0], lambda _: self.update(), interval=interval)
+			figure[0].canvas.mpl_connect('close_event', lambda evt: self.close_event())
+		elif backend == Backend.BOKEH:
+			curdoc().add_periodic_callback(lambda: self.update(), interval)
+			# TODO (is this even possible?):
+			curdoc().on_session_destroyed(lambda ctx: self.close_event())
+
+	def signal(self, name, *args):
+		if name not in self.__events:
+			raise ValueError("Invalid signal '%s' in PlotHandler." % name)
+		for evt in self.__events[name]:
+			evt(*args)
+
+	def connect(self, name, function):
+		if name not in self.__events:
+			raise ValueError("Invalid signal '%s' in PlotHandler." % name)
+		self.__events[name].append(function)
+
+	def set_data_getter(self, function):
+		self.__get_data = function
+
+	def get_data(self):
+		return self.__get_data(self.object)
+
+	def close_event(self):
+		self.__opened = False
+		self.signal('close_event')
+
+	def is_opened(self):
+		return self.__opened
+
+	def update(self):
+		data = self.get_data()
+		self.kind.update(self.elm, *data)
+		self.signal('update_event', data)
+
+	@staticmethod
+	def follow(data, size, lower_bound=-float('inf'), upper_bound=float('inf')):
+		assert upper_bound - lower_bound >= size, "Invalid size: outside bounds."
+		if upper_bound - lower_bound == size:
+			return lower_bound, upper_bound
+		if len(data) == 0:
+			return -size / 2.0, size / 2.0
+		value = data[-1]
+		low = max(value - size / 2.0, lower_bound)
+		high = min(low + size, upper_bound)
+		if high == upper_bound:
+			low = high - size
+		return low, high
+
+
+class PlotKind:
+	def __init__(self, *args, **kwargs):
+		self._backend = None
+		self.args = args
+		self.kwargs = kwargs
+
+	def is_backend(self, backend):
+		return self._backend == backend
+
+	def create(self, figure):
+		raise NotImplementedError()
+
+	def update(self, element, *data):
+		raise NotImplementedError()
+
+
+class LinePlot(PlotKind):
+	def create(self, figure):
+		if self.is_backend(Backend.MPL):
+			# matplotlib: figure[1] is the axis
+			line, = figure[1].plot([], [], *self.args, **self.kwargs)
+			return line
+		elif self.is_backend(Backend.BOKEH):
+			return figure.line([], [], *self.args, **self.kwargs)
+
+	def update(self, element, *data):
+		if self.is_backend(Backend.MPL):
+			element.set_data(data[0], data[1])
+		elif self.is_backend(Backend.BOKEH):
+			element.data_source.data.update(x=data[0], y=data[1])
+
+
+class StepPlot(PlotKind):
+	def create(self, figure):
+		if self.is_backend(Backend.MPL):
+			# matplotlib: figure[1] is the axis
+			line, = figure[1].step([], [], *self.args, **self.kwargs)
+			return line
+		elif self.is_backend(Backend.BOKEH):
+			return figure.step([], [], *self.args, **self.kwargs)
+
+	def update(self, element, *data):
+		if self.is_backend(Backend.MPL):
+			element.set_data(data[0], data[1])
+		elif self.is_backend(Backend.BOKEH):
+			element.data_source.data.update(x=data[0], y=data[1])
+
 
+class ScatterPlot(PlotKind):
+	def create(self, figure):
+		if self.is_backend(Backend.MPL):
+			# matplotlib: figure[1] is the axis
+			pathc = figure[1].scatter([], [], *self.args, **self.kwargs)
+			return pathc
+		elif self.is_backend(Backend.BOKEH):
+			return figure.circle([], [], *self.args, **self.kwargs)
+
+	def update(self, element, *data):
+		if self.is_backend(Backend.MPL):
+			element.set_offsets(list(zip(*data)))
+		elif self.is_backend(Backend.BOKEH):
+			element.data_source.data.update(x=data[0], y=data[1])
+
+
+class PlotManager:
+	def __init__(self, backend=Backend.MPL):
+		self.__handlers = {}
+		self.backend = backend
+
+		# TODO: check existance of backend
+
+	def is_opened(self):
+		return any([h.is_opened() for h in self.__handlers.values()])
+
+	def register(self, name, object, figure, kind, interval=100):
+		if name in self.__handlers:
+			raise ValueError("PlotManager: PlotHandler '%s' already registered." % name)
+		self.__handlers[name] = PlotHandler(object, figure, kind, interval, self.backend)
+
+	def unregister(self, name):
+		self.get(name)
+		del self.__handlers[name]
+
+	def get(self, name):
+		if name not in self.__handlers:
+			raise ValueError("PlotManager: No PlotHandler exists with name '%s'." % name)
+		return self.__handlers[name]
+
+	def connect(self, handler_name, event_name, function):
+		self.get(handler_name).connect(event_name, function)
+
+
+# TEMPORARY TEST OBJECT:
+import math
+class __Block:
+	def __init__(self, method):
+		self.x = []
+		self.y = []
+		self.int = 0.0
+		self.method = method
+
+	@property
+	def data_xy(self):
+		self.x.append(self.int)
+		self.y.append(getattr(math, self.method)(self.int))
+		self.int += 0.1
+		return self.x, self.y
+
+
+
+def mpl():
+	import matplotlib.pyplot as plt
+	fig = plt.figure(figsize=(5, 5), dpi=100)
+	ax = fig.add_subplot(111)
+	ax.set_ylim((-1, 1))
+
+	manager = PlotManager()
+	manager.register("sin", __Block('sin'), (fig, ax), ScatterPlot())
+	manager.register("cos", __Block('cos'), (fig, ax), LinePlot(c='red'))
+	manager.connect('sin', 'update_event', lambda d, axis=ax: axis.set_xlim(PlotHandler.follow(d[0], 10.0, 0.0)))
+
+	plt.show()
+
+
+def bkh():
+	from bokeh.plotting import figure
+	from bokeh.client import push_session
+
+	fig = figure(plot_width=400, plot_height=400, y_range=(-1, 1))
+	curdoc().add_root(fig)
+
+	manager = PlotManager(Backend.BOKEH)
+	manager.register("sin", __Block('sin'), fig, ScatterPlot())
+	manager.register("cos", __Block('cos'), fig, LinePlot(color='red'))
+
+	def set_xlim(limits):
+		lower, upper = limits
+		fig.x_range.start = lower
+		fig.x_range.end = upper
+	manager.connect('sin', 'update_event', lambda d: set_xlim(PlotHandler.follow(d[0], 10.0, 0.0)))
+
+	session = push_session(curdoc())
+	session.show()
+
+	# TODO: disable flickering? (is this possible?)
+	import time
+	while manager.is_opened():
+		session.push()
+		time.sleep(0.1)
+
+
+if __name__ == '__main__':
+	bkh()