text.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import lark
  2. from sccd.action_lang.static.statement import *
  3. from collections import defaultdict
  4. # Lark transformer for generating a parse tree of our own types.
  5. class Transformer(lark.Transformer):
  6. def __init__(self):
  7. self.macros = defaultdict(list)
  8. def set_macro(self, macro_id, constructor):
  9. # print("registered macro", macro_id, constructor)
  10. self.macros[macro_id].append(constructor)
  11. def unset_macro(self, macro_id):
  12. # print("unregistered macro", macro_id)
  13. self.macros[macro_id].pop()
  14. array = Array
  15. block = Block
  16. def string_literal(self, node):
  17. return StringLiteral(node[0][1:-1])
  18. def int_literal(self, node):
  19. return IntLiteral(int(node[0].value))
  20. def float_literal(self, node):
  21. return FloatLiteral(float(node[0].value))
  22. def bool_literal(self, node):
  23. return BoolLiteral({
  24. "True": True,
  25. "False": False,
  26. }[node[0].value])
  27. def duration_literal(self, node):
  28. val = int(node[0])
  29. suffix = node[1]
  30. unit = {
  31. "d": None, # 'd' stands for "duration", the non-unit for all zero-durations.
  32. # need this to parse zero-duration as a duration instead of int.
  33. "fs": FemtoSecond,
  34. "ps": PicoSecond,
  35. "ns": Nanosecond,
  36. "us": Microsecond,
  37. "ms": Millisecond,
  38. "s": Second,
  39. "m": Minute,
  40. "h": Hour
  41. }[suffix]
  42. return DurationLiteral(duration(val, unit))
  43. def func_call(self, node):
  44. return FunctionCall(node[0], node[1].children)
  45. def macro_call(self, node):
  46. macro_id = node[0]
  47. params = node[1].children
  48. try:
  49. constructor = self.macros[macro_id][-1]
  50. except IndexError as e:
  51. print(self.macros)
  52. raise Exception("Unknown macro: %s" % macro_id) from e
  53. return constructor(params)
  54. def array_indexed(self, node):
  55. return ArrayIndexed(node[0], node[1])
  56. def identifier(self, node):
  57. name = node[0].value
  58. return Identifier(name)
  59. def binary_expr(self, node):
  60. return BinaryExpression(node[0], node[1].value, node[2])
  61. def unary_expr(self, node):
  62. return UnaryExpression(node[0].value, node[1])
  63. def group(self, node):
  64. return Group(node[0])
  65. def assignment(self, node):
  66. operator = node[1].value
  67. if operator == "=":
  68. return Assignment(node[0], node[2])
  69. else:
  70. # Increment, decrement etc. operators are just syntactic sugar
  71. bin_operator = {"+=": "+", "-=": "-", "*=": "*", "/=": "/", "//=": "//"}[operator]
  72. return Assignment(node[0], BinaryExpression(node[0], bin_operator, node[2]))
  73. def expression_stmt(self, node):
  74. return ExpressionStatement(node[0])
  75. def return_stmt(self, node):
  76. return ReturnStatement(node[0])
  77. def if_stmt(self, node):
  78. if len(node) == 2:
  79. return IfStatement(cond=node[0], if_body=node[1])
  80. else:
  81. return IfStatement(cond=node[0], if_body=node[1], else_body=node[2])
  82. def import_stmt(self, node):
  83. return ImportStatement(module_name=node[0])
  84. params_decl = list
  85. def param_decl(self, node):
  86. return ParamDecl(name=node[0].value, formal_type=node[1])
  87. def type_annot(self, node):
  88. return {
  89. "int": SCCDInt,
  90. "str": SCCDString,
  91. "float": SCCDFloat,
  92. "dur": SCCDDuration,
  93. }[node[0]]
  94. def func_type(self, node):
  95. if len(node) > 1:
  96. return SCCDFunction(param_types=node[0], return_type=node[1])
  97. else:
  98. return SCCDFunction(param_types=node[0])
  99. param_types = list
  100. def func_decl(self, node):
  101. return FunctionDeclaration(params_decl=node[0], body=node[1])
  102. import os
  103. grammar_dir = os.path.dirname(__file__)
  104. with open(os.path.join(grammar_dir,"action_lang.g")) as file:
  105. grammar = file.read()
  106. _default_parser = lark.Lark(grammar, parser="lalr", start=["expr", "stmt"], transformer=Transformer(), cache=True)
  107. class TextParser:
  108. def __init__(self, parser=_default_parser):
  109. self.parser = parser
  110. def parse_expr(self, text: str) -> Expression:
  111. return self.parser.parse(text, start="expr")
  112. def parse_stmt(self, text: str) -> Statement:
  113. return self.parser.parse(text, start="block")