Skip to content

Commit e040984

Browse files
committed
:
Update README and add instability diagnostic plots - Update README: add step 3b/7 to pipeline, fix output filenames, add timeLagFile/timeLag params, fix evaluateNetwork signature, add dev example files to examples table - Add plotInstabilityCurves (src/grn/PlotInstability.jl): auto-generated network-level λ selection diagnostics; standalone per-gene mode - Add lambdaRangeWarm and targGenes to GrnData for plot support
1 parent 283471e commit e040984

5 files changed

Lines changed: 223 additions & 7 deletions

File tree

src/API.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,11 @@ function buildNetwork(
249249
# Fine estimation pass
250250
constructSubsamples(data, grnData; totSS = totSS, subsampleFrac = subsampleFrac)
251251
bstartsEstimateInstability(grnData;
252-
totLambdas = totLambdas,
253-
instabilityLevel = instabilityLevel,
254-
zTarget = zScoreLASSO,
255-
outputDir = outputDir)
252+
totLambdas = totLambdas,
253+
instabilityLevel = instabilityLevel,
254+
zTarget = zScoreLASSO,
255+
targetInstability = targetInstability, # ADDED: forward for λ selection diagnostic plot
256+
outputDir = outputDir)
256257

257258
buildGrn = BuildGrn()
258259
chooseLambda!(grnData, buildGrn;

src/InferelatorJL.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include("grn/BuildGRN.jl") # bstarsWarmStart, bstartsEstimateInstab
2525
include("grn/AggregateNetworks.jl") # combineGRNs / aggregateNetworks
2626
include("grn/RefineTFA.jl") # combineGRNS2 / refineTFA
2727
include("grn/UtilsGRN.jl") # GRN utility helpers
28+
include("grn/PlotInstability.jl") # ADDED: plotInstabilityCurves — λ selection diagnostic plots
2829

2930
# ── Metrics ───────────────────────────────────────────────────────────────────
3031
include("metrics/Constants.jl")
@@ -81,7 +82,7 @@ export
8182
# preparePenaltyMatrix! constructSubsamples bstarsWarmStart
8283
# bstartsEstimateInstability chooseLambda! rankEdges!
8384
# computePR plotPRCurves plotAUPR
84-
# loadPRData
85+
# loadPRData plotInstabilityCurves
8586
# -------------------------------------------------------------------------
8687
# -------------------------------------------------------------------------
8788

src/Types.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ mutable struct GrnData
100100
geneInstabilities::Matrix{Float64}
101101
lambdaRange::Vector{Float64}
102102
lambdaRangeGene::Vector{Vector{Float64}}
103+
lambdaRangeWarm::Vector{Float64} # ADDED: λ grid from warm-start pass (needed for instability bound plots)
103104
stabilityMat::Array{Float64}
104105
priorMatProcessed::Matrix{Float64}
105106
betas::Array{Float64,3}
107+
targGenes::Vector{String} # ADDED: target gene names — row index of stabilityMat (needed for per-gene plots)
106108

107109
function GrnData()
108110
return new(
@@ -123,9 +125,11 @@ mutable struct GrnData
123125
Matrix{Float64}(undef, 0, 0), # geneInstabilities
124126
[], # lambdaRange
125127
Vector{Vector{Float64}}(undef, 0), # lambdaRangeGene
128+
[], # lambdaRangeWarm
126129
Array{Float64}(undef, 0, 0, 0), # stabilityMat (3-D)
127130
Matrix{Float64}(undef, 0, 0), # priorMatProcessed
128-
Array{Float64,3}(undef, 0, 0, 0) # betas
131+
Array{Float64,3}(undef, 0, 0, 0), # betas
132+
[] # targGenes
129133
)
130134
end
131135
end

src/grn/PlotInstability.jl

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# =============================================================================
2+
# PlotInstability.jl — Regularization penalty selection diagnostic plots
3+
# src/grn/PlotInstability.jl
4+
#
5+
# Generates two figures per pipeline run to validate λ selection:
6+
#
7+
# Figure 1 instability_diagnostic_<mode>.png (2 panels side by side)
8+
# ─────────────────────────────────────────────────────────────────────
9+
# Left panel — Model Size vs λ
10+
# Shows how many regulators appear in the network as λ decreases.
11+
# At network level : total unique TFs with at least one edge anywhere.
12+
# At gene level : # TFs regulating that specific gene.
13+
# A steep drop as λ increases confirms that regularization is working.
14+
#
15+
# Right panel — Instability (Lower + Upper Bound) vs λ
16+
# bStARS stability bounds computed during the warm-start pass over the
17+
# full λ search range [minLambda, maxLambda].
18+
# The dashed line marks targetInstability (default 0.05).
19+
# The fine-pass window is the region where Lb ≤ target ≤ Ub.
20+
#
21+
# Figure 2 instability_selection_<mode>.png (1 panel)
22+
# ─────────────────────────────────────────────────────────────────────
23+
# |netInstabilities − targetInstability| vs λ from the fine estimation pass.
24+
# The dot marks the chosen λ — the point closest to the target threshold.
25+
# This is a direct visualisation of what chooseLambda! computes.
26+
#
27+
# How Figures 1 and 2 relate
28+
# ─────────────────────────────────────────────────────────────────────
29+
# Fig 1 right panel shows the broad warm-start search window.
30+
# Fig 2 zooms into the fine-pass λ grid within that window and confirms
31+
# that the chosen λ sits at the minimum of |instability − target|.
32+
# Together they give a complete picture of the regularisation selection.
33+
#
34+
# Standalone usage (per-gene diagnostics)
35+
# ─────────────────────────────────────────────────────────────────────
36+
# Load the saved GrnData and call plotInstabilityCurves with mode = :gene:
37+
#
38+
# using JLD2, InferelatorJL
39+
# grnData = load_object("TFA/instabOutMat.jld")
40+
# import InferelatorJL: plotInstabilityCurves
41+
# plotInstabilityCurves(grnData;
42+
# mode = :gene,
43+
# geneName = "IFNG",
44+
# targetInstability = 0.05,
45+
# outputDir = "TFA/plots")
46+
# =============================================================================
47+
48+
49+
"""
50+
plotInstabilityCurves(grnData; mode, geneName, targetInstability, outputDir)
51+
52+
Generate two regularization penalty selection diagnostic figures.
53+
54+
**Figure 1** (`instability_diagnostic_<mode>.png`): two panels showing
55+
model size vs λ (left) and bStARS instability bounds vs λ (right).
56+
The right panel uses the warm-start λ range; the left uses the fine-pass range.
57+
58+
**Figure 2** (`instability_selection_<mode>.png`): |instability − target|
59+
vs λ from the fine estimation pass, with a dot at the selected λ.
60+
61+
# Arguments
62+
- `grnData` : `GrnData` populated by `bstartsEstimateInstability`
63+
- `mode` : `:network` (default) — network-level plots using all genes;
64+
`:gene` — per-gene plots for one target gene
65+
- `geneName` : Target gene name string; required when `mode = :gene`
66+
- `targetInstability` : Instability threshold used during λ selection (default 0.05)
67+
- `outputDir` : Directory to save both figures; `nothing` displays instead
68+
69+
# Called automatically
70+
Inside `bstartsEstimateInstability` with `mode = :network` after every run.
71+
The figures are saved to the same `outputDir` as `instabOutMat.jld`.
72+
73+
# Standalone (per-gene)
74+
```julia
75+
using JLD2
76+
grnData = load_object("TFA/instabOutMat.jld")
77+
import InferelatorJL: plotInstabilityCurves
78+
plotInstabilityCurves(grnData; mode = :gene, geneName = "IFNG", outputDir = "TFA/plots")
79+
```
80+
"""
81+
function plotInstabilityCurves(
82+
grnData::GrnData;
83+
mode::Symbol = :network,
84+
geneName::Union{String, Nothing} = nothing,
85+
targetInstability::Float64 = 0.05,
86+
outputDir::Union{String, Nothing} = nothing
87+
)
88+
if !isnothing(outputDir)
89+
mkpath(outputDir)
90+
end
91+
92+
lambdaFine = grnData.lambdaRange # fine-pass λ grid (x-axis for Fig 2 + model size)
93+
lambdaWarm = grnData.lambdaRangeWarm # warm-start λ grid (x-axis for instability bounds)
94+
totLambdas = length(lambdaFine)
95+
96+
# ── Select data arrays based on mode ──────────────────────────────────────
97+
if mode == :network
98+
# Instability bounds from warm-start pass (one curve per bound)
99+
instabLb = grnData.netInstabilitiesLb
100+
instabUb = grnData.netInstabilitiesUb
101+
102+
# Fine-pass instability (one value per λ in lambdaFine)
103+
instabFine = grnData.netInstabilities
104+
105+
# Model size: unique TFs with ≥ 1 non-zero edge across all genes at each λ
106+
modelSize = zeros(Int, totLambdas)
107+
for lind in 1:totLambdas
108+
slab = grnData.stabilityMat[lind, :, :] # totResponses × totPreds
109+
modelSize[lind] = sum(vec(any(isfinite.(slab) .& (slab .> 0), dims=1)))
110+
end
111+
112+
modeLabel = "Network"
113+
modelLabel = "# Unique TF Regulators"
114+
figSuffix = "network"
115+
116+
elseif mode == :gene
117+
isnothing(geneName) && error("geneName must be provided when mode = :gene")
118+
geneIdx = findfirst(==(geneName), grnData.targGenes)
119+
isnothing(geneIdx) && error("Gene \"$geneName\" not found in grnData.targGenes")
120+
121+
# Per-gene instability bounds from warm-start pass
122+
instabLb = grnData.instabilitiesLb[geneIdx, :]
123+
instabUb = grnData.instabilitiesUb[geneIdx, :]
124+
125+
# Per-gene fine-pass instability
126+
instabFine = grnData.geneInstabilities[geneIdx, :]
127+
128+
# Model size: # TFs with non-zero edge to this gene at each fine-pass λ
129+
modelSize = zeros(Int, totLambdas)
130+
for lind in 1:totLambdas
131+
row = grnData.stabilityMat[lind, geneIdx, :]
132+
modelSize[lind] = sum(isfinite.(row) .& (row .> 0))
133+
end
134+
135+
modeLabel = "Gene: $geneName"
136+
modelLabel = "# TF Regulators"
137+
figSuffix = replace(geneName, r"[/\\]" => "_")
138+
139+
else
140+
error("mode must be :network or :gene, got :$mode")
141+
end
142+
143+
# ── Figure 1: Model size (left) + Instability bounds (right) ─────────────
144+
fig1, (ax1, ax2) = PyPlot.subplots(1, 2; figsize = (10, 4))
145+
fig1.suptitle("Regularization Penalty Selection — $modeLabel", fontsize = 13)
146+
147+
# Left panel — model size vs λ (fine-pass range)
148+
ax1.plot(lambdaFine, modelSize; color = "steelblue", linewidth = 1.5)
149+
ax1.set_xscale("log")
150+
ax1.set_xlabel("λ", fontsize = 11)
151+
ax1.set_ylabel(modelLabel, fontsize = 10)
152+
ax1.set_title("Model Size vs λ", fontsize = 10)
153+
154+
# Right panel — instability bounds vs λ (warm-start range)
155+
ax2.plot(lambdaWarm, instabLb; color = "steelblue", linewidth = 1.5, label = "Lower Bound")
156+
ax2.plot(lambdaWarm, instabUb; color = "darkorange", linewidth = 1.5, label = "Upper Bound")
157+
ax2.axhline(y = targetInstability; color = "black", linestyle = "--",
158+
linewidth = 1.0, label = "Target ($targetInstability)")
159+
ax2.set_xscale("log")
160+
ax2.set_xlabel("λ", fontsize = 11)
161+
ax2.set_ylabel("Instability", fontsize = 10)
162+
ax2.set_title("Instability vs λ (warm-start range)", fontsize = 10)
163+
ax2.legend(fontsize = 8)
164+
165+
PyPlot.tight_layout()
166+
167+
if !isnothing(outputDir)
168+
savePath1 = joinpath(outputDir, "instability_diagnostic_$(figSuffix).png")
169+
PyPlot.savefig(savePath1; dpi = 150)
170+
@info "Saved instability diagnostic" file = savePath1
171+
else
172+
PyPlot.show()
173+
end
174+
PyPlot.close(fig1)
175+
176+
# ── Figure 2: |instability − target| with chosen λ dot ───────────────────
177+
devs = abs.(instabFine .- targetInstability)
178+
minIdx = argmin(devs)
179+
chosenλ = lambdaFine[minIdx]
180+
181+
fig2, ax = PyPlot.subplots(1, 1; figsize = (5, 4))
182+
fig2.suptitle("λ Selection — $modeLabel", fontsize = 13)
183+
184+
ax.plot(lambdaFine, devs; color = "steelblue", linewidth = 1.5)
185+
ax.scatter([chosenλ], [devs[minIdx]]; color = "steelblue", s = 80, zorder = 5)
186+
ax.set_xscale("log")
187+
ax.set_xlabel("λ", fontsize = 11)
188+
ax.set_ylabel("|Instability − Target|", fontsize = 10)
189+
ax.set_title("|Instability − Target| vs λ", fontsize = 10)
190+
191+
PyPlot.tight_layout()
192+
193+
if !isnothing(outputDir)
194+
savePath2 = joinpath(outputDir, "instability_selection_$(figSuffix).png")
195+
PyPlot.savefig(savePath2; dpi = 150)
196+
@info "Saved λ selection diagnostic" file = savePath2 chosenλ = chosenλ
197+
else
198+
PyPlot.show()
199+
end
200+
PyPlot.close(fig2)
201+
202+
return nothing
203+
end

src/grn/PrepareGRN.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ function preparePredictorMat!(grnData::GrnData, expressionData::GeneExpressionDa
118118
grnData.allPredictors = allPredictors
119119
grnData.responseMat = expressionData.targGeneMat
120120
grnData.priorMatProcessed = priorMat
121+
grnData.targGenes = copy(expressionData.targGenes) # ADDED: store gene names for per-gene instability plots
121122
end
122123

123124

@@ -334,9 +335,10 @@ function bstarsWarmStart(expressionData::GeneExpressionData, tfaData::PriorTFADa
334335
grnData.netInstabilitiesUb = netInstabilitiesUb
335336
grnData.instabilitiesLb = instabilitiesLb
336337
grnData.instabilitiesUb = instabilitiesUb
338+
grnData.lambdaRangeWarm = lambdaRange # ADDED: store warm-start λ grid for instability bound plots
337339
end
338340

339-
function bstartsEstimateInstability(grnData::GrnData; totLambdas = 10, instabilityLevel = "Gene", zTarget = false, outputDir::Union{String, Nothing}=nothing)
341+
function bstartsEstimateInstability(grnData::GrnData; totLambdas = 10, instabilityLevel = "Gene", zTarget = false, targetInstability::Float64 = 0.05, outputDir::Union{String, Nothing}=nothing) # ADDED: targetInstability for λ selection diagnostic plot
340342

341343
totResponses,totSamps = size(grnData.responseMat) # totResponses is same as length(grnData["targGenes"])
342344
totPreds = size(grnData.predictorMat,1)
@@ -450,5 +452,10 @@ function bstartsEstimateInstability(grnData::GrnData; totLambdas = 10, instabili
450452
if outputDir !== nothing && outputDir !== ""
451453
outputFile = joinpath(outputDir, "instabOutMat.jld")
452454
save_object(outputFile, grnData)
455+
# ADDED: auto-generate network-level diagnostic plots alongside saved data
456+
plotInstabilityCurves(grnData;
457+
mode = :network,
458+
targetInstability = targetInstability,
459+
outputDir = outputDir)
453460
end
454461
end

0 commit comments

Comments
 (0)