from __future__ import annotations
import argparse
import os
import pickle
import re
from pathlib import Path
from typing import TypeAlias
import astropy.units as u
import camb
import healpy as hp
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from camb import initialpower
from furax._instruments.sky import get_observation, get_sky
from furax.obs.operators import (
CMBOperator,
DustOperator,
MixingMatrixOperator,
SynchrotronOperator,
)
from furax.obs.stokes import Stokes, StokesQU
from jaxtyping import (
Array,
Bool,
Float,
Int,
PRNGKeyArray,
PyTree, # pyright: ignore
)
from ..logging_utils import info, success
from .instruments import get_instrument
SkyType: TypeAlias = dict[str, Stokes]
def compute_Cl_from_camb(r: float, which: str = "total", lmax: int = 2000) -> np.ndarray:
"""Compute CMB power spectra from CAMB.
Args:
r: Tensor-to-scalar ratio.
which: Spectrum type - "total" (lensed total) or "tensor".
lmax: Maximum multipole.
Returns:
Power spectrum array of shape (lmax+1, 4) for TT, EE, BB, TE.
"""
cosmo_params = camb.set_params(
Alens=1.0,
H0=67.5,
ombh2=0.022,
omch2=0.122,
mnu=0.06,
omk=0,
tau=0.06,
As=2e-9,
ns=0.965,
halofit_version="mead",
max_l_tensor=lmax,
max_eta_k_tensor=18000,
parameterization="tensor_param_indeptilt",
)
cosmo_params.set_for_lmax(lmax, lens_potential_accuracy=1)
cosmo_params.WantTensors = True
infl_params = initialpower.InitialPowerLaw()
infl_params.set_params(ns=0.96, r=r, parameterization="tensor_param_indeptilt", nt=0, ntrun=0)
cosmo_params.InitPower = infl_params
results = camb.get_results(cosmo_params)
spectra = results.get_cmb_power_spectra(cosmo_params, CMB_unit="muK", raw_cl=True)
return spectra[which] # Shape: (lmax+1, 4) for TT, EE, BB, TE
_CL_PRIMORDIAL_R1: np.ndarray | None = None
_CL_LENSING: np.ndarray | None = None
def _get_template_spectra(lmax: int = 2000) -> tuple[np.ndarray, np.ndarray]:
"""Get or compute template spectra for linear approximation."""
global _CL_PRIMORDIAL_R1, _CL_LENSING
if _CL_PRIMORDIAL_R1 is None:
_CL_PRIMORDIAL_R1 = compute_Cl_from_camb(r=1.0, which="tensor", lmax=lmax)
if _CL_LENSING is None:
_CL_LENSING = compute_Cl_from_camb(r=0.0, which="total", lmax=lmax)
return _CL_PRIMORDIAL_R1, _CL_LENSING
[docs]
class CMBLensedWithTensors:
"""CMB map generator with custom tensor-to-scalar ratio r.
Uses CAMB for power spectra and healpy synfast for map generation.
Power spectra are computed as: Cl(r) = r * Cl_tensor(r=1) + Cl_lensing
Args:
nside: HEALPix resolution parameter.
r: Tensor-to-scalar ratio. Defaults to 0.0.
cmb_seed: Random seed for map generation.
lmax: Maximum multipole for power spectra. Defaults to 2000.
"""
def __init__(
self,
nside: int,
r: float = 0.0,
cmb_seed: int | None = None,
lmax: int = 2000,
):
self.nside = nside
self.r = r
self.lmax = lmax
# Get template spectra (lazy-loaded)
cl_primordial_r1, cl_lensing = _get_template_spectra(lmax)
# Linear combination: Cl(r) = r * Cl_tensor(r=1) + Cl_lensing
# Shape: (lmax+1, 4) for TT, EE, BB, TE
cl_total = r * cl_primordial_r1 + cl_lensing
# Extract spectra for synfast: TT, EE, BB, TE
cl_tt = cl_total[:, 0]
cl_ee = cl_total[:, 1]
cl_bb = cl_total[:, 2]
cl_te = cl_total[:, 3]
# Generate maps with healpy synfast
# Returns (T, Q, U) maps in uK_CMB
# Set numpy random seed for reproducibility
if cmb_seed is not None:
np.random.seed(cmb_seed)
self.map = np.array(
hp.synfast(
[cl_tt, cl_ee, cl_bb, cl_te],
nside=nside,
new=True,
)
)
def parse_sky_tag(sky: str) -> tuple[str | None, str]:
"""Parse sky string to separate CMB and foreground tags.
Args:
sky: Sky model string (e.g., "c1d0s0", "cr3d0s0").
Returns:
A tuple (cmb_tag, fg_tag). `cmb_tag` is None if no CMB present.
Example:
>>> cmb_tag, fg_tag = parse_sky_tag("c1d1s1")
>>> print(cmb_tag, fg_tag)
c1 d1s1
"""
# Check for custom r pattern (crX)
match = re.search(r"cr(\d+)", sky)
if match:
cmb_tag = match.group(0)
fg_tag = sky.replace(cmb_tag, "")
return cmb_tag, fg_tag
# Legacy 2-char parsing
tags = [sky[i : i + 2] for i in range(0, len(sky), 2)]
cmb_tags = [t for t in tags if t.startswith("c")]
if cmb_tags:
cmb_tag = cmb_tags[0]
fg_tags = [t for t in tags if not t.startswith("c")]
fg_tag = "".join(fg_tags)
return cmb_tag, fg_tag
return None, sky
def _parse_synt_tag(sky: str) -> tuple[str, int, int, int] | None:
"""Returns (base_sky, BD, TD, BS) if sky matches {base_sky}_synt_bd{N}_td{N}_bs{N}, else None."""
match = re.match(r"^(.+)_synt_bd(\d+)_td(\d+)_bs(\d+)$", sky)
if match:
return match.group(1), int(match.group(2)), int(match.group(3)), int(match.group(4))
return None
[docs]
def generate_custom_cmb(
r_value: float, nside: int, seed: int | None = None
) -> Float[Array, "3 npix"]:
"""Generate a CMB realization with a specific tensor-to-scalar ratio r.
Uses CAMB + healpy synfast with linear approximation for B-modes.
Args:
r_value: Tensor-to-scalar ratio.
nside: HEALPix resolution.
seed: Random seed for generation.
Returns:
CMB map (3, npix) in uK_CMB.
Example:
>>> cmb_map = generate_custom_cmb(r_value=0.01, nside=64, seed=42)
"""
info(f"generating with r_value {r_value}")
cmb = CMBLensedWithTensors(nside=nside, r=r_value, cmb_seed=seed)
return cmb.map # Already (3, npix) numpy array
[docs]
def save_to_cache(
nside: int,
noise_ratio: float = 0.0,
instrument_name: str = "LiteBIRD",
sky: str = "c1d0s0",
key: PRNGKeyArray | None = None,
) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]:
"""Generate and cache frequency maps for component separation.
Args:
nside: HEALPix resolution parameter.
noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name: Instrument configuration name. Defaults to "LiteBIRD".
sky: Sky model preset string (e.g., "c1d0s0" for CMB only). Defaults to "c1d0s0".
key: JAX random key for noise generation. Defaults to None.
Returns:
A tuple (frequencies, freq_maps) where frequencies are in GHz and freq_maps
have shape (n_freq, 3, n_pix) for Stokes I, Q, U.
Example:
>>> freqs, maps = save_to_cache(nside=64, noise_ratio=0.0, sky="c1d0s0")
"""
if key is None:
key = jr.PRNGKey(0)
instrument = get_instrument(instrument_name)
# Define cache file path
cache_dir = "freq_maps_cache"
os.makedirs(cache_dir, exist_ok=True)
noise_str = f"noise_{int(noise_ratio * 100)}" if noise_ratio > 0 else "no_noise"
cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{noise_str}_{sky}.pkl")
synt = _parse_synt_tag(sky)
if synt is not None:
base_sky, bd, td, bs = synt
full_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_no_noise_{sky}.pkl")
fg_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_fg_{sky}.pkl")
cmb_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_cmb_{sky}.pkl")
if not (
os.path.exists(full_cache) and os.path.exists(fg_cache) and os.path.exists(cmb_cache)
):
from ..kmeans_clusters import kmeans_clusters as _kmeans_clusters
npix = 12 * nside**2
pysm_sky = get_sky(nside, base_sky)
cmb_arr = np.array(pysm_sky.components[0].map.to_value()) # (3, npix)
dust_arr = np.array(pysm_sky.components[1].get_emission(160 * u.GHz)) # (3, npix)
sync_arr = np.array(pysm_sky.components[2].get_emission(20 * u.GHz)) # (3, npix)
sky_synth = {
"cmb": StokesQU(q=jnp.array(cmb_arr[1]), u=jnp.array(cmb_arr[2])),
"dust": StokesQU(q=jnp.array(dust_arr[1]), u=jnp.array(dust_arr[2])),
"synchrotron": StokesQU(q=jnp.array(sync_arr[1]), u=jnp.array(sync_arr[2])),
}
mask_full = jnp.ones(npix, dtype=bool)
indices_full = jnp.arange(npix)
n_regions_true = {
"beta_dust_patches": bd,
"temp_dust_patches": td,
"beta_pl_patches": bs,
}
true_clusters = _kmeans_clusters(
jax.random.key(0), mask_full, indices_full, n_regions_true
)
base_params = {"beta_dust": 1.54, "temp_dust": 20.0, "beta_pl": -3.0}
true_params_count = {"beta_dust": bd, "temp_dust": td, "beta_pl": bs}
params_flat, tree_struct = jax.tree.flatten(
jax.tree.map(lambda v, c: jnp.full((c,), v), base_params, true_params_count)
)
perturbed = [
x + jr.normal(jr.PRNGKey(i), x.shape) * 0.2 for i, x in enumerate(params_flat)
]
true_params = jax.tree.unflatten(tree_struct, perturbed)
nu = instrument.frequency
dust_nu0 = 160.0
synchrotron_nu0 = 20.0
d, d_nocmb = simulate_D_from_params(
true_params, true_clusters, nu, sky_synth, dust_nu0, synchrotron_nu0
)
n_freq = len(nu)
zeros_freq = np.zeros((n_freq, npix))
freq_maps = np.stack([zeros_freq, np.array(d.q), np.array(d.u)], axis=1)
fg_maps = np.stack([zeros_freq, np.array(d_nocmb.q), np.array(d_nocmb.u)], axis=1)
cmb_map = np.stack([cmb_arr[0], cmb_arr[1], cmb_arr[2]]) # (3, npix) full IQU
with open(full_cache, "wb") as f:
pickle.dump(freq_maps, f)
with open(fg_cache, "wb") as f:
pickle.dump(fg_maps, f)
with open(cmb_cache, "wb") as f:
pickle.dump(cmb_map, f)
success(f"Generated synthetic sky cache for {sky}.")
else:
with open(full_cache, "rb") as f:
freq_maps = pickle.load(f)
info(f"Loaded synthetic sky cache for {sky}.")
return np.array(instrument.frequency), freq_maps
# Check if file exists, load if it does, otherwise create and save it
r_val = None
if os.path.exists(cache_file):
with open(cache_file, "rb") as f:
freq_maps = pickle.load(f)
info(f"Loaded freq_maps for nside {nside} from cache with noise_ratio {noise_ratio}.")
else:
# Check for custom r CMB
match = re.search(r"cr(\d+)", sky)
custom_cmb_map = None
fg_tag = sky
if match:
info(f"Detected custom r tag: {match.group(0)}")
r_exp = int(match.group(1))
r_val = r_exp * 1e-3
info(f"Generating custom CMB with r={r_val}")
fg_tag = sky.replace(match.group(0), "")
# Derive seed from key if possible, else fixed
# key is JAX key (uint32 array). Use first element.
seed = int(key[1]) if key is not None else 0
custom_cmb_map = generate_custom_cmb(r_val, nside, seed=seed)
else:
r_val = 0.0 # Unused but here for Pyright
# If we have custom CMB, we treat the rest as the "furax" sky
tag_to_use = fg_tag if custom_cmb_map is not None else sky
stokes_obs = get_observation(
instrument,
nside=nside,
tag=tag_to_use,
noise_ratio=noise_ratio,
key=key,
stokes_type="IQU",
unit="uK_CMB",
)
# Convert Stokes PyTree to numpy array (n_freq, 3, n_pix)
freq_maps = np.stack(
[np.array(stokes_obs.i), np.array(stokes_obs.q), np.array(stokes_obs.u)], axis=1
)
if custom_cmb_map is not None:
# Add custom CMB (broadcasting over frequencies)
# freq_maps: (n_freq, 3, npix)
# custom_cmb_map: (3, npix)
freq_maps += custom_cmb_map[None, ...]
assert r_val is not None
info(f"Added custom CMB with r={r_val} to maps.")
# Save freq_maps to the cache
with open(cache_file, "wb") as f:
pickle.dump(freq_maps, f)
success(f"Generated and saved freq_maps for nside {nside}.")
return np.array(instrument.frequency), freq_maps
[docs]
def load_from_cache(
nside: int, noise_ratio: float = 0.0, instrument_name: str = "LiteBIRD", sky: str = "c1d0s0"
) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]:
"""Load cached frequency maps from disk.
Args:
nside: HEALPix resolution parameter.
noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name: Instrument configuration name. Defaults to "LiteBIRD".
sky: Sky model preset string. Defaults to "c1d0s0".
Returns:
A tuple (frequencies, freq_maps) loaded from cache.
Raises:
FileNotFoundError: If cache file does not exist.
Example:
>>> freqs, maps = load_from_cache(nside=64, noise_ratio=0.0, sky="c1d0s0")
"""
# Define cache file path
instrument = get_instrument(instrument_name)
noise_str = f"noise_{int(noise_ratio * 100)}" if noise_ratio > 0 else "no_noise"
cache_dir = "freq_maps_cache"
cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{noise_str}_{sky}.pkl")
# Check if file exists and load if it does; otherwise raise an error with guidance
if os.path.exists(cache_file):
with open(cache_file, "rb") as f:
freq_maps = pickle.load(f)
info(f"Loaded freq_maps for nside {nside} from cache.")
else:
raise FileNotFoundError(
f"Cache file for freq_maps with nside {nside} and noise_ratio {noise_ratio} not found.\n"
f"Please generate it first by calling `generate_data --nside {nside}`."
)
return np.array(instrument.frequency), freq_maps
[docs]
def save_fg_map(
nside: int,
noise_ratio: float = 0.0,
instrument_name: str = "LiteBIRD",
sky: str = "c1d0s0",
key: PRNGKeyArray | None = None,
) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]:
"""Generate and cache foreground-only frequency maps (CMB excluded).
Args:
nside: HEALPix resolution parameter.
noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name: Instrument configuration name. Defaults to "LiteBIRD".
sky: Sky model preset string, CMB component automatically removed. Defaults to "c1d0s0".
key: JAX random key for noise generation. Defaults to None.
Returns:
A tuple (frequencies, freq_maps) containing only foreground contributions.
Example:
>>> freqs, fg_maps = save_fg_map(nside=64, sky="c1d1s1")
"""
info(
f"Generating fg map for nside {nside}, noise_ratio {noise_ratio}, instrument {instrument_name}"
)
if _parse_synt_tag(sky) is not None:
instrument = get_instrument(instrument_name)
cache_dir = "freq_maps_cache"
fg_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_fg_{sky}.pkl")
if not os.path.exists(fg_cache):
save_to_cache(nside, noise_ratio=0.0, instrument_name=instrument_name, sky=sky)
with open(fg_cache, "rb") as f:
fg_maps = pickle.load(f)
return np.array(instrument.frequency), fg_maps
_, stripped_sky = parse_sky_tag(sky)
return save_to_cache(
nside, noise_ratio=noise_ratio, instrument_name=instrument_name, sky=stripped_sky, key=key
)
[docs]
def load_fg_map(
nside: int, noise_ratio: float = 0.0, instrument_name: str = "LiteBIRD", sky: str = "c1d0s0"
) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]:
"""Load cached foreground-only frequency maps.
Args:
nside: HEALPix resolution parameter.
noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0.
instrument_name: Instrument configuration name. Defaults to "LiteBIRD".
sky: Sky model preset string, CMB automatically excluded. Defaults to "c1d0s0".
Returns:
A tuple (frequencies, freq_maps) containing only foreground contributions.
Example:
>>> freqs, fg_maps = load_fg_map(nside=64, sky="c1d1s1")
"""
if _parse_synt_tag(sky) is not None:
instrument = get_instrument(instrument_name)
fg_cache = os.path.join("freq_maps_cache", f"freq_maps_nside_{nside}_fg_{sky}.pkl")
if not os.path.exists(fg_cache):
raise FileNotFoundError(
f"FG cache not found for sky {sky}. "
f"Run: generate_data --sky {sky} --nside {nside}"
)
with open(fg_cache, "rb") as f:
return np.array(instrument.frequency), pickle.load(f)
_, stripped_sky = parse_sky_tag(sky)
return load_from_cache(
nside, noise_ratio=noise_ratio, instrument_name=instrument_name, sky=stripped_sky
)
[docs]
def save_cmb_map(nside: int, sky: str = "c1d0s0") -> Float[Array, "3 npix"]:
"""Generate and cache CMB-only maps for template generation.
Args:
nside: HEALPix resolution parameter.
sky: Sky model preset string. Defaults to "c1d0s0".
Returns:
CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB.
Example:
>>> cmb_map = save_cmb_map(nside=64, sky="c1d0s0")
"""
info(f"Generating CMB map for nside {nside}, sky {sky}")
# Define cache file path
cache_dir = "freq_maps_cache"
os.makedirs(cache_dir, exist_ok=True)
if _parse_synt_tag(sky) is not None:
cmb_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_cmb_{sky}.pkl")
if not os.path.exists(cmb_cache):
save_to_cache(nside, sky=sky)
with open(cmb_cache, "rb") as f:
return pickle.load(f)
cmb_tag, _ = parse_sky_tag(sky)
if cmb_tag is None:
npix = 12 * nside**2
return np.zeros((3, npix))
else:
cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{cmb_tag}.pkl")
match = re.match(r"cr(\d+)", cmb_tag)
if match:
r_exp = int(match.group(1))
r_val = r_exp * 0.001
# Use default seed=0 to match save_to_cache default
freq_maps = generate_custom_cmb(r_val, nside, seed=0)
else:
sky_obj = get_sky(nside, sky)
freq_maps = sky_obj.components[0].map.to_value()
with open(cache_file, "wb") as f:
pickle.dump(freq_maps, f)
success(f"Generated and saved freq_maps for nside {nside} and for tag {cmb_tag}.")
return freq_maps
[docs]
def load_cmb_map(nside: int, sky: str = "c1d0s0") -> Float[Array, "3 npix"]:
"""Load cached CMB-only maps.
Args:
nside: HEALPix resolution parameter.
sky: Sky model preset string. Defaults to "c1d0s0".
Returns:
CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB.
Raises:
FileNotFoundError: If cache file does not exist.
Example:
>>> cmb_map = load_cmb_map(nside=64, sky="c1d0s0")
"""
# Define cache file path
cache_dir = "freq_maps_cache"
if _parse_synt_tag(sky) is not None:
cmb_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_cmb_{sky}.pkl")
if not os.path.exists(cmb_cache):
raise FileNotFoundError(
f"CMB cache not found for sky {sky}. "
f"Run: generate_data --sky {sky} --nside {nside}"
)
with open(cmb_cache, "rb") as f:
return pickle.load(f)
cmb_tag, _ = parse_sky_tag(sky)
if cmb_tag is None:
npix = 12 * nside**2
return np.zeros((3, npix))
else:
cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{cmb_tag}.pkl")
if os.path.exists(cache_file):
with open(cache_file, "rb") as f:
freq_maps = pickle.load(f)
info(f"Loaded freq_maps for nside {nside} from cache.")
else:
raise FileNotFoundError(
f"Cache file for freq_maps with nside {nside} not found.\n"
f"Please generate it first by calling `generate_data --nside {nside}`."
)
return freq_maps
[docs]
def get_mixin_matrix_operator(
params: PyTree[Float[Array, " P"]],
patch_indices: PyTree[Int[Array, " P"]],
nu: Float[Array, " Nf"],
sky: SkyType,
dust_nu0: float,
synchrotron_nu0: float,
) -> tuple[MixingMatrixOperator, MixingMatrixOperator]:
"""Construct mixing matrix operators for CMB and foregrounds.
Args:
params: Spectral parameters (temp_dust, beta_dust, beta_pl).
patch_indices: Patch assignment indices for each parameter.
nu: Frequency array in GHz.
sky: Sky component dictionary from FURAX.
dust_nu0: Dust reference frequency in GHz.
synchrotron_nu0: Synchrotron reference frequency in GHz.
Returns:
A tuple (MixingMatrixOperator with CMB, MixingMatrixOperator without CMB).
Example:
>>> A, A_nocmb = get_mixin_matrix_operator(params, patches, nu, sky, 150.0, 20.0)
"""
first_element = next(iter(sky.values()))
size = first_element.shape[-1]
in_structure = first_element.structure_for((size,))
cmb = CMBOperator(nu, in_structure=in_structure)
dust = DustOperator(
nu,
frequency0=dust_nu0,
temperature=params["temp_dust"],
temperature_patch_indices=patch_indices["temp_dust_patches"],
beta=params["beta_dust"],
beta_patch_indices=patch_indices["beta_dust_patches"],
in_structure=in_structure,
)
synchrotron = SynchrotronOperator(
nu,
frequency0=synchrotron_nu0,
beta_pl=params["beta_pl"],
beta_pl_patch_indices=patch_indices["beta_pl_patches"],
in_structure=in_structure,
)
return MixingMatrixOperator(cmb=cmb, dust=dust, synchrotron=synchrotron), MixingMatrixOperator(
dust=dust, synchrotron=synchrotron
)
[docs]
def simulate_D_from_params(
params: PyTree[Float[Array, " P"]],
patch_indices: PyTree[Int[Array, " P"]],
nu: Float[Array, " Nf"],
sky: SkyType,
dust_nu0: float,
synchrotron_nu0: float,
) -> tuple[Stokes, Stokes]:
"""Simulate observed frequency maps given spectral parameters.
Args:
params: Spectral parameters (temp_dust, beta_dust, beta_pl).
patch_indices: Patch assignment indices for each parameter.
nu: Frequency array in GHz.
sky: Sky component dictionary.
dust_nu0: Dust reference frequency in GHz.
synchrotron_nu0: Synchrotron reference frequency in GHz.
Returns:
A tuple (d, d_nocmb) where d includes CMB and d_nocmb excludes it.
Example:
>>> d, d_nocmb = simulate_D_from_params(params, patches, nu, sky, 150.0, 20.0)
"""
A, A_nocmb = get_mixin_matrix_operator(
params, patch_indices, nu, sky, dust_nu0, synchrotron_nu0
)
d = A(sky)
sky_no_cmb = sky.copy()
sky_no_cmb.pop("cmb")
d_nocmb = A_nocmb(sky_no_cmb)
return d, d_nocmb
MASK_CHOICES = [
"ALL",
"GALACTIC",
"GAL020_U",
"GAL020_L",
"GAL020",
"GAL040_U",
"GAL040_L",
"GAL040",
"GAL060_U",
"GAL060_L",
"GAL060",
]
[docs]
def sanitize_mask_name(mask_expr: str) -> str:
"""Convert mask expression to valid folder name.
Args:
mask_expr: Mask expression potentially containing + (union) or - (subtract) operators.
Returns:
Sanitized folder name with operators replaced by descriptive names.
Example:
>>> sanitize_mask_name("GAL020+GAL040")
'GAL020_UNION_GAL040'
>>> sanitize_mask_name("ALL-GALACTIC")
'ALL_SUBTRACT_GALACTIC'
"""
sanitized = mask_expr.replace("+", "_UNION_").replace("-", "_SUBTRACT_")
return sanitized
def _parse_mask_expression(expr: str, nside: int) -> Bool[Array, " npix"]:
"""Parse and evaluate boolean mask expressions.
Supports left-to-right evaluation of expressions with + (union) and - (subtraction)
operators. Does not support parentheses.
Args:
expr: Mask expression with optional boolean operators.
Examples: "GAL020+GAL040", "ALL-GALACTIC", "GAL020+GAL040-GALACTIC"
nside: HEALPix resolution parameter.
Returns:
Boolean mask array where True indicates observed pixels.
Raises:
ValueError: If expression contains invalid mask names or syntax.
Example:
>>> mask = _parse_mask_expression("GAL020+GAL040", nside=64)
>>> mask = _parse_mask_expression("ALL-GALACTIC", nside=64)
"""
# Tokenize the expression while preserving operators
tokens = []
current_token = ""
for char in expr:
if char in ["+", "-"]:
if current_token:
tokens.append(current_token.strip())
current_token = ""
tokens.append(char)
else:
current_token += char
if current_token:
tokens.append(current_token.strip())
if not tokens:
raise ValueError(f"Empty mask expression: {expr}")
# Validate that we have alternating mask names and operators
if len(tokens) == 1:
# Single mask, no operators
mask_name = tokens[0]
if mask_name not in MASK_CHOICES:
raise ValueError(
f"Invalid mask name '{mask_name}' in expression '{expr}'. "
f"Choose from: {MASK_CHOICES}"
)
return get_mask(mask_name, nside)
# Multiple tokens - evaluate left to right
if len(tokens) % 2 == 0:
raise ValueError(
f"Invalid expression syntax: {expr}. Expected format: MASK [+/-] MASK [+/-] MASK ..."
)
# Start with first mask
result = None
i = 0
while i < len(tokens):
if i == 0:
# First token must be a mask name
mask_name = tokens[i]
if mask_name not in MASK_CHOICES:
raise ValueError(
f"Invalid mask name '{mask_name}' in expression '{expr}'. "
f"Choose from: {MASK_CHOICES}"
)
result = get_mask(mask_name, nside)
i += 1
else:
# Even indices are operators, odd are mask names
operator = tokens[i]
if operator not in ["+", "-"]:
raise ValueError(
f"Expected operator (+ or -) at position {i} in expression '{expr}', "
f"got '{operator}'"
)
if i + 1 >= len(tokens):
raise ValueError(f"Operator '{operator}' at end of expression '{expr}'")
mask_name = tokens[i + 1]
if mask_name not in MASK_CHOICES:
raise ValueError(
f"Invalid mask name '{mask_name}' in expression '{expr}'. "
f"Choose from: {MASK_CHOICES}"
)
next_mask = get_mask(mask_name, nside)
if operator == "+":
# Union
result = np.logical_or(result, next_mask)
elif operator == "-":
# Subtraction
result = np.logical_and(result, np.logical_not(next_mask))
i += 2
return result
def _get_or_generate_mask_file(nside: int) -> Path:
"""Get path to mask file, generating it from 2048 source if needed.
Args:
nside: HEALPix resolution parameter.
Returns:
Path to the mask file.
"""
mask_dir = Path(__file__).parent / "masks"
mask_file = mask_dir / f"GAL_PlanckMasks_{nside}.npz"
if mask_file.exists():
return mask_file
# Generate from 2048 source
source_file = mask_dir / "GAL_PlanckMasks_2048.npz"
masks_2048 = np.load(source_file)
downgraded = {}
for key in masks_2048.files:
downgraded[key] = hp.ud_grade(masks_2048[key] * 1.0, nside).astype(np.uint8)
np.savez(mask_file, **downgraded)
success(f"Generated and cached mask for nside {nside}")
return mask_file
[docs]
def get_mask(mask_name: str = "GAL020", nside: int = 64) -> Bool[Array, " npix"]:
"""Load and process galactic masks at specified resolution.
Args:
mask_name: Mask identifier (e.g., "GAL020", "GAL040", "GALACTIC") or boolean expression
(e.g., "GAL020+GAL040", "ALL-GALACTIC"). Defaults to "GAL020".
nside: HEALPix resolution parameter. Defaults to 64.
Returns:
Boolean mask array where True indicates observed pixels.
Raises:
ValueError: If mask_name is invalid.
Notes:
Available mask choices: ALL, GALACTIC, GAL020, GAL040, GAL060, and
their _U (upper) and _L (lower) hemisphere variants.
Boolean operations are supported:
- Use + for union (logical OR)
- Use - for subtraction (logical AND NOT)
- Expressions are evaluated left-to-right
- Examples: "GAL020+GAL040", "ALL-GALACTIC", "GAL020+GAL040-GALACTIC"
Masks are automatically generated and cached on first call for each nside.
Example:
>>> mask = get_mask("GAL020", nside=64)
>>> # Using boolean expression:
>>> mask_union = get_mask("GAL020+GAL040", nside=64)
"""
# Check if mask_name contains boolean operators
if "+" in mask_name or "-" in mask_name:
return _parse_mask_expression(mask_name, nside)
masks_file = _get_or_generate_mask_file(nside)
masks = np.load(masks_file)
if mask_name not in MASK_CHOICES:
raise ValueError(f"Invalid mask name: {mask_name}. Choose from: {MASK_CHOICES}")
npix = 12 * nside**2
ones = np.ones(npix, dtype=bool)
# Extract the masks (keys: "GAL020", "GAL040", "GAL060").
mask_GAL020 = masks["GAL020"]
mask_GAL040 = masks["GAL040"]
mask_GAL060 = masks["GAL060"]
mask_galactic = np.logical_and(ones, np.logical_not(mask_GAL060))
mask_GAL060 = np.logical_and(mask_GAL060, np.logical_not(mask_GAL040))
mask_GAL040 = np.logical_and(mask_GAL040, np.logical_not(mask_GAL020))
# Determine the HEALPix resolution (nside) from one of the masks.
nside = hp.get_nside(mask_GAL020)
# Get pixel indices and corresponding angular coordinates (theta, phi) in radians.
npix = hp.nside2npix(nside)
pix = np.arange(npix)
theta, phi = hp.pix2ang(nside, pix)
# Define upper and lower hemispheres based on theta.
# (Assuming theta < pi/2 corresponds to b > 0, i.e. the "upper" hemisphere.)
upper = theta < np.pi / 2
lower = theta >= np.pi / 2
zones = {}
zones["ALL"] = ones
# --- Define Zones ---
# GAL020 Upper ring and lower ring
zones["GAL020_U"] = np.logical_and(mask_GAL020, upper)
zones["GAL020_L"] = np.logical_and(mask_GAL020, lower)
zones["GAL020"] = mask_GAL020
# GAL040 Upper ring and lower ring
zones["GAL040_U"] = np.logical_and(mask_GAL040, upper)
zones["GAL040_L"] = np.logical_and(mask_GAL040, lower)
zones["GAL040"] = mask_GAL040
# GAL060 Upper ring and lower ring
zones["GAL060_U"] = np.logical_and(mask_GAL060, upper)
zones["GAL060_L"] = np.logical_and(mask_GAL060, lower)
zones["GAL060"] = mask_GAL060
# Galactic mask
zones["GALACTIC"] = mask_galactic
# Return the requested zone.
return zones[mask_name]
[docs]
def generate_needed_maps(
nside_list: list[int] | None = None,
noise_ratio_list: list[float] | None = None,
instrument_name: str = "LiteBIRD",
sky_list: list[str] | None = None,
) -> None:
"""Batch generate and cache all required frequency maps.
Args:
nside_list: HEALPix resolutions to generate. Defaults to [4, 8, 32, 64, 128].
noise_ratio_list: Noise ratio configurations. Defaults to [0.0, 1.0].
instrument_name: Instrument configuration. Defaults to "LiteBIRD".
sky_list: Sky model presets. Defaults to ["c1d0s0", "c1d1s1"].
Notes:
Generates full frequency maps, foreground-only maps, and CMB-only maps
for all combinations of input parameters.
Example:
>>> generate_needed_maps(nside_list=[64], noise_ratio_list=[0.0, 1.0])
"""
if nside_list is None:
nside_list = [4, 8, 32, 64, 128]
if noise_ratio_list is None:
noise_ratio_list = [0.0, 1.0]
if sky_list is None:
sky_list = ["c1d0s0", "c1d1s1"]
for nside in nside_list:
for noise_ratio in noise_ratio_list:
for sky in sky_list:
save_to_cache(
nside, noise_ratio=noise_ratio, instrument_name=instrument_name, sky=sky
)
for sky in sky_list:
for nside in nside_list:
save_fg_map(nside, noise_ratio=0.0, instrument_name=instrument_name, sky=sky)
save_cmb_map(nside, sky=sky)
def main():
parser = argparse.ArgumentParser(
description="Generate cached frequency maps for CMB component separation"
)
parser.add_argument(
"--nside",
type=int,
nargs="+",
default=[4, 8, 32, 64, 128],
help="HEALPix resolution(s) to generate maps for (default: 4 8 32 64 128)",
)
parser.add_argument(
"--noise-ratio",
type=float,
nargs="+",
default=[0.0, 1.0],
help="Noise ratio level(s) to generate (0.0=no noise, 1.0=100%% noise, default: 0.0 1.0)",
)
parser.add_argument(
"--instrument",
type=str,
default="LiteBIRD",
help="Instrument name (default: LiteBIRD)",
)
parser.add_argument(
"--sky",
type=str,
nargs="+",
default=["c1d0s0", "c1d1s1"],
help="Sky model tag(s) (default: c1d0s0 c1d1s1)",
)
args = parser.parse_args()
generate_needed_maps(
nside_list=args.nside,
noise_ratio_list=args.noise_ratio,
instrument_name=args.instrument,
sky_list=args.sky,
)
if __name__ == "__main__":
main()