FGBuster vs FURAX: Framework Comparison for CMB Component Separation

Open In Colab

Learning Objectives

By the end of this notebook, you will:

  • Understand the differences between traditional (FGBuster) and modern (FURAX) component separation frameworks

  • See the performance advantages of JAX over NumPy for CMB analysis

  • Learn how to implement and benchmark likelihood functions

  • Understand automatic differentiation benefits for parameter optimization

Background

The Component Separation Problem

CMB observations contain multiple astrophysical components:

  • CMB signal: What we want to measure

  • Galactic dust: Modified blackbody emission

  • Synchrotron: Power-law emission from cosmic rays

  • Instrumental noise: Detector and systematic effects

The challenge is to separate these components accurately to recover the CMB signal.

[ ]:
!pip install -q furax-cs
!pip install --force-reinstall -r https://raw.githubusercontent.com/CMBSciPol/furax-cs/main/requirements.txt
!pip install git+https://github.com/fgbuster/fgbuster
/home/wassim/micromamba/envs/fg/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid, fd = os.forkpty()
[ ]:
# Core scientific computing
from functools import partial

import furax_cs as fcs

# JAX for high-performance computing
import jax
import jax.numpy as jnp
import jaxopt

# FGBuster - Traditional component separation framework
from fgbuster import (
    CMB,
    Dust,
    Synchrotron,
    basic_comp_sep,
    get_instrument,
)

# FURAX - Modern JAX-based framework
from furax import HomothetyOperator, tree
from furax.obs.landscapes import Stokes
from furax.obs.operators import (
    CMBOperator,
    DustOperator,
    MixingMatrixOperator,
    SynchrotronOperator,
)

# Set JAX to use 64-bit precision for scientific accuracy
jax.config.update("jax_enable_x64", True)

Step 1: Generate Simulated Sky Maps

We start by creating realistic CMB and foreground simulations using PySM (Python Sky Model). These simulated observations will serve as our test data for comparing the two frameworks.

Key Parameters:

  • NSIDE = 64: HEALPix resolution

  • Instrument: LiteBIRD frequency configuration (15 bands: 40-402 GHz)

  • Components: CMB + dust + synchrotron emission

  • Stokes: I, Q, U polarization parameters

Why Use Simulations?

  1. Ground truth: We know the input parameters

  2. Controlled testing: Compare framework accuracy

  3. Reproducibility: Same data for fair comparison

Note: On HPC clusters without internet access, these maps are pre-cached using generate_maps.py

[2]:
nsides = [64]
for nside in nsides:
    fcs.save_to_cache(nside)
[INFO] Loaded freq_maps for nside 64 from cache with noise_ratio 0.0.
[3]:
nside = 64

nu, freq_maps = fcs.load_from_cache(nside)
# Check the shape of freq_maps
print("freq_maps shape:", freq_maps.shape)
[INFO] Loaded freq_maps for nside 64 from cache.
freq_maps shape: (15, 3, 49152)

Step 2: Convert Data to FURAX Format

FURAX uses a structured data format called Stokes that organizes polarization data (I, Q, U) in a JAX-compatible format. This conversion step is essential for interfacing with FURAX’s operator-based architecture.

Understanding the Data Structure

The frequency maps have shape (15, 3, 49152) where:

  • 15: Number of frequency channels (LiteBIRD bands)

  • 3: Stokes parameters (I, Q, U polarization)

  • 49152: HEALPix pixels (12 × 64² for NSIDE=64)

Note: Although FURAX includes its own functions to create sky maps from PySM, we use fgbuster here to ensure that both methods receive identical inputs for comparison.

[4]:
d = Stokes.from_stokes(I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :])
d.structure
[4]:
StokesIQU(i=ShapeDtypeStruct(shape=(15, 49152), dtype=float64), q=ShapeDtypeStruct(shape=(15, 49152), dtype=float64), u=ShapeDtypeStruct(shape=(15, 49152), dtype=float64))
[5]:
# Component Initialization and Reference Frequencies

# Reference frequencies for component models (in GHz)
dust_nu0 = 150.0  # Dust template frequency - near peak of dust SED
synchrotron_nu0 = 20.0  # Synchrotron template frequency - low frequency where dominant

# Get LiteBIRD instrument specification
instrument = get_instrument("LiteBIRD")

# Define the astrophysical components for separation
# Each component has its own spectral energy distribution (SED)
components = [
    CMB(),  # Blackbody at 2.725K (no free parameters)
    Dust(dust_nu0),  # Modified blackbody (temp, beta_dust)
    Synchrotron(synchrotron_nu0),  # Power law (beta_synchrotron)
]

Defining the Likelihood Function for Component Separation

In this cell, we define the negative_log_prob function, which calculates the negative log-likelihood of observing the given data d based on the model parameters.

The likelihood function is based on a quadratic form that includes the mixing matrix A, inverse noise covariance N^{-1}, and observed data d. The key term in the likelihood is:

\[\left(A^T N^{-1} d\right)^T \left(A^T N^{-1} A\right)^{-1} \left(A^T N^{-1} d\right)\]

Explanation of Each Term

  1. :math:`A`: The mixing matrix operator, which maps the component space to the observed frequency space.

  2. :math:`N^{-1}`: The inverse of the noise covariance matrix, represented by invN in the code.

  3. :math:`d`: The observed data, which is structured as a Stokes in Furax.

Implementation Details

  • Transposing and Applying ``A``: A.T(d) applies the transpose of A to d, equivalent to the term \(A^T d\).

  • Computing the Likelihood: The quadratic form is computed by applying \(A^T N^{-1} d\), inverting \(A^T N^{-1} A\), and performing matrix multiplications to evaluate the likelihood.

  • Negative Log-Likelihood: The final output of negative_log_prob is the negative of this log-likelihood value, allowing us to use it as a loss function for optimization.

[11]:
# Initialize FURAX operators and noise model
invN = HomothetyOperator(
    jnp.ones(1), in_structure=d.structure
)  # Identity noise (uniform weighting)
DND = invN(d) @ d  # Pre-compute noise-weighted data norm

# Define input structure for component operators
in_structure = d.structure_for((d.shape[1],))  # Structure for pixel-wise operations

# True parameter values from simulation (our target for optimization)
best_params = {"temp_dust": 20.0, "beta_dust": 1.54, "beta_pl": -3.0}

# Reference frequencies (same as component initialization)
dust_nu0 = 150.0
synchrotron_nu0 = 20.0


@jax.jit  # JIT compilation for performance
def negative_log_prob(params, d):
    """
    Compute negative log-likelihood for component separation.

    This function implements the standard CMB likelihood:
    -ln(L) = (d - As)^T N^-1 (d - As) / 2 + const

    For the parametric case where we marginalize over amplitudes s,
    this reduces to the quadratic form shown in the mathematical derivation.
    """
    # Create component operators with current parameter values
    cmb = CMBOperator(nu, in_structure=in_structure)
    dust = DustOperator(
        nu,
        frequency0=dust_nu0,
        temperature=params["temp_dust"],
        beta=params["beta_dust"],
        in_structure=in_structure,
    )
    synchrotron = SynchrotronOperator(
        nu,
        frequency0=synchrotron_nu0,
        beta_pl=params["beta_pl"],
        in_structure=in_structure,
    )

    # Construct mixing matrix A(β) from component operators
    A = MixingMatrixOperator(cmb=cmb, dust=dust, synchrotron=synchrotron)

    # Compute likelihood using the marginalized form:
    # L = (A^T N^-1 d)^T (A^T N^-1 A)^-1 (A^T N^-1 d)
    x = (A.T @ invN)(d)  # A^T N^-1 d
    s = (A.T @ invN @ A).I(x)
    L = tree.dot(x, s)

    return -L  # Return negative for minimization

Performance Analysis: Likelihood and Gradient Evaluation

Now we benchmark the core performance characteristics of our FURAX implementation. This timing analysis is crucial for understanding the computational advantages of the JAX-based approach.

[12]:
negative_log_prob(best_params, d)
[12]:
Array(-2.7526803e+13, dtype=float64)
[13]:
# Evaluate likelihood at true parameters (should be close to optimum)
likelihood_value = negative_log_prob(best_params, d)
print(f"Likelihood at true parameters: {likelihood_value}")

# Performance benchmarking with proper JAX timing (block_until_ready)
print("Performance of the negative log-likelihood evaluation:")
negative_log_prob(best_params, d).block_until_ready()  # Warm-up for JIT
%timeit negative_log_prob(best_params, d).block_until_ready()

print("Performance of the gradient evaluation:")
# JAX automatic differentiation - no manual gradient coding required!
grad_func = jax.grad(negative_log_prob)
grad_func(best_params, d)["beta_pl"].block_until_ready()  # Warm-up
%timeit grad_func(best_params, d)['beta_pl'].block_until_ready()
Likelihood at true parameters: -27526803021880.71
Performance of the negative log-likelihood evaluation:
9.04 ms ± 138 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Performance of the gradient evaluation:
20.1 ms ± 363 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Check for Correctness

In this cell, we perform a basic correctness check by comparing the gradients of the negative log-likelihood at two sets of parameters:

  1. Wrong Parameters: A set of parameters obtained by adding random noise to best_params.

  2. Correct Parameters: The original best_params.

By calculating and comparing the gradient magnitudes (using the max reduction), we can verify that the gradient at the correct parameters is smaller, indicating proximity to an optimal or near-optimal point.

[14]:
# Gradient-based correctness validation
# Generate perturbed parameters (intentionally wrong values)
wrong_params = jax.tree.map(lambda x: x + jax.random.normal(jax.random.PRNGKey(0)), best_params)

# Compare gradient magnitudes - smaller gradients indicate proximity to optimum
grad_wrong = jax.tree.reduce(max, jax.grad(negative_log_prob)(wrong_params, d))
grad_correct = jax.tree.reduce(max, jax.grad(negative_log_prob)(best_params, d))

print(f"Gradient magnitude at wrong parameters: {grad_wrong:.2e}")
print(f"Gradient magnitude at correct parameters: {grad_correct:.2e}")
print(f"Ratio (wrong/correct): {abs(grad_wrong / grad_correct):.1e}")

# The much smaller gradient at true parameters confirms our implementation is correct
Gradient magnitude at wrong parameters: -5.01e+08
Gradient magnitude at correct parameters: 7.72e-01
Ratio (wrong/correct): 6.5e+08

Using FURAX’s Built-in Likelihood Functions

Up to this point, we have implemented our own custom likelihood function to understand the underlying mathematics. However, FURAX provides optimized, pre-built likelihood functions that are more robust and efficient for production use.

Why Use Built-in Functions?

  1. Optimization: FURAX’s built-in functions are heavily optimized with JAX transformations

  2. Robustness: They include proper error handling and numerical stability checks

  3. Consistency: Standardized interface across different component separation methods

  4. Maintenance: Less code to maintain and debug

Comparing Custom vs Built-in Implementation

Here we demonstrate that our custom implementation produces identical results to FURAX’s built-in negative_log_likelihood function. This validation ensures our understanding is correct and that we can trust the built-in functions for future work.

[17]:
from furax.obs import negative_log_likelihood

negative_log_likelihood = partial(
    negative_log_likelihood, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0
)

L = negative_log_likelihood(best_params, nu=nu, N=invN, d=d)

assert jax.tree.all(
    jax.tree.map(lambda x, y: jnp.isclose(x, y, rtol=1e-15), L, negative_log_prob(best_params, d))
)

Framework Validation: Comparing FURAX and FGBuster

In this section, we validate our FURAX implementation by comparing it to FGBuster’s well-established component separation algorithms. This cross-validation is crucial for ensuring that our modern JAX-based approach produces results consistent with the traditional NumPy-based methods.

The Importance of Cross-Framework Validation

  1. Scientific Rigor: Ensures our new methods don’t introduce systematic biases

  2. Trust Building: Demonstrates that modern tools preserve scientific accuracy

  3. Method Verification: Confirms our implementation of component separation is correct

  4. Performance Baseline: Establishes a reference for speed and accuracy comparisons

Test Cases: From Simple to Complex

We test both frameworks under increasingly challenging conditions:

  • Case 1: Optimal starting parameters (convergence test)

  • Case 2: Single incorrect parameter (robustness test)

  • Case 3: Multiple incorrect parameters (recovery test)

  • Case 4: All parameters wrong (optimization challenge)

Case 1 : Initial Validation: Using best_params as the Starting Point

We begin the validation process by setting best_params as the initial point for both our custom implementation and FGBuster’s c1d0s0 model. This allows us to directly compare the outputs and confirm that the models produce similar results when initialized wit

[18]:
components[1]._set_default_of_free_symbols(beta_d=1.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-3.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.54 20.   -3.  ]
[19]:
options = {"disp": True}
scipy_solver = jaxopt.ScipyBoundedMinimize(
    fun=negative_log_likelihood, method="TNC", jit=True, tol=1e-10, maxiter=1000, options=options
)
bounds = ((0.5, 10.0, -5.0), (0.6, 30.0, -1.0))
result = scipy_solver.run(best_params, bounds, nu=nu, N=invN, d=d)
/home/wassim/micromamba/envs/fg/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:341: OptimizeWarning: Unknown solver options: maxiter
  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
  NIT   NF   F                       GTG
    0    1 -2.752362190741737E+13   1.88746399E+20
tnc: fscale = 5.94297e-10
    1    2 -2.752362190741737E+13   1.88687369E+20
    2    4 -2.752484506360604E+13   0.00000000E+00
tnc: |pg| = 0 -> local minimum
    2    4 -2.752484506360604E+13   0.00000000E+00
tnc: Local minima reach (|pg| ~= 0)

Case 2 : Validation with Incorrect Parameter: Setting beta_dust to a Wrong Value

[20]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-3.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
<lambdifygenerated-10>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
SVD of A failed -> logL = -inf
SVD of A failed -> logL_dB not updated
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in multiply
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53195651 19.97377884 -2.94289685]
[21]:
params = {"temp_dust": 20.0, "beta_dust": 2.54, "beta_pl": -3.0}

result = scipy_solver.run(params, bounds, nu=nu, N=invN, d=d)
result.params
/home/wassim/micromamba/envs/fg/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:341: OptimizeWarning: Unknown solver options: maxiter
  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
  NIT   NF   F                       GTG
    0    1 -2.752362190741737E+13   1.88746399E+20
tnc: fscale = 5.94297e-10
    1    2 -2.752362190741737E+13   1.88687369E+20
    2    4 -2.752484506360604E+13   0.00000000E+00
tnc: |pg| = 0 -> local minimum
    2    4 -2.752484506360604E+13   0.00000000E+00
tnc: Local minima reach (|pg| ~= 0)
[21]:
{'beta_dust': Array(0.5, dtype=float64),
 'beta_pl': Array(10., dtype=float64),
 'temp_dust': Array(-1., dtype=float64)}

Case 3 : Setting beta_dust and beta_pl to Incorrect Values

[22]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)
components[2]._set_default_of_free_symbols(beta_pl=-6.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)
SVD of A failed -> logL = -inf
SVD of A failed -> logL_dB not updated
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
SVD of A failed -> logL = -inf
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
<lambdifygenerated-10>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
SVD of A failed -> logL = -inf
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in multiply
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
<lambdifygenerated-9>:2: RuntimeWarning: overflow encountered in power
  return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)
SVD of A failed -> logL = -inf
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53034371 19.9741819  -5.99474719]
[23]:
params = {"temp_dust": 20.0, "beta_dust": 2.54, "beta_pl": -6.0}

result = scipy_solver.run(params, bounds, nu=nu, N=invN, d=d)
result.params
/home/wassim/micromamba/envs/fg/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:341: OptimizeWarning: Unknown solver options: maxiter
  res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
  NIT   NF   F                       GTG
    0    1 -2.752362190741737E+13   1.88746399E+20
tnc: fscale = 5.94297e-10
    1    2 -2.752362190741737E+13   1.88687369E+20
    2    4 -2.752484506360604E+13   0.00000000E+00
tnc: |pg| = 0 -> local minimum
    2    4 -2.752484506360604E+13   0.00000000E+00
tnc: Local minima reach (|pg| ~= 0)
[23]:
{'beta_dust': Array(0.5, dtype=float64),
 'beta_pl': Array(10., dtype=float64),
 'temp_dust': Array(-1., dtype=float64)}

Case 4 : Setting All Parameters to Incorrect Values

[24]:
components[1]._set_default_of_free_symbols(beta_d=2.54, temp=25.0)
components[2]._set_default_of_free_symbols(beta_pl=-6.0)

result = basic_comp_sep(components, instrument, freq_maps)
print(result.params)
print(result.x)
['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']
[ 1.53999883 20.00004072 -2.99997694]

Advanced Optimization with JAX: Using Optax

The JAX ecosystem provides powerful optimization tools through Optax, which offers state-of-the-art optimization algorithms with automatic differentiation support. This demonstrates one of the key advantages of the FURAX framework over traditional approaches.

Why JAX-based Optimization?

  1. Automatic Differentiation: No need for manual gradient computation

  2. GPU Acceleration: Seamless GPU support for large-scale problems

  3. Modern Algorithms: Access to latest optimization methods (Adam, L-BFGS, etc.)

  4. Composability: Easy to combine different optimization strategies

be sure to use the analytical gradient of the likelihood function, since the native automatic differentation by Lineax can apply the operator to non physical cotangents leading to nans.

[25]:
negative_log_likelihood_fn = partial(
    negative_log_likelihood,
    dust_nu0=dust_nu0,
    synchrotron_nu0=synchrotron_nu0,
    analytical_gradient=True,
)

params = {"temp_dust": 25.0, "beta_dust": 2.54, "beta_pl": -3.0}

final_params, final_state = fcs.minimize(
    negative_log_likelihood_fn,
    params,
    "ADABK0",
    max_iter=1000,
    atol=1e-16,
    rtol=1e-16,
    lower_bound={"temp_dust": 10.0, "beta_dust": 0.5, "beta_pl": -5.0},
    upper_bound={"temp_dust": 40.0, "beta_dust": 3.0, "beta_pl": -1.0},
    nu=nu,
    N=invN,
    d=d,
)

print(f"Final parameters: {final_params}, number of evaluations: {final_state.iter_num}")
print(f"Initial Value: {negative_log_prob(final_params, d=d)}")
[INFO] key active_set: max_constraints_to_release=1 / 3 params
100.00%|██████████| [00:22<00:00,  4.53%/s]
Final parameters: {'beta_dust': Array(1.53999737, dtype=float64), 'beta_pl': Array(-2.99993738, dtype=float64), 'temp_dust': Array(20.00008681, dtype=float64)}, number of evaluations: 836
Initial Value: -27526803021880.703

L-BFGS: A Powerful Second-Order Method

When working on noiseless data, L-BFGS can significantly speed up convergence due to its use of second-order information. Here, we switch to the optax_lbfgs optimizer to leverage these advantages.

[26]:
negative_log_likelihood_fn = partial(
    negative_log_likelihood,
    dust_nu0=dust_nu0,
    synchrotron_nu0=synchrotron_nu0,
    analytical_gradient=True,
)

params = {"temp_dust": 25.0, "beta_dust": 2.54, "beta_pl": -3.0}

final_params, final_state = fcs.minimize(
    negative_log_likelihood_fn,
    params,
    "optax_lbfgs",
    max_iter=1000,
    atol=1e-16,
    rtol=1e-16,
    precondition=True,
    lower_bound={"temp_dust": 10.0, "beta_dust": 0.5, "beta_pl": -5.0},
    upper_bound={"temp_dust": 40.0, "beta_dust": 3.0, "beta_pl": -1.0},
    nu=nu,
    N=invN,
    d=d,
)

print(f"Final parameters: {final_params}, number of evaluations: {final_state.iter_num}")
print(f"Initial Value: {negative_log_prob(final_params, d=d)}")
100.00%|██████████| [00:04<00:00, 20.46%/s]
Final parameters: {'beta_dust': Array(1.53999969, dtype=float64), 'beta_pl': Array(-2.99999613, dtype=float64), 'temp_dust': Array(20.00000995, dtype=float64)}, number of evaluations: 108
Initial Value: -27526803021880.715