solver.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import math
  2. from .Core import CBD
  3. from .util import PYTHON_VERSION
  4. if PYTHON_VERSION == 3:
  5. # Python 2 complient
  6. from functools import reduce
  7. # Superclass for possible additional solvers
  8. class Solver:
  9. """
  10. Superclass that can solve algebraic loops.
  11. Args:
  12. logger (Logger): The logger to use.
  13. """
  14. def __init__(self, logger):
  15. self._logger = logger
  16. def checkValidity(self, path, component):
  17. """
  18. Checks the validity of an algebraic loop.
  19. Args:
  20. path (str): The path of the top-level block.
  21. component (list): The blocks in the algebraic loop.
  22. """
  23. raise NotImplementedError()
  24. def constructInput(self, component, curIt):
  25. """
  26. Constructs input for the solver.
  27. Args:
  28. component (list): The blocks in the algebraic loop.
  29. curIt (int): The current iteration of the simulation.
  30. See Also:
  31. :func:`solve`
  32. """
  33. raise NotImplementedError()
  34. def solve(self, solverInput):
  35. """
  36. Solves the algebraic loop.
  37. Args:
  38. solverInput: The constructed input.
  39. See Also:
  40. :func:`constructInput`
  41. """
  42. raise NotImplementedError()
  43. class LinearSolver(Solver):
  44. """
  45. Solves linear algebraic loops using matrices.
  46. """
  47. def checkValidity(self, path, component):
  48. if not self.__isLinear(component):
  49. self._logger.fatal("Cannot solve non-linear algebraic loop.\nSelf: {}\nComponents: {}".format(path, component))
  50. def __isLinear(self, strongComponent):
  51. """Determines if an algebraic loop describes a linear equation or not.
  52. For a block to comprise the strong component, at least one of its dependencies must be in the strong
  53. component as well.
  54. Args:
  55. strongComponent (list): The detected loop, in a list (of BaseBlock instances)
  56. Returns:
  57. :class:`True` if the loop is linear, else :code:`False`.
  58. """
  59. # TO IMPLEMENT
  60. """
  61. A non-linear equation is generated when the following conditions occur:
  62. (1) there is a multiplication operation being performed between two unknowns.
  63. (2) there is an invertion operation being performed in an unknown.
  64. (3) some non-linear block belongs to the strong component
  65. The condition (1) can be operationalized by finding a product block that has two dependencies belonging to
  66. the strong component. This will immediatly tell us that it is a product between two unknowns.
  67. The condition (2) can be operationalized simply by finding an inverter block in the strong component.
  68. Because the inverter block only has one input, if it is in the strong component, it means that its only
  69. dependency is in the strong component.
  70. """
  71. # WON'T APPEAR: Constant, Sequence, Time, Logging
  72. # LINEAR: Negator, Adder, Delay, Input, Output, Wire
  73. # NON-LINEAR: Inverter, Modulo, Root, LT, EQ, LTE, Not, Or, And, MUX, Generic, ABS, Int, Power, Min, Max, Clamp
  74. # SEMI-LINEAR: Product // MUX?
  75. for block in strongComponent:
  76. # condition (1)
  77. if block.getBlockType() == "ProductBlock":
  78. dependenciesUnknown = [x for x in block.getDependencies(0) if x in strongComponent]
  79. if len(dependenciesUnknown) >= 2:
  80. return False
  81. # condition (2) and (3)
  82. if block.getBlockType() in ["InverterBlock", "ModuloBlock", "RootBlock", "LessThanBlock", "EqualsBlock",
  83. "LessThanOrEqualsBlock", "NotBlock", "OrBlock", "AndBlock", "MinBlock",
  84. "MaxBlock", "MultiplexerBlock", "GenericBlock", "AbsBlock", "IntBlock",
  85. "PowerBlock", "ClampBlock"]:
  86. return False
  87. return True
  88. def constructInput(self, strongComponent, curIteration):
  89. """
  90. Constructs input for a solver of systems of linear equations
  91. Input consists of two matrices:
  92. - M1: coefficient matrix, where each row represents an equation of the system
  93. - M2: result matrix, where each element is the result for the corresponding equation in M1
  94. """
  95. # Initialize matrices with zeros
  96. size = len(strongComponent)
  97. M1 = Matrix(size, size)
  98. M2 = [0] * size
  99. # block -> index of block
  100. indexdict = dict()
  101. for i, block in enumerate(strongComponent):
  102. indexdict[block] = i
  103. # Get low-level dependency
  104. resolveBlock = lambda possibleDep, output_port: possibleDep if not isinstance(possibleDep, CBD) else possibleDep.getBlockByName(output_port)
  105. # Get list of low-level dependencies from n inputs
  106. def getBlockDependencies2(block):
  107. return (resolveBlock(b, output_port) for (b, output_port) in [block.getBlockConnectedToInput(x) for x in block.getInputPortNames()])
  108. for i, block in enumerate(strongComponent):
  109. if block.getBlockType() == "AdderBlock":
  110. for external in [x for x in getBlockDependencies2(block) if x not in strongComponent]:
  111. M2[i] -= external.getSignal()[curIteration].value
  112. M1[i, i] = -1
  113. for compInStrong in [x for x in getBlockDependencies2(block) if x in strongComponent]:
  114. M1[i, indexdict[compInStrong]] += 1
  115. elif block.getBlockType() == "ProductBlock":
  116. # M2 can stay 0
  117. M1[i, i] = -1
  118. fact = reduce(lambda x, y: x * y, [x.getSignal()[curIteration].value for x in getBlockDependencies2(block) if x not in strongComponent])
  119. for compInStrong in [x for x in getBlockDependencies2(block) if x in strongComponent]:
  120. M1[i, indexdict[compInStrong]] += fact
  121. elif block.getBlockType() == "NegatorBlock":
  122. # M2 can stay 0
  123. M1[i, i] = -1
  124. possibleDep, output_port = block.getBlockConnectedToInput("IN1")
  125. M1[i, indexdict[resolveBlock(possibleDep, output_port)]] = - 1
  126. elif block.getBlockType() == "InputPortBlock":
  127. # M2 can stay 0
  128. M1[i, i] = 1
  129. possibleDep, output_port = block.parent.getBlockConnectedToInput(block.getBlockName())
  130. M1[i, indexdict[resolveBlock(possibleDep, output_port)]] = - 1
  131. elif block.getBlockType() == "OutputPortBlock" or block.getBlockType() == "WireBlock":
  132. # M2 can stay 0
  133. M1[i, i] = 1
  134. dblock = block.getDependencies(0)[0]
  135. if isinstance(dblock, CBD):
  136. oport = block.getLinksIn()['IN1'].output_port
  137. dblock = dblock.getBlockByName(oport).getLinksIn()['IN1'].block
  138. M1[i, indexdict[dblock]] = - 1
  139. elif block.getBlockType() == "DelayBlock":
  140. # If a delay is in a strong component, this is the first iteration
  141. # FIXME: turn this into a normal error?
  142. assert curIteration == 0
  143. # And so the dependency is the IC
  144. # M2 can stay 0 because we have an equation of the type -x = -ic <=> -x + ic = 0
  145. M1[i, i] = -1
  146. possibleDep, output_port = block.getBlockConnectedToInput("IC")
  147. dependency = resolveBlock(possibleDep, output_port)
  148. assert dependency in strongComponent
  149. M1[i, indexdict[dependency]] = 1
  150. else:
  151. self._logger.fatal("Unknown element '{}', please implement".format(block.getBlockType()))
  152. return M1, M2
  153. def solve(self, solverInput):
  154. M1, M2 = solverInput
  155. n = M1.rows
  156. indxc = [0] * n
  157. indxr = [0] * n
  158. ipiv = [0] * n
  159. icol = 0
  160. irow = 0
  161. for i in range(n):
  162. big = 0.0
  163. for j in range(n):
  164. if ipiv[j] != 1:
  165. for k in range(n):
  166. if ipiv[k] == 0:
  167. nb = math.fabs(M1[j, k])
  168. if nb >= big:
  169. big = nb
  170. irow = j
  171. icol = k
  172. elif ipiv[k] > 1:
  173. raise ValueError("GAUSSJ: Singular Matrix-1")
  174. ipiv[icol] += 1
  175. if irow != icol:
  176. for l in range(n):
  177. M1[irow, l], M1[icol, l] = M1[icol, l], M1[irow, l]
  178. M2[irow], M2[icol] = M2[icol], M2[irow]
  179. indxr[i] = irow
  180. indxc[i] = icol
  181. if M1[icol, icol] == 0.0:
  182. raise ValueError("GAUSSJ: Singular Matrix-2")
  183. pivinv = 1.0 / M1[icol, icol]
  184. M1[icol, icol] = 1.0
  185. for l in range(n):
  186. M1[icol, l] *= pivinv
  187. M2[icol] *= pivinv
  188. for ll in range(n):
  189. if ll != icol:
  190. dum = M1[ll, icol]
  191. M1[ll, icol] = 0.0
  192. for l in range(n):
  193. M1[ll, l] -= M1[icol, l] * dum
  194. M2[ll] -= M2[icol] * dum
  195. for l in range(n - 1, 0, -1):
  196. if indxr[l] != indxc[l]:
  197. for k in range(n):
  198. M1[k, indxr[l]], M1[k, indxc[l]] = M1[k, indxc[l]], M1[k, indxr[l]]
  199. return solverInput[1]
  200. class Matrix:
  201. """Custom, efficient matrix class. This class is used for efficiency purposes.
  202. - Using a while/for loop is slow.
  203. - Using :class:`[[0] * n] * n` will have n references to the same list.
  204. - Using :class:`[[0] * size for _ in range(size)]` can be 5 times slower
  205. than this class!
  206. Numpy could be used to even further increase efficiency, but this increases the
  207. required dependencies for external hardware systems (that may not provide these options).
  208. Note:
  209. Internally, the matrix is segmented into chunks of 500.000.000 items.
  210. """
  211. def __init__(self, rows, cols):
  212. self.rows = rows
  213. self.cols = cols
  214. self.size = rows * cols
  215. self.__max_list_size = 500 * 1000 * 1000
  216. self.data = [[0] * ((rows * cols) % self.__max_list_size)]
  217. for r in range(self.size // self.__max_list_size):
  218. self.data.append([0] * self.__max_list_size)
  219. def __getitem__(self, idx):
  220. absolute = idx[0] * self.cols + idx[1]
  221. outer = absolute // self.__max_list_size
  222. inner = absolute % self.__max_list_size
  223. return self.data[outer][inner]
  224. def __setitem__(self, idx, value):
  225. absolute = idx[0] * self.cols + idx[1]
  226. outer = absolute // self.__max_list_size
  227. inner = absolute % self.__max_list_size
  228. self.data[outer][inner] = value
  229. def __str__(self):
  230. res = ""
  231. for row in range(self.rows):
  232. if len(res) > 0:
  233. res += "\n"
  234. res += "["
  235. for col in range(self.cols):
  236. res += "\t%8.4f" % self[row, col]
  237. res += "\t]"
  238. return res
  239. try:
  240. import sympy
  241. # TODO: non-unique solutions?
  242. class SympySolver(Solver):
  243. def checkValidity(self, path, component):
  244. raise NotImplementedError("The Sympy Solver is not finished yet, so please refrain from using it.")
  245. def constructInput(self, component, curIt):
  246. eqs = {}
  247. for i, block in enumerate(component):
  248. args = []
  249. for x in self.__dependencies(block):
  250. if x not in component:
  251. args.append(x.getSignal()[curIt].value)
  252. else:
  253. args.append(sympy.symbols('x%d' % component.index(x)))
  254. eqs['x%d' % i] = self.__OPERATIONS[block.getBlockType()](args, block)
  255. return eqs
  256. def solve(self, solverInput):
  257. sol = []
  258. eqs = []
  259. answer = [0] * len(solverInput)
  260. for k, v in solverInput.items():
  261. x = sympy.symbols(k)
  262. sol.append(x)
  263. eqs.append(v - x)
  264. solution = sympy.nonlinsolve(eqs, sol)
  265. print(solverInput, solution)
  266. # TODO: Clamp, MUX, Split, LTE, Eq, LT, not, and, or, delay
  267. __OPERATIONS = {
  268. "AdderBlock": lambda l, _: sum(l),
  269. "ProductBlock": lambda l, _: reduce((lambda a, b: a * b), l),
  270. "NegatorBlock": lambda l, _: -l[0],
  271. "InverterBlock": lambda l, _: 1.0/l[0],
  272. "ModuloBlock": lambda l, _: l[0] % l[1],
  273. "RootBlock": lambda l, _: sympy.root(l[0], l[1]),
  274. "PowerBlock": lambda l, _: l[0] ** l[1],
  275. "AbsBlock": lambda l, _: abs(l[0]),
  276. "IntBlock": lambda l, _: sympy.floor(l[0]),
  277. "GenericBlock": lambda l, b: getattr(sympy, b.getBlockOperator())(l[0]),
  278. "MaxBlock": lambda l: sympy.Max(*l),
  279. "MinBlock": lambda l: sympy.Min(*l),
  280. }
  281. @staticmethod
  282. def __dependencies(block):
  283. blocks = []
  284. for s in block.getInputPortNames():
  285. b, op = block.getBlockConnectedToInput(s)
  286. if isinstance(b, CBD):
  287. b = b.getBlockByName(op)
  288. blocks.append(b)
  289. return blocks
  290. except: pass