Source code for lymph.mixins

"""Mixin classes to enhance functionality of models."""

from collections.abc import Sequence

from lymph.types import ParamsType
from lymph.utils import does_contain_in_order


[docs] class NamedParamsMixin: """Allow defining a :py:attr:`.named_params` subset of params to set and get.""" @property def named_params(self) -> Sequence[str]: """Sequence of parameter names that may be changed. Only parameter names are allowed that would also be recognized by the :py:meth:`~lymph.types.Model.set_params` method. For example, ``"TtoII_spread"`` or ``"late_p"`` could be valid named parameters. Even global parameters like ``"spread"`` work. .. warning:: The order is important: If the :py:attr:`.named_params` are set to e.g. ``["TtoII_spread", "spread"]``, then the ``"spread"`` parameter will override the ``"TtoII_spread"``. This exists for reproducibility reasons: It allows for a subset of parameters to be set via a special method (:py:meth:`.set_named_params`). Subsequently, only these parameters can be set via that method, both using positional and keyword arguments. A use case for this is parameter sampling. E.g., someone samples only a subset of parameters and stores these as an unnamed array along with a list of the parameters names they correspond to. Without the :py:attr:`.named_params` and the :py:meth:`.set_named_params` method, it would be tricky to load those values back into the model. .. seealso:: `This issue`_ on GitHub provides more information for the rationale behind this mixin. .. _This issue: https://github.com/rmnvsl/lymph/issues/95 """ return getattr(self, "_named_params", self.get_params(as_dict=True).keys()) @named_params.setter def named_params(self, new_names: Sequence[str]) -> None: """Set the named params.""" if not isinstance(new_names, Sequence): try: new_names = list(new_names) except TypeError as te: raise ValueError("Named params must be castable to a sequence.") from te default_params = list(self.get_params(as_dict=True, as_flat=True).keys()) for name in new_names: if not name.isidentifier(): raise ValueError(f"Named param {name} isn't valid identifier.") is_valid = False for default_name in default_params: if does_contain_in_order( sequence=default_name.split("_"), items=name.split("_"), ): is_valid = True if not is_valid: raise ValueError(f"Named param {name} is not a settable param.") self._named_params = new_names
[docs] def get_named_params(self, as_dict: bool = True) -> ParamsType: """Get the values of the :py:attr:`.named_params`. .. note:: Unlike the general :py:meth:`~lymph.types.Model.get_params` method, this method does not support the keyword argument ``as_flat``. The returned dictionary (if ``as_dict=True``) will always be flat. """ all_params = self.get_params(as_dict=True, as_flat=True) named_params = {k: all_params[k] for k in self.named_params} return named_params if as_dict else named_params.values()
[docs] def set_named_params(self, *args, **kwargs) -> None: """Set the values of the :py:attr:`.named_params`. .. note:: Positional arguments are overwritten by keyword arguments, which must only contain keys that are in :py:attr:`.named_params`. """ if not set(self.named_params).issuperset(kwargs.keys()): raise ValueError(f"Kwargs must be subset of named params, but is {kwargs}.") new_params = dict(zip(self.named_params, args, strict=False)) new_params.update(kwargs) self.set_params(**new_params)