Skip to content

Abstract Model

rma_kinetics.models.AbstractModel

Abstract RMA model.

Attributes:

Name Type Description
rma_prod_rate float

RMA production rate (concentration/time).

rma_rt_rate float

RMA reverse transcytosis rate (1/time).

rma_deg_rate float

RMA degradation rate (1/time).

time_units Time

Time units (Default = Time.hours).

conc_units Concentration

Concentration units (Default = Concentration.nanomolar).

Source code in src/rma_kinetics/models/abstract.py
class AbstractModel(EqxModule):
    """
    Abstract RMA model.

    Attributes:
        rma_prod_rate (float): RMA production rate (concentration/time).
        rma_rt_rate (float): RMA reverse transcytosis rate (1/time).
        rma_deg_rate (float): RMA degradation rate (1/time).
        time_units (Time): Time units (Default = Time.hours).
        conc_units (Concentration): Concentration units (Default = Concentration.nanomolar).
    """
    rma_prod_rate: float
    rma_rt_rate: float
    rma_deg_rate: float
    time_units: Time
    conc_units: Concentration

    def __init__(
        self,
        rma_prod_rate: float,
        rma_rt_rate: float,
        rma_deg_rate: float,
        time_units: Time = Time.hours,
        conc_units: Concentration = Concentration.nanomolar,
    ):
        self.rma_prod_rate = rma_prod_rate
        self.rma_rt_rate = rma_rt_rate
        self.rma_deg_rate = rma_deg_rate
        self.time_units = time_units
        self.conc_units = conc_units

    @abstractmethod
    def _model(self, t: float, y: PyTree[float], args=None) -> PyTree[float]:
        """
        Final ODE/SDE model (implemented by child class)

        Args:
            t (float): Time point.
            y (PyTree[float]): Brain and plasma RMA concentrations

        Returns:
            dydt (PyTree[float]): Change in brain and plasma RMA concentrations
                (along with any other additional species).
        """
        pass

    @abstractmethod
    def _terms(self) -> AbstractTerm:
        """
        Wraps model in `AbstractTerm` for use with the differential
        equation solver (implemented by child class).

        Returns:
            term (AbstractTerm): Terms for use with `diffrax.diffeqsolve`.
        """
        pass

    def simulate(
            self,
            t0: float,
            t1: float,
            y0: PyTree[float],
            dt0: float | None = None,
            sampling_rate: float = 1,
            stepsize_controller: AbstractStepSizeController = PIDController(rtol=1e-5, atol=1e-5),
            max_steps: int = 4096,
            solver: AbstractSolver = Kvaerno3(),
            adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(),
            throw: bool = True,
            progress_meter: AbstractProgressMeter = NoProgressMeter()
    ):
        """
        Simulates model within the given time interval.

        Wraps `diffrax.diffeqsolve` with specific defaults for RMA model simulation.

        Arguments:
            t0 (float): Start time of integration.
            t1 (float): Stop time of integration.
            y0 (PyTree[float]): Tuple of initial conditions.
            dt0 (float | None`): Initial step size if using adaptive
                step sizes, or size of all steps if using constant stepsize.
            sampling_rate (float): Sampling rate for saving solution.
            stepsize_controller (AbstractStepSizeController`): Determines
                how to change step size during integration.
            max_steps (int): Max number of steps before stopping.
            solver (AbstractSolver): Differential equation solver.
            adjoint (AbstractAdjoint): How to differentiate.
            throw (bool): Raise an exception if integration fails.

        Returns:
            solution (Solution): A solution object (parent of diffrax.Solution) with added plotting methods.
        """
        saveat = SaveAt(ts=jnp.linspace(t0, t1, int(t1*sampling_rate)))
        diffsol = diffeqsolve(
            self._terms(),
            solver,
            t0,
            t1,
            dt0,
            y0,
            saveat=saveat,
            stepsize_controller=stepsize_controller,
            max_steps=max_steps,
            adjoint=adjoint,
            throw=throw,
            progress_meter=progress_meter
        )

        return Solution(diffsol, self.time_units, self.conc_units)

simulate(t0: float, t1: float, y0: PyTree[float], dt0: float | None = None, sampling_rate: float = 1, stepsize_controller: AbstractStepSizeController = PIDController(rtol=1e-05, atol=1e-05), max_steps: int = 4096, solver: AbstractSolver = Kvaerno3(), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), throw: bool = True, progress_meter: AbstractProgressMeter = NoProgressMeter())

Simulates model within the given time interval.

Wraps diffrax.diffeqsolve with specific defaults for RMA model simulation.

Parameters:

Name Type Description Default
t0 float

Start time of integration.

required
t1 float

Stop time of integration.

required
y0 PyTree[float]

Tuple of initial conditions.

required
dt0 float | None`

Initial step size if using adaptive step sizes, or size of all steps if using constant stepsize.

None
sampling_rate float

Sampling rate for saving solution.

1
stepsize_controller AbstractStepSizeController`

Determines how to change step size during integration.

PIDController(rtol=1e-05, atol=1e-05)
max_steps int

Max number of steps before stopping.

4096
solver AbstractSolver

Differential equation solver.

Kvaerno3()
adjoint AbstractAdjoint

How to differentiate.

RecursiveCheckpointAdjoint()
throw bool

Raise an exception if integration fails.

True

Returns:

Name Type Description
solution Solution

A solution object (parent of diffrax.Solution) with added plotting methods.

Source code in src/rma_kinetics/models/abstract.py
def simulate(
        self,
        t0: float,
        t1: float,
        y0: PyTree[float],
        dt0: float | None = None,
        sampling_rate: float = 1,
        stepsize_controller: AbstractStepSizeController = PIDController(rtol=1e-5, atol=1e-5),
        max_steps: int = 4096,
        solver: AbstractSolver = Kvaerno3(),
        adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(),
        throw: bool = True,
        progress_meter: AbstractProgressMeter = NoProgressMeter()
):
    """
    Simulates model within the given time interval.

    Wraps `diffrax.diffeqsolve` with specific defaults for RMA model simulation.

    Arguments:
        t0 (float): Start time of integration.
        t1 (float): Stop time of integration.
        y0 (PyTree[float]): Tuple of initial conditions.
        dt0 (float | None`): Initial step size if using adaptive
            step sizes, or size of all steps if using constant stepsize.
        sampling_rate (float): Sampling rate for saving solution.
        stepsize_controller (AbstractStepSizeController`): Determines
            how to change step size during integration.
        max_steps (int): Max number of steps before stopping.
        solver (AbstractSolver): Differential equation solver.
        adjoint (AbstractAdjoint): How to differentiate.
        throw (bool): Raise an exception if integration fails.

    Returns:
        solution (Solution): A solution object (parent of diffrax.Solution) with added plotting methods.
    """
    saveat = SaveAt(ts=jnp.linspace(t0, t1, int(t1*sampling_rate)))
    diffsol = diffeqsolve(
        self._terms(),
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        max_steps=max_steps,
        adjoint=adjoint,
        throw=throw,
        progress_meter=progress_meter
    )

    return Solution(diffsol, self.time_units, self.conc_units)

_terms() -> AbstractTerm

Wraps model in AbstractTerm for use with the differential equation solver (implemented by child class).

Returns:

Name Type Description
term AbstractTerm

Terms for use with diffrax.diffeqsolve.

Source code in src/rma_kinetics/models/abstract.py
@abstractmethod
def _terms(self) -> AbstractTerm:
    """
    Wraps model in `AbstractTerm` for use with the differential
    equation solver (implemented by child class).

    Returns:
        term (AbstractTerm): Terms for use with `diffrax.diffeqsolve`.
    """
    pass

_model(t: float, y: PyTree[float], args=None) -> PyTree[float]

Final ODE/SDE model (implemented by child class)

Parameters:

Name Type Description Default
t float

Time point.

required
y PyTree[float]

Brain and plasma RMA concentrations

required

Returns:

Name Type Description
dydt PyTree[float]

Change in brain and plasma RMA concentrations (along with any other additional species).

Source code in src/rma_kinetics/models/abstract.py
@abstractmethod
def _model(self, t: float, y: PyTree[float], args=None) -> PyTree[float]:
    """
    Final ODE/SDE model (implemented by child class)

    Args:
        t (float): Time point.
        y (PyTree[float]): Brain and plasma RMA concentrations

    Returns:
        dydt (PyTree[float]): Change in brain and plasma RMA concentrations
            (along with any other additional species).
    """
    pass