Skip to content

Commit e8f07eb

Browse files
committed
Convergence check for NTU
1 parent c51b54b commit e8f07eb

2 files changed

Lines changed: 91 additions & 15 deletions

File tree

src/PEPSKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module PEPSKit
33
using LinearAlgebra, Statistics, Base.Threads, Base.Iterators, Printf
44
using Random
55
using Compat
6-
using Accessors: @set, @reset
6+
using Accessors: @set, @reset, @insert
77
using VectorInterface
88
import VectorInterface as VI
99

src/algorithms/time_evolution/ntupdate.jl

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function ntu_iter(
7979
state::InfiniteState, circuit::LocalCircuit, alg::NeighbourUpdate
8080
)
8181
Nr, Nc, = size(state)
82-
state2, wts = deepcopy(state), SUWeight(state)
82+
state2, wts = copy(state), SUWeight(state)
8383
info = (; fid = 1.0)
8484
for (sites, gate) in circuit.gates
8585
if length(sites) == 1
@@ -96,14 +96,14 @@ function ntu_iter(
9696
(!alg.bipartite) && continue
9797
if d == 1
9898
rp1, cp1 = _next(r, Nr), _next(c, Nc)
99-
state2[rp1, cp1] = deepcopy(state2[r, c])
100-
state2[rp1, c] = deepcopy(state2[r, cp1])
101-
wts[1, rp1, cp1] = deepcopy(wts[1, r, c])
99+
state2[rp1, cp1] = copy(state2[r, c])
100+
state2[rp1, c] = copy(state2[r, cp1])
101+
wts[1, rp1, cp1] = copy(wts[1, r, c])
102102
else
103103
rm1, cm1 = _prev(r, Nr), _prev(c, Nc)
104-
state2[rm1, cm1] = deepcopy(state2[r, c])
105-
state2[r, cm1] = deepcopy(state2[rm1, c])
106-
wts[2, rm1, cm1] = deepcopy(wts[2, r, c])
104+
state2[rm1, cm1] = copy(state2[r, c])
105+
state2[r, cm1] = copy(state2[rm1, c])
106+
wts[2, rm1, cm1] = copy(wts[2, r, c])
107107
end
108108
else
109109
# N-site MPO gate (N ≥ 2)
@@ -156,14 +156,91 @@ end
156156

157157
"""
158158
time_evolve(
159-
it::TimeEvolver{<:NeighbourUpdate}; check_interval::Int = 500
159+
it::TimeEvolver{<:NeighbourUpdate},
160+
[H::LocalOperator, env::CTMRGEnv, ctm_alg::CTMRGAlgorithm];
161+
tol::Float64 = 1.0e-7, check_interval::Int = 10
160162
) -> (psi, info)
161163
162-
Perform time evolution to the end of `TimeEvolver` iterator `it`.
164+
Perform time evolution to the end of `NeighbourUpdate` TimeEvolver `it`,
165+
or until convergence of energy set by a positive `tol`.
163166
164-
- `check_interval` sets the number of iterations between outputs of information.
167+
To enable convergence check (for imaginary time evolution of InfinitePEPS only),
168+
provide the Hamiltonian `H`, CTMRG environment `env`, CTMRG algorithm `ctm_alg`
169+
and setting `tol > 0`.
170+
171+
`check_interval` sets the number of iterations between energy checks
172+
(for ground state search) and outputs of information.
165173
"""
166-
function MPSKit.time_evolve(it::TimeEvolver{<:NeighbourUpdate}; check_interval::Int = 500)
174+
function MPSKit.time_evolve(
175+
it::TimeEvolver{<:NeighbourUpdate},
176+
H::LocalOperator, env::CTMRGEnv, ctm_alg::CTMRGAlgorithm;
177+
tol::Float64 = 1.0e-7, check_interval::Int = 10
178+
)
179+
@info "--- Time evolution (neighbourhood tensor update), dt = $(it.dt) ---"
180+
time_start = time0 = time()
181+
psi0 = copy(it.state.psi)
182+
@assert (psi0 isa InfinitePEPS) && it.alg.imaginary_time "Only imaginary time evolution of InfinitePEPS allows convergence checking."
183+
# initial energy
184+
env, = leading_boundary(env, psi0, ctm_alg)
185+
energy = expectation_value(psi0, H, env) / prod(size(psi0))
186+
@info @sprintf("NTU iter 0: E = %.4e", energy)
187+
info0 = (; energy, env)
188+
# start evolving
189+
energy0, ΔE = energy, 0.0
190+
iter0, t0 = it.state.iter, it.state.t
191+
for (psi, info) in it
192+
iter = it.state.iter
193+
showinfo = (iter == 1) || (iter % check_interval == 0) || (iter == it.nstep)
194+
!showinfo && continue
195+
# bond weight change
196+
Δλ = hasproperty(info0, :wts) ? compare_weights(info.wts, info0.wts) : NaN
197+
# reconverge environment
198+
if all(space(t) == space(t0) for (t, t0) in zip(psi.A, psi0.A))
199+
# recreate `env` from bond weights if psi virtual space changed
200+
env = CTMRGEnv(info.wts)
201+
end
202+
env, = leading_boundary(env, psi, ctm_alg)
203+
# measure energy
204+
energy = expectation_value(psi, H, env) / prod(size(psi))
205+
ΔE = energy - energy0
206+
info = @insert info.energy = energy
207+
info = @insert info.env = env
208+
# show information
209+
time1 = time()
210+
@info @sprintf(
211+
"NTU iter %-6d: E = %.5f, ΔE = %.3e, |Δλ| = %.3e. Time: %.2f s",
212+
it.state.iter, energy, ΔE, Δλ, time1 - time0
213+
)
214+
# determine whether to stop evolution
215+
stop = false
216+
if (ΔE <= 0 && abs(ΔE) < tol)
217+
stop = true
218+
@info "NTU: energy has converged."
219+
end
220+
if ΔE > 0
221+
stop = true
222+
@warn "NTU: energy has increased. Abort evolution and return results from last check."
223+
psi, info, energy = psi0, info0, energy0
224+
it.state = NTUState(iter0, t0, psi0)
225+
end
226+
if iter == it.nstep
227+
stop = true
228+
@info "NTU: reached maximum iteration."
229+
end
230+
if stop
231+
time_end = time()
232+
@info @sprintf("Time evolution finished in %.2f s", time_end - time_start)
233+
return psi, info
234+
else
235+
iter0, t0 = it.state.iter, it.state.t
236+
psi0, energy0, info0 = psi, energy, info
237+
end
238+
time0 = time()
239+
end
240+
return
241+
end
242+
243+
function MPSKit.time_evolve(it::TimeEvolver{<:NeighbourUpdate}; check_interval::Int = 50)
167244
time_start = time0 = time()
168245
@info "--- Time evolution (neighbourhood tensor update), dt = $(it.dt) ---"
169246
info0 = nothing
@@ -197,11 +274,10 @@ end
197274
check_interval::Int = 10
198275
) -> (psi, info)
199276
200-
Perform time evolution on the initial state `psi0` and initial environment `env0`
201-
with Hamiltonian `H`, using `NeighbourUpdate` algorithm `alg`, time step `dt` for
277+
Perform time evolution on the initial state `psi0` with Hamiltonian `H`,
278+
using `NeighbourUpdate` algorithm `alg`, time step `dt` for
202279
`nstep` number of steps.
203280
204-
- Convergence check for ground state search is not supported.
205281
- Set `symmetrize_gates = true` for second-order Trotter decomposition.
206282
- Use `t0` to specify the initial time of `psi0`.
207283
- `check_interval` sets the interval to output information (and check convergence).

0 commit comments

Comments
 (0)