Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1a6e7c9
Set up Julia package and example SSP2 optimizations
lxvm Mar 13, 2026
7a0353b
Delete examples/julia/Manifest.toml
lxvm Mar 14, 2026
422bda4
fix underflow in tanh projection
lxvm Mar 14, 2026
34b8986
use evalpoly
lxvm Mar 14, 2026
fbe944e
fix projection typos/bugs
lxvm Mar 16, 2026
e4e53df
refactor projection into a loop over target points
lxvm Mar 16, 2026
46105ab
refactor interpolation into another module
lxvm Mar 16, 2026
0e66fba
update example to use a higher resolution target grid and reduce allo…
lxvm Mar 16, 2026
a5ddab8
correct the usage of cubic_adjoint deriv api
lxvm Mar 17, 2026
639ed5d
correct the comparison to finite differences because arrays are reused
lxvm Mar 17, 2026
b19bd89
avoid NaN in tanh projection when x == eta at beta = Inf
lxvm Mar 19, 2026
e9f1442
change definition of ProjectionProblem and SSP2
lxvm Mar 20, 2026
000599d
add grid as an optional argument to PaddingProblem
lxvm Mar 20, 2026
db22074
add init! to interface
lxvm Mar 20, 2026
326eeb0
fix typo
lxvm Mar 20, 2026
1562411
add support for SSP1
lxvm Mar 21, 2026
79588e5
add linear interpolation and use bc in cubic interpolation
lxvm Mar 21, 2026
75c442a
rename R_smoothing_factor to smoothing_radius like python api
lxvm Mar 21, 2026
17788f2
add pythonic API
lxvm Mar 21, 2026
f32ed89
add ChainRulesCore rrules for pythonic api
lxvm Mar 21, 2026
7502781
run checks to make sure different apis match
lxvm Mar 22, 2026
7a7ac00
document apis of padding module
lxvm Mar 25, 2026
242c131
document public api in kernels module
lxvm Mar 25, 2026
e4525d8
document public api of convolution module
lxvm Mar 25, 2026
3cfa7d9
Document public api of interpolation module
lxvm Mar 25, 2026
13ac5ea
document the public api of the projection module
lxvm Mar 25, 2026
d10ba70
document public api of ssp package
lxvm Mar 25, 2026
ea998e4
add package tests
lxvm Mar 29, 2026
4d56672
add README.md
lxvm Mar 29, 2026
5baff4f
remove revise dep
lxvm Mar 30, 2026
01ad3b8
use rffts, precompute conv kernel, use perturb in arb direction, and …
lxvm Mar 30, 2026
b4a6b26
Apply suggestion from @stevengj
stevengj Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/julia/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Comment thread
lxvm marked this conversation as resolved.
Outdated
SSP = "e5b5d2ee-15bb-40cc-a0da-b305b842b7a8"
197 changes: 197 additions & 0 deletions examples/julia/ssp2_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
using SSP: init, solve!, adjoint_solve!
using SSP: Kernel, Pad, Convolve, Project
using .Kernel: conickernel
using .Pad: FillPadding, BoundaryPadding, Inner, PaddingProblem, DefaultPaddingAlgorithm
using .Convolve: DiscreteConvolutionProblem, FFTConvolution
using .Project: ProjectionProblem, SSP2

using Random
using CairoMakie
using CairoMakie: colormap
using NLopt


Nx = Ny = 100
grid = (
range(-1, 1, length=Nx),
range(-1, 1, length=Ny),
)
# Random.seed!(42)
# design_vars = rand(Nx, Ny)
# design_vars = [sinpi(x) * sinpi(y) for (x, y) in Iterators.product(grid...)]
design_vars = let a = 0.5, b = 0.499
# Cassini oval
[((x^2 + y^2)^2 - 2a^2 * (x^2 - y^2) + a^4 - b^4) + 0.5 for (x, y) in Iterators.product(grid...)]
end
radius = 0.1

kernel = conickernel(grid, radius)

padprob = PaddingProblem(;
data = design_vars,
# boundary = BoundaryPadding(size(kernel) .- 1, size(kernel) .- 1),
boundary = FillPadding(1.0, size(kernel) .- 1, size(kernel) .- 1),
)
padalg = DefaultPaddingAlgorithm()
padsolver = init(padprob, padalg)
padsol = solve!(padsolver)

convprob = DiscreteConvolutionProblem(;
data = padsol.value,
kernel,
)

convalg = FFTConvolution()
convsolver = init(convprob, convalg)
convsol = solve!(convsolver)

depadprob = PaddingProblem(;
data = convsol.value,
boundary = Inner(size(kernel) .- 1, size(kernel) .- 1),
)
depadalg = DefaultPaddingAlgorithm()
depadsolver = init(depadprob, depadalg)
depadsol = solve!(depadsolver)

filtered_design_vars = depadsol.value

# projection points need not be the same as design variable grid
target_grid = (
range(-1, 1, length=Nx * 2),
range(-1, 1, length=Ny * 2),
)
target_points = vec(collect(Iterators.product(target_grid...)))
projprob = ProjectionProblem(;
data=filtered_design_vars,
grid,
target_points,
)
projalg = SSP2(; beta=Inf, eta=0.5)
projsolver = init(projprob, projalg)
projsol = solve!(projsolver)

projected_design_vars = projsol.value

let
fig = Figure()
ax1 = Axis(fig[1,1]; title = "design variables", aspect=DataAspect())
h1 = heatmap!(grid..., design_vars; colormap=colormap("grays"))
Colorbar(fig[1,2], h1)

ax2 = Axis(fig[1,3]; title = "SSP2 output", aspect=DataAspect())
h2 = heatmap!(target_grid..., reshape(projected_design_vars, length.(target_grid)); colormap=colormap("grays"))
Colorbar(fig[1,4], h2)
save("design.png", fig)
end

function fom(data, grid)
return sum(abs2, data) * prod(step, grid)
end
obj = fom(projected_design_vars, grid)

function adjoint_fom(adj_fom, data, grid)
adjoint_fom!(similar(data), adj_fom, data, grid)
end
function adjoint_fom!(adj_data, adj_fom, data, grid)
adj_data .= adj_fom .* 2 .* data .* prod(step, grid)
return adj_data
end

adj_projsol = adjoint_fom(1.0, projected_design_vars, grid)

adj_projprob = adjoint_solve!(projsolver, adj_projsol, projsol.tape)
adj_depadsol = adj_projprob.data
adj_depadprob = adjoint_solve!(depadsolver, adj_depadsol, depadsol.tape)
adj_convsol = adj_depadprob.data
adj_convprob = adjoint_solve!(convsolver, adj_convsol, convsol.tape)
adj_padsol = adj_convprob.data
adj_padprob = adjoint_solve!(padsolver, adj_padsol, padsol.tape)
adj_design_vars = adj_padprob.data

let
fig = Figure()
ax1 = Axis(fig[1,1]; title = "SSP2 output", aspect=DataAspect())
h1 = heatmap!(ax1, target_grid..., reshape(projected_design_vars, length.(target_grid)); colormap=colormap("grays"))
Colorbar(fig[1,2], h1)

ax2 = Axis(fig[1,3]; title = "design variables gradient", aspect=DataAspect())
h2 = heatmap!(ax2, grid..., adj_design_vars; colormap=colormap("RdBu"))
Colorbar(fig[1,4], h2)
save("design_gradient.png", fig)
end

fom_withgradient = let grid=grid, padsolver=padsolver, convsolver=convsolver, depadsolver=depadsolver, projsolver=projsolver, adj_projsol=adj_projsol
function (design_vars)

padsolver.data = design_vars
padsol = solve!(padsolver)
convsolver.data = padsol.value
convsol = solve!(convsolver)
depadsolver.data = convsol.value
depadsol = solve!(depadsolver)
projsolver.data = depadsol.value
projsol = solve!(projsolver)

_fom = fom(projsol.value, grid)
adjoint_fom!(adj_projsol, 1.0, projsol.value, grid)

adj_projprob = adjoint_solve!(projsolver, adj_projsol, projsol.tape)
adj_depadsol = adj_projprob.data
adj_depadprob = adjoint_solve!(depadsolver, adj_depadsol, depadsol.tape)
adj_convsol = adj_depadprob.data
adj_convprob = adjoint_solve!(convsolver, adj_convsol, convsol.tape)
adj_padsol = adj_convprob.data
adj_padprob = adjoint_solve!(padsolver, adj_padsol, padsol.tape)
adj_design_vars = adj_padprob.data
return _fom, adj_design_vars
end
end

h = 1e-5
h_index = (50, 50)
# h_index = (38, 50)
perturb = zero(design_vars)
perturb[h_index...] = h
fom_ph, = fom_withgradient(design_vars + perturb)
fom_mh, = fom_withgradient(design_vars - perturb)
dfomdh_fd = (fom_ph - fom_mh) / 2h

fom_val, adj_design_vars = fom_withgradient(design_vars)
dfomdh = adj_design_vars[h_index...]
@show dfomdh_fd dfomdh

opt = NLopt.Opt(:LD_CCSAQ, length(design_vars))
evaluation_history = Float64[]
my_objective_fn = let fom_withgradient=fom_withgradient, evaluation_history=evaluation_history, design_vars=design_vars
function (x, grad)
val, adj_design = fom_withgradient(reshape(x, size(design_vars)))
if !isempty(grad)
copy!(grad, vec(adj_design))
end
push!(evaluation_history, val)
return val
end
end
NLopt.min_objective!(opt, my_objective_fn)
NLopt.maxeval!(opt, 50)
fmax, xmax, ret = NLopt.optimize(opt, vec(design_vars))

let
padsolver.data = reshape(xmax, size(design_vars))
padsol = solve!(padsolver)
convsolver.data = padsol.value
convsol = solve!(convsolver)
depadsolver.data = convsol.value
depadsol = solve!(depadsolver)
projsolver.data = depadsol.value
projsol = solve!(projsolver)

fig = Figure()
ax1 = Axis(fig[1,1]; title = "Objective history", yscale=log10)
h1 = scatterlines!(ax1, evaluation_history)

ax2 = Axis(fig[1,2]; title = "Final SSP2 design", aspect=DataAspect())
h2 = heatmap!(target_grid..., reshape(projsol.value, length.(target_grid)); colormap=colormap("grays"))
Colorbar(fig[1,3], h2)
save("optimization.png", fig)
end
12 changes: 12 additions & 0 deletions src/julia/SSP/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name = "SSP"
uuid = "e5b5d2ee-15bb-40cc-a0da-b305b842b7a8"
version = "0.1.0"
authors = ["Lorenzo Van Munoz <lorenzo@vanmunoz.com>"]

[deps]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b"

[compat]
FFTW = "1.10.0"
FastInterpolations = "0.4.2"
12 changes: 12 additions & 0 deletions src/julia/SSP/src/SSP.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module SSP

include("definitions.jl")
public init, solve, solve!, adjoint_solve, adjoint_solve!

include("pad.jl")
include("kernel.jl")
include("convolve.jl")
include("interpolate.jl")
include("project.jl")

end
103 changes: 103 additions & 0 deletions src/julia/SSP/src/convolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
module Convolve

using FFTW: plan_fft!, plan_bfft!
import SSP: init, solve!, adjoint_solve!

Base.@kwdef struct DiscreteConvolutionProblem{D,K}
data::D
kernel::K
end

mutable struct DiscreteConvolutionSolver{D,K,A,C}
data::D
const kernel::K
alg::A
cacheval::C
end

Base.@kwdef struct FFTConvolution{F,P}
factors::F=(2,3,5,7)
plan_kws::P=(;)
end

function init(prob::DiscreteConvolutionProblem, alg::FFTConvolution)

(; data, kernel) = prob

N = size(data) # number of target points
K = size(kernel) # support of kernel
transformsize = N .+ K .- 1
fftsize = map(x -> nextprod(alg.factors, x), transformsize)

fftkernel = zeros(ComplexF64, fftsize...)
fftconv = zeros(ComplexF64, fftsize...)

plan_fw = plan_fft!(fftconv; alg.plan_kws...)
plan_bk = plan_bfft!(fftconv; alg.plan_kws...)
Comment thread
lxvm marked this conversation as resolved.
Outdated

output = similar(data)
adj_data = similar(data)
cacheval = (; fftsize, fftkernel, fftconv, plan_fw, plan_bk, output, adj_data)
return DiscreteConvolutionSolver(data, kernel, alg, cacheval)
end

function solve!(solver::DiscreteConvolutionSolver)
conv_solve!(solver, solver.alg)
end

function conv_solve!(solver, ::FFTConvolution)
(; data, kernel, cacheval) = solver
(; fftsize, fftkernel, fftconv, plan_fw, plan_bk, output, adj_data) = cacheval

fill!(fftkernel, zero(eltype(fftkernel)))
copy!(view(fftkernel, axes(kernel)...), kernel)
plan_fw * fftkernel
Comment thread
lxvm marked this conversation as resolved.
Outdated

fill!(fftconv, zero(eltype(fftconv)))
copy!(view(fftconv, axes(data)...), data)
plan_fw * fftconv

fftconv .*= fftkernel ./ prod(fftsize)
Comment thread
stevengj marked this conversation as resolved.
Outdated

plan_bk * fftconv

N = size(data) # number of target points
K = size(kernel) # support of kernel
target_indices = map((n, k) -> k÷2+1:n+k÷2, N, K)

elt = eltype(data) <: Real && eltype(kernel) <: Real ? real : identity
output .= elt.(view(fftconv, target_indices...))

return (; value = output, tape = nothing)
end

function adjoint_solve!(solver::DiscreteConvolutionSolver, adj_output, tape)
adjoint_conv_solve!(solver, solver.alg, adj_output, tape)
end

function adjoint_conv_solve!(solver, ::FFTConvolution, adj_output, tape)
(; data, kernel, cacheval) = solver
(; fftsize, fftkernel, fftconv, plan_fw, plan_bk, output, adj_data) = cacheval

fill!(fftkernel, zero(eltype(fftkernel)))
copy!(view(fftkernel, axes(kernel)...), kernel)
plan_fw * fftkernel

N = size(data) # number of target points
K = size(kernel) # support of kernel
target_indices = map((n, k) -> k÷2+1:n+k÷2, N, K)

elt = eltype(data) <: Real && eltype(kernel) <: Real ? real : identity
fill!(fftconv, zero(eltype(fftconv)))
view(fftconv, target_indices...) .= elt.(adj_output)
plan_fw * fftconv

fftconv .*= conj.(fftkernel) ./ prod(fftsize)

plan_bk * fftconv

adj_data .= elt.(view(fftconv, axes(data)...))

return (; data=adj_data, kernel=nothing)
end
end
18 changes: 18 additions & 0 deletions src/julia/SSP/src/definitions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Define a CommonSolve-like interface for adjoint problems

function init end

function solve! end

function solve(prob, alg)
solver = init(prob, alg)
sol = solve!(solver)
return sol
end
function adjoint_solve! end

function adjoint_solve(prob, alg, adj_sol, tape)
solver = init(prob, alg)
adj_prob = adjoint_solve!(solver, adj_sol, tape)
return adj_prob
end
Loading
Loading