1.1 --- a/imiptools/period.py Sun Oct 18 18:02:07 2015 +0200
1.2 +++ b/imiptools/period.py Sun Oct 18 19:33:04 2015 +0200
1.3 @@ -28,6 +28,70 @@
1.4 get_tzid, \
1.5 to_timezone, to_utc_datetime
1.6
1.7 +class Comparable:
1.8 +
1.9 + "A date/datetime wrapper that allows comparisons with other types."
1.10 +
1.11 + def __init__(self, dt):
1.12 + self.dt = dt
1.13 +
1.14 + def __cmp__(self, other):
1.15 + dt = None
1.16 + odt = None
1.17 +
1.18 + # Find any dates/datetimes.
1.19 +
1.20 + if isinstance(self.dt, date):
1.21 + dt = self.dt
1.22 + if isinstance(other, date):
1.23 + odt = other
1.24 + elif isinstance(other, Comparable):
1.25 + if isinstance(other.dt, date):
1.26 + odt = other.dt
1.27 + else:
1.28 + other = other.dt
1.29 +
1.30 + if dt and odt:
1.31 + return cmp(dt, odt)
1.32 + elif dt:
1.33 + return other.__rcmp__(dt)
1.34 + elif odt:
1.35 + return self.dt.__cmp__(odt)
1.36 + else:
1.37 + return self.dt.__cmp__(other)
1.38 +
1.39 +class PointInTime:
1.40 +
1.41 + "A base class for special values."
1.42 +
1.43 + pass
1.44 +
1.45 +class StartOfTime(PointInTime):
1.46 +
1.47 + "A special value that compares earlier than other values."
1.48 +
1.49 + def __cmp__(self, other):
1.50 + if isinstance(other, StartOfTime):
1.51 + return 0
1.52 + else:
1.53 + return -1
1.54 +
1.55 + def __rcmp__(self, other):
1.56 + return -self.__cmp__(other)
1.57 +
1.58 +class EndOfTime(PointInTime):
1.59 +
1.60 + "A special value that compares later than other values."
1.61 +
1.62 + def __cmp__(self, other):
1.63 + if isinstance(other, EndOfTime):
1.64 + return 0
1.65 + else:
1.66 + return 1
1.67 +
1.68 + def __rcmp__(self, other):
1.69 + return -self.__cmp__(other)
1.70 +
1.71 class PeriodBase:
1.72
1.73 "A basic period abstraction."
1.74 @@ -44,12 +108,16 @@
1.75
1.76 if isinstance(other, PeriodBase):
1.77 return cmp(
1.78 - (self.get_start_point(), self.get_end_point()),
1.79 - (other.get_start_point(), other.get_end_point())
1.80 + (Comparable(self.get_start_point() or StartOfTime()), Comparable(self.get_end_point() or EndOfTime())),
1.81 + (Comparable(other.get_start_point() or StartOfTime()), Comparable(other.get_end_point() or EndOfTime()))
1.82 )
1.83 else:
1.84 return 1
1.85
1.86 + def overlaps(self, other):
1.87 + return Comparable(self.get_end_point()) > Comparable(other.get_start_point()) and \
1.88 + Comparable(self.get_start_point()) < Comparable(other.get_end_point())
1.89 +
1.90 def get_key(self):
1.91 return self.get_start(), self.get_end()
1.92
1.93 @@ -96,8 +164,8 @@
1.94 dates/datetimes.
1.95 """
1.96
1.97 - self.start = isinstance(start, date) and start or get_datetime(start)
1.98 - self.end = isinstance(end, date) and end or get_datetime(end)
1.99 + self.start = isinstance(start, date) and start or isinstance(start, PointInTime) and start or get_datetime(start) or StartOfTime()
1.100 + self.end = isinstance(end, date) and end or isinstance(end, PointInTime) and end or get_datetime(end) or EndOfTime()
1.101 self.tzid = tzid
1.102 self.origin = origin
1.103
1.104 @@ -113,10 +181,12 @@
1.105 return get_tzid(self.get_start_attr(), self.get_end_attr()) or self.tzid
1.106
1.107 def get_start_point(self):
1.108 - return to_utc_datetime(self.get_start(), self.get_tzid())
1.109 + start = self.get_start()
1.110 + return isinstance(start, PointInTime) and start or to_utc_datetime(start, self.get_tzid())
1.111
1.112 def get_end_point(self):
1.113 - return to_utc_datetime(self.get_end(), self.get_tzid())
1.114 + end = self.get_end()
1.115 + return isinstance(end, PointInTime) and end or to_utc_datetime(end, self.get_tzid())
1.116
1.117 # Period and event recurrence logic.
1.118
1.119 @@ -411,15 +481,15 @@
1.120 # Find the range of periods potentially overlapping the period in a version
1.121 # of the free/busy collection sorted according to end datetimes.
1.122
1.123 - endpoints = [(fb.get_end_point(), fb.get_start_point(), fb) for fb in startpoints]
1.124 + endpoints = [(Period(fb.get_end_point(), fb.get_end_point()), fb) for fb in startpoints]
1.125 endpoints.sort()
1.126 - first = bisect_left(endpoints, (period.get_start_point(),))
1.127 + first = bisect_left(endpoints, (Period(period.get_start_point(), period.get_start_point()),))
1.128 endpoints = endpoints[first:]
1.129
1.130 overlapping = set()
1.131
1.132 - for end, start, fb in endpoints:
1.133 - if end > period.get_start_point() and start < period.get_end_point():
1.134 + for p, fb in endpoints:
1.135 + if fb.overlaps(period):
1.136 overlapping.add(fb)
1.137
1.138 overlapping = list(overlapping)