from __future__ import annotations
from typing import Any, Literal, TypeAlias, Union
import jax
import jax.numpy as jnp
import optax
import optimistix as optx
from jaxtyping import Array, Bool, Float, PyTree
from optax._src import combine, transform
from optax._src import linesearch as _linesearch
from .active_set import active_set
# =============================================================================
# OFF-THE-SHELF L-BFGS SOLVERS
# =============================================================================
Solver: TypeAlias = Union[optx.BestSoFarMinimiser, str]
class ActiveSetMinimiser(optx.OptaxMinimiser):
cooldown_steps: int
min_steps: int
verbose_print: bool
def __init__(
self, optim, atol, rtol, cooldown_steps=20, min_steps=200, verbose_print=False, **kwargs
):
super().__init__(optim, atol=atol, rtol=rtol, **kwargs)
self.cooldown_steps = cooldown_steps
self.min_steps = min_steps
self.verbose_print = verbose_print
def terminate(
self,
fn: Any,
y: PyTree,
args: PyTree,
options: dict[str, Any],
state: Any,
tags: frozenset[object],
) -> tuple[Bool[Array, ""], optx.RESULTS]:
del fn, args, options
ast = state.opt_state # ActiveSetState
# Robust f-check: scale by best_f, not current f (spike-immune)
scale = jnp.maximum(1.0, jnp.abs(ast.best_f))
f_diff = jnp.abs(ast.f_val - ast.prev_f)
f_converged = f_diff < self.atol + self.rtol * scale
# Require BOTH robust f-check AND cauchy y-space convergence
converged = f_converged & state.terminate
# Override: don't terminate before min_steps
too_early = ast.count < self.min_steps
# Override: don't terminate during cooldown after a constraint release
# (last_release_step == -1 means no release has happened yet)
steps_since_release = ast.count - ast.last_release_step
in_cooldown = (ast.last_release_step >= 0) & (steps_since_release < self.cooldown_steps)
override = ast.constraints_released | in_cooldown | too_early
terminate = jnp.where(override, False, converged)
if self.verbose_print:
jax.debug.print(
"step={s} | f={f:.4e} best_f={bf:.4e} f_diff={fd:.4e} scale={sc:.4e} | "
"f_conv={fc} cooldown={cd} too_early={te} released={rel} cauchy={cau} => terminate={t}",
s=ast.count,
f=ast.f_val,
bf=ast.best_f,
fd=f_diff,
sc=scale,
fc=f_converged,
cd=in_cooldown,
te=too_early,
rel=ast.constraints_released,
cau=state.terminate,
t=terminate,
)
return terminate, optx.RESULTS.successful
def lbfgs_zoom(
learning_rate: optax.ScalarOrSchedule | None = None,
memory_size: int = 10,
scale_init_precond: bool = False,
max_linesearch_steps: int = 200,
initial_guess_strategy: str = "one",
slope_rtol: float = 1e-4,
curv_rtol: float = 0.9,
verbose: bool = False,
lower: PyTree[Float[Array, " P"]] | None = None,
upper: PyTree[Float[Array, " P"]] | None = None,
) -> optax.GradientTransformation:
"""L-BFGS with zoom linesearch (strong Wolfe conditions).
This is the standard L-BFGS with zoom linesearch that enforces both:
- Sufficient decrease (Armijo): f(x + η*d) ≤ f(x) + c1*η*∇f(x)ᵀd
- Curvature condition: |∇f(x + η*d)ᵀd| ≤ c2*|∇f(x)ᵀd|
Args:
learning_rate: Optional global scaling factor.
memory_size: Number of past updates for Hessian approximation.
scale_init_precond: Whether to scale initial Hessian approximation.
WARNING: Set to False for numerically sensitive problems.
max_linesearch_steps: Maximum iterations for zoom linesearch.
initial_guess_strategy: "one" (start at η=1) or "keep" (use previous).
slope_rtol: c1 parameter for Armijo condition (default 1e-4).
curv_rtol: c2 parameter for curvature condition (default 0.9).
verbose: Print linesearch debugging info.
lower: Optional lower bounds for box projection (pytree).
upper: Optional upper bounds for box projection (pytree).
Returns:
An optax GradientTransformation.
"""
if learning_rate is None:
base_scaling = transform.scale(-1.0)
else:
base_scaling = optax.scale_by_learning_rate(learning_rate)
linesearch = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=max_linesearch_steps,
initial_guess_strategy=initial_guess_strategy,
slope_rtol=slope_rtol,
curv_rtol=curv_rtol,
verbose=verbose,
)
chain_components = [
transform.scale_by_lbfgs(
memory_size=memory_size,
scale_init_precond=scale_init_precond,
),
base_scaling,
linesearch,
]
# Add projection if bounds provided
if lower is not None and upper is not None:
chain_components.append(apply_projection(lower, upper))
return combine.chain(*chain_components)
def lbfgs_backtrack(
learning_rate: optax.ScalarOrSchedule | None = None,
memory_size: int = 10,
scale_init_precond: bool = False,
max_backtracking_steps: int = 200,
slope_rtol: float = 1e-4,
decrease_factor: float = 0.8,
increase_factor: float = 1.5,
max_learning_rate: float = 1.0,
verbose: bool = False,
lower: PyTree[Float[Array, " P"]] | None = None,
upper: PyTree[Float[Array, " P"]] | None = None,
) -> optax.GradientTransformation:
"""L-BFGS with backtracking linesearch (Armijo condition only).
Simpler than zoom linesearch, only enforces sufficient decrease:
- Armijo: f(x + η*d) ≤ f(x) + c1*η*∇f(x)ᵀd
Args:
learning_rate: Optional global scaling factor.
memory_size: Number of past updates for Hessian approximation.
scale_init_precond: Whether to scale initial Hessian approximation.
WARNING: Set to False for numerically sensitive problems.
max_backtracking_steps: Maximum backtracking iterations.
slope_rtol: c1 parameter for Armijo condition (default 1e-4).
decrease_factor: Multiply stepsize by this when condition fails (default 0.8).
increase_factor: Initial guess = previous * this factor (default 1.5).
max_learning_rate: Upper bound on stepsize (default 1.0).
verbose: Print linesearch debugging info.
lower: Optional lower bounds for box projection (pytree).
upper: Optional upper bounds for box projection (pytree).
Returns:
An optax GradientTransformation.
"""
if learning_rate is None:
base_scaling = transform.scale(-1.0)
else:
base_scaling = optax.scale_by_learning_rate(learning_rate)
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_backtracking_steps,
slope_rtol=slope_rtol,
decrease_factor=decrease_factor,
increase_factor=increase_factor,
max_learning_rate=max_learning_rate,
verbose=verbose,
)
chain_components = [
transform.scale_by_lbfgs(
memory_size=memory_size,
scale_init_precond=scale_init_precond,
),
base_scaling,
linesearch,
]
# Add projection if bounds provided
if lower is not None and upper is not None:
chain_components.append(apply_projection(lower, upper))
return combine.chain(*chain_components)
def backtracking_adam(
max_backtracking_steps: int = 200,
slope_rtol: float = 1e-4,
decrease_factor: float = 0.8,
increase_factor: float = 1.5,
max_learning_rate: float = 1.0,
verbose: bool = False,
lower: PyTree[Float[Array, " P"]] | None = None,
upper: PyTree[Float[Array, " P"]] | None = None,
) -> optax.GradientTransformation:
"""Adam with backtracking linesearch (Armijo condition only)."""
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_backtracking_steps,
slope_rtol=slope_rtol,
decrease_factor=decrease_factor,
increase_factor=increase_factor,
max_learning_rate=max_learning_rate,
verbose=verbose,
)
chain_components = [
optax.adam(learning_rate=1.0), # Learning rate handled by linesearch
linesearch,
]
# Add projection if bounds provided
if lower is not None and upper is not None:
chain_components.append(apply_projection(lower, upper))
return combine.chain(*chain_components)
# =============================================================================
# BOX PROJECTION TRANSFORMATION
# =============================================================================
def apply_projection(
lower: PyTree[Float[Array, " P"]] | None = None,
upper: PyTree[Float[Array, " P"]] | None = None,
) -> optax.GradientTransformation:
"""Wrap box projection into a GradientTransformation.
After applying this transformation, params + updates will be within [lower, upper].
The update rule: u_new = clip(p + u, lower, upper) - p
This can be chained with optimizers like:
optimizer = optax.chain(
optax.adam(learning_rate=1e-3),
apply_projection(lower={'w': 0.0}, upper={'w': 1.0})
)
Args:
lower: Lower bounds (pytree matching params structure)
upper: Upper bounds (pytree matching params structure)
Returns:
GradientTransformation that projects updates to keep params in bounds.
"""
def init_fn(params: PyTree[Float[Array, " P"]]) -> optax.EmptyState:
del params
return optax.EmptyState()
def update_fn(
updates: PyTree[Float[Array, " P"]],
state: optax.EmptyState,
params: PyTree[Float[Array, " P"]] | None = None,
) -> tuple[PyTree[Float[Array, " P"]], optax.EmptyState]:
if params is None:
raise ValueError("NO_PARAMS_MSG")
if lower is None or upper is None:
return updates, state
def process_leaf(
p: Float[Array, " P"],
u: Float[Array, " P"],
lo: Float[Array, " P"],
hi: Float[Array, " P"],
) -> Float[Array, " P"]:
if p is None or u is None:
return u
tentative = p + u
projected = jnp.clip(tentative, lo, hi)
return projected - p
new_updates = jax.tree.map(process_leaf, params, updates, lower, upper)
return new_updates, state
return optax.GradientTransformation(init_fn, update_fn)
# =============================================================================
# SOLVER NAMES AND FACTORY
# =============================================================================
SOLVER_NAMES = Literal[
# Optax L-BFGS (jax_grid_search compatible)
"optax_lbfgs",
"optax_lbfgs",
"adam",
"sgd",
"adabelief",
"adaw",
"active_set",
"active_set_sgd",
"active_set_adabelief",
"active_set_adaw",
# Optimistix BFGS
"optimistix_bfgs",
# Optimistix L-BFGS
"optimistix_lbfgs",
# Optimistix NCG (Armijo)
"optimistix_ncg_pr",
"optimistix_ncg_hs",
"optimistix_ncg_fr",
"optimistix_ncg_dy",
# Scipy
"scipy_tnc",
"scipy_cobyqa",
# Legacy aliases
"zoom",
"backtrack",
]
SELFCONDITIONED_SOLVERS = {"active_set", "active_set_sgd", "scipy_tnc", "scipy_cobyqa"}
[docs]
def get_solver(
solver_name: SOLVER_NAMES,
rtol: float = 1e-8,
atol: float = 1e-8,
learning_rate: float = 1e-3,
max_linesearch_steps: int = 50,
lower: PyTree[Float[Array, " P"]] | None = None,
upper: PyTree[Float[Array, " P"]] | None = None,
verbose_print: bool = False,
min_steps: int = 10,
cooldown: int = 20,
**kwargs: Any,
) -> tuple[Solver, Literal["optimistix", "scipy"]]:
"""
Create a solver instance from a name string.
Parameters
----------
solver_name : str
Solver identifier. See SOLVER_NAMES for available options.
rtol : float
Relative tolerance for optimistix solvers.
atol : float
Absolute tolerance for optimistix solvers.
learning_rate : float
Learning rate for adam solver.
max_linesearch_steps : int
Maximum linesearch steps for L-BFGS solvers.
lower : PyTree, optional
Lower bounds for box projection (optax solvers only).
upper : PyTree, optional
Upper bounds for box projection (optax solvers only).
verbose_print : bool
If True, print per-step termination diagnostics for active-set
solvers via ``jax.debug.print`` (JIT-compatible).
min_steps : int
Minimum iterations before termination is considered
(active-set solvers only).
cooldown : int
Steps to suppress termination after a constraint release
(active-set solvers only).
max_linesearch_steps : int
Maximum line-search steps per iteration (active-set and
``optax_lbfgs`` solvers).
Returns
-------
solver : Solver can be either a BestSoFar wrapped minimiser or a string for scipy.
The solver instance.
solver_type : str
One of "optimistix", "scipy".
"""
# Resolve aliases
# Optax solvers (with optional box projection)
if solver_name == "optax_lbfgs":
linesearch_type = kwargs.pop("linesearch", "zoom")
if linesearch_type == "zoom":
return optx.BestSoFarMinimiser(
optx.OptaxMinimiser(
lbfgs_zoom(
max_linesearch_steps=max_linesearch_steps,
lower=lower,
upper=upper,
**kwargs,
),
atol=atol,
rtol=rtol,
)
), "optimistix"
elif linesearch_type == "backtracking":
return optx.BestSoFarMinimiser(
optx.OptaxMinimiser(
lbfgs_backtrack(
max_backtracking_steps=max_linesearch_steps,
lower=lower,
upper=upper,
**kwargs,
),
atol=atol,
rtol=rtol,
)
), "optimistix"
else:
raise ValueError(
f"Unknown linesearch type: {linesearch_type}. Use 'backtracking' or 'zoom'."
)
elif solver_name == "adam":
# Chain adam with projection if bounds provided
# learning_rate from kwargs takes precedence over function parameter
lr = kwargs.pop("learning_rate", learning_rate)
adam_opt = optax.adam(learning_rate=lr, **kwargs)
if lower is not None and upper is not None:
adam_opt = combine.chain(adam_opt, apply_projection(lower, upper))
return optx.BestSoFarMinimiser(
optx.OptaxMinimiser(adam_opt, atol=atol, rtol=rtol)
), "optimistix"
elif solver_name == "sgd":
# Chain sgd with projection if bounds provided
# learning_rate from kwargs takes precedence (default 1.0 for linesearch use)
lr = kwargs.pop("learning_rate", 1.0)
direction = optax.sgd(learning_rate=lr)
# Keep your line search
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_linesearch_steps
)
if lower is not None and upper is not None:
sgd_opt = combine.chain(direction, linesearch, apply_projection(lower, upper))
else:
sgd_opt = combine.chain(direction, linesearch)
return optx.BestSoFarMinimiser(
optx.OptaxMinimiser(sgd_opt, atol=atol, rtol=rtol)
), "optimistix"
elif solver_name == "adabelief":
lr = kwargs.pop("learning_rate", learning_rate)
opt = optax.adabelief(learning_rate=lr)
if lower is not None and upper is not None:
opt = combine.chain(opt, apply_projection(lower, upper))
return optx.BestSoFarMinimiser(optx.OptaxMinimiser(opt, atol=atol, rtol=rtol)), "optimistix"
elif solver_name == "adaw" or solver_name == "adamw":
lr = kwargs.pop("learning_rate", learning_rate)
opt = optax.adamw(learning_rate=lr, **kwargs)
if lower is not None and upper is not None:
opt = combine.chain(opt, apply_projection(lower, upper))
return optx.BestSoFarMinimiser(optx.OptaxMinimiser(opt, atol=atol, rtol=rtol)), "optimistix"
elif solver_name == "active_set":
# Default configuration for active set: Adam + configurable linesearch
# Extract learning_rate and linesearch options
lr = kwargs.pop("learning_rate", 1.0)
linesearch_type = kwargs.pop("linesearch", "backtracking")
direction = optax.adam(learning_rate=lr)
if linesearch_type == "backtracking":
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_linesearch_steps
)
elif linesearch_type == "zoom":
linesearch = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=max_linesearch_steps
)
else:
raise ValueError(
f"Unknown linesearch type: {linesearch_type}. Use 'backtracking' or 'zoom'."
)
return optx.BestSoFarMinimiser(
ActiveSetMinimiser(
active_set(
direction,
linesearch,
lower=lower,
upper=upper,
verbose_print=verbose_print,
**kwargs,
),
atol=atol,
rtol=rtol,
min_steps=min_steps,
cooldown_steps=cooldown,
verbose_print=verbose_print,
)
), "optimistix"
elif solver_name == "active_set_sgd":
# Default configuration for active set SGD: SGD + configurable linesearch
# Extract learning_rate and linesearch options
lr = kwargs.pop("learning_rate", 1.0)
linesearch_type = kwargs.pop("linesearch", "backtracking")
direction = optax.sgd(learning_rate=lr)
if linesearch_type == "backtracking":
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_linesearch_steps
)
elif linesearch_type == "zoom":
linesearch = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=max_linesearch_steps
)
else:
raise ValueError(
f"Unknown linesearch type: {linesearch_type}. Use 'backtracking' or 'zoom'."
)
return optx.BestSoFarMinimiser(
ActiveSetMinimiser(
active_set(
direction,
linesearch,
lower=lower,
upper=upper,
verbose_print=verbose_print,
**kwargs,
),
atol=atol,
rtol=rtol,
min_steps=min_steps,
cooldown_steps=cooldown,
verbose_print=verbose_print,
)
), "optimistix"
elif solver_name == "active_set_adabelief" or solver_name.startswith("ADABK"):
lr = kwargs.pop("learning_rate", 1.0)
linesearch_type = kwargs.pop("linesearch", "zoom")
max_constraints_to_release = kwargs.pop("max_constraints_to_release", None)
if max_constraints_to_release is None:
# check int in ADABKN as in ADABK5 for example
if solver_name.startswith("ADABK") and len(solver_name) > 5:
try:
max_constraints_to_release = int(solver_name[5:]) * 0.1
except ValueError:
raise ValueError(
f"Invalid solver name: {solver_name}. "
f"When using 'ADABK' prefix, it should be followed by an integer."
)
direction = optax.adabelief(learning_rate=lr)
if linesearch_type == "backtracking":
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_linesearch_steps
)
elif linesearch_type == "zoom":
linesearch = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=max_linesearch_steps
)
else:
raise ValueError(
f"Unknown linesearch type: {linesearch_type}. Use 'backtracking' or 'zoom'."
)
return optx.BestSoFarMinimiser(
ActiveSetMinimiser(
active_set(
direction,
linesearch,
lower=lower,
upper=upper,
max_constraints_to_release=max_constraints_to_release,
verbose_print=verbose_print,
**kwargs,
),
atol=atol,
rtol=rtol,
min_steps=min_steps,
cooldown_steps=cooldown,
verbose_print=verbose_print,
)
), "optimistix"
elif solver_name == "active_set_adaw":
lr = kwargs.pop("learning_rate", 1.0)
linesearch_type = kwargs.pop("linesearch", "backtracking")
direction = optax.adamw(learning_rate=lr)
if linesearch_type == "backtracking":
linesearch = _linesearch.scale_by_backtracking_linesearch(
max_backtracking_steps=max_linesearch_steps
)
elif linesearch_type == "zoom":
linesearch = _linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=max_linesearch_steps
)
else:
raise ValueError(
f"Unknown linesearch type: {linesearch_type}. Use 'backtracking' or 'zoom'."
)
return optx.BestSoFarMinimiser(
ActiveSetMinimiser(
active_set(
direction,
linesearch,
lower=lower,
upper=upper,
verbose_print=verbose_print,
**kwargs,
),
atol=atol,
rtol=rtol,
min_steps=min_steps,
cooldown_steps=cooldown,
verbose_print=verbose_print,
)
), "optimistix"
# Optimistix BFGS
elif solver_name == "optimistix_bfgs":
return optx.BestSoFarMinimiser(optx.BFGS(rtol=rtol, atol=atol, **kwargs)), "optimistix"
# Optimistix L-BFGS
elif solver_name == "optimistix_lbfgs":
return optx.BestSoFarMinimiser(optx.LBFGS(rtol=rtol, atol=atol, **kwargs)), "optimistix"
# Optimistix NCG (Armijo)
elif solver_name == "optimistix_ncg_pr":
return optx.BestSoFarMinimiser(
optx.NonlinearCG(rtol=rtol, atol=atol, method=optx.polak_ribiere, **kwargs)
), "optimistix"
elif solver_name == "optimistix_ncg_hs":
return optx.BestSoFarMinimiser(
optx.NonlinearCG(rtol=rtol, atol=atol, method=optx.hestenes_stiefel, **kwargs)
), "optimistix"
elif solver_name == "optimistix_ncg_fr":
return optx.BestSoFarMinimiser(
optx.NonlinearCG(rtol=rtol, atol=atol, method=optx.fletcher_reeves, **kwargs)
), "optimistix"
elif solver_name == "optimistix_ncg_dy":
return optx.BestSoFarMinimiser(
optx.NonlinearCG(rtol=rtol, atol=atol, method=optx.dai_yuan, **kwargs)
), "optimistix"
# Scipy (returns string — handled by minimize() via scipy_minimize)
elif solver_name == "scipy_tnc":
return "scipy_tnc", "scipy"
elif solver_name == "scipy_cobyqa":
return "scipy_cobyqa", "scipy"
else:
raise ValueError(f"Unknown solver: {solver_name}. Available: {list(SOLVER_NAMES.__args__)}")