{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FGBuster vs FURAX: Framework Comparison for CMB Component Separation\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CMBSciPol/furax-cs/blob/main/notebooks/01_FGBuster_vs_FURAX_Comparison.ipynb)\n", "\n", "## Learning Objectives\n", "\n", "By the end of this notebook, you will:\n", "- Understand the differences between traditional (FGBuster) and modern (FURAX) component separation frameworks\n", "- See the performance advantages of JAX over NumPy for CMB analysis\n", "- Learn how to implement and benchmark likelihood functions\n", "- Understand automatic differentiation benefits for parameter optimization\n", "\n", "## Background\n", "\n", "### The Component Separation Problem\n", "\n", "CMB observations contain multiple astrophysical components:\n", "- **CMB signal**: What we want to measure\n", "- **Galactic dust**: Modified blackbody emission\n", "- **Synchrotron**: Power-law emission from cosmic rays\n", "- **Instrumental noise**: Detector and systematic effects\n", "\n", "The challenge is to separate these components accurately to recover the CMB signal." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wassim/micromamba/envs/fg/lib/python3.11/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " pid, fd = os.forkpty()\n" ] } ], "source": [ "!pip install -q furax-cs\n", "!pip install --force-reinstall -r https://raw.githubusercontent.com/CMBSciPol/furax-cs/main/requirements.txt\n", "!pip install git+https://github.com/fgbuster/fgbuster" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Core scientific computing\n", "from functools import partial\n", "\n", "import furax_cs as fcs\n", "\n", "# JAX for high-performance computing\n", "import jax\n", "import jax.numpy as jnp\n", "import jaxopt\n", "\n", "# FGBuster - Traditional component separation framework\n", "from fgbuster import (\n", " CMB,\n", " Dust,\n", " Synchrotron,\n", " basic_comp_sep,\n", " get_instrument,\n", ")\n", "\n", "# FURAX - Modern JAX-based framework\n", "from furax import HomothetyOperator, tree\n", "from furax.obs.landscapes import Stokes\n", "from furax.obs.operators import (\n", " CMBOperator,\n", " DustOperator,\n", " MixingMatrixOperator,\n", " SynchrotronOperator,\n", ")\n", "\n", "# Set JAX to use 64-bit precision for scientific accuracy\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Generate Simulated Sky Maps\n", "\n", "We start by creating realistic CMB and foreground simulations using PySM (Python Sky Model). These simulated observations will serve as our test data for comparing the two frameworks.\n", "\n", "### Key Parameters:\n", "- **NSIDE = 64**: HEALPix resolution \n", "- **Instrument**: LiteBIRD frequency configuration (15 bands: 40-402 GHz)\n", "- **Components**: CMB + dust + synchrotron emission\n", "- **Stokes**: I, Q, U polarization parameters\n", "\n", "### Why Use Simulations?\n", "1. **Ground truth**: We know the input parameters\n", "2. **Controlled testing**: Compare framework accuracy\n", "3. **Reproducibility**: Same data for fair comparison\n", "\n", "**Note**: On HPC clusters without internet access, these maps are pre-cached using `generate_maps.py`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO] Loaded freq_maps for nside 64 from cache with noise_ratio 0.0.\n" ] } ], "source": [ "nsides = [64]\n", "for nside in nsides:\n", " fcs.save_to_cache(nside)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO] Loaded freq_maps for nside 64 from cache.\n", "freq_maps shape: (15, 3, 49152)\n" ] } ], "source": [ "nside = 64\n", "\n", "nu, freq_maps = fcs.load_from_cache(nside)\n", "# Check the shape of freq_maps\n", "print(\"freq_maps shape:\", freq_maps.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Convert Data to FURAX Format\n", "\n", "FURAX uses a structured data format called `Stokes` that organizes polarization data (I, Q, U) in a JAX-compatible format. This conversion step is essential for interfacing with FURAX's operator-based architecture.\n", "\n", "### Understanding the Data Structure\n", "\n", "The frequency maps have shape `(15, 3, 49152)` where:\n", "- **15**: Number of frequency channels (LiteBIRD bands)\n", "- **3**: Stokes parameters (I, Q, U polarization) \n", "- **49152**: HEALPix pixels (12 × 64² for NSIDE=64)\n", "\n", "**Note**: Although FURAX includes its own functions to create sky maps from PySM, we use `fgbuster` here to ensure that both methods receive identical inputs for comparison." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "StokesIQU(i=ShapeDtypeStruct(shape=(15, 49152), dtype=float64), q=ShapeDtypeStruct(shape=(15, 49152), dtype=float64), u=ShapeDtypeStruct(shape=(15, 49152), dtype=float64))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d = Stokes.from_stokes(I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :])\n", "d.structure" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Component Initialization and Reference Frequencies\n", "\n", "# Reference frequencies for component models (in GHz)\n", "dust_nu0 = 150.0 # Dust template frequency - near peak of dust SED\n", "synchrotron_nu0 = 20.0 # Synchrotron template frequency - low frequency where dominant\n", "\n", "# Get LiteBIRD instrument specification\n", "instrument = get_instrument(\"LiteBIRD\")\n", "\n", "# Define the astrophysical components for separation\n", "# Each component has its own spectral energy distribution (SED)\n", "components = [\n", " CMB(), # Blackbody at 2.725K (no free parameters)\n", " Dust(dust_nu0), # Modified blackbody (temp, beta_dust)\n", " Synchrotron(synchrotron_nu0), # Power law (beta_synchrotron)\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the Likelihood Function for Component Separation\n", "\n", "In this cell, we define the `negative_log_prob` function, which calculates the negative log-likelihood of observing the given data `d` based on the model parameters.\n", "\n", "The likelihood function is based on a quadratic form that includes the mixing matrix `A`, inverse noise covariance `N^{-1}`, and observed data `d`. The key term in the likelihood is:\n", "\n", "$$\n", "\\left(A^T N^{-1} d\\right)^T \\left(A^T N^{-1} A\\right)^{-1} \\left(A^T N^{-1} d\\right)\n", "$$\n", "\n", "### Explanation of Each Term\n", "\n", "1. **$A$**: The mixing matrix operator, which maps the component space to the observed frequency space.\n", "2. **$N^{-1}$**: The inverse of the noise covariance matrix, represented by `invN` in the code.\n", "3. **$d$**: The observed data, which is structured as a `Stokes` in Furax.\n", "\n", "### Implementation Details\n", "\n", "- **Transposing and Applying `A`**: `A.T(d)` applies the transpose of `A` to `d`, equivalent to the term $A^T d$.\n", "- **Computing the Likelihood**: The quadratic form is computed by applying $A^T N^{-1} d$, inverting $A^T N^{-1} A$, and performing matrix multiplications to evaluate the likelihood.\n", "- **Negative Log-Likelihood**: The final output of `negative_log_prob` is the negative of this log-likelihood value, allowing us to use it as a loss function for optimization.\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Initialize FURAX operators and noise model\n", "invN = HomothetyOperator(\n", " jnp.ones(1), in_structure=d.structure\n", ") # Identity noise (uniform weighting)\n", "DND = invN(d) @ d # Pre-compute noise-weighted data norm\n", "\n", "# Define input structure for component operators\n", "in_structure = d.structure_for((d.shape[1],)) # Structure for pixel-wise operations\n", "\n", "# True parameter values from simulation (our target for optimization)\n", "best_params = {\"temp_dust\": 20.0, \"beta_dust\": 1.54, \"beta_pl\": -3.0}\n", "\n", "# Reference frequencies (same as component initialization)\n", "dust_nu0 = 150.0\n", "synchrotron_nu0 = 20.0\n", "\n", "\n", "@jax.jit # JIT compilation for performance\n", "def negative_log_prob(params, d):\n", " \"\"\"\n", " Compute negative log-likelihood for component separation.\n", "\n", " This function implements the standard CMB likelihood:\n", " -ln(L) = (d - As)^T N^-1 (d - As) / 2 + const\n", "\n", " For the parametric case where we marginalize over amplitudes s,\n", " this reduces to the quadratic form shown in the mathematical derivation.\n", " \"\"\"\n", " # Create component operators with current parameter values\n", " cmb = CMBOperator(nu, in_structure=in_structure)\n", " dust = DustOperator(\n", " nu,\n", " frequency0=dust_nu0,\n", " temperature=params[\"temp_dust\"],\n", " beta=params[\"beta_dust\"],\n", " in_structure=in_structure,\n", " )\n", " synchrotron = SynchrotronOperator(\n", " nu,\n", " frequency0=synchrotron_nu0,\n", " beta_pl=params[\"beta_pl\"],\n", " in_structure=in_structure,\n", " )\n", "\n", " # Construct mixing matrix A(β) from component operators\n", " A = MixingMatrixOperator(cmb=cmb, dust=dust, synchrotron=synchrotron)\n", "\n", " # Compute likelihood using the marginalized form:\n", " # L = (A^T N^-1 d)^T (A^T N^-1 A)^-1 (A^T N^-1 d)\n", " x = (A.T @ invN)(d) # A^T N^-1 d\n", " s = (A.T @ invN @ A).I(x)\n", " L = tree.dot(x, s)\n", "\n", " return -L # Return negative for minimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Performance Analysis: Likelihood and Gradient Evaluation\n", "\n", "Now we benchmark the core performance characteristics of our FURAX implementation. This timing analysis is crucial for understanding the computational advantages of the JAX-based approach." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array(-2.7526803e+13, dtype=float64)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "negative_log_prob(best_params, d)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Likelihood at true parameters: -27526803021880.71\n", "Performance of the negative log-likelihood evaluation:\n", "9.04 ms ± 138 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "Performance of the gradient evaluation:\n", "20.1 ms ± 363 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "# Evaluate likelihood at true parameters (should be close to optimum)\n", "likelihood_value = negative_log_prob(best_params, d)\n", "print(f\"Likelihood at true parameters: {likelihood_value}\")\n", "\n", "# Performance benchmarking with proper JAX timing (block_until_ready)\n", "print(\"Performance of the negative log-likelihood evaluation:\")\n", "negative_log_prob(best_params, d).block_until_ready() # Warm-up for JIT\n", "%timeit negative_log_prob(best_params, d).block_until_ready()\n", "\n", "print(\"Performance of the gradient evaluation:\")\n", "# JAX automatic differentiation - no manual gradient coding required!\n", "grad_func = jax.grad(negative_log_prob)\n", "grad_func(best_params, d)[\"beta_pl\"].block_until_ready() # Warm-up\n", "%timeit grad_func(best_params, d)['beta_pl'].block_until_ready()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Check for Correctness\n", "\n", "In this cell, we perform a basic correctness check by comparing the gradients of the negative log-likelihood at two sets of parameters:\n", "\n", "1. **Wrong Parameters**: A set of parameters obtained by adding random noise to `best_params`.\n", "2. **Correct Parameters**: The original `best_params`.\n", "\n", "By calculating and comparing the gradient magnitudes (using the `max` reduction), we can verify that the gradient at the correct parameters is smaller, indicating proximity to an optimal or near-optimal point.\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Gradient magnitude at wrong parameters: -5.01e+08\n", "Gradient magnitude at correct parameters: 7.72e-01\n", "Ratio (wrong/correct): 6.5e+08\n" ] } ], "source": [ "# Gradient-based correctness validation\n", "# Generate perturbed parameters (intentionally wrong values)\n", "wrong_params = jax.tree.map(lambda x: x + jax.random.normal(jax.random.PRNGKey(0)), best_params)\n", "\n", "# Compare gradient magnitudes - smaller gradients indicate proximity to optimum\n", "grad_wrong = jax.tree.reduce(max, jax.grad(negative_log_prob)(wrong_params, d))\n", "grad_correct = jax.tree.reduce(max, jax.grad(negative_log_prob)(best_params, d))\n", "\n", "print(f\"Gradient magnitude at wrong parameters: {grad_wrong:.2e}\")\n", "print(f\"Gradient magnitude at correct parameters: {grad_correct:.2e}\")\n", "print(f\"Ratio (wrong/correct): {abs(grad_wrong / grad_correct):.1e}\")\n", "\n", "# The much smaller gradient at true parameters confirms our implementation is correct" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using FURAX's Built-in Likelihood Functions\n", "\n", "Up to this point, we have implemented our own custom likelihood function to understand the underlying mathematics. However, FURAX provides optimized, pre-built likelihood functions that are more robust and efficient for production use.\n", "\n", "### Why Use Built-in Functions?\n", "\n", "1. **Optimization**: FURAX's built-in functions are heavily optimized with JAX transformations\n", "2. **Robustness**: They include proper error handling and numerical stability checks\n", "3. **Consistency**: Standardized interface across different component separation methods\n", "4. **Maintenance**: Less code to maintain and debug\n", "\n", "### Comparing Custom vs Built-in Implementation\n", "\n", "Here we demonstrate that our custom implementation produces identical results to FURAX's built-in `negative_log_likelihood` function. This validation ensures our understanding is correct and that we can trust the built-in functions for future work." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from furax.obs import negative_log_likelihood\n", "\n", "negative_log_likelihood = partial(\n", " negative_log_likelihood, dust_nu0=dust_nu0, synchrotron_nu0=synchrotron_nu0\n", ")\n", "\n", "L = negative_log_likelihood(best_params, nu=nu, N=invN, d=d)\n", "\n", "assert jax.tree.all(\n", " jax.tree.map(lambda x, y: jnp.isclose(x, y, rtol=1e-15), L, negative_log_prob(best_params, d))\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Framework Validation: Comparing FURAX and FGBuster\n", "\n", "In this section, we validate our FURAX implementation by comparing it to FGBuster's well-established component separation algorithms. This cross-validation is crucial for ensuring that our modern JAX-based approach produces results consistent with the traditional NumPy-based methods.\n", "\n", "## The Importance of Cross-Framework Validation\n", "\n", "1. **Scientific Rigor**: Ensures our new methods don't introduce systematic biases\n", "2. **Trust Building**: Demonstrates that modern tools preserve scientific accuracy\n", "3. **Method Verification**: Confirms our implementation of component separation is correct\n", "4. **Performance Baseline**: Establishes a reference for speed and accuracy comparisons\n", "\n", "## Test Cases: From Simple to Complex\n", "\n", "We test both frameworks under increasingly challenging conditions:\n", "- **Case 1**: Optimal starting parameters (convergence test)\n", "- **Case 2**: Single incorrect parameter (robustness test)\n", "- **Case 3**: Multiple incorrect parameters (recovery test)\n", "- **Case 4**: All parameters wrong (optimization challenge)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Case 1 : Initial Validation: Using `best_params` as the Starting Point\n", "\n", "We begin the validation process by setting `best_params` as the initial point for both our custom implementation and FGBuster’s `c1d0s0` model. This allows us to directly compare the outputs and confirm that the models produce similar results when initialized wit\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']\n", "[ 1.54 20. -3. ]\n" ] } ], "source": [ "components[1]._set_default_of_free_symbols(beta_d=1.54, temp=20.0)\n", "components[2]._set_default_of_free_symbols(beta_pl=-3.0)\n", "\n", "result = basic_comp_sep(components, instrument, freq_maps)\n", "print(result.params)\n", "print(result.x)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wassim/micromamba/envs/fg/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:341: OptimizeWarning: Unknown solver options: maxiter\n", " res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),\n", " NIT NF F GTG\n", " 0 1 -2.752362190741737E+13 1.88746399E+20\n", "tnc: fscale = 5.94297e-10\n", " 1 2 -2.752362190741737E+13 1.88687369E+20\n", " 2 4 -2.752484506360604E+13 0.00000000E+00\n", "tnc: |pg| = 0 -> local minimum\n", " 2 4 -2.752484506360604E+13 0.00000000E+00\n", "tnc: Local minima reach (|pg| ~= 0)\n" ] } ], "source": [ "options = {\"disp\": True}\n", "scipy_solver = jaxopt.ScipyBoundedMinimize(\n", " fun=negative_log_likelihood, method=\"TNC\", jit=True, tol=1e-10, maxiter=1000, options=options\n", ")\n", "bounds = ((0.5, 10.0, -5.0), (0.6, 30.0, -1.0))\n", "result = scipy_solver.run(best_params, bounds, nu=nu, N=invN, d=d)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Case 2 : Validation with Incorrect Parameter: Setting `beta_dust` to a Wrong Value" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n", ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n", ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL_dB not updated\n", "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL = -inf\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":2: RuntimeWarning: overflow encountered in multiply\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL = -inf\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']\n", "[ 1.53195651 19.97377884 -2.94289685]\n" ] } ], "source": [ "components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)\n", "components[2]._set_default_of_free_symbols(beta_pl=-3.0)\n", "\n", "result = basic_comp_sep(components, instrument, freq_maps)\n", "print(result.params)\n", "print(result.x)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wassim/micromamba/envs/fg/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:341: OptimizeWarning: Unknown solver options: maxiter\n", " res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),\n", " NIT NF F GTG\n", " 0 1 -2.752362190741737E+13 1.88746399E+20\n", "tnc: fscale = 5.94297e-10\n", " 1 2 -2.752362190741737E+13 1.88687369E+20\n", " 2 4 -2.752484506360604E+13 0.00000000E+00\n", "tnc: |pg| = 0 -> local minimum\n", " 2 4 -2.752484506360604E+13 0.00000000E+00\n", "tnc: Local minima reach (|pg| ~= 0)\n" ] }, { "data": { "text/plain": [ "{'beta_dust': Array(0.5, dtype=float64),\n", " 'beta_pl': Array(10., dtype=float64),\n", " 'temp_dust': Array(-1., dtype=float64)}" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = {\"temp_dust\": 20.0, \"beta_dust\": 2.54, \"beta_pl\": -3.0}\n", "\n", "result = scipy_solver.run(params, bounds, nu=nu, N=invN, d=d)\n", "result.params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Case 3 : Setting `beta_dust` and `beta_pl` to Incorrect Values\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL_dB not updated\n", "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL = -inf\n", "SVD of A failed -> logL = -inf\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n", ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)*log(0.05*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SVD of A failed -> logL = -inf\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":2: RuntimeWarning: overflow encountered in multiply\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n", ":2: RuntimeWarning: overflow encountered in power\n", " return 568.8620443215493*(0.05*nu)**beta_pl*numpy.expm1(0.01760867023799751*nu)**2*exp(-0.01760867023799751*nu)/(nu**2*numpy.expm1(0.3521734047599502)**2)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SVD of A failed -> logL = -inf\n", "['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']\n", "[ 1.53034371 19.9741819 -5.99474719]\n" ] } ], "source": [ "components[1]._set_default_of_free_symbols(beta_d=2.54, temp=20.0)\n", "components[2]._set_default_of_free_symbols(beta_pl=-6.0)\n", "\n", "result = basic_comp_sep(components, instrument, freq_maps)\n", "print(result.params)\n", "print(result.x)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/wassim/micromamba/envs/fg/lib/python3.11/site-packages/jaxopt/_src/scipy_wrappers.py:341: OptimizeWarning: Unknown solver options: maxiter\n", " res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),\n", " NIT NF F GTG\n", " 0 1 -2.752362190741737E+13 1.88746399E+20\n", "tnc: fscale = 5.94297e-10\n", " 1 2 -2.752362190741737E+13 1.88687369E+20\n", " 2 4 -2.752484506360604E+13 0.00000000E+00\n", "tnc: |pg| = 0 -> local minimum\n", " 2 4 -2.752484506360604E+13 0.00000000E+00\n", "tnc: Local minima reach (|pg| ~= 0)\n" ] }, { "data": { "text/plain": [ "{'beta_dust': Array(0.5, dtype=float64),\n", " 'beta_pl': Array(10., dtype=float64),\n", " 'temp_dust': Array(-1., dtype=float64)}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = {\"temp_dust\": 20.0, \"beta_dust\": 2.54, \"beta_pl\": -6.0}\n", "\n", "result = scipy_solver.run(params, bounds, nu=nu, N=invN, d=d)\n", "result.params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Case 4 : Setting All Parameters to Incorrect Values\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Dust.beta_d', 'Dust.temp', 'Synchrotron.beta_pl']\n", "[ 1.53999883 20.00004072 -2.99997694]\n" ] } ], "source": [ "components[1]._set_default_of_free_symbols(beta_d=2.54, temp=25.0)\n", "components[2]._set_default_of_free_symbols(beta_pl=-6.0)\n", "\n", "result = basic_comp_sep(components, instrument, freq_maps)\n", "print(result.params)\n", "print(result.x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Advanced Optimization with JAX: Using Optax\n", "\n", "The JAX ecosystem provides powerful optimization tools through Optax, which offers state-of-the-art optimization algorithms with automatic differentiation support. This demonstrates one of the key advantages of the FURAX framework over traditional approaches.\n", "\n", "## Why JAX-based Optimization?\n", "\n", "1. **Automatic Differentiation**: No need for manual gradient computation\n", "2. **GPU Acceleration**: Seamless GPU support for large-scale problems\n", "3. **Modern Algorithms**: Access to latest optimization methods (Adam, L-BFGS, etc.)\n", "4. **Composability**: Easy to combine different optimization strategies\n", "\n", "be sure to use the analytical gradient of the likelihood function, since the native automatic differentation by Lineax can apply the operator to non physical cotangents leading to nans." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[INFO] key active_set: max_constraints_to_release=1 / 3 params\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100.00%|██████████| [00:22<00:00, 4.53%/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Final parameters: {'beta_dust': Array(1.53999737, dtype=float64), 'beta_pl': Array(-2.99993738, dtype=float64), 'temp_dust': Array(20.00008681, dtype=float64)}, number of evaluations: 836\n", "Initial Value: -27526803021880.703\n" ] } ], "source": [ "negative_log_likelihood_fn = partial(\n", " negative_log_likelihood,\n", " dust_nu0=dust_nu0,\n", " synchrotron_nu0=synchrotron_nu0,\n", " analytical_gradient=True,\n", ")\n", "\n", "params = {\"temp_dust\": 25.0, \"beta_dust\": 2.54, \"beta_pl\": -3.0}\n", "\n", "final_params, final_state = fcs.minimize(\n", " negative_log_likelihood_fn,\n", " params,\n", " \"ADABK0\",\n", " max_iter=1000,\n", " atol=1e-16,\n", " rtol=1e-16,\n", " lower_bound={\"temp_dust\": 10.0, \"beta_dust\": 0.5, \"beta_pl\": -5.0},\n", " upper_bound={\"temp_dust\": 40.0, \"beta_dust\": 3.0, \"beta_pl\": -1.0},\n", " nu=nu,\n", " N=invN,\n", " d=d,\n", ")\n", "\n", "print(f\"Final parameters: {final_params}, number of evaluations: {final_state.iter_num}\")\n", "print(f\"Initial Value: {negative_log_prob(final_params, d=d)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### L-BFGS: A Powerful Second-Order Method\n", "\n", "When working on noiseless data, L-BFGS can significantly speed up convergence due to its use of second-order information. Here, we switch to the `optax_lbfgs` optimizer to leverage these advantages." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100.00%|██████████| [00:04<00:00, 20.46%/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Final parameters: {'beta_dust': Array(1.53999969, dtype=float64), 'beta_pl': Array(-2.99999613, dtype=float64), 'temp_dust': Array(20.00000995, dtype=float64)}, number of evaluations: 108\n", "Initial Value: -27526803021880.715\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "negative_log_likelihood_fn = partial(\n", " negative_log_likelihood,\n", " dust_nu0=dust_nu0,\n", " synchrotron_nu0=synchrotron_nu0,\n", " analytical_gradient=True,\n", ")\n", "\n", "params = {\"temp_dust\": 25.0, \"beta_dust\": 2.54, \"beta_pl\": -3.0}\n", "\n", "final_params, final_state = fcs.minimize(\n", " negative_log_likelihood_fn,\n", " params,\n", " \"optax_lbfgs\",\n", " max_iter=1000,\n", " atol=1e-16,\n", " rtol=1e-16,\n", " precondition=True,\n", " lower_bound={\"temp_dust\": 10.0, \"beta_dust\": 0.5, \"beta_pl\": -5.0},\n", " upper_bound={\"temp_dust\": 40.0, \"beta_dust\": 3.0, \"beta_pl\": -1.0},\n", " nu=nu,\n", " N=invN,\n", " d=d,\n", ")\n", "\n", "print(f\"Final parameters: {final_params}, number of evaluations: {final_state.iter_num}\")\n", "print(f\"Initial Value: {negative_log_prob(final_params, d=d)}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }