-
Notifications
You must be signed in to change notification settings - Fork 0
Set up Julia package and example SSP2 optimizations #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 11 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
1a6e7c9
Set up Julia package and example SSP2 optimizations
lxvm 7a0353b
Delete examples/julia/Manifest.toml
lxvm 422bda4
fix underflow in tanh projection
lxvm 34b8986
use evalpoly
lxvm fbe944e
fix projection typos/bugs
lxvm e4e53df
refactor projection into a loop over target points
lxvm 46105ab
refactor interpolation into another module
lxvm 0e66fba
update example to use a higher resolution target grid and reduce allo…
lxvm a5ddab8
correct the usage of cubic_adjoint deriv api
lxvm 639ed5d
correct the comparison to finite differences because arrays are reused
lxvm b19bd89
avoid NaN in tanh projection when x == eta at beta = Inf
lxvm e9f1442
change definition of ProjectionProblem and SSP2
lxvm 000599d
add grid as an optional argument to PaddingProblem
lxvm db22074
add init! to interface
lxvm 326eeb0
fix typo
lxvm 1562411
add support for SSP1
lxvm 79588e5
add linear interpolation and use bc in cubic interpolation
lxvm 75c442a
rename R_smoothing_factor to smoothing_radius like python api
lxvm 17788f2
add pythonic API
lxvm f32ed89
add ChainRulesCore rrules for pythonic api
lxvm 7502781
run checks to make sure different apis match
lxvm 7a7ac00
document apis of padding module
lxvm 242c131
document public api in kernels module
lxvm e4525d8
document public api of convolution module
lxvm 3cfa7d9
Document public api of interpolation module
lxvm 13ac5ea
document the public api of the projection module
lxvm d10ba70
document public api of ssp package
lxvm ea998e4
add package tests
lxvm 4d56672
add README.md
lxvm 5baff4f
remove revise dep
lxvm 01ad3b8
use rffts, precompute conv kernel, use perturb in arb direction, and …
lxvm b4a6b26
Apply suggestion from @stevengj
stevengj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| [deps] | ||
| CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
| NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" | ||
| SSP = "e5b5d2ee-15bb-40cc-a0da-b305b842b7a8" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| using SSP: init, solve!, adjoint_solve! | ||
| using SSP: Kernel, Pad, Convolve, Project | ||
| using .Kernel: conickernel | ||
| using .Pad: FillPadding, BoundaryPadding, Inner, PaddingProblem, DefaultPaddingAlgorithm | ||
| using .Convolve: DiscreteConvolutionProblem, FFTConvolution | ||
| using .Project: ProjectionProblem, SSP2 | ||
|
|
||
| using Random | ||
| using CairoMakie | ||
| using CairoMakie: colormap | ||
| using NLopt | ||
|
|
||
|
|
||
| Nx = Ny = 100 | ||
| grid = ( | ||
| range(-1, 1, length=Nx), | ||
| range(-1, 1, length=Ny), | ||
| ) | ||
| # Random.seed!(42) | ||
| # design_vars = rand(Nx, Ny) | ||
| # design_vars = [sinpi(x) * sinpi(y) for (x, y) in Iterators.product(grid...)] | ||
| design_vars = let a = 0.5, b = 0.499 | ||
| # Cassini oval | ||
| [((x^2 + y^2)^2 - 2a^2 * (x^2 - y^2) + a^4 - b^4) + 0.5 for (x, y) in Iterators.product(grid...)] | ||
| end | ||
| radius = 0.1 | ||
|
|
||
| kernel = conickernel(grid, radius) | ||
|
|
||
| padprob = PaddingProblem(; | ||
| data = design_vars, | ||
| # boundary = BoundaryPadding(size(kernel) .- 1, size(kernel) .- 1), | ||
| boundary = FillPadding(1.0, size(kernel) .- 1, size(kernel) .- 1), | ||
| ) | ||
| padalg = DefaultPaddingAlgorithm() | ||
| padsolver = init(padprob, padalg) | ||
| padsol = solve!(padsolver) | ||
|
|
||
| convprob = DiscreteConvolutionProblem(; | ||
| data = padsol.value, | ||
| kernel, | ||
| ) | ||
|
|
||
| convalg = FFTConvolution() | ||
| convsolver = init(convprob, convalg) | ||
| convsol = solve!(convsolver) | ||
|
|
||
| depadprob = PaddingProblem(; | ||
| data = convsol.value, | ||
| boundary = Inner(size(kernel) .- 1, size(kernel) .- 1), | ||
| ) | ||
| depadalg = DefaultPaddingAlgorithm() | ||
| depadsolver = init(depadprob, depadalg) | ||
| depadsol = solve!(depadsolver) | ||
|
|
||
| filtered_design_vars = depadsol.value | ||
|
|
||
| # projection points need not be the same as design variable grid | ||
| target_grid = ( | ||
| range(-1, 1, length=Nx * 2), | ||
| range(-1, 1, length=Ny * 2), | ||
| ) | ||
| target_points = vec(collect(Iterators.product(target_grid...))) | ||
| projprob = ProjectionProblem(; | ||
| data=filtered_design_vars, | ||
| grid, | ||
| target_points, | ||
| ) | ||
| projalg = SSP2(; beta=Inf, eta=0.5) | ||
| projsolver = init(projprob, projalg) | ||
| projsol = solve!(projsolver) | ||
|
|
||
| projected_design_vars = projsol.value | ||
|
|
||
| let | ||
| fig = Figure() | ||
| ax1 = Axis(fig[1,1]; title = "design variables", aspect=DataAspect()) | ||
| h1 = heatmap!(grid..., design_vars; colormap=colormap("grays")) | ||
| Colorbar(fig[1,2], h1) | ||
|
|
||
| ax2 = Axis(fig[1,3]; title = "SSP2 output", aspect=DataAspect()) | ||
| h2 = heatmap!(target_grid..., reshape(projected_design_vars, length.(target_grid)); colormap=colormap("grays")) | ||
| Colorbar(fig[1,4], h2) | ||
| save("design.png", fig) | ||
| end | ||
|
|
||
| function fom(data, grid) | ||
| return sum(abs2, data) * prod(step, grid) | ||
| end | ||
| obj = fom(projected_design_vars, grid) | ||
|
|
||
| function adjoint_fom(adj_fom, data, grid) | ||
| adjoint_fom!(similar(data), adj_fom, data, grid) | ||
| end | ||
| function adjoint_fom!(adj_data, adj_fom, data, grid) | ||
| adj_data .= adj_fom .* 2 .* data .* prod(step, grid) | ||
| return adj_data | ||
| end | ||
|
|
||
| adj_projsol = adjoint_fom(1.0, projected_design_vars, grid) | ||
|
|
||
| adj_projprob = adjoint_solve!(projsolver, adj_projsol, projsol.tape) | ||
| adj_depadsol = adj_projprob.data | ||
| adj_depadprob = adjoint_solve!(depadsolver, adj_depadsol, depadsol.tape) | ||
| adj_convsol = adj_depadprob.data | ||
| adj_convprob = adjoint_solve!(convsolver, adj_convsol, convsol.tape) | ||
| adj_padsol = adj_convprob.data | ||
| adj_padprob = adjoint_solve!(padsolver, adj_padsol, padsol.tape) | ||
| adj_design_vars = adj_padprob.data | ||
|
|
||
| let | ||
| fig = Figure() | ||
| ax1 = Axis(fig[1,1]; title = "SSP2 output", aspect=DataAspect()) | ||
| h1 = heatmap!(ax1, target_grid..., reshape(projected_design_vars, length.(target_grid)); colormap=colormap("grays")) | ||
| Colorbar(fig[1,2], h1) | ||
|
|
||
| ax2 = Axis(fig[1,3]; title = "design variables gradient", aspect=DataAspect()) | ||
| h2 = heatmap!(ax2, grid..., adj_design_vars; colormap=colormap("RdBu")) | ||
| Colorbar(fig[1,4], h2) | ||
| save("design_gradient.png", fig) | ||
| end | ||
|
|
||
| fom_withgradient = let grid=grid, padsolver=padsolver, convsolver=convsolver, depadsolver=depadsolver, projsolver=projsolver, adj_projsol=adj_projsol | ||
| function (design_vars) | ||
|
|
||
| padsolver.data = design_vars | ||
| padsol = solve!(padsolver) | ||
| convsolver.data = padsol.value | ||
| convsol = solve!(convsolver) | ||
| depadsolver.data = convsol.value | ||
| depadsol = solve!(depadsolver) | ||
| projsolver.data = depadsol.value | ||
| projsol = solve!(projsolver) | ||
|
|
||
| _fom = fom(projsol.value, grid) | ||
| adjoint_fom!(adj_projsol, 1.0, projsol.value, grid) | ||
|
|
||
| adj_projprob = adjoint_solve!(projsolver, adj_projsol, projsol.tape) | ||
| adj_depadsol = adj_projprob.data | ||
| adj_depadprob = adjoint_solve!(depadsolver, adj_depadsol, depadsol.tape) | ||
| adj_convsol = adj_depadprob.data | ||
| adj_convprob = adjoint_solve!(convsolver, adj_convsol, convsol.tape) | ||
| adj_padsol = adj_convprob.data | ||
| adj_padprob = adjoint_solve!(padsolver, adj_padsol, padsol.tape) | ||
| adj_design_vars = adj_padprob.data | ||
| return _fom, adj_design_vars | ||
| end | ||
| end | ||
|
|
||
| h = 1e-5 | ||
| h_index = (50, 50) | ||
| # h_index = (38, 50) | ||
| perturb = zero(design_vars) | ||
| perturb[h_index...] = h | ||
| fom_ph, = fom_withgradient(design_vars + perturb) | ||
| fom_mh, = fom_withgradient(design_vars - perturb) | ||
| dfomdh_fd = (fom_ph - fom_mh) / 2h | ||
|
|
||
| fom_val, adj_design_vars = fom_withgradient(design_vars) | ||
| dfomdh = adj_design_vars[h_index...] | ||
| @show dfomdh_fd dfomdh | ||
|
|
||
| opt = NLopt.Opt(:LD_CCSAQ, length(design_vars)) | ||
| evaluation_history = Float64[] | ||
| my_objective_fn = let fom_withgradient=fom_withgradient, evaluation_history=evaluation_history, design_vars=design_vars | ||
| function (x, grad) | ||
| val, adj_design = fom_withgradient(reshape(x, size(design_vars))) | ||
| if !isempty(grad) | ||
| copy!(grad, vec(adj_design)) | ||
| end | ||
| push!(evaluation_history, val) | ||
| return val | ||
| end | ||
| end | ||
| NLopt.min_objective!(opt, my_objective_fn) | ||
| NLopt.maxeval!(opt, 50) | ||
| fmax, xmax, ret = NLopt.optimize(opt, vec(design_vars)) | ||
|
|
||
| let | ||
| padsolver.data = reshape(xmax, size(design_vars)) | ||
| padsol = solve!(padsolver) | ||
| convsolver.data = padsol.value | ||
| convsol = solve!(convsolver) | ||
| depadsolver.data = convsol.value | ||
| depadsol = solve!(depadsolver) | ||
| projsolver.data = depadsol.value | ||
| projsol = solve!(projsolver) | ||
|
|
||
| fig = Figure() | ||
| ax1 = Axis(fig[1,1]; title = "Objective history", yscale=log10) | ||
| h1 = scatterlines!(ax1, evaluation_history) | ||
|
|
||
| ax2 = Axis(fig[1,2]; title = "Final SSP2 design", aspect=DataAspect()) | ||
| h2 = heatmap!(target_grid..., reshape(projsol.value, length.(target_grid)); colormap=colormap("grays")) | ||
| Colorbar(fig[1,3], h2) | ||
| save("optimization.png", fig) | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| name = "SSP" | ||
| uuid = "e5b5d2ee-15bb-40cc-a0da-b305b842b7a8" | ||
| version = "0.1.0" | ||
| authors = ["Lorenzo Van Munoz <lorenzo@vanmunoz.com>"] | ||
|
|
||
| [deps] | ||
| FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" | ||
| FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b" | ||
|
|
||
| [compat] | ||
| FFTW = "1.10.0" | ||
| FastInterpolations = "0.4.2" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| module SSP | ||
|
|
||
| include("definitions.jl") | ||
| public init, solve, solve!, adjoint_solve, adjoint_solve! | ||
|
|
||
| include("pad.jl") | ||
| include("kernel.jl") | ||
| include("convolve.jl") | ||
| include("interpolate.jl") | ||
| include("project.jl") | ||
|
|
||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| module Convolve | ||
|
|
||
| using FFTW: plan_fft!, plan_bfft! | ||
| import SSP: init, solve!, adjoint_solve! | ||
|
|
||
| Base.@kwdef struct DiscreteConvolutionProblem{D,K} | ||
| data::D | ||
| kernel::K | ||
| end | ||
|
|
||
| mutable struct DiscreteConvolutionSolver{D,K,A,C} | ||
| data::D | ||
| const kernel::K | ||
| alg::A | ||
| cacheval::C | ||
| end | ||
|
|
||
| Base.@kwdef struct FFTConvolution{F,P} | ||
| factors::F=(2,3,5,7) | ||
| plan_kws::P=(;) | ||
| end | ||
|
|
||
| function init(prob::DiscreteConvolutionProblem, alg::FFTConvolution) | ||
|
|
||
| (; data, kernel) = prob | ||
|
|
||
| N = size(data) # number of target points | ||
| K = size(kernel) # support of kernel | ||
| transformsize = N .+ K .- 1 | ||
| fftsize = map(x -> nextprod(alg.factors, x), transformsize) | ||
|
|
||
| fftkernel = zeros(ComplexF64, fftsize...) | ||
| fftconv = zeros(ComplexF64, fftsize...) | ||
|
|
||
| plan_fw = plan_fft!(fftconv; alg.plan_kws...) | ||
| plan_bk = plan_bfft!(fftconv; alg.plan_kws...) | ||
|
lxvm marked this conversation as resolved.
Outdated
|
||
|
|
||
| output = similar(data) | ||
| adj_data = similar(data) | ||
| cacheval = (; fftsize, fftkernel, fftconv, plan_fw, plan_bk, output, adj_data) | ||
| return DiscreteConvolutionSolver(data, kernel, alg, cacheval) | ||
| end | ||
|
|
||
| function solve!(solver::DiscreteConvolutionSolver) | ||
| conv_solve!(solver, solver.alg) | ||
| end | ||
|
|
||
| function conv_solve!(solver, ::FFTConvolution) | ||
| (; data, kernel, cacheval) = solver | ||
| (; fftsize, fftkernel, fftconv, plan_fw, plan_bk, output, adj_data) = cacheval | ||
|
|
||
| fill!(fftkernel, zero(eltype(fftkernel))) | ||
| copy!(view(fftkernel, axes(kernel)...), kernel) | ||
| plan_fw * fftkernel | ||
|
lxvm marked this conversation as resolved.
Outdated
|
||
|
|
||
| fill!(fftconv, zero(eltype(fftconv))) | ||
| copy!(view(fftconv, axes(data)...), data) | ||
| plan_fw * fftconv | ||
|
|
||
| fftconv .*= fftkernel ./ prod(fftsize) | ||
|
stevengj marked this conversation as resolved.
Outdated
|
||
|
|
||
| plan_bk * fftconv | ||
|
|
||
| N = size(data) # number of target points | ||
| K = size(kernel) # support of kernel | ||
| target_indices = map((n, k) -> k÷2+1:n+k÷2, N, K) | ||
|
|
||
| elt = eltype(data) <: Real && eltype(kernel) <: Real ? real : identity | ||
| output .= elt.(view(fftconv, target_indices...)) | ||
|
|
||
| return (; value = output, tape = nothing) | ||
| end | ||
|
|
||
| function adjoint_solve!(solver::DiscreteConvolutionSolver, adj_output, tape) | ||
| adjoint_conv_solve!(solver, solver.alg, adj_output, tape) | ||
| end | ||
|
|
||
| function adjoint_conv_solve!(solver, ::FFTConvolution, adj_output, tape) | ||
| (; data, kernel, cacheval) = solver | ||
| (; fftsize, fftkernel, fftconv, plan_fw, plan_bk, output, adj_data) = cacheval | ||
|
|
||
| fill!(fftkernel, zero(eltype(fftkernel))) | ||
| copy!(view(fftkernel, axes(kernel)...), kernel) | ||
| plan_fw * fftkernel | ||
|
|
||
| N = size(data) # number of target points | ||
| K = size(kernel) # support of kernel | ||
| target_indices = map((n, k) -> k÷2+1:n+k÷2, N, K) | ||
|
|
||
| elt = eltype(data) <: Real && eltype(kernel) <: Real ? real : identity | ||
| fill!(fftconv, zero(eltype(fftconv))) | ||
| view(fftconv, target_indices...) .= elt.(adj_output) | ||
| plan_fw * fftconv | ||
|
|
||
| fftconv .*= conj.(fftkernel) ./ prod(fftsize) | ||
|
|
||
| plan_bk * fftconv | ||
|
|
||
| adj_data .= elt.(view(fftconv, axes(data)...)) | ||
|
|
||
| return (; data=adj_data, kernel=nothing) | ||
| end | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Define a CommonSolve-like interface for adjoint problems | ||
|
|
||
| function init end | ||
|
|
||
| function solve! end | ||
|
|
||
| function solve(prob, alg) | ||
| solver = init(prob, alg) | ||
| sol = solve!(solver) | ||
| return sol | ||
| end | ||
| function adjoint_solve! end | ||
|
|
||
| function adjoint_solve(prob, alg, adj_sol, tape) | ||
| solver = init(prob, alg) | ||
| adj_prob = adjoint_solve!(solver, adj_sol, tape) | ||
| return adj_prob | ||
| end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.