Source code for ash_model.classes.presence_store

from __future__ import annotations

"""presence_store.py
~~~~~~~~~~~~~~~~~~~~
Tiny abstraction layer that lets **ASH** swap the internal representation of
*temporal presence* without touching the public API.

Only the very small subset of dictionary behaviour used by ``ASH`` is
re‑implemented:

* ``__getitem__(t)`` – return the set of IDs present at *t* (read‑only).
* ``setdefault(t, set())`` – return a *mutable* set‑like proxy so that calls
  like ``self._snapshots.setdefault(t, set()).add(hid)`` keep working.
* ``get(t, default)`` – same semantics as ``dict.get``.
* ``keys()`` – iterable of snapshot indices.

Two concrete stores are provided:

* :class:`DensePresenceStore` – thin subclass of ``defaultdict(set)`` that
    keeps a *dense* mapping ``time → set[id]``.  This
    is the default behavior.
* :class:`IntervalPresenceStore` – keeps ``id → list[(start, end)]`` disjoint
  intervals and a tiny ``time_counts`` map so that ``keys()`` is cheap.

Switching the back‑end is as simple as:

>>> h = ASH(backend="interval")

No other public API changed.
"""

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List, Set, Tuple

###############################################################################
# Abstract façade
###############################################################################


[docs]class PresenceStore(ABC): """Minimal dict‑like interface required by :class:`ASH`.""" # --------------------------------------------------------------------- # Required by ASH # --------------------------------------------------------------------- @abstractmethod def __getitem__(self, t: int) -> Set[int]: """Return **read‑only** snapshot set for *t*."""
[docs] @abstractmethod def setdefault(self, t: int, default: Set[int]) -> "_SnapshotMutable": """Return mutable set‑like view (creates snapshot if absent)."""
[docs] @abstractmethod def get(self, t: int, default): """Dict‑style *get*."""
[docs] @abstractmethod def keys(self) -> Iterable[int]: """Return iterable of snapshot indices."""
# ------------------------------------------------------------------ # Convenience – these are never called directly by ASH but make the # façade quack like a normal dict. # ------------------------------------------------------------------ def __contains__(self, key: int) -> bool: # pragma: no cover – trivial return key in self.keys() def __iter__(self) -> Iterator[int]: # pragma: no cover – trivial return iter(self.keys()) def __len__(self) -> int: # pragma: no cover – trivial return len(list(self.keys()))
############################################################################### # Dense implementation (status quo) ###############################################################################
[docs]class DensePresenceStore(defaultdict, PresenceStore): """Keep the original *dense* mapping ``time → set[id]`` intact."""
[docs] def __init__(self): super().__init__(set)
# defaultdict already provides all behaviours we need. The overrides # below are just for static typing clarity. def __getitem__(self, t: int) -> Set[int]: # type: ignore[override] return super().__getitem__(t)
[docs] def setdefault(self, t: int, default: Set[int]) -> Set[int]: # type: ignore[override] return super().setdefault(t, default)
############################################################################### # Interval implementation ############################################################################### class _SnapshotMutable(set): """A *mutable* view returned by :meth:`IntervalPresenceStore.setdefault`.""" __slots__ = ("_store", "_time") def __init__(self, store: "IntervalPresenceStore", time: int, data: Set[int]): super().__init__(data) # materialised copy so we can do normal set ops self._store = store self._time = time # Mutators – keep the interval representation in sync ----------------- def add(self, element: int): # type: ignore[override] if element not in self: super().add(element) self._store._add_occurrence(element, self._time) def discard(self, element: int): # type: ignore[override] if element in self: super().discard(element) self._store._remove_occurrence(element, self._time) def remove(self, element: int): # type: ignore[override] if element not in self: raise KeyError(element) self.discard(element)
[docs]class IntervalPresenceStore(PresenceStore): """Sparse *interval* representation. Internally we keep: ``_intervals`` – mapping ``id → List[(start, end)]`` (sorted, disjoint). ``_starts`` – mapping ``id → List[start]`` (parallel to intervals, for bisect). ``_time_events`` – difference array ``time → delta`` for O(1) interval updates. ``_time_counts`` – rebuilt lazily from ``_time_events`` when needed. Optimizations: - Binary search (bisect) for presence checks: O(log k) per id. - Event-diff updates: O(1) per interval add/remove (vs O(length)). """
[docs] def __init__(self): self._intervals: Dict[int, List[Tuple[int, int]]] = defaultdict(list) self._starts: Dict[int, List[int]] = defaultdict(list) # for bisect self._time_events: Dict[int, int] = defaultdict(int) # difference array self._time_counts: Dict[int, int] = {} self._time_counts_valid: bool = True # lazy rebuild flag
# ------------------------------------------------------------------ # Public façade – dict‑like behaviour expected by ASH # ------------------------------------------------------------------ def __getitem__(self, t: int) -> Set[int]: # noqa: Dunder – matches dict API """Return *read‑only* snapshot (materialised as an ordinary set).""" return self._materialise(t)
[docs] def setdefault(self, t: int, default: Set[int]): # noqa: Dunder """Return *mutable* view for snapshot *t* (creates if absent).""" # Ensures time bucket exists so later *discard* knows "empty" means # absence, not "store unaware". if t not in self._time_counts: self._time_counts[t] = 0 # really empty snapshot for now return _SnapshotMutable(self, t, self._materialise(t))
[docs] def get(self, t: int, default): # noqa: Dunder – dict API return self._materialise(t) if t in self else default
[docs] def keys(self) -> Iterable[int]: # noqa: Dunder – dict API self._ensure_time_counts() return self._time_counts.keys()
# ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ def _ensure_time_counts(self) -> None: """Rebuild _time_counts from _time_events (lazy reconstruction).""" if self._time_counts_valid: return self._time_counts.clear() if not self._time_events: self._time_counts_valid = True return # Compute running count and track all times with positive count running = 0 sorted_events = sorted(self._time_events.items()) if not sorted_events: self._time_counts_valid = True return # Build ranges where count > 0 current_t = sorted_events[0][0] for t, delta in sorted_events: # Fill in all times from current_t to t-1 with running count if running > 0: for tt in range(current_t, t): self._time_counts[tt] = running running += delta current_t = t self._time_counts_valid = True # ---------- helpers for snapshot (de)materialisation ----------------- def _materialise(self, t: int) -> Set[int]: """Compute *set* of IDs alive at time *t* using bisect (O(log k) per id).""" import bisect present: Set[int] = set() for hid, intervals in self._intervals.items(): if not intervals: continue starts = self._starts[hid] # bisect_right(starts, t) gives index of first start > t i = bisect.bisect_right(starts, t) - 1 if i >= 0: s, e = intervals[i] if s <= t <= e: present.add(hid) return present # ---------- mutators ------------------------------------------------- def _add_occurrence(self, hid: int, t: int) -> None: """Insert a *single* time‑point into ``hid``'s interval list.""" intervals = self._intervals[hid] starts = self._starts[hid] # Check if already present already_present = False for s, e in intervals: if s <= t <= e: already_present = True break if already_present: return # nothing to do # Find merge candidates and apply if not intervals: # first ever intervals.append((t, t)) starts.append(t) self._time_events[t] += 1 self._time_events[t + 1] -= 1 self._time_counts_valid = False return merged = False for i, (s, e) in enumerate(intervals): # Extend forward if t == e + 1: intervals[i] = (s, t) # Merge with next if adjacent if i + 1 < len(intervals) and intervals[i + 1][0] == t + 1: n_s, n_e = intervals.pop(i + 1) starts.pop(i + 1) intervals[i] = (s, n_e) # Net event: only t gets added (merging doesn't change event counts elsewhere) else: # Just extending: add t self._time_events[t] += 1 self._time_events[t + 1] -= 1 merged = True break # Extend backward if t == s - 1: intervals[i] = (t, e) starts[i] = t # Merge with previous if adjacent if i - 1 >= 0 and intervals[i - 1][1] == t - 1: p_s, p_e = intervals.pop(i - 1) starts.pop(i - 1) intervals[i - 1] = (p_s, e) starts[i - 1] = p_s # Net event: only t gets added else: self._time_events[t] += 1 self._time_events[t + 1] -= 1 merged = True break if t < s - 1: intervals.insert(i, (t, t)) starts.insert(i, t) self._time_events[t] += 1 self._time_events[t + 1] -= 1 merged = True break if not merged: intervals.append((t, t)) starts.append(t) self._time_events[t] += 1 self._time_events[t + 1] -= 1 self._time_counts_valid = False def _remove_occurrence(self, hid: int, t: int) -> None: """Remove a single time‑point from ``hid``'s intervals (if present).""" intervals = self._intervals.get(hid, []) starts = self._starts.get(hid, []) found = False for i, (s, e) in enumerate(intervals): if s <= t <= e: found = True if s == e == t: # whole interval goes away intervals.pop(i) starts.pop(i) elif t == s: # shrink from left intervals[i] = (s + 1, e) starts[i] = s + 1 elif t == e: # shrink from right intervals[i] = (s, e - 1) else: # split interval in two intervals[i] = (s, t - 1) intervals.insert(i + 1, (t + 1, e)) starts.insert(i + 1, t + 1) break # done if found: # Update time events (difference array) self._time_events[t] -= 1 self._time_events[t + 1] += 1 self._time_counts_valid = False def _add_interval(self, hid: int, start: int, end: int) -> None: """Insert the entire [start,end] span into ``hid``'s interval list in one pass.""" intervals = self._intervals[hid] starts = self._starts[hid] # Track old coverage before merge old_intervals = list(intervals) new_s, new_e = start, end i = 0 # 1) Merge any overlapping or adjacent existing intervals merged_indices = [] while i < len(intervals): s, e = intervals[i] if e + 1 < new_s: i += 1 continue if s - 1 > new_e: break # overlapping or adjacent -> absorb new_s = min(new_s, s) new_e = max(new_e, e) merged_indices.append(i) i += 1 # Remove merged intervals in reverse to preserve indices for idx in reversed(merged_indices): intervals.pop(idx) starts.pop(idx) # 2) Insert the merged interval at the correct position insert_pos = len([s for s in starts if s < new_s]) intervals.insert(insert_pos, (new_s, new_e)) starts.insert(insert_pos, new_s) # 3) Update events: add for [new_s, new_e], subtract for old intervals that were merged self._time_events[new_s] += 1 self._time_events[new_e + 1] -= 1 for idx in merged_indices: old_s, old_e = old_intervals[idx] self._time_events[old_s] -= 1 self._time_events[old_e + 1] += 1 self._time_counts_valid = False def _remove_interval(self, hid: int, start: int, end: int) -> None: """Remove the entire [start,end] span from ``hid``'s interval list in one pass.""" intervals = self._intervals.get(hid, []) starts = self._starts.get(hid, []) # Track old intervals that will be affected old_intervals = list(intervals) affected_indices = [] i = 0 to_insert = [] indices_to_remove = [] while i < len(intervals): s, e = intervals[i] if e < start or s > end: i += 1 continue # overlapping: may need to split or shrink affected_indices.append(i) before = (s, start - 1) if s < start else None after = (end + 1, e) if e > end else None indices_to_remove.append(i) if before: to_insert.append((i, before)) if after: to_insert.append((i, after)) i += 1 # Remove old intervals in reverse for idx in reversed(indices_to_remove): intervals.pop(idx) starts.pop(idx) # Insert new split intervals for idx, (new_s, new_e) in to_insert: insert_pos = len([s for s in starts if s < new_s]) intervals.insert(insert_pos, (new_s, new_e)) starts.insert(insert_pos, new_s) # Update events: remove [start, end], add back the split pieces self._time_events[start] -= 1 self._time_events[end + 1] += 1 for idx in affected_indices: old_s, old_e = old_intervals[idx] self._time_events[old_s] -= 1 self._time_events[old_e + 1] += 1 for _, (new_s, new_e) in to_insert: self._time_events[new_s] += 1 self._time_events[new_e + 1] -= 1 self._time_counts_valid = False