from __future__ import annotations
from collections.abc import Callable
from typing import Any, cast
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optimistix as optx
from jaxtyping import (
Array,
Float,
PyTree, # pyright: ignore
Scalar,
)
from ._compat import requires_scipy
from ._logging import warning
from .solvers import SELFCONDITIONED_SOLVERS, SOLVER_NAMES, get_solver
from .utils import condition
try:
import jaxopt # noqa: F401
except ImportError:
pass
# =============================================================================
# SCIPY MINIMIZE WITH VMAP SUPPORT
# =============================================================================
class ScipyMinimizeState(eqx.Module):
"""State returned by scipy minimize via pure_callback.
This equinox module holds the optimization result in a JAX-compatible format
that can be used with vmap/lax.map.
Attributes
----------
params : PyTree
Optimized parameters.
fun_val : Scalar
Final objective function value (scalar).
success : Scalar
Whether optimization converged successfully (bool scalar).
iter_num : Scalar
Number of iterations performed (int32 scalar).
"""
params: PyTree[Float[Array, " P"]]
fun_val: Scalar
success: Scalar
iter_num: Scalar
@requires_scipy
def scipy_minimize(
fn: Callable[..., Scalar],
init_params: PyTree[Float[Array, " P"]],
lower_bound: PyTree[Float[Array, " P"]] | None = None,
upper_bound: PyTree[Float[Array, " P"]] | None = None,
method: str = "tnc",
maxiter: int = 1000,
**fn_kwargs: Any,
) -> ScipyMinimizeState:
"""Scipy minimize wrapper that supports vmap via jax.pure_callback.
This function wraps scipy optimization in a way that is compatible with
JAX transformations like vmap and lax.map. It uses jax.pure_callback to
call the host-side scipy solver.
Parameters
----------
fn : Callable
Objective function to minimize. Should accept (params, **fn_kwargs).
init_params : PyTree
Initial parameter values.
lower_bound : PyTree, optional
Lower bounds for parameters. Same shape as init_params.
upper_bound : PyTree, optional
Upper bounds for parameters. Same shape as init_params.
method : str
Scipy optimization method (default "tnc").
maxiter : int
Maximum number of iterations.
**fn_kwargs
Additional arguments passed to fn.
Returns
-------
ScipyMinimizeState
Optimization result containing params, fun_val, success, and iter_num.
Raises
------
ImportError
If ``cadre[scipy]`` optional dependencies are not installed.
"""
from jaxopt import ScipyBoundedMinimize
def host_solver_callback(x_init, lower, upper, fn_kwargs):
"""Host-side scipy solver callback."""
# Handle bounds
if lower is None and upper is None:
bounds = None
else:
bounds = (lower, upper)
# Define wrapped objective
def scipy_fn(params, fn_kwargs):
return fn(params, **fn_kwargs)
# Scipy method handling
solver_options = {"disp": False}
if method == "cobyqa":
try:
import cobyqa # noqa: F401
except ImportError:
raise ImportError(
"cobyqa not installed. Install with: pip install jax-cadre[scipy]"
)
solver = ScipyBoundedMinimize(
fun=scipy_fn,
method=method,
jit=False,
maxiter=maxiter,
options=solver_options,
)
res = solver.run(x_init, bounds=bounds, fn_kwargs=fn_kwargs)
# Return numpy arrays for pure_callback
return {
"params": jax.tree.map(lambda x: np.array(x), res.params),
"fun_val": np.array(res.state.fun_val, dtype=np.float32),
"success": np.array(res.state.success, dtype=bool),
"iter_num": np.array(res.state.iter_num, dtype=np.int32),
}
# Define result shape for pure_callback
result_shape = {
"params": jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), init_params),
"fun_val": jax.ShapeDtypeStruct((), jnp.float32),
"success": jax.ShapeDtypeStruct((), jnp.bool_),
"iter_num": jax.ShapeDtypeStruct((), jnp.int32),
}
result_dict = jax.pure_callback(
host_solver_callback,
result_shape,
init_params,
lower_bound,
upper_bound,
fn_kwargs,
vmap_method="sequential",
)
return ScipyMinimizeState(
params=result_dict["params"],
fun_val=result_dict["fun_val"],
success=result_dict["success"],
iter_num=result_dict["iter_num"],
)
# =============================================================================
# UNIFIED STATE
# =============================================================================
class UnifiedState(eqx.Module):
"""Unified optimization state.
Attributes
----------
best_loss : Scalar
Best objective function value found.
best_y : PyTree
Best parameters found (in original space).
iter_num : Scalar
Number of iterations performed.
solver_state : Any
Internal solver state (Optimistix state or ScipyMinimizeState).
"""
best_loss: Scalar
best_y: PyTree[Float[Array, " P"]]
iter_num: Scalar
solver_state: Any
# =============================================================================
# UNIFIED OPTIMIZATION INTERFACE
# =============================================================================
[docs]
def minimize(
fn: Callable[..., Scalar],
init_params: PyTree[Float[Array, " P"]],
solver_name: SOLVER_NAMES = "optax_lbfgs",
max_iter: int = 1000,
rtol: float = 1e-8,
atol: float = 1e-8,
lower_bound: PyTree[Float[Array, " P"]] | None = None,
upper_bound: PyTree[Float[Array, " P"]] | None = None,
precondition: bool = False,
options: dict[str, Any] | None = None,
refresh_steps: int = 10,
**fn_kwargs: Any,
) -> tuple[PyTree[Float[Array, " P"]], UnifiedState]:
"""
Unified optimization interface.
Supports optax solvers, optimistix solvers (via optimistix.minimise),
and scipy solvers (via jaxopt.ScipyMinimize, requires ``cadre[scipy]``).
Parameters
----------
fn : Callable
Objective function to minimize. Should accept (params, **fn_kwargs).
init_params : PyTree
Initial parameter values.
solver_name : str
Solver identifier. See SOLVER_NAMES for available options.
max_iter : int
Maximum iterations.
rtol, atol : float
Relative/absolute tolerance for optimization convergence.
lower_bound, upper_bound : PyTree, optional
Box constraints.
precondition : bool
Whether to apply parameter transformation and output scaling.
options : dict, optional
Extra arguments passed to the solver factory (get_solver).
For active-set solvers (``ADABK{N}`` family) the recognised keys are:
* ``cooldown`` (int, default 20) — steps to suppress termination
after a constraint release.
* ``min_steps`` (int, default 10) — minimum iterations before
termination is considered.
* ``verbose_print`` (bool, default False) — print per-step debug
info via ``jax.debug.print`` (JIT-compatible).
* ``max_linesearch_steps`` (int, default 50) — maximum line-search
steps per iteration (active-set and ``optax_lbfgs`` solvers).
* ``linesearch`` (str) — linesearch variant for ``optax_lbfgs``
(``"zoom"`` or ``"backtracking"``).
**fn_kwargs
Additional arguments passed to fn.
Returns
-------
final_params : PyTree
Optimized parameters.
final_state : UnifiedState
Final optimizer state containing best loss, best parameters, iteration count, and solver state.
"""
solver_name = cast(SOLVER_NAMES, solver_name)
if solver_name in SELFCONDITIONED_SOLVERS and precondition:
warning(f"Solver '{solver_name}' is self-conditioned; ignoring preconditioning request.")
precondition = False
if precondition:
fn, to_opt, from_opt = condition(
fn,
lower=lower_bound,
upper=upper_bound,
scale_function=precondition,
init_params=init_params,
**fn_kwargs,
)
init_params = to_opt(init_params)
lower_bound = to_opt(lower_bound) if lower_bound is not None else None
upper_bound = to_opt(upper_bound) if upper_bound is not None else None
else:
from_opt = lambda x: x
_opts = options or {}
cooldown = _opts.get("cooldown", 20)
min_steps = _opts.get("min_steps", 10)
solver_kwargs = {k: v for k, v in _opts.items() if k not in ("cooldown", "min_steps")}
solver, solver_type = get_solver(
solver_name,
rtol=rtol,
atol=atol,
lower=lower_bound,
upper=upper_bound,
cooldown=cooldown,
min_steps=min_steps,
**solver_kwargs,
)
if solver_type == "optimistix":
# Optimistix uses (y, args) signature, wrap fn
def optx_fn(y, fn_kwargs):
return fn(y, **fn_kwargs)
# Does optax have TqdmProgressMeter? defined?
if not hasattr(optx, "TqdmProgressMeter"):
kwargs = {}
warning("optx.TqdmProgressMeter not found. Progress meter disabled.")
else:
kwargs = {"progress_meter": optx.TqdmProgressMeter(refresh_steps=refresh_steps)}
sol = optx.minimise(
optx_fn,
solver,
init_params,
max_steps=max_iter,
throw=False,
args=fn_kwargs,
**kwargs,
)
unified_state = UnifiedState(
best_loss=sol.state.best_loss,
best_y=from_opt(sol.state.best_y),
iter_num=sol.stats["num_steps"],
solver_state=sol.state,
)
return from_opt(sol.value), unified_state
elif solver_type == "scipy":
# Scipy via vmap-compatible scipy_minimize
method = solver_name.split("_")[1]
options = _opts
if method == "tnc":
options["ftol"] = atol
options["gtol"] = rtol
options["xtol"] = atol
elif method == "l-bfgs-b":
options["ftol"] = atol
options["gtol"] = rtol
elif method == "cobyqa": # COBYQA
options["final_tr_radius"] = atol
state = scipy_minimize(
fn=fn,
init_params=init_params,
lower_bound=lower_bound,
upper_bound=upper_bound,
method=method,
maxiter=max_iter,
**fn_kwargs,
)
unified_state = UnifiedState(
best_loss=state.fun_val,
best_y=from_opt(state.params),
iter_num=state.iter_num,
solver_state=state,
)
return from_opt(state.params), unified_state
else:
raise ValueError(f"Unknown solver type: {solver_type}")