Skip to content

Commit c04a866

Browse files
committed
improve performance of ClusterTree constructor
- use a buffer to store the intermediate permutations - only permute the elements at the end of the tree construction
1 parent 7683739 commit c04a866

3 files changed

Lines changed: 56 additions & 29 deletions

File tree

benchmarks/clustertree_bench.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import WavePropBase as WPB
2+
using StaticArrays
3+
using BenchmarkTools
4+
5+
n = 1_000_000
6+
7+
pts = rand(WPB.Point3D,n)
8+
splitter = WPB.GeometricSplitter(nmax=100)
9+
@btime WPB.ClusterTree($pts,$splitter;threads=false)

src/Trees/clustertree.jl

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ mutable struct ClusterTree{T,S,D} <: AbstractTree
2222
container::S
2323
index_range::UnitRange{Int}
2424
loc2glob::Vector{Int}
25+
buffer::Vector{Int}
2526
children::Vector{ClusterTree{T,S,D}}
2627
parent::ClusterTree{T,S,D}
2728
data::D
2829
# inner constructors handling missing fields.
29-
function ClusterTree{D}(els::T,container::S,loc_idxs,loc2glob,children,parent,data=nothing) where {T,S,D}
30-
clt = new{T,S,D}(els,container,loc_idxs,loc2glob)
30+
function ClusterTree{D}(els::T,container::S,loc_idxs,loc2glob,buffer,children,parent,data=nothing) where {T,S,D}
31+
clt = new{T,S,D}(els,container,loc_idxs,loc2glob,buffer)
3132
clt.children = isnothing(children) ? Vector{typeof(clt)}() : children
3233
clt.parent = isnothing(parent) ? clt : parent
3334
clt.data = isnothing(data) ? D() : data
@@ -76,6 +77,8 @@ the (global) indexes used upon the construction of the tree.
7677
"""
7778
loc2glob(clt::ClusterTree) = clt.loc2glob
7879

80+
buffer(clt::ClusterTree) = clt.buffer
81+
7982
"""
8083
container_type(clt::ClusterTree)
8184
@@ -126,27 +129,29 @@ function ClusterTree{D}(elements,splitter=CardinalitySplitter();copy_elements=tr
126129
bbox = HyperRectangle(elements)
127130
end
128131
n = length(elements)
129-
irange = 1:n
132+
irange = 1:n
130133
loc2glob = collect(irange)
134+
buffer = collect(irange)
131135
children = nothing
132136
parent = nothing
133137
#build the root, then recurse
134-
root = ClusterTree{D}(elements,bbox,irange,loc2glob,children,parent)
138+
root = ClusterTree{D}(elements,bbox,irange,loc2glob,buffer,children,parent)
135139
_build_cluster_tree!(root,splitter,threads)
140+
# finally, permute the elements so as to use the local indexing
141+
copy!(elements,elements[loc2glob]) # faster than permute!
136142
return root
137143
end
138144
ClusterTree(args...;kwargs...) = ClusterTree{Nothing}(args...;kwargs...)
139145

140146
function _build_cluster_tree!(current_node,splitter,threads)
141147
if should_split(current_node,splitter)
142-
children = split!(current_node,splitter)
143-
current_node.children = children
148+
split!(current_node,splitter)
144149
if threads
145-
Threads.@threads for child in children
150+
Threads.@threads for child in children(current_node)
146151
_build_cluster_tree!(child,splitter,threads)
147152
end
148153
else
149-
for child in children
154+
for child in children(current_node)
150155
_build_cluster_tree!(child,splitter,threads)
151156
end
152157
end
@@ -180,18 +185,18 @@ function _common_binary_split!(cluster::ClusterTree{T,S,D},conditions;
180185
parentcluster) where {T,S,D}
181186
els = root_elements(cluster)
182187
l2g = loc2glob(cluster)
188+
buf = buffer(cluster)
183189
irange = index_range(cluster)
184190
# get split data
185191
npts_left,npts_right,left_rec,right_rec,buff = binary_split_data(cluster,conditions)
186192
@assert npts_left + npts_right == length(irange) "elements lost during split"
187-
l2g[irange] = l2g[buff]
188-
els[irange] = els[buff] # reorders the global index set
193+
copy!(view(l2g,irange),buff)
189194
# new ranges for cluster
190195
left_indices = irange.start:(irange.start)+npts_left-1
191196
right_indices = (irange.start+npts_left):irange.stop
192197
# create children
193-
clt1 = ClusterTree{D}(els,left_rec,left_indices,l2g,nothing,parentcluster)
194-
clt2 = ClusterTree{D}(els,right_rec,right_indices,l2g,nothing,parentcluster)
198+
clt1 = ClusterTree{D}(els,left_rec,left_indices,l2g,buf,nothing,parentcluster)
199+
clt2 = ClusterTree{D}(els,right_rec,right_indices,l2g,buf,nothing,parentcluster)
195200
return clt1, clt2
196201
end
197202
function binary_split_data(cluster::ClusterTree{T,S},conditions::Function) where {T,S}
@@ -200,23 +205,24 @@ function binary_split_data(cluster::ClusterTree{T,S},conditions::Function) where
200205
els = root_elements(cluster)
201206
irange = index_range(cluster)
202207
n = length(irange)
208+
buff = view(cluster.buffer,irange)
209+
l2g = loc2glob(cluster)
203210
npts_left = 0
204211
npts_right = 0
205-
buff = Vector{Int}(undef,length(cluster))
206212
xl_left = xl_right = high_corner(rec)
207213
xu_left = xu_right = low_corner(rec)
208214
#sort the points into left and right rectangle
209215
for i in irange
210-
pt = els[i] |> center
216+
pt = els[l2g[i]] |> center
211217
if f(pt)
212218
xl_left = min.(xl_left,pt)
213219
xu_left = max.(xu_left,pt)
214220
npts_left += 1
215-
buff[npts_left] = i
221+
buff[npts_left] = l2g[i]
216222
else
217223
xl_right = min.(xl_right,pt)
218224
xu_right = max.(xu_right,pt)
219-
buff[n-npts_right] = i
225+
buff[n-npts_right] = l2g[i]
220226
npts_right += 1
221227
end
222228
end
@@ -231,19 +237,21 @@ function binary_split_data(cluster::ClusterTree,conditions::Tuple{Integer,Real})
231237
els = root_elements(cluster)
232238
irange = index_range(cluster)
233239
n = length(irange)
240+
l2g = loc2glob(cluster)
234241
npts_left = 0
235242
npts_right = 0
236-
buff = Vector{Int}(undef,length(cluster))
243+
buff = view(cluster.buffer,irange)
237244
# bounding boxes
238245
left_rec, right_rec = split(rec,dir,pos)
239246
#sort the points into left and right rectangle
240247
for i in irange
241-
pt = els[i] |> center
248+
pt = els[l2g[i]] |> center
242249
if pt in left_rec
243250
npts_left += 1
244-
buff[npts_left] = i
251+
buff[npts_left] = l2g[i]
245252
else # pt in right_rec
246-
buff[n-npts_right] = i
253+
# @assert pt in right_rec
254+
buff[n-npts_right] = l2g[i]
247255
npts_right += 1
248256
end
249257
end

src/Trees/splitter.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ end
2424
"""
2525
split!(clt::ClusterTree,splitter::AbstractSplitter)
2626
27-
Divide `clt` using the strategy implemented by `splitter`.
27+
Divide `clt` using the strategy implemented by `splitter`. This function is
28+
reponsible of assigning the `children` and `parent` fields, as well as of
29+
permuting the data of `clt`.
2830
"""
2931
function split!(clt, splitter::AbstractSplitter)
3032
abstractmethod(splitter)
@@ -59,7 +61,8 @@ function split!(parentcluster::ClusterTree, ::DyadicSplitter)
5961
append!(clusters, _binary_split!(clt, i, pos; parentcluster))
6062
end
6163
end
62-
return clusters
64+
parentcluster.children = clusters
65+
return parentcluster
6366
end
6467

6568
"""
@@ -77,7 +80,8 @@ function split!(cluster::ClusterTree, ::GeometricSplitter)
7780
rec = cluster.container
7881
wmax, imax = findmax(high_corner(rec) - low_corner(rec))
7982
left_node, right_node = _binary_split!(cluster, imax, low_corner(rec)[imax] + wmax / 2)
80-
return [left_node, right_node]
83+
cluster.children = [left_node,right_node]
84+
return cluster
8185
end
8286

8387
"""
@@ -97,7 +101,8 @@ function split!(cluster::ClusterTree, ::GeometricMinimalSplitter)
97101
mid = low_corner(rec)[imax] + wmax / 2
98102
predicate = (x) -> x[imax] < mid
99103
left_node, right_node = _binary_split!(cluster, predicate)
100-
return [left_node, right_node]
104+
cluster.children = [left_node,right_node]
105+
return cluster
101106
end
102107

103108
"""
@@ -114,25 +119,28 @@ function split!(cluster::ClusterTree, ::PrincipalComponentSplitter)
114119
irange = cluster.index_range
115120
xc = center_of_mass(cluster)
116121
# compute covariance matrix for principal direction
122+
l2g = loc2glob(cluster)
117123
cov = sum(irange) do i
118-
x = coords(pts[i])
124+
x = coords(pts[l2g[i]])
119125
(x - xc) * transpose(x - xc)
120126
end
121127
v = eigvecs(cov)[:, end]
122128
predicate = (x) -> dot(x - xc, v) < 0
123129
left_node, right_node = _binary_split!(cluster, predicate)
124-
return [left_node, right_node]
130+
cluster.children = [left_node,right_node]
131+
return cluster
125132
end
126133

127134
function center_of_mass(clt::ClusterTree)
128135
pts = clt._elements
129136
loc_idxs = clt.index_range
137+
l2g = loc2glob(clt)
130138
# w = clt.weights
131139
n = length(loc_idxs)
132140
# M = isempty(w) ? n : sum(i->w[i],glob_idxs)
133141
# xc = isempty(w) ? sum(i->pts[i]/M,glob_idxs) : sum(i->w[i]*pts[i]/M,glob_idxs)
134142
M = n
135-
xc = sum(i -> coords(pts[i]) / M, loc_idxs)
143+
xc = sum(i -> coords(pts[l2g[i]]) / M, loc_idxs)
136144
return xc
137145
end
138146

@@ -156,8 +164,10 @@ function split!(cluster::ClusterTree, ::CardinalitySplitter)
156164
irange = cluster.index_range
157165
rec = container(cluster)
158166
_, imax = findmax(high_corner(rec) - low_corner(rec))
159-
med = median(coords(points[i])[imax] for i in irange) # the median along largest axis `imax`
167+
l2g = loc2glob(cluster)
168+
med = median(coords(points[l2g[i]])[imax] for i in irange) # the median along largest axis `imax`
160169
predicate = (x) -> x[imax] < med
161170
left_node, right_node = _binary_split!(cluster, predicate)
162-
return [left_node, right_node]
171+
cluster.children = [left_node,right_node]
172+
return cluster
163173
end

0 commit comments

Comments
 (0)