desisky provides machine learning models and tools for DESI sky modeling:
- Predictive broadband model — Predicts surface brightness in V, g, r, and z photometric bands from observational metadata (moon position, transparency, eclipse fraction)
- Variational Autoencoder (VAE) — Compresses sky spectra (7,781 wavelength bins → 8-dimensional latent space) for analysis, anomaly detection, and dimensionality reduction. Trained with InfoVAE-MMD objective
- Latent Diffusion Models (LDM) — Generates realistic night-sky emission spectra using EDM preconditioning (Karras et al. 2022), conditioned on observational parameters:
- LDM Dark — Dark-time spectra conditioned on 8 features: sun position, transparency, galactic/ecliptic coordinates, and solar flux
- LDM Moon — Moon-contaminated spectra conditioned on 6 features: moon position, separation, and illumination fraction
- LDM Twilight — Twilight spectra conditioned on 4 features: observation altitude, transparency, sun altitude, and sun separation
- Data utilities — Download and load the DESI DR1 Sky Spectra Value-Added Catalog (VAC) with automatic SHA-256 integrity verification, subset filtering, and enrichment (V-band magnitudes, eclipse fractions, solar flux, coordinate transforms)
- Spectral analysis — Measure airglow emission line intensities and compute broadband magnitudes directly from spectra
- CLI tools — Train and run inference on all models from the command line with optional W&B experiment tracking
- Experiment tracking — Optional Weights & Biases integration with visualization callbacks and hyperparameter sweeps
Built with JAX/Equinox for high-performance model inference and designed to integrate with SpecSim and survey forecasting workflows. This repository hosts the code and notebooks supporting the forthcoming paper by Dowicz et al. (20XX).
- Installation
- Quick Start
- Data
- Models
- CLI Tools
- Training (Python API)
- Experiment Tracking (W&B)
- Visualization
- Examples
- Project Structure
- Development
- License
# Default: inference-ready (CPU)
pip install desisky
# With data loading (FITS files, enrichment)
pip install desisky[data]
# GPU training + data + visualization
pip install desisky[cuda12,data,viz]
# Everything (CPU) including W&B experiment tracking
pip install desisky[all]
# Everything with GPU
pip install desisky[all,cuda12]Note: CUDA wheels may require manual installation. See the JAX installation guide for details.
| Extra | Packages |
|---|---|
cuda12 |
jax[cuda12], equinox, optax, torch, tqdm |
data |
fitsio, pandas, speclite, astropy |
viz |
matplotlib, umap-learn |
wandb |
wandb, matplotlib, pandas |
all |
All of the above (CPU JAX) |
Core dependencies (always installed): numpy, scipy, requests, jax, equinox
from desisky.data import SkySpecVAC
# Download DR1 VAC (~274 MB, with SHA-256 verification)
vac = SkySpecVAC(version="v1.0", download=True)
# Load wavelength, flux, and metadata
wavelength, flux, metadata = vac.load()
print(f"Wavelength: {wavelength.shape}") # (7781,)
print(f"Flux: {flux.shape}") # (9176, 7781)
print(f"Metadata columns: {list(metadata.columns)}")
# ['NIGHT', 'EXPID', 'TILEID', 'AIRMASS', 'EBV', 'MOONFRAC', 'MOONALT', ...]
# Enriched columns are added automatically for v1.0
print('SKY_MAG_V_SPEC' in metadata.columns) # True
print('ECLIPSE_FRAC' in metadata.columns) # Trueimport desisky
import jax.numpy as jnp
model, meta = desisky.io.load_model("broadband")
# Input: [MOONSEP, OBSALT, MOONALT, MOONFRAC, TRANSPARENCY_GFA, ECLIPSE_FRAC]
x = jnp.array([45.0, 80.0, 30.0, 0.8, 0.95, 0.0])
# Predict surface brightness in V, g, r, z bands
y = model(x) # Shape: (4,)
print(f"Predicted magnitudes: {y}")from desisky.io import load_model
from desisky.data import SkySpecVAC
import jax
import jax.random as jr
vac = SkySpecVAC(version="v1.0", download=True)
wavelength, flux, metadata = vac.load()
vae, meta = load_model("vae")
# Encode a single spectrum to latent representation
mean, logvar = vae.encode(flux[0])
print(f"Latent mean: {mean}") # Shape: (8,)
# Sample and decode
latent = vae.sample(mean, logvar, jr.PRNGKey(0))
reconstructed = vae.decode(latent)
print(f"Reconstructed shape: {reconstructed.shape}") # (7781,)
# Batch encoding with vmap
batch_means, batch_logvars = jax.vmap(vae.encode)(flux)
print(f"Batch latents: {batch_means.shape}") # (9176, 8)from desisky.inference import LatentDiffusionSampler
import jax.random as jr
import jax.numpy as jnp
# Create sampler (auto-loads model + VAE from pretrained checkpoints)
sampler = LatentDiffusionSampler("ldm_dark")
# Conditioning: [OBSALT, transparency, SUNALT, SOLFLUX, ECLLON, ECLLAT, GALLON, GALLAT]
conditioning = jnp.array([
[2100.0, 0.9, -30.0, 150.0, 45.0, 10.0, 120.0, 5.0], # Dark sky
])
generated = sampler.sample(
key=jr.PRNGKey(42),
conditioning=conditioning,
guidance_scale=2.0,
)
print(f"Generated spectrum shape: {generated.shape}") # (1, 7781)The VAC provides subset methods for filtering observations by sky conditions:
Dark time (non-contaminated):
wave, flux, meta = vac.load_dark_time()
# SUNALT < -20 | MOONALT < -5 | TRANSPARENCY_GFA > 0Twilight (sun-contaminated):
wave, flux, meta = vac.load_sun_contaminated()
# SUNALT > -20 | MOONALT <= -5 | SUNSEP <= 110 | TRANSPARENCY_GFA > 0Moon-contaminated:
wave, flux, meta = vac.load_moon_contaminated()
# SUNALT < -20 | MOONALT > 5 | MOONFRAC > 0.5 | MOONSEP <= 90 | TRANSPARENCY_GFA > 0All data is automatically enriched with computed columns on load (v1.0 only):
| Column | Description |
|---|---|
SKY_MAG_V_SPEC |
V-band AB magnitude computed from the spectrum via speclite |
ECLIPSE_FRAC |
Lunar eclipse umbral coverage fraction (0-1) |
SOLFLUX |
Daily 10.7 cm solar flux (sfu) |
GALLON, GALLAT |
Galactic coordinates (from TILERA/TILEDEC) |
ECLLON, ECLLAT |
Ecliptic coordinates (from TILERA/TILEDEC) |
The individual enrichment functions are also available for standalone use:
from desisky.data import (
compute_vband_magnitudes, # V-band magnitudes from spectra
load_eclipse_catalog, # NASA lunar eclipse catalog
compute_eclipse_fraction, # Eclipse umbral coverage
load_solar_flux, # F10.7 solar radio flux
attach_solar_flux, # Add SOLFLUX column to metadata
add_galactic_coordinates, # Add GALLON, GALLAT columns
add_ecliptic_coordinates, # Add ECLLON, ECLLAT columns
)Extract physical features from spectra (require desisky[data]):
from desisky.data import measure_airglow_intensities, compute_broadband_mags
# Measure 10 airglow emission line intensities via continuum-subtracted integration
# Returns DataFrame: OI_5577, OI_6300, OI_6364, OH_1, OH_2, ..., OH_7
airglow = measure_airglow_intensities(wavelength, flux)
# Compute broadband magnitudes via speclite (V, g, r, z)
mags = compute_broadband_mags(wavelength, flux) # Shape: (n_spectra, 4)The airglow measurement follows the method of Noll et al. (2012), using two flanking continuum windows for background subtraction. Composite lines are also computed: OH (sum of all OH bands) and OI doublet (OI 6300 + OI 6364).
Related constants:
LINE_BANDS— Dictionary of airglow line wavelength windowsAIRGLOW_CDF_NAMES— Standard names for the 10 + 2 composite airglow featuresBROADBAND_NAMES— Standard names for the 4 broadband magnitudes (["V", "g", "r", "z"])FLUX_SCALE— Default flux scaling factor (1e-17 erg/s/cm^2/A)
# Show default data directory
desisky-data dir
# Download DESI DR1 sky spectra VAC
desisky-data fetch --version v1.0
# Download to custom location
desisky-data fetch --root /path/to/data
# Skip checksum verification (not recommended)
desisky-data fetch --no-verifyOverride the default data directory with an environment variable:
export DESISKY_DATA_DIR=/path/to/data| Model | Architecture | Description |
|---|---|---|
broadband |
MLP (6 → 128 × 5 → 4) | Predicts V, g, r, z magnitudes from observational metadata |
vae |
Encoder-Decoder (7781 → 8 → 7781) | Compresses sky spectra to 8D latent space |
ldm_dark |
1D U-Net + EDM | Generates dark-time spectra (8 conditioning features) |
ldm_moon |
1D U-Net + EDM | Generates moon-contaminated spectra (6 conditioning features) |
ldm_twilight |
1D U-Net + EDM | Generates twilight spectra (4 conditioning features) |
LDM conditioning features:
ldm_dark:[OBSALT, TRANSPARENCY_GFA, SUNALT, SOLFLUX, ECLLON, ECLLAT, GALLON, GALLAT]ldm_moon:[OBSALT, TRANSPARENCY_GFA, SUNALT, MOONALT, MOONSEP, MOONFRAC]ldm_twilight:[OBSALT, TRANSPARENCY_GFA, SUNALT, SUNSEP]
All pre-trained weights are hosted on HuggingFace and downloaded automatically on first use:
import desisky
# Load pre-trained weights (downloads from HuggingFace on first use)
model, meta = desisky.io.load_model("broadband")
# Load from a user checkpoint
model, meta = desisky.io.load_model("vae", path="path/to/checkpoint.eqx")
# Save a trained model
desisky.io.save(
"my_model.eqx",
model,
meta={
"schema": 1,
"arch": {"in_channels": 7781, "latent_dim": 8},
"training": {"date": "2025-01-15", "epoch": 100},
},
)Checkpoints use a JSON header (architecture + training metadata) followed by binary Equinox-serialized weights.
By default, downloaded weights are cached in ~/.desisky/models/<kind>/. Override with:
export DESISKY_CACHE_DIR=/path/to/cache # shellimport os
os.environ["DESISKY_CACHE_DIR"] = "/path/to/cache" # Python / notebookAll CLI commands are registered as console entry points and available after installation.
- Inference commands work with the base install (
pip install desisky) - Training commands require training dependencies:
pip install desisky[all](CPU) orpip install desisky[cuda12](GPU) - Training with W&B visualization requires the wandb extra:
pip install desisky[all,wandb]orpip install desisky[cuda12,wandb]
For the full reference including data formats and wandb integration, see docs/CLI_GUIDE.md.
# Broadband MLP (moon-contaminated data)
desisky-train-broadband --epochs 500
desisky-train-broadband --epochs 500 --wandb
# VAE (full dataset)
desisky-train-vae --epochs 100
desisky-train-vae --epochs 100 --wandb
# LDM (per-variant: dark, moon, twilight)
desisky-train-ldm --variant dark --epochs 200
desisky-train-ldm --variant moon --epochs 300 --wandb --vae-path my_vae.eqxAll training scripts support:
--wandbfor optional W&B experiment tracking with automatic visualization callbacks--data-pathfor user-provided data (.fits,.csv,.npzdepending on model)--no-saveto skip checkpointing (useful for testing or sweeps)--vae-path/--model-pathfor custom pretrained weights
# Broadband predictions (CSV or npz output)
desisky-infer-broadband --output predictions.csv
desisky-infer-broadband --output predictions.npz --output-format npz
# VAE encode + reconstruct
desisky-infer-vae --subset dark --output dark_latents.npz
# LDM spectral generation
desisky-infer-ldm --variant dark --n-samples 500
desisky-infer-ldm --variant moon --n-samples 100 --guidance-scale 2.0
desisky-infer-ldm --conditioning '[[60,0.9,-30,150,45,10,120,5]]'The VAE is trained with the InfoVAE-MMD objective, which provides better control over the trade-off between reconstruction quality and latent space regularization compared to standard beta-VAE. The total loss is:
L = Reconstruction + beta * KL + (lam - beta) * MMD
from desisky.training import VAETrainer, VAETrainingConfig, NumpyLoader
from desisky.models.vae import make_SkyVAE
import jax.random as jr
model = make_SkyVAE(in_channels=7781, latent_dim=8, key=jr.PRNGKey(42))
config = VAETrainingConfig(
epochs=100,
learning_rate=1e-4,
beta=1e-3, # KL divergence weight
lam=4.0, # Total regularization weight (MMD weight = lam - beta)
kernel_sigma="auto", # RBF kernel bandwidth for MMD
)
trainer = VAETrainer(model, config)
trained_model, history = trainer.train(train_loader, test_loader)The LDM is trained with the EDM framework (Karras et al. 2022) using continuous log-normal noise sampling, preconditioned denoiser, and EDM-weighted loss. Exponential Moving Average (EMA) of model weights is maintained for stable inference.
from desisky.training import (
LatentDiffusionTrainer, LDMTrainingConfig,
fit_conditioning_scaler, normalize_conditioning,
)
from desisky.models.ldm import compute_sigma_data
# 1. Compute sigma_data from training latents
sigma_data = compute_sigma_data(latent_train)
# 2. Fit conditioning scaler on training data (stored in checkpoint for inference)
scaler = fit_conditioning_scaler(cond_train, ["OBSALT", "TRANSPARENCY_GFA", "SUNALT", ...])
# 3. Normalize conditioning with the scaler
cond_train_norm = normalize_conditioning(cond_train, scaler)
cond_val_norm = normalize_conditioning(cond_val, scaler)
# 4. Configure training — scaler is passed here so it gets saved in checkpoint metadata
config = LDMTrainingConfig(
epochs=200,
learning_rate=1e-4,
meta_dim=8, # Number of conditioning features
sigma_data=sigma_data,
ema_decay=0.9999,
early_stop_on_ema=True, # Gate early stopping on EMA validation loss
conditioning_scaler=scaler, # Saved in checkpoint for auto-normalization at inference
)
trainer = LatentDiffusionTrainer(model, config)
model, ema_model, history = trainer.train(train_loader, val_loader)Both trainers support:
- Automatic best-model checkpointing
- Optional
on_epoch_end(model, history, epoch)callback for custom per-epoch logging tqdmprogress bars (auto-detected; falls back toprint_everywhen unavailable)- Training without validation (
test_loader=None/val_loader=None) for final training after hyperparameters are validated
Optionally integrate with Weights & Biases for real-time experiment tracking and hyperparameter sweeps:
pip install desisky[wandb]from desisky.training import VAETrainer, VAETrainingConfig, WandbConfig
config = VAETrainingConfig(epochs=100, learning_rate=1e-4)
wandb_config = WandbConfig(project="desisky-vae", tags=["experiment-1"])
trainer = VAETrainer(model, config, wandb_config=wandb_config)
model, history = trainer.train(train_loader, test_loader)This logs all loss components (train/val) to your W&B dashboard every epoch. Add an on_epoch_end callback for custom visualization logging:
from desisky.training import log_figure
def on_epoch_end(model, history, epoch):
fig = plot_vae_reconstructions(originals, reconstructions, wavelength)
log_figure("viz/reconstructions", fig, epoch)
trainer = VAETrainer(
model, config,
wandb_config=wandb_config,
on_epoch_end=on_epoch_end,
)W&B hyperparameter sweeps are demonstrated in notebooks 07 and 08.
All visualization functions return plain matplotlib Figure objects and are usable with or without W&B:
from desisky.visualization import (
# Experiment tracking plots
plot_vae_reconstructions, # Original vs reconstructed spectra
plot_latent_corner, # Corner plot of latent dims, colored by sky condition
plot_latent_corner_comparison, # Corner plot comparing two latent distributions (e.g. real vs generated)
plot_cdf_comparison, # CDF + histogram with Wasserstein-1 (EMD) annotation
plot_conditional_validation_grid, # Feature statistics vs conditioning variable with 16-84% CI bands
plot_broadband_cdfs, # Broadband magnitude CDF comparison
plot_airglow_cdfs, # Airglow line intensity CDF comparison
# General diagnostics
plot_loss_curve, # Training/validation loss curves
plot_nn_outlier_analysis, # 2x3 diagnostic panel for MLP models
plot_latent_umap, # UMAP projection of latent space (requires umap-learn)
)Key inference timings (after JIT compilation):
| Operation | CPU | GPU (RTX 3090) | Speedup |
|---|---|---|---|
| LDM inference (100 spectra, deterministic) | ~4,500 ms | ~130 ms | ~35x |
| VAE encode+decode (1,000 spectra) | ~27 ms | ~1.8 ms | ~15x |
| Broadband MLP (1,000 rows) | ~0.06 ms | ~0.06 ms | ~1x |
For practical inference runs (thousands of spectra), GPU is clearly worth it. See docs/BENCHMARKS.md for full results across platforms.
For running on NERSC Perlmutter with GPU support, see docs/NERSC_SETUP.md.
| Notebook | Description |
|---|---|
| 00_tutorial.ipynb | Start here -- End-to-end walkthrough of every major capability |
| 01_broadband_training.ipynb | Train broadband model on moon-contaminated subset |
| 02_vae_inference.ipynb | VAE encoding/decoding and latent space visualization |
| 03_vae_analysis.ipynb | Latent space interpolation and anomaly detection |
| 04_vae_training.ipynb | Train VAE from scratch with InfoVAE-MMD objective |
| 05_ldm_inference.ipynb | Generate dark/moon/twilight spectra with EDM sampler |
| 06_ldm_training.ipynb | Train LDM from scratch with EDM framework and EMA |
| 07_vae_wandb_training.ipynb | VAE + W&B: reconstruction plots, latent corners, sweeps |
| 08_ldm_wandb_training.ipynb | LDM + W&B: CDF comparisons, validation grids, sweeps |
| Notebook | Description |
|---|---|
| investigate_outlier_spectra.ipynb | Investigation of anomalous LDM-generated spectra (dark samples looking moon-like, twilight outliers) |
| specsim_polar_comparison.ipynb | Side-by-side comparison of desisky vs specsim sky predictions in polar coordinates |
| desi_data_pres.ipynb | Companion notebook for the DESI data systems presentation (slide figures, validation plots) |
Some notebooks require packages beyond desisky[all] (e.g. specsim) — install via pip install -r notebooks/requirements.txt.
desisky/
├── src/desisky/
│ ├── data/ # Data loading, enrichment, spectral analysis
│ │ ├── skyspec.py # SkySpecVAC class with subset filtering
│ │ ├── _core.py # Download utilities with SHA-256 verification
│ │ ├── _enrich.py # V-band, eclipse, solar flux, coordinates
│ │ ├── _spectral.py # Airglow line intensities, broadband magnitudes
│ │ └── _splits.py # Validation mask utilities
│ ├── models/ # Model architectures (JAX/Equinox)
│ │ ├── broadband.py # Broadband MLP
│ │ ├── vae.py # SkyVAE encoder-decoder
│ │ └── ldm.py # 1D U-Net + EDM preconditioning
│ ├── io/ # Model I/O and checkpoint handling
│ │ └── model_io.py # Save/load with JSON header + binary weights
│ ├── inference/ # Sampling algorithms
│ │ └── sampling.py # EDM Heun ODE solver, classifier-free guidance
│ ├── training/ # Training infrastructure
│ │ ├── trainer.py # BroadbandTrainer
│ │ ├── vae_trainer.py # VAETrainer (InfoVAE-MMD)
│ │ ├── ldm_trainer.py # LatentDiffusionTrainer (EDM)
│ │ ├── dataset.py # PyTorch Dataset/DataLoader wrappers
│ │ ├── losses.py # L2, Huber loss functions
│ │ ├── vae_losses.py # InfoVAE-MMD loss with RBF kernel
│ │ └── wandb_utils.py # W&B logging utilities
│ ├── visualization/ # Plotting
│ │ ├── plots.py # Loss curves, outlier analysis, broadband band panels
│ │ └── wandb_plots.py # Reconstructions, corner plots, CDFs, validation grids
│ └── scripts/ # CLI tools
│ ├── download_data.py # desisky-data command
│ ├── train_broadband.py # desisky-train-broadband
│ ├── train_vae.py # desisky-train-vae
│ ├── train_ldm.py # desisky-train-ldm
│ ├── infer_broadband.py # desisky-infer-broadband
│ ├── infer_vae.py # desisky-infer-vae
│ └── infer_ldm.py # desisky-infer-ldm
├── tests/ # 382 unit tests
├── examples/ # 9 user-facing tutorial notebooks
├── notebooks/ # Analysis & presentation notebooks
├── scripts/ # Standalone scripts (NERSC VAE training, data prep)
├── jobs/ # SLURM job scripts for NERSC training
├── benchmarks/ # Inference/training timing benchmarks
├── docs/
│ ├── CLI_GUIDE.md # CLI data formats, output formats, wandb reference
│ ├── BENCHMARKS.md # Cross-platform timing results (Entropy + NERSC Perlmutter)
│ └── NERSC_SETUP.md # NERSC environment setup and GPU training recipes
├── pyproject.toml
├── CHANGELOG.md
└── LICENSE.txt
git clone https://github.com/MatthewDowicz/desisky.git
cd desisky
pip install -e ".[all]"
pip install pytest pytest-cov
# Run all tests
pytest
# Run with coverage
pytest --cov=desisky --cov-report=html
# Run specific test file
pytest tests/test_model_io.py -vdesisky is distributed under the terms of the MIT license.