Просмотр исходного кода

Better LaTeX generator (stepwise!)

rparedis 4 лет назад
Родитель
Сommit
244540bf29
2 измененных файлов с 507 добавлено и 1 удалено
  1. 506 0
      src/CBD/converters/latexify.py
  2. 1 1
      src/CBD/preprocessing/rk-notes/rkf.py

+ 506 - 0
src/CBD/converters/latexify.py

@@ -0,0 +1,506 @@
+"""
+A module that allows the construction of LaTeX-equations from a CBD model,
+as a whole (one equation per block and per connection) or as a simplified
+version (by means of substitution).
+"""
+
+from copy import deepcopy
+
+# TODO: better paths; i.e. remove common prefix and escape non-latex characters
+# TODO: better implementation of 'delay', based on the time
+class CBD2Latex:
+	"""
+	Creates a corresponding set of LaTeX-equations for a CBD model.
+
+	Args:
+		model (CBD.CBD.CBD):    The model to create the equations for.
+	"""
+	def __init__(self, model):
+		self.model = model
+		self.equations = {}
+		self.outputs = [self.model.getPath() + "." + x for x in self.model.getSignals().keys()]
+		self._collect_equations()
+
+	def _collect_equations(self):
+		"""
+		Loads the equations from the model in.
+
+		See Also:
+			`Cláudio Gomes, Joachim Denil and Hans Vangheluwe. 2016. "Causal-Block Diagrams",
+			Technical Report <https://repository.uantwerpen.be/docman/irua/d28eb1/151279.pdf>`_
+		"""
+		# Add all blocks
+		for block in self.model.getBlocks():
+			func = _BLOCK_MAP.get(block.getBlockType(), None)
+			if func is None: continue
+			if isinstance(func, str):
+				func = lambda b, p, f=func: (p + ".OUT1", Fnc(f, [p + ".%s" % x for x in block.getInputPortNames()]))
+			res = func(block, block.getPath())
+			if isinstance(res, tuple):
+				self.equations[res[0]] = res[1]
+			elif isinstance(res, list):
+				for r in res:
+					self.equations[r[0]] = r[1]
+			if block.getBlockType() == "DelayBlock":
+				self.outputs.append(block.getPath() + ".OUT1")
+
+		# Add all connections
+		for block in self.model.getBlocks():
+			tp = block.getBlockType()
+			path = block.getPath()
+			for k, v in block.getLinksIn().items():
+				if tp == "OutputPortBlock":
+					self.equations[path] = [v.block.getPath() + "." + v.output_port]
+				else:
+					self.equations[path + "." + k] = v.block.getPath() + "." + v.output_port
+
+	def render(self):
+		"""
+		Creates the LaTeX string for the model, based on the current level of simplifications.
+		"""
+		latex = ""
+		for variable, value in self.equations.items():
+			var = variable
+			val = value
+			if isinstance(value, Fnc):
+				val = value.latex()
+			latex += r"{v} &=& {val}\\".format(v=var, val=val)
+
+		return latex
+
+	def simplify_links(self):
+		"""
+		First step to execute is a link simplification. Generally, there are more links
+		than blocks, so this function will take care of the largest simplification.
+		"""
+		links = set()
+		numeric = set()
+		for k, v in self.equations.items():
+			if isinstance(v, str):
+				links.add(k)
+			elif isinstance(v, (int, float)):
+				numeric.add(k)
+		for k, v in self.equations.items():
+			if isinstance(v, Fnc):
+				for link in links:
+					v.apply(link, self.equations[link])
+				for num in numeric:
+					v.apply(num, self.equations[num])
+		for link in links | numeric:
+			del self.equations[link]
+
+	def substitute(self):
+		"""
+		Combines multiple equations into one, based on the requested output, by
+		means of substitution. This function will be called multiple times: once
+		for each "step" in the simplification.
+
+		See Also:
+			:func:`simplify`
+		"""
+		outputs = self.outputs
+		to_delete = set()
+		for output in outputs:
+			if output not in self.equations: continue
+			v = self.equations[output]
+			if isinstance(v, list):
+				v = v[0]
+				self.equations[output] = self.equations[v]
+				to_delete.add(v)
+			elif isinstance(v, Fnc):
+				for k, e in self.equations.items():
+					if k not in outputs:
+						v.apply(k, e)
+						to_delete.add(k)
+			for f in self.equations.values():
+				if isinstance(f, Fnc):
+					if v not in outputs:
+						f.apply(v, output)
+			deps = self.get_dependencies_for(output)
+			for dep in deps:
+				if dep in to_delete:
+					to_delete.remove(dep)
+		for k, f in self.equations.items():
+			if isinstance(f, Fnc):
+				self.equations[k] = f.simplify()
+		for td in to_delete:
+			del self.equations[td]
+
+	def get_dependencies_for(self, variable, visited=tuple()):
+		"""
+		Tries to obtain all dependencies of a specific variable, to prevent
+		accidental removal. This will be done via a depth-first search.
+
+		Args:
+			variable (str): The variable to get the dependencies for.
+			visited (iter): A collection of all variables that have been
+							checked.
+		"""
+		value = self.equations.get(variable, None)
+		if isinstance(value, Fnc):
+			deps = value.dependencies()
+		else:
+			deps = []
+		i = 0
+		vis = list(visited)
+		while i < len(deps):
+			if deps[i] not in visited:
+				vis.append(deps[i])
+				n_deps = self.get_dependencies_for(deps[i], vis)
+				for dep in n_deps:
+					if dep not in deps:
+						deps.append(dep)
+			i += 1
+		return deps
+
+	def simplify(self, show_steps=False, steps=-1):
+		"""
+		Simplifies the system of equations to become a more optimal solution.
+
+		Args:
+			show_steps (bool):  When :code:`True`, all intermediary results will
+								be rendered with the :func:`render` method.
+								Defaults to :code:`False`.
+			steps (int):        When positive, this indicates the amount of steps
+								that must be taken. When negative, the equations
+								will be simplified until a convergence (i.e. no
+								possible changes) is reached. Defaults to -1.
+
+		See Also:
+			- :func:`simplify_links`
+			- :func:`substitute`
+		"""
+		if show_steps:
+			self.render()
+		self.simplify_links()
+		peq = None
+		i = 0
+		while peq != self.equations:
+			if 0 <= steps <= i: break
+			if show_steps:
+				self.render()
+			peq = self.equations.copy()
+			self.substitute()
+			i += 1
+
+
+class Fnc:
+	"""
+	An identifier of a function within the context of the equation system.
+	This class is a helper class to be used by the :class:`CBD2Latex` class.
+
+	Args:
+		name (str):     The name of the function.
+		args (list):    The ordered list of arguments to be applied by the
+						function.
+	"""
+	def __init__(self, name, args):
+		self.name = name
+		self.args = list(args)
+
+	def __repr__(self):
+		return "%s%s" % (self.name, self.args)
+
+	def __hash__(self):
+		return hash((self.name, tuple(self.args)))
+
+	def __eq__(self, other):
+		return isinstance(other, Fnc) and self.name == other.name and self.args == other.args
+
+	def apply(self, name, value):
+		"""
+		Recursively replaces all references of a variable by its value.
+
+		Args:
+			name (str): The variable to replace.
+			value:      The value to replace the variable by.
+		"""
+		for i, elem in enumerate(self.args):
+			if isinstance(elem, str):
+				if elem == name:
+					self.args[i] = value
+			elif isinstance(elem, Fnc):
+				elem.apply(name, value)
+
+	def simplify(self):
+		"""
+		Simplifies the function w.r.t. its meaning.
+		"""
+		nargs = self.args
+		name = self.name
+		if name == '+':
+			val = 0
+			occ = {}
+			nargs = []
+			for a in self.args:
+				if isinstance(a, (int, float)):
+					val += a
+				elif a in occ:
+					occ[a] += 1
+				else:
+					occ[a] = 1
+			if val != 0:
+				nargs.append(val)
+			for a, c in occ.items():
+				if c == 1:
+					nargs.append(a)
+				else:
+					nargs.append(Fnc("*", [a, c]))
+			if len(nargs) == 1:
+				return nargs[0]
+		elif name == '*':
+			val = 1
+			occ = {}
+			nargs = []
+			for a in self.args:
+				if isinstance(a, (int, float)):
+					val *= a
+				elif a in occ:
+					occ[a] += 1
+				else:
+					occ[a] = 1
+			if val != 1:
+				nargs.append(val)
+			for a, c in occ.items():
+				if c == 1:
+					nargs.append(a)
+				else:
+					nargs.append(Fnc("^", [a, c]))
+			if len(nargs) == 1:
+				return nargs[0]
+		elif name == '^':
+			if self.args[1] == 1:
+				return self.args[0]
+			if self.is_numeric():
+				return self.args[0] ** self.args[1]
+		elif name == 'root':
+			if self.args[1] == 1:
+				return self.args[0]
+			if self.is_numeric():
+				return self.args[0] ** (1.0 / self.args[1])
+		elif name == '-':
+			if self.is_numeric():
+				return -self.args[0]
+		elif name == '%':
+			if self.args[1] == 1:
+				return self.args[0]
+			if self.is_numeric():
+				return self.args[0] % self.args[1]
+		elif name == '~':
+			if self.is_numeric():
+				return 1.0/self.args[0]
+		elif name == 'abs':
+			if self.is_numeric():
+				return abs(self.args[0])
+		elif name == 'int':
+			if self.is_numeric():
+				return int(self.args[0])
+		elif name == 'clamp':
+			if self.is_numeric():
+				return min(max(self.args[0], self.args[1]), self.args[2])
+		elif name == 'max':
+			if self.is_numeric():
+				return max(self.args[0], self.args[1])
+		elif name == 'min':
+			if self.is_numeric():
+				return min(self.args[0], self.args[1])
+		elif name == '<':
+			if self.is_numeric():
+				return int(self.args[0] < self.args[1])
+		elif name == '<=':
+			if self.is_numeric():
+				return int(self.args[0] <= self.args[1])
+		elif name == '==':
+			return int(self.args[0] == self.args[1])
+		elif name == '!':
+			if self.is_numeric():
+				return 0 if self.args[0] else 1
+		elif name == 'or':
+			val = False
+			occ = {}
+			nargs = []
+			for a in self.args:
+				if isinstance(a, (int, float)):
+					val = val or a
+				elif a in occ:
+					occ[a] += 1
+				else:
+					occ[a] = 1
+			nargs.append(val)
+			for a in occ.keys():
+				nargs.append(a)
+			if len(nargs) == 1:
+				return nargs[0]
+		elif name == 'and':
+			val = True
+			occ = {}
+			nargs = []
+			for a in self.args:
+				if isinstance(a, (int, float)):
+					val = val and a
+				elif a in occ:
+					occ[a] += 1
+				else:
+					occ[a] = 1
+			nargs.append(val)
+			for a in occ.keys():
+				nargs.append(a)
+			if len(nargs) == 1:
+				return nargs[0]
+		elif name == 'D':
+			if self.args[0] == self.args[1]:
+				return self.args[0]
+
+		return Fnc(name, nargs)
+
+	def is_numeric(self):
+		"""
+		Checks if the function only contains numeric arguments.
+		"""
+		return all([isinstance(x, (int, float)) for x in self.args])
+
+	def dependencies(self):
+		"""
+		Obtains the dependencies for executing the function.
+		"""
+		x = []
+		for a in self.args:
+			if isinstance(a, str):
+				x.append(a)
+			elif isinstance(a, Fnc):
+				x += a.dependencies()
+		return list(set(x))
+
+	def brackets(self):
+		"""
+		Tests if it is required to enclose the function in brackets.
+		"""
+		return self.name in ["+", "-", "*", "~", "^", "root", "%", "or", "and", "==", "<=", "<"]
+
+	def latex(self):
+		"""
+		Returns a LaTeX-formatted string of this function.
+		"""
+		largs = deepcopy(self.args)
+		for i, a in enumerate(self.args):
+			if isinstance(a, Fnc):
+				if a.brackets():
+					largs[i] = "(%s)" % a.latex()
+				else:
+					largs[i] = a.latex()
+			elif isinstance(a, str):
+				largs[i] = "%s" % a
+			else:
+				largs[i] = str(a)
+
+		if self.name in ['+', '*', 'or', 'and']:
+			op = {
+				'*': r"\cdot ",
+				'or': r"\wedge ",
+				'and': r"\vee ",
+			}.get(self.name, self.name)
+			return (" %s " % op).join(largs)
+		elif self.name in '-!~':
+			op = {
+				'!': r"\neg ",
+				'~': "1/",
+			}.get(self.name, self.name)
+			return "{}{}".format(op, largs[0])
+		elif self.name == '^':
+			return "%s^{%s}" % (largs[0], largs[1])
+		elif self.name == 'root':
+			return "%s^{1/%s}" % (largs[1], largs[0])
+		elif self.name in ['%', '<', '<=', '==']:
+			op = {
+				"%": r"\mod ",
+				"<=": r"\leq ",
+				"==": r"\leftrightarrow ",
+			}.get(self.name, self.name)
+			return "%s %s %s" % (largs[0], op, largs[1])
+		elif self.name == 'D':
+			# return r"\left\{{\begin{{array}}{{lcr}}{ic}&\textrm{{if }}t + 1 = 0\\" \
+			# r"{n}&\textrm{{otherwise}}\end{{array}}\right.".format(ic=largs[1], n=largs[0])
+			return "delay(%s, %s)" % (largs[0], largs[1])
+		return "{}({})".format(self.name, ", ".join(largs))
+
+def _clamp_block(block, p):
+	if block._use_const:
+		return p + ".OUT1", Fnc('clamp', [p + ".IN1", block.min, block.max])
+	return p + ".OUT1", Fnc('clamp', [p + ".IN1", p + ".IN2", p + ".IN3"])
+
+# Maps all standard block types onto a function that is representative of the
+# corresponding equation. This function will be called with the block and its full
+# path; and it should return a tuple of :code:`(LeftHandSide, RightHandSide)`.
+# When the value in this dict is a string, the standard function will be used with
+# the string as a function name. This standard function is:
+#       lambda b, p, f=func: (p + ".OUT1", Fnc(f, [p + ".%s" % x for x in block.getInputPortNames()]))
+# Note: the LHS is required to be a single value!
+_BLOCK_MAP = {
+	"ConstantBlock": lambda block, p: (p + ".OUT1", block.getValue()),
+	"NegatorBlock": lambda block, p: (p + ".OUT1", Fnc('-', [p + ".IN1"])),
+	"InverterBlock": lambda block, p: (p + ".OUT1", Fnc('~', [p + ".IN1"])),
+	"AdderBlock": '+',
+	"ProductBlock": '*',
+	"ModuloBlock": lambda block, p: (p + ".OUT1", Fnc('%', [p + ".IN1", p + ".IN2"])),
+	"RootBlock": lambda block, p: (p + ".OUT1", Fnc('root', [p + ".IN1", p + ".IN2"])),
+	"PowerBlock": lambda block, p: (p + ".OUT1", Fnc('^', [p + ".IN1", p + ".IN2"])),
+	"AbsBlock": lambda block, p: (p + ".OUT1", Fnc('abs', [p + ".IN1"])),
+	"IntBlock": lambda block, p: (p + ".OUT1", Fnc('int', [p + ".IN1"])),
+	"ClampBlock": _clamp_block,
+	"GenericBlock": lambda block, p: (p + ".OUT1", Fnc(block.getBlockOperator(), [p + ".IN1"])),
+	"MultiplexerBlock": 'MUX',
+	"MaxBlock": 'max',
+	"MinBlock": 'min',
+	"LessThanBlock": lambda block, p: (p + ".OUT1", Fnc('<', [p + ".IN1", p + ".IN2"])),
+	"LessThanOrEqualsBlock": lambda block, p: (p + ".OUT1", Fnc('<=', [p + ".IN1", p + ".IN2"])),
+	"EqualsBlock": lambda block, p: (p + ".OUT1", Fnc('==', [p + ".IN1", p + ".IN2"])),
+	"NotBlock": lambda block, p: (p + ".OUT1", Fnc('!', [p + ".IN1"])),
+	"OrBlock": 'or',
+	"AndBlock": 'and',
+	"DelayBlock": 'D',
+	# "DelayBlock": lambda block, p: [(p + ".OUT1", Fnc('D', [p + ".IN1"])), (p + ".OUT1(0)", Fnc('S', [p + ".IC"]))],
+	"TimeBlock": lambda block, p: (p + ".OUT1", 't')
+}
+
+
+if __name__ == '__main__':
+	from CBD.CBD import CBD
+	from CBD.lib.std import *
+	class Test(CBD):
+		def __init__(self):
+			super().__init__("Test", [], ["x", "y"])
+			self.addBlock(AdderBlock("A"))
+			self.addBlock(ProductBlock("B"))
+			self.addBlock(ConstantBlock("C", 3.0))
+			self.addConnection("C", "A")
+			self.addConnection("A", "B")
+			self.addConnection("A", "B")
+			self.addConnection("B", "A")
+			self.addConnection("B", "y")
+			self.addConnection("A", "x")
+
+	class FibonacciGen(CBD):
+		def __init__(self, block_name):
+			super().__init__(block_name, input_ports=[], output_ports=['OUT1'])
+
+			# Create the Blocks
+			self.addBlock(DelayBlock("delay1"))
+			self.addBlock(DelayBlock("delay2"))
+			self.addBlock(AdderBlock("sum"))
+			self.addBlock(ConstantBlock("zero", value=(0)))
+			self.addBlock(ConstantBlock("one", value=(1)))
+
+			# Create the Connections
+			self.addConnection("delay1", "delay2", output_port_name='OUT1', input_port_name='IN1')
+			self.addConnection("delay1", "sum", output_port_name='OUT1', input_port_name='IN2')
+			self.addConnection("delay2", "sum", output_port_name='OUT1', input_port_name='IN1')
+			self.addConnection("sum", "delay1", output_port_name='OUT1', input_port_name='IN1')
+			self.addConnection("sum", "OUT1", output_port_name='OUT1')
+			self.addConnection("zero", "delay1", output_port_name='OUT1', input_port_name='IC')
+			self.addConnection("one", "delay2", output_port_name='OUT1', input_port_name='IC')
+
+
+	ltx = CBD2Latex(FibonacciGen("fib"))
+	# ltx.render()
+	ltx.simplify()
+	print(ltx.render())

+ 1 - 1
src/CBD/preprocessing/rk-notes/rkf.py

@@ -128,7 +128,7 @@ if __name__ == '__main__':
     # x0=[2,2,2]
 
     # t,u  = rkf(f=lorenz, a=0, b=1e+1, x0=x0, atol=1e-8, rtol=1e-6 , hmax=1e-1, hmin=1e-40,plot_stepsize=True).solve()
-    t, u  = rkf(f=test, a=0, b=1.4, x0=x0, atol=2e-5, rtol=0, hmax=0.1, safety=.84, plot_stepsize=False).solve()
+    t, u  = rkf(f=test, a=0, b=1.4, x0=x0, atol=2e-5, rtol=0, hmax=0.2, safety=.84, plot_stepsize=False).solve()
     # t, u  = rkf(f=test, a=0, b=2.0, x0=x0, atol=1e-5, rtol=0, hmax=0.25, safety=.84, plot_stepsize=False).solve()
 
     # x, y = u.T