Skip to content

Commit c9432b0

Browse files
committed
add more dyadic trees
1 parent e12b389 commit c9432b0

3 files changed

Lines changed: 90 additions & 14 deletions

File tree

src/Trees/clustertree.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ during the tree construction.
122122
"""
123123
function ClusterTree{D}(elements,splitter=CardinalitySplitter();copy_elements=true, threads=false) where {D}
124124
copy_elements && (elements = deepcopy(elements))
125-
if splitter isa DyadicSplitter
125+
if splitter isa DyadicSplitter || splitter isa DyadicMinimalSplitter || splitter isa DyadicMaxDepthSplitter
126126
# make a cube for bounding box for quad/oct trees
127127
bbox = HyperRectangle(elements,true)
128128
else
@@ -143,16 +143,16 @@ function ClusterTree{D}(elements,splitter=CardinalitySplitter();copy_elements=tr
143143
end
144144
ClusterTree(args...;kwargs...) = ClusterTree{Nothing}(args...;kwargs...)
145145

146-
function _build_cluster_tree!(current_node,splitter,threads)
147-
if should_split(current_node,splitter)
146+
function _build_cluster_tree!(current_node,splitter,threads,depth=0)
147+
if should_split(current_node,depth,splitter)
148148
split!(current_node,splitter)
149149
if threads
150150
Threads.@threads for child in children(current_node)
151-
_build_cluster_tree!(child,splitter,threads)
151+
_build_cluster_tree!(child,splitter,threads,depth+1)
152152
end
153153
else
154154
for child in children(current_node)
155-
_build_cluster_tree!(child,splitter,threads)
155+
_build_cluster_tree!(child,splitter,threads,depth+1)
156156
end
157157
end
158158
end

src/Trees/splitter.jl

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ See [`GeometricSplitter`](@ref) for an example of an implementation.
1313
abstract type AbstractSplitter end
1414

1515
"""
16-
should_split(clt::ClusterTree,splitter::AbstractSplitter)
16+
should_split(clt::ClusterTree, depth, splitter::AbstractSplitter)
1717
1818
Determine whether or not a `ClusterTree` should be further divided.
1919
"""
20-
function should_split(clt, splitter::AbstractSplitter)
20+
function should_split(clt, depth, splitter::AbstractSplitter)
2121
abstractmethod(splitter)
2222
end
2323

@@ -41,14 +41,15 @@ Used to split an `N` dimensional `ClusterTree` into `2^N` children until at most
4141
## See also: [`AbstractSplitter`](@ref)
4242
"""
4343
Base.@kwdef struct DyadicSplitter <: AbstractSplitter
44-
nmax::Int = typemax(Int)
44+
nmax::Int = 50
45+
keep_empty::Bool = false
4546
end
4647

47-
function should_split(node::ClusterTree, splitter::DyadicSplitter)
48+
function should_split(node::ClusterTree, depth, splitter::DyadicSplitter)
4849
return length(node) > splitter.nmax
4950
end
5051

51-
function split!(parentcluster::ClusterTree, ::DyadicSplitter)
52+
function split!(parentcluster::ClusterTree, spl::DyadicSplitter)
5253
d = ambient_dimension(parentcluster)
5354
clusters = [parentcluster]
5455
rec = container(parentcluster)
@@ -61,10 +62,73 @@ function split!(parentcluster::ClusterTree, ::DyadicSplitter)
6162
append!(clusters, _binary_split!(clt, i, pos; parentcluster))
6263
end
6364
end
65+
if !spl.keep_empty
66+
iempty = Int[]
67+
for (i,cluster) in enumerate(clusters)
68+
irange = index_range(cluster)
69+
isempty(irange) && (push!(iempty,i))
70+
end
71+
clusters = deleteat!(clusters,iempty)
72+
end
6473
parentcluster.children = clusters
6574
return parentcluster
6675
end
6776

77+
"""
78+
struct DyadicMinimalSplitter <: AbstractSplitter
79+
80+
Similar to [`DiadicSplitter`](@ref), but the boundin boxes are shrank to the
81+
minimal axis-aligned boxes at the end.
82+
83+
## See also: [`AbstractSplitter`](@ref)
84+
"""
85+
Base.@kwdef struct DyadicMinimalSplitter <: AbstractSplitter
86+
nmax::Int = 50
87+
keep_empty::Bool = false
88+
end
89+
90+
should_split(node::ClusterTree, depth, splitter::DyadicMinimalSplitter) = length(node) > splitter.nmax
91+
92+
function split!(parentcluster::ClusterTree, spl::DyadicMinimalSplitter)
93+
# split as a dyadic splitter, then shrink bounding boxes
94+
split!(parentcluster,DyadicSplitter(spl.nmax,spl.keep_empty))
95+
# shrink bounding boxes to minimal one
96+
l2g = loc2glob(parentcluster)
97+
root_els = root_elements(parentcluster)
98+
clusters = children(parentcluster)
99+
for (i,cluster) in enumerate(clusters)
100+
irange = index_range(cluster)
101+
isempty(irange) && continue
102+
els = (root_els[l2g[j]] for j in irange)
103+
cluster.container = HyperRectangle(els,true)
104+
end
105+
return parentcluster
106+
end
107+
108+
"""
109+
struct DyadicMaxDepthSplitter <: AbstractSplitter
110+
111+
Similar to [`DyadicSplitter`](@ref), but splits nodes until a maximum `depth` is reached.
112+
113+
## See also: [`AbstractSplitter`](@ref)
114+
"""
115+
Base.@kwdef struct DyadicMaxDepthSplitter <: AbstractSplitter
116+
depth::Int
117+
keep_empty::Bool = false
118+
end
119+
120+
function should_split(node::ClusterTree, depth, spl::DyadicMaxDepthSplitter)
121+
if !spl.keep_empty
122+
length(node) == 0 && (return false)
123+
end
124+
depth < spl.depth
125+
end
126+
127+
function split!(parentcluster::ClusterTree, spl::DyadicMaxDepthSplitter)
128+
# exactly like the dyadic splitter
129+
split!(parentcluster, DyadicSplitter(-1,spl.keep_empty))
130+
end
131+
68132
"""
69133
struct GeometricSplitter <: AbstractSplitter
70134
@@ -74,7 +138,7 @@ Base.@kwdef struct GeometricSplitter <: AbstractSplitter
74138
nmax::Int = 50
75139
end
76140

77-
should_split(node::ClusterTree, splitter::GeometricSplitter) = length(node) > splitter.nmax
141+
should_split(node::ClusterTree, depth, splitter::GeometricSplitter) = length(node) > splitter.nmax
78142

79143
function split!(cluster::ClusterTree, ::GeometricSplitter)
80144
rec = cluster.container
@@ -93,7 +157,7 @@ Base.@kwdef struct GeometricMinimalSplitter <: AbstractSplitter
93157
nmax::Int = 50
94158
end
95159

96-
should_split(node::ClusterTree, splitter::GeometricMinimalSplitter) = length(node) > splitter.nmax
160+
should_split(node::ClusterTree, depth, splitter::GeometricMinimalSplitter) = length(node) > splitter.nmax
97161

98162
function split!(cluster::ClusterTree, ::GeometricMinimalSplitter)
99163
rec = cluster.container
@@ -112,7 +176,7 @@ Base.@kwdef struct PrincipalComponentSplitter <: AbstractSplitter
112176
nmax::Int = 50
113177
end
114178

115-
should_split(node::ClusterTree, splitter::PrincipalComponentSplitter) = length(node) > splitter.nmax
179+
should_split(node::ClusterTree, depth, splitter::PrincipalComponentSplitter) = length(node) > splitter.nmax
116180

117181
function split!(cluster::ClusterTree, ::PrincipalComponentSplitter)
118182
pts = cluster._elements
@@ -157,7 +221,7 @@ Base.@kwdef struct CardinalitySplitter <: AbstractSplitter
157221
nmax::Int = 50
158222
end
159223

160-
should_split(node::ClusterTree, splitter::CardinalitySplitter) = length(node) > splitter.nmax
224+
should_split(node::ClusterTree, depth, splitter::CardinalitySplitter) = length(node) > splitter.nmax
161225

162226
function split!(cluster::ClusterTree, ::CardinalitySplitter)
163227
points = cluster._elements

test/Trees/clustertree_test.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ end
3838
splitter = WPB.DyadicSplitter(nmax=1)
3939
clt = WPB.ClusterTree(points,splitter)
4040
@test sortperm(points) == clt.loc2glob
41+
splitter = WPB.DyadicMinimalSplitter(nmax=1)
42+
clt = WPB.ClusterTree(points,splitter)
43+
@test sortperm(points) == clt.loc2glob
4144
end
4245

4346
@testset "2d" begin
@@ -57,6 +60,9 @@ end
5760
splitter = WPB.DyadicSplitter(nmax=32)
5861
clt = WPB.ClusterTree(points,splitter)
5962
@test test_cluster_tree(clt)
63+
splitter = WPB.DyadicMinimalSplitter(nmax=1)
64+
clt = WPB.ClusterTree(points,splitter)
65+
@test test_cluster_tree(clt)
6066
end
6167

6268
@testset "3d" begin
@@ -76,6 +82,9 @@ end
7682
splitter = WPB.DyadicSplitter(nmax=32)
7783
clt = WPB.ClusterTree(points,splitter)
7884
@test test_cluster_tree(clt)
85+
splitter = WPB.DyadicMinimalSplitter(nmax=1)
86+
clt = WPB.ClusterTree(points,splitter)
87+
@test test_cluster_tree(clt)
7988
end
8089

8190
@testset "3d + threads" begin
@@ -96,5 +105,8 @@ end
96105
splitter = WPB.DyadicSplitter(nmax=32)
97106
clt = WPB.ClusterTree(points,splitter;threads)
98107
@test test_cluster_tree(clt)
108+
splitter = WPB.DyadicMinimalSplitter(nmax=32)
109+
clt = WPB.ClusterTree(points,splitter;threads)
110+
@test test_cluster_tree(clt)
99111
end
100112
end

0 commit comments

Comments
 (0)