"""Noise generation utilities for CMB component separation.
This module provides utilities for generating instrumental noise and
creating noise covariance operators used in likelihood computations.
"""
from functools import partial
from typing import Literal
import jax
import jax.numpy as jnp
from furax._instruments.sky import FGBusterInstrument, get_noise_sigma_from_instrument
from furax.obs.landscapes import FrequencyLandscape
from furax.obs.operators import NoiseDiagonalOperator
from furax.obs.stokes import Stokes
from jax_healpy.clustering import get_cutout_from_mask
from jaxtyping import Array, Int, PRNGKeyArray
[docs]
@partial(jax.jit, static_argnames=("nside", "stokes_type"))
def generate_noise_operator(
key: PRNGKeyArray,
noise_ratio: float,
indices: Int[Array, " n_valid"],
nside: int,
masked_d: Stokes,
instrument: FGBusterInstrument,
stokes_type: Literal["QU", "IQU"] | None = None,
) -> tuple[Stokes, NoiseDiagonalOperator, Stokes]:
"""Generate noised data and corresponding noise covariance operator.
This function creates a complete noise model for CMB component separation:
1. Generates white noise scaled by instrument sensitivity.
2. Adds noise to the input data.
3. Creates a diagonal noise covariance operator for the likelihood.
Args:
key: JAX PRNG key for reproducible noise generation.
noise_ratio: Noise level as fraction of signal. Use 0.0 for noiseless analysis.
Typical values: 0.1 (10%), 1.0 (100%), etc.
indices: Indices of unmasked pixels from the full-sky map.
nside: HEALPix resolution parameter.
masked_d: Input data (already masked/cutout), Stokes object with Q and U.
instrument: Instrument configuration containing frequency bands and noise specs.
stokes_type: Stokes parameters to use. If None, inferred from ``masked_d.stokes``.
``"QU"`` for polarization-only, ``"IQU"`` for intensity + polarization.
Returns:
A tuple containing:
- **noised_d**: Data with added instrumental noise.
- **N**: Diagonal noise covariance operator for use in likelihood computation.
Example:
>>> from furax_cs.data import get_instrument, get_mask
>>> from jax_healpy.clustering import get_cutout_from_mask
>>> instrument = get_instrument("LiteBIRD")
>>> mask = get_mask("GAL020")
>>> (indices,) = jnp.where(mask == 1)
>>> masked_d = get_cutout_from_mask(d, indices, axis=1)
>>> noised_d, N = generate_noise_operator(
... key=jax.random.key(42),
... noise_ratio=1.0,
... indices=indices,
... nside=64,
... masked_d=masked_d,
... instrument=instrument,
... )
Notes:
When ``noise_ratio=0``, the noise operator uses variance=1.0 to avoid
singular matrices in the likelihood computation.
The noise model assumes:
- White (uncorrelated) noise between pixels and frequencies.
- Diagonal noise covariance (independent noise per measurement).
- Noise variance determined by instrument sensitivity.
"""
# Infer stokes type from data structure if not provided
if stokes_type is None:
stokes_type = masked_d.stokes
# Create frequency landscape for noise generation
f_landscapes = FrequencyLandscape(nside, instrument.frequency, stokes_type)
# Generate white noise and scale by noise_ratio
white_noise = f_landscapes.normal(key) * noise_ratio
# Extract cutout (only unmasked pixels)
white_noise = get_cutout_from_mask(white_noise, indices, axis=1)
# Get instrument noise sigma
sigma = get_noise_sigma_from_instrument(instrument, nside, stokes_type=stokes_type)
# Scale noise by instrument sensitivity
noise = white_noise * sigma
# Add noise to data
noised_d = masked_d + noise
# Compute noise variance for covariance operator
# When noise_ratio=0, use 1.0 to avoid singular N
small_n = jax.tree.map(
lambda s: jnp.where(noise_ratio == 0, 1.0, (s * noise_ratio) ** 2), sigma
)
# Create diagonal noise covariance operator
N = NoiseDiagonalOperator(small_n, in_structure=masked_d.structure)
return noised_d, N, small_n