Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ logs/

# Slurm batch scripts (user-specific, not part of the package)
*.sbatch
*.sh
90 changes: 89 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Pkg.add(url="https://github.com/LearningToOptimize/DecisionRulesExa.jl.git")

```julia
using DecisionRulesExa
using ExaModels, Flux, Random
using ExaModels, MadNLP, Flux, Random

Random.seed!(1)

Expand Down Expand Up @@ -92,6 +92,94 @@ train_tsddr(policy, x0, prob, ..., sampler;

Each pool entry gets its own MadNLP solver on a dedicated thread, with CUDA handles properly bound.

## Optional Critic Control Variates

`train_tsddr` can optionally train a scalar critic `C(w, xhat)`. The critic does
not replace the deterministic-equivalent solve in the default
`:control_variate` mode: solved target-constraint multipliers remain the
actor's primary local sensitivity signal. Instead, the critic supplies a learned
rollout-value guide and optional control variate.

For critic fitting, the preferred target is the stage-wise rollout objective via
`RolloutCriticTarget`, with `policy_state = :target` by default. This matches the
differentiable target recurrence used by the actor while evaluating the true
stage-by-stage objective. Set `policy_state = :realized` to train the critic on
closed-loop realized-state rollout labels. For ablations, use
`DeterministicEquivalentCriticTarget()` or `critic_training_target =
:deterministic_equivalent` to fit the deterministic-equivalent objective instead.

The default `control_variate = NoCriticControlVariate()` recovers the original
dual-only behavior. A scalar critic can be attached with:

```julia
input_dim = length(x0) + 2 * T * nx
critic = Chain(
Dense(input_dim => 128, tanh),
Dense(128 => 128, tanh),
Dense(128 => 1),
)

cv = ScalarCriticControlVariate(
critic;
featurizer = default_critic_featurizer,
value_loss_weight = 1.0,
gradient_loss_weight = 0.0,
)

critic_target = RolloutCriticTarget(
stage_problem;
horizon = T,
n_uncertainty = nx,
set_stage_parameters! = set_stage_parameters!,
realized_state = realized_state,
objective_no_target_penalty = objective_no_target_penalty,
policy_state = :target,
objective_value = :objective_no_target_penalty,
)

train_tsddr(
policy, x0, prob, prob.p_x0, prob.p_target, prob.p_w, sampler;
control_variate = cv,
critic_training_target = critic_target,
critic_rollout_samples_per_batch = 1,
actor_gradient_mode = :control_variate,
critic_cv_weight = 1.0,
num_cheap_critic_samples_per_batch = 32,
critic_updates_per_batch = 1,
critic_optimizer = Flux.Adam(1f-3),
)
```

The critic loss combines value matching against the selected target objective and
optional gradient matching against target multipliers:

```julia
value_loss_weight * mse(C(w, xhat), objective) +
gradient_loss_weight * mse(gradient(xhat -> C(w, xhat), xhat), lambda)
```

Set either weight to zero for objective-only or gradient-only critic training.
For rollout targets, objective-only critic training is usually the clean default,
because DE target multipliers are not exact gradients of the realized rollout
objective. If objectives and multipliers have very different magnitudes, prefer a
custom featurizer and tuned loss weights; the Hydro example normalizes volumes
and inflows before critic evaluation.

Actor modes:

- `:control_variate`: subtracts `critic_cv_weight * gradient_xhat(C)` from the
solved dual signal and adds the critic actor gradient back on solved or cheap
rollout samples. This is the recommended mode when dual multipliers are
reliable. `critic_cv_weight = 0.0` recovers dual-only updates.
- `:surrogate`: uses a practical hybrid of solved dual gradients and critic
actor gradients, controlled by `dual_actor_weight` and `critic_actor_weight`.
This is useful when raw dual/subgradient signals are empirically noisy or
unstable, but it is no longer a pure unbiased control-variate estimator.

`num_cheap_critic_samples_per_batch` draws additional uncertainty samples,
rolls out the current policy, and evaluates the critic actor term without any
extra MadNLP or ExaModels solve.

## Rollout evaluation

`RolloutEvaluation` evaluates the policy in deployment semantics (stage-by-stage sequential solves) on held-out scenarios:
Expand Down
1 change: 1 addition & 0 deletions examples/HydroPowerModels/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
CUDSS_jll = "4889d778-9329-5762-9fec-0578a5d30366"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DecisionRulesExa = "7c3e91a4-d8f2-4b6a-9e15-a2c4f7b80d53"
Expand Down
17 changes: 17 additions & 0 deletions examples/HydroPowerModels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Pre-solved deterministic-equivalent references (MOF format) are provided for val
| File | Description |
|---|---|
| `train_hydro_exa.jl` | Main training script with penalty scheduling, parallel GPU solves, and W&B logging |
| `train_hydro_exa_critic.jl` | Critic/control-variate variant of the main training script; uses normalized hydro features, a replay buffer, and cheap critic rollouts |
| `hydro_power_data.jl` | Data parsing (PowerModels JSON, hydro JSON, inflows CSV) |
| `hydro_power_exa.jl` | ExaModels problem builder for DC and AC OPF formulations |
| `eval_exa_de.jl` | Validation script comparing ExaModels results against JuMP reference |
Expand All @@ -55,6 +56,20 @@ julia --project -t auto train_hydro_exa.jl

Set `USE_GPU = true` in `train_hydro_exa.jl` (default). Requires a CUDA-capable GPU.

### GPU training with critic control variate

```julia
# From this directory:
julia --project -t auto train_hydro_exa_critic.jl
```

The critic script keeps the dual-multiplier actor update but adds a damped
control variate (`critic_cv_weight = 0.5`) trained on the stage-wise rollout
objective without target penalty. Its default critic rollout uses
`policy_state = :target`; set `CRITIC_POLICY_STATE = :realized` for closed-loop
critic labels. Deterministic-equivalent critic fitting remains available as an
ablation through `DeterministicEquivalentCriticTarget()`.

### CPU training

Set `USE_GPU = false` in `train_hydro_exa.jl`, then run the same command.
Expand All @@ -81,6 +96,8 @@ Key parameters in `train_hydro_exa.jl`:
- **Evaluation scheduling**: rollout evaluation starts with 4 scenarios and ramps to 32 at halfway
- **Parallel solves**: independent NLP copies solved concurrently via `Threads.@spawn` worker pool
- **Parallel rollout**: evaluation scenarios distributed across CPU stage-problem copies
- **Critic variant**: optional scalar critic with value and gradient matching,
replay-buffer training, and cheap critic actor samples
- **W&B logging**: training loss, rollout objectives, violation share, penalty multiplier

## Validation
Expand Down
9 changes: 7 additions & 2 deletions examples/HydroPowerModels/hydro_power_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,20 @@ end
"""
sample_scenario(hydro_data, T) -> Vector{Float64}

Sample one inflow trajectory of length T*nHyd (flat, stage-major order).
Sample one inflow trajectory of length `T*nHyd` (flat, stage-major order).

Uses **joint** scenario sampling: at each stage one scenario index `ω` is drawn
and applied to all hydro reservoirs, preserving the spatial correlation present
in the historical inflow data. This matches SDDP's `SDDP.parameterize` semantics.
"""
function sample_scenario(hydro_data::HydroData, T::Int)
nHyd = hydro_data.nHyd
w = Vector{Float64}(undef, T * nHyd)
for t in 1:T
t_row = mod1(t, hydro_data.nStagesSample)
# One scenario index per stage — all reservoirs share it (joint sampling).
j = rand(1:hydro_data.nScenarios)
for r in 1:nHyd
j = rand(1:hydro_data.nScenarios)
w[(t-1)*nHyd + r] = hydro_data.scenario_inflows[r][t_row, j]
end
end
Expand Down
Loading