Source code for cadre.minimize

from __future__ import annotations

from collections.abc import Callable
from typing import Any, cast

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx
from jaxtyping import (
    Array,
    Float,
    PyTree,  # pyright: ignore
    Scalar,
)

from ._compat import requires_scipy
from ._logging import warning
from .solvers import SELFCONDITIONED_SOLVERS, SOLVER_NAMES, get_solver
from .utils import condition

try:
    import jaxopt  # noqa: F401
except ImportError:
    pass

# =============================================================================
# SCIPY MINIMIZE WITH VMAP SUPPORT
# =============================================================================


class ScipyMinimizeState(eqx.Module):
    """State returned by scipy minimize via pure_callback.

    This equinox module holds the optimization result in a JAX-compatible format
    that can be used with vmap/lax.map.

    Attributes
    ----------
    params : PyTree
        Optimized parameters.
    fun_val : Scalar
        Final objective function value (scalar).
    success : Scalar
        Whether optimization converged successfully (bool scalar).
    iter_num : Scalar
        Number of iterations performed (int32 scalar).
    """

    params: PyTree[Float[Array, " P"]]
    fun_val: Scalar
    success: Scalar
    iter_num: Scalar


@requires_scipy
def scipy_minimize(
    fn: Callable[..., Scalar],
    init_params: PyTree[Float[Array, " P"]],
    lower_bound: PyTree[Float[Array, " P"]] | None = None,
    upper_bound: PyTree[Float[Array, " P"]] | None = None,
    method: str = "tnc",
    maxiter: int = 1000,
    **fn_kwargs: Any,
) -> ScipyMinimizeState:
    """Scipy minimize wrapper that supports vmap via jax.pure_callback.

    This function wraps scipy optimization in a way that is compatible with
    JAX transformations like vmap and lax.map. It uses jax.pure_callback to
    call the host-side scipy solver.

    Parameters
    ----------
    fn : Callable
        Objective function to minimize. Should accept (params, **fn_kwargs).
    init_params : PyTree
        Initial parameter values.
    lower_bound : PyTree, optional
        Lower bounds for parameters. Same shape as init_params.
    upper_bound : PyTree, optional
        Upper bounds for parameters. Same shape as init_params.
    method : str
        Scipy optimization method (default "tnc").
    maxiter : int
        Maximum number of iterations.
    **fn_kwargs
        Additional arguments passed to fn.

    Returns
    -------
    ScipyMinimizeState
        Optimization result containing params, fun_val, success, and iter_num.

    Raises
    ------
    ImportError
        If ``cadre[scipy]`` optional dependencies are not installed.
    """
    from jaxopt import ScipyBoundedMinimize

    def host_solver_callback(x_init, lower, upper, fn_kwargs):
        """Host-side scipy solver callback."""
        # Handle bounds
        if lower is None and upper is None:
            bounds = None
        else:
            bounds = (lower, upper)

        # Define wrapped objective
        def scipy_fn(params, fn_kwargs):
            return fn(params, **fn_kwargs)

        # Scipy method handling
        solver_options = {"disp": False}
        if method == "cobyqa":
            try:
                import cobyqa  # noqa: F401
            except ImportError:
                raise ImportError(
                    "cobyqa not installed. Install with: pip install jax-cadre[scipy]"
                )

        solver = ScipyBoundedMinimize(
            fun=scipy_fn,
            method=method,
            jit=False,
            maxiter=maxiter,
            options=solver_options,
        )

        res = solver.run(x_init, bounds=bounds, fn_kwargs=fn_kwargs)

        # Return numpy arrays for pure_callback
        return {
            "params": jax.tree.map(lambda x: np.array(x), res.params),
            "fun_val": np.array(res.state.fun_val, dtype=np.float32),
            "success": np.array(res.state.success, dtype=bool),
            "iter_num": np.array(res.state.iter_num, dtype=np.int32),
        }

    # Define result shape for pure_callback
    result_shape = {
        "params": jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), init_params),
        "fun_val": jax.ShapeDtypeStruct((), jnp.float32),
        "success": jax.ShapeDtypeStruct((), jnp.bool_),
        "iter_num": jax.ShapeDtypeStruct((), jnp.int32),
    }

    result_dict = jax.pure_callback(
        host_solver_callback,
        result_shape,
        init_params,
        lower_bound,
        upper_bound,
        fn_kwargs,
        vmap_method="sequential",
    )

    return ScipyMinimizeState(
        params=result_dict["params"],
        fun_val=result_dict["fun_val"],
        success=result_dict["success"],
        iter_num=result_dict["iter_num"],
    )


# =============================================================================
# UNIFIED STATE
# =============================================================================


class UnifiedState(eqx.Module):
    """Unified optimization state.

    Attributes
    ----------
    best_loss : Scalar
        Best objective function value found.
    best_y : PyTree
        Best parameters found (in original space).
    iter_num : Scalar
        Number of iterations performed.
    solver_state : Any
        Internal solver state (Optimistix state or ScipyMinimizeState).
    """

    best_loss: Scalar
    best_y: PyTree[Float[Array, " P"]]
    iter_num: Scalar
    solver_state: Any


# =============================================================================
# UNIFIED OPTIMIZATION INTERFACE
# =============================================================================


[docs] def minimize( fn: Callable[..., Scalar], init_params: PyTree[Float[Array, " P"]], solver_name: SOLVER_NAMES = "optax_lbfgs", max_iter: int = 1000, rtol: float = 1e-8, atol: float = 1e-8, lower_bound: PyTree[Float[Array, " P"]] | None = None, upper_bound: PyTree[Float[Array, " P"]] | None = None, precondition: bool = False, options: dict[str, Any] | None = None, refresh_steps: int = 10, **fn_kwargs: Any, ) -> tuple[PyTree[Float[Array, " P"]], UnifiedState]: """ Unified optimization interface. Supports optax solvers, optimistix solvers (via optimistix.minimise), and scipy solvers (via jaxopt.ScipyMinimize, requires ``cadre[scipy]``). Parameters ---------- fn : Callable Objective function to minimize. Should accept (params, **fn_kwargs). init_params : PyTree Initial parameter values. solver_name : str Solver identifier. See SOLVER_NAMES for available options. max_iter : int Maximum iterations. rtol, atol : float Relative/absolute tolerance for optimization convergence. lower_bound, upper_bound : PyTree, optional Box constraints. precondition : bool Whether to apply parameter transformation and output scaling. options : dict, optional Extra arguments passed to the solver factory (get_solver). For active-set solvers (``ADABK{N}`` family) the recognised keys are: * ``cooldown`` (int, default 20) — steps to suppress termination after a constraint release. * ``min_steps`` (int, default 10) — minimum iterations before termination is considered. * ``verbose_print`` (bool, default False) — print per-step debug info via ``jax.debug.print`` (JIT-compatible). * ``max_linesearch_steps`` (int, default 50) — maximum line-search steps per iteration (active-set and ``optax_lbfgs`` solvers). * ``linesearch`` (str) — linesearch variant for ``optax_lbfgs`` (``"zoom"`` or ``"backtracking"``). **fn_kwargs Additional arguments passed to fn. Returns ------- final_params : PyTree Optimized parameters. final_state : UnifiedState Final optimizer state containing best loss, best parameters, iteration count, and solver state. """ solver_name = cast(SOLVER_NAMES, solver_name) if solver_name in SELFCONDITIONED_SOLVERS and precondition: warning(f"Solver '{solver_name}' is self-conditioned; ignoring preconditioning request.") precondition = False if precondition: fn, to_opt, from_opt = condition( fn, lower=lower_bound, upper=upper_bound, scale_function=precondition, init_params=init_params, **fn_kwargs, ) init_params = to_opt(init_params) lower_bound = to_opt(lower_bound) if lower_bound is not None else None upper_bound = to_opt(upper_bound) if upper_bound is not None else None else: from_opt = lambda x: x _opts = options or {} cooldown = _opts.get("cooldown", 20) min_steps = _opts.get("min_steps", 10) solver_kwargs = {k: v for k, v in _opts.items() if k not in ("cooldown", "min_steps")} solver, solver_type = get_solver( solver_name, rtol=rtol, atol=atol, lower=lower_bound, upper=upper_bound, cooldown=cooldown, min_steps=min_steps, **solver_kwargs, ) if solver_type == "optimistix": # Optimistix uses (y, args) signature, wrap fn def optx_fn(y, fn_kwargs): return fn(y, **fn_kwargs) # Does optax have TqdmProgressMeter? defined? if not hasattr(optx, "TqdmProgressMeter"): kwargs = {} warning("optx.TqdmProgressMeter not found. Progress meter disabled.") else: kwargs = {"progress_meter": optx.TqdmProgressMeter(refresh_steps=refresh_steps)} sol = optx.minimise( optx_fn, solver, init_params, max_steps=max_iter, throw=False, args=fn_kwargs, **kwargs, ) unified_state = UnifiedState( best_loss=sol.state.best_loss, best_y=from_opt(sol.state.best_y), iter_num=sol.stats["num_steps"], solver_state=sol.state, ) return from_opt(sol.value), unified_state elif solver_type == "scipy": # Scipy via vmap-compatible scipy_minimize method = solver_name.split("_")[1] options = _opts if method == "tnc": options["ftol"] = atol options["gtol"] = rtol options["xtol"] = atol elif method == "l-bfgs-b": options["ftol"] = atol options["gtol"] = rtol elif method == "cobyqa": # COBYQA options["final_tr_radius"] = atol state = scipy_minimize( fn=fn, init_params=init_params, lower_bound=lower_bound, upper_bound=upper_bound, method=method, maxiter=max_iter, **fn_kwargs, ) unified_state = UnifiedState( best_loss=state.fun_val, best_y=from_opt(state.params), iter_num=state.iter_num, solver_state=state, ) return from_opt(state.params), unified_state else: raise ValueError(f"Unknown solver type: {solver_type}")