Source code for ash_model.paths.time_respecting_walks

import copy
from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple, Optional, List, Union

import networkx as nx
import numpy as np

from ash_model import ASH


# A temporal edge in an s-walk: from-hyperedge id, to-hyperedge id, weight, and timestamp.
@dataclass(frozen=True)
class TemporalEdge:
    fr: str
    to: str
    weight: float
    tid: int


[docs]def temporal_s_dag( h: ASH, s: int, start_from: Optional[Union[int, str, List[Union[int, str]]]] = None, stop_at: Optional[Union[int, str]] = None, start: Optional[int] = None, end: Optional[int] = None, edge: bool = False, ) -> Tuple[nx.DiGraph, List[str], List[str]]: """ Build a time-respecting DAG over [start, end] for either hyperedges (edge=True) or nodes (edge=False). Nodes are labeled as "<id>_<tid>". Edges connect items from different timestamps following chronological order, ensuring time-respecting properties. Only forward-in-time edges are created: from timestamp t to timestamps t' where t' > t. :param h: The source hypergraph. :param s: Minimum s-incidence threshold. :param start_from: Node or hyperedge id(s) to start from. If None, starts from all items present at the first snapshot in range. :param stop_at: Node or hyperedge id to stop at (optional). :param start: First snapshot ID to include. Defaults to earliest. :param end: Last snapshot ID to include. Defaults to latest. :param edge: If True, operate on hyperedges. If False, operate on nodes. :returns: (DAG, sources, targets) with labels "<id>_<tid>". :raises ValueError: If the [start, end] interval is not a valid subset of the hypergraph's snapshot IDs. """ ids = h.temporal_snapshots_ids() if len(ids) == 0: return nx.DiGraph(), [], [] if end is None: end = ids[-1] if start is None: start = ids[0] if start < min(ids) or start > end or end > max(ids) or start > max(ids): raise ValueError( f"The specified interval {[start, end]} is not a proper subset of the network timestamps " f"{[min(ids), max(ids)]}." ) start_idx = next(i for i, t in enumerate(ids) if t >= start) end_idx = max(i for i, t in enumerate(ids) if t <= end) ids = ids[start_idx : end_idx + 1] # Normalize seeds to list[str] if start_from is None: seeds: List[str] = [] elif isinstance(start_from, (list, tuple, set)): seeds = [str(x) for x in start_from] else: seeds = [str(start_from)] DG = nx.DiGraph() sources, targets = {}, {} # First pass: build all edges and collect reachable items all_edges = [] item_at_time = defaultdict(set) # Track which items exist at which times for i, tid in enumerate(ids): # Track items present at this timestamp if edge: for he in h.hyperedges(start=tid, end=tid): item_at_time[tid].add(str(he)) else: for n in h.nodes(start=tid, end=tid): item_at_time[tid].add(str(n)) # Build connections to all future timestamps for future_idx in range(i + 1, len(ids)): future_tid = ids[future_idx] if edge: # For edge mode: find s-incident hyperedges at future timestamp for he in h.hyperedges(start=tid, end=tid): he_id = str(he) raw_neighbors = h.get_s_incident( he_id, s=s, start=future_tid, end=future_tid ) for n_id, w in raw_neighbors: if n_id != he_id: all_edges.append( (f"{he_id}_{tid}", f"{n_id}_{future_tid}", w) ) else: # For node mode: find co-members at future timestamp node_neighbors_future: dict[str, dict[str, int]] = defaultdict( lambda: defaultdict(int) ) for he in h.hyperedges(start=future_tid, end=future_tid): he_nodes = list(h.get_hyperedge_nodes(he)) for ii in range(len(he_nodes)): u = str(he_nodes[ii]) for j in range(len(he_nodes)): if ii == j: continue v = str(he_nodes[j]) node_neighbors_future[u][v] += 1 # Process nodes active at current tid for n in h.nodes(start=tid, end=tid): n_id = str(n) counts = node_neighbors_future.get(n_id, {}) for v, c in counts.items(): if c >= s and v != n_id: all_edges.append((f"{n_id}_{tid}", f"{v}_{future_tid}", c)) # Add all edges to graph for u, v, w in all_edges: DG.add_edge(u, v, weight=w) # Determine sources based on start_from if seeds: # If seeds specified, sources are only at timestamps where seed items have outgoing edges for seed_id in seeds: for tid in ids: node_label = f"{seed_id}_{tid}" # Check if this seed exists at this time and has outgoing edges if ( seed_id in item_at_time[tid] and DG.has_node(node_label) and DG.out_degree[node_label] > 0 ): sources[node_label] = None else: # If no seeds, sources are all items at the first timestamp with outgoing edges first_tid = ids[0] for item_id in item_at_time[first_tid]: node_label = f"{item_id}_{first_tid}" if DG.has_node(node_label) and DG.out_degree[node_label] > 0: sources[node_label] = None # Determine targets if stop_at is not None: # Only nodes matching stop_at are targets stop_id = str(stop_at) for tid in ids: node_label = f"{stop_id}_{tid}" if DG.has_node(node_label) and DG.in_degree[node_label] > 0: targets[node_label] = None else: # All reachable nodes (except sources) are potential targets for node in DG.nodes(): if node not in sources and DG.in_degree[node] > 0: targets[node] = None excluded_ids = set(str(s).split("_")[0] for s in seeds) if seeds else set() final_targets = [t for t in targets if t.split("_")[0] not in excluded_ids] return DG, list(sources), final_targets
[docs]def time_respecting_s_walks( h: ASH, s: int, start_from: Union[str, List[str]], stop_at: Optional[str] = None, start: int = None, end: int = None, sample: float = 1, ) -> dict: """ Enumerate all time-respecting s-walks between a given source and optionally a target hyperedge. :param h: The source hypergraph. :param s: Minimum number of shared nodes for s-incidence. :param start_from: ID o lista di iperarchi da cui partire. :param stop_at: Se fornito, considera solo cammini che terminano a questo iperarco. :param start: First snapshot to include. :param end: Last snapshot to include. :param sample: Fraction of source-target pairs to sample (0 < sample <= 1). :returns: Mapping (start_edge, end_edge) -> list of walks (TemporalEdge lists). """ DAG, sources, targets = temporal_s_dag( h, s, start_from=start_from, stop_at=stop_at, start=start, end=end, edge=True ) pairs = [(x, y) for x in sources for y in targets] if sample < 1: to_sample = int(len(pairs) * sample) idxs = np.random.choice(len(pairs), size=to_sample, replace=False) pairs = [pairs[i] for i in idxs] paths = [] for src, dst in pairs: for path_nodes in nx.all_simple_paths(DAG, src, dst): seq = [] for u, v in zip(path_nodes, path_nodes[1:]): t_from = int(u.split("_")[-1]) t_to = int(v.split("_")[-1]) w = DAG[u][v]["weight"] seq.append(TemporalEdge(u.split("_")[0], v.split("_")[0], w, t_to)) # Validate time-respecting property: each step must happen at a strictly later time if len(seq) > 0: valid = True prev_edge = seq[0] for edge in seq[1:]: # Each edge must occur at a strictly later timestamp if edge.tid <= prev_edge.tid: valid = False break # Also reject immediate back-and-forth between same pair of nodes if edge.fr == prev_edge.to and edge.to == prev_edge.fr: valid = False break prev_edge = edge if valid: paths.append(seq) unique = list({tuple(w): w for w in paths}.values()) res = defaultdict(list) for w in unique: key = (w[0].fr, w[-1].to) res[key].append(w) return res
[docs]def all_time_respecting_s_walks( h: ASH, s: int, start: int = None, end: int = None, sample: float = 1, ) -> dict: """ Compute time-respecting s-walks originating from every hyperedge in the graph. :param h: The hypergraph. :param s: Minimum s-incidence threshold. :param start: Earliest snapshot to include. :param end: Latest snapshot to include. :param sample: Fraction of source-target samples per origin. :returns: Mapping (origin_edge, destination_edge) -> list of walks. """ res = {} for he in h.hyperedges(start=start, end=end): subpaths = time_respecting_s_walks( h, s, start_from=he, stop_at=None, start=start, end=end, sample=sample, ) for key, walks in subpaths.items(): if walks: res[(he, key[1])] = walks return res
[docs]def annotate_walks(paths: list) -> dict: """ Annotate a list of s-walks with standard path metrics. :param paths: The walks to classify. :returns: Dictionary of metric names to lists of walks. """ metrics = [] for p in paths: length = len(p) duration = p[-1].tid - p[0].tid weight = sum(e.weight for e in p) reach = p[-1].tid metrics.append( { "path": p, "length": length, "duration": duration, "weight": weight, "reach": reach, } ) shortest = min(metrics, key=lambda m: m["length"])["length"] fastest = min(metrics, key=lambda m: m["duration"])["duration"] heaviest = max(metrics, key=lambda m: m["weight"])["weight"] foremost = min(metrics, key=lambda m: m["reach"])["reach"] def by(key, op, val): return [m["path"] for m in metrics if op(m[key], val)] return { "shortest": by("length", lambda x, y: x == y, shortest), "fastest": by("duration", lambda x, y: x == y, fastest), "heaviest": by("weight", lambda x, y: x == y, heaviest), "foremost": by("reach", lambda x, y: x == y, foremost), "shortest_fastest": by( "duration", lambda x, y: x == y, min(m["duration"] for m in metrics if m["length"] == shortest), ), "shortest_heaviest": by( "weight", lambda x, y: x == y, max(m["weight"] for m in metrics if m["length"] == shortest), ), "fastest_shortest": by( "length", lambda x, y: x == y, min(m["length"] for m in metrics if m["duration"] == fastest), ), "fastest_heaviest": by( "weight", lambda x, y: x == y, max(m["weight"] for m in metrics if m["duration"] == fastest), ), "heaviest_fastest": by( "duration", lambda x, y: x == y, max(m["duration"] for m in metrics if m["weight"] == heaviest), ), "heaviest_shortest": by( "length", lambda x, y: x == y, max(m["length"] for m in metrics if m["weight"] == heaviest), ), }
[docs]def walk_length(path: list) -> int: """ Compute the number of edges in a temporal walk. :param path: The walk to measure. :returns: Number of steps in the walk. """ return len(path)
[docs]def walk_duration(path: list) -> int: """ Compute the duration of a temporal walk. :param path: The walk to measure. :returns: Time difference between first and last edge. """ return int(path[-1].tid) - int(path[0].tid)
[docs]def walk_weight(path: list) -> int: """ Compute the total weight of a temporal walk. :param path: The walk to measure. :returns: Cumulative weight of the walk. """ return sum(p.weight for p in path)