Skip to content

Force (Oscillating) RMA

rma_kinetics.models.ForceRMA

Model of rapidly changing RMA expression.

Attributes:

Name Type Description
rma_prod_rate float

RMA production rate (concentration/time).

rma_rt_rate float

RMA reverse transcytosis rate (1/time).

rma_deg_rate float

RMA degradation rate (1/time).

freq float

Frequency of oscillations (1/time).

Source code in src/rma_kinetics/models/force.py
class ForceRMA(AbstractModel):
    """
    Model of rapidly changing RMA expression.

    Attributes:
        rma_prod_rate (float): RMA production rate (concentration/time).
        rma_rt_rate (float): RMA reverse transcytosis rate (1/time).
        rma_deg_rate (float): RMA degradation rate (1/time).
        freq (float): Frequency of oscillations (1/time).
    """
    rma_prod_rate: float
    rma_rt_rate: float
    rma_deg_rate: float
    freq: float

    def __init__(
        self,
        rma_prod_rate: float,
        rma_rt_rate: float,
        rma_deg_rate: float,
        freq: float,
        time_units: Time = Time.hours,
        conc_units: Concentration = Concentration.nanomolar
    ):
        super().__init__(rma_prod_rate, rma_rt_rate, rma_deg_rate, time_units, conc_units)
        self.freq = freq

    def _force_rma_prod_rate(self, t: float) -> Array:
        """
        Sinusoid oscillating RMA production rate.

        Args:
            t (float): Time point.

        Returns:
            rma_prod_rate (jax.Array): RMA production rate evaluated at time `t`.
        """
        return self.rma_prod_rate * (1 + jnp.sin(2 * jnp.pi * self.freq * t))

    def _model(self, t: float, y: PyTree[float], args=None) -> PyTree[float]:
        """
        ODE model implementation. See the model equations section for more details.

        Arguments:
            t (float): Time point.
            y (PyTree[float]): Brain and plasma RMA concentrations.

        Returns:
            dydt (PyTree[float]): Change in brain and plasma RMA concentrations.
        """
        brain_rma, plasma_rma = y
        plasma_rma_transport_flux = self.rma_rt_rate * brain_rma
        dbrain_rma = self._force_rma_prod_rate(t) - plasma_rma_transport_flux
        dplasma_rma = plasma_rma_transport_flux - (self.rma_deg_rate * plasma_rma)

        return dbrain_rma, dplasma_rma

    def _terms(self) -> AbstractTerm:
        return ODETerm(self._model)

    def _apply_noise(self, solution: Array, std: float, prng_key: Array) -> Array:
        """
        Apply Gaussian noise to a given trajectory.

        Args:
            solution (jax.Array): Solution/trajectory to apply noise to.
            std (float): Noise standard deviation.
            prng_key (jax.Array): Jax PRNG key.

        Returns:
            noisy_solution (jax.Array): Solution with applied Gaussian noise.
        """
        noise = std * jrand.normal(prng_key, shape=(len(solution),))
        return jnp.clip(solution * (1 + noise), min=0)

    def freq_recovery(
        self,
        simulation_config: dict[str, Any],
        noise_std: float,
        prng_key: Array,
        target_freq: float,
        n_iter: int,
        fs: float,
        rtol: float,
        min_snr: float
    ):
        """
        Calculate temporal resolution of model at a given noise level
        and target frequency.

        Args:
            simulation_config (dict[str, Any]): `simulate` method params as a dictionary.
            noise_std (float): Gaussian noise standard deviation.
            prng_key (jax.Array): Jax PRNG key.
            target_freq (float): Target frequency of true RMA oscillations.
            n_iter (int): Number of iterations to run for bootstrapping.
            fs (float): Sampling frequency.
            rtol (float): Relative tolerance for matching target and recovered frequencies.
            min_snr (float): Minimum signal-to-noise ratio to consider recovered frequencies.

        Returns:
            resolution (float): Percent recovery of target frequency.
        """
        resolution = 0
        keys = jrand.split(prng_key, n_iter)

        def body_fn(i, resolution):
            # solution = diffeqsolve(
            #     self._terms(),
            #     **simulation_config
            # )
            solution = self.simulate(**simulation_config)
            plasma_rma = cond(
                noise_std > 0,
                lambda: self._apply_noise(solution.ys[1], noise_std, keys[i]),
                lambda: solution.ys[1]
            )

            norm_plasma_rma = plasma_rma / (self.rma_prod_rate/self.rma_deg_rate)
            nperseg = len(norm_plasma_rma) // 2

            freq, psd = welch(norm_plasma_rma - jnp.mean(norm_plasma_rma), fs=fs, nperseg=nperseg)
            fpeak = freq[jnp.argmax(psd)]
            freq_match = jnp.isclose(fpeak, target_freq, rtol=rtol, atol=0)
            psd_noise = jnp.where(~jnp.isclose(freq, target_freq, rtol=rtol), psd, jnp.nan)
            snr = fpeak / jnp.nanmean(psd_noise)
            increment = cond(
                jnp.logical_and(freq_match, snr >= min_snr),
                lambda: resolution + 1,
                lambda: resolution
            )

            return increment

        resolution = fori_loop(0, n_iter, body_fn, resolution)
        return resolution / n_iter

    def coherence_cutoff(
        self,
        input_signal: Array,
        input_psd: Array,
        simulation_config: dict[str, Any],
        noise_std: float,
        prng_key: Array,
        target_freq: float,
        n_iter: int,
        fs: float,
    ):
        """
        Calculate coherence based frequency cutoff of model at a given noise level
        and target frequency.

        Args:
            input_signal (Array): Input sine wave used to drive RMA production.
            input_psd (Array): Input signal power spectral density.
            simulation_config (dict[str, Any]): `simulate` method params as a dictionary.
            noise_std (float): Gaussian noise standard deviation.
            prng_key (jax.Array): Jax PRNG key.
            target_freq (float): Target frequency of true RMA oscillations.
            n_iter (int): Number of iterations to run for bootstrapping.
            fs (float): Sampling frequency.

        Returns:
            coherence (float): Average magnitude-squared coherence at a given target input frequency.
        """
        sum_coh = 0
        keys = jrand.split(prng_key, n_iter)

        def body_fn(i, sum_coh):
            solution = diffeqsolve(
                self._terms(),
                **simulation_config
            )
            plasma_rma = cond(
                noise_std > 0,
                lambda: self._apply_noise(solution.ys[1], noise_std, keys[i]),
                lambda: solution.ys[1]
            )

            norm_plasma_rma = plasma_rma / (self.rma_prod_rate/self.rma_deg_rate)
            # freq, pyy = welch(norm_plasma_rma, fs=fs)
            # _, pxy = csd(input_signal, norm_plasma_rma, fs=fs)
            # coherence = jnp.abs(pxy)**2 / pyy / input_psd
            freq, coh = coherence(input_signal, norm_plasma_rma, fs=fs)
            # return coherence near the target frequency
            target_idx = jnp.argmin(jnp.abs(freq - target_freq))
            return sum_coh + coh[target_idx]

        sum_coh = fori_loop(0, n_iter, body_fn, sum_coh)
        return sum_coh / n_iter

    def power_cutoff(
        self,
        simulation_config: dict[str, Any],
        noise_std: float,
        prng_key: Array,
        target_freq: float,
        n_iter: int,
        fs: float,
        rtol: float
    ):
        """
        Calculate power based frequency cutoff of model at a given noise level
        and target frequency.

        Args:
            input_signal (Array): Input sine wave used to drive RMA production.
            input_psd (Array): Input signal power spectral density.
            simulation_config (dict[str, Any]): `simulate` method params as a dictionary.
            noise_std (float): Gaussian noise standard deviation.
            prng_key (jax.Array): Jax PRNG key.
            target_freq (float): Target frequency of true RMA oscillations.
            n_iter (int): Number of iterations to run for bootstrapping.
            fs (float): Sampling frequency.

        Returns:
            coherence (float): Average magnitude-squared coherence at a given target input frequency.
        """
        sum_power_ratio = 0
        keys = jrand.split(prng_key, n_iter)

        def body_fn(i, sum_power_ratio):
            solution = self.simulate(**simulation_config)
            plasma_rma = cond(
                noise_std > 0,
                lambda: self._apply_noise(solution.ys[1], noise_std, keys[i]),
                lambda: solution.ys[1]
            )

            norm_plasma_rma = plasma_rma / (self.rma_prod_rate/self.rma_deg_rate)
            nperseg = len(norm_plasma_rma) // 2

            freq, psd = welch(norm_plasma_rma - jnp.mean(norm_plasma_rma), fs=fs, nperseg=nperseg)

            bands = (target_freq - rtol*target_freq, target_freq + rtol*target_freq)
            avg_target_bandpower = bandpower(psd, freq, bands)
            # let's try to use the total bandpower instead
            #total_bandpower = bandpower(psd, freq, (float(freq[0]), float(freq[-1])))
            total_bandpower = jnp.sum(psd) * (freq[1] - freq[0])
            #avg_noise_bandpower = bandpower(psd, freq, bands, logical_and=False)
            power_ratio = avg_target_bandpower / total_bandpower
            return power_ratio + sum_power_ratio

        sum_power_ratio = fori_loop(0, n_iter, body_fn, sum_power_ratio)
        return sum_power_ratio / n_iter

    def coh_cutoff(
        self,
        simulation_config: dict[str, Any],
        input_signal: jnp.ndarray,
        input_psd: jnp.ndarray,
        noise_std: float,
        prng_key: Array,
        target_freq: float,
        n_iter: int,
        fs: float,
    ):
        """
        Calculate power based frequency cutoff of model at a given noise level
        and target frequency.

        Args:
            input_signal (Array): Input sine wave used to drive RMA production.
            input_psd (Array): Input signal power spectral density.
            simulation_config (dict[str, Any]): `simulate` method params as a dictionary.
            noise_std (float): Gaussian noise standard deviation.
            prng_key (jax.Array): Jax PRNG key.
            target_freq (float): Target frequency of true RMA oscillations.
            n_iter (int): Number of iterations to run for bootstrapping.
            fs (float): Sampling frequency.

        Returns:
            coherence (float): Average magnitude-squared coherence at a given target input frequency.
        """
        sum_coh = 0
        keys = jrand.split(prng_key, n_iter)

        def body_fn(i, sum_coh):
            solution = self.simulate(**simulation_config)
            plasma_rma = cond(
                noise_std > 0,
                lambda: self._apply_noise(solution.ys[1], noise_std, keys[i]),
                lambda: solution.ys[1]
            )

            norm_plasma_rma = plasma_rma / (self.rma_prod_rate/self.rma_deg_rate)

            #freq, coh = coherence(input_signal, norm_plasma_rma, fs=fs)
            nperseg = len(norm_plasma_rma) // 2
            freq, psdy = welch(norm_plasma_rma, fs=fs, nperseg=len(norm_plasma_rma)//2)
            _, psdxy = csd(input_signal, norm_plasma_rma, fs=fs, nperseg=nperseg)
            cxy = jnp.abs(psdxy)**2 / input_psd / psdy

            # return coherence near the target frequency
            target_idx = jnp.argmin(jnp.abs(freq - target_freq))
            return sum_coh + cxy[target_idx]

        sum_coh = fori_loop(0, n_iter, body_fn, sum_coh)
        return sum_coh / n_iter

simulate(t0: float, t1: float, y0: PyTree[float], dt0: float | None = None, sampling_rate: float = 1, stepsize_controller: AbstractStepSizeController = PIDController(rtol=1e-05, atol=1e-05), max_steps: int = 4096, solver: AbstractSolver = Kvaerno3(), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), throw: bool = True, progress_meter: AbstractProgressMeter = NoProgressMeter())

Simulates model within the given time interval.

Wraps diffrax.diffeqsolve with specific defaults for RMA model simulation.

Parameters:

Name Type Description Default
t0 float

Start time of integration.

required
t1 float

Stop time of integration.

required
y0 PyTree[float]

Tuple of initial conditions.

required
dt0 float | None`

Initial step size if using adaptive step sizes, or size of all steps if using constant stepsize.

None
sampling_rate float

Sampling rate for saving solution.

1
stepsize_controller AbstractStepSizeController`

Determines how to change step size during integration.

PIDController(rtol=1e-05, atol=1e-05)
max_steps int

Max number of steps before stopping.

4096
solver AbstractSolver

Differential equation solver.

Kvaerno3()
adjoint AbstractAdjoint

How to differentiate.

RecursiveCheckpointAdjoint()
throw bool

Raise an exception if integration fails.

True

Returns:

Name Type Description
solution Solution

A solution object (parent of diffrax.Solution) with added plotting methods.

Source code in src/rma_kinetics/models/abstract.py
def simulate(
        self,
        t0: float,
        t1: float,
        y0: PyTree[float],
        dt0: float | None = None,
        sampling_rate: float = 1,
        stepsize_controller: AbstractStepSizeController = PIDController(rtol=1e-5, atol=1e-5),
        max_steps: int = 4096,
        solver: AbstractSolver = Kvaerno3(),
        adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(),
        throw: bool = True,
        progress_meter: AbstractProgressMeter = NoProgressMeter()
):
    """
    Simulates model within the given time interval.

    Wraps `diffrax.diffeqsolve` with specific defaults for RMA model simulation.

    Arguments:
        t0 (float): Start time of integration.
        t1 (float): Stop time of integration.
        y0 (PyTree[float]): Tuple of initial conditions.
        dt0 (float | None`): Initial step size if using adaptive
            step sizes, or size of all steps if using constant stepsize.
        sampling_rate (float): Sampling rate for saving solution.
        stepsize_controller (AbstractStepSizeController`): Determines
            how to change step size during integration.
        max_steps (int): Max number of steps before stopping.
        solver (AbstractSolver): Differential equation solver.
        adjoint (AbstractAdjoint): How to differentiate.
        throw (bool): Raise an exception if integration fails.

    Returns:
        solution (Solution): A solution object (parent of diffrax.Solution) with added plotting methods.
    """
    saveat = SaveAt(ts=jnp.linspace(t0, t1, int(t1*sampling_rate)))
    diffsol = diffeqsolve(
        self._terms(),
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        max_steps=max_steps,
        adjoint=adjoint,
        throw=throw,
        progress_meter=progress_meter
    )

    return Solution(diffsol, self.time_units, self.conc_units)

_model(t: float, y: PyTree[float], args=None) -> PyTree[float]

ODE model implementation. See the model equations section for more details.

Parameters:

Name Type Description Default
t float

Time point.

required
y PyTree[float]

Brain and plasma RMA concentrations.

required

Returns:

Name Type Description
dydt PyTree[float]

Change in brain and plasma RMA concentrations.

Source code in src/rma_kinetics/models/force.py
def _model(self, t: float, y: PyTree[float], args=None) -> PyTree[float]:
    """
    ODE model implementation. See the model equations section for more details.

    Arguments:
        t (float): Time point.
        y (PyTree[float]): Brain and plasma RMA concentrations.

    Returns:
        dydt (PyTree[float]): Change in brain and plasma RMA concentrations.
    """
    brain_rma, plasma_rma = y
    plasma_rma_transport_flux = self.rma_rt_rate * brain_rma
    dbrain_rma = self._force_rma_prod_rate(t) - plasma_rma_transport_flux
    dplasma_rma = plasma_rma_transport_flux - (self.rma_deg_rate * plasma_rma)

    return dbrain_rma, dplasma_rma

_force_rma_prod_rate(t: float) -> Array

Sinusoid oscillating RMA production rate.

Parameters:

Name Type Description Default
t float

Time point.

required

Returns:

Name Type Description
rma_prod_rate Array

RMA production rate evaluated at time t.

Source code in src/rma_kinetics/models/force.py
def _force_rma_prod_rate(self, t: float) -> Array:
    """
    Sinusoid oscillating RMA production rate.

    Args:
        t (float): Time point.

    Returns:
        rma_prod_rate (jax.Array): RMA production rate evaluated at time `t`.
    """
    return self.rma_prod_rate * (1 + jnp.sin(2 * jnp.pi * self.freq * t))

_apply_noise(solution: Array, std: float, prng_key: Array) -> Array

Apply Gaussian noise to a given trajectory.

Parameters:

Name Type Description Default
solution Array

Solution/trajectory to apply noise to.

required
std float

Noise standard deviation.

required
prng_key Array

Jax PRNG key.

required

Returns:

Name Type Description
noisy_solution Array

Solution with applied Gaussian noise.

Source code in src/rma_kinetics/models/force.py
def _apply_noise(self, solution: Array, std: float, prng_key: Array) -> Array:
    """
    Apply Gaussian noise to a given trajectory.

    Args:
        solution (jax.Array): Solution/trajectory to apply noise to.
        std (float): Noise standard deviation.
        prng_key (jax.Array): Jax PRNG key.

    Returns:
        noisy_solution (jax.Array): Solution with applied Gaussian noise.
    """
    noise = std * jrand.normal(prng_key, shape=(len(solution),))
    return jnp.clip(solution * (1 + noise), min=0)