From 4ef0c8cd98a7747dc33cce3dff32a2957b858c09 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Thu, 18 Jun 2026 21:23:34 -0400 Subject: [PATCH 1/9] add critic first attempt --- README.md | 88 ++++ examples/HydroPowerModels/Project.toml | 1 + examples/HydroPowerModels/README.md | 17 + .../train_hydro_exa_critic.jl | 458 ++++++++++++++++++ src/DecisionRulesExa.jl | 15 + src/critic_control_variate.jl | 322 ++++++++++++ src/rollout.jl | 36 +- src/training.jl | 304 +++++++++++- test/runtests.jl | 144 ++++++ 9 files changed, 1371 insertions(+), 14 deletions(-) create mode 100644 examples/HydroPowerModels/train_hydro_exa_critic.jl create mode 100644 src/critic_control_variate.jl diff --git a/README.md b/README.md index a97ad10..a9ab925 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,94 @@ train_tsddr(policy, x0, prob, ..., sampler; Each pool entry gets its own MadNLP solver on a dedicated thread, with CUDA handles properly bound. +## Optional Critic Control Variates + +`train_tsddr` can optionally train a scalar critic `C(w, xhat)`. The critic does +not replace the deterministic-equivalent solve in the default +`:control_variate` mode: solved target-constraint multipliers remain the +actor's primary local sensitivity signal. Instead, the critic supplies a learned +rollout-value guide and optional control variate. + +For critic fitting, the preferred target is the stage-wise rollout objective via +`RolloutCriticTarget`, with `policy_state = :target` by default. This matches the +differentiable target recurrence used by the actor while evaluating the true +stage-by-stage objective. Set `policy_state = :realized` to train the critic on +closed-loop realized-state rollout labels. For ablations, use +`DeterministicEquivalentCriticTarget()` or `critic_training_target = +:deterministic_equivalent` to fit the deterministic-equivalent objective instead. + +The default `control_variate = NoCriticControlVariate()` recovers the original +dual-only behavior. A scalar critic can be attached with: + +```julia +input_dim = length(x0) + 2 * T * nx +critic = Chain( + Dense(input_dim => 128, tanh), + Dense(128 => 128, tanh), + Dense(128 => 1), +) + +cv = ScalarCriticControlVariate( + critic; + featurizer = default_critic_featurizer, + value_loss_weight = 1.0, + gradient_loss_weight = 0.0, +) + +critic_target = RolloutCriticTarget( + stage_problem; + horizon = T, + n_uncertainty = nx, + set_stage_parameters! = set_stage_parameters!, + realized_state = realized_state, + objective_no_target_penalty = objective_no_target_penalty, + policy_state = :target, + objective_value = :objective_no_target_penalty, +) + +train_tsddr( + policy, x0, prob, prob.p_x0, prob.p_target, prob.p_w, sampler; + control_variate = cv, + critic_training_target = critic_target, + critic_rollout_samples_per_batch = 1, + actor_gradient_mode = :control_variate, + critic_cv_weight = 1.0, + num_cheap_critic_samples_per_batch = 32, + critic_updates_per_batch = 1, + critic_optimizer = Flux.Adam(1f-3), +) +``` + +The critic loss combines value matching against the selected target objective and +optional gradient matching against target multipliers: + +```julia +value_loss_weight * mse(C(w, xhat), objective) + +gradient_loss_weight * mse(gradient(xhat -> C(w, xhat), xhat), lambda) +``` + +Set either weight to zero for objective-only or gradient-only critic training. +For rollout targets, objective-only critic training is usually the clean default, +because DE target multipliers are not exact gradients of the realized rollout +objective. If objectives and multipliers have very different magnitudes, prefer a +custom featurizer and tuned loss weights; the Hydro example normalizes volumes +and inflows before critic evaluation. + +Actor modes: + +- `:control_variate`: subtracts `critic_cv_weight * gradient_xhat(C)` from the + solved dual signal and adds the critic actor gradient back on solved or cheap + rollout samples. This is the recommended mode when dual multipliers are + reliable. `critic_cv_weight = 0.0` recovers dual-only updates. +- `:surrogate`: uses a practical hybrid of solved dual gradients and critic + actor gradients, controlled by `dual_actor_weight` and `critic_actor_weight`. + This is useful when raw dual/subgradient signals are empirically noisy or + unstable, but it is no longer a pure unbiased control-variate estimator. + +`num_cheap_critic_samples_per_batch` draws additional uncertainty samples, +rolls out the current policy, and evaluates the critic actor term without any +extra MadNLP or ExaModels solve. + ## Rollout evaluation `RolloutEvaluation` evaluates the policy in deployment semantics (stage-by-stage sequential solves) on held-out scenarios: diff --git a/examples/HydroPowerModels/Project.toml b/examples/HydroPowerModels/Project.toml index 9a4fe03..30e1a40 100644 --- a/examples/HydroPowerModels/Project.toml +++ b/examples/HydroPowerModels/Project.toml @@ -1,6 +1,7 @@ [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" CUDSS_jll = "4889d778-9329-5762-9fec-0578a5d30366" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DecisionRulesExa = "7c3e91a4-d8f2-4b6a-9e15-a2c4f7b80d53" diff --git a/examples/HydroPowerModels/README.md b/examples/HydroPowerModels/README.md index 8d799e4..65f91a6 100644 --- a/examples/HydroPowerModels/README.md +++ b/examples/HydroPowerModels/README.md @@ -39,6 +39,7 @@ Pre-solved deterministic-equivalent references (MOF format) are provided for val | File | Description | |---|---| | `train_hydro_exa.jl` | Main training script with penalty scheduling, parallel GPU solves, and W&B logging | +| `train_hydro_exa_critic.jl` | Critic/control-variate variant of the main training script; uses normalized hydro features, a replay buffer, and cheap critic rollouts | | `hydro_power_data.jl` | Data parsing (PowerModels JSON, hydro JSON, inflows CSV) | | `hydro_power_exa.jl` | ExaModels problem builder for DC and AC OPF formulations | | `eval_exa_de.jl` | Validation script comparing ExaModels results against JuMP reference | @@ -55,6 +56,20 @@ julia --project -t auto train_hydro_exa.jl Set `USE_GPU = true` in `train_hydro_exa.jl` (default). Requires a CUDA-capable GPU. +### GPU training with critic control variate + +```julia +# From this directory: +julia --project -t auto train_hydro_exa_critic.jl +``` + +The critic script keeps the dual-multiplier actor update but adds a damped +control variate (`critic_cv_weight = 0.5`) trained on the stage-wise rollout +objective without target penalty. Its default critic rollout uses +`policy_state = :target`; set `CRITIC_POLICY_STATE = :realized` for closed-loop +critic labels. Deterministic-equivalent critic fitting remains available as an +ablation through `DeterministicEquivalentCriticTarget()`. + ### CPU training Set `USE_GPU = false` in `train_hydro_exa.jl`, then run the same command. @@ -81,6 +96,8 @@ Key parameters in `train_hydro_exa.jl`: - **Evaluation scheduling**: rollout evaluation starts with 4 scenarios and ramps to 32 at halfway - **Parallel solves**: independent NLP copies solved concurrently via `Threads.@spawn` worker pool - **Parallel rollout**: evaluation scenarios distributed across CPU stage-problem copies +- **Critic variant**: optional scalar critic with value and gradient matching, + replay-buffer training, and cheap critic actor samples - **W&B logging**: training loss, rollout objectives, violation share, penalty multiplier ## Validation diff --git a/examples/HydroPowerModels/train_hydro_exa_critic.jl b/examples/HydroPowerModels/train_hydro_exa_critic.jl new file mode 100644 index 0000000..24e696c --- /dev/null +++ b/examples/HydroPowerModels/train_hydro_exa_critic.jl @@ -0,0 +1,458 @@ +# train_hydro_exa_critic.jl +# +# HydroPowerModels training with ExaModels + MadNLP (DC or AC OPF). +# Uses train_tsddr from DecisionRulesExa with a scalar critic control variate. +# +# Usage: +# julia --project -t auto train_hydro_exa_critic.jl + +using DecisionRulesExa +using ExaModels +using Flux +using Statistics, Random, Dates +using Wandb, Logging +using JLD2 +using MadNLP +using MadNLPGPU, KernelAbstractions, CUDA +using CUDSS, CUDSS_jll, cuDNN + +const SCRIPT_DIR = dirname(@__FILE__) +include(joinpath(SCRIPT_DIR, "hydro_power_data.jl")) +include(joinpath(SCRIPT_DIR, "hydro_power_exa.jl")) + +# ── Configuration ───────────────────────────────────────────────────────────── + +const CASE_NAME = "bolivia" +const FORMULATION = :ac_polar # :dc or :ac_polar +const FORM_LABEL = FORMULATION === :ac_polar ? "ACPPowerModel" : "DCPPowerModel" + +const CASE_DIR = joinpath(SCRIPT_DIR, CASE_NAME) +const PM_FILE = joinpath(CASE_DIR, "PowerModels.json") +const HYDRO_FILE = joinpath(CASE_DIR, "hydro.json") +const INFLOW_FILE = joinpath(CASE_DIR, "inflows.csv") +const DEMAND_FILE = joinpath(CASE_DIR, "demand.csv") + +const LAYERS = [128, 128] +const ACTIVATION = sigmoid +const NUM_STAGES = 96 +const NUM_EPOCHS = 80 +const NUM_BATCHES = 100 +const MAX_EVAL_SCENARIOS = 32 +const EVAL_EVERY = 25 + +const EVAL_SCHEDULE = [ + (1, div(NUM_EPOCHS * NUM_BATCHES, 2), 4), + (div(NUM_EPOCHS * NUM_BATCHES, 2) + 1, NUM_EPOCHS * NUM_BATCHES, MAX_EVAL_SCENARIOS), +] +const LR = 1f-3 +const GRAD_CLIP = 10.0f0 +const CRITIC_LR = 5f-4 +const CRITIC_HIDDEN = [256, 128] +const CRITIC_VALUE_LOSS_WEIGHT = 1.0 +const CRITIC_GRADIENT_LOSS_WEIGHT = 0.0 +const CRITIC_CV_WEIGHT = 0.5 +const CRITIC_UPDATES_PER_BATCH = 2 +const CRITIC_BUFFER_SIZE = 512 +const CRITIC_BATCH_SIZE = 32 + +const TARGET_PEN_ARG = :auto +const DEFICIT_COST = 1e5 +const USE_GPU = true +const load_scaler = 0.6 +const NUM_WORKERS = 4 +const CRITIC_ROLLOUT_SAMPLES_PER_BATCH = 0 # eval rollouts feed the critic via external_critic_samples +const CRITIC_POLICY_STATE = :target # set to :realized for closed-loop critic targets +const CRITIC_ROLLOUT_OBJECTIVE = :objective_no_target_penalty +const NUM_CHEAP_CRITIC_SAMPLES_PER_BATCH = 4 * NUM_WORKERS + +const PENALTY_SCHEDULE = [ + (1, div(NUM_EPOCHS * NUM_BATCHES, 4), 0.1), + (div(NUM_EPOCHS * NUM_BATCHES, 4) + 1, div(NUM_EPOCHS * NUM_BATCHES, 4) * 2, 1.0), + (div(NUM_EPOCHS * NUM_BATCHES, 4) * 2 + 1, div(NUM_EPOCHS * NUM_BATCHES, 4) * 3, 10.0), + (div(NUM_EPOCHS * NUM_BATCHES, 4) * 3 + 1, NUM_EPOCHS * NUM_BATCHES, 30.0), +] + +const NUM_TRAIN_SCHEDULE = [ + (1, div(NUM_EPOCHS * NUM_BATCHES, 5), NUM_WORKERS), + (div(NUM_EPOCHS * NUM_BATCHES, 5) + 1, div(NUM_EPOCHS * NUM_BATCHES, 5) * 2, 2 * NUM_WORKERS), + (div(NUM_EPOCHS * NUM_BATCHES, 5) * 2 + 1, div(NUM_EPOCHS * NUM_BATCHES, 5) * 3, 4 * NUM_WORKERS), + (div(NUM_EPOCHS * NUM_BATCHES, 5) * 3 + 1, div(NUM_EPOCHS * NUM_BATCHES, 5) * 4, 8 * NUM_WORKERS), + (div(NUM_EPOCHS * NUM_BATCHES, 5) * 4 + 1, NUM_EPOCHS * NUM_BATCHES, 8 * NUM_WORKERS), +] + +const SOLVER_KWARGS = (print_level = MadNLP.ERROR, tol = 1e-6, max_iter = 9000) + +const RUN_NAME = "$(CASE_NAME)-$(FORM_LABEL)-h$(NUM_STAGES)-deteq-gpu-critic-cv-$(Dates.format(now(), "yyyymmdd-HHMMSS"))" +const MODEL_DIR = joinpath(CASE_DIR, FORM_LABEL, "models") +mkpath(MODEL_DIR) +const MODEL_PATH = joinpath(MODEL_DIR, RUN_NAME * ".jld2") + +const PRE_TRAINED = nothing + +# ── Load data ───────────────────────────────────────────────────────────────── + +@info "Loading power system data..." +power_data = load_power_data(PM_FILE) +@info " nBus=$(power_data.nBus) nGen=$(power_data.nGen)" + +@info "Loading hydro data..." +hydro_data = load_hydro_data(HYDRO_FILE, INFLOW_FILE, power_data; + num_stages = NUM_STAGES * 10) +nHyd = hydro_data.nHyd +T = NUM_STAGES +@info " nHyd=$(nHyd) nScenarios=$(hydro_data.nScenarios)" + +demand_mat = if isfile(DEMAND_FILE) + @info "Loading demand from $(DEMAND_FILE)..." + load_demand(DEMAND_FILE, power_data; T = T) +else + nothing +end + +# ── Build ExaModels DE ──────────────────────────────────────────────────────── + +resolved_pen = TARGET_PEN_ARG === :auto ? + auto_target_penalty(power_data, hydro_data) : + Float64(TARGET_PEN_ARG) +@info "Auto target penalty: ρ=$(round(resolved_pen; digits=2))" + +backend = USE_GPU ? (@info "Using GPU backend"; CUDA.CUDABackend()) : + (@info "Using CPU backend"; nothing) + +function _build_de() + build_hydro_de(power_data, hydro_data, T; + backend = backend, + float_type = Float64, + formulation = FORMULATION, + target_penalty = TARGET_PEN_ARG, + deficit_cost = DEFICIT_COST, + demand_matrix = demand_mat, + load_scaler = load_scaler, + ) +end + +@info "Building $(T)-stage ExaModels DE (formulation=$FORMULATION)..." +prob = _build_de() + +@info "Building $(NUM_WORKERS)-worker problem pool..." +problem_pool = [(prob, prob.p_x0, prob.p_target, prob.p_inflow)] +for i in 2:NUM_WORKERS + p = _build_de() + push!(problem_pool, (p, p.p_x0, p.p_target, p.p_inflow)) +end +@info " Pool ready: $(NUM_WORKERS) independent DE instances on GPU" + +x0_init = Float32.([clamp(hydro_data.initial_volumes[r], + hydro_data.units[r].min_vol, + hydro_data.units[r].max_vol) + for r in 1:nHyd]) + +# ── Critic/control variate ─────────────────────────────────────────────────── + +volume_scale = Float32.([max(abs(u.max_vol), abs(u.min_vol), 1.0) for u in hydro_data.units]) +inflow_scale = Float32.([ + max(maximum(abs, hydro_data.scenario_inflows[r]), 1.0) + for r in 1:nHyd +]) + +const _full_inflow_scale = repeat(inflow_scale, T) +const _full_volume_scale = repeat(volume_scale, T) + +function hydro_critic_featurizer(initial_state, w_flat, xhat_flat) + x0_scaled = Float32.(initial_state) ./ volume_scale + w_scaled = Float32.(w_flat) ./ _full_inflow_scale + x_scaled = Float32.(xhat_flat) ./ _full_volume_scale + return vcat(x0_scaled, w_scaled, x_scaled) +end + +critic_input_dim = nHyd + 2 * T * nHyd +critic_layers = Any[] +critic_in = critic_input_dim +for h in CRITIC_HIDDEN + push!(critic_layers, Flux.Dense(critic_in => h, tanh)) + global critic_in = h +end +push!(critic_layers, Flux.Dense(critic_in => 1)) +critic = Flux.Chain(critic_layers...) + +control_variate = ScalarCriticControlVariate( + critic; + featurizer = hydro_critic_featurizer, + value_loss_weight = CRITIC_VALUE_LOSS_WEIGHT, + gradient_loss_weight = CRITIC_GRADIENT_LOSS_WEIGHT, +) + +# ── Smoke test ──────────────────────────────────────────────────────────────── + +w_mean = mean_inflow(hydro_data, T) +ExaModels.set_parameter!(prob.core, prob.p_x0, x0_init) +ExaModels.set_parameter!(prob.core, prob.p_inflow, w_mean) +ExaModels.set_parameter!(prob.core, prob.p_target, zeros(T * nHyd)) +@info "Smoke test: solving DE with mean inflows..." +result0 = MadNLP.madnlp(prob.model; SOLVER_KWARGS..., print_level = MadNLP.WARN) +@info " Status: $(result0.status) Objective: $(round(result0.objective; digits=4))" +isfinite(result0.objective) || error("Smoke test returned non-finite objective") +solve_succeeded(result0) || @warn "Smoke test did not fully converge; proceeding anyway" + +# ── Policy ──────────────────────────────────────────────────────────────────── + +policy = StateConditionedPolicy(nHyd, nHyd, nHyd, LAYERS; + activation = ACTIVATION, + encoder_type = Flux.LSTM) + +if !isnothing(PRE_TRAINED) + @info "Loading pre-trained model from $(PRE_TRAINED)..." + Flux.loadmodel!(policy, JLD2.load(PRE_TRAINED, "model_state")) +end + +# ── W&B logging ─────────────────────────────────────────────────────────────── + +lg = WandbLogger( + project = "RL", + name = RUN_NAME, + save_code = false, + config = Dict( + "case" => CASE_NAME, + "formulation" => FORM_LABEL, + "num_stages" => T, + "layers" => LAYERS, + "activation" => string(ACTIVATION), + "target_penalty" => "auto=$(round(resolved_pen; digits=2))", + "deficit_cost" => DEFICIT_COST, + "num_epochs" => NUM_EPOCHS, + "num_batches" => NUM_BATCHES, + "max_eval_scenarios" => MAX_EVAL_SCENARIOS, + "eval_schedule" => string(EVAL_SCHEDULE), + "eval_every" => EVAL_EVERY, + "lr" => LR, + "grad_clip" => GRAD_CLIP, + "critic_lr" => CRITIC_LR, + "critic_hidden" => CRITIC_HIDDEN, + "critic_value_loss_weight" => CRITIC_VALUE_LOSS_WEIGHT, + "critic_gradient_loss_weight" => CRITIC_GRADIENT_LOSS_WEIGHT, + "critic_cv_weight" => CRITIC_CV_WEIGHT, + "critic_updates_per_batch" => CRITIC_UPDATES_PER_BATCH, + "critic_buffer_size" => CRITIC_BUFFER_SIZE, + "critic_batch_size" => CRITIC_BATCH_SIZE, + "critic_training_target" => "rollout", + "critic_rollout_samples_per_batch" => CRITIC_ROLLOUT_SAMPLES_PER_BATCH, + "critic_policy_state" => string(CRITIC_POLICY_STATE), + "critic_rollout_objective" => string(CRITIC_ROLLOUT_OBJECTIVE), + "num_cheap_critic_samples_per_batch" => NUM_CHEAP_CRITIC_SAMPLES_PER_BATCH, + "backend" => USE_GPU ? "GPU" : "CPU", + "load_scaler" => load_scaler, + "penalty_schedule" => string(PENALTY_SCHEDULE), + "num_train_schedule" => string(NUM_TRAIN_SCHEDULE), + "num_workers" => NUM_WORKERS, + ), +) + +# ── Training ────────────────────────────────────────────────────────────────── + +Random.seed!(8788) + +best_obj = Inf +epoch_losses = Float64[] + +stage_demand = demand_mat === nothing ? nothing : demand_mat[1:1, :] +function _build_rollout_de() + build_hydro_de(power_data, hydro_data, 1; + backend = nothing, + float_type = Float64, + formulation = FORMULATION, + target_penalty = TARGET_PEN_ARG, + deficit_cost = DEFICIT_COST, + demand_matrix = stage_demand, + load_scaler = load_scaler, + ) +end +rollout_prob = _build_rollout_de() +rollout_pool = [_build_rollout_de() for _ in 1:NUM_WORKERS] +@info "Rollout pool ready: $(NUM_WORKERS) CPU stage-problem copies" + +function set_hydro_rollout_stage!(stage_prob, state_in, wt, target, stage) + ExaModels.set_parameter!(stage_prob.core, stage_prob.p_x0, state_in) + ExaModels.set_parameter!(stage_prob.core, stage_prob.p_inflow, wt) + if demand_mat !== nothing + set_demand!(stage_prob, load_scaler .* demand_mat[stage:stage, :]) + end + ExaModels.set_parameter!(stage_prob.core, stage_prob.p_target, target) + return stage_prob +end + +hydro_realized_state(stage_prob, result) = + Array(hydro_solution(stage_prob, result).reservoir[:, end]) + +function hydro_objective_no_target_penalty(stage_prob, result) + sol = hydro_solution(stage_prob, result) + return result.objective - (resolved_pen / 2) * sum(abs2, Array(sol.delta)) +end + +Random.seed!(8789) +eval_scenarios = [sample_scenario(hydro_data, T) for _ in 1:MAX_EVAL_SCENARIOS] +rollout_evaluation = RolloutEvaluation( + rollout_prob, + x0_init, + eval_scenarios; + horizon = T, + n_uncertainty = nHyd, + set_stage_parameters! = set_hydro_rollout_stage!, + realized_state = hydro_realized_state, + objective_no_target_penalty = hydro_objective_no_target_penalty, + madnlp_kwargs = SOLVER_KWARGS, + warmstart = true, + stride = EVAL_EVERY, + policy_state = :target, + stage_problem_pool = rollout_pool, + active_scenarios = 4, +) +realized_rollout_evaluation = RolloutEvaluation( + rollout_prob, + x0_init, + eval_scenarios; + horizon = T, + n_uncertainty = nHyd, + set_stage_parameters! = set_hydro_rollout_stage!, + realized_state = hydro_realized_state, + objective_no_target_penalty = hydro_objective_no_target_penalty, + madnlp_kwargs = SOLVER_KWARGS, + warmstart = true, + stride = EVAL_EVERY, + policy_state = :realized, + stage_problem_pool = rollout_pool, + active_scenarios = 4, +) + +critic_training_target = RolloutCriticTarget( + rollout_prob; + horizon = T, + n_uncertainty = nHyd, + set_stage_parameters! = set_hydro_rollout_stage!, + realized_state = hydro_realized_state, + objective_no_target_penalty = hydro_objective_no_target_penalty, + madnlp_kwargs = SOLVER_KWARGS, + warmstart = true, + policy_state = CRITIC_POLICY_STATE, + objective_value = CRITIC_ROLLOUT_OBJECTIVE, +) + +Random.seed!(8788) + +function _schedule_value(schedule, iter, default) + for (lo, hi, val) in schedule + lo <= iter <= hi && return val + end + return default +end + +current_penalty_mult = Ref(NaN) +shared_critic_samples = Any[] + +train_tsddr( + policy, + x0_init, + prob, + prob.p_x0, + prob.p_target, + prob.p_inflow, + () -> sample_scenario(hydro_data, T); # returns flat Float32 vector, length T*nHyd + num_batches = NUM_EPOCHS * NUM_BATCHES, + num_train_per_batch = NUM_WORKERS, + optimizer = Flux.Optimisers.OptimiserChain( + Flux.Optimisers.ClipGrad(GRAD_CLIP), + Flux.Adam(LR), + ), + madnlp_kwargs = SOLVER_KWARGS, + warmstart = true, + problem_pool = problem_pool, + control_variate = control_variate, + actor_gradient_mode = :control_variate, + critic_cv_weight = CRITIC_CV_WEIGHT, + critic_updates_per_batch = CRITIC_UPDATES_PER_BATCH, + critic_buffer_size = CRITIC_BUFFER_SIZE, + critic_batch_size = CRITIC_BATCH_SIZE, + critic_training_target = critic_training_target, + critic_rollout_samples_per_batch = CRITIC_ROLLOUT_SAMPLES_PER_BATCH, + num_cheap_critic_samples_per_batch = NUM_CHEAP_CRITIC_SAMPLES_PER_BATCH, + critic_optimizer = Flux.Adam(CRITIC_LR), + external_critic_samples = shared_critic_samples, + adjust_hyperparameters = (iter, opt_state, n) -> begin + mult = _schedule_value(PENALTY_SCHEDULE, iter, last(PENALTY_SCHEDULE)[3]) + if mult != current_penalty_mult[] + current_penalty_mult[] = mult + ρ_half_scaled = prob.base_penalty_half * mult + penalty_vals = fill(ρ_half_scaled, T * nHyd) + for (p, _, _, _) in problem_pool + ExaModels.set_parameter!(p.core, p.p_penalty_half, penalty_vals) + end + @info "Penalty multiplier → $mult (ρ/2 = $(round(ρ_half_scaled; digits=2)))" + end + n_eval = _schedule_value(EVAL_SCHEDULE, iter, MAX_EVAL_SCENARIOS) + rollout_evaluation.active_scenarios = n_eval + realized_rollout_evaluation.active_scenarios = n_eval + return _schedule_value(NUM_TRAIN_SCHEDULE, iter, n) + end, + record_loss = (iter, m, loss, tag) -> begin + metrics = Dict{String, Any}(tag => loss, "batch" => iter) + isfinite(loss) && push!(epoch_losses, loss) + + if iter % EVAL_EVERY == 0 + rollout_evaluation(iter, m) + realized_rollout_evaluation(iter, m) + append!(shared_critic_samples, + critic_samples_from_evaluation(rollout_evaluation)) + metrics["metrics/rollout_objective_no_deficit"] = + rollout_evaluation.last_objective_no_target_penalty + metrics["metrics/rollout_target_violation_share"] = + rollout_evaluation.last_violation_share + metrics["metrics/rollout_realized_objective_no_deficit"] = + realized_rollout_evaluation.last_objective_no_target_penalty + metrics["metrics/rollout_realized_target_violation_share"] = + realized_rollout_evaluation.last_violation_share + metrics["metrics/rollout_n_ok"] = + realized_rollout_evaluation.last_n_ok + end + + if !isnan(current_penalty_mult[]) + metrics["metrics/target_penalty_multiplier"] = current_penalty_mult[] + end + metrics["metrics/num_train_per_batch"] = + _schedule_value(NUM_TRAIN_SCHEDULE, iter, 1) + metrics["metrics/active_eval_scenarios"] = + _schedule_value(EVAL_SCHEDULE, iter, MAX_EVAL_SCENARIOS) + + batch_in_epoch = (iter - 1) % NUM_BATCHES + 1 + if batch_in_epoch == NUM_BATCHES + epoch = (iter - 1) ÷ NUM_BATCHES + 1 + mean_loss = isempty(epoch_losses) ? NaN : mean(epoch_losses) + n_ok = length(epoch_losses) + empty!(epoch_losses) + Wandb.log(lg, Dict("metrics/epoch_objective" => mean_loss, "epoch" => epoch)) + @info "Epoch $epoch/$NUM_EPOCHS mean=$(round(mean_loss; digits=2)) ok=$n_ok/$NUM_BATCHES" + if isfinite(mean_loss) && mean_loss < best_obj + global best_obj = mean_loss + jldsave(MODEL_PATH; + model_state = Flux.state(cpu(m)), + critic_state = Flux.state(cpu(critic)), + critic_config = Dict( + "mode" => "control_variate", + "training_target" => "rollout", + "policy_state" => string(CRITIC_POLICY_STATE), + "rollout_objective" => string(CRITIC_ROLLOUT_OBJECTIVE), + "critic_cv_weight" => CRITIC_CV_WEIGHT, + "value_loss_weight" => CRITIC_VALUE_LOSS_WEIGHT, + "gradient_loss_weight" => CRITIC_GRADIENT_LOSS_WEIGHT, + "critic_rollout_samples_per_batch" => CRITIC_ROLLOUT_SAMPLES_PER_BATCH, + "num_cheap_critic_samples_per_batch" => NUM_CHEAP_CRITIC_SAMPLES_PER_BATCH, + ), + ) + @info " → New best: $(round(mean_loss; digits=4)) — saved $MODEL_PATH" + end + end + Wandb.log(lg, metrics) + return false + end, +) + +close(lg) +@info "Done. Best model saved to: $(MODEL_PATH)" diff --git a/src/DecisionRulesExa.jl b/src/DecisionRulesExa.jl index ad02924..4e23112 100644 --- a/src/DecisionRulesExa.jl +++ b/src/DecisionRulesExa.jl @@ -13,6 +13,7 @@ using ChainRulesCore include("utils.jl") include("deterministic_equivalent.jl") include("policy.jl") +include("critic_control_variate.jl") include("training.jl") include("rollout.jl") @@ -43,6 +44,20 @@ export solve_succeeded, materialize_tangent, _all_finite_gradient, + AbstractCriticControlVariate, + AbstractCriticTrainingTarget, + NoCriticControlVariate, + DeterministicEquivalentCriticTarget, + RolloutCriticTarget, + ScalarCriticControlVariate, + CriticSample, + CriticReplayBuffer, + default_critic_featurizer, + critic_value, + critic_xhat_gradient, + critic_loss, + update_critic!, + critic_samples_from_evaluation, simulate_tsddr, train_tsddr, diff --git a/src/critic_control_variate.jl b/src/critic_control_variate.jl new file mode 100644 index 0000000..3f4fbbe --- /dev/null +++ b/src/critic_control_variate.jl @@ -0,0 +1,322 @@ +# critic_control_variate.jl +# +# Optional scalar critic / control-variate support for TS-DDR training. + +abstract type AbstractCriticControlVariate end + +abstract type AbstractCriticTrainingTarget end + +""" + DeterministicEquivalentCriticTarget() + +Train critic value targets from the full deterministic-equivalent objective. +This is useful for ablations and for pure DE control-variate experiments. +""" +struct DeterministicEquivalentCriticTarget <: AbstractCriticTrainingTarget end + +""" + RolloutCriticTarget(stage_problem; kwargs...) + +Train critic value targets from stage-wise rollout evaluation. This is the +preferred target when the critic is meant to guide convergence of the deployed +rollout objective rather than the deterministic-equivalent surrogate. + +Required keyword callbacks match `rollout_tsddr`: +- `set_stage_parameters!` +- `realized_state` + +By default `policy_state = :target`, matching the differentiable target +recurrence used by the actor. Set `policy_state = :realized` to train on +closed-loop realized-state rollout targets. +""" +struct RolloutCriticTarget{S,R,O,M} <: AbstractCriticTrainingTarget + stage_problem + horizon::Int + n_uncertainty::Int + set_stage_parameters!::S + realized_state::R + objective_no_target_penalty::O + madnlp_kwargs::M + warmstart::Bool + policy_state::Symbol + reuse_solver::Bool + objective_value::Symbol +end + +function RolloutCriticTarget( + stage_problem; + horizon::Int, + n_uncertainty::Int, + set_stage_parameters!::Function, + realized_state::Function, + objective_no_target_penalty::Function = (prob, result) -> result.objective, + madnlp_kwargs = NamedTuple(), + warmstart::Bool = true, + policy_state::Symbol = :target, + reuse_solver::Bool = false, + objective_value::Symbol = :objective_no_target_penalty, +) + policy_state in (:target, :realized) || + error("policy_state must be :target or :realized") + objective_value in (:objective, :objective_no_target_penalty) || + error("objective_value must be :objective or :objective_no_target_penalty") + return RolloutCriticTarget( + stage_problem, + horizon, + n_uncertainty, + set_stage_parameters!, + realized_state, + objective_no_target_penalty, + madnlp_kwargs, + warmstart, + policy_state, + reuse_solver, + objective_value, + ) +end + +""" + NoCriticControlVariate() + +Default no-op critic configuration. Passing this to `train_tsddr` recovers the +original dual-multiplier actor update. +""" +struct NoCriticControlVariate <: AbstractCriticControlVariate end + +""" + ScalarCriticControlVariate(critic; featurizer=default_critic_featurizer, + value_loss_weight=0.1, + gradient_loss_weight=1.0) + +Wrap a scalar Flux-compatible critic `C(w, xhat)` for optional TS-DDR +control-variate training. The critic is called as `critic(features)`, where +`features = featurizer(initial_state, uncertainty, xhat)`. + +The critic loss is + + value_loss_weight * mse(C, objective) + + gradient_loss_weight * mse(gradient(xhat -> C, xhat), target_multipliers) + +Either loss weight may be zero. +""" +struct ScalarCriticControlVariate{C,F} <: AbstractCriticControlVariate + critic::C + featurizer::F + value_loss_weight::Float64 + gradient_loss_weight::Float64 +end + +function ScalarCriticControlVariate( + critic; + featurizer = default_critic_featurizer, + value_loss_weight::Real = 0.1, + gradient_loss_weight::Real = 1.0, +) + value_loss_weight >= 0 || error("value_loss_weight must be nonnegative") + gradient_loss_weight >= 0 || error("gradient_loss_weight must be nonnegative") + return ScalarCriticControlVariate( + critic, + featurizer, + Float64(value_loss_weight), + Float64(gradient_loss_weight), + ) +end + +""" + CriticSample(initial_state, uncertainty, xhat, objective_value, + target_multipliers; metadata=nothing) + +Training sample for a scalar critic. Samples are produced from already-solved +TS-DDR scenarios and do not require additional optimization solves. +""" +struct CriticSample{I,W,X,L,M} + initial_state::I + uncertainty::W + xhat::X + objective_value::Float64 + target_multipliers::L + metadata::M +end + +function CriticSample( + initial_state, + uncertainty, + xhat, + objective_value::Real, + target_multipliers; + metadata = nothing, +) + return CriticSample( + initial_state, + uncertainty, + xhat, + Float64(objective_value), + target_multipliers, + metadata, + ) +end + +mutable struct CriticReplayBuffer{S} + samples::Vector{S} + max_size::Int +end + +CriticReplayBuffer(max_size::Integer) = + CriticReplayBuffer{Any}(Any[], max(0, Int(max_size))) + +function push_critic_sample!(buffer::CriticReplayBuffer, sample::CriticSample) + buffer.max_size == 0 && return buffer + push!(buffer.samples, sample) + overflow = length(buffer.samples) - buffer.max_size + overflow > 0 && deleteat!(buffer.samples, 1:overflow) + return buffer +end + +function push_critic_samples!(buffer::CriticReplayBuffer, samples) + for sample in samples + push_critic_sample!(buffer, sample) + end + return buffer +end + +""" + default_critic_featurizer(initial_state, uncertainty, xhat) + +Default critic featurizer: concatenate flattened initial state, uncertainty, and +policy target trajectory. +""" +default_critic_featurizer(initial_state, uncertainty, xhat) = + vcat(vec(initial_state), vec(uncertainty), vec(xhat)) + +_scalar_output(y::Number) = y +_scalar_output(y::AbstractArray) = begin + length(y) == 1 || error("critic must return a scalar or length-1 array, got length $(length(y))") + return only(vec(y)) +end + +function _critic_value(critic, featurizer, initial_state, uncertainty, xhat) + features = featurizer(initial_state, uncertainty, xhat) + return _scalar_output(critic(features)) +end + +""" + critic_value(control_variate, initial_state, uncertainty, xhat) + +Evaluate the scalar critic on one scenario. +""" +critic_value( + cv::ScalarCriticControlVariate, + initial_state, + uncertainty, + xhat, +) = _critic_value(cv.critic, cv.featurizer, initial_state, uncertainty, xhat) + +""" + critic_xhat_gradient(control_variate, initial_state, uncertainty, xhat) + +Return `gradient(xhat -> C(initial_state, uncertainty, xhat), xhat)` and check +that it has the same shape as `xhat`. +""" +function critic_xhat_gradient( + cv::ScalarCriticControlVariate, + initial_state, + uncertainty, + xhat, +) + gx = Zygote.gradient(x -> critic_value(cv, initial_state, uncertainty, x), xhat)[1] + gx = gx === nothing ? zero(xhat) : gx + size(gx) == size(xhat) || + error("critic xhat gradient shape $(size(gx)) does not match xhat shape $(size(xhat))") + return gx +end + +function _check_critic_sample_shapes(sample::CriticSample, grad_xhat = nothing) + size(sample.target_multipliers) == size(sample.xhat) || + error("target_multipliers shape $(size(sample.target_multipliers)) does not match xhat shape $(size(sample.xhat))") + if grad_xhat !== nothing + size(grad_xhat) == size(sample.xhat) || + error("critic xhat gradient shape $(size(grad_xhat)) does not match xhat shape $(size(sample.xhat))") + end + return true +end + +function _critic_loss_with( + critic, + cv::ScalarCriticControlVariate, + samples; + value_loss_weight::Real = cv.value_loss_weight, + gradient_loss_weight::Real = cv.gradient_loss_weight, +) + isempty(samples) && return 0.0 + value_w = Float64(value_loss_weight) + grad_w = Float64(gradient_loss_weight) + value_w >= 0 || error("value_loss_weight must be nonnegative") + grad_w >= 0 || error("gradient_loss_weight must be nonnegative") + + total = 0.0 + for sample in samples + _check_critic_sample_shapes(sample) + if value_w > 0 + pred = _critic_value( + critic, + cv.featurizer, + sample.initial_state, + sample.uncertainty, + sample.xhat, + ) + target = convert(typeof(pred), sample.objective_value) + total = total + value_w * abs2(pred - target) + end + if grad_w > 0 + gx = Zygote.gradient(sample.xhat) do x + _critic_value(critic, cv.featurizer, sample.initial_state, sample.uncertainty, x) + end[1] + gx = gx === nothing ? zero(sample.xhat) : gx + _check_critic_sample_shapes(sample, gx) + total = total + grad_w * sum(abs2, gx .- sample.target_multipliers) / length(sample.xhat) + end + end + return total / length(samples) +end + +""" + critic_loss(control_variate, samples; value_loss_weight, gradient_loss_weight) + +Compute the scalar critic loss on a collection of `CriticSample`s. +""" +critic_loss(cv::ScalarCriticControlVariate, samples; kwargs...) = + _critic_loss_with(cv.critic, cv, samples; kwargs...) + +function _critic_minibatch(samples, batch_size) + n = length(samples) + n == 0 && return samples + if batch_size === nothing || batch_size >= n + return samples + end + idx = rand(1:n, Int(batch_size)) + return samples[idx] +end + +""" + update_critic!(opt_state, control_variate, samples; batch_size=nothing) + +Run one critic optimizer step and return the numeric loss. Only critic +parameters are updated. +""" +function update_critic!( + opt_state, + cv::ScalarCriticControlVariate, + samples; + batch_size = nothing, +) + batch = _critic_minibatch(samples, batch_size) + isempty(batch) && return NaN + gs = Zygote.gradient(cv.critic) do critic + _critic_loss_with(critic, cv, batch) + end + grad = materialize_tangent(gs[1]) + if grad !== nothing && _all_finite_gradient(grad) + Flux.update!(opt_state, cv.critic, grad) + end + return Float64(critic_loss(cv, batch)) +end diff --git a/src/rollout.jl b/src/rollout.jl index 1485496..4139a25 100644 --- a/src/rollout.jl +++ b/src/rollout.jl @@ -152,6 +152,7 @@ mutable struct RolloutEvaluation <: Function last_objective_no_target_penalty::Float64 last_violation_share::Float64 last_n_ok::Int + last_scenario_data::Vector{Any} end function RolloutEvaluation( @@ -197,10 +198,12 @@ function RolloutEvaluation( NaN, NaN, 0, + Any[], ) end function (evaluation::RolloutEvaluation)(iter, model) + empty!(evaluation.last_scenario_data) iter % evaluation.stride == 0 || return nothing n_eval = min(evaluation.active_scenarios, length(evaluation.scenarios)) @@ -234,6 +237,7 @@ function (evaluation::RolloutEvaluation)(iter, model) total += result.objective total_no_penalty += result.objective_no_target_penalty n_ok += 1 + push!(evaluation.last_scenario_data, (i, result)) end else # Parallel path: distribute scenarios across pool @@ -267,11 +271,12 @@ function (evaluation::RolloutEvaluation)(iter, model) results[round_start + j - 1] = fetch(t) end end - for r in results + for (idx, r) in enumerate(results) r === nothing && continue total += r.objective total_no_penalty += r.objective_no_target_penalty n_ok += 1 + push!(evaluation.last_scenario_data, (idx, r)) end end @@ -291,3 +296,32 @@ function (evaluation::RolloutEvaluation)(iter, model) ) return nothing end + +""" + critic_samples_from_evaluation(eval; objective_key) -> Vector{CriticSample} + +Convert the last rollout evaluation results into `CriticSample`s for critic +training. Target multipliers are zero (rollout evaluation does not produce +duals), so these samples contribute only to the value loss term. +""" +function critic_samples_from_evaluation( + eval_obj::RolloutEvaluation; + objective_key::Symbol = :objective_no_target_penalty, +) + isempty(eval_obj.last_scenario_data) && return CriticSample[] + F = eltype(eval_obj.initial_state) + samples = CriticSample[] + for (i, result) in eval_obj.last_scenario_data + w_flat = eval_obj.scenarios[i] + xhat_flat = F.(vcat(result.target_trajectory...)) + obj = Float64(getfield(result, objective_key)) + push!(samples, CriticSample( + F.(eval_obj.initial_state), + F.(w_flat), + xhat_flat, + obj, + zeros(F, length(xhat_flat)), + )) + end + return samples +end diff --git a/src/training.jl b/src/training.jl index 534ac22..688863b 100644 --- a/src/training.jl +++ b/src/training.jl @@ -168,6 +168,128 @@ function simulate_tsddr( return (objective = result.objective, lambda = F.(Array(λ))) end +function _rollout_xhat_flat(model, initial_state, w_flat, T::Int, F) + nw = length(w_flat) ÷ T + nx = length(initial_state) + Flux.reset!(model) + buf = Zygote.Buffer(zeros(F, nx * T)) + prev = F.(initial_state) + for t in 1:T + wt = F.(w_flat[(t-1)*nw+1 : t*nw]) + xt = model(vcat(wt, prev)) + for i in 1:nx + buf[(t-1)*nx + i] = xt[i] + end + prev = xt + end + return copy(buf) +end + +_has_critic(::NoCriticControlVariate) = false +_has_critic(::AbstractCriticControlVariate) = true + +function _validate_critic_training_args(; + actor_gradient_mode, + critic_cv_weight, + dual_actor_weight, + critic_actor_weight, + critic_updates_per_batch, + critic_buffer_size, + critic_rollout_samples_per_batch, + num_cheap_critic_samples_per_batch, +) + actor_gradient_mode in (:control_variate, :surrogate) || + error("actor_gradient_mode must be :control_variate or :surrogate") + critic_cv_weight >= 0 || error("critic_cv_weight must be nonnegative") + dual_actor_weight >= 0 || error("dual_actor_weight must be nonnegative") + critic_actor_weight >= 0 || error("critic_actor_weight must be nonnegative") + critic_updates_per_batch >= 0 || error("critic_updates_per_batch must be nonnegative") + critic_buffer_size >= 0 || error("critic_buffer_size must be nonnegative") + if critic_rollout_samples_per_batch !== nothing + critic_rollout_samples_per_batch >= 0 || + error("critic_rollout_samples_per_batch must be nonnegative or nothing") + end + num_cheap_critic_samples_per_batch >= 0 || + error("num_cheap_critic_samples_per_batch must be nonnegative") + return true +end + +function _resolve_critic_training_target(target, has_critic::Bool) + has_critic || return DeterministicEquivalentCriticTarget() + target isa AbstractCriticTrainingTarget && return target + if target === :deterministic_equivalent || target === :de + return DeterministicEquivalentCriticTarget() + elseif target === :rollout + error("critic_training_target=:rollout requires a RolloutCriticTarget(...) configuration") + else + error("critic_training_target must be RolloutCriticTarget(...), DeterministicEquivalentCriticTarget(), :rollout, or :deterministic_equivalent") + end +end + +function _critic_sample_from_rollout( + model, + initial_state, + target::RolloutCriticTarget, + w_flat, + lambda, + F, + solver_state, +) + result = rollout_tsddr( + model, + initial_state, + target.stage_problem, + w_flat; + horizon = target.horizon, + n_uncertainty = target.n_uncertainty, + set_stage_parameters! = target.set_stage_parameters!, + realized_state = target.realized_state, + objective_no_target_penalty = target.objective_no_target_penalty, + madnlp_kwargs = target.madnlp_kwargs, + warmstart = target.warmstart, + policy_state = target.policy_state, + solver_state = solver_state, + reuse_solver = target.reuse_solver, + ) + result === nothing && return nothing + + objective = target.objective_value === :objective ? + result.objective : result.objective_no_target_penalty + xhat_flat = F.(vcat(result.target_trajectory...)) + return CriticSample(F.(initial_state), F.(w_flat), xhat_flat, objective, F.(lambda)) +end + +function _rollout_critic_samples( + model, + initial_state, + target::RolloutCriticTarget, + de_samples, + F, + max_samples, + solver_state, +) + isempty(de_samples) && return CriticSample[] + n = max_samples === nothing ? length(de_samples) : min(Int(max_samples), length(de_samples)) + n == 0 && return CriticSample[] + idx = n == length(de_samples) ? eachindex(de_samples) : randperm(length(de_samples))[1:n] + samples = CriticSample[] + for i in idx + s = de_samples[i] + sample = _critic_sample_from_rollout( + model, + initial_state, + target, + s.uncertainty, + s.target_multipliers, + F, + solver_state, + ) + sample === nothing && continue + push!(samples, sample) + end + return samples +end + # ── train_tsddr ─────────────────────────────────────────────────────────────── """ @@ -202,6 +324,20 @@ Keyword arguments (mirror `train_multistage`): - `problem_pool` : vector of `(de, p_x0, p_target, p_uncertainty)` tuples for parallel GPU solves; each entry gets its own MadNLP solver and samples are distributed round-robin across the pool +- `control_variate` : optional `ScalarCriticControlVariate`; default + `NoCriticControlVariate()` recovers the original update +- `critic_training_target` : `RolloutCriticTarget(...)` for rollout-objective + critic fitting, or `DeterministicEquivalentCriticTarget()` + / `:deterministic_equivalent` for DE ablations +- `critic_rollout_samples_per_batch`: number of solved batch scenarios to rerun + through stage-wise rollout for critic targets; + `nothing` uses all successful solved scenarios +- `actor_gradient_mode` : `:control_variate` or `:surrogate` +- `num_cheap_critic_samples_per_batch`: extra policy rollouts used only for + critic actor terms; these do not trigger NLP solves +- `external_critic_samples` : mutable vector; `record_loss` can push + `CriticSample`s (e.g. from `critic_samples_from_evaluation`) + to feed the critic replay buffer without extra solves """ function train_tsddr( model, @@ -225,11 +361,40 @@ function train_tsddr( madnlp_kwargs = NamedTuple(), warmstart::Bool = true, problem_pool = nothing, + control_variate::AbstractCriticControlVariate = NoCriticControlVariate(), + actor_gradient_mode::Symbol = :control_variate, + critic_cv_weight::Real = 1.0, + dual_actor_weight::Real = 1.0, + critic_actor_weight::Real = 1.0, + critic_updates_per_batch::Int = 1, + critic_buffer_size::Int = 0, + critic_batch_size = nothing, + critic_training_target = :rollout, + critic_rollout_samples_per_batch = nothing, + num_cheap_critic_samples_per_batch::Int = 0, + critic_optimizer = Flux.Adam(1f-3), + external_critic_samples = nothing, ) T = det_equivalent.horizon F = eltype(initial_state) nx = length(initial_state) + _validate_critic_training_args( + actor_gradient_mode = actor_gradient_mode, + critic_cv_weight = critic_cv_weight, + dual_actor_weight = dual_actor_weight, + critic_actor_weight = critic_actor_weight, + critic_updates_per_batch = critic_updates_per_batch, + critic_buffer_size = critic_buffer_size, + critic_rollout_samples_per_batch = critic_rollout_samples_per_batch, + num_cheap_critic_samples_per_batch = num_cheap_critic_samples_per_batch, + ) + has_critic = _has_critic(control_variate) + resolved_critic_training_target = _resolve_critic_training_target( + critic_training_target, + has_critic, + ) + # ── Build worker pool ──────────────────────────────────────────────────── if problem_pool === nothing _pool = [(det_equivalent, p_x0, p_target, p_uncertainty)] @@ -277,6 +442,12 @@ function train_tsddr( end opt_state = Flux.setup(optimizer, model) + critic_opt_state = has_critic ? Flux.setup(critic_optimizer, control_variate.critic) : nothing + critic_buffer = CriticReplayBuffer(critic_buffer_size) + critic_rollout_solver_state = resolved_critic_training_target isa RolloutCriticTarget && + resolved_critic_training_target.reuse_solver ? + _make_solver(resolved_critic_training_target.stage_problem.model, + resolved_critic_training_target.madnlp_kwargs) : nothing try @@ -339,31 +510,138 @@ function train_tsddr( # Step 3: Collect valid results valid = Vector{Tuple{Vector{F}, Vector{F}}}() + de_samples = CriticSample[] obj_sum = 0.0 - for r in solve_ok + for (s, r) in enumerate(solve_ok) r === nothing && continue push!(valid, (r[1], r[2])) + if has_critic + _, xhat_flat = sample_data[s] + push!(de_samples, CriticSample(F.(initial_state), r[1], xhat_flat, r[3], r[2])) + end obj_sum += r[3] end n_ok = length(valid) mean_obj = n_ok > 0 ? obj_sum / n_ok : NaN + if has_critic && n_ok > 0 && critic_updates_per_batch > 0 + valid_samples = if resolved_critic_training_target isa RolloutCriticTarget + _rollout_critic_samples( + model, + initial_state, + resolved_critic_training_target, + de_samples, + F, + critic_rollout_samples_per_batch, + critic_rollout_solver_state, + ) + else + de_samples + end + if external_critic_samples !== nothing && !isempty(external_critic_samples) + merged = Any[] + append!(merged, valid_samples) + append!(merged, external_critic_samples) + empty!(external_critic_samples) + valid_samples = merged + end + if critic_buffer_size > 0 + push_critic_samples!(critic_buffer, valid_samples) + critic_samples = critic_buffer.samples + else + critic_samples = valid_samples + end + for _ in 1:critic_updates_per_batch + update_critic!( + critic_opt_state, + control_variate, + critic_samples; + batch_size = critic_batch_size, + ) + end + end + # ── Gradient: ∇_θ (1/n) Σ_s ⟨λ_s, rollout_s(θ)⟩ ──────────────────── if n_ok > 0 - gs = Zygote.gradient(model) do m - total = zero(F) - for (w_flat_s, λf) in valid - nw = length(w_flat_s) ÷ T - Flux.reset!(m) - prev_ad = F.(initial_state) - for t in 1:T - wt = F.(w_flat_s[(t-1)*nw+1 : t*nw]) - xt = m(vcat(wt, prev_ad)) - total = total + sum(λf[(t-1)*nx+1 : t*nx] .* xt) - prev_ad = xt + if !has_critic + gs = Zygote.gradient(model) do m + total = zero(F) + for (w_flat_s, λf) in valid + nw = length(w_flat_s) ÷ T + Flux.reset!(m) + prev_ad = F.(initial_state) + for t in 1:T + wt = F.(w_flat_s[(t-1)*nw+1 : t*nw]) + xt = m(vcat(wt, prev_ad)) + total = total + sum(λf[(t-1)*nx+1 : t*nx] .* xt) + prev_ad = xt + end + end + total / F(n_ok) + end + else + solved_weights = Vector{Tuple{Vector{F}, Vector{F}}}() + for sample in de_samples + λf = F.(sample.target_multipliers) + if actor_gradient_mode === :control_variate + if critic_cv_weight == 0 + actor_weight = λf + else + gx = F.(critic_xhat_gradient( + control_variate, + sample.initial_state, + sample.uncertainty, + sample.xhat, + )) + _check_critic_sample_shapes(sample, gx) + actor_weight = λf .- F(critic_cv_weight) .* gx + end + else + actor_weight = F(dual_actor_weight) .* λf + end + push!(solved_weights, (F.(sample.uncertainty), actor_weight)) + end + + critic_uncertainties = if num_cheap_critic_samples_per_batch > 0 + [F.(uncertainty_sampler()) for _ in 1:num_cheap_critic_samples_per_batch] + else + [F.(sample.uncertainty) for sample in de_samples] + end + + gs = Zygote.gradient(model) do m + residual_total = zero(F) + for (w_flat_s, actor_weight) in solved_weights + nw = length(w_flat_s) ÷ T + Flux.reset!(m) + prev_ad = F.(initial_state) + for t in 1:T + wt = F.(w_flat_s[(t-1)*nw+1 : t*nw]) + xt = m(vcat(wt, prev_ad)) + residual_total = + residual_total + sum(actor_weight[(t-1)*nx+1 : t*nx] .* xt) + prev_ad = xt + end + end + actor_loss = residual_total / F(n_ok) + + critic_coeff = actor_gradient_mode === :control_variate ? + F(critic_cv_weight) : F(critic_actor_weight) + if critic_coeff != 0 && !isempty(critic_uncertainties) + critic_total = zero(actor_loss) + for w_flat_s in critic_uncertainties + xhat_ad = _rollout_xhat_flat(m, initial_state, w_flat_s, T, F) + critic_total = critic_total + critic_value( + control_variate, + F.(initial_state), + w_flat_s, + xhat_ad, + ) + end + actor_loss = actor_loss + + critic_coeff * critic_total / F(length(critic_uncertainties)) end + actor_loss end - total / F(n_ok) end grad = materialize_tangent(gs[1]) diff --git a/test/runtests.jl b/test/runtests.jl index 561c405..857e673 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,8 @@ using Test using DecisionRulesExa +using Flux +using Random +using Zygote @testset "DeterministicEquivalentProblem (CPU)" begin T = 6 @@ -39,3 +42,144 @@ using DecisionRulesExa @test length(u_sol) == n_u @test length(δ_sol) == n_x end + +function _state_vector(m) + out = Float64[] + function visit(x) + if x isa AbstractArray && eltype(x) <: Number + append!(out, vec(Float64.(x))) + elseif x isa NamedTuple + foreach(visit, values(x)) + elseif x isa Tuple + foreach(visit, x) + end + return nothing + end + visit(Flux.state(m)) + return out +end + +@testset "Critic control variate helpers" begin + initial_state = Float32[1] + uncertainty = Float32[0.2, -0.1] + xhat = Float32[1.5, -2.0] + + quadratic = x -> sum(abs2, x) / 2 + cv = ScalarCriticControlVariate( + quadratic; + featurizer = (x0, w, x) -> x, + value_loss_weight = 1.0, + gradient_loss_weight = 1.0, + ) + + @test critic_value(cv, initial_state, uncertainty, xhat) ≈ sum(abs2, xhat) / 2 + @test critic_xhat_gradient(cv, initial_state, uncertainty, xhat) ≈ xhat + + sample = CriticSample( + initial_state, + uncertainty, + xhat, + sum(abs2, xhat) / 2, + copy(xhat), + ) + @test critic_loss(cv, [sample]) ≈ 0 atol = 1e-6 + + bad_sample = CriticSample(initial_state, uncertainty, xhat, 0.0, Float32[1.0]) + @test_throws ErrorException critic_loss(cv, [bad_sample]) + + value_only = ScalarCriticControlVariate( + quadratic; + featurizer = (x0, w, x) -> x, + value_loss_weight = 1.0, + gradient_loss_weight = 0.0, + ) + grad_only = ScalarCriticControlVariate( + quadratic; + featurizer = (x0, w, x) -> x, + value_loss_weight = 0.0, + gradient_loss_weight = 1.0, + ) + hybrid = ScalarCriticControlVariate( + quadratic; + featurizer = (x0, w, x) -> x, + value_loss_weight = 0.1, + gradient_loss_weight = 1.0, + ) + @test isfinite(critic_loss(value_only, [sample])) + @test isfinite(critic_loss(grad_only, [sample])) + @test isfinite(critic_loss(hybrid, [sample])) +end + +@testset "Critic and actor update separation" begin + Random.seed!(11) + critic = Chain(Dense(2 => 1, bias = false)) + critic[1].weight .= 0.0f0 + cv = ScalarCriticControlVariate( + critic; + featurizer = (x0, w, x) -> x, + value_loss_weight = 1.0, + gradient_loss_weight = 0.0, + ) + sample = CriticSample(Float32[0], Float32[0], Float32[1, -1], 2.0, Float32[0, 0]) + + critic_before = _state_vector(critic) + actor = Chain(Dense(1 => 2, bias = false)) + actor_before = _state_vector(actor) + opt_state = Flux.setup(Flux.Descent(0.1f0), critic) + loss = update_critic!(opt_state, cv, [sample]) + @test isfinite(loss) + @test _state_vector(critic) != critic_before + @test _state_vector(actor) == actor_before + + critic_before_actor_update = _state_vector(critic) + actor_opt = Flux.setup(Flux.Descent(0.1f0), actor) + gs = Zygote.gradient(actor) do m + x = m(Float32[1]) + critic_value(cv, Float32[0], Float32[0], x) + end + Flux.update!(actor_opt, actor, materialize_tangent(gs[1])) + @test _state_vector(actor) != actor_before + @test _state_vector(critic) == critic_before_actor_update + + for (value_w, grad_w) in ((0.0, 1.0), (0.1, 1.0)) + c = Chain(Dense(2 => 1, bias = false)) + c[1].weight .= 0.0f0 + cv_step = ScalarCriticControlVariate( + c; + featurizer = (x0, w, x) -> x, + value_loss_weight = value_w, + gradient_loss_weight = grad_w, + ) + s = CriticSample(Float32[0], Float32[0], Float32[1, -1], 2.0, Float32[1, -1]) + before = _state_vector(c) + st = Flux.setup(Flux.Descent(0.1f0), c) + @test isfinite(update_critic!(st, cv_step, [s])) + @test _state_vector(c) != before + end +end + +@testset "Control-variate actor gradients" begin + Random.seed!(12) + actor_dual = Chain(Dense(1 => 2, bias = false)) + Random.seed!(12) + actor_cv = Chain(Dense(1 => 2, bias = false)) + + x_in = Float32[2] + lambda = Float32[1, -3] + zero_cv = ScalarCriticControlVariate( + x -> zero(sum(x)); + featurizer = (x0, w, x) -> x, + ) + + g_dual = Zygote.gradient(actor_dual) do m + sum(lambda .* m(x_in)) + end[1] + g_cv = Zygote.gradient(actor_cv) do m + xhat = m(x_in) + gx = critic_xhat_gradient(zero_cv, Float32[0], Float32[0], xhat) + sum((lambda .- gx) .* xhat) + critic_value(zero_cv, Float32[0], Float32[0], xhat) + end[1] + + @test materialize_tangent(g_cv).layers[1].weight ≈ + materialize_tangent(g_dual).layers[1].weight +end From 08380d7201be77c16b1f98d02e9d6a44baa66ffe Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Sat, 20 Jun 2026 12:16:26 -0400 Subject: [PATCH 2/9] improve critic --- examples/HydroPowerModels/train_hydro_exa.jl | 4 ++++ examples/HydroPowerModels/train_hydro_exa_critic.jl | 11 +++++++++-- src/critic_control_variate.jl | 7 ++++++- src/rollout.jl | 6 ++++-- src/training.jl | 2 ++ 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/examples/HydroPowerModels/train_hydro_exa.jl b/examples/HydroPowerModels/train_hydro_exa.jl index 2a4f40e..62e6f90 100644 --- a/examples/HydroPowerModels/train_hydro_exa.jl +++ b/examples/HydroPowerModels/train_hydro_exa.jl @@ -314,10 +314,14 @@ train_tsddr( if iter % EVAL_EVERY == 0 rollout_evaluation(iter, m) realized_rollout_evaluation(iter, m) + metrics["metrics/rollout_objective_no_target_penalty"] = + rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_objective_no_deficit"] = rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_target_violation_share"] = rollout_evaluation.last_violation_share + metrics["metrics/rollout_realized_objective_no_target_penalty"] = + realized_rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_realized_objective_no_deficit"] = realized_rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_realized_target_violation_share"] = diff --git a/examples/HydroPowerModels/train_hydro_exa_critic.jl b/examples/HydroPowerModels/train_hydro_exa_critic.jl index 24e696c..3d9a5a0 100644 --- a/examples/HydroPowerModels/train_hydro_exa_critic.jl +++ b/examples/HydroPowerModels/train_hydro_exa_critic.jl @@ -62,7 +62,7 @@ const load_scaler = 0.6 const NUM_WORKERS = 4 const CRITIC_ROLLOUT_SAMPLES_PER_BATCH = 0 # eval rollouts feed the critic via external_critic_samples const CRITIC_POLICY_STATE = :target # set to :realized for closed-loop critic targets -const CRITIC_ROLLOUT_OBJECTIVE = :objective_no_target_penalty +const CRITIC_ROLLOUT_OBJECTIVE = :objective const NUM_CHEAP_CRITIC_SAMPLES_PER_BATCH = 4 * NUM_WORKERS const PENALTY_SCHEDULE = [ @@ -400,11 +400,18 @@ train_tsddr( rollout_evaluation(iter, m) realized_rollout_evaluation(iter, m) append!(shared_critic_samples, - critic_samples_from_evaluation(rollout_evaluation)) + critic_samples_from_evaluation( + rollout_evaluation; + objective_key = CRITIC_ROLLOUT_OBJECTIVE, + )) + metrics["metrics/rollout_objective_no_target_penalty"] = + rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_objective_no_deficit"] = rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_target_violation_share"] = rollout_evaluation.last_violation_share + metrics["metrics/rollout_realized_objective_no_target_penalty"] = + realized_rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_realized_objective_no_deficit"] = realized_rollout_evaluation.last_objective_no_target_penalty metrics["metrics/rollout_realized_target_violation_share"] = diff --git a/src/critic_control_variate.jl b/src/critic_control_variate.jl index 3f4fbbe..a1953d0 100644 --- a/src/critic_control_variate.jl +++ b/src/critic_control_variate.jl @@ -28,6 +28,11 @@ Required keyword callbacks match `rollout_tsddr`: By default `policy_state = :target`, matching the differentiable target recurrence used by the actor. Set `policy_state = :realized` to train on closed-loop realized-state rollout targets. + +By default `objective_value = :objective`, so critic value targets include the +same target-penalty contribution that appears in the dual actor signal. Set +`objective_value = :objective_no_target_penalty` to train on the rollout +objective with target-slack penalties removed. """ struct RolloutCriticTarget{S,R,O,M} <: AbstractCriticTrainingTarget stage_problem @@ -54,7 +59,7 @@ function RolloutCriticTarget( warmstart::Bool = true, policy_state::Symbol = :target, reuse_solver::Bool = false, - objective_value::Symbol = :objective_no_target_penalty, + objective_value::Symbol = :objective, ) policy_state in (:target, :realized) || error("policy_state must be :target or :realized") diff --git a/src/rollout.jl b/src/rollout.jl index 4139a25..f998d5b 100644 --- a/src/rollout.jl +++ b/src/rollout.jl @@ -302,11 +302,13 @@ end Convert the last rollout evaluation results into `CriticSample`s for critic training. Target multipliers are zero (rollout evaluation does not produce -duals), so these samples contribute only to the value loss term. +duals), so these samples contribute only to the value loss term. By default the +critic target uses the full rollout objective; pass +`objective_key = :objective_no_target_penalty` to remove target-slack penalties. """ function critic_samples_from_evaluation( eval_obj::RolloutEvaluation; - objective_key::Symbol = :objective_no_target_penalty, + objective_key::Symbol = :objective, ) isempty(eval_obj.last_scenario_data) && return CriticSample[] F = eltype(eval_obj.initial_state) diff --git a/src/training.jl b/src/training.jl index 688863b..eddb53d 100644 --- a/src/training.jl +++ b/src/training.jl @@ -235,6 +235,8 @@ function _critic_sample_from_rollout( F, solver_state, ) + # Keep both rollout objective variants available; target.objective_value + # below selects which one is used as the critic value target. result = rollout_tsddr( model, initial_state, From ea67e89e86c0211928c7c9a6d184cc6551c57e52 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Sat, 20 Jun 2026 14:46:23 -0400 Subject: [PATCH 3/9] update penalty --- examples/HydroPowerModels/hydro_power_exa.jl | 81 +++++++++++++++++--- examples/HydroPowerModels/train_hydro_exa.jl | 15 +++- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/examples/HydroPowerModels/hydro_power_exa.jl b/examples/HydroPowerModels/hydro_power_exa.jl index 2ad9329..40feee9 100644 --- a/examples/HydroPowerModels/hydro_power_exa.jl +++ b/examples/HydroPowerModels/hydro_power_exa.jl @@ -55,6 +55,8 @@ struct HydroExaDEProblem p_target p_penalty_half # ExaModels parameter for (ρ/2)*mult (length T*nHyd) base_penalty_half::Float64 # ρ/2 at multiplier=1 + p_penalty_l1 # ExaModels parameter for L1 penalty (length T*nHyd) + base_penalty_l1::Float64 # L1 coefficient at multiplier=1 # sizes nHyd::Int nBus::Int @@ -142,7 +144,8 @@ end """ build_hydro_de(power_data, hydro_data, T; backend=nothing, float_type=Float64, formulation=:dc, - target_penalty=:auto, demand_matrix=nothing, + target_penalty=:auto, target_penalty_l1=:auto, + demand_matrix=nothing, reactive_demand_matrix=nothing, deficit_cost=nothing) -> HydroExaDEProblem @@ -163,6 +166,10 @@ Pass a large value (e.g. 1e5, >> max thermal cost) to effectively enforce hard K `target_penalty` sets the L2 coefficient ρ for the `(ρ/2)·δ²` target slack penalty. Pass `:auto` (default) to use `2 × max_gen_cost`, matching JuMP's `penalty_l2 = :auto`. + +`target_penalty_l1` sets the L1 coefficient for the `λ·|δ|` target slack penalty. +Pass `:auto` (default) to use the same value as L2 ρ. Pass `nothing` to disable L1. +The L1 term is reformulated as `λ·(δ⁺ + δ⁻)` with `δ = δ⁺ − δ⁻`, `δ⁺,δ⁻ ≥ 0`. """ function build_hydro_de(power_data::PowerData, hydro_data::HydroData, @@ -171,6 +178,7 @@ function build_hydro_de(power_data::PowerData, float_type::Type{<:AbstractFloat} = Float64, formulation::Symbol = :dc, target_penalty::Union{Real,Symbol} = :auto, + target_penalty_l1::Union{Real,Symbol,Nothing} = :auto, demand_matrix = nothing, reactive_demand_matrix = nothing, deficit_cost::Union{Nothing,Real} = nothing, @@ -183,6 +191,7 @@ function build_hydro_de(power_data::PowerData, return _build_dc_hydro_de(power_data, hydro_data, T; backend=backend, float_type=float_type, target_penalty=target_penalty, + target_penalty_l1=target_penalty_l1, demand_matrix=demand_matrix, deficit_cost=deficit_cost, load_scaler=load_scaler) @@ -190,6 +199,7 @@ function build_hydro_de(power_data::PowerData, return _build_ac_hydro_de(power_data, hydro_data, T; backend=backend, float_type=float_type, target_penalty=target_penalty, + target_penalty_l1=target_penalty_l1, demand_matrix=demand_matrix, reactive_demand_matrix=reactive_demand_matrix, deficit_cost=deficit_cost, @@ -205,6 +215,7 @@ function _build_dc_hydro_de(power_data::PowerData, backend = nothing, float_type::Type{<:AbstractFloat} = Float64, target_penalty::Union{Real,Symbol} = :auto, + target_penalty_l1::Union{Real,Symbol,Nothing} = :auto, demand_matrix = nothing, deficit_cost::Union{Nothing,Real} = nothing, load_scaler::Real = 1.0) @@ -215,6 +226,14 @@ function _build_dc_hydro_de(power_data::PowerData, nHyd = hydro_data.nHyd K = float_type(hydro_data.K) ρ = float_type(target_penalty === :auto ? auto_target_penalty(power_data, hydro_data) : target_penalty) + ρ_l1 = if target_penalty_l1 === :auto + ρ + elseif target_penalty_l1 === nothing + zero(float_type) + else + float_type(target_penalty_l1) + end + use_l1 = ρ_l1 > 0 baseMVA = float_type(power_data.baseMVA) cd = float_type(deficit_cost !== nothing ? deficit_cost : power_data.cost_deficit) @@ -251,8 +270,9 @@ function _build_dc_hydro_de(power_data::PowerData, # Spill: T*nHyd (non-negative) spill = ExaModels.variable(core, T * nHyd; lvar = float_type(0)) - # Target slack δ: T*nHyd (free, penalized quadratically) - delta = ExaModels.variable(core, T * nHyd) + # Target slack: δ = δ⁺ − δ⁻ with δ⁺,δ⁻ ≥ 0 (L1+L2 Lagrangian penalty) + delta_pos = ExaModels.variable(core, T * nHyd; lvar = float_type(0)) + delta_neg = ExaModels.variable(core, T * nHyd; lvar = float_type(0)) # ── Parameters ──────────────────────────────────────────────────────────── @@ -267,6 +287,7 @@ function _build_dc_hydro_de(power_data::PowerData, p_inflow = ExaModels.parameter(core, zeros(float_type, T * nHyd)) p_target = ExaModels.parameter(core, zeros(float_type, T * nHyd)) p_penalty_half = ExaModels.parameter(core, fill(float_type(ρ / 2), T * nHyd)) + p_penalty_l1 = ExaModels.parameter(core, fill(ρ_l1, T * nHyd)) # ── Objective ───────────────────────────────────────────────────────────── @@ -285,12 +306,21 @@ function _build_dc_hydro_de(power_data::PowerData, for item in def_cost_items ) + # L2 penalty: (ρ/2)·(δ⁺ − δ⁻)² delta_items = [(idx = _ri(nHyd, t, r),) for t in 1:T for r in 1:nHyd] ExaModels.objective(core, - p_penalty_half[item.idx] * delta[item.idx]^2 + p_penalty_half[item.idx] * (delta_pos[item.idx] - delta_neg[item.idx])^2 for item in delta_items ) + # L1 penalty: λ·(δ⁺ + δ⁻) + if use_l1 + ExaModels.objective(core, + p_penalty_l1[item.idx] * (delta_pos[item.idx] + delta_neg[item.idx]) + for item in delta_items + ) + end + # ── Constraints ─────────────────────────────────────────────────────────── n_con = 0 @@ -417,12 +447,13 @@ function _build_dc_hydro_de(power_data::PowerData, n_con += T * nHyd # ── TARGET CONSTRAINTS (ADDED LAST) ─────────────────────────────────────── + # x̂ − x − (δ⁺ − δ⁻) = 0 target_items = [(param_idx = _ri(nHyd, t, r), res_idx = _ri(nHyd, t+1, r), delta_idx = _ri(nHyd, t, r)) for t in 1:T for r in 1:nHyd] ExaModels.constraint(core, - p_target[item.param_idx] - reservoir[item.res_idx] - delta[item.delta_idx] + p_target[item.param_idx] - reservoir[item.res_idx] - delta_pos[item.delta_idx] + delta_neg[item.delta_idx] for item in target_items ) target_con_range = (n_con + 1):(n_con + T * nHyd) @@ -433,6 +464,7 @@ function _build_dc_hydro_de(power_data::PowerData, core, model, p_demand, nothing, p_x0, p_inflow, p_target, p_penalty_half, Float64(ρ / 2), + p_penalty_l1, Float64(ρ_l1), nHyd, nBus, nGen, nBranch, T, :dc, target_con_range, ) @@ -446,6 +478,7 @@ function _build_ac_hydro_de(power_data::PowerData, backend = nothing, float_type::Type{<:AbstractFloat} = Float64, target_penalty::Union{Real,Symbol} = :auto, + target_penalty_l1::Union{Real,Symbol,Nothing} = :auto, demand_matrix = nothing, reactive_demand_matrix = nothing, deficit_cost::Union{Nothing,Real} = nothing, @@ -457,6 +490,14 @@ function _build_ac_hydro_de(power_data::PowerData, nHyd = hydro_data.nHyd K = float_type(hydro_data.K) ρ = float_type(target_penalty === :auto ? auto_target_penalty(power_data, hydro_data) : target_penalty) + ρ_l1 = if target_penalty_l1 === :auto + ρ + elseif target_penalty_l1 === nothing + zero(float_type) + else + float_type(target_penalty_l1) + end + use_l1 = ρ_l1 > 0 baseMVA = float_type(power_data.baseMVA) cd = float_type(deficit_cost !== nothing ? deficit_cost : power_data.cost_deficit) @@ -513,8 +554,9 @@ function _build_ac_hydro_de(power_data::PowerData, # Spill: T*nHyd (non-negative) spill = ExaModels.variable(core, T * nHyd; lvar = float_type(0)) - # Target slack δ: T*nHyd (free, penalized quadratically) - delta = ExaModels.variable(core, T * nHyd) + # Target slack: δ = δ⁺ − δ⁻ with δ⁺,δ⁻ ≥ 0 (L1+L2 Lagrangian penalty) + delta_pos = ExaModels.variable(core, T * nHyd; lvar = float_type(0)) + delta_neg = ExaModels.variable(core, T * nHyd; lvar = float_type(0)) # ── Parameters ──────────────────────────────────────────────────────────── @@ -536,6 +578,7 @@ function _build_ac_hydro_de(power_data::PowerData, p_inflow = ExaModels.parameter(core, zeros(float_type, T * nHyd)) p_target = ExaModels.parameter(core, zeros(float_type, T * nHyd)) p_penalty_half = ExaModels.parameter(core, fill(float_type(ρ / 2), T * nHyd)) + p_penalty_l1 = ExaModels.parameter(core, fill(ρ_l1, T * nHyd)) # ── Precompute branch AC coefficients ───────────────────────────────────── @@ -558,12 +601,21 @@ function _build_ac_hydro_de(power_data::PowerData, for item in def_cost_items ) + # L2 penalty: (ρ/2)·(δ⁺ − δ⁻)² delta_items = [(idx = _ri(nHyd, t, r),) for t in 1:T for r in 1:nHyd] ExaModels.objective(core, - p_penalty_half[item.idx] * delta[item.idx]^2 + p_penalty_half[item.idx] * (delta_pos[item.idx] - delta_neg[item.idx])^2 for item in delta_items ) + # L1 penalty: λ·(δ⁺ + δ⁻) + if use_l1 + ExaModels.objective(core, + p_penalty_l1[item.idx] * (delta_pos[item.idx] + delta_neg[item.idx]) + for item in delta_items + ) + end + # ── Constraints ─────────────────────────────────────────────────────────── n_con = 0 @@ -774,12 +826,13 @@ function _build_ac_hydro_de(power_data::PowerData, n_con += T * nHyd # ── TARGET CONSTRAINTS (ADDED LAST) ─────────────────────────────────────── + # x̂ − x − (δ⁺ − δ⁻) = 0 target_items = [(param_idx = _ri(nHyd, t, r), res_idx = _ri(nHyd, t+1, r), delta_idx = _ri(nHyd, t, r)) for t in 1:T for r in 1:nHyd] ExaModels.constraint(core, - p_target[item.param_idx] - reservoir[item.res_idx] - delta[item.delta_idx] + p_target[item.param_idx] - reservoir[item.res_idx] - delta_pos[item.delta_idx] + delta_neg[item.delta_idx] for item in target_items ) target_con_range = (n_con + 1):(n_con + T * nHyd) @@ -790,6 +843,7 @@ function _build_ac_hydro_de(power_data::PowerData, core, model, p_demand, p_reactive_demand, p_x0, p_inflow, p_target, p_penalty_half, Float64(ρ / 2), + p_penalty_l1, Float64(ρ_l1), nHyd, nBus, nGen, nBranch, T, :ac_polar, target_con_range, ) @@ -829,6 +883,7 @@ end hydro_solution(prob, result) -> NamedTuple Reshape the flat solution vector into named components. +`delta` is reconstructed as `delta_pos - delta_neg`. DC: (va, pg, pf, deficit, reservoir, outflow, spill, delta) AC: (va, vm, pg, qg, p_fr, q_fr, p_to, q_to, deficit, deficit_q, @@ -851,7 +906,9 @@ function hydro_solution(prob::HydroExaDEProblem, result) res_sol = reshape(sol[off .+ (1:(T+1)*nH)], nH, T+1); off += (T+1)*nH out_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH spill_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH - delta_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH + dp_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH + dn_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH + delta_sol = dp_sol .- dn_sol return (va=va_sol, pg=pg_sol, pf=pf_sol, deficit=def_sol, reservoir=res_sol, outflow=out_sol, spill=spill_sol, delta=delta_sol) else # :ac_polar @@ -868,7 +925,9 @@ function hydro_solution(prob::HydroExaDEProblem, result) res_sol = reshape(sol[off .+ (1:(T+1)*nH)], nH, T+1); off += (T+1)*nH out_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH spill_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH - delta_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH + dp_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH + dn_sol = reshape(sol[off .+ (1:T*nH)], nH, T); off += T*nH + delta_sol = dp_sol .- dn_sol return (va=va_sol, vm=vm_sol, pg=pg_sol, qg=qg_sol, p_fr=p_fr_sol, q_fr=q_fr_sol, p_to=p_to_sol, q_to=q_to_sol, deficit=def_sol, deficit_q=def_q_sol, diff --git a/examples/HydroPowerModels/train_hydro_exa.jl b/examples/HydroPowerModels/train_hydro_exa.jl index 62e6f90..3215056 100644 --- a/examples/HydroPowerModels/train_hydro_exa.jl +++ b/examples/HydroPowerModels/train_hydro_exa.jl @@ -171,6 +171,7 @@ lg = WandbLogger( "layers" => LAYERS, "activation" => string(ACTIVATION), "target_penalty" => "auto=$(round(resolved_pen; digits=2))", + "target_penalty_l1" => "auto=$(round(resolved_pen_l1; digits=2))", "deficit_cost" => DEFICIT_COST, "num_epochs" => NUM_EPOCHS, "num_batches" => NUM_BATCHES, @@ -223,9 +224,14 @@ end hydro_realized_state(stage_prob, result) = Array(hydro_solution(stage_prob, result).reservoir[:, end]) +resolved_pen_l1 = prob.base_penalty_l1 + function hydro_objective_no_target_penalty(stage_prob, result) sol = hydro_solution(stage_prob, result) - return result.objective - (resolved_pen / 2) * sum(abs2, Array(sol.delta)) + delta = Array(sol.delta) + penalty_l2_cost = (resolved_pen / 2) * sum(abs2, delta) + penalty_l1_cost = resolved_pen_l1 * sum(abs, delta) + return result.objective - penalty_l2_cost - penalty_l1_cost end Random.seed!(8789) @@ -296,11 +302,14 @@ train_tsddr( if mult != current_penalty_mult[] current_penalty_mult[] = mult ρ_half_scaled = prob.base_penalty_half * mult - penalty_vals = fill(ρ_half_scaled, T * nHyd) + ρ_l1_scaled = prob.base_penalty_l1 * mult + penalty_vals = fill(ρ_half_scaled, T * nHyd) + penalty_l1_vals = fill(ρ_l1_scaled, T * nHyd) for (p, _, _, _) in problem_pool ExaModels.set_parameter!(p.core, p.p_penalty_half, penalty_vals) + ExaModels.set_parameter!(p.core, p.p_penalty_l1, penalty_l1_vals) end - @info "Penalty multiplier → $mult (ρ/2 = $(round(ρ_half_scaled; digits=2)))" + @info "Penalty multiplier → $mult (ρ/2 = $(round(ρ_half_scaled; digits=2)), λ_l1 = $(round(ρ_l1_scaled; digits=2)))" end n_eval = _schedule_value(EVAL_SCHEDULE, iter, MAX_EVAL_SCENARIOS) rollout_evaluation.active_scenarios = n_eval From 96b78d1d93539bc526ad527b88fd2c79b4e7f069 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Sun, 21 Jun 2026 13:30:06 -0400 Subject: [PATCH 4/9] update --- examples/HydroPowerModels/hydro_power_data.jl | 9 +++++++-- examples/HydroPowerModels/train_hydro_exa.jl | 3 ++- src/training.jl | 5 ++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/HydroPowerModels/hydro_power_data.jl b/examples/HydroPowerModels/hydro_power_data.jl index 6ee9b95..bc7f5b4 100644 --- a/examples/HydroPowerModels/hydro_power_data.jl +++ b/examples/HydroPowerModels/hydro_power_data.jl @@ -354,15 +354,20 @@ end """ sample_scenario(hydro_data, T) -> Vector{Float64} -Sample one inflow trajectory of length T*nHyd (flat, stage-major order). +Sample one inflow trajectory of length `T*nHyd` (flat, stage-major order). + +Uses **joint** scenario sampling: at each stage one scenario index `ω` is drawn +and applied to all hydro reservoirs, preserving the spatial correlation present +in the historical inflow data. This matches SDDP's `SDDP.parameterize` semantics. """ function sample_scenario(hydro_data::HydroData, T::Int) nHyd = hydro_data.nHyd w = Vector{Float64}(undef, T * nHyd) for t in 1:T t_row = mod1(t, hydro_data.nStagesSample) + # One scenario index per stage — all reservoirs share it (joint sampling). + j = rand(1:hydro_data.nScenarios) for r in 1:nHyd - j = rand(1:hydro_data.nScenarios) w[(t-1)*nHyd + r] = hydro_data.scenario_inflows[r][t_row, j] end end diff --git a/examples/HydroPowerModels/train_hydro_exa.jl b/examples/HydroPowerModels/train_hydro_exa.jl index 3215056..0de91b8 100644 --- a/examples/HydroPowerModels/train_hydro_exa.jl +++ b/examples/HydroPowerModels/train_hydro_exa.jl @@ -290,7 +290,8 @@ train_tsddr( () -> sample_scenario(hydro_data, T); # returns flat Float32 vector, length T*nHyd num_batches = NUM_EPOCHS * NUM_BATCHES, num_train_per_batch = NUM_WORKERS, - optimizer = Flux.Optimisers.OptimiserChain( + optimizer = isnothing(GRAD_CLIP) ? Flux.Adam(LR) : + Flux.Optimisers.OptimiserChain( Flux.Optimisers.ClipGrad(GRAD_CLIP), Flux.Adam(LR), ), diff --git a/src/training.jl b/src/training.jl index eddb53d..0407ef6 100644 --- a/src/training.jl +++ b/src/training.jl @@ -313,7 +313,10 @@ Arguments: - `p_x0` : ExaModels parameter for the initial state - `p_target` : ExaModels parameter for policy targets - `p_uncertainty` : ExaModels parameter for per-stage uncertainty -- `uncertainty_sampler`: `() -> w_flat` — flat vector of length `T * nw_per_stage` +- `uncertainty_sampler`: `() -> w_flat` — flat vector of length `T * nw_per_stage`. + For multi-unit problems (e.g., hydro reservoirs) the sampler + should draw one joint scenario index per stage to preserve + spatial correlation; see `sample_scenario` in examples. Keyword arguments (mirror `train_multistage`): - `num_batches` : total gradient steps (default 100) From 46f0bf163e82a48e43eee697b5c67e72b89f2354 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Sun, 21 Jun 2026 20:49:36 -0400 Subject: [PATCH 5/9] update script options --- .gitignore | 1 + examples/HydroPowerModels/train_hydro_exa.jl | 29 ++++++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 01dec9e..a432b01 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ logs/ # Slurm batch scripts (user-specific, not part of the package) *.sbatch +*.sh diff --git a/examples/HydroPowerModels/train_hydro_exa.jl b/examples/HydroPowerModels/train_hydro_exa.jl index 0de91b8..a86bdb5 100644 --- a/examples/HydroPowerModels/train_hydro_exa.jl +++ b/examples/HydroPowerModels/train_hydro_exa.jl @@ -14,7 +14,7 @@ using Wandb, Logging using JLD2 using MadNLP using MadNLPGPU, KernelAbstractions, CUDA -using CUDSS_jll, cuDNN +using CUDSS, CUDSS_jll, cuDNN const SCRIPT_DIR = dirname(@__FILE__) include(joinpath(SCRIPT_DIR, "hydro_power_data.jl")) @@ -45,7 +45,7 @@ const EVAL_SCHEDULE = [ (div(NUM_EPOCHS * NUM_BATCHES, 2) + 1, NUM_EPOCHS * NUM_BATCHES, MAX_EVAL_SCENARIOS), ] const LR = 1f-3 -const GRAD_CLIP = 10.0f0 +const GRAD_CLIP = parse(Float32, get(ENV, "DR_GRAD_CLIP", "10")) const TARGET_PEN_ARG = :auto const DEFICIT_COST = 1e5 @@ -53,12 +53,17 @@ const USE_GPU = true const load_scaler = 0.6 const NUM_WORKERS = 4 -const PENALTY_SCHEDULE = [ - (1, div(NUM_EPOCHS * NUM_BATCHES, 4), 0.1), - (div(NUM_EPOCHS * NUM_BATCHES, 4) + 1, div(NUM_EPOCHS * NUM_BATCHES, 4) * 2, 1.0), - (div(NUM_EPOCHS * NUM_BATCHES, 4) * 2 + 1, div(NUM_EPOCHS * NUM_BATCHES, 4) * 3, 10.0), - (div(NUM_EPOCHS * NUM_BATCHES, 4) * 3 + 1, NUM_EPOCHS * NUM_BATCHES, 30.0), -] +const _PENALTY_MODE = get(ENV, "DR_PENALTY_SCHEDULE", "annealed") +const PENALTY_SCHEDULE = if _PENALTY_MODE == "annealed" + [ + (1, div(NUM_EPOCHS * NUM_BATCHES, 4), 0.1), + (div(NUM_EPOCHS * NUM_BATCHES, 4) + 1, div(NUM_EPOCHS * NUM_BATCHES, 4) * 2, 1.0), + (div(NUM_EPOCHS * NUM_BATCHES, 4) * 2 + 1, div(NUM_EPOCHS * NUM_BATCHES, 4) * 3, 10.0), + (div(NUM_EPOCHS * NUM_BATCHES, 4) * 3 + 1, NUM_EPOCHS * NUM_BATCHES, 30.0), + ] +else + [(1, NUM_EPOCHS * NUM_BATCHES, 1.0)] +end const NUM_TRAIN_SCHEDULE = [ (1, div(NUM_EPOCHS * NUM_BATCHES, 5), NUM_WORKERS), @@ -70,7 +75,9 @@ const NUM_TRAIN_SCHEDULE = [ const SOLVER_KWARGS = (print_level = MadNLP.ERROR, tol = 1e-6, max_iter = 9000) -const RUN_NAME = "$(CASE_NAME)-$(FORM_LABEL)-h$(NUM_STAGES)-deteq-gpu-$(Dates.format(now(), "yyyymmdd-HHMMSS"))" +const _CLIP_TAG = GRAD_CLIP > 0 ? "-clip$(Int(GRAD_CLIP))" : "" +const _SCHED_TAG = _PENALTY_MODE == "annealed" ? "-anneal" : "-const" +const RUN_NAME = "$(CASE_NAME)-$(FORM_LABEL)-h$(NUM_STAGES)-deteq-gpu$(_CLIP_TAG)$(_SCHED_TAG)-$(Dates.format(now(), "yyyymmdd-HHMMSS"))" const MODEL_DIR = joinpath(CASE_DIR, FORM_LABEL, "models") mkpath(MODEL_DIR) const MODEL_PATH = joinpath(MODEL_DIR, RUN_NAME * ".jld2") @@ -290,11 +297,11 @@ train_tsddr( () -> sample_scenario(hydro_data, T); # returns flat Float32 vector, length T*nHyd num_batches = NUM_EPOCHS * NUM_BATCHES, num_train_per_batch = NUM_WORKERS, - optimizer = isnothing(GRAD_CLIP) ? Flux.Adam(LR) : + optimizer = GRAD_CLIP > 0 ? Flux.Optimisers.OptimiserChain( Flux.Optimisers.ClipGrad(GRAD_CLIP), Flux.Adam(LR), - ), + ) : Flux.Adam(LR), madnlp_kwargs = SOLVER_KWARGS, warmstart = true, problem_pool = problem_pool, From 9f1ebf800be46d539a967031ddfa2aad252897ce Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Sun, 21 Jun 2026 21:56:53 -0400 Subject: [PATCH 6/9] update --- examples/HydroPowerModels/train_hydro_exa.jl | 67 +++++++++----------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/examples/HydroPowerModels/train_hydro_exa.jl b/examples/HydroPowerModels/train_hydro_exa.jl index a86bdb5..81d3021 100644 --- a/examples/HydroPowerModels/train_hydro_exa.jl +++ b/examples/HydroPowerModels/train_hydro_exa.jl @@ -35,15 +35,11 @@ const DEMAND_FILE = joinpath(CASE_DIR, "demand.csv") const LAYERS = [128, 128] const ACTIVATION = sigmoid const NUM_STAGES = 96 -const NUM_EPOCHS = 20 +const NUM_EPOCHS = 40 const NUM_BATCHES = 100 -const MAX_EVAL_SCENARIOS = 32 -const EVAL_EVERY = 25 - -const EVAL_SCHEDULE = [ - (1, div(NUM_EPOCHS * NUM_BATCHES, 2), 4), - (div(NUM_EPOCHS * NUM_BATCHES, 2) + 1, NUM_EPOCHS * NUM_BATCHES, MAX_EVAL_SCENARIOS), -] +const NUM_TRAIN_PER_BATCH = 1 +const NUM_EVAL_SCENARIOS = 4 +const EVAL_EVERY = 25 const LR = 1f-3 const GRAD_CLIP = parse(Float32, get(ENV, "DR_GRAD_CLIP", "10")) @@ -51,7 +47,7 @@ const TARGET_PEN_ARG = :auto const DEFICIT_COST = 1e5 const USE_GPU = true const load_scaler = 0.6 -const NUM_WORKERS = 4 +const NUM_WORKERS = 1 const _PENALTY_MODE = get(ENV, "DR_PENALTY_SCHEDULE", "annealed") const PENALTY_SCHEDULE = if _PENALTY_MODE == "annealed" @@ -65,13 +61,10 @@ else [(1, NUM_EPOCHS * NUM_BATCHES, 1.0)] end -const NUM_TRAIN_SCHEDULE = [ - (1, div(NUM_EPOCHS * NUM_BATCHES, 5), NUM_WORKERS), - (div(NUM_EPOCHS * NUM_BATCHES, 5) + 1, div(NUM_EPOCHS * NUM_BATCHES, 5) * 2, 2 * NUM_WORKERS), - (div(NUM_EPOCHS * NUM_BATCHES, 5) * 2 + 1, div(NUM_EPOCHS * NUM_BATCHES, 5) * 3, 4 * NUM_WORKERS), - (div(NUM_EPOCHS * NUM_BATCHES, 5) * 3 + 1, div(NUM_EPOCHS * NUM_BATCHES, 5) * 4, 8 * NUM_WORKERS), - (div(NUM_EPOCHS * NUM_BATCHES, 5) * 4 + 1, NUM_EPOCHS * NUM_BATCHES, 8 * NUM_WORKERS), -] +# Optional: ramp num_train_per_batch and eval scenarios over training. +# Set to `nothing` to use fixed NUM_TRAIN_PER_BATCH / NUM_EVAL_SCENARIOS. +const NUM_TRAIN_SCHEDULE = nothing # e.g. [(1,500,1),(501,2000,4),(2001,4000,8)] +const EVAL_SCHEDULE = nothing # e.g. [(1,2000,4),(2001,4000,32)] const SOLVER_KWARGS = (print_level = MadNLP.ERROR, tol = 1e-6, max_iter = 9000) @@ -154,6 +147,8 @@ result0 = MadNLP.madnlp(prob.model; SOLVER_KWARGS..., print_level = MadNLP.WARN) isfinite(result0.objective) || error("Smoke test returned non-finite objective") solve_succeeded(result0) || @warn "Smoke test did not fully converge; proceeding anyway" +resolved_pen_l1 = prob.base_penalty_l1 + # ── Policy ──────────────────────────────────────────────────────────────────── policy = StateConditionedPolicy(nHyd, nHyd, nHyd, LAYERS; @@ -182,16 +177,17 @@ lg = WandbLogger( "deficit_cost" => DEFICIT_COST, "num_epochs" => NUM_EPOCHS, "num_batches" => NUM_BATCHES, - "max_eval_scenarios" => MAX_EVAL_SCENARIOS, - "eval_schedule" => string(EVAL_SCHEDULE), + "num_train_per_batch" => NUM_TRAIN_PER_BATCH, + "num_eval_scenarios" => NUM_EVAL_SCENARIOS, "eval_every" => EVAL_EVERY, "lr" => LR, "grad_clip" => GRAD_CLIP, "backend" => USE_GPU ? "GPU" : "CPU", "load_scaler" => load_scaler, "penalty_schedule" => string(PENALTY_SCHEDULE), - "num_train_schedule" => string(NUM_TRAIN_SCHEDULE), - "num_workers" => NUM_WORKERS, + "num_train_schedule" => string(something(NUM_TRAIN_SCHEDULE, "fixed")), + "eval_schedule" => string(something(EVAL_SCHEDULE, "fixed")), + "num_workers" => NUM_WORKERS, ), ) @@ -215,8 +211,9 @@ function _build_rollout_de() ) end rollout_prob = _build_rollout_de() -rollout_pool = [_build_rollout_de() for _ in 1:NUM_WORKERS] -@info "Rollout pool ready: $(NUM_WORKERS) CPU stage-problem copies" +n_rollout_pool = max(NUM_WORKERS, NUM_EVAL_SCENARIOS) +rollout_pool = [_build_rollout_de() for _ in 1:n_rollout_pool] +@info "Rollout pool ready: $(n_rollout_pool) CPU stage-problem copies" function set_hydro_rollout_stage!(stage_prob, state_in, wt, target, stage) ExaModels.set_parameter!(stage_prob.core, stage_prob.p_x0, state_in) @@ -231,8 +228,6 @@ end hydro_realized_state(stage_prob, result) = Array(hydro_solution(stage_prob, result).reservoir[:, end]) -resolved_pen_l1 = prob.base_penalty_l1 - function hydro_objective_no_target_penalty(stage_prob, result) sol = hydro_solution(stage_prob, result) delta = Array(sol.delta) @@ -242,7 +237,7 @@ function hydro_objective_no_target_penalty(stage_prob, result) end Random.seed!(8789) -eval_scenarios = [sample_scenario(hydro_data, T) for _ in 1:MAX_EVAL_SCENARIOS] +eval_scenarios = [sample_scenario(hydro_data, T) for _ in 1:NUM_EVAL_SCENARIOS] rollout_evaluation = RolloutEvaluation( rollout_prob, x0_init, @@ -257,7 +252,7 @@ rollout_evaluation = RolloutEvaluation( stride = EVAL_EVERY, policy_state = :target, stage_problem_pool = rollout_pool, - active_scenarios = 4, + active_scenarios = NUM_EVAL_SCENARIOS, ) realized_rollout_evaluation = RolloutEvaluation( rollout_prob, @@ -273,7 +268,7 @@ realized_rollout_evaluation = RolloutEvaluation( stride = EVAL_EVERY, policy_state = :realized, stage_problem_pool = rollout_pool, - active_scenarios = 4, + active_scenarios = NUM_EVAL_SCENARIOS, ) Random.seed!(8788) @@ -294,9 +289,9 @@ train_tsddr( prob.p_x0, prob.p_target, prob.p_inflow, - () -> sample_scenario(hydro_data, T); # returns flat Float32 vector, length T*nHyd + () -> sample_scenario(hydro_data, T); num_batches = NUM_EPOCHS * NUM_BATCHES, - num_train_per_batch = NUM_WORKERS, + num_train_per_batch = NUM_TRAIN_PER_BATCH, optimizer = GRAD_CLIP > 0 ? Flux.Optimisers.OptimiserChain( Flux.Optimisers.ClipGrad(GRAD_CLIP), @@ -319,10 +314,12 @@ train_tsddr( end @info "Penalty multiplier → $mult (ρ/2 = $(round(ρ_half_scaled; digits=2)), λ_l1 = $(round(ρ_l1_scaled; digits=2)))" end - n_eval = _schedule_value(EVAL_SCHEDULE, iter, MAX_EVAL_SCENARIOS) - rollout_evaluation.active_scenarios = n_eval - realized_rollout_evaluation.active_scenarios = n_eval - return _schedule_value(NUM_TRAIN_SCHEDULE, iter, n) + if !isnothing(EVAL_SCHEDULE) + n_eval = _schedule_value(EVAL_SCHEDULE, iter, NUM_EVAL_SCENARIOS) + rollout_evaluation.active_scenarios = n_eval + realized_rollout_evaluation.active_scenarios = n_eval + end + return isnothing(NUM_TRAIN_SCHEDULE) ? n : _schedule_value(NUM_TRAIN_SCHEDULE, iter, n) end, record_loss = (iter, m, loss, tag) -> begin metrics = Dict{String, Any}(tag => loss, "batch" => iter) @@ -350,10 +347,6 @@ train_tsddr( if !isnan(current_penalty_mult[]) metrics["metrics/target_penalty_multiplier"] = current_penalty_mult[] end - metrics["metrics/num_train_per_batch"] = - _schedule_value(NUM_TRAIN_SCHEDULE, iter, 1) - metrics["metrics/active_eval_scenarios"] = - _schedule_value(EVAL_SCHEDULE, iter, MAX_EVAL_SCENARIOS) batch_in_epoch = (iter - 1) % NUM_BATCHES + 1 if batch_in_epoch == NUM_BATCHES From efb8ff069f0161a007365e612c1c0be95bd3dd76 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 22 Jun 2026 11:03:12 -0400 Subject: [PATCH 7/9] improve policy evaluation --- examples/HydroPowerModels/train_hydro_exa.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/HydroPowerModels/train_hydro_exa.jl b/examples/HydroPowerModels/train_hydro_exa.jl index 81d3021..6ec71e9 100644 --- a/examples/HydroPowerModels/train_hydro_exa.jl +++ b/examples/HydroPowerModels/train_hydro_exa.jl @@ -34,8 +34,8 @@ const DEMAND_FILE = joinpath(CASE_DIR, "demand.csv") const LAYERS = [128, 128] const ACTIVATION = sigmoid -const NUM_STAGES = 96 -const NUM_EPOCHS = 40 +const NUM_STAGES = parse(Int, get(ENV, "DR_NUM_STAGES", "126")) +const NUM_EPOCHS = parse(Int, get(ENV, "DR_NUM_EPOCHS", "80")) const NUM_BATCHES = 100 const NUM_TRAIN_PER_BATCH = 1 const NUM_EVAL_SCENARIOS = 4 From d35d2200cfcca23d8c2fd2b3b55ea2b2ba473f4d Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 22 Jun 2026 14:40:09 -0400 Subject: [PATCH 8/9] fix typo --- examples/end_to_end_cpu.jl | 56 ++++++++++++++------------------------ examples/end_to_end_gpu.jl | 56 ++++++++++++++++---------------------- 2 files changed, 45 insertions(+), 67 deletions(-) diff --git a/examples/end_to_end_cpu.jl b/examples/end_to_end_cpu.jl index b874162..f540839 100644 --- a/examples/end_to_end_cpu.jl +++ b/examples/end_to_end_cpu.jl @@ -4,50 +4,36 @@ using Random using LinearAlgebra using DecisionRulesExa +using MadNLP +using Flux Random.seed!(1) -# Problem dimensions -T = 8 # horizon -nx = 1 # state dimension (demo uses scalar) -# controls u are also dimension nx in build_linear_tracking_problem() +T = 8 +nx = 1 -# Build deterministic-equivalent NLP on CPU prob = build_linear_tracking_problem( horizon = T, nx = nx, - backend = nothing, # CPU + backend = nothing, slack_penalty = 10.0, u_bounds = (-2.0, 2.0), ) -# Policy: input = [x0 ; w(1:T-1)], output = xhat(1:T) -input_dim = nx + (T - 1) * nx -output_dim = T * nx -policy = MLPPolicy(input_dim, output_dim; hidden = (32, 32), act = tanh) - -# Optional solver cache (recommended if you solve many times) -cache = init_madnlp_cache(prob) - -# Scenario sampler -sampler(k) = begin - x0 = [1.0] - w = 0.1 .* randn(T - 1) # (T-1)*nx - return x0, w -end - -# Run a few TS-DDR iterations -hist = train_tsddr!( - policy, prob; - n_iters = 5, - sampler = sampler, - η = 1e-2, - cache = cache, - tol = 1e-6, - max_iter = 200, +policy = StateConditionedPolicy(nx, nx, nx, [32, 32]) + +sampler() = Float64.(0.1 .* randn(T * nx)) + +train_tsddr( + policy, + [1.0], + prob, + prob.p_x0, + prob.p_target, + prob.p_w, + sampler; + num_batches = 5, + num_train_per_batch = 2, + optimizer = Flux.Adam(1f-3), + madnlp_kwargs = (print_level = MadNLP.ERROR, tol = 1e-6), ) - -# Print last objective and a few duals -last = hist[end] -println("\nLast objective: ", last.result.objective) -println("First 5 target multipliers λ: ", collect(last.lambda[1:min(5, length(last.lambda))])) diff --git a/examples/end_to_end_gpu.jl b/examples/end_to_end_gpu.jl index d6bab48..4145332 100644 --- a/examples/end_to_end_gpu.jl +++ b/examples/end_to_end_gpu.jl @@ -2,14 +2,15 @@ # examples/end_to_end_gpu.jl # # End-to-end GPU demo: -# - ExaModels model instantiated with `CUDABackend()` -# - MadNLP solves with cuDSS via MadNLPGPU +# - ExaModels model instantiated with CUDABackend() +# - MadNLP solves with CUDSS via MadNLPGPU using Random using DecisionRulesExa using ExaModels using CUDA using MadNLPGPU +using MadNLP using Flux if !CUDA.functional() @@ -22,41 +23,32 @@ CUDA.allowscalar(false) T = 16 nx = 1 -# Build deterministic-equivalent NLP directly on the GPU prob = build_linear_tracking_problem( horizon = T, nx = nx, - backend = CUDABackend(), # <- GPU instantiation + backend = CUDABackend(), slack_penalty = 10.0, u_bounds = (-2.0, 2.0), ) -# Policy on GPU -input_dim = nx + (T - 1) * nx -output_dim = T * nx -policy = MLPPolicy(input_dim, output_dim; hidden = (64, 64), act = tanh) |> gpu - -# Solver cache on GPU (recommended) -cache = init_madnlp_cache(prob; linear_solver = CUDSSSolver) - -sampler(k) = begin - x0 = cu([1.0]) - w = cu(0.1 .* randn(T - 1)) - return x0, w -end - -hist = train_tsddr!( - policy, prob; - n_iters = 3, - sampler = sampler, - η = 5e-3, - cache = cache, - # MadNLP options (looser tolerances often make sense on GPU) - tol = 1e-4, - max_iter = 200, - linear_solver = CUDSSSolver, +policy = StateConditionedPolicy(nx, nx, nx, [64, 64]) + +sampler() = Float64.(0.1 .* randn(T * nx)) + +train_tsddr( + policy, + [1.0], + prob, + prob.p_x0, + prob.p_target, + prob.p_w, + sampler; + num_batches = 3, + num_train_per_batch = 2, + optimizer = Flux.Adam(1f-3), + madnlp_kwargs = ( + print_level = MadNLP.ERROR, + tol = 1e-4, + linear_solver = CUDSSSolver, + ), ) - -last = hist[end] -println("\nLast objective: ", last.result.objective) -println("First 5 target multipliers λ: ", Array(last.lambda[1:min(5, length(last.lambda))])) From 23ee424bbf1792759a955f4f5a0864efd8799fa0 Mon Sep 17 00:00:00 2001 From: Andrew Rosemberg Date: Mon, 22 Jun 2026 14:56:24 -0400 Subject: [PATCH 9/9] update docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a9ab925..a06eca1 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Pkg.add(url="https://github.com/LearningToOptimize/DecisionRulesExa.jl.git") ```julia using DecisionRulesExa -using ExaModels, Flux, Random +using ExaModels, MadNLP, Flux, Random Random.seed!(1)