Skip to content

Commit cb79b00

Browse files
committed
update
1 parent 7a8dfb8 commit cb79b00

3 files changed

Lines changed: 18 additions & 5 deletions

File tree

src/AbstractWavefunctionMC.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function Carlo.init!(mc::AbstractWavefunctionMC{N}, ctx::MCContext, params::Abst
3434
mc_state = state(mc)
3535
N_active, dynamic_pos = dynamic_positions(mc)
3636

37-
mc_state[:] = coordinate_projector(mc).(rand(ctx.rng, distribution(mc), N))
37+
mc_state[:] = [coordinate_projector(mc)(rand(ctx.rng, distribution(mc))) for _ in 1:N]
3838

3939
mc_state.logdensity = logdensity(wavefunction(mc), position(mc_state).buffer)
4040
mc_state.num_accepts = 0
@@ -62,7 +62,7 @@ end
6262

6363
function _update_coordinates(ctx::MCContext, dynamic_pos, bare_position, d, coordinate_trafo, coordinate_proj, wavefunc::T, buffer::X, mc_state) where {T<:AbstractWavefunction, X<:Buffer}
6464
@inbounds for dim in dynamic_pos
65-
x_new = coordinate_proj(bare_position[dim] + rand(ctx.rng, d))
65+
x_new = coordinate_proj(bare_position[dim] .+ rand(ctx.rng, d))
6666
x_new_transformed = coordinate_trafo(x_new)
6767

6868
logα = delta_logdensity(wavefunc, x_new_transformed, buffer.buffer, dim)

src/WavefunctionMC.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,24 @@ function WavefunctionMC(params::AbstractDict)
5252
N = inputlength(wavefunction)
5353
sigma_dist = get(params, :sigma_distribution, 0.3)
5454
distribution = get(params, :distribution, S <: Complex ? ComplexNormal(0, sigma_dist) : Normal(0, sigma_dist))
55-
position = coordinate_proj.(100 * rand(distribution, N))
55+
position = [coordinate_proj(100 * rand(distribution)) for _ in 1:N]
5656

5757
observables = get(params, :observables, NoObservables())
5858
dynamic_pos = get(params, :dynamic_positions, (N, 1:N))
59-
return WavefunctionMC{N}(State(position, 0.0, coordinate_transf), distribution, coordinate_proj, wavefunction, observables; adapt_interval = div(get(params, :thermalization, 10_000), 10), dynamic_positions = dynamic_pos)
59+
60+
adapt_interval = get(params, :adapt_interval,div(get(params, :thermalization, 10_000), 10))
61+
coordinate_update = get(params, :coordinate_update, CoordinateUpdate())
62+
target_accept = get(params, :target_accept, optimal_acceptance_rate(coordinate_update))
63+
adaptive = get(params, :adaptive, true)
64+
active_during_run = get(params, :active_during_run, false)
65+
update_distribution = get(params, :update_distribution, (d, adjustment) -> begin
66+
sigma_new = d.σ * Base.exp(-adjustment)
67+
return typeof(d)(d.μ, sigma_new)
68+
end)
69+
adaptor = adaptive ? AcceptanceAdapter(; acceptance_rate = target_accept, adapt_interval = adapt_interval, update_distribution = update_distribution, active_during_run = active_during_run) : NoAcceptanceAdapter()
70+
71+
72+
return WavefunctionMC{N}(State(position, 0.0, coordinate_transf), distribution, coordinate_proj, wavefunction, observables; dynamic_positions = dynamic_pos, adaptor = adaptor, coordinate_update = coordinate_update)
6073
end
6174

6275
function Carlo.write_checkpoint(mc::WavefunctionMC, out::HDF5.Group)

src/buffer/buffer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Base.length(b::Buffer) = length(parent(b))
1818
Base.@propagate_inbounds Base.getindex(b::Buffer, i) = getindex(parent(b), i)
1919
Base.firstindex(b::Buffer) = firstindex(parent(b))
2020
Base.lastindex(b::Buffer) = lastindex(parent(b))
21-
Base.@propagate_inbounds Base.setindex!(b::Buffer, v, i) = (parent(b)[i] = b.map(v))
21+
Base.@propagate_inbounds Base.setindex!(b::Buffer, v, i) = (parent(b)[i] = b.map.(v))
2222

2323
function Buffer(def::T, size; map::G = Base.identity) where {T, G}
2424
return Buffer{T, AbstractVector{T}, G}(fill(def, size), map)

0 commit comments

Comments
 (0)