Source code for furax_cs.data.search_space
from __future__ import annotations
from pathlib import Path
from typing import Any, Union
import jax.numpy as jnp
import yaml
from jaxtyping import (
Array,
Int,
)
[docs]
def load_search_space(filepath: Union[str, Path] | None = None) -> dict[str, Int[Array, " N"]]:
"""Load search space configuration from YAML file.
Args:
filepath: Path to custom search space YAML file. If None, loads the default
configuration from search_spaces_default.yaml.
Returns:
Dictionary with JAX arrays for T_d_patches, B_d_patches, B_s_patches.
Raises:
FileNotFoundError: If the specified filepath does not exist.
ValueError: If the YAML file is missing required keys or has invalid values.
Example:
>>> space = load_search_space()
>>> print(space["T_d_patches"])
"""
if filepath is None:
# Load default configuration from data directory
data_dir = Path(__file__).parent
filepath = data_dir / "search_spaces_default.yaml"
else:
filepath = Path(filepath)
if not filepath.exists():
raise FileNotFoundError(f"Search space file not found: {filepath}")
with open(filepath) as f:
config = yaml.safe_load(f)
# Validate required keys
required_keys = ["T_d_patches", "B_d_patches", "B_s_patches"]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ValueError(
f"Search space YAML missing required keys: {missing_keys}. "
f"Required keys: {required_keys}"
)
# Convert to JAX arrays
search_space = search_space_to_jax(config)
return search_space
[docs]
def search_space_to_jax(config: dict[str, Any]) -> dict[str, Int[Array, " N"]]:
"""Convert search space configuration from YAML to JAX arrays.
Args:
config: Dictionary with search space parameters as lists or numpy arrays.
Returns:
Dictionary with JAX arrays ready for grid search.
Example:
>>> config = {"T_d_patches": [1, 2], "B_d_patches": [3, 4], "B_s_patches": [5]}
>>> jax_space = search_space_to_jax(config)
"""
search_space: dict[str, Int[Array, " N"]] = {}
# Convert T_d_patches
if "T_d_patches" in config:
search_space["T_d_patches"] = jnp.array(config["T_d_patches"], dtype=jnp.int32)
# Convert B_d_patches - handle both list and range specifications
if "B_d_patches" in config:
b_d = config["B_d_patches"]
if isinstance(b_d, list):
search_space["B_d_patches"] = jnp.array(b_d, dtype=jnp.int32)
else:
# Handle potential dict specification for ranges
search_space["B_d_patches"] = jnp.array(b_d, dtype=jnp.int32)
# Convert B_s_patches
if "B_s_patches" in config:
search_space["B_s_patches"] = jnp.array(config["B_s_patches"], dtype=jnp.int32)
return search_space
[docs]
def dump_default_search_space(output_path: Union[str, Path]) -> None:
"""Dump the default search space configuration to a YAML file.
This creates a template file that users can customize for their needs.
Args:
output_path: Path where the default search space YAML will be saved.
Raises:
FileNotFoundError: If the default config file is missing in the package.
IOError: If the output file cannot be written.
Example:
>>> dump_default_search_space("my_search_space.yaml")
"""
output_path = Path(output_path)
# Load the default configuration
data_dir = Path(__file__).parent
default_path = data_dir / "search_spaces_default.yaml"
if not default_path.exists():
raise FileNotFoundError(
f"Default search space file not found at {default_path}. "
"This should not happen - please check the package installation."
)
# Copy the default file to the output location
with open(default_path) as f:
default_config = f.read()
with open(output_path, "w") as f:
f.write(default_config)
print(f"Default search space configuration saved to: {output_path}")
print("You can now edit this file to customize the search space.")
[docs]
def validate_search_space(search_space: dict[str, Int[Array, " N"]]) -> None:
"""Validate that search space has valid structure and values.
Args:
search_space: Dictionary with JAX arrays for search space parameters.
Raises:
ValueError: If validation fails (missing keys, empty arrays, invalid values).
Example:
>>> validate_search_space(search_space)
"""
required_keys = ["T_d_patches", "B_d_patches", "B_s_patches"]
for key in required_keys:
if key not in search_space:
raise ValueError(f"Search space missing required key: {key}")
arr = search_space[key]
if not isinstance(arr, Array):
raise ValueError(f"{key} must be an array, got {type(arr)}")
if arr.size == 0:
raise ValueError(f"{key} cannot be empty")
if jnp.any(arr < 1):
raise ValueError(f"{key} values must be >= 1, got minimum {jnp.min(arr)}")