Minimization Solvers¶
furax-cs exposes a unified interface for various minimization solvers via the
CADRE package (Constraint-Aware Descent Routine Executor),
which provides the underlying implementations for Optax, Optimistix, and SciPy solvers.
Available Solvers¶
The solver_name argument in the minimize function accepts the following:
Recommended¶
ADABK0: Best for noisy maps. Active-set method with AdaBelief direction and Top-K constraint release (K=0, i.e. one constraint released per iteration). Very robust in low-SNR regions. See How ADABK Works below.optax_lbfgs: Best for noiseless runs (systematics). L-BFGS with zoom linesearch (Strong Wolfe conditions). Very fast and accurate for smooth, noise-free landscapes.
Other Options¶
Active set variants (self-conditioned):
ADABK{N}— AdaBelief + Top-K active set.N * 0.1= fraction of constraints released per step.ADABK0releases 1 constraint/step (most stable),ADABK5releases up to 50%. (see How ADABK Works and paper for more info)active_set— Active set with Adam direction.active_set_sgd— Active set with SGD direction.active_set_adabelief— Active set with AdaBelief direction.active_set_adaw— Active set with AdamW direction.
Optax L-BFGS:
optax_lbfgs— L-BFGS with zoom linesearch (default) or backtracking.
Optax first-order:
adam— Adam optimizer.sgd— SGD with backtracking linesearch.adabelief— AdaBelief optimizer.adaw/adamw— AdamW optimizer.
Optimistix:
optimistix_bfgs— Full BFGS.optimistix_lbfgs— Limited-memory BFGS.optimistix_ncg_pr— Nonlinear Conjugate Gradient (Polak-Ribière).optimistix_ncg_hs— Nonlinear Conjugate Gradient (Hestenes-Stiefel).optimistix_ncg_fr— Nonlinear Conjugate Gradient (Fletcher-Reeves).optimistix_ncg_dy— Nonlinear Conjugate Gradient (Dai-Yuan).
SciPy (self-conditioned):
scipy_tnc— Truncated Newton (TNC).scipy_cobyqa— COBYQA (derivative-free constrained optimizer).
How ADABK Works¶
ADABK (Adaptive AdaBelief with Top-K Active Set, also called AdaTopK in the paper) is a JAX-native optimizer that combines the TNC active-set constraint strategy with the AdaBelief adaptive gradient method.
Internal parameter space¶
Physical parameters x (bounded by l, u) are mapped to a normalized [0, 1] representation via an affine transform:
y = (x − l) / (u − l)
This normalizes the optimization landscape and ensures consistent step sizes across parameters with different physical scales.
Active set and pivot vector¶
Each parameter has a pivot value p_i:
p_i = −1: parameter is at the lower bound (active constraint)
p_i = +1: parameter is at the upper bound (active constraint)
p_i = 0: parameter is free
Only free parameters (p_i = 0) are optimized at each iteration.
Top-K constraint release¶
At each iteration, a release score is computed for every active constraint:
score_i = p_i × (−g_i)
A positive score means the negative gradient points into the feasible region — releasing this constraint could decrease the objective. The Top-K fraction K controls how many constraints are released per iteration:
K = 0 (
ADABK0): releases 1 constraint at a time. Most stable, consistently reaches the lowest objective values.K = N (
ADABK{N}): releases up toN × 0.1fraction of active constraints.
Projected gradient and AdaBelief direction¶
Gradients for active constraints are zeroed out: g_proj = g ⊙ (p = 0). The projected gradient is then fed to AdaBelief, which adapts step sizes based on gradient variance. This makes it better suited to noisy gradient landscapes (low-SNR regions) than classical quasi-Newton methods (L-BFGS, TNC) which tend to reset their curvature history when gradients are unreliable.
Dynamic state rescaling¶
When the gradient norm falls outside [10⁻¹⁵, 10¹⁵], the cost function and AdaBelief moment estimates are rescaled:
m ← f_scale · m , v ← f_scale² · v
This prevents numerical under/overflow across the extreme dynamic range between the bright Galactic plane and faint high-latitude sky, without resetting the optimizer’s momentum.
Bounded line search¶
The step size α is capped at the distance to the nearest bound (α_max), then a line search finds the optimal α in [0, α_max]. If a parameter hits a bound, it becomes an active constraint.
Termination Control (ADABK solvers)¶
Four parameters control when ADABK solvers decide to stop. Pass them via the options dict of minimize():
final_params, state = minimize(
fn=my_loss_fn,
init_params=params,
solver_name="ADABK0",
lower_bound=lower,
upper_bound=upper,
options={
"cooldown": 50, # default: 20
"min_steps": 200, # default: 10
"verbose_print": True, # default: False
"max_linesearch_steps": 100, # default: 50
},
)
Parameter |
Default |
Description |
|---|---|---|
|
|
Steps to suppress termination after a constraint is released. Prevents premature convergence caused by a transient function spike when a bound constraint opens. |
|
|
Minimum iterations before termination is ever considered. Useful when the initial gradient is near zero but the landscape is not yet explored. |
|
|
Print per-step diagnostics: current |
|
|
Maximum bounded line-search steps per iteration. |
How termination is decided¶
Termination requires all of the following to hold simultaneously:
f_diff = |f_current − f_prev| < atol + rtol × max(1, |best_f|)— spike-immune f-change checkCauchy-convergence in y-space (base Optimistix check)
Step count ≥
min_stepsNot inside the cooldown window after the last constraint release
CLI equivalents¶
When using kmeans-model or ptep-model, the same parameters are available as flags:
kmeans-model ... --cooldown 50 --min-steps 200 --verbose
Conditioning¶
Conditioning (preconditioning) transforms the optimization problem to improve convergence. It applies two transformations before optimization:
Parameter scaling: min-max normalization to [0, 1] based on bounds.
Gradient scaling: the objective is scaled by 1/‖∇f‖ at initialization (like SciPy TNC’s
fscale), so the initial gradient norm is ≈ 1.
Self-conditioned solvers¶
These solvers handle conditioning internally and ignore the precondition flag:
Active set variants (
active_set,active_set_sgd,active_set_adabelief,active_set_adaw,ADABK{N}) — use internal affine transform + dynamic state rescaling.SciPy solvers (
scipy_tnc,scipy_cobyqa) — SciPy handles bounds and scaling internally.
Externally conditioned solvers¶
All other solvers (optax_lbfgs, adam, optimistix_*, etc.) benefit from external conditioning when dealing with poorly scaled problems. Pass precondition=True (or a custom scaling function) to minimize.
Minimizing Programmatically¶
Single Run¶
from cadre import minimize
final_params, state = minimize(
fn=my_loss_fn,
init_params=params,
solver_name="ADABK0", # or "optax_lbfgs"
lower_bound=lower,
upper_bound=upper,
**kwargs
)
Advanced: Stepping Interactively with Solvers¶
Since most solvers (except SciPy) are JAX-compatible, you can step through the optimization process manually. This is useful for custom logging or adaptive strategies.
Here is an example of running the same optimization problem with multiple solvers programmatically using get_solver.
import jax
import jax.numpy as jnp
import optimistix as optx
from cadre.solvers import get_solver
from functools import partial
# Define a simple quadratic loss function
def simple_loss(params, target):
return jnp.sum((params - target) ** 2) , None
target = jnp.array([1.0, 2.0, 3.0])
init_params = jnp.array([0.0, 0.0, 0.0])
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
# List of solvers to test
print(f"{'Loss':<12} | {'Step':<5} | {'Status':<10} | {'Final Params':<30}")
print("-" * 55)
# active_set or optax_lbfgs for example
solver, _ = get_solver("optax_lbfgs", rtol=1e-6, atol=1e-6)
state = solver.init(simple_loss, init_params, target ,{} , f_struct , None , frozenset())
step = partial(solver.step, fn=simple_loss, args=target, options={}, tags=frozenset())
terminate = partial(solver.terminate, fn=simple_loss, args=target, options={}, tags=frozenset())
y = init_params
for i in range(100):
y , state , aux = step(y=y, state=state)
done, result = terminate(y=y, state=state)
loss = state.state.f
print(f"{loss:<12.6f} | {i+1:<5} | {'Done' if done else 'In Progress':<10} | {y}")
if done:
break
y, _, _ = solver.postprocess(simple_loss, y, None, target, {}, state, frozenset(), result)