Source code for ash_model.paths.randwalks

from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from scipy import sparse
import networkx as nx
import csrgraph as cg

from ash_model import ASH
from ash_model.paths import temporal_s_dag, TemporalEdge


def _normalize_rows(matrix: np.ndarray) -> np.ndarray:
    """
    Normalize each row of a numpy matrix so that rows sum to 1.
    """
    row_sums = matrix.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    return matrix / row_sums


def _map_to_indices(items: List[Any]) -> Tuple[Dict[Any, int], Dict[int, Any]]:
    """
    Create forward and reverse mappings between items and integer indices.
    """
    fwd = {item: idx for idx, item in enumerate(items)}
    rev = {idx: item for item, idx in fwd.items()}
    return fwd, rev


def _build_node_transition_matrix(
    h: ASH, s: int, start: Optional[int], end: Optional[int]
) -> Tuple[sparse.csr_matrix, Dict[Any, int]]:
    """
    Construct transition probability matrix between nodes in hyperedges.

    :param h: ASH hypergraph object
    :param s: Minimum s-incidence threshold (nodes must co-occur in at least s hyperedges)
    :param start: Lower temporal bound
    :param end: Upper temporal bound
    """
    nodes = list(h.nodes(start=start, end=end))
    n2idx, _ = _map_to_indices(nodes)
    n = len(nodes)

    # Track co-occurrence counts between node pairs
    cooccurrence_counts: Dict[Tuple[Any, Any], int] = {}

    for edge in h.hyperedges(start=start, end=end, as_ids=False):
        vertices = list(edge)
        for u in vertices:
            for v in vertices:
                if u != v:
                    pair = (u, v)
                    cooccurrence_counts[pair] = cooccurrence_counts.get(pair, 0) + 1

    # Build transition matrix only for pairs with >= s co-occurrences
    T = np.zeros((n, n), dtype=float)
    for (u, v), count in cooccurrence_counts.items():
        if count >= s:
            # Weight by number of co-occurrences
            T[n2idx[u], n2idx[v]] += count

    T = _normalize_rows(T)
    return sparse.csr_matrix(T), n2idx


def _build_edge_transition_matrix(
    h: ASH, s: int, start: Optional[int], end: Optional[int]
) -> Tuple[sparse.csr_matrix, Dict[Any, int]]:
    """
    Construct transition probability matrix on the line graph of hyperedges.

    :param h: ASH hypergraph object
    :param s: Minimum size of node intersection between hyperedges
    :param start: Lower temporal bound
    :param end: Upper temporal bound
    """
    G = h.s_line_graph(s=s, start=start, end=end)
    nodes = sorted(G.nodes())
    n2idx, _ = _map_to_indices(nodes)

    A = nx.to_numpy_array(G, nodelist=nodes, dtype=float)
    A = _normalize_rows(A)
    return sparse.csr_matrix(A), n2idx


[docs]def random_walk_probabilities( h: ASH, s: int = 1, start: Optional[int] = None, end: Optional[int] = None, edge: bool = False, ) -> Tuple[sparse.csr_matrix, Dict[Any, int]]: """ Compute CSR transition matrix and index mapping for nodes or hyperedges. :param h: ASH hypergraph object :param s: Minimum s-incidence threshold (for nodes: co-occurrence count; for edges: intersection size) :param start: Lower temporal bound :param end: Upper temporal bound :param edge: If True, compute for hyperedge line graph """ if edge: return _build_edge_transition_matrix(h, s, start, end) return _build_node_transition_matrix(h, s, start, end)
[docs]def random_walks( h: ASH, start_from: Union[int, str, List[Union[int, str]], None] = None, num_walks: int = 100, walk_length: int = 10, p: float = 1.0, q: float = 1.0, s: int = 1, edge: bool = False, start: Optional[int] = None, end: Optional[int] = None, threads: int = -1, ) -> np.ndarray: """ Generate biased random walks on ASH hypergraph (node or edge graph). :param h: ASH hypergraph object :param start_from: Node or list of nodes to start walks from :param num_walks: Number of walks per start node :param walk_length: Length of each walk :param p: Return parameter (higher values make walk less likely to return to previous node) :param q: In-out parameter (higher values make walk more local, lower values encourage exploration) :param s: Minimum s-incidence threshold. For node walks: nodes must co-occur in at least s hyperedges to be connected. For edge walks: hyperedges must share at least s nodes to be connected. :param edge: If True, walk on hyperedge line graph :param start: Lower temporal bound :param end: Upper temporal bound :param threads: Parallel threads for random walk computation :returns: Array of walks (each walk is a list of original node/edge IDs) Examples -------- .. code-block:: python # Node-based random walks with s=1 (any co-occurrence) walks = random_walks(h, num_walks=100, walk_length=10, s=1) # Node-based random walks with s=2 (nodes must co-occur in at least 2 hyperedges) walks = random_walks(h, num_walks=100, walk_length=10, s=2) # Hyperedge-based random walks with s=2 (hyperedges must share at least 2 nodes) walks = random_walks(h, num_walks=100, walk_length=10, s=2, edge=True) """ T_csr, n2idx = random_walk_probabilities(h, s, start, end, edge=edge) idx2n = {idx: node for node, idx in n2idx.items()} G = cg.csrgraph(T_csr, threads=threads) if start_from is None: start_nodes = None else: if not isinstance(start_from, list): start_from = [start_from] start_nodes = [n2idx[item] for item in start_from] raw = G.random_walks( walklen=walk_length, epochs=num_walks, start_nodes=start_nodes, return_weight=1.0 / p, neighbor_weight=1.0 / q, ) return np.array([[idx2n[idx] for idx in walk] for walk in raw])
[docs]def time_respecting_random_walks( h: ASH, s: int, start_from: Optional[Union[int, str, List[Union[int, str]]]] = None, stop_at: Optional[Union[int, str]] = None, start: Optional[int] = None, end: Optional[int] = None, num_walks: int = 100, walk_length: int = 10, p: float = 1.0, q: float = 1.0, edge: bool = False, threads: int = -1, ) -> Union[np.ndarray, Dict[Tuple[str, str], List[List[TemporalEdge]]]]: """ Generate biased, time-respecting random walks on the temporal hypergraph. This function builds a time-respecting transition matrix and uses it to guide random walks that respect temporal ordering. The approach uses the temporal DAG structure where all edges are forward-in-time transitions (t -> t' where t' > t). Semantics: - All transitions are forward-in-time, respecting strict temporal ordering - Each step moves to a strictly later timestamp - Walks terminate when no forward neighbors exist (reached a temporal sink) :param h: ASH hypergraph object :param s: Minimum s-incidence threshold :param start_from: Node or edge (or list) to start walks from :param stop_at: Node or edge to stop walks at (optional) :param start: Lower temporal bound :param end: Upper temporal bound :param num_walks: Number of walks per start node/edge :param walk_length: Length of each walk (number of transitions) :param p: Return parameter (higher values discourage returning to previous node) :param q: In-out parameter (higher values favor local exploration) :param edge: If True, walk on hyperedge line graph and return TemporalEdge dict :param threads: Parallel threads for random walk computation (currently unused in custom logic) :returns: If edge=False, ndarray of node ID sequences. If edge=True, dict mapping (start, end) to lists of TemporalEdge walks. Examples -------- .. code-block:: python # Time-respecting node walks walks = time_respecting_random_walks(h, s=1, num_walks=100, walk_length=10) # Time-respecting hyperedge walks walks_dict = time_respecting_random_walks(h, s=2, num_walks=100, walk_length=10, edge=True) # Start from specific nodes walks = time_respecting_random_walks(h, s=1, start_from=[1, 2], num_walks=50) """ # Build temporal DAG DAG, sources, _ = temporal_s_dag( h, s, start_from, stop_at, start=start, end=end, edge=edge ) # Build neighbor maps: all edges are now forward-in-time (time-respecting) # There are NO same-timestamp transitions in the corrected implementation nodes = [n for n in DAG.nodes() if isinstance(n, str) and "_" in n] neighbors: Dict[str, List[Tuple[str, float]]] = {} for u, v, attrs in DAG.edges(data=True): if "_" not in str(u) or "_" not in str(v): continue # All edges are forward-in-time transitions (t -> t' where t' > t) neighbors.setdefault(str(u), []).append( (str(v), float(attrs.get("weight", 1.0))) ) # Helper to choose a neighbor by weights def pick_weighted(neis: List[Tuple[str, float]]) -> Optional[str]: if not neis: return None vs, ws = zip(*neis) ws = np.array(ws, dtype=float) s_sum = ws.sum() if s_sum <= 0: ws = np.ones_like(ws) / len(ws) else: ws = ws / s_sum idx = np.random.choice(len(vs), p=ws) return vs[idx] # Determine start nodes if start_from is None: start_nodes = [n for n in sources if n in neighbors or n in nodes] else: if not isinstance(start_from, list): start_from = [start_from] bases = set(str(sf) for sf in start_from) start_nodes = [n for n in nodes if n.split("_")[0] in bases] if edge: from collections import defaultdict res: Dict[Tuple[str, str], List[List[TemporalEdge]]] = defaultdict(list) for s_node in start_nodes: for _ in range(num_walks): path: List[TemporalEdge] = [] cur = s_node steps = 0 while steps < walk_length: neis = neighbors.get(cur, []) if not neis: # No outgoing edges - walk reached a temporal sink break nxt = pick_weighted(neis) if nxt is None: break # Append step (time-respecting: cur_t -> nxt_t' where t' > t) fr, ft = cur.split("_") to, tt = nxt.split("_") w = next((w for v, w in neis if v == nxt), 1.0) path.append(TemporalEdge(fr, to, float(w), int(tt))) steps += 1 cur = nxt if stop_at and (to == str(stop_at) or nxt == str(stop_at)): break if path: key = (path[0].fr, path[-1].to) res[key].append(path) return dict(res) else: walks: List[List[Union[int, str]]] = [] for s_node in start_nodes: for _ in range(num_walks): seq: List[Union[int, str]] = [] cur = s_node steps = 0 while steps < walk_length: neis = neighbors.get(cur, []) if not neis: # No outgoing edges - walk reached a temporal sink break nxt = pick_weighted(neis) if nxt is None: break base = nxt.split("_")[0] seq.append(base) steps += 1 cur = nxt if stop_at is not None and str(base) == str(stop_at): break if seq: walks.append(seq) return np.array(walks, dtype=object)