|
|
@@ -0,0 +1,94 @@
|
|
|
+"""
|
|
|
+Transforms equations/textual denotations to CBD models.
|
|
|
+"""
|
|
|
+import os
|
|
|
+from lark import Lark, Transformer, Token
|
|
|
+
|
|
|
+class eq2CBD:
|
|
|
+ def __init__(self):
|
|
|
+ filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "eq.lark")
|
|
|
+ with open(filename) as file:
|
|
|
+ contents = file.read()
|
|
|
+ parser = Lark(contents, parser="earley")
|
|
|
+ tree = parser.parse("y = 3 * (-z + -4)")
|
|
|
+ # print(tree.pretty())
|
|
|
+ transformer = EqTransformer()
|
|
|
+ print(transformer.transform(tree))
|
|
|
+
|
|
|
+class Node:
|
|
|
+ def __init__(self, value, type):
|
|
|
+ self.value = value
|
|
|
+ self.type = type
|
|
|
+ self.conn = []
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ if len(self.conn) > 0:
|
|
|
+ return "%s(%s) %s" % (self.type, self.value, self.conn)
|
|
|
+ return "%s(%s)" % (self.type, self.value)
|
|
|
+
|
|
|
+class EqTransformer(Transformer):
|
|
|
+ def __init__(self):
|
|
|
+ super().__init__()
|
|
|
+ self.vars = {}
|
|
|
+
|
|
|
+ def eq(self, items):
|
|
|
+ node = Node("=", "OUT")
|
|
|
+ node.conn.append(items[0])
|
|
|
+ node.conn.append(items[2])
|
|
|
+ return node
|
|
|
+
|
|
|
+ def sum(self, items):
|
|
|
+ if len(items) > 1:
|
|
|
+ add = Node("+", "SUM")
|
|
|
+ add.conn.append(items[0])
|
|
|
+ for i in range((len(items) - 1) // 2):
|
|
|
+ idx = (i * 2) + 1
|
|
|
+ if items[idx].type == "ADD":
|
|
|
+ add.conn.append(items[idx+1])
|
|
|
+ else:
|
|
|
+ neg = Node("-", "NEG")
|
|
|
+ neg.conn.append(items[idx+1])
|
|
|
+ add.conn.append(neg)
|
|
|
+ return add
|
|
|
+ return items[0]
|
|
|
+
|
|
|
+ def prod(self, items):
|
|
|
+ if len(items) > 1:
|
|
|
+ mul = Node("*", "MUL")
|
|
|
+ mul.conn.append(items[0])
|
|
|
+ for i in range((len(items) - 1) // 2):
|
|
|
+ idx = (i * 2) + 1
|
|
|
+ if items[idx].type == "MUL":
|
|
|
+ mul.conn.append(items[idx+1])
|
|
|
+ else:
|
|
|
+ inv = Node("/", "INV")
|
|
|
+ inv.conn.append(items[idx+1])
|
|
|
+ mul.conn.append(inv)
|
|
|
+ return mul
|
|
|
+ return items[0]
|
|
|
+
|
|
|
+ def pow(self, items):
|
|
|
+ if len(items) > 1:
|
|
|
+ pow = Node("^", "POW")
|
|
|
+ pow.conn.append(items[0])
|
|
|
+ if len(items) == 3:
|
|
|
+ pow.conn.append(items[2])
|
|
|
+ return pow
|
|
|
+ return items[0]
|
|
|
+
|
|
|
+ def var(self, items):
|
|
|
+ if len(items) == 2:
|
|
|
+ neg = Node("-", "NEG")
|
|
|
+ neg.conn.append(items[1])
|
|
|
+ return neg
|
|
|
+ return items[0]
|
|
|
+
|
|
|
+ def VNAME(self, tok):
|
|
|
+ return self.vars.setdefault(tok.value, Node(tok.value, tok.type))
|
|
|
+
|
|
|
+ def NUMBER(self, tok):
|
|
|
+ return Node(float(tok.value), "NUMBER")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ eq2CBD()
|