소스 검색

More efficient conversions between duration units. Improved type checks for binary and unary expressions.

Joeri Exelmans 5 년 전
부모
커밋
f399300f64

+ 64 - 46
src/sccd/action_lang/static/expression.py

@@ -275,49 +275,57 @@ class BinaryExpression(Expression):
         lhs_t = self.lhs.init_expr(scope)
         rhs_t = self.rhs.init_expr(scope)
 
-        def comparison_type():
-            same_type()
-            return SCCDBool
-
-        def same_type():
-            if lhs_t != rhs_t:
-                raise StaticTypeError("Mixed LHS and RHS types in binary '%s'-expression: %s and %s" % (self.operator, str(lhs_t), str(rhs_t)))
-            return lhs_t
-
-        def sum_type():
-            if lhs_t != SCCDInt and lhs_t != SCCDFloat and lhs_t != SCCDDuration:
-                raise StaticTypeError("Invalid type '%s' for binary '%s'-expresion" % (lhs_t, self.operator))
-            return same_type()
-
-        def mult_type():
-            if lhs_t == rhs_t:
-                if lhs_t == Duration:
-                    raise StaticTypeError("Cannot multiply 'Duration' and 'Duration'")
+        def logical():
+            if lhs_t.is_bool_castable() and rhs_t.is_bool_castable():
+                return SCCDBool
+
+        def eq():
+            if lhs_t.is_eq(rhs_t):
+                return SCCDBool
+
+        def ord():
+            if lhs_t.is_ord(rhs_t):
+                return SCCDBool
+
+        def sum():
+            if lhs_t.is_summable(rhs_t):
                 return lhs_t
-            key = lambda x: {SCCDInt: 1, SCCDFloat: 2, SCCDDuration: 3}[x]
-            [smallest_type, largest_type] = sorted([lhs_t, rhs_t], key=key)
-            if largest_type == SCCDDuration and smallest_type == SCCDFloat:
-                raise StaticTypeError("Cannot multiply 'float' and 'Duration'")
-            return largest_type
 
-        return {
-            "and": lambda: SCCDBool,
-            "or":  lambda: SCCDBool,
-            "==":  comparison_type,
-            "!=":  comparison_type,
-            ">":   comparison_type,
-            ">=":  comparison_type,
-            "<":   comparison_type,
-            "<=":  comparison_type,
-            "+":   sum_type,
-            "-":   sum_type,
-            "*":   mult_type,
-            "/":   lambda: SCCDFloat,
-            "//":  same_type,
-            "%":   same_type,
-            "**":  same_type,
+        def mult():
+            return lhs_t.mult(rhs_t)
+
+        def div():
+            return lhs_t.div(rhs_t)
+
+        def floordiv():
+            return lhs_t.floordiv(rhs_t)
+
+        def exp():
+            return lhs_t.exp(rhs_t)
+
+        t = {
+            "and": logical,
+            "or":  logical,
+            "==":  eq,
+            "!=":  eq,
+            ">":   ord,
+            ">=":  ord,
+            "<":   ord,
+            "<=":  ord,
+            "+":   sum,
+            "-":   sum,
+            "*":   mult,
+            "/":   div,
+            "//":  floordiv,
+            "%":   floordiv,
+            "**":  exp,
         }[self.operator]()
 
+        if t is None:
+            raise StaticTypeError("Illegal types for '%s'-operation: %s and %s" % (self.operator, lhs_t, rhs_t))
+
+        return t
+
     def eval(self, memory: MemoryInterface):
         
         return {
@@ -348,15 +356,25 @@ class UnaryExpression(Expression):
 
     def init_expr(self, scope: Scope) -> SCCDType:
         expr_type = self.expr.init_expr(scope)
-        def num_type():
-            if expr_type != SCCDInt and expr_type != SCCDFloat:
-                raise StaticTypeError("Invalid type '%s' for unary '%s'-expresion" % (expr_type, self.operator))
-            return expr_type
-        return {
-            "not": lambda: SCCDBool,
-            "-":   num_type,
+
+        def logical():
+            if expr_type.is_bool_castable():
+                return SCCDBool
+                
+        def neg():
+            if expr_type.is_neg():
+                return expr_type
+
+        t = {
+            "not": logical,
+            "-":   neg,
         }[self.operator]()
 
+        if t is None:
+            raise StaticTypeError("Illegal type for unary '%s'-expression: %s" % (self.operator, expr_type))
+
+        return t
+
     def eval(self, memory: MemoryInterface):
         return {
             "not": lambda x: not x.eval(memory),

+ 105 - 7
src/sccd/action_lang/static/types.py

@@ -10,15 +10,92 @@ class SCCDType(ABC):
         
     def __str__(self):
         return termcolor.colored(self._str(), 'cyan')
-        # return self._str()
 
-@dataclass(frozen=True)
+    def __repr__(self):
+        return "SCCDType(" + self._str() + ")"
+
+    def is_neg(self):
+        return False
+
+    def is_summable(self, other):
+        return False
+
+    def floordiv(self, other) -> Optional['SCCDType']:
+        return None
+
+    def mult(self, other) -> Optional['SCCDType']:
+        return None
+
+    def div(self, other) -> Optional['SCCDType']:
+        return None
+
+    def exp(self, other) -> Optional['SCCDType']:
+        return None
+
+    # Can the type be used as input to '==' and '!=' operations?
+    def is_eq(self, other):
+        return False
+
+    # Can the type be used as input to '<', '<=', ... operations?
+    def is_ord(self, other):
+        return False
+
+    def is_bool_castable(self):
+        return False
+
+@dataclass(eq=False)
 class _SCCDSimpleType(SCCDType):
     name: str
+    neg: bool = False
+    summable: bool = False
+    eq: bool = False
+    ord: bool = False
+    bool_cast: bool = False
+
+    floordiv_dict: Dict[SCCDType, SCCDType] = field(default_factory=dict)
+    mult_dict: Dict[SCCDType, SCCDType] = field(default_factory=dict)
+    div_dict: Dict[SCCDType, SCCDType] = field(default_factory=dict)
+    exp_dict: Dict[SCCDType, SCCDType] = field(default_factory=dict)
 
     def _str(self):
         return self.name
 
+    def is_neg(self):
+        return self.neg
+
+    def is_summable(self, other):
+        if other is self:
+            return self.summable
+
+    def __dict_lookup(self, dict, other):
+        try:
+            return dict[other]
+        except KeyError:
+            return None
+
+    def floordiv(self, other):
+        return self.__dict_lookup(self.floordiv_dict, other)
+
+    def mult(self, other):
+        return self.__dict_lookup(self.mult_dict, other)
+
+    def div(self, other):
+        return self.__dict_lookup(self.div_dict, other)
+
+    def exp(self, other):
+        return self.__dict_lookup(self.exp_dict, other)
+
+    def is_eq(self, other):
+        if other is self:
+            return self.eq
+
+    def is_ord(self, other):
+        if other is self:
+            return self.ord
+
+    def is_bool_castable(self):
+        return self.bool_cast
+
 @dataclass(frozen=True)
 class SCCDFunction(SCCDType):
     param_types: List[SCCDType]
@@ -40,8 +117,29 @@ class SCCDArray(SCCDType):
     def _str(self):
         return "[" + str(self.element_type) + "]"
 
-SCCDBool = _SCCDSimpleType("bool")
-SCCDInt = _SCCDSimpleType("int")
-SCCDFloat = _SCCDSimpleType("float")
-SCCDDuration = _SCCDSimpleType("dur")
-SCCDString = _SCCDSimpleType("str")
+    def is_eq(self, other):
+        if isinstance(other, SCCDArray) and self.element_type.is_eq(other.element_type):
+            return True
+        return False
+
+SCCDBool = _SCCDSimpleType("bool", eq=True, bool_cast=True)
+SCCDInt = _SCCDSimpleType("int", neg=True, summable=True, eq=True, ord=True, bool_cast=True)
+SCCDFloat = _SCCDSimpleType("float", neg=True, summable=True, eq=True, ord=True)
+SCCDDuration = _SCCDSimpleType("dur", neg=True, summable=True, eq=True, ord=True, bool_cast=True)
+SCCDString = _SCCDSimpleType("str", summable=True, eq=True)
+
+# Supported operations between simple types:
+
+SCCDInt.mult_dict = {SCCDInt: SCCDInt, SCCDFloat: SCCDFloat, SCCDDuration: SCCDDuration}
+SCCDFloat.mult_dict = {SCCDInt: SCCDFloat, SCCDFloat: SCCDFloat}
+SCCDDuration.mult_dict = {SCCDInt: SCCDDuration}
+
+SCCDInt.div_dict = {SCCDInt: SCCDFloat, SCCDFloat: SCCDFloat}
+SCCDFloat.div_dict = {SCCDInt: SCCDFloat, SCCDFloat: SCCDFloat}
+
+SCCDInt.floordiv_dict = {SCCDInt: SCCDInt}
+SCCDFloat.floordiv_dict = {SCCDInt: SCCDFloat, SCCDFloat: SCCDFloat}
+SCCDDuration.floordiv_dict = {SCCDInt: SCCDDuration, SCCDDuration: SCCDInt}
+
+SCCDInt.exp_dict = {SCCDInt: SCCDInt, SCCDFloat: SCCDFloat}
+SCCDFloat.exp_dict = {SCCDInt: SCCDFloat, SCCDFloat: SCCDFloat}

+ 1 - 0
src/sccd/cd/globals.py

@@ -2,6 +2,7 @@ from typing import *
 from sccd.util.namespace import *
 from sccd.util.duration import *
 from sccd.util.debug import *
+from sccd.action_lang.static.exceptions import ModelError
 
 # Global values for all statecharts in a class diagram.
 class Globals:

+ 2 - 1
src/sccd/statechart/dynamic/round.py

@@ -69,7 +69,8 @@ class EnabledEventsStrategy(GeneratorStrategy):
 
     def cache_init(self):
         cache = {}
-        for event_id in bm_items(self.statechart.events):
+        cache[(0, 0)] = self.generate(None, 0, 0)
+        for event_id in bm_items(self.statechart.internal_events):
             events_bitmap = bit(event_id)
             cache[(events_bitmap, 0)] = self.generate(None, events_bitmap, 0)
         return cache

+ 17 - 1
src/sccd/statechart/parser/text.py

@@ -19,7 +19,23 @@ class StatechartTransformer(action_lang.ExpressionTransformer):
 
   # override
   def duration_literal(self, node):
-    d = SCDurationLiteral(node[0])
+    val = int(node[0])
+    suffix = node[1]
+
+    unit = {
+      "d": None, # 'd' stands for "duration", the non-unit for all zero-durations.
+                 # need this to parse zero-duration as a duration instead of int.
+      "fs": FemtoSecond,
+      "ps": PicoSecond,
+      "ns": Nanosecond,
+      "us": Microsecond,
+      "ms": Millisecond,
+      "s": Second,
+      "m": Minute,
+      "h": Hour
+    }[suffix]
+
+    d = SCDurationLiteral(duration(val, unit))
     self.globals.durations.append(d)
     return d
 

+ 109 - 69
src/sccd/util/duration.py

@@ -5,43 +5,31 @@ from typing import *
 import math
 import functools
 
-@dataclass
 class _Unit:
-  notation: str
-  relative_size: int
-  larger: Optional[Tuple[Any, int]] = None
-  # smaller: Optional[Tuple[Any, int]] = None
-
-  def __eq__(self, other):
-    return self is other
-
-FemtoSecond = _Unit("fs", 1)
-PicoSecond = _Unit("ps", 1000)
-Nanosecond = _Unit("ns", 1000000)
-Microsecond = _Unit("µs", 1000000000)
-Millisecond = _Unit("ms", 1000000000000)
-Second = _Unit("s", 1000000000000000)
-Minute = _Unit("m", 60000000000000000)
-Hour = _Unit("h", 3600000000000000000)
-Day = _Unit("D", 86400000000000000000)
-
-FemtoSecond.larger = (PicoSecond, 1000)
-PicoSecond.larger = (Nanosecond, 1000)
-Nanosecond.larger = (Microsecond, 1000)
-Microsecond.larger = (Millisecond, 1000)
-Millisecond.larger = (Second, 1000)
-Second.larger = (Minute, 60)
-Minute.larger = (Hour, 60)
-Hour.larger = (Day, 24)
-
-# Day.smaller = (Hour, 24)
-# Hour.smaller = (Minute, 60)
-# Minute.smaller = (Second, 60)
-# Second.smaller = (Millisecond, 1000)
-# Millisecond.smaller = (Microsecond, 1000)
-# Microsecond.smaller = (Nanosecond, 1000)
-# Nanosecond.smaller = (PicoSecond, 1000)
-# PicoSecond.smaller = (FemtoSecond, 1000)
+  __slots__ = ["notation", "larger", "conv_dict"]
+  def __init__(self, notation: str, larger: Optional[Tuple[int, '_Unit']] = None):
+    self.notation = notation
+    self.larger = larger
+
+    # Pre-calculate conversions to all larger units
+    self.conv_dict = {}
+    rel_size = 1
+    larger = (1, self)
+    while larger:
+      rel_size *= larger[0]
+      unit = larger[1]
+      self.conv_dict[unit] = rel_size
+      larger = unit.larger
+
+Day = _Unit("D")
+Hour = _Unit("h", (24, Day))
+Minute = _Unit("m", (60, Hour))
+Second = _Unit("s", (60, Minute))
+Millisecond = _Unit("ms", (1000, Second))
+Microsecond = _Unit("µs", (1000, Millisecond))
+Nanosecond = _Unit("ns", (1000, Microsecond))
+PicoSecond = _Unit("ps", (1000, Nanosecond))
+FemtoSecond = _Unit("fs", (1000, PicoSecond))
 
 class Duration(ABC):
   def __repr__(self):
@@ -71,28 +59,30 @@ class Duration(ABC):
   def __mul__(self):
     pass
 
+  @abstractmethod
+  def __rmul__(self):
+    pass
+
+  @abstractmethod
+  def __bool__(self):
+    pass
+
   def __floordiv__(self, other: 'Duration') -> int:
     if other is _zero:
       raise ZeroDivisionError("duration floordiv by zero duration")
     self_val, other_val, _ = _same_unit(self, other)
     return self_val // other_val
 
-  def __mod__(self, other):
+  def __mod__(self, other: 'Duration'):
       self_val, other_val, unit = _same_unit(self, other)
       new_val = self_val % other_val
-      if new_val == 0:
-        return _zero
-      else:
-        return _NonZeroDuration(new_val, unit)
+      return _duration_no_checks(new_val, unit)
 
   def __lt__(self, other):
     self_val, other_val, unit = _same_unit(self, other)
     return self_val < other_val
 
 class _ZeroDuration(Duration):
-  def _convert(self, unit: _Unit) -> int:
-    return 0
-
   def __str__(self):
     return '0 d'
 
@@ -103,20 +93,44 @@ class _ZeroDuration(Duration):
     return other
 
   def __sub__(self, other):
-    return duration(-other.val, other.unit)
+    return other.__neg__()
 
   def __neg__(self):
     return self
 
-
   def __mul__(self, other: int) -> Duration:
     return self
 
-  # Commutativity
   __rmul__ = __mul__
 
+  def __floordiv__(self, other) -> int:
+    if isinstance(other, Duration):
+      return Duration.__floordiv__(self, other)
+    elif isinstance(other, int):
+      if other == 0:
+        raise ZeroDivisionError("duration floordiv by zero")
+      else:
+        return _ZeroDuration
+    else:
+      raise TypeError("cannot floordiv duration by %s" % type(other))
+
+  def __mod__(self, other):
+    if isinstance(other, Duration):
+      return Duration.__mod__(self, other)
+    elif isinstance(other, int):
+      if other == 0:
+        raise ZeroDivisionError("duration modulo by zero")
+      else:
+        return _ZeroDuration
+    else:
+      raise TypeError("cannot modulo duration by %s" % type(other))
+
+  def __bool__(self):
+    return False
+
 _zero = _ZeroDuration() # Singleton. Only place the constructor should be called.
 
+# Only exported way to construct a Duration object
 def duration(val: int, unit: Optional[_Unit] = None) -> Duration:
   if unit is None:
     if val != 0:
@@ -125,31 +139,30 @@ def duration(val: int, unit: Optional[_Unit] = None) -> Duration:
       return _zero
   else:
     if val == 0:
-      raise Exception("Duration: Zero value should not have unit")
+      raise Exception("Duration: Zero value should have pseudo-unit 'd'")
     else:
       return _NonZeroDuration(val, unit)
 
-# @dataclass
+def _duration_no_checks(val: int, unit: _Unit) -> Duration:
+  if val == 0:
+    return _zero
+  else:
+    return _NonZeroDuration(val, unit)
+
 class _NonZeroDuration(Duration):
+  __slots__ = ["val", "unit"]
   def __init__(self, val: int, unit: _Unit = None):
     self.val = val
     self.unit = unit
 
+    # Use largest possible unit without losing precision
     while self.unit.larger:
-      next_unit, factor = self.unit.larger
+      factor, next_unit = self.unit.larger
       if self.val % factor != 0:
         break
       self.val //= factor
       self.unit = next_unit
 
-  # Can only convert to smaller units.
-  # Returns new Duration.
-  def _convert(self, unit: _Unit) -> int:
-    # Precondition
-    assert self.unit.relative_size >= unit.relative_size
-    factor = self.unit.relative_size // unit.relative_size
-    return self.val * factor
-
   def __str__(self):
     return str(self.val)+' '+self.unit.notation
 
@@ -162,25 +175,51 @@ class _NonZeroDuration(Duration):
     if other is _zero:
       return self
     self_val, other_val, unit = _same_unit(self, other)
-    return duration(self_val + other_val, unit)
+    return _duration_no_checks(self_val + other_val, unit)
 
   def __sub__(self, other):
     if other is _zero:
-      return duration(-self.val, self.unit)
+      return self
     self_val, other_val, unit = _same_unit(self, other)
-    return duration(self_val - other_val, unit)
+    return _duration_no_checks(self_val - other_val, unit)
 
   def __neg__(self):
-    return duration(-self.val, self.unit)
+    return _NonZeroDuration(-self.val, self.unit)
 
   def __mul__(self, other: int) -> Duration:
     if other == 0:
       return _zero
     return _NonZeroDuration(self.val * other, self.unit)
 
-  # Commutativity
   __rmul__ = __mul__
 
+  def __floordiv__(self, other) -> int:
+    if isinstance(other, Duration):
+      return Duration.__floordiv__(self, other)
+    elif isinstance(other, int):
+      if other == 0:
+        raise ZeroDivisionError("duration floordiv by zero")
+      else:
+        new_val = self.val // other
+        if new_val == 0:
+          return _zero
+        else:
+          return _NonZeroDuration(self.val, self.unit)
+    else:
+      raise TypeError("cannot floordiv duration by %s" % type(other))
+
+  def __mod__(self, other):
+    if isinstance(other, Duration):
+      return Duration.__mod__(self, other)
+    elif isinstance(other, int):
+      return _duration_no_checks(self.val % other, self.unit)
+    else:
+      raise TypeError("cannot modulo duration by %s" % type(other))
+
+  def __bool__(self):
+    return True
+
+# Convert both durations to the largest unit among them.
 def _same_unit(x: Duration, y: Duration) -> Tuple[int, int, _Unit]:
   if x is _zero:
     if y is _zero:
@@ -190,16 +229,17 @@ def _same_unit(x: Duration, y: Duration) -> Tuple[int, int, _Unit]:
   if y is _zero:
     return (x.val, 0, x.unit)
 
-  if x.unit.relative_size >= y.unit.relative_size:
-    return (x._convert(y.unit), y.val, y.unit)
-  else:
-    return (x.val, y._convert(x.unit), x.unit)
-  return (x_conv, y_conv, unit)
+  try:
+    factor = x.unit.conv_dict[y.unit]
+    return (x.val, y.val * factor, x.unit)
+  except KeyError:
+    factor = y.unit.conv_dict[x.unit]
+    return (x.val * factor, y.val, y.unit)
 
 def gcd_pair(x: Duration, y: Duration) -> Duration:
   x_conv, y_conv, unit = _same_unit(x, y)
   gcd = math.gcd(x_conv, y_conv)
-  return duration(gcd, unit)
+  return _duration_no_checks(gcd, unit)
 
 def gcd(*iterable: Iterable[Duration]) -> Duration:
   return functools.reduce(gcd_pair, iterable, _zero)