Browse Source

Simplify Duration class

Joeri Exelmans 5 years ago
parent
commit
0886183702
2 changed files with 48 additions and 76 deletions
  1. 42 48
      src/sccd/util/duration.py
  2. 6 28
      src/sccd/util/test_duration.py

+ 42 - 48
src/sccd/util/duration.py

@@ -55,31 +55,28 @@ class Duration:
     if self.val == 0 and self.unit is not None:
       raise Exception("Duration: Zero value should not have unit")
 
+    # Convert Duration to the largest possible unit without losing accuracy.
+
+    if self.unit is None:
+      return
+
+    while self.unit.larger:
+      next_unit, factor = self.unit.larger
+      if self.val % factor != 0:
+        break
+      self.val //= factor
+      self.unit = next_unit
+
   # Can only convert to smaller units.
   # Returns new Duration.
-  def convert(self, unit: Unit):
+  def _convert(self, unit: Unit) -> int:
     if self.unit is None:
-      return self
+      return 0
 
     # Precondition
     assert self.unit.relative_size >= unit.relative_size
     factor = self.unit.relative_size // unit.relative_size
-    return Duration(self.val * factor, unit)
-
-  # Convert Duration to the largest possible unit.
-  # Returns new Duration.
-  def normalize(self):
-    if self.unit is None:
-      return self
-
-    val = self.val
-    unit = self.unit
-    next_unit, factor = unit.larger
-    while val % factor == 0:
-      val //= factor
-      unit = next_unit
-      next_unit, factor = unit.larger
-    return Duration(val, unit)
+    return self.val * factor
 
   def __str__(self):
     if self.unit is None:
@@ -92,7 +89,7 @@ class Duration:
   def __eq__(self, other):
     return self.val == other.val and self.unit is other.unit
 
-  def __mul__(self, other):
+  def __mul__(self, other: int):
     new_val = self.val * other
     if new_val == 0:
       return Duration(0)
@@ -102,47 +99,44 @@ class Duration:
   # Commutativity
   __rmul__ = __mul__
 
-  def __floordiv__(self, other):
-    # if isinstance(other, Duration):
-      self_conv, other_conv = same_unit(self, other)
-      return self_conv.val // other_conv.val
-    # else:
-      # return Duration(self.val//other, self.unit)
-
+  def __floordiv__(self, other: 'Duration'):
+    if other.val == 0:
+      raise ZeroDivisionError("duration floordiv by zero duration")
+    self_conv, other_conv, _ = _same_unit(self, other)
+    return self_conv // other_conv
 
   def __mod__(self, other):
-    # if isinstance(other, Duration):
-      self_conv, other_conv = same_unit(self, other)
-      new_val = self_conv.val % other_conv.val
+      self_conv, other_conv, unit = _same_unit(self, other)
+      new_val = self_conv % other_conv
       if new_val == 0:
         return Duration(0)
       else:
-        return Duration(new_val, self_conv.unit)
-    # else:
-      # return Duration(self.val%other, self.unit)
+        return Duration(new_val, unit)
 
   def __lt__(self, other):
-    self_conv, other_conv = same_unit(self, other)
+    self_conv, other_conv = _same_unit(self, other)
     return self_conv.val < other_conv.val
 
-def same_unit(x: Duration, y: Duration) -> Tuple[Duration, Duration]:
+def _same_unit(x: Duration, y: Duration) -> Tuple[int, int, Unit]:
+  if x.unit is None:
+    return (0, y.val, y.unit)
+  if y.unit is None:
+    return (x.val, 0, x.unit)
+
   if x.unit.relative_size >= y.unit.relative_size:
-    x_conv = x.convert(y.unit)
-    y_conv = y
+    x_conv = x._convert(y.unit)
+    y_conv = y.val
+    unit = y.unit
   else:
-    x_conv = x
-    y_conv = y.convert(x.unit)
-  return (x_conv, y_conv)
+    x_conv = x.val
+    y_conv = y._convert(x.unit)
+    unit = x.unit
+  return (x_conv, y_conv, unit)
 
 def gcd_pair(x: Duration, y: Duration) -> Duration:
-  if x.unit is None:
-    return y
-  if y.unit is None:
-    return x
-
-  x_conv, y_conv = same_unit(x, y)
-  gcd = math.gcd(x_conv.val, y_conv.val)
-  return Duration(gcd, x_conv.unit).normalize()
+  x_conv, y_conv, unit = _same_unit(x, y)
+  gcd = math.gcd(x_conv, y_conv)
+  return Duration(gcd, unit)
 
-def gcd(*iterable) -> Duration:
+def gcd(*iterable: Iterable[Duration]) -> Duration:
   return functools.reduce(gcd_pair, iterable, Duration(0))

+ 6 - 28
src/sccd/util/test_duration.py

@@ -7,36 +7,12 @@ class TestDuration(unittest.TestCase):
     # The same amount of time, but objects not considered equal.
     x = Duration(1000, Millisecond)
     y = Duration(1, Second)
+    z = Duration(3, Day)
 
-    self.assertNotEqual(x, y)
+    self.assertEqual(x, y)
+    self.assertEqual(y, x)
 
-    self.assertEqual(x.normalize(), y.normalize())
-
-    # original objects left intact by normalize() operation
-    self.assertNotEqual(x, y)
-
-  def test_convert_unit(self):
-    x = Duration(2, Second)
-
-    x2 = x.convert(Microsecond)
-
-    self.assertEqual(x2, Duration(2000000, Microsecond))
-
-  def test_convert_zero(self):
-    x = Duration(0)
-
-    x2 = x.convert(Millisecond)
-
-    self.assertEqual(x2, Duration(0))
-
-  def test_normalize(self):
-    x = Duration(1000, Millisecond)
-    y = Duration(1000000, Microsecond)
-    z = Duration(0)
-
-    self.assertEqual(x.normalize(), Duration(1, Second))
-    self.assertEqual(y.normalize(), Duration(1, Second))
-    self.assertEqual(z.normalize(), Duration(0))
+    self.assertNotEqual(x, z)
 
   def test_gcd(self):
     x = Duration(20, Second)
@@ -95,6 +71,8 @@ class TestDuration(unittest.TestCase):
     self.assertEqual(y // x, 0)
     self.assertEqual(x // z, 33)
 
+    self.assertRaises(ZeroDivisionError, lambda: x // Duration(0))
+
   def test_mod(self):
     x = Duration(100, Millisecond)
     y = Duration(10, Microsecond)