API Reference

Clustering

furax_cs.kmeans_clusters.kmeans_clusters(prngkey, mask, indices, regions, max_patches=None, initial_sample_size=1)[source]

Generate K-means cluster assignments for spectral parameter optimization.

This function performs spherical K-means clustering on HEALPix sky pixels to partition the sky into regions with potentially different spectral parameters. The clustering is applied separately for each parameter type (e.g., dust temperature, dust spectral index, synchrotron index).

Parameters:
  • prngkey (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2']) – JAX PRNG key for reproducible clustering initialization.

  • mask (Float[jaxlib._jax.Array, 'npix']) – Full-sky HEALPix mask array (1 for valid pixels, 0 for masked).

  • indices (Int[jaxlib._jax.Array, 'n_valid']) – Indices of unmasked pixels, typically from jnp.where(mask == 1).

  • regions (dict[str, int]) – Dictionary mapping patch names to target number of clusters. Expected keys: "temp_dust_patches", "beta_dust_patches", "beta_pl_patches".

  • max_patches (dict[str, int] | None) – Maximum number of clusters per parameter. If None, uses regions values. This controls the size of the output arrays for JIT compatibility.

  • initial_sample_size (int) – Number of initial samples for K-means initialization. Defaults to 1.

Returns:

Dictionary of cluster assignments (cutout arrays, not full-sky). Keys match regions. Values are int64 arrays of shape (n_unmasked,).

Return type:

dict[str, Int[jaxlib._jax.Array, ‘n_valid’]]

Example

>>> from furax_cs.data import get_mask
>>> mask = get_mask("GAL020")
>>> (indices,) = jnp.where(mask == 1)
>>> regions = {
...     "temp_dust_patches": 50,
...     "beta_dust_patches": 500,
...     "beta_pl_patches": 50,
... }
>>> clusters = kmeans_clusters(
...     jax.random.key(0), mask, indices, regions
... )
>>> clusters["beta_dust_patches"].shape
(n_unmasked_pixels,)

Notes

The function uses normalize_by_first_occurrence to ensure cluster indices are contiguous starting from 0, which is required for parameter indexing.

furax_cs.multires_clusters.multires_clusters(mask, indices, target_ud_grade, nside=None)[source]

Generate multi-resolution cluster assignments using HEALPix ud_grade.

This function creates resolution-based patches where pixels sharing the same low-resolution parent pixel are grouped together. This is the approach used in the LiteBIRD PTEP methodology.

Parameters:
  • mask (Float[jaxlib._jax.Array, 'npix']) – Full-sky HEALPix mask array (1 for valid pixels, 0 for masked).

  • indices (Int[jaxlib._jax.Array, 'n_valid']) – Indices of unmasked pixels, typically from jnp.where(mask == 1).

  • target_ud_grade (dict[str, int]) – Dictionary mapping parameter names to target nside values. Expected keys: "beta_dust", "temp_dust", "beta_pl". Values are nside parameters (must be powers of 2). Use 0 for a single global patch.

  • nside (int | None) – Input map resolution. If None, inferred from mask size.

Returns:

Dictionary of normalized cluster assignments (cutout arrays). Keys are "{param}_patches" format. Values are int64 arrays.

Return type:

dict[str, Int[jaxlib._jax.Array, ‘n_valid’]]

Example

>>> from furax_cs.data import get_mask
>>> mask = get_mask("GAL020")
>>> (indices,) = jnp.where(mask == 1)
>>> target_resolutions = {
...     "beta_dust": 64,   # Full resolution
...     "temp_dust": 32,   # 4x fewer patches
...     "beta_pl": 16,     # 16x fewer patches
... }
>>> clusters = multires_clusters(mask, indices, target_resolutions)
>>> clusters["temp_dust_patches"].shape
(n_unmasked_pixels,)

Noise

furax_cs.noise.generate_noise_operator(key, noise_ratio, indices, nside, masked_d, instrument, stokes_type=None)[source]

Generate noised data and corresponding noise covariance operator.

This function creates a complete noise model for CMB component separation:

  1. Generates white noise scaled by instrument sensitivity.

  2. Adds noise to the input data.

  3. Creates a diagonal noise covariance operator for the likelihood.

Parameters:
  • key (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2']) – JAX PRNG key for reproducible noise generation.

  • noise_ratio (float) – Noise level as fraction of signal. Use 0.0 for noiseless analysis. Typical values: 0.1 (10%), 1.0 (100%), etc.

  • indices (Int[jaxlib._jax.Array, 'n_valid']) – Indices of unmasked pixels from the full-sky map.

  • nside (int) – HEALPix resolution parameter.

  • masked_d (Stokes) – Input data (already masked/cutout), Stokes object with Q and U.

  • instrument (FGBusterInstrument) – Instrument configuration containing frequency bands and noise specs.

  • stokes_type (Literal['QU', 'IQU'] | None) – Stokes parameters to use. If None, inferred from masked_d.stokes. "QU" for polarization-only, "IQU" for intensity + polarization.

Returns:

  • noised_d: Data with added instrumental noise.

  • N: Diagonal noise covariance operator for use in likelihood computation.

Return type:

A tuple containing

Example

>>> from furax_cs.data import get_instrument, get_mask
>>> from jax_healpy.clustering import get_cutout_from_mask
>>> instrument = get_instrument("LiteBIRD")
>>> mask = get_mask("GAL020")
>>> (indices,) = jnp.where(mask == 1)
>>> masked_d = get_cutout_from_mask(d, indices, axis=1)
>>> noised_d, N = generate_noise_operator(
...     key=jax.random.key(42),
...     noise_ratio=1.0,
...     indices=indices,
...     nside=64,
...     masked_d=masked_d,
...     instrument=instrument,
... )

Notes

When noise_ratio=0, the noise operator uses variance=1.0 to avoid singular matrices in the likelihood computation.

The noise model assumes: - White (uncorrelated) noise between pixels and frequencies. - Diagonal noise covariance (independent noise per measurement). - Noise variance determined by instrument sensitivity.

Optimization

Optimization is provided by the CADRE package. See the minimization docs for full solver reference.

cadre.minimize.minimize(fn, init_params, solver_name='optax_lbfgs', max_iter=1000, rtol=1e-08, atol=1e-08, lower_bound=None, upper_bound=None, precondition=False, options=None, refresh_steps=10, **fn_kwargs)[source]

Unified optimization interface.

Supports optax solvers, optimistix solvers (via optimistix.minimise), and scipy solvers (via jaxopt.ScipyMinimize, requires cadre[scipy]).

Parameters:
  • fn (Callable) – Objective function to minimize. Should accept (params, **fn_kwargs).

  • init_params (PyTree) – Initial parameter values.

  • solver_name (str) – Solver identifier. See SOLVER_NAMES for available options.

  • max_iter (int) – Maximum iterations.

  • rtol (float) – Relative/absolute tolerance for optimization convergence.

  • atol (float) – Relative/absolute tolerance for optimization convergence.

  • lower_bound (PyTree, optional) – Box constraints.

  • upper_bound (PyTree, optional) – Box constraints.

  • precondition (bool) – Whether to apply parameter transformation and output scaling.

  • options (dict, optional) –

    Extra arguments passed to the solver factory (get_solver). For active-set solvers (ADABK{N} family) the recognised keys are:

    • cooldown (int, default 20) — steps to suppress termination after a constraint release.

    • min_steps (int, default 10) — minimum iterations before termination is considered.

    • verbose_print (bool, default False) — print per-step debug info via jax.debug.print (JIT-compatible).

    • max_linesearch_steps (int, default 50) — maximum line-search steps per iteration (active-set and optax_lbfgs solvers).

    • linesearch (str) — linesearch variant for optax_lbfgs ("zoom" or "backtracking").

  • **fn_kwargs (Any) – Additional arguments passed to fn.

  • refresh_steps (int)

  • **fn_kwargs

Returns:

  • final_params (PyTree) – Optimized parameters.

  • final_state (UnifiedState) – Final optimizer state containing best loss, best parameters, iteration count, and solver state.

Return type:

tuple[PyTree[jaxtyping.Float[jaxlib._jax.Array, ‘P’]], UnifiedState]

cadre.solvers.get_solver(solver_name, rtol=1e-08, atol=1e-08, learning_rate=0.001, max_linesearch_steps=50, lower=None, upper=None, verbose_print=False, min_steps=10, cooldown=20, **kwargs)[source]

Create a solver instance from a name string.

Parameters:
  • solver_name (str) – Solver identifier. See SOLVER_NAMES for available options.

  • rtol (float) – Relative tolerance for optimistix solvers.

  • atol (float) – Absolute tolerance for optimistix solvers.

  • learning_rate (float) – Learning rate for adam solver.

  • max_linesearch_steps (int) – Maximum linesearch steps for L-BFGS solvers.

  • lower (PyTree, optional) – Lower bounds for box projection (optax solvers only).

  • upper (PyTree, optional) – Upper bounds for box projection (optax solvers only).

  • verbose_print (bool) – If True, print per-step termination diagnostics for active-set solvers via jax.debug.print (JIT-compatible).

  • min_steps (int) – Minimum iterations before termination is considered (active-set solvers only).

  • cooldown (int) – Steps to suppress termination after a constraint release (active-set solvers only).

  • max_linesearch_steps – Maximum line-search steps per iteration (active-set and optax_lbfgs solvers).

  • kwargs (Any)

Returns:

  • solver (Solver can be either a BestSoFar wrapped minimiser or a string for scipy.) – The solver instance.

  • solver_type (str) – One of “optimistix”, “scipy”.

Return type:

tuple[BestSoFarMinimiser | str, Literal[‘optimistix’, ‘scipy’]]

Binning

furax_cs.binning.bin_parameter_map(pixel_map, nbins)[source]

Bin a valid-pixels-only parameter map into equal-width bins.

Parameters:
  • pixel_map (np.ndarray) – 1-D array of parameter values for valid pixels only (no UNSEEN).

  • nbins (int) – Number of equal-width bins.

Returns:

  • patch_indices (np.ndarray) – 0-based bin indices, shape (n_valid,), values in [0, nbins-1].

  • bin_centers (np.ndarray) – Bin center values, shape (nbins,).

  • bin_edges (np.ndarray) – Bin edges, shape (nbins+1,).

Return type:

tuple[ndarray, ndarray, ndarray]

Example

Reconstruct a full-sky binned map from valid-pixel results:

import numpy as np
import healpy as hp
from furax_cs import bin_parameter_map

nside = 64
npix = hp.nside2npix(nside)
mask = np.load("mask.npy")
(valid,) = np.where(mask == 1)

pixel_values = ...  # parameter values for valid pixels only
patch_indices, centers, edges = bin_parameter_map(pixel_values, nbins=10)

# Write back to a full-sky map (masked pixels = UNSEEN)
out_map = np.full(npix, hp.UNSEEN)
out_map[valid] = patch_indices.astype(float)
np.save("patches_beta_dust.npy", out_map)

Data

Data generation and instrument configuration for CMB component separation.

furax_cs.data.get_mixin_matrix_operator(params, patch_indices, nu, sky, dust_nu0, synchrotron_nu0)[source]

Construct mixing matrix operators for CMB and foregrounds.

Parameters:
  • params (PyTree[jaxtyping.Float[jaxlib._jax.Array, 'P']]) – Spectral parameters (temp_dust, beta_dust, beta_pl).

  • patch_indices (PyTree[jaxtyping.Int[jaxlib._jax.Array, 'P']]) – Patch assignment indices for each parameter.

  • nu (Float[jaxlib._jax.Array, 'Nf']) – Frequency array in GHz.

  • sky (dict[str, Stokes]) – Sky component dictionary from FURAX.

  • dust_nu0 (float) – Dust reference frequency in GHz.

  • synchrotron_nu0 (float) – Synchrotron reference frequency in GHz.

Returns:

A tuple (MixingMatrixOperator with CMB, MixingMatrixOperator without CMB).

Return type:

tuple[MixingMatrixOperator, MixingMatrixOperator]

Example

>>> A, A_nocmb = get_mixin_matrix_operator(params, patches, nu, sky, 150.0, 20.0)
furax_cs.data.load_cmb_map(nside, sky='c1d0s0')[source]

Load cached CMB-only maps.

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • sky (str) – Sky model preset string. Defaults to “c1d0s0”.

Returns:

CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB.

Raises:

FileNotFoundError – If cache file does not exist.

Return type:

Float[jaxlib._jax.Array, ‘3 npix’]

Example

>>> cmb_map = load_cmb_map(nside=64, sky="c1d0s0")
furax_cs.data.load_fg_map(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0')[source]

Load cached foreground-only frequency maps.

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.

  • instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.

  • sky (str) – Sky model preset string, CMB automatically excluded. Defaults to “c1d0s0”.

Returns:

A tuple (frequencies, freq_maps) containing only foreground contributions.

Return type:

tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]

Example

>>> freqs, fg_maps = load_fg_map(nside=64, sky="c1d1s1")
furax_cs.data.load_from_cache(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0')[source]

Load cached frequency maps from disk.

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.

  • instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.

  • sky (str) – Sky model preset string. Defaults to “c1d0s0”.

Returns:

A tuple (frequencies, freq_maps) loaded from cache.

Raises:

FileNotFoundError – If cache file does not exist.

Return type:

tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]

Example

>>> freqs, maps = load_from_cache(nside=64, noise_ratio=0.0, sky="c1d0s0")
furax_cs.data.save_cmb_map(nside, sky='c1d0s0')[source]

Generate and cache CMB-only maps for template generation.

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • sky (str) – Sky model preset string. Defaults to “c1d0s0”.

Returns:

CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB.

Return type:

Float[jaxlib._jax.Array, ‘3 npix’]

Example

>>> cmb_map = save_cmb_map(nside=64, sky="c1d0s0")
furax_cs.data.save_fg_map(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0', key=None)[source]

Generate and cache foreground-only frequency maps (CMB excluded).

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.

  • instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.

  • sky (str) – Sky model preset string, CMB component automatically removed. Defaults to “c1d0s0”.

  • key (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2'] | None) – JAX random key for noise generation. Defaults to None.

Returns:

A tuple (frequencies, freq_maps) containing only foreground contributions.

Return type:

tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]

Example

>>> freqs, fg_maps = save_fg_map(nside=64, sky="c1d1s1")
furax_cs.data.save_to_cache(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0', key=None)[source]

Generate and cache frequency maps for component separation.

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.

  • instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.

  • sky (str) – Sky model preset string (e.g., “c1d0s0” for CMB only). Defaults to “c1d0s0”.

  • key (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2'] | None) – JAX random key for noise generation. Defaults to None.

Returns:

A tuple (frequencies, freq_maps) where frequencies are in GHz and freq_maps have shape (n_freq, 3, n_pix) for Stokes I, Q, U.

Return type:

tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]

Example

>>> freqs, maps = save_to_cache(nside=64, noise_ratio=0.0, sky="c1d0s0")
furax_cs.data.simulate_D_from_params(params, patch_indices, nu, sky, dust_nu0, synchrotron_nu0)[source]

Simulate observed frequency maps given spectral parameters.

Parameters:
  • params (PyTree[jaxtyping.Float[jaxlib._jax.Array, 'P']]) – Spectral parameters (temp_dust, beta_dust, beta_pl).

  • patch_indices (PyTree[jaxtyping.Int[jaxlib._jax.Array, 'P']]) – Patch assignment indices for each parameter.

  • nu (Float[jaxlib._jax.Array, 'Nf']) – Frequency array in GHz.

  • sky (dict[str, Stokes]) – Sky component dictionary.

  • dust_nu0 (float) – Dust reference frequency in GHz.

  • synchrotron_nu0 (float) – Synchrotron reference frequency in GHz.

Returns:

A tuple (d, d_nocmb) where d includes CMB and d_nocmb excludes it.

Return type:

tuple[Stokes, Stokes]

Example

>>> d, d_nocmb = simulate_D_from_params(params, patches, nu, sky, 150.0, 20.0)
furax_cs.data.get_mask(mask_name='GAL020', nside=64)[source]

Load and process galactic masks at specified resolution.

Parameters:
  • mask_name (str) – Mask identifier (e.g., “GAL020”, “GAL040”, “GALACTIC”) or boolean expression (e.g., “GAL020+GAL040”, “ALL-GALACTIC”). Defaults to “GAL020”.

  • nside (int) – HEALPix resolution parameter. Defaults to 64.

Returns:

Boolean mask array where True indicates observed pixels.

Raises:

ValueError – If mask_name is invalid.

Return type:

Bool[jaxlib._jax.Array, ‘npix’]

Notes

Available mask choices: ALL, GALACTIC, GAL020, GAL040, GAL060, and their _U (upper) and _L (lower) hemisphere variants.

Boolean operations are supported: - Use + for union (logical OR) - Use - for subtraction (logical AND NOT) - Expressions are evaluated left-to-right - Examples: “GAL020+GAL040”, “ALL-GALACTIC”, “GAL020+GAL040-GALACTIC”

Masks are automatically generated and cached on first call for each nside.

Example

>>> mask = get_mask("GAL020", nside=64)
>>> # Using boolean expression:
>>> mask_union = get_mask("GAL020+GAL040", nside=64)
furax_cs.data.get_instrument(instrument_name)[source]

Get an instrument configuration by name.

Parameters:

instrument_name (str) – Name of the instrument (e.g., “LiteBIRD”, “Planck”). Must correspond to an entry in instruments.yaml. Use “default” for the FGBuster default instrument.

Returns:

The instrument configuration object with frequency bands and sensitivities.

Raises:

ValueError – If instrument_name is not found in the configuration.

Return type:

FGBusterInstrument

Example

>>> instrument = get_instrument("LiteBIRD")
>>> print(instrument.frequency)
furax_cs.data.load_search_space(filepath=None)[source]

Load search space configuration from YAML file.

Parameters:

filepath (str | Path | None) – Path to custom search space YAML file. If None, loads the default configuration from search_spaces_default.yaml.

Returns:

Dictionary with JAX arrays for T_d_patches, B_d_patches, B_s_patches.

Raises:
  • FileNotFoundError – If the specified filepath does not exist.

  • ValueError – If the YAML file is missing required keys or has invalid values.

Return type:

dict[str, Int[jaxlib._jax.Array, ‘N’]]

Example

>>> space = load_search_space()
>>> print(space["T_d_patches"])
furax_cs.data.search_space_to_jax(config)[source]

Convert search space configuration from YAML to JAX arrays.

Parameters:

config (dict[str, Any]) – Dictionary with search space parameters as lists or numpy arrays.

Returns:

Dictionary with JAX arrays ready for grid search.

Return type:

dict[str, Int[jaxlib._jax.Array, ‘N’]]

Example

>>> config = {"T_d_patches": [1, 2], "B_d_patches": [3, 4], "B_s_patches": [5]}
>>> jax_space = search_space_to_jax(config)
furax_cs.data.dump_default_search_space(output_path)[source]

Dump the default search space configuration to a YAML file.

This creates a template file that users can customize for their needs.

Parameters:

output_path (str | Path) – Path where the default search space YAML will be saved.

Raises:
  • FileNotFoundError – If the default config file is missing in the package.

  • IOError – If the output file cannot be written.

Return type:

None

Example

>>> dump_default_search_space("my_search_space.yaml")
furax_cs.data.validate_search_space(search_space)[source]

Validate that search space has valid structure and values.

Parameters:

search_space (dict[str, Int[jaxlib._jax.Array, 'N']]) – Dictionary with JAX arrays for search space parameters.

Raises:

ValueError – If validation fails (missing keys, empty arrays, invalid values).

Return type:

None

Example

>>> validate_search_space(search_space)
furax_cs.data.sanitize_mask_name(mask_expr)[source]

Convert mask expression to valid folder name.

Parameters:

mask_expr (str) – Mask expression potentially containing + (union) or - (subtract) operators.

Returns:

Sanitized folder name with operators replaced by descriptive names.

Return type:

str

Example

>>> sanitize_mask_name("GAL020+GAL040")
'GAL020_UNION_GAL040'
>>> sanitize_mask_name("ALL-GALACTIC")
'ALL_SUBTRACT_GALACTIC'
furax_cs.data.generate_custom_cmb(r_value, nside, seed=None)[source]

Generate a CMB realization with a specific tensor-to-scalar ratio r.

Uses CAMB + healpy synfast with linear approximation for B-modes.

Parameters:
  • r_value (float) – Tensor-to-scalar ratio.

  • nside (int) – HEALPix resolution.

  • seed (int | None) – Random seed for generation.

Returns:

CMB map (3, npix) in uK_CMB.

Return type:

Float[jaxlib._jax.Array, ‘3 npix’]

Example

>>> cmb_map = generate_custom_cmb(r_value=0.01, nside=64, seed=42)
furax_cs.data.generate_needed_maps(nside_list=None, noise_ratio_list=None, instrument_name='LiteBIRD', sky_list=None)[source]

Batch generate and cache all required frequency maps.

Parameters:
  • nside_list (list[int] | None) – HEALPix resolutions to generate. Defaults to [4, 8, 32, 64, 128].

  • noise_ratio_list (list[float] | None) – Noise ratio configurations. Defaults to [0.0, 1.0].

  • instrument_name (str) – Instrument configuration. Defaults to “LiteBIRD”.

  • sky_list (list[str] | None) – Sky model presets. Defaults to [“c1d0s0”, “c1d1s1”].

Return type:

None

Notes

Generates full frequency maps, foreground-only maps, and CMB-only maps for all combinations of input parameters.

Example

>>> generate_needed_maps(nside_list=[64], noise_ratio_list=[0.0, 1.0])
class furax_cs.data.CMBLensedWithTensors(nside, r=0.0, cmb_seed=None, lmax=2000)[source]

Bases: object

CMB map generator with custom tensor-to-scalar ratio r.

Uses CAMB for power spectra and healpy synfast for map generation. Power spectra are computed as: Cl(r) = r * Cl_tensor(r=1) + Cl_lensing

Parameters:
  • nside (int) – HEALPix resolution parameter.

  • r (float) – Tensor-to-scalar ratio. Defaults to 0.0.

  • cmb_seed (int | None) – Random seed for map generation.

  • lmax (int) – Maximum multipole for power spectra. Defaults to 2000.