"""Module to manage matrices of the :py:class:`~lymph.models.Unilateral` class."""
# pylint: disable=too-few-public-methods
from __future__ import annotations
import warnings
from collections.abc import Iterable
from functools import lru_cache
import numpy as np
import pandas as pd
from lymph import graph, types
from lymph.modalities import Modality
from lymph.utils import get_state_idx_matrix, row_wise_kron, tile_and_repeat
[docs]
@lru_cache(maxsize=128)
def generate_transition(
lnls: Iterable[graph.LymphNodeLevel],
edges: Iterable[graph.Edge],
num_states: int,
) -> np.ndarray:
"""Compute the transition matrix of the lymph model.
the edges are included for caching purposes.
"""
lnls = list(lnls) # necessary for `index()` call
num_lnls = len(lnls)
transition_matrix = np.ones(shape=(num_states**num_lnls, num_states**num_lnls))
for i, lnl in enumerate(lnls):
current_state_idx = get_state_idx_matrix(
lnl_idx=i,
num_lnls=num_lnls,
num_states=num_states,
)
new_state_idx = current_state_idx.T
# This needs to be initialized with a one where no transition happens
# and a zero where a transition happens. This is because of how differently
# the transition matrix entries are computed for no spread vs. spread.
lnl_transition_matrix = new_state_idx == current_state_idx
for edge in lnl.inc:
if edge.is_tumor_spread:
edge_transition_grid = edge.transition_tensor[
0,
current_state_idx,
new_state_idx,
]
else:
parent_node_i = lnls.index(edge.parent)
parent_state_idx = get_state_idx_matrix(
lnl_idx=parent_node_i,
num_lnls=num_lnls,
num_states=num_states,
)
edge_transition_grid = edge.transition_tensor[
parent_state_idx,
current_state_idx,
new_state_idx,
]
lnl_transition_matrix = np.where(
# For transitions, we need to compute the probability that none of
# the incoming edges of an LNL spread. This is done by multiplying
# the probabilities of all edges not spreading and taking the
# complement: 1 - (1 - p_1) * (1 - p_2) * ...
# If an LNL remains in its current state, the probability is
# simply the product of all incoming edges not spreading.
new_state_idx == current_state_idx + 1,
1 - (1 - lnl_transition_matrix) * (1 - edge_transition_grid),
lnl_transition_matrix * edge_transition_grid,
)
transition_matrix *= lnl_transition_matrix
return transition_matrix
[docs]
@lru_cache(maxsize=128)
def generate_observation(
modalities: Iterable[Modality],
num_lnls: int,
base: int = 2,
) -> np.ndarray:
"""Generate the observation matrix of the lymph model."""
shape = (base**num_lnls, 1)
observation_matrix = np.ones(shape=shape)
for modality in modalities:
mod_obs_matrix = np.ones(shape=(1, 1))
for _ in range(num_lnls):
mod_obs_matrix = np.kron(mod_obs_matrix, modality.confusion_matrix)
observation_matrix = row_wise_kron(observation_matrix, mod_obs_matrix)
return observation_matrix
[docs]
def compute_encoding(
lnls: list[str],
pattern: pd.Series | dict[str, types.InvolvementIndicator],
base: int = 2,
) -> np.ndarray:
"""Compute the encoding of a particular ``pattern`` of involvement.
A ``pattern`` holds information about the involvement of each LNL and the function
transforms this into a binary encoding which is ``True`` for all possible complete
states/diagnosis that are compatible with the given ``pattern``.
In the binary case (``base=2``), the value behind ``pattern[lnl]`` can be one of
the following things:
- ``False``: The LNL is healthy.
- ``"healthy"``: The LNL is healthy.
- ``True``: The LNL is involved.
- ``"involved"``: The LNL is involved.
- ``pd.isna(pattern[lnl]) == True``: The involvement of the LNL is unknown.
In the trinary case (``base=3``), the value behind ``pattern[lnl]`` can be one of
these things:
- ``False``: The LNL is healthy.
- ``"healthy"``: The LNL is healthy.
- ``True``: The LNL is involved (micro- or macroscopic).
- ``"involved"``: The LNL is involved (micro- or macroscopic).
- ``"micro"``: The LNL is involved microscopically only.
- ``"macro"``: The LNL is involved macroscopically only.
- ``"notmacro"``: The LNL is healthy or involved microscopically.
Missing values are treated as unknown involvement.
>>> compute_encoding(["II", "III"], {"II": True, "III": False})
array([False, False, True, False])
>>> compute_encoding(["II", "III"], {"II": "involved"})
array([False, False, True, True])
>>> compute_encoding(
... lnls=["II", "III"],
... pattern={"II": True, "III": False},
... base=3,
... )
array([False, False, False, True, False, False, True, False, False])
>>> compute_encoding(
... lnls=["II", "III"],
... pattern={"II": "micro", "III": "notmacro"},
... base=3,
... )
array([False, False, False, True, True, False, False, False, False])
"""
num_lnls = len(lnls)
encoding = np.ones(shape=base**num_lnls, dtype=bool)
if base == 2:
element_map = {
"healthy": np.array([True, False]),
False: np.array([True, False]),
"involved": np.array([False, True]),
True: np.array([False, True]),
}
elif base == 3:
element_map = {
"healthy": np.array([True, False, False]),
False: np.array([True, False, False]),
"involved": np.array([False, True, True]),
True: np.array([False, True, True]),
"micro": np.array([False, True, False]),
"macro": np.array([False, False, True]),
"notmacro": np.array([True, True, False]),
}
else:
raise ValueError(f"Invalid base {base}.")
for j, lnl in enumerate(lnls):
if lnl not in pattern or pd.isna(pattern[lnl]):
continue
try:
element = element_map[pattern[lnl]]
except KeyError as key_err:
raise ValueError(
f"Invalid pattern for LNL {lnl}: {pattern[lnl]}",
) from key_err
encoding = np.logical_and(
encoding,
tile_and_repeat(
mat=element,
tile=(1, base**j),
repeat=(1, base ** (num_lnls - j - 1)),
)[0],
)
return encoding
[docs]
def generate_data_encoding(
patient_data: pd.DataFrame,
modalities: dict[str, Modality],
lnls: list[str],
) -> np.ndarray:
r"""Generate the data matrix for a specific T-stage from patient data.
The :py:attr:`.models.Unilateral.patient_data` needs to contain the column
``"_model"``, which is constructed when loading the data into the model. From this,
a data matrix is constructed for all present diagnostic modalities.
The returned matrix has the shape :math:`2^{N \\cdot \\mathcal{O}} \\times M`,
where :math:`N` is the number of lymph node levels, :math:`\\mathcal{O}` is the
number of diagnostic modalities and :math:`M` is the number of patients.
"""
result = np.ones(
shape=(2 ** (len(lnls) * len(modalities)), len(patient_data)),
dtype=bool,
)
for i, (_, patient_row) in enumerate(patient_data["_model"].iterrows()):
patient_encoding = np.ones(shape=1, dtype=bool)
for modality_name in modalities.keys():
if modality_name not in patient_row:
warnings.warn(f"Modality {modality_name} not in data. Skipping.")
continue
diagnosis_encoding = compute_encoding(
lnls=lnls,
pattern=patient_row[modality_name],
base=2, # observations are always binary!
)
patient_encoding = np.kron(patient_encoding, diagnosis_encoding)
result[:, i] = patient_encoding
return result.T
[docs]
@lru_cache
def evolve_midext(max_time: int, midext_prob: int) -> np.ndarray:
"""Compute the evolution over the state of a tumor's midline extension."""
midext_states = np.zeros(shape=(max_time + 1, 2), dtype=float)
midext_states[0, 0] = 1.0
midext_transition_matrix = np.array(
[
[1 - midext_prob, midext_prob],
[0.0, 1.0],
],
)
# compute midext prob for all time steps
for i in range(len(midext_states) - 1):
midext_states[i + 1, :] = midext_states[i, :] @ midext_transition_matrix
return midext_states
[docs]
def fast_trace(
left: np.ndarray,
right: np.ndarray,
) -> np.ndarray:
"""Compute the trace of a product of two matrices (``left`` and ``right``).
This is based on the observation that the trace of a product of two matrices is
equal to the sum of the element-wise products of the two matrices. See
`Wikipedia <https://en.wikipedia.org/wiki/Trace_(linear_algebra)#Properties>`_ and
`StackOverflow <https://stackoverflow.com/a/18854776>`_ for more information.
"""
return np.sum(left.T * right, axis=0)