Source code for lymph.modalities

"""Module implementing management of the diagnostic modalities.

This allows the user to define diagnostic modalities and their sensitivity/specificity
values. This is necessary to compute the likelihood of a dataset (that was created by
recoding the output of diagnostic modalities), given the model and its parameters
(which we want to learn).
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Literal, TypeVar

import numpy as np


[docs] class Modality: """Stores the confusion matrix of a diagnostic modality."""
[docs] def __init__( self, spec: float, sens: float, is_trinary: bool = False, ) -> None: """Initialize the modality.""" if not (0.0 <= sens <= 1.0 and 0.0 <= spec <= 1.0): raise ValueError("Senstivity and specificity must be between 0 and 1.") self.spec = spec self.sens = sens self.is_trinary = is_trinary
[docs] def __hash__(self) -> int: """Return a hash of the modality. This is computed from the confusion matrix of the modality. """ return hash(self.confusion_matrix.tobytes())
def __eq__(self, other: object) -> bool: """Check if two modalities are equal.""" if not isinstance(other, Modality): return False return np.all(self.confusion_matrix == other.confusion_matrix) def __repr__(self) -> str: """Return a string representation of the modality.""" return ( f"{type(self).__name__}(" f"spec={self.spec!r}, " f"sens={self.sens!r}, " f"is_trinary={self.is_trinary!r})" ) @property def spec(self) -> float: """Return the specificity of the modality.""" return self._spec @spec.setter def spec(self, value: float) -> None: """Set the specificity of the modality.""" if not 0.0 <= value <= 1.0: raise ValueError("Specificity must be between 0 and 1.") if hasattr(self, "_confusion_matrix"): del self._confusion_matrix self._spec = value @property def sens(self) -> float: """Return the sensitivity of the modality.""" return self._sens @sens.setter def sens(self, value: float) -> None: """Set the sensitivity of the modality.""" if not 0.0 <= value <= 1.0: raise ValueError("Sensitivity must be between 0 and 1.") if hasattr(self, "_confusion_matrix"): del self._confusion_matrix self._sens = value
[docs] def compute_confusion_matrix(self) -> np.ndarray: """Compute the confusion matrix of the modality.""" return np.array( [ [self.spec, 1.0 - self.spec], [1.0 - self.sens, self.sens], ], )
@property def confusion_matrix(self) -> np.ndarray: """Return the confusion matrix of the modality.""" if not hasattr(self, "_confusion_matrix"): self.confusion_matrix = self.compute_confusion_matrix() if self.is_trinary and not self._confusion_matrix.shape[0] == 3: self.confusion_matrix = self.compute_confusion_matrix() return self._confusion_matrix @confusion_matrix.setter def confusion_matrix(self, value: np.ndarray) -> None: """Set the confusion matrix of the modality.""" self.check_confusion_matrix(value) self._confusion_matrix = value
[docs] def check_confusion_matrix(self, value: np.ndarray) -> None: """Check if the confusion matrix is valid.""" row_sums = np.sum(value, axis=1) if not np.allclose(row_sums, 1.0): raise ValueError("Rows of confusion matrix must sum to one.") if not np.all(np.greater_equal(value, 0.0)): raise ValueError("Confusion matrix must be non-negative.") if not np.all(np.less_equal(value, 1.0)): raise ValueError("Confusion matrix must be less than or equal to one.") if self.is_trinary and value.shape[0] != 3: raise ValueError("Confusion matrix must have 3 rows for trinary models.") if not self.is_trinary and value.shape[0] != 2: raise ValueError("Confusion matrix must have 2 rows for binary models.")
[docs] class Clinical(Modality): """Stores the confusion matrix of a clinical modality."""
[docs] def compute_confusion_matrix(self) -> np.ndarray: """Compute the confusion matrix of the clinical modality.""" binary_confusion_matrix = super().compute_confusion_matrix() if not self.is_trinary: return binary_confusion_matrix return np.vstack([binary_confusion_matrix[0], binary_confusion_matrix])
[docs] class Pathological(Modality): """Stores the confusion matrix of a pathological modality."""
[docs] def compute_confusion_matrix(self) -> np.ndarray: """Return the confusion matrix of the pathological modality.""" binary_confusion_matrix = super().compute_confusion_matrix() if not self.is_trinary: return binary_confusion_matrix return np.vstack([binary_confusion_matrix, binary_confusion_matrix[1]])
MC = TypeVar("MC", bound="Composite")
[docs] class Composite(ABC): """Abstract base class implementing the composite pattern for diagnostic modalities. Any class inheriting from this class should be able to handle the definition of diagnostic modalities and their sensitivity/specificity values, """ _is_trinary: bool _modalities: dict[str, Modality] # only for leaf nodes _modality_children: dict[str, Composite]
[docs] def __init__( self: MC, modality_children: dict[str, Composite] = None, is_modality_leaf: bool = False, ) -> None: """Initialize the modality composite.""" if modality_children is None: modality_children = {} if is_modality_leaf: self._modalities = {} modality_children = {} # ignore any provided children self._modality_children = modality_children
@property def _is_modality_leaf(self: MC) -> bool: """Return whether the composite is a leaf node.""" if len(self._modality_children) > 0: return False if not hasattr(self, "_modalities"): raise AttributeError(f"{self} has no children and no modalities.") return True @property @abstractmethod def is_trinary(self: MC) -> bool: """Return whether the modality is trinary."""
[docs] def modalities_hash(self: MC) -> int: """Compute a hash from all stored modalities. See the :py:meth:`.Modality.__hash__` method for more information. """ hash_res = 0 if self._is_modality_leaf: for name, modality in self._modalities.items(): hash_res = hash((hash_res, name, hash(modality))) else: for child in self._modality_children.values(): hash_res = hash((hash_res, child.modalities_hash())) return hash_res
[docs] def get_modality(self: MC, name: str) -> Modality: """Return the modality with the given ``name``.""" return self.get_all_modalities()[name]
[docs] def get_all_modalities(self: MC) -> dict[str, Modality]: """Return all modalities of the composite. This will issue a warning if it finds that not all modalities of the composite are equal. Note that it will always return the modalities of the first child. This means one should NOT try to set the modalities via the returned dictionary of this method. Instead, use the :py:meth:`.set_modality` method. """ if self._is_modality_leaf: return self._modalities child_keys = list(self._modality_children.keys()) first_child = self._modality_children[child_keys[0]] firs_modalities = first_child.get_all_modalities() are_all_equal = True for key in child_keys[1:]: other_child = self._modality_children[key] are_all_equal &= firs_modalities == other_child.get_all_modalities() if not are_all_equal: warnings.warn("Not all modalities are equal. Returning first one.") return firs_modalities
[docs] def set_modality( self, name: str, spec: float, sens: float, kind: Literal["clinical", "pathological"] = "clinical", ) -> None: """Set the modality with the given ``name``.""" if self._is_modality_leaf: cls = Pathological if kind == "pathological" else Clinical self._modalities[name] = cls(spec, sens, self.is_trinary) else: for child in self._modality_children.values(): child.set_modality(name, spec, sens, kind)
[docs] def del_modality(self: MC, name: str) -> None: """Delete the modality with the given ``name``.""" if self._is_modality_leaf: del self._modalities[name] else: for child in self._modality_children.values(): child.del_modality(name)
[docs] def replace_all_modalities(self: MC, modalities: dict[str, Modality]) -> None: """Replace all modalities of the composite with new ``modalities``.""" if self._is_modality_leaf: self.clear_modalities() for name, modality in modalities.items(): kind = ( "pathological" if isinstance(modality, Pathological) else "clinical" ) self.set_modality(name, modality.spec, modality.sens, kind) else: for child in self._modality_children.values(): child.replace_all_modalities(modalities)
[docs] def clear_modalities(self: MC) -> None: """Clear all modalities of the composite.""" if self._is_modality_leaf: self._modalities.clear() else: for child in self._modality_children.values(): child.clear_modalities()