瀏覽代碼

Fixed tree rewriting and don't require hacks in the writing visitors anymore

Yentl Van Tendeloo 9 年之前
父節點
當前提交
258ca88eca

+ 2 - 5
interface/HUTN/hutn_compiler/constructors_visitor.py

@@ -124,11 +124,8 @@ class ConstructorsVisitor(Visitor):
         self.visit_literal(tree)
 
     def visit_rvalue(self, tree):
-        if tree.get_tail()[0].head == "expression":
-            self.visit_expression(tree.get_tail()[0])
-        else:
-            self.add_constructors('"access"')
-            self.visit_lvalue(tree)
+        self.add_constructors('"access"')
+        self.visit_lvalue(tree)
 
     def visit_lvalue(self, tree):
         symbol = self.get_symbol(tree)

+ 7 - 12
interface/HUTN/hutn_compiler/primitives_visitor.py

@@ -181,18 +181,13 @@ class PrimitivesVisitor(Visitor):
         self.visit_literal(tree)
 
     def visit_rvalue(self, tree):
-        if tree.get_tail()[0].head == "expression":
-            self.visit_expression(tree.get_tail()[0])
-            r = self.get_primitive(tree)
-            self.set_primitive(tree, r)
-        else:
-            self.visit_lvalue(tree)
-            r = self.get_primitive(tree)
-            if r is None:
-                return
-            a = self.value("access")
-            self.dict(a, "var", r)
-            self.set_primitive(tree, a)
+        self.visit_lvalue(tree)
+        r = self.get_primitive(tree)
+        if r is None:
+            return
+        a = self.value("access")
+        self.dict(a, "var", r)
+        self.set_primitive(tree, a)
 
     def visit_lvalue(self, tree):
         symbol = self.get_symbol(tree)

+ 3 - 32
interface/HUTN/hutn_compiler/semantics_visitor.py

@@ -500,39 +500,10 @@ class SemanticsVisitor(Visitor):
             operation = "dict_read"
             call_tree = SemanticsVisitor.func_call(operation, [node, expression])
             self.visit(call_tree)
-            tree.replace_child(child, call_tree)
+            tree.head = call_tree.head
+            tree._tail = call_tree.tail
+            tree.tail = call_tree.tail
             self.set_type(tree, self.get_type(node))
-            """
-            child = tree.get_tail()[0]
-            if len(child.get_tail()) > 1:
-                l, op, r = child.get_tail()
-                l_type, r_type = self.get_type(l), self.get_type(r)
-                if type(l_type) != type(r_type):
-                    print("Error: " + str(l_type) + " <-> " + str(r_type))
-                    raise RuntimeError(
-                        "{}:{}:{}: error: children were not casted".format(
-                            self.inputfiles[0],
-                            tree.startpos['line'],
-                            tree.startpos['column']
-                        ))
-                call_name = SemanticsVisitor.call_name_binary(l_type, op)
-                call_tree = SemanticsVisitor.func_call(call_name, [l, r])
-                try:
-                    self.visit(call_tree)
-                except RuntimeError:
-                    call_signature = "{0} function {1}({2}, {2})".format(
-                        str(types_mv.Boolean()), call_name, l_type)
-                    raise RuntimeError(
-                        "{}:{}:{}: error: cannot perform {}: function '{}' is "
-                        "not found".format(
-                            self.inputfiles[0],
-                            tree.startpos['line'],
-                            tree.startpos['column'],
-                            child.head,
-                            call_signature))
-                tree.replace_child(child, call_tree)
-            self.set_type(tree, self.get_type(tree.get_tail()[i]))
-            """
         else:
             # Simple
             self.visit_id(tree)

+ 8 - 4
interface/HUTN/test/graph_compilation_action_language/test_simple.py

@@ -7,12 +7,13 @@ from hutn_compiler.compiler import main
 def compile_file(obj, filename):
     result = main(util.get_code_path(filename), "grammars/actionlanguage.g", "PS", [])
     expected = open(util.get_expected_path(filename)).read()
+    print(result)
     result = postproc(result)
     expected = postproc(expected)
-    #if result != expected:
-    #    f = open(util.get_expected_path(filename), 'w')
-    #    f.write(result)
-    #    f.close()
+    if result != expected:
+        f = open(util.get_expected_path(filename), 'w')
+        f.write(result)
+        f.close()
     assert result == expected
 
 class TestSimple(unittest.TestCase):
@@ -54,3 +55,6 @@ class TestSimple(unittest.TestCase):
 
     def test_multi_include(self):
         compile_file(self, "multi_include.al")
+
+    def test_dict_access(self):
+        compile_file(self, "dict_access.al")