Source code for furax_cs.multires_clusters

"""Multi-resolution clustering utilities using HEALPix ud_grade.

This module provides utilities for generating resolution-based patches
where pixels sharing the same low-resolution parent pixel are grouped
together. This is the approach used in the LiteBIRD PTEP methodology.
"""


import jax
import jax.numpy as jnp
import jax_healpy as jhp
from jax_healpy.clustering import get_cutout_from_mask
from jaxtyping import Array, Float, Int


def _ud_grade_patches(
    ipix: Int[Array, " npix"], nside_in: int, nside_out: int
) -> Int[Array, " npix"]:
    """Create resolution-based patch indices using ud_grade.

    Downgrade pixel indices to target resolution then upgrade back,
    creating uniform patches at the target resolution scale.

    Args:
        ipix: Full-sky pixel indices (0 to npix-1).
        nside_in: Input map resolution (nside).
        nside_out: Target resolution for patches. If 0, returns all zeros (single patch).

    Returns:
        Patch indices at original resolution with values grouped by
        target resolution pixels.
    """
    if nside_out == 0:
        return jnp.zeros_like(ipix)
    else:
        # Downgrade to target resolution
        lowered = jhp.ud_grade(ipix.astype(jnp.float64), nside_out=nside_out)
        # Upgrade back to original resolution
        return jhp.ud_grade(lowered, nside_out=nside_in).astype(jnp.int64)


def _normalize_array(arr: Int[Array, " n"]) -> Int[Array, " n"]:
    """Normalize patch indices to be contiguous from 0."""
    unique_vals, indices_norm = jnp.unique(arr, return_inverse=True, size=arr.size)
    return indices_norm.astype(jnp.int64)


[docs] def multires_clusters( mask: Float[Array, " npix"], indices: Int[Array, " n_valid"], target_ud_grade: dict[str, int], nside: int | None = None, ) -> dict[str, Int[Array, " n_valid"]]: """Generate multi-resolution cluster assignments using HEALPix ud_grade. This function creates resolution-based patches where pixels sharing the same low-resolution parent pixel are grouped together. This is the approach used in the LiteBIRD PTEP methodology. Args: mask: Full-sky HEALPix mask array (1 for valid pixels, 0 for masked). indices: Indices of unmasked pixels, typically from ``jnp.where(mask == 1)``. target_ud_grade: Dictionary mapping parameter names to target nside values. Expected keys: ``"beta_dust"``, ``"temp_dust"``, ``"beta_pl"``. Values are nside parameters (must be powers of 2). Use 0 for a single global patch. nside: Input map resolution. If None, inferred from mask size. Returns: Dictionary of normalized cluster assignments (cutout arrays). Keys are ``"{param}_patches"`` format. Values are int64 arrays. Example: >>> from furax_cs.data import get_mask >>> mask = get_mask("GAL020") >>> (indices,) = jnp.where(mask == 1) >>> target_resolutions = { ... "beta_dust": 64, # Full resolution ... "temp_dust": 32, # 4x fewer patches ... "beta_pl": 16, # 16x fewer patches ... } >>> clusters = multires_clusters(mask, indices, target_resolutions) >>> clusters["temp_dust_patches"].shape (n_unmasked_pixels,) """ # Infer nside from mask if not provided if nside is None: npix = mask.shape[0] nside = int(jnp.sqrt(npix / 12)) # Create full-sky pixel indices npix = nside**2 * 12 ipix = jnp.arange(npix) # Generate patch indices for each parameter patch_indices = {} for param_name, target_nside in target_ud_grade.items(): patch_key = f"{param_name}_patches" patch_indices[patch_key] = _ud_grade_patches(ipix, nside, int(target_nside)) # Extract cutout (only unmasked pixels) masked_patches = get_cutout_from_mask(patch_indices, indices) # Normalize indices for consistent indexing (0, 1, 2, ...) masked_patches = jax.tree.map(_normalize_array, masked_patches) return masked_patches