Browse Source

More Latex debugging and cleaner solution

rparedis 3 years ago
parent
commit
3a65cad8aa

File diff suppressed because it is too large
+ 1 - 1
examples/BuiltIn.xml


BIN
examples/SinGen/examples.zip


+ 18 - 0
examples/SinGen/latex_example.py

@@ -0,0 +1,18 @@
+# Import your model, using the SinGen example
+from SinGen import *
+
+# Import the simulator and LaTeX generator
+from CBD.simulator import Simulator
+from CBD.converters.latexify import CBD2Latex
+
+# Create the model
+model = SinGen('model')
+
+# Create the converter (see docs for info on args)
+cbd2latex = CBD2Latex(model, show_steps=True)
+
+# Simplify the set of equations
+cbd2latex.simplify()
+
+# Show the result (will be shown automatically because of show_steps)
+print(cbd2latex.render())

+ 18 - 0
examples/SinGen/rkf_example.py

@@ -0,0 +1,18 @@
+# Import your model, using the SinGen example
+from SinGen import *
+
+# Import the simulator and Runge-Kutta reqs
+from CBD.simulator import Simulator
+from CBD.preprocessing.butcher import ButcherTableau as BT
+from CBD.preprocessing.rungekutta import RKPreprocessor
+
+# Create the model
+model = SinGen('model')
+
+# Create the Runge-Kutta preprocessor, using the RKF45 method
+RKP = RKPreprocessor(BT.RKF45(), atol=2e-5, hmin=0.1, safety=.84)
+
+# Create the new model and simulate it
+new_model = RKP.preprocess(model)
+sim = Simulator(new_model)
+sim.run()

File diff suppressed because it is too large
+ 92 - 17
examples/notebook/.ipynb_checkpoints/HybridTrain-checkpoint.ipynb


+ 22 - 11
src/CBD/converters/latexify/CBD2Latex.py

@@ -17,8 +17,9 @@ class CBD2Latex:
 								will be removed from all path names. This name is
 								a common prefix over the system.
 								Defaults to :code:`True`.
-		escape_nonlatex (bool): When :code:`True`, non-latex characters are escaped
-								from the rendered result. Defaults to :code:`True`.
+		escape_nonlatex (bool): When :code:`True`, non-latex characters (e.g., underscores)
+								are escaped from the rendered result, if rendered as LaTeX.
+								Defaults to :code:`True`.
 		time_variable (str):    The name for the variable that represents the time
 								(i.e., the current iteration). Defaults to :code:`'i'`.
 		render_latex (bool):    When :code:`True`, the :func:`render` method will
@@ -90,7 +91,8 @@ class CBD2Latex:
 			mname = model.getPath(config['path_sep']) + config['path_sep']
 			if name.startswith(mname):
 				name = name[len(mname):]
-		if config["escape_nonlatex"]:
+		name = name.replace("-", "_")
+		if config["render_latex"] and config["escape_nonlatex"]:
 			name = name.replace("_", r"\_")
 		return name
 
@@ -167,12 +169,10 @@ class CBD2Latex:
 		"""
 		res = ""
 		for eq in self.equations:
-			eqs = eq.at(Time.now(self.config["time_variable"], self.config["delta_t"]))
+			eqs = eq.at(Time.now())
 			for e in eqs:
-				time = e.rhs.time
-				if isinstance(eq.rhs, DelayFnc):
-					time = e.rhs.eq_time
-				res += e.eq().format(T=self.config["time_format"]).format(time=str(time)) + '\n'
+				eq_time_fmt, time_fmt = self._get_time_formats(e)
+				res += e.eq().format(T=time_fmt, E=eq_time_fmt) + '\n'
 		return res
 
 	def latex(self):
@@ -181,11 +181,22 @@ class CBD2Latex:
 		"""
 		res = ""
 		for eq in self.equations:
-			eqs = eq.at(Time.now(self.config["time_variable"], self.config["delta_t"]))
+			eqs = eq.at(Time.now())
 			for e in eqs:
-				res += e.latex().format(T=self.config["time_format"]).format(time=str(e.rhs.time)) + '\\\\\n'
+				eq_time_fmt, time_fmt = self._get_time_formats(e)
+				res += e.latex().format(T=time_fmt, E=eq_time_fmt) + '\\\\\n'
 		return res
 
+	def _get_time_formats(self, e):
+		dt = self.config["delta_t"]
+		if dt != "":
+			dt = " * " + dt
+		time_fmt = self.config["time_format"].format(time=str(e.rhs.time)).format(
+			i=self.config["time_variable"], dt=dt)
+		eq_time_fmt = self.config["time_format"].format(time=str(e.rhs.eq_time)).format(
+			i=self.config["time_variable"], dt=dt)
+		return eq_time_fmt, time_fmt
+
 	def simplify_links(self):
 		"""
 		First step to execute is a link simplification. Generally, there are more links
@@ -312,7 +323,7 @@ class CBD2Latex:
 			print("" + text + ":")
 		else:
 			print("STEP %d:" % self._step, text)
-		print(self.render(None))
+		print(self.render())
 		self._step += 1
 
 

+ 33 - 31
src/CBD/converters/latexify/functions.py

@@ -13,23 +13,19 @@ class Time:
 		value (numeric):    The value of the time.
 		relative (bool):    When :code:`True`, the time is relative to some time variable.
 							When :code:`False`, the time is meant to be an absolute time.
-		var (str):          The time variable to use. Defaults to :code:`i`.
 	"""
-	def __init__(self, value, relative=False, var="", delta=""):
+	def __init__(self, value, relative=False):
 		self.value = value
 		self.relative = relative
-		self.var = var
-		self.delta = delta
 
 	def __eq__(self, other):
 		return other.value == self.value and other.relative == self.relative
 
 	def __add__(self, other):
 		if isinstance(other, (int, float)):
-			return Time(self.value + other, self.relative, self.var, self.delta)
+			return Time(self.value + other, self.relative)
 		if self.is_relative() and other.is_relative():
-			assert self.var == other.var
-			return Time(self.value + other.value, True, self.var, other.delta)
+			return Time(self.value + other.value, True)
 		if self.is_absolute() or other.is_absolute():
 			return Time(self.value + other.value)
 		raise TypeError("unsupported operand type(s) for +: '%s' and 'Time'" % str(other.__class__))
@@ -37,19 +33,11 @@ class Time:
 	def __str__(self):
 		if self.is_relative():
 			if self.value == 0:
-				return self.var
+				return "{i}"
 			elif self.value > 0:
-				if self.delta == "":
-					return self.var + " + " + str(self.value)
-				if self.value == 1:
-					return self.var + " + " + self.delta
-				return self.var + " + (" + str(self.value) + " * "+ self.delta + ")"
+				return "{i} + " + str(self.value) + "{dt}"
 			elif self.value < 0:
-				if self.delta == "":
-					return self.var + " - " + str(abs(self.value))
-				if self.value == -1:
-					return self.var + " - " + self.delta
-				return self.var + " - (" + str(abs(self.value)) + " * "+ self.delta + ")"
+				return "{i} - " + str(abs(self.value)) + "{dt}"
 		else:
 			return str(self.value)
 
@@ -60,8 +48,8 @@ class Time:
 		return self.relative
 
 	@staticmethod
-	def now(var="", delta=""):
-		return Time(0, True, var, delta)
+	def now():
+		return Time(0, True)
 
 
 class Eq:
@@ -109,10 +97,10 @@ class Eq:
 		return self.rhs.get_dependencies()
 
 	def eq(self):
-		return "{lhs}{{T}} = {rhs}".format(lhs=self.lhs, rhs=self.rhs.eq())
+		return "{lhs}{{E}} = {rhs}".format(lhs=self.lhs, rhs=self.rhs.eq())
 
 	def latex(self):
-		return "{lhs}{{T}} = {rhs}".format(lhs=self.lhs, rhs=self.rhs.latex())
+		return "{lhs}{{E}} = {rhs}".format(lhs=self.lhs, rhs=self.rhs.latex())
 
 	def apply(self, other):
 		self.rhs = self.rhs.apply(other)
@@ -157,7 +145,7 @@ class Fnc:
 
 	def __eq__(self, other):
 		return self.name == other.name and self.time == other.time and self.eq_time == other.time and \
-		       self.args == other.args
+				self.args == other.args
 
 	def __hash__(self):
 		return hash(self.name)
@@ -204,17 +192,19 @@ class Fnc:
 		"""
 		fncsets = []
 		for arg in self.args:
-			fncsets.append(arg.at(time))
+			fncsets.append([x for x in arg.at(time)])
 		if len(fncsets) == 1:
-			args = fncsets
+			args = [a for a in fncsets]
 		else:
 			args = self._cross_product_fncs(*fncsets)
-		# TODO: create a list of functions that respect all arguments
 		for a in args:
-			k = self.__class__(self.name, a[0].time, a[0].eq_time)
+			k = self.create(a[0].time, a[0].eq_time)
 			k.args = a
 			yield k
 
+	def create(self, time, eq_time):
+		return self.__class__(self.name, time, eq_time)
+
 	def apply(self, eq):
 		for i, a in enumerate(self.args):
 			self.args[i] = a.apply(eq)
@@ -266,6 +256,9 @@ class MultiFnc(Fnc):
 		Fnc.__init__(self, name, time=time, eq_time=eq_time)
 		self.symbol = symbol
 
+	def create(self, time, eq_time):
+		return self.__class__(self.name, self.symbol, time, eq_time)
+
 	def eq(self):
 		return "(" + (" " + self.name + " ").join([a.eq() for a in self.args]) + ")"
 
@@ -299,9 +292,9 @@ class MultiFnc(Fnc):
 				others.append(a)
 			else:
 				if self.name == '+':
-					na = _BLOCK_MAP["ProductBlock"](None)[0]
+					na = BLOCK_MAP["ProductBlock"](None)[0]
 				elif self.name == '*':
-					na = _BLOCK_MAP["PowerBlock"](None)[0]
+					na = BLOCK_MAP["PowerBlock"](None)[0]
 				else:
 					continue
 				na.time = self.time
@@ -323,6 +316,9 @@ class UnaryFnc(Fnc):
 		Fnc.__init__(self, name, time=time, eq_time=eq_time)
 		self.symbol = symbol
 
+	def create(self, time, eq_time):
+		return self.__class__(self.name, self.symbol, time, eq_time)
+
 	def eq(self):
 		return "(" + self.name + self.args[0].eq() + ")"
 
@@ -353,6 +349,9 @@ class BinaryFnc(Fnc):
 		Fnc.__init__(self, name, time=time, eq_time=eq_time)
 		self.symbol = symbol
 
+	def create(self, time, eq_time):
+		return self.__class__(self.name, self.symbol, time, eq_time)
+
 	def eq(self):
 		return "(" + self.name.format(a=self.args[0].eq(), b=self.args[1].eq()) + ")"
 
@@ -379,6 +378,9 @@ class ConstantFnc(Fnc):
 		Fnc.__init__(self, name, time=time, eq_time=eq_time)
 		self.val = val
 
+	def create(self, time, eq_time):
+		return self.__class__(self.name, self.val, time, eq_time)
+
 	def __str__(self):
 		return str(self.val)
 
@@ -429,10 +431,10 @@ class VarFnc(Fnc):
 class DelayFnc(Fnc):
 	def at(self, time):
 		res = []
-		res += self.args[0].at(time + Time(-1, True, time.var, time.delta))
+		res += self.args[0].at(time + Time(-1, True))
 		for i in range(len(res)):
 			res[i].eq_time = time
-		res += self.args[1].at(Time(0, False, time.var))
+		res += self.args[1].at(Time(0))
 		return res
 
 # TODO: der, integrator?