Source code for furax_cs.data.generate_maps

from __future__ import annotations

import argparse
import os
import pickle
import re
from pathlib import Path
from typing import TypeAlias

import astropy.units as u
import camb
import healpy as hp
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from camb import initialpower
from furax._instruments.sky import get_observation, get_sky
from furax.obs.operators import (
    CMBOperator,
    DustOperator,
    MixingMatrixOperator,
    SynchrotronOperator,
)
from furax.obs.stokes import Stokes, StokesQU
from jaxtyping import (
    Array,
    Bool,
    Float,
    Int,
    PRNGKeyArray,
    PyTree,  # pyright: ignore
)

from ..logging_utils import info, success
from .instruments import get_instrument

SkyType: TypeAlias = dict[str, Stokes]


def compute_Cl_from_camb(r: float, which: str = "total", lmax: int = 2000) -> np.ndarray:
    """Compute CMB power spectra from CAMB.

    Args:
        r: Tensor-to-scalar ratio.
        which: Spectrum type - "total" (lensed total) or "tensor".
        lmax: Maximum multipole.

    Returns:
        Power spectrum array of shape (lmax+1, 4) for TT, EE, BB, TE.
    """
    cosmo_params = camb.set_params(
        Alens=1.0,
        H0=67.5,
        ombh2=0.022,
        omch2=0.122,
        mnu=0.06,
        omk=0,
        tau=0.06,
        As=2e-9,
        ns=0.965,
        halofit_version="mead",
        max_l_tensor=lmax,
        max_eta_k_tensor=18000,
        parameterization="tensor_param_indeptilt",
    )
    cosmo_params.set_for_lmax(lmax, lens_potential_accuracy=1)
    cosmo_params.WantTensors = True

    infl_params = initialpower.InitialPowerLaw()
    infl_params.set_params(ns=0.96, r=r, parameterization="tensor_param_indeptilt", nt=0, ntrun=0)
    cosmo_params.InitPower = infl_params

    results = camb.get_results(cosmo_params)
    spectra = results.get_cmb_power_spectra(cosmo_params, CMB_unit="muK", raw_cl=True)
    return spectra[which]  # Shape: (lmax+1, 4) for TT, EE, BB, TE


_CL_PRIMORDIAL_R1: np.ndarray | None = None
_CL_LENSING: np.ndarray | None = None


def _get_template_spectra(lmax: int = 2000) -> tuple[np.ndarray, np.ndarray]:
    """Get or compute template spectra for linear approximation."""
    global _CL_PRIMORDIAL_R1, _CL_LENSING
    if _CL_PRIMORDIAL_R1 is None:
        _CL_PRIMORDIAL_R1 = compute_Cl_from_camb(r=1.0, which="tensor", lmax=lmax)
    if _CL_LENSING is None:
        _CL_LENSING = compute_Cl_from_camb(r=0.0, which="total", lmax=lmax)
    return _CL_PRIMORDIAL_R1, _CL_LENSING


[docs] class CMBLensedWithTensors: """CMB map generator with custom tensor-to-scalar ratio r. Uses CAMB for power spectra and healpy synfast for map generation. Power spectra are computed as: Cl(r) = r * Cl_tensor(r=1) + Cl_lensing Args: nside: HEALPix resolution parameter. r: Tensor-to-scalar ratio. Defaults to 0.0. cmb_seed: Random seed for map generation. lmax: Maximum multipole for power spectra. Defaults to 2000. """ def __init__( self, nside: int, r: float = 0.0, cmb_seed: int | None = None, lmax: int = 2000, ): self.nside = nside self.r = r self.lmax = lmax # Get template spectra (lazy-loaded) cl_primordial_r1, cl_lensing = _get_template_spectra(lmax) # Linear combination: Cl(r) = r * Cl_tensor(r=1) + Cl_lensing # Shape: (lmax+1, 4) for TT, EE, BB, TE cl_total = r * cl_primordial_r1 + cl_lensing # Extract spectra for synfast: TT, EE, BB, TE cl_tt = cl_total[:, 0] cl_ee = cl_total[:, 1] cl_bb = cl_total[:, 2] cl_te = cl_total[:, 3] # Generate maps with healpy synfast # Returns (T, Q, U) maps in uK_CMB # Set numpy random seed for reproducibility if cmb_seed is not None: np.random.seed(cmb_seed) self.map = np.array( hp.synfast( [cl_tt, cl_ee, cl_bb, cl_te], nside=nside, new=True, ) )
def parse_sky_tag(sky: str) -> tuple[str | None, str]: """Parse sky string to separate CMB and foreground tags. Args: sky: Sky model string (e.g., "c1d0s0", "cr3d0s0"). Returns: A tuple (cmb_tag, fg_tag). `cmb_tag` is None if no CMB present. Example: >>> cmb_tag, fg_tag = parse_sky_tag("c1d1s1") >>> print(cmb_tag, fg_tag) c1 d1s1 """ # Check for custom r pattern (crX) match = re.search(r"cr(\d+)", sky) if match: cmb_tag = match.group(0) fg_tag = sky.replace(cmb_tag, "") return cmb_tag, fg_tag # Legacy 2-char parsing tags = [sky[i : i + 2] for i in range(0, len(sky), 2)] cmb_tags = [t for t in tags if t.startswith("c")] if cmb_tags: cmb_tag = cmb_tags[0] fg_tags = [t for t in tags if not t.startswith("c")] fg_tag = "".join(fg_tags) return cmb_tag, fg_tag return None, sky def _parse_synt_tag(sky: str) -> tuple[str, int, int, int] | None: """Returns (base_sky, BD, TD, BS) if sky matches {base_sky}_synt_bd{N}_td{N}_bs{N}, else None.""" match = re.match(r"^(.+)_synt_bd(\d+)_td(\d+)_bs(\d+)$", sky) if match: return match.group(1), int(match.group(2)), int(match.group(3)), int(match.group(4)) return None
[docs] def generate_custom_cmb( r_value: float, nside: int, seed: int | None = None ) -> Float[Array, "3 npix"]: """Generate a CMB realization with a specific tensor-to-scalar ratio r. Uses CAMB + healpy synfast with linear approximation for B-modes. Args: r_value: Tensor-to-scalar ratio. nside: HEALPix resolution. seed: Random seed for generation. Returns: CMB map (3, npix) in uK_CMB. Example: >>> cmb_map = generate_custom_cmb(r_value=0.01, nside=64, seed=42) """ info(f"generating with r_value {r_value}") cmb = CMBLensedWithTensors(nside=nside, r=r_value, cmb_seed=seed) return cmb.map # Already (3, npix) numpy array
[docs] def save_to_cache( nside: int, noise_ratio: float = 0.0, instrument_name: str = "LiteBIRD", sky: str = "c1d0s0", key: PRNGKeyArray | None = None, ) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]: """Generate and cache frequency maps for component separation. Args: nside: HEALPix resolution parameter. noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0. instrument_name: Instrument configuration name. Defaults to "LiteBIRD". sky: Sky model preset string (e.g., "c1d0s0" for CMB only). Defaults to "c1d0s0". key: JAX random key for noise generation. Defaults to None. Returns: A tuple (frequencies, freq_maps) where frequencies are in GHz and freq_maps have shape (n_freq, 3, n_pix) for Stokes I, Q, U. Example: >>> freqs, maps = save_to_cache(nside=64, noise_ratio=0.0, sky="c1d0s0") """ if key is None: key = jr.PRNGKey(0) instrument = get_instrument(instrument_name) # Define cache file path cache_dir = "freq_maps_cache" os.makedirs(cache_dir, exist_ok=True) noise_str = f"noise_{int(noise_ratio * 100)}" if noise_ratio > 0 else "no_noise" cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{noise_str}_{sky}.pkl") synt = _parse_synt_tag(sky) if synt is not None: base_sky, bd, td, bs = synt full_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_no_noise_{sky}.pkl") fg_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_fg_{sky}.pkl") cmb_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_cmb_{sky}.pkl") if not ( os.path.exists(full_cache) and os.path.exists(fg_cache) and os.path.exists(cmb_cache) ): from ..kmeans_clusters import kmeans_clusters as _kmeans_clusters npix = 12 * nside**2 pysm_sky = get_sky(nside, base_sky) cmb_arr = np.array(pysm_sky.components[0].map.to_value()) # (3, npix) dust_arr = np.array(pysm_sky.components[1].get_emission(160 * u.GHz)) # (3, npix) sync_arr = np.array(pysm_sky.components[2].get_emission(20 * u.GHz)) # (3, npix) sky_synth = { "cmb": StokesQU(q=jnp.array(cmb_arr[1]), u=jnp.array(cmb_arr[2])), "dust": StokesQU(q=jnp.array(dust_arr[1]), u=jnp.array(dust_arr[2])), "synchrotron": StokesQU(q=jnp.array(sync_arr[1]), u=jnp.array(sync_arr[2])), } mask_full = jnp.ones(npix, dtype=bool) indices_full = jnp.arange(npix) n_regions_true = { "beta_dust_patches": bd, "temp_dust_patches": td, "beta_pl_patches": bs, } true_clusters = _kmeans_clusters( jax.random.key(0), mask_full, indices_full, n_regions_true ) base_params = {"beta_dust": 1.54, "temp_dust": 20.0, "beta_pl": -3.0} true_params_count = {"beta_dust": bd, "temp_dust": td, "beta_pl": bs} params_flat, tree_struct = jax.tree.flatten( jax.tree.map(lambda v, c: jnp.full((c,), v), base_params, true_params_count) ) perturbed = [ x + jr.normal(jr.PRNGKey(i), x.shape) * 0.2 for i, x in enumerate(params_flat) ] true_params = jax.tree.unflatten(tree_struct, perturbed) nu = instrument.frequency dust_nu0 = 160.0 synchrotron_nu0 = 20.0 d, d_nocmb = simulate_D_from_params( true_params, true_clusters, nu, sky_synth, dust_nu0, synchrotron_nu0 ) n_freq = len(nu) zeros_freq = np.zeros((n_freq, npix)) freq_maps = np.stack([zeros_freq, np.array(d.q), np.array(d.u)], axis=1) fg_maps = np.stack([zeros_freq, np.array(d_nocmb.q), np.array(d_nocmb.u)], axis=1) cmb_map = np.stack([cmb_arr[0], cmb_arr[1], cmb_arr[2]]) # (3, npix) full IQU with open(full_cache, "wb") as f: pickle.dump(freq_maps, f) with open(fg_cache, "wb") as f: pickle.dump(fg_maps, f) with open(cmb_cache, "wb") as f: pickle.dump(cmb_map, f) success(f"Generated synthetic sky cache for {sky}.") else: with open(full_cache, "rb") as f: freq_maps = pickle.load(f) info(f"Loaded synthetic sky cache for {sky}.") return np.array(instrument.frequency), freq_maps # Check if file exists, load if it does, otherwise create and save it r_val = None if os.path.exists(cache_file): with open(cache_file, "rb") as f: freq_maps = pickle.load(f) info(f"Loaded freq_maps for nside {nside} from cache with noise_ratio {noise_ratio}.") else: # Check for custom r CMB match = re.search(r"cr(\d+)", sky) custom_cmb_map = None fg_tag = sky if match: info(f"Detected custom r tag: {match.group(0)}") r_exp = int(match.group(1)) r_val = r_exp * 1e-3 info(f"Generating custom CMB with r={r_val}") fg_tag = sky.replace(match.group(0), "") # Derive seed from key if possible, else fixed # key is JAX key (uint32 array). Use first element. seed = int(key[1]) if key is not None else 0 custom_cmb_map = generate_custom_cmb(r_val, nside, seed=seed) else: r_val = 0.0 # Unused but here for Pyright # If we have custom CMB, we treat the rest as the "furax" sky tag_to_use = fg_tag if custom_cmb_map is not None else sky stokes_obs = get_observation( instrument, nside=nside, tag=tag_to_use, noise_ratio=noise_ratio, key=key, stokes_type="IQU", unit="uK_CMB", ) # Convert Stokes PyTree to numpy array (n_freq, 3, n_pix) freq_maps = np.stack( [np.array(stokes_obs.i), np.array(stokes_obs.q), np.array(stokes_obs.u)], axis=1 ) if custom_cmb_map is not None: # Add custom CMB (broadcasting over frequencies) # freq_maps: (n_freq, 3, npix) # custom_cmb_map: (3, npix) freq_maps += custom_cmb_map[None, ...] assert r_val is not None info(f"Added custom CMB with r={r_val} to maps.") # Save freq_maps to the cache with open(cache_file, "wb") as f: pickle.dump(freq_maps, f) success(f"Generated and saved freq_maps for nside {nside}.") return np.array(instrument.frequency), freq_maps
[docs] def load_from_cache( nside: int, noise_ratio: float = 0.0, instrument_name: str = "LiteBIRD", sky: str = "c1d0s0" ) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]: """Load cached frequency maps from disk. Args: nside: HEALPix resolution parameter. noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0. instrument_name: Instrument configuration name. Defaults to "LiteBIRD". sky: Sky model preset string. Defaults to "c1d0s0". Returns: A tuple (frequencies, freq_maps) loaded from cache. Raises: FileNotFoundError: If cache file does not exist. Example: >>> freqs, maps = load_from_cache(nside=64, noise_ratio=0.0, sky="c1d0s0") """ # Define cache file path instrument = get_instrument(instrument_name) noise_str = f"noise_{int(noise_ratio * 100)}" if noise_ratio > 0 else "no_noise" cache_dir = "freq_maps_cache" cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{noise_str}_{sky}.pkl") # Check if file exists and load if it does; otherwise raise an error with guidance if os.path.exists(cache_file): with open(cache_file, "rb") as f: freq_maps = pickle.load(f) info(f"Loaded freq_maps for nside {nside} from cache.") else: raise FileNotFoundError( f"Cache file for freq_maps with nside {nside} and noise_ratio {noise_ratio} not found.\n" f"Please generate it first by calling `generate_data --nside {nside}`." ) return np.array(instrument.frequency), freq_maps
[docs] def save_fg_map( nside: int, noise_ratio: float = 0.0, instrument_name: str = "LiteBIRD", sky: str = "c1d0s0", key: PRNGKeyArray | None = None, ) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]: """Generate and cache foreground-only frequency maps (CMB excluded). Args: nside: HEALPix resolution parameter. noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0. instrument_name: Instrument configuration name. Defaults to "LiteBIRD". sky: Sky model preset string, CMB component automatically removed. Defaults to "c1d0s0". key: JAX random key for noise generation. Defaults to None. Returns: A tuple (frequencies, freq_maps) containing only foreground contributions. Example: >>> freqs, fg_maps = save_fg_map(nside=64, sky="c1d1s1") """ info( f"Generating fg map for nside {nside}, noise_ratio {noise_ratio}, instrument {instrument_name}" ) if _parse_synt_tag(sky) is not None: instrument = get_instrument(instrument_name) cache_dir = "freq_maps_cache" fg_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_fg_{sky}.pkl") if not os.path.exists(fg_cache): save_to_cache(nside, noise_ratio=0.0, instrument_name=instrument_name, sky=sky) with open(fg_cache, "rb") as f: fg_maps = pickle.load(f) return np.array(instrument.frequency), fg_maps _, stripped_sky = parse_sky_tag(sky) return save_to_cache( nside, noise_ratio=noise_ratio, instrument_name=instrument_name, sky=stripped_sky, key=key )
[docs] def load_fg_map( nside: int, noise_ratio: float = 0.0, instrument_name: str = "LiteBIRD", sky: str = "c1d0s0" ) -> tuple[Float[Array, " freqs"], Float[Array, " freqs 3 npix"]]: """Load cached foreground-only frequency maps. Args: nside: HEALPix resolution parameter. noise_ratio: Noise level ratio (0.0 = no noise, 1.0 = 100% noise). Defaults to 0.0. instrument_name: Instrument configuration name. Defaults to "LiteBIRD". sky: Sky model preset string, CMB automatically excluded. Defaults to "c1d0s0". Returns: A tuple (frequencies, freq_maps) containing only foreground contributions. Example: >>> freqs, fg_maps = load_fg_map(nside=64, sky="c1d1s1") """ if _parse_synt_tag(sky) is not None: instrument = get_instrument(instrument_name) fg_cache = os.path.join("freq_maps_cache", f"freq_maps_nside_{nside}_fg_{sky}.pkl") if not os.path.exists(fg_cache): raise FileNotFoundError( f"FG cache not found for sky {sky}. " f"Run: generate_data --sky {sky} --nside {nside}" ) with open(fg_cache, "rb") as f: return np.array(instrument.frequency), pickle.load(f) _, stripped_sky = parse_sky_tag(sky) return load_from_cache( nside, noise_ratio=noise_ratio, instrument_name=instrument_name, sky=stripped_sky )
[docs] def save_cmb_map(nside: int, sky: str = "c1d0s0") -> Float[Array, "3 npix"]: """Generate and cache CMB-only maps for template generation. Args: nside: HEALPix resolution parameter. sky: Sky model preset string. Defaults to "c1d0s0". Returns: CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB. Example: >>> cmb_map = save_cmb_map(nside=64, sky="c1d0s0") """ info(f"Generating CMB map for nside {nside}, sky {sky}") # Define cache file path cache_dir = "freq_maps_cache" os.makedirs(cache_dir, exist_ok=True) if _parse_synt_tag(sky) is not None: cmb_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_cmb_{sky}.pkl") if not os.path.exists(cmb_cache): save_to_cache(nside, sky=sky) with open(cmb_cache, "rb") as f: return pickle.load(f) cmb_tag, _ = parse_sky_tag(sky) if cmb_tag is None: npix = 12 * nside**2 return np.zeros((3, npix)) else: cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{cmb_tag}.pkl") match = re.match(r"cr(\d+)", cmb_tag) if match: r_exp = int(match.group(1)) r_val = r_exp * 0.001 # Use default seed=0 to match save_to_cache default freq_maps = generate_custom_cmb(r_val, nside, seed=0) else: sky_obj = get_sky(nside, sky) freq_maps = sky_obj.components[0].map.to_value() with open(cache_file, "wb") as f: pickle.dump(freq_maps, f) success(f"Generated and saved freq_maps for nside {nside} and for tag {cmb_tag}.") return freq_maps
[docs] def load_cmb_map(nside: int, sky: str = "c1d0s0") -> Float[Array, "3 npix"]: """Load cached CMB-only maps. Args: nside: HEALPix resolution parameter. sky: Sky model preset string. Defaults to "c1d0s0". Returns: CMB map with shape (3, n_pix) for Stokes I, Q, U, or zeros if no CMB. Raises: FileNotFoundError: If cache file does not exist. Example: >>> cmb_map = load_cmb_map(nside=64, sky="c1d0s0") """ # Define cache file path cache_dir = "freq_maps_cache" if _parse_synt_tag(sky) is not None: cmb_cache = os.path.join(cache_dir, f"freq_maps_nside_{nside}_cmb_{sky}.pkl") if not os.path.exists(cmb_cache): raise FileNotFoundError( f"CMB cache not found for sky {sky}. " f"Run: generate_data --sky {sky} --nside {nside}" ) with open(cmb_cache, "rb") as f: return pickle.load(f) cmb_tag, _ = parse_sky_tag(sky) if cmb_tag is None: npix = 12 * nside**2 return np.zeros((3, npix)) else: cache_file = os.path.join(cache_dir, f"freq_maps_nside_{nside}_{cmb_tag}.pkl") if os.path.exists(cache_file): with open(cache_file, "rb") as f: freq_maps = pickle.load(f) info(f"Loaded freq_maps for nside {nside} from cache.") else: raise FileNotFoundError( f"Cache file for freq_maps with nside {nside} not found.\n" f"Please generate it first by calling `generate_data --nside {nside}`." ) return freq_maps
[docs] def get_mixin_matrix_operator( params: PyTree[Float[Array, " P"]], patch_indices: PyTree[Int[Array, " P"]], nu: Float[Array, " Nf"], sky: SkyType, dust_nu0: float, synchrotron_nu0: float, ) -> tuple[MixingMatrixOperator, MixingMatrixOperator]: """Construct mixing matrix operators for CMB and foregrounds. Args: params: Spectral parameters (temp_dust, beta_dust, beta_pl). patch_indices: Patch assignment indices for each parameter. nu: Frequency array in GHz. sky: Sky component dictionary from FURAX. dust_nu0: Dust reference frequency in GHz. synchrotron_nu0: Synchrotron reference frequency in GHz. Returns: A tuple (MixingMatrixOperator with CMB, MixingMatrixOperator without CMB). Example: >>> A, A_nocmb = get_mixin_matrix_operator(params, patches, nu, sky, 150.0, 20.0) """ first_element = next(iter(sky.values())) size = first_element.shape[-1] in_structure = first_element.structure_for((size,)) cmb = CMBOperator(nu, in_structure=in_structure) dust = DustOperator( nu, frequency0=dust_nu0, temperature=params["temp_dust"], temperature_patch_indices=patch_indices["temp_dust_patches"], beta=params["beta_dust"], beta_patch_indices=patch_indices["beta_dust_patches"], in_structure=in_structure, ) synchrotron = SynchrotronOperator( nu, frequency0=synchrotron_nu0, beta_pl=params["beta_pl"], beta_pl_patch_indices=patch_indices["beta_pl_patches"], in_structure=in_structure, ) return MixingMatrixOperator(cmb=cmb, dust=dust, synchrotron=synchrotron), MixingMatrixOperator( dust=dust, synchrotron=synchrotron )
[docs] def simulate_D_from_params( params: PyTree[Float[Array, " P"]], patch_indices: PyTree[Int[Array, " P"]], nu: Float[Array, " Nf"], sky: SkyType, dust_nu0: float, synchrotron_nu0: float, ) -> tuple[Stokes, Stokes]: """Simulate observed frequency maps given spectral parameters. Args: params: Spectral parameters (temp_dust, beta_dust, beta_pl). patch_indices: Patch assignment indices for each parameter. nu: Frequency array in GHz. sky: Sky component dictionary. dust_nu0: Dust reference frequency in GHz. synchrotron_nu0: Synchrotron reference frequency in GHz. Returns: A tuple (d, d_nocmb) where d includes CMB and d_nocmb excludes it. Example: >>> d, d_nocmb = simulate_D_from_params(params, patches, nu, sky, 150.0, 20.0) """ A, A_nocmb = get_mixin_matrix_operator( params, patch_indices, nu, sky, dust_nu0, synchrotron_nu0 ) d = A(sky) sky_no_cmb = sky.copy() sky_no_cmb.pop("cmb") d_nocmb = A_nocmb(sky_no_cmb) return d, d_nocmb
MASK_CHOICES = [ "ALL", "GALACTIC", "GAL020_U", "GAL020_L", "GAL020", "GAL040_U", "GAL040_L", "GAL040", "GAL060_U", "GAL060_L", "GAL060", ]
[docs] def sanitize_mask_name(mask_expr: str) -> str: """Convert mask expression to valid folder name. Args: mask_expr: Mask expression potentially containing + (union) or - (subtract) operators. Returns: Sanitized folder name with operators replaced by descriptive names. Example: >>> sanitize_mask_name("GAL020+GAL040") 'GAL020_UNION_GAL040' >>> sanitize_mask_name("ALL-GALACTIC") 'ALL_SUBTRACT_GALACTIC' """ sanitized = mask_expr.replace("+", "_UNION_").replace("-", "_SUBTRACT_") return sanitized
def _parse_mask_expression(expr: str, nside: int) -> Bool[Array, " npix"]: """Parse and evaluate boolean mask expressions. Supports left-to-right evaluation of expressions with + (union) and - (subtraction) operators. Does not support parentheses. Args: expr: Mask expression with optional boolean operators. Examples: "GAL020+GAL040", "ALL-GALACTIC", "GAL020+GAL040-GALACTIC" nside: HEALPix resolution parameter. Returns: Boolean mask array where True indicates observed pixels. Raises: ValueError: If expression contains invalid mask names or syntax. Example: >>> mask = _parse_mask_expression("GAL020+GAL040", nside=64) >>> mask = _parse_mask_expression("ALL-GALACTIC", nside=64) """ # Tokenize the expression while preserving operators tokens = [] current_token = "" for char in expr: if char in ["+", "-"]: if current_token: tokens.append(current_token.strip()) current_token = "" tokens.append(char) else: current_token += char if current_token: tokens.append(current_token.strip()) if not tokens: raise ValueError(f"Empty mask expression: {expr}") # Validate that we have alternating mask names and operators if len(tokens) == 1: # Single mask, no operators mask_name = tokens[0] if mask_name not in MASK_CHOICES: raise ValueError( f"Invalid mask name '{mask_name}' in expression '{expr}'. " f"Choose from: {MASK_CHOICES}" ) return get_mask(mask_name, nside) # Multiple tokens - evaluate left to right if len(tokens) % 2 == 0: raise ValueError( f"Invalid expression syntax: {expr}. Expected format: MASK [+/-] MASK [+/-] MASK ..." ) # Start with first mask result = None i = 0 while i < len(tokens): if i == 0: # First token must be a mask name mask_name = tokens[i] if mask_name not in MASK_CHOICES: raise ValueError( f"Invalid mask name '{mask_name}' in expression '{expr}'. " f"Choose from: {MASK_CHOICES}" ) result = get_mask(mask_name, nside) i += 1 else: # Even indices are operators, odd are mask names operator = tokens[i] if operator not in ["+", "-"]: raise ValueError( f"Expected operator (+ or -) at position {i} in expression '{expr}', " f"got '{operator}'" ) if i + 1 >= len(tokens): raise ValueError(f"Operator '{operator}' at end of expression '{expr}'") mask_name = tokens[i + 1] if mask_name not in MASK_CHOICES: raise ValueError( f"Invalid mask name '{mask_name}' in expression '{expr}'. " f"Choose from: {MASK_CHOICES}" ) next_mask = get_mask(mask_name, nside) if operator == "+": # Union result = np.logical_or(result, next_mask) elif operator == "-": # Subtraction result = np.logical_and(result, np.logical_not(next_mask)) i += 2 return result def _get_or_generate_mask_file(nside: int) -> Path: """Get path to mask file, generating it from 2048 source if needed. Args: nside: HEALPix resolution parameter. Returns: Path to the mask file. """ mask_dir = Path(__file__).parent / "masks" mask_file = mask_dir / f"GAL_PlanckMasks_{nside}.npz" if mask_file.exists(): return mask_file # Generate from 2048 source source_file = mask_dir / "GAL_PlanckMasks_2048.npz" masks_2048 = np.load(source_file) downgraded = {} for key in masks_2048.files: downgraded[key] = hp.ud_grade(masks_2048[key] * 1.0, nside).astype(np.uint8) np.savez(mask_file, **downgraded) success(f"Generated and cached mask for nside {nside}") return mask_file
[docs] def get_mask(mask_name: str = "GAL020", nside: int = 64) -> Bool[Array, " npix"]: """Load and process galactic masks at specified resolution. Args: mask_name: Mask identifier (e.g., "GAL020", "GAL040", "GALACTIC") or boolean expression (e.g., "GAL020+GAL040", "ALL-GALACTIC"). Defaults to "GAL020". nside: HEALPix resolution parameter. Defaults to 64. Returns: Boolean mask array where True indicates observed pixels. Raises: ValueError: If mask_name is invalid. Notes: Available mask choices: ALL, GALACTIC, GAL020, GAL040, GAL060, and their _U (upper) and _L (lower) hemisphere variants. Boolean operations are supported: - Use + for union (logical OR) - Use - for subtraction (logical AND NOT) - Expressions are evaluated left-to-right - Examples: "GAL020+GAL040", "ALL-GALACTIC", "GAL020+GAL040-GALACTIC" Masks are automatically generated and cached on first call for each nside. Example: >>> mask = get_mask("GAL020", nside=64) >>> # Using boolean expression: >>> mask_union = get_mask("GAL020+GAL040", nside=64) """ # Check if mask_name contains boolean operators if "+" in mask_name or "-" in mask_name: return _parse_mask_expression(mask_name, nside) masks_file = _get_or_generate_mask_file(nside) masks = np.load(masks_file) if mask_name not in MASK_CHOICES: raise ValueError(f"Invalid mask name: {mask_name}. Choose from: {MASK_CHOICES}") npix = 12 * nside**2 ones = np.ones(npix, dtype=bool) # Extract the masks (keys: "GAL020", "GAL040", "GAL060"). mask_GAL020 = masks["GAL020"] mask_GAL040 = masks["GAL040"] mask_GAL060 = masks["GAL060"] mask_galactic = np.logical_and(ones, np.logical_not(mask_GAL060)) mask_GAL060 = np.logical_and(mask_GAL060, np.logical_not(mask_GAL040)) mask_GAL040 = np.logical_and(mask_GAL040, np.logical_not(mask_GAL020)) # Determine the HEALPix resolution (nside) from one of the masks. nside = hp.get_nside(mask_GAL020) # Get pixel indices and corresponding angular coordinates (theta, phi) in radians. npix = hp.nside2npix(nside) pix = np.arange(npix) theta, phi = hp.pix2ang(nside, pix) # Define upper and lower hemispheres based on theta. # (Assuming theta < pi/2 corresponds to b > 0, i.e. the "upper" hemisphere.) upper = theta < np.pi / 2 lower = theta >= np.pi / 2 zones = {} zones["ALL"] = ones # --- Define Zones --- # GAL020 Upper ring and lower ring zones["GAL020_U"] = np.logical_and(mask_GAL020, upper) zones["GAL020_L"] = np.logical_and(mask_GAL020, lower) zones["GAL020"] = mask_GAL020 # GAL040 Upper ring and lower ring zones["GAL040_U"] = np.logical_and(mask_GAL040, upper) zones["GAL040_L"] = np.logical_and(mask_GAL040, lower) zones["GAL040"] = mask_GAL040 # GAL060 Upper ring and lower ring zones["GAL060_U"] = np.logical_and(mask_GAL060, upper) zones["GAL060_L"] = np.logical_and(mask_GAL060, lower) zones["GAL060"] = mask_GAL060 # Galactic mask zones["GALACTIC"] = mask_galactic # Return the requested zone. return zones[mask_name]
[docs] def generate_needed_maps( nside_list: list[int] | None = None, noise_ratio_list: list[float] | None = None, instrument_name: str = "LiteBIRD", sky_list: list[str] | None = None, ) -> None: """Batch generate and cache all required frequency maps. Args: nside_list: HEALPix resolutions to generate. Defaults to [4, 8, 32, 64, 128]. noise_ratio_list: Noise ratio configurations. Defaults to [0.0, 1.0]. instrument_name: Instrument configuration. Defaults to "LiteBIRD". sky_list: Sky model presets. Defaults to ["c1d0s0", "c1d1s1"]. Notes: Generates full frequency maps, foreground-only maps, and CMB-only maps for all combinations of input parameters. Example: >>> generate_needed_maps(nside_list=[64], noise_ratio_list=[0.0, 1.0]) """ if nside_list is None: nside_list = [4, 8, 32, 64, 128] if noise_ratio_list is None: noise_ratio_list = [0.0, 1.0] if sky_list is None: sky_list = ["c1d0s0", "c1d1s1"] for nside in nside_list: for noise_ratio in noise_ratio_list: for sky in sky_list: save_to_cache( nside, noise_ratio=noise_ratio, instrument_name=instrument_name, sky=sky ) for sky in sky_list: for nside in nside_list: save_fg_map(nside, noise_ratio=0.0, instrument_name=instrument_name, sky=sky) save_cmb_map(nside, sky=sky)
def main(): parser = argparse.ArgumentParser( description="Generate cached frequency maps for CMB component separation" ) parser.add_argument( "--nside", type=int, nargs="+", default=[4, 8, 32, 64, 128], help="HEALPix resolution(s) to generate maps for (default: 4 8 32 64 128)", ) parser.add_argument( "--noise-ratio", type=float, nargs="+", default=[0.0, 1.0], help="Noise ratio level(s) to generate (0.0=no noise, 1.0=100%% noise, default: 0.0 1.0)", ) parser.add_argument( "--instrument", type=str, default="LiteBIRD", help="Instrument name (default: LiteBIRD)", ) parser.add_argument( "--sky", type=str, nargs="+", default=["c1d0s0", "c1d1s1"], help="Sky model tag(s) (default: c1d0s0 c1d1s1)", ) args = parser.parse_args() generate_needed_maps( nside_list=args.nside, noise_ratio_list=args.noise_ratio, instrument_name=args.instrument, sky_list=args.sky, ) if __name__ == "__main__": main()