Source code for lymph.utils

"""Module containing supporting classes and functions used accross the project."""

import logging
from collections.abc import Mapping, Sequence
from functools import cached_property, lru_cache, wraps
from typing import Any, TypeVar

import numpy as np

from lymph import types

logger = logging.getLogger(__name__)


[docs] def check_unique_names(graph: dict): """Check all nodes in ``graph`` have unique names and no duplicate connections.""" node_name_set = set() for (_, node_name), connections in graph.items(): if isinstance(connections, set): raise TypeError("A node's connection list should not be a set (ordering)") if len(connections) != len(set(connections)): raise ValueError(f"Duplicate connections for node {node_name} in graph") if node_name in connections: raise ValueError(f"Node {node_name} is connected to itself") node_name_set.add(node_name) if len(node_name_set) != len(graph): raise ValueError("Node names are not unique")
[docs] def check_spsn(spsn: list[float]): """Check whether specificity and sensitivity are valid.""" has_len_2 = len(spsn) == 2 is_above_lb = np.all(np.greater_equal(spsn, 0.5)) is_below_ub = np.all(np.less_equal(spsn, 1.0)) if not has_len_2 or not is_above_lb or not is_below_ub: raise ValueError( "For each modality provide a list of two decimals between 0.5 and 1.0 as " "specificity & sensitivity respectively.", )
[docs] @lru_cache def comp_transition_tensor( num_parent: int, num_child: int, is_tumor_spread: bool, is_growth: bool, spread_prob: float, micro_mod: float, ) -> np.ndarray: """Compute the transition factors of the edge. The returned array is of shape (p,c,c), where p is the number of states of the parent node and c is the number of states of the child node. Essentially, the tensors computed here contain most of the parametrization of the model. They are used to compute the transition matrix. This function globally computes and caches the transition tensors, such that we do not need to worry about deleting and recomputing them when the parameters of the edge change. """ tensor = np.stack([np.eye(num_child)] * num_parent) # this should allow edges from trinary nodes to binary nodes pad = [0.0] * (num_child - 2) if is_tumor_spread: # NOTE: Here we define how tumors spread to LNLs tensor[0, 0, :] = np.array([1.0 - spread_prob, spread_prob, *pad]) return tensor if is_growth: # In the growth case, we can assume that two things: # 1. parent and child state are the same # 2. the child node is trinary tensor[1, 1, :] = np.array([0.0, (1 - spread_prob), spread_prob]) return tensor if num_parent == 3: # NOTE: here we define how the micro_mod affects the spread probability micro_spread = spread_prob * micro_mod tensor[1, 0, :] = np.array([1.0 - micro_spread, micro_spread, *pad]) macro_spread = spread_prob tensor[2, 0, :] = np.array([1.0 - macro_spread, macro_spread, *pad]) return tensor tensor[1, 0, :] = np.array([1.0 - spread_prob, spread_prob, *pad]) return tensor
[docs] def clinical(spsn: list) -> np.ndarray: """Produce the confusion matrix of a clinical modality. A clinical modality can by definition *not* detect microscopic metastases. """ check_spsn(spsn) sp, sn = spsn return np.array( [ [sp, 1.0 - sp], [sp, 1.0 - sp], [1.0 - sn, sn], ], )
[docs] def pathological(spsn: list) -> np.ndarray: """Produce the confusion matrix of a pathological modality. A pathological modality can detect microscopic disease, but is unable to differentiate between micro- and macroscopic involvement. """ check_spsn(spsn) sp, sn = spsn return np.array( [ [sp, 1.0 - sp], [1.0 - sn, sn], [1.0 - sn, sn], ], )
[docs] def tile_and_repeat( mat: np.ndarray, tile: tuple[int, int], repeat: tuple[int, int], ) -> np.ndarray: """Apply the numpy functions `tile`_ and `repeat`_ successively to ``mat``. .. _tile: https://numpy.org/doc/stable/reference/generated/numpy.tile.html .. _repeat: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html >>> mat = np.array([[1, 2], [3, 4]]) >>> tile_and_repeat(mat, (2, 2), (2, 2)) array([[1, 1, 2, 2, 1, 1, 2, 2], [1, 1, 2, 2, 1, 1, 2, 2], [3, 3, 4, 4, 3, 3, 4, 4], [3, 3, 4, 4, 3, 3, 4, 4], [1, 1, 2, 2, 1, 1, 2, 2], [1, 1, 2, 2, 1, 1, 2, 2], [3, 3, 4, 4, 3, 3, 4, 4], [3, 3, 4, 4, 3, 3, 4, 4]]) >>> tile_and_repeat( ... mat=np.array([False, True], dtype=bool), ... tile=(1, 2), ... repeat=(1, 3), ... ) array([[False, False, False, True, True, True, False, False, False, True, True, True]]) """ tiled = np.tile(mat, tile) repeat_along_0 = np.repeat(tiled, repeat[0], axis=0) return np.repeat(repeat_along_0, repeat[1], axis=1)
[docs] @lru_cache def get_state_idx_matrix(lnl_idx: int, num_lnls: int, num_states: int) -> np.ndarray: """Return the indices for the transition tensor corresponding to ``lnl_idx``. >>> get_state_idx_matrix(1, 3, 2) array([[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]) >>> get_state_idx_matrix(1, 2, 3) array([[0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2]]) """ indices = np.arange(num_states).reshape(num_states, -1) block = np.tile(indices, (num_states**lnl_idx, num_states**num_lnls)) return np.repeat(block, num_states ** (num_lnls - lnl_idx - 1), axis=0)
[docs] def row_wise_kron(a: np.ndarray, b: np.ndarray) -> np.ndarray: """Compute the `kronecker product`_ of two matrices row-wise. .. _kronecker product: https://en.wikipedia.org/wiki/Kronecker_product >>> a = np.array([[1, 2], [3, 4]]) >>> b = np.array([[5, 6], [7, 8]]) >>> row_wise_kron(a, b) array([[ 5., 6., 10., 12.], [21., 24., 28., 32.]]) """ result = np.zeros((a.shape[0], a.shape[1] * b.shape[1])) for i in range(a.shape[0]): result[i] = np.kron(a[i], b[i]) return result
[docs] def early_late_mapping(t_stage: int | str) -> str: """Map the reported T-category (i.e., 1, 2, 3, 4) to "early" and "late".""" t_stage = int(t_stage) if 0 <= t_stage <= 2: return "early" if 3 <= t_stage <= 4: return "late" raise ValueError(f"Invalid T-stage: {t_stage}")
[docs] def trigger(func: callable) -> callable: """Decorator that runs instance's ``trigger_callbacks`` when called.""" # noqa: D401 @wraps(func) def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) for callback in self.trigger_callbacks: callback() return result return wrapper
[docs] class smart_updating_dict_cached_property(cached_property): # noqa: N801 """Allows setting/deleting dict-like attrs by updating/clearing them.""" def __set__(self, instance: object, value: Any) -> None: """Update the dict-like attribute with the given value.""" dict_like = self.__get__(instance) dict_like.clear() dict_like.update(value) def __delete__(self, instance: object) -> None: """Clear the dict-like attribute.""" dict_like = self.__get__(instance) dict_like.clear()
[docs] def dict_to_func(mapping: dict[Any, Any]) -> callable: """Transform a dictionary into a function. >>> char_map = {'a': 1, 'b': 2, 'c': 3} >>> char_map = dict_to_func(char_map) >>> char_map('a') 1 """ def callable_mapping(key): return mapping[key] return callable_mapping
[docs] def popfirst(seq: Sequence[Any]) -> tuple[Any, Sequence[Any]]: """Return the first element of a sequence and the sequence without it. If the sequence is empty, the first element will be ``None`` and the second just the empty sequence. Example: >>> popfirst([1, 2, 3]) (1, [2, 3]) >>> popfirst([]) (None, []) """ try: return seq[0], seq[1:] except IndexError: return None, seq
[docs] def poplast(seq: Sequence[Any]) -> tuple[Sequence[Any], Any]: """Return the sequence without the last element and the last element. If the sequence is empty, the first element will be ``None`` and the second just the empty sequence. Example: >>> poplast([1, 2, 3]) ([1, 2], 3) >>> poplast([]) ([], None) """ first, rest = popfirst(seq[::-1]) return rest[::-1], first
[docs] def popat(seq: Sequence[Any], idx: int) -> tuple[Sequence[Any], Any, Sequence[Any]]: """Return the sequence before, the element at, and the sequence after ``idx``. If the sequence is empty, the sequence before and after will be empty and the element at ``idx`` will be ``None``. If ``idx`` is too large, the sequence before will be the whole sequence, the element at ``idx`` will be ``None``, and the sequence after will be empty. Example: >>> popat([1, 2, 3], 1) ([1], 2, [3]) >>> popat([], 0) ([], None, []) >>> popat([1, 2, 3], -1) ([1, 2], 3, []) >>> popat([1, 2, 3], -10) ([], None, [1, 2, 3]) >>> popat([1, 2, 3], 10) ([1, 2, 3], None, []) >>> popat((1, 2, 3), 10) ((1, 2, 3), None, ()) """ if idx < 0: idx += len(seq) if idx < 0: return type(seq)(), None, seq if idx >= len(seq): return seq, None, type(seq)() return seq[:idx], seq[idx], seq[idx + 1 :]
[docs] def flatten(mapping, parent_key="", sep="_") -> dict: """Flatten a nested dictionary. >>> flatten({"a": {"b": 1, "c": 2}, "d": 3}) {'a_b': 1, 'a_c': 2, 'd': 3} """ items = [] for k, v in mapping.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(flatten(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items)
[docs] def unflatten_and_split( mapping: dict, expected_keys: list[str], sep: str = "_", ) -> tuple[dict, dict]: """Unflatten the part of a dict containing ``expected_keys`` and return the rest. >>> unflatten_and_split({'a_b': 1, 'a_c_x': 2, 'd_y': 3}, expected_keys=['a']) ({'a': {'b': 1, 'c_x': 2}}, {'d_y': 3}) """ split_kwargs, global_kwargs = {}, {} for key, value in mapping.items(): left, _, right = key.partition(sep) if left not in expected_keys: global_kwargs[key] = value continue tmp = split_kwargs if left not in tmp: tmp[left] = {} tmp = tmp[left] tmp[right] = value return split_kwargs, global_kwargs
[docs] def get_params_from( objects: dict[str, types.HasGetParams], as_dict: bool = True, as_flat: bool = True, ) -> types.ParamsType: """Get the parameters from each ``get_params()`` method of the ``objects``.""" params = {} for key, obj in objects.items(): params[key] = obj.get_params(as_flat=as_flat) if as_flat or not as_dict: params = flatten(params) return params if as_dict else params.values()
[docs] def set_params_for( objects: dict[str, types.HasSetParams], *args: float, **kwargs: float, ) -> tuple[float]: """Pass arguments to each ``set_params()`` method of the ``objects``.""" kwargs, global_kwargs = unflatten_and_split(kwargs, expected_keys=objects.keys()) for key, obj in objects.items(): obj_kwargs = global_kwargs.copy() obj_kwargs.update(kwargs.get(key, {})) args = obj.set_params(*args, **obj_kwargs) return args
[docs] def safe_set_params( model: types.ModelT, params: types.ParamsType | None = None, ) -> None: """Set the ``params`` of the ``model``. This infers whether ``params`` is a dict or a list and calls the ``model``'s method ``set_params()`` accordingly. """ if params is None: return if isinstance(params, dict): model.set_named_params(**params) else: model.set_named_params(*params)
[docs] def synchronize_params( get_from: dict[str, types.HasGetParams], set_to: dict[str, types.HasSetParams], ) -> None: """Get the parameters from one object and set them to another.""" for key, obj in set_to.items(): obj.set_params(**get_from[key].get_params(as_dict=True))
[docs] def draw_diagnosis( diagnosis_times: list[int], state_evolution: np.ndarray, observation_matrix: np.ndarray, possible_diagnosis: np.ndarray, rng: np.random.Generator | None = None, seed: int = 42, ) -> np.ndarray: """Draw diagnosis given ``diagnosis_times`` and hidden ``state_evolution``.""" if rng is None: rng = np.random.default_rng(seed) state_dists_given_time = state_evolution[diagnosis_times] observation_dists_given_time = state_dists_given_time @ observation_matrix drawn_observation_idxs = [ rng.choice(a=np.arange(len(possible_diagnosis)), p=dist) for dist in observation_dists_given_time ] return possible_diagnosis[drawn_observation_idxs].astype(bool)
[docs] def add_or_mult(llh: float, arr: np.ndarray, log: bool = True) -> float: """Add or multiply the log-likelihood with the given array.""" if log: return llh + np.sum(np.log(arr)) return llh * np.prod(arr)
K, V = TypeVar("K", bound=Any), TypeVar("V", bound=Any)
[docs] def get_item(mapping: Mapping[K, V], keys: Sequence[K]) -> V: # type: ignore """Get an item from a mapping using a sequence of keys. >>> d = {'a': 1, 'b': 2} >>> get_item(d, ['a']) 1 >>> get_item(d, ['b', 'a']) 2 >>> get_item(d, ['x', 'b']) 2 >>> get_item(d, ['x', 'y']) Traceback (most recent call last): ... KeyError: "None of the keys=['x', 'y'] found in the mapping={'a': 1, 'b': 2}." """ for key in keys: if (value := mapping.get(key, None)) is not None: return value raise KeyError(f"None of the {keys=} found in the {mapping=}.")