Skip to content
12 changes: 6 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Stage 0 - Create from julia image and install OS packages
FROM julia:1.11.6 as stage0
FROM julia:1.12.5 as stage0
RUN apt update && apt -y install bzip2 build-essential libxml2

# STAGE 1 - Python and python packages for S3 functionality
Expand All @@ -15,8 +15,8 @@ ENV PYTHON="/usr/bin/python3"
COPY deps.jl /app/deps.jl
ENV JULIA_CPU_TARGET="generic;sandybridge,-xsaveopt,clone_all;haswell,-rdrnd,base(1)"
RUN julia /app/deps.jl \
&& find /usr/local/bin/julia_pkgs -type d -exec chmod 755 {} \; \
&& find /usr/local/bin/julia_pkgs -type f -exec chmod 644 {} \;
&& find /usr/local/bin/julia_pkgs -type d -exec chmod 755 {} \; \
&& find /usr/local/bin/julia_pkgs -type f -exec chmod 644 {} \;

# Stage 3 - Copy SWOT script
FROM stage2 as stage3
Expand All @@ -26,7 +26,7 @@ COPY ./sos_read /app/sos_read/
# Stage 4 - Execute algorithm
FROM stage3 as stage4
LABEL version="1.0" \
description="Containerized SAD algorithm." \
"confluence.contact"="ntebaldi@umass.edu" \
"algorithm.contact"="kandread@umass.edu"
description="Containerized SAD algorithm." \
"confluence.contact"="ntebaldi@umass.edu" \
"algorithm.contact"="kandread@umass.edu"
ENTRYPOINT ["/usr/local/julia/bin/julia", "/app/swot.jl"]
50 changes: 22 additions & 28 deletions swot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ Load SWOT observations.
"""
function read_swot_obs(ncfile::String, nids::Vector{Int})
Dataset(ncfile) do ds
nodes = NCDatasets.group(ds, "node")
reaches = NCDatasets.group(ds, "reach")
S = permutedims(nodes["slope2"][:])
H = permutedims(nodes["wse"][:])
W = permutedims(nodes["width"][:])
nodes = ds.group["node"]
reaches = ds.group["reach"]
S = permutedims(nodes["slope2"][:, :])
H = permutedims(nodes["wse"][:, :])
W = permutedims(nodes["width"][:, :])
dA = reaches["d_x_area"][:]
dA = convert(Vector{Sad.FloatM}, dA)
Hr = convert(Vector{Sad.FloatM}, reaches["wse"][:])
Expand All @@ -64,10 +64,8 @@ function read_swot_obs(ncfile::String, nids::Vector{Int})
nid = nodes["node_id"][:]
dmap = Dict(nid[k] => k for k=1:length(nid))
i = [dmap[k] for k in nids]
time_str_var = reaches["time_str"].var
time_str_raw = permutedims(time_str_var[:])
time_str = [join(time_str_raw[i, :]) for i in 1:size(time_str_raw, 1)]

time = reaches["time"][:]
time_str = [string(t) for t in time]

H[i, :], W[i, :], S[i, :], dA, Hr, Wr, Sr, time_str
end
Expand All @@ -86,7 +84,7 @@ Retrieve information about river reach cross sections.
"""
function river_info(id::Int, swordfile::String)
Dataset(swordfile) do fd
g = NCDatasets.group(fd, "nodes")
g = fd.group["nodes"]
i = findall(g["reach_id"][:] .== id)
nid = g["node_id"][i]
x = g["dist_out"][i]
Expand Down Expand Up @@ -116,9 +114,9 @@ function write_output(reachid, valid, outdir, A0, n, Qa, Qu, W, time_str)
ridv = defVar(out, "reach_id", Int64, (), fillvalue = FILL)
ridv[:] = reachid
A0v = defVar(out, "A0", Float64, (), fillvalue = FILL)
A0v[:] = A0
A0v[:] = coalesce(A0, FILL)
nv = defVar(out, "n", Float64, (), fillvalue = FILL)
nv[:] = n
nv[:] = coalesce(n, FILL)
Qav = defVar(out, "Qa", Float64, ("nt",), fillvalue = FILL)
Qav[:] = replace!(Qa, NaN=>FILL)
Quv = defVar(out, "Q_u", Float64, ("nt",), fillvalue = FILL)
Expand Down Expand Up @@ -171,7 +169,7 @@ function main()
else
index = parsed_args["index"] + 1
end

reachfile = parsed_args["reachfile"]
bucketkey = parsed_args["bucketkey"]
println("Index: $(index)")
Expand All @@ -187,31 +185,27 @@ function main()
nids, x = river_info(reachid, swordfile)
H, W, S, dA, Hr, Wr, Sr, time_str = read_swot_obs(swotfile, nids)

try
x, H, W, S = Sad.drop_unobserved(x, H, W, S)
catch e
if e isa MethodError
println("Error loading swot observation")
end
end
reach = Sad.preprocess(x, H, W, S)

A0 = missing
n = missing
Qa = Array{Missing}(missing, 1, size(W, 2))
Qu = Array{Missing}(missing, 1, size(W, 2))
Qa = Matrix{Sad.FloatM}(missing, 1, size(W, 2))
Qu = Matrix{Sad.FloatM}(missing, 1, size(W, 2))
if all(ismissing, H) || all(ismissing, W) || all(ismissing, S)
println("$(reachid): INVALID")
write_output(reachid, 0, outdir, A0, n, Qa, Qu, W, time_str)
else
Hmin = minimum(skipmissing(H[1, :]))
Qp, np, rp, zp = Sad.priors(sosfile, Hmin, reachid)
if ismissing(Qp)
p = Sad.priors(sosfile, reach.hmin, reachid)
if ismissing(p)
println("$(reachid): INVALID, missing mean discharge")
write_output(reachid, 0, outdir, A0, n, Qa, Qu, W, time_str)
else
try
nens = 100 # default ensemble size
nsamples = 1000 # default sampling size
Qa, Qu, A0, n = Sad.estimate(x, H, W, S, dA, Qp, np, rp, zp, nens, nsamples, Hr, Wr, Sr)
res = Sad.infer(p, reach, time_str=time_str)
A0 = Sad.compute_A0(reach, res.reach_ensemble)
n = mean(res.reach_ensemble[1, :])
Qa[1, :] = [isnan(q) ? missing : q for q in res.Q_post]
Qu[1, :] = [isnothing(res.A_post[t]) ? missing : std(res.A_post[t][1, :]) for t in 1:reach.nt]
println("$(reachid): VALID")
write_output(reachid, 1, outdir, A0, n, Qa, Qu, W, time_str)
catch
Expand Down
Loading