|
| 1 | +# ============================================================================= |
| 2 | +# PlotInstability.jl — Regularization penalty selection diagnostic plots |
| 3 | +# src/grn/PlotInstability.jl |
| 4 | +# |
| 5 | +# Generates two figures per pipeline run to validate λ selection: |
| 6 | +# |
| 7 | +# Figure 1 instability_diagnostic_<mode>.png (2 panels side by side) |
| 8 | +# ───────────────────────────────────────────────────────────────────── |
| 9 | +# Left panel — Model Size vs λ |
| 10 | +# Shows how many regulators appear in the network as λ decreases. |
| 11 | +# At network level : total unique TFs with at least one edge anywhere. |
| 12 | +# At gene level : # TFs regulating that specific gene. |
| 13 | +# A steep drop as λ increases confirms that regularization is working. |
| 14 | +# |
| 15 | +# Right panel — Instability (Lower + Upper Bound) vs λ |
| 16 | +# bStARS stability bounds computed during the warm-start pass over the |
| 17 | +# full λ search range [minLambda, maxLambda]. |
| 18 | +# The dashed line marks targetInstability (default 0.05). |
| 19 | +# The fine-pass window is the region where Lb ≤ target ≤ Ub. |
| 20 | +# |
| 21 | +# Figure 2 instability_selection_<mode>.png (1 panel) |
| 22 | +# ───────────────────────────────────────────────────────────────────── |
| 23 | +# |netInstabilities − targetInstability| vs λ from the fine estimation pass. |
| 24 | +# The dot marks the chosen λ — the point closest to the target threshold. |
| 25 | +# This is a direct visualisation of what chooseLambda! computes. |
| 26 | +# |
| 27 | +# How Figures 1 and 2 relate |
| 28 | +# ───────────────────────────────────────────────────────────────────── |
| 29 | +# Fig 1 right panel shows the broad warm-start search window. |
| 30 | +# Fig 2 zooms into the fine-pass λ grid within that window and confirms |
| 31 | +# that the chosen λ sits at the minimum of |instability − target|. |
| 32 | +# Together they give a complete picture of the regularisation selection. |
| 33 | +# |
| 34 | +# Standalone usage (per-gene diagnostics) |
| 35 | +# ───────────────────────────────────────────────────────────────────── |
| 36 | +# Load the saved GrnData and call plotInstabilityCurves with mode = :gene: |
| 37 | +# |
| 38 | +# using JLD2, InferelatorJL |
| 39 | +# grnData = load_object("TFA/instabOutMat.jld") |
| 40 | +# import InferelatorJL: plotInstabilityCurves |
| 41 | +# plotInstabilityCurves(grnData; |
| 42 | +# mode = :gene, |
| 43 | +# geneName = "IFNG", |
| 44 | +# targetInstability = 0.05, |
| 45 | +# outputDir = "TFA/plots") |
| 46 | +# ============================================================================= |
| 47 | + |
| 48 | + |
| 49 | +""" |
| 50 | + plotInstabilityCurves(grnData; mode, geneName, targetInstability, outputDir) |
| 51 | +
|
| 52 | +Generate two regularization penalty selection diagnostic figures. |
| 53 | +
|
| 54 | +**Figure 1** (`instability_diagnostic_<mode>.png`): two panels showing |
| 55 | +model size vs λ (left) and bStARS instability bounds vs λ (right). |
| 56 | +The right panel uses the warm-start λ range; the left uses the fine-pass range. |
| 57 | +
|
| 58 | +**Figure 2** (`instability_selection_<mode>.png`): |instability − target| |
| 59 | +vs λ from the fine estimation pass, with a dot at the selected λ. |
| 60 | +
|
| 61 | +# Arguments |
| 62 | +- `grnData` : `GrnData` populated by `bstartsEstimateInstability` |
| 63 | +- `mode` : `:network` (default) — network-level plots using all genes; |
| 64 | + `:gene` — per-gene plots for one target gene |
| 65 | +- `geneName` : Target gene name string; required when `mode = :gene` |
| 66 | +- `targetInstability` : Instability threshold used during λ selection (default 0.05) |
| 67 | +- `outputDir` : Directory to save both figures; `nothing` displays instead |
| 68 | +
|
| 69 | +# Called automatically |
| 70 | +Inside `bstartsEstimateInstability` with `mode = :network` after every run. |
| 71 | +The figures are saved to the same `outputDir` as `instabOutMat.jld`. |
| 72 | +
|
| 73 | +# Standalone (per-gene) |
| 74 | +```julia |
| 75 | +using JLD2 |
| 76 | +grnData = load_object("TFA/instabOutMat.jld") |
| 77 | +import InferelatorJL: plotInstabilityCurves |
| 78 | +plotInstabilityCurves(grnData; mode = :gene, geneName = "IFNG", outputDir = "TFA/plots") |
| 79 | +``` |
| 80 | +""" |
| 81 | +function plotInstabilityCurves( |
| 82 | + grnData::GrnData; |
| 83 | + mode::Symbol = :network, |
| 84 | + geneName::Union{String, Nothing} = nothing, |
| 85 | + targetInstability::Float64 = 0.05, |
| 86 | + outputDir::Union{String, Nothing} = nothing |
| 87 | +) |
| 88 | + if !isnothing(outputDir) |
| 89 | + mkpath(outputDir) |
| 90 | + end |
| 91 | + |
| 92 | + lambdaFine = grnData.lambdaRange # fine-pass λ grid (x-axis for Fig 2 + model size) |
| 93 | + lambdaWarm = grnData.lambdaRangeWarm # warm-start λ grid (x-axis for instability bounds) |
| 94 | + totLambdas = length(lambdaFine) |
| 95 | + |
| 96 | + # ── Select data arrays based on mode ────────────────────────────────────── |
| 97 | + if mode == :network |
| 98 | + # Instability bounds from warm-start pass (one curve per bound) |
| 99 | + instabLb = grnData.netInstabilitiesLb |
| 100 | + instabUb = grnData.netInstabilitiesUb |
| 101 | + |
| 102 | + # Fine-pass instability (one value per λ in lambdaFine) |
| 103 | + instabFine = grnData.netInstabilities |
| 104 | + |
| 105 | + # Model size: unique TFs with ≥ 1 non-zero edge across all genes at each λ |
| 106 | + modelSize = zeros(Int, totLambdas) |
| 107 | + for lind in 1:totLambdas |
| 108 | + slab = grnData.stabilityMat[lind, :, :] # totResponses × totPreds |
| 109 | + modelSize[lind] = sum(vec(any(isfinite.(slab) .& (slab .> 0), dims=1))) |
| 110 | + end |
| 111 | + |
| 112 | + modeLabel = "Network" |
| 113 | + modelLabel = "# Unique TF Regulators" |
| 114 | + figSuffix = "network" |
| 115 | + |
| 116 | + elseif mode == :gene |
| 117 | + isnothing(geneName) && error("geneName must be provided when mode = :gene") |
| 118 | + geneIdx = findfirst(==(geneName), grnData.targGenes) |
| 119 | + isnothing(geneIdx) && error("Gene \"$geneName\" not found in grnData.targGenes") |
| 120 | + |
| 121 | + # Per-gene instability bounds from warm-start pass |
| 122 | + instabLb = grnData.instabilitiesLb[geneIdx, :] |
| 123 | + instabUb = grnData.instabilitiesUb[geneIdx, :] |
| 124 | + |
| 125 | + # Per-gene fine-pass instability |
| 126 | + instabFine = grnData.geneInstabilities[geneIdx, :] |
| 127 | + |
| 128 | + # Model size: # TFs with non-zero edge to this gene at each fine-pass λ |
| 129 | + modelSize = zeros(Int, totLambdas) |
| 130 | + for lind in 1:totLambdas |
| 131 | + row = grnData.stabilityMat[lind, geneIdx, :] |
| 132 | + modelSize[lind] = sum(isfinite.(row) .& (row .> 0)) |
| 133 | + end |
| 134 | + |
| 135 | + modeLabel = "Gene: $geneName" |
| 136 | + modelLabel = "# TF Regulators" |
| 137 | + figSuffix = replace(geneName, r"[/\\]" => "_") |
| 138 | + |
| 139 | + else |
| 140 | + error("mode must be :network or :gene, got :$mode") |
| 141 | + end |
| 142 | + |
| 143 | + # ── Figure 1: Model size (left) + Instability bounds (right) ───────────── |
| 144 | + fig1, (ax1, ax2) = PyPlot.subplots(1, 2; figsize = (10, 4)) |
| 145 | + fig1.suptitle("Regularization Penalty Selection — $modeLabel", fontsize = 13) |
| 146 | + |
| 147 | + # Left panel — model size vs λ (fine-pass range) |
| 148 | + ax1.plot(lambdaFine, modelSize; color = "steelblue", linewidth = 1.5) |
| 149 | + ax1.set_xscale("log") |
| 150 | + ax1.set_xlabel("λ", fontsize = 11) |
| 151 | + ax1.set_ylabel(modelLabel, fontsize = 10) |
| 152 | + ax1.set_title("Model Size vs λ", fontsize = 10) |
| 153 | + |
| 154 | + # Right panel — instability bounds vs λ (warm-start range) |
| 155 | + ax2.plot(lambdaWarm, instabLb; color = "steelblue", linewidth = 1.5, label = "Lower Bound") |
| 156 | + ax2.plot(lambdaWarm, instabUb; color = "darkorange", linewidth = 1.5, label = "Upper Bound") |
| 157 | + ax2.axhline(y = targetInstability; color = "black", linestyle = "--", |
| 158 | + linewidth = 1.0, label = "Target ($targetInstability)") |
| 159 | + ax2.set_xscale("log") |
| 160 | + ax2.set_xlabel("λ", fontsize = 11) |
| 161 | + ax2.set_ylabel("Instability", fontsize = 10) |
| 162 | + ax2.set_title("Instability vs λ (warm-start range)", fontsize = 10) |
| 163 | + ax2.legend(fontsize = 8) |
| 164 | + |
| 165 | + PyPlot.tight_layout() |
| 166 | + |
| 167 | + if !isnothing(outputDir) |
| 168 | + savePath1 = joinpath(outputDir, "instability_diagnostic_$(figSuffix).png") |
| 169 | + PyPlot.savefig(savePath1; dpi = 150) |
| 170 | + @info "Saved instability diagnostic" file = savePath1 |
| 171 | + else |
| 172 | + PyPlot.show() |
| 173 | + end |
| 174 | + PyPlot.close(fig1) |
| 175 | + |
| 176 | + # ── Figure 2: |instability − target| with chosen λ dot ─────────────────── |
| 177 | + devs = abs.(instabFine .- targetInstability) |
| 178 | + minIdx = argmin(devs) |
| 179 | + chosenλ = lambdaFine[minIdx] |
| 180 | + |
| 181 | + fig2, ax = PyPlot.subplots(1, 1; figsize = (5, 4)) |
| 182 | + fig2.suptitle("λ Selection — $modeLabel", fontsize = 13) |
| 183 | + |
| 184 | + ax.plot(lambdaFine, devs; color = "steelblue", linewidth = 1.5) |
| 185 | + ax.scatter([chosenλ], [devs[minIdx]]; color = "steelblue", s = 80, zorder = 5) |
| 186 | + ax.set_xscale("log") |
| 187 | + ax.set_xlabel("λ", fontsize = 11) |
| 188 | + ax.set_ylabel("|Instability − Target|", fontsize = 10) |
| 189 | + ax.set_title("|Instability − Target| vs λ", fontsize = 10) |
| 190 | + |
| 191 | + PyPlot.tight_layout() |
| 192 | + |
| 193 | + if !isnothing(outputDir) |
| 194 | + savePath2 = joinpath(outputDir, "instability_selection_$(figSuffix).png") |
| 195 | + PyPlot.savefig(savePath2; dpi = 150) |
| 196 | + @info "Saved λ selection diagnostic" file = savePath2 chosenλ = chosenλ |
| 197 | + else |
| 198 | + PyPlot.show() |
| 199 | + end |
| 200 | + PyPlot.close(fig2) |
| 201 | + |
| 202 | + return nothing |
| 203 | +end |
0 commit comments