API Reference¶
Clustering¶
- furax_cs.kmeans_clusters.kmeans_clusters(prngkey, mask, indices, regions, max_patches=None, initial_sample_size=1)[source]¶
Generate K-means cluster assignments for spectral parameter optimization.
This function performs spherical K-means clustering on HEALPix sky pixels to partition the sky into regions with potentially different spectral parameters. The clustering is applied separately for each parameter type (e.g., dust temperature, dust spectral index, synchrotron index).
- Parameters:
prngkey (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2']) – JAX PRNG key for reproducible clustering initialization.
mask (Float[jaxlib._jax.Array, 'npix']) – Full-sky HEALPix mask array (1 for valid pixels, 0 for masked).
indices (Int[jaxlib._jax.Array, 'n_valid']) – Indices of unmasked pixels, typically from
jnp.where(mask == 1).regions (dict[str, int]) – Dictionary mapping patch names to target number of clusters. Expected keys:
"temp_dust_patches","beta_dust_patches","beta_pl_patches".max_patches (dict[str, int] | None) – Maximum number of clusters per parameter. If None, uses regions values. This controls the size of the output arrays for JIT compatibility.
initial_sample_size (int) – Number of initial samples for K-means initialization. Defaults to 1.
- Returns:
Dictionary of cluster assignments (cutout arrays, not full-sky). Keys match
regions. Values are int64 arrays of shape (n_unmasked,).- Return type:
Example
>>> from furax_cs.data import get_mask >>> mask = get_mask("GAL020") >>> (indices,) = jnp.where(mask == 1) >>> regions = { ... "temp_dust_patches": 50, ... "beta_dust_patches": 500, ... "beta_pl_patches": 50, ... } >>> clusters = kmeans_clusters( ... jax.random.key(0), mask, indices, regions ... ) >>> clusters["beta_dust_patches"].shape (n_unmasked_pixels,)
Notes
The function uses
normalize_by_first_occurrenceto ensure cluster indices are contiguous starting from 0, which is required for parameter indexing.
- furax_cs.multires_clusters.multires_clusters(mask, indices, target_ud_grade, nside=None)[source]¶
Generate multi-resolution cluster assignments using HEALPix ud_grade.
This function creates resolution-based patches where pixels sharing the same low-resolution parent pixel are grouped together. This is the approach used in the LiteBIRD PTEP methodology.
- Parameters:
mask (Float[jaxlib._jax.Array, 'npix']) – Full-sky HEALPix mask array (1 for valid pixels, 0 for masked).
indices (Int[jaxlib._jax.Array, 'n_valid']) – Indices of unmasked pixels, typically from
jnp.where(mask == 1).target_ud_grade (dict[str, int]) – Dictionary mapping parameter names to target nside values. Expected keys:
"beta_dust","temp_dust","beta_pl". Values are nside parameters (must be powers of 2). Use 0 for a single global patch.nside (int | None) – Input map resolution. If None, inferred from mask size.
- Returns:
Dictionary of normalized cluster assignments (cutout arrays). Keys are
"{param}_patches"format. Values are int64 arrays.- Return type:
Example
>>> from furax_cs.data import get_mask >>> mask = get_mask("GAL020") >>> (indices,) = jnp.where(mask == 1) >>> target_resolutions = { ... "beta_dust": 64, # Full resolution ... "temp_dust": 32, # 4x fewer patches ... "beta_pl": 16, # 16x fewer patches ... } >>> clusters = multires_clusters(mask, indices, target_resolutions) >>> clusters["temp_dust_patches"].shape (n_unmasked_pixels,)
Noise¶
- furax_cs.noise.generate_noise_operator(key, noise_ratio, indices, nside, masked_d, instrument, stokes_type=None)[source]¶
Generate noised data and corresponding noise covariance operator.
This function creates a complete noise model for CMB component separation:
Generates white noise scaled by instrument sensitivity.
Adds noise to the input data.
Creates a diagonal noise covariance operator for the likelihood.
- Parameters:
key (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2']) – JAX PRNG key for reproducible noise generation.
noise_ratio (float) – Noise level as fraction of signal. Use 0.0 for noiseless analysis. Typical values: 0.1 (10%), 1.0 (100%), etc.
indices (Int[jaxlib._jax.Array, 'n_valid']) – Indices of unmasked pixels from the full-sky map.
nside (int) – HEALPix resolution parameter.
masked_d (Stokes) – Input data (already masked/cutout), Stokes object with Q and U.
instrument (FGBusterInstrument) – Instrument configuration containing frequency bands and noise specs.
stokes_type (Literal['QU', 'IQU'] | None) – Stokes parameters to use. If None, inferred from
masked_d.stokes."QU"for polarization-only,"IQU"for intensity + polarization.
- Returns:
noised_d: Data with added instrumental noise.
N: Diagonal noise covariance operator for use in likelihood computation.
- Return type:
A tuple containing
Example
>>> from furax_cs.data import get_instrument, get_mask >>> from jax_healpy.clustering import get_cutout_from_mask >>> instrument = get_instrument("LiteBIRD") >>> mask = get_mask("GAL020") >>> (indices,) = jnp.where(mask == 1) >>> masked_d = get_cutout_from_mask(d, indices, axis=1) >>> noised_d, N = generate_noise_operator( ... key=jax.random.key(42), ... noise_ratio=1.0, ... indices=indices, ... nside=64, ... masked_d=masked_d, ... instrument=instrument, ... )
Notes
When
noise_ratio=0, the noise operator uses variance=1.0 to avoid singular matrices in the likelihood computation.The noise model assumes: - White (uncorrelated) noise between pixels and frequencies. - Diagonal noise covariance (independent noise per measurement). - Noise variance determined by instrument sensitivity.
Optimization¶
Optimization is provided by the CADRE package. See the minimization docs for full solver reference.
- cadre.minimize.minimize(fn, init_params, solver_name='optax_lbfgs', max_iter=1000, rtol=1e-08, atol=1e-08, lower_bound=None, upper_bound=None, precondition=False, options=None, refresh_steps=10, **fn_kwargs)[source]¶
Unified optimization interface.
Supports optax solvers, optimistix solvers (via optimistix.minimise), and scipy solvers (via jaxopt.ScipyMinimize, requires
cadre[scipy]).- Parameters:
fn (Callable) – Objective function to minimize. Should accept (params, **fn_kwargs).
init_params (PyTree) – Initial parameter values.
solver_name (str) – Solver identifier. See SOLVER_NAMES for available options.
max_iter (int) – Maximum iterations.
rtol (float) – Relative/absolute tolerance for optimization convergence.
atol (float) – Relative/absolute tolerance for optimization convergence.
lower_bound (PyTree, optional) – Box constraints.
upper_bound (PyTree, optional) – Box constraints.
precondition (bool) – Whether to apply parameter transformation and output scaling.
options (dict, optional) –
Extra arguments passed to the solver factory (get_solver). For active-set solvers (
ADABK{N}family) the recognised keys are:cooldown(int, default 20) — steps to suppress termination after a constraint release.min_steps(int, default 10) — minimum iterations before termination is considered.verbose_print(bool, default False) — print per-step debug info viajax.debug.print(JIT-compatible).max_linesearch_steps(int, default 50) — maximum line-search steps per iteration (active-set andoptax_lbfgssolvers).linesearch(str) — linesearch variant foroptax_lbfgs("zoom"or"backtracking").
**fn_kwargs (Any) – Additional arguments passed to fn.
refresh_steps (int)
**fn_kwargs
- Returns:
final_params (PyTree) – Optimized parameters.
final_state (UnifiedState) – Final optimizer state containing best loss, best parameters, iteration count, and solver state.
- Return type:
tuple[PyTree[jaxtyping.Float[jaxlib._jax.Array, ‘P’]], UnifiedState]
- cadre.solvers.get_solver(solver_name, rtol=1e-08, atol=1e-08, learning_rate=0.001, max_linesearch_steps=50, lower=None, upper=None, verbose_print=False, min_steps=10, cooldown=20, **kwargs)[source]¶
Create a solver instance from a name string.
- Parameters:
solver_name (str) – Solver identifier. See SOLVER_NAMES for available options.
rtol (float) – Relative tolerance for optimistix solvers.
atol (float) – Absolute tolerance for optimistix solvers.
learning_rate (float) – Learning rate for adam solver.
max_linesearch_steps (int) – Maximum linesearch steps for L-BFGS solvers.
lower (PyTree, optional) – Lower bounds for box projection (optax solvers only).
upper (PyTree, optional) – Upper bounds for box projection (optax solvers only).
verbose_print (bool) – If True, print per-step termination diagnostics for active-set solvers via
jax.debug.print(JIT-compatible).min_steps (int) – Minimum iterations before termination is considered (active-set solvers only).
cooldown (int) – Steps to suppress termination after a constraint release (active-set solvers only).
max_linesearch_steps – Maximum line-search steps per iteration (active-set and
optax_lbfgssolvers).kwargs (Any)
- Returns:
solver (Solver can be either a BestSoFar wrapped minimiser or a string for scipy.) – The solver instance.
solver_type (str) – One of “optimistix”, “scipy”.
- Return type:
tuple[BestSoFarMinimiser | str, Literal[‘optimistix’, ‘scipy’]]
Binning¶
- furax_cs.binning.bin_parameter_map(pixel_map, nbins)[source]¶
Bin a valid-pixels-only parameter map into equal-width bins.
- Parameters:
pixel_map (np.ndarray) – 1-D array of parameter values for valid pixels only (no UNSEEN).
nbins (int) – Number of equal-width bins.
- Returns:
patch_indices (np.ndarray) – 0-based bin indices, shape
(n_valid,), values in[0, nbins-1].bin_centers (np.ndarray) – Bin center values, shape
(nbins,).bin_edges (np.ndarray) – Bin edges, shape
(nbins+1,).
- Return type:
Example
Reconstruct a full-sky binned map from valid-pixel results:
import numpy as np import healpy as hp from furax_cs import bin_parameter_map nside = 64 npix = hp.nside2npix(nside) mask = np.load("mask.npy") (valid,) = np.where(mask == 1) pixel_values = ... # parameter values for valid pixels only patch_indices, centers, edges = bin_parameter_map(pixel_values, nbins=10) # Write back to a full-sky map (masked pixels = UNSEEN) out_map = np.full(npix, hp.UNSEEN) out_map[valid] = patch_indices.astype(float) np.save("patches_beta_dust.npy", out_map)
Data¶
Data generation and instrument configuration for CMB component separation.
- furax_cs.data.get_mixin_matrix_operator(params, patch_indices, nu, sky, dust_nu0, synchrotron_nu0)[source]
Construct mixing matrix operators for CMB and foregrounds.
- Parameters:
params (PyTree[jaxtyping.Float[jaxlib._jax.Array, 'P']]) – Spectral parameters (temp_dust, beta_dust, beta_pl).
patch_indices (PyTree[jaxtyping.Int[jaxlib._jax.Array, 'P']]) – Patch assignment indices for each parameter.
nu (Float[jaxlib._jax.Array, 'Nf']) – Frequency array in GHz.
sky (dict[str, Stokes]) – Sky component dictionary from FURAX.
dust_nu0 (float) – Dust reference frequency in GHz.
synchrotron_nu0 (float) – Synchrotron reference frequency in GHz.
- Returns:
A tuple (MixingMatrixOperator with CMB, MixingMatrixOperator without CMB).
- Return type:
tuple[MixingMatrixOperator, MixingMatrixOperator]
Example
>>> A, A_nocmb = get_mixin_matrix_operator(params, patches, nu, sky, 150.0, 20.0)
- furax_cs.data.load_cmb_map(nside, sky='c1d0s0')[source]
Load cached CMB-only maps.
- Parameters:
- Returns:
CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB.
- Raises:
FileNotFoundError – If cache file does not exist.
- Return type:
Float[jaxlib._jax.Array, ‘3 npix’]
Example
>>> cmb_map = load_cmb_map(nside=64, sky="c1d0s0")
- furax_cs.data.load_fg_map(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0')[source]
Load cached foreground-only frequency maps.
- Parameters:
nside (int) – HEALPix resolution parameter.
noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.
sky (str) – Sky model preset string, CMB automatically excluded. Defaults to “c1d0s0”.
- Returns:
A tuple (frequencies, freq_maps) containing only foreground contributions.
- Return type:
tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]
Example
>>> freqs, fg_maps = load_fg_map(nside=64, sky="c1d1s1")
- furax_cs.data.load_from_cache(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0')[source]
Load cached frequency maps from disk.
- Parameters:
- Returns:
A tuple (frequencies, freq_maps) loaded from cache.
- Raises:
FileNotFoundError – If cache file does not exist.
- Return type:
tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]
Example
>>> freqs, maps = load_from_cache(nside=64, noise_ratio=0.0, sky="c1d0s0")
- furax_cs.data.save_cmb_map(nside, sky='c1d0s0')[source]
Generate and cache CMB-only maps for template generation.
- Parameters:
- Returns:
CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB.
- Return type:
Float[jaxlib._jax.Array, ‘3 npix’]
Example
>>> cmb_map = save_cmb_map(nside=64, sky="c1d0s0")
- furax_cs.data.save_fg_map(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0', key=None)[source]
Generate and cache foreground-only frequency maps (CMB excluded).
- Parameters:
nside (int) – HEALPix resolution parameter.
noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.
sky (str) – Sky model preset string, CMB component automatically removed. Defaults to “c1d0s0”.
key (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2'] | None) – JAX random key for noise generation. Defaults to None.
- Returns:
A tuple (frequencies, freq_maps) containing only foreground contributions.
- Return type:
tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]
Example
>>> freqs, fg_maps = save_fg_map(nside=64, sky="c1d1s1")
- furax_cs.data.save_to_cache(nside, noise_ratio=0.0, instrument_name='LiteBIRD', sky='c1d0s0', key=None)[source]
Generate and cache frequency maps for component separation.
- Parameters:
nside (int) – HEALPix resolution parameter.
noise_ratio (float) – Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name (str) – Instrument configuration name. Defaults to “LiteBIRD”.
sky (str) – Sky model preset string (e.g., “c1d0s0” for CMB only). Defaults to “c1d0s0”.
key (Key[jaxlib._jax.Array, ''] | UInt32[jaxlib._jax.Array, '2'] | None) – JAX random key for noise generation. Defaults to None.
- Returns:
A tuple (frequencies, freq_maps) where frequencies are in GHz and freq_maps have shape (n_freq, 3, n_pix) for Stokes I, Q, U.
- Return type:
tuple[Float[jaxlib._jax.Array, ‘freqs’], Float[jaxlib._jax.Array, ‘freqs 3 npix’]]
Example
>>> freqs, maps = save_to_cache(nside=64, noise_ratio=0.0, sky="c1d0s0")
- furax_cs.data.simulate_D_from_params(params, patch_indices, nu, sky, dust_nu0, synchrotron_nu0)[source]
Simulate observed frequency maps given spectral parameters.
- Parameters:
params (PyTree[jaxtyping.Float[jaxlib._jax.Array, 'P']]) – Spectral parameters (temp_dust, beta_dust, beta_pl).
patch_indices (PyTree[jaxtyping.Int[jaxlib._jax.Array, 'P']]) – Patch assignment indices for each parameter.
nu (Float[jaxlib._jax.Array, 'Nf']) – Frequency array in GHz.
dust_nu0 (float) – Dust reference frequency in GHz.
synchrotron_nu0 (float) – Synchrotron reference frequency in GHz.
- Returns:
A tuple (d, d_nocmb) where d includes CMB and d_nocmb excludes it.
- Return type:
tuple[Stokes, Stokes]
Example
>>> d, d_nocmb = simulate_D_from_params(params, patches, nu, sky, 150.0, 20.0)
- furax_cs.data.get_mask(mask_name='GAL020', nside=64)[source]
Load and process galactic masks at specified resolution.
- Parameters:
- Returns:
Boolean mask array where True indicates observed pixels.
- Raises:
ValueError – If mask_name is invalid.
- Return type:
Bool[jaxlib._jax.Array, ‘npix’]
Notes
Available mask choices: ALL, GALACTIC, GAL020, GAL040, GAL060, and their _U (upper) and _L (lower) hemisphere variants.
Boolean operations are supported: - Use + for union (logical OR) - Use - for subtraction (logical AND NOT) - Expressions are evaluated left-to-right - Examples: “GAL020+GAL040”, “ALL-GALACTIC”, “GAL020+GAL040-GALACTIC”
Masks are automatically generated and cached on first call for each nside.
Example
>>> mask = get_mask("GAL020", nside=64) >>> # Using boolean expression: >>> mask_union = get_mask("GAL020+GAL040", nside=64)
- furax_cs.data.get_instrument(instrument_name)[source]
Get an instrument configuration by name.
- Parameters:
instrument_name (str) – Name of the instrument (e.g., “LiteBIRD”, “Planck”). Must correspond to an entry in instruments.yaml. Use “default” for the FGBuster default instrument.
- Returns:
The instrument configuration object with frequency bands and sensitivities.
- Raises:
ValueError – If instrument_name is not found in the configuration.
- Return type:
FGBusterInstrument
Example
>>> instrument = get_instrument("LiteBIRD") >>> print(instrument.frequency)
- furax_cs.data.load_search_space(filepath=None)[source]
Load search space configuration from YAML file.
- Parameters:
filepath (str | Path | None) – Path to custom search space YAML file. If None, loads the default configuration from search_spaces_default.yaml.
- Returns:
Dictionary with JAX arrays for T_d_patches, B_d_patches, B_s_patches.
- Raises:
FileNotFoundError – If the specified filepath does not exist.
ValueError – If the YAML file is missing required keys or has invalid values.
- Return type:
Example
>>> space = load_search_space() >>> print(space["T_d_patches"])
- furax_cs.data.search_space_to_jax(config)[source]
Convert search space configuration from YAML to JAX arrays.
- Parameters:
config (dict[str, Any]) – Dictionary with search space parameters as lists or numpy arrays.
- Returns:
Dictionary with JAX arrays ready for grid search.
- Return type:
Example
>>> config = {"T_d_patches": [1, 2], "B_d_patches": [3, 4], "B_s_patches": [5]} >>> jax_space = search_space_to_jax(config)
- furax_cs.data.dump_default_search_space(output_path)[source]
Dump the default search space configuration to a YAML file.
This creates a template file that users can customize for their needs.
- Parameters:
output_path (str | Path) – Path where the default search space YAML will be saved.
- Raises:
FileNotFoundError – If the default config file is missing in the package.
IOError – If the output file cannot be written.
- Return type:
None
Example
>>> dump_default_search_space("my_search_space.yaml")
- furax_cs.data.validate_search_space(search_space)[source]
Validate that search space has valid structure and values.
- Parameters:
search_space (dict[str, Int[jaxlib._jax.Array, 'N']]) – Dictionary with JAX arrays for search space parameters.
- Raises:
ValueError – If validation fails (missing keys, empty arrays, invalid values).
- Return type:
None
Example
>>> validate_search_space(search_space)
- furax_cs.data.sanitize_mask_name(mask_expr)[source]
Convert mask expression to valid folder name.
- Parameters:
mask_expr (str) – Mask expression potentially containing + (union) or - (subtract) operators.
- Returns:
Sanitized folder name with operators replaced by descriptive names.
- Return type:
Example
>>> sanitize_mask_name("GAL020+GAL040") 'GAL020_UNION_GAL040' >>> sanitize_mask_name("ALL-GALACTIC") 'ALL_SUBTRACT_GALACTIC'
- furax_cs.data.generate_custom_cmb(r_value, nside, seed=None)[source]
Generate a CMB realization with a specific tensor-to-scalar ratio r.
Uses CAMB + healpy synfast with linear approximation for B-modes.
- Parameters:
- Returns:
CMB map (3, npix) in uK_CMB.
- Return type:
Float[jaxlib._jax.Array, ‘3 npix’]
Example
>>> cmb_map = generate_custom_cmb(r_value=0.01, nside=64, seed=42)
- furax_cs.data.generate_needed_maps(nside_list=None, noise_ratio_list=None, instrument_name='LiteBIRD', sky_list=None)[source]
Batch generate and cache all required frequency maps.
- Parameters:
nside_list (list[int] | None) – HEALPix resolutions to generate. Defaults to [4, 8, 32, 64, 128].
noise_ratio_list (list[float] | None) – Noise ratio configurations. Defaults to [0.0, 1.0].
instrument_name (str) – Instrument configuration. Defaults to “LiteBIRD”.
sky_list (list[str] | None) – Sky model presets. Defaults to [“c1d0s0”, “c1d1s1”].
- Return type:
None
Notes
Generates full frequency maps, foreground-only maps, and CMB-only maps for all combinations of input parameters.
Example
>>> generate_needed_maps(nside_list=[64], noise_ratio_list=[0.0, 1.0])