Skip to content

Commit 837a800

Browse files
committed
make ClusterTree more generic
- A `ClusterTree` can now be constructed from `AbstractVector`
1 parent 3507aff commit 837a800

5 files changed

Lines changed: 100 additions & 60 deletions

File tree

src/Trees/abstracttree.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,45 @@ function filter_tree!(f,nodes,tree,isterminal=true)
8686
return nodes
8787
end
8888

89+
"""
90+
partition_by_depth(tree)
91+
92+
Given a `tree`, return a `partition` vector whose `i`-th entry stores all the nodes in
93+
`tree` with `depth=i-1`. Empty nodes are not added to the partition.
94+
"""
95+
function partition_by_depth(tree)
96+
T = eltype(tree)
97+
partition = Vector{Vector{T}}()
98+
depth = 0
99+
_partition_by_depth!(partition,tree,depth)
100+
end
101+
102+
function _partition_by_depth!(partition,tree,depth)
103+
T = eltype(tree)
104+
if length(partition) < depth+1
105+
push!(partition,T[])
106+
end
107+
length(tree) > 0 && push!(partition[depth+1],tree)
108+
for chd in children(tree)
109+
_partition_by_depth!(partition,chd,depth+1)
110+
end
111+
return partition
112+
end
113+
114+
"""
115+
partition_by_height(tree)
116+
117+
Given a `tree`, return a `partition` vector whose `i`-th entry stores all the nodes in
118+
`tree` with `height=i-1`. The `height` of the tree is thus `lenth(partition)`,
119+
with `partition(end)==tree`.
120+
"""
121+
function partition_by_height(tree)
122+
# TODO: how to do this more or less efficiently? One idea is to start at the
123+
# leaves, push them, then push their parents and recurse, making sure call
124+
# `unique!` as you go up in order to avoid duplicate parents.
125+
end
126+
127+
89128
# interface to AbstractTrees. No children is determined by an empty tuple for
90129
# AbstractTrees.
91130
AbstractTrees.children(t::AbstractTree) = isleaf(t) ? () : t.children

src/Trees/clustertree.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
22
mutable struct ClusterTree{T,S,D}
33
4-
Tree structure used to cluster elements of type `T` into containers of type `S`.
5-
The method `center(::T)::SVector` is required for the clustering algorithms. An
4+
Tree structure used to cluster elements of type `V = eltype(T)` into containers of type `S`.
5+
The method `center(::V)::SVector` is required for the clustering algorithms. An
66
additional `data` field of type `D` can be associated with each node to store
77
node-specific information (it defaults to `D=Nothing`).
88
99
# Fields:
10-
- `_elements::Vector{T}` : vector containing the sorted elements.
10+
- `_elements::T` : vector containing the sorted elements.
1111
- `container::S` : container for the elements in the current node.
1212
- `index_range::UnitRange{Int}` : indices of elements contained in the current node.
1313
- `loc2glob::Vector{Int}` : permutation from the local indexing system to the
@@ -17,15 +17,15 @@ original (global) indexing system used as input in the construction of the tree.
1717
- `data::D` : generic data field of type `D`.
1818
"""
1919
mutable struct ClusterTree{T,S,D} <: AbstractTree
20-
_elements::Vector{T}
20+
_elements::T
2121
container::S
2222
index_range::UnitRange{Int}
2323
loc2glob::Vector{Int}
2424
children::Vector{ClusterTree{T,S,D}}
2525
parent::ClusterTree{T,S,D}
2626
data::D
2727
# inner constructors handling missing fields.
28-
function ClusterTree{D}(els::Vector{T},container::S,loc_idxs,loc2glob,children,parent,data=nothing) where {T,S,D}
28+
function ClusterTree{D}(els::T,container::S,loc_idxs,loc2glob,children,parent,data=nothing) where {T,S,D}
2929
clt = new{T,S,D}(els,container,loc_idxs,loc2glob)
3030
clt.children = isnothing(children) ? Vector{typeof(clt)}() : children
3131
clt.parent = isnothing(parent) ? clt : parent
@@ -87,7 +87,7 @@ container_type(::ClusterTree{T,S}) where {T,S} = S
8787
8888
Type of elements sorted in `clt`.
8989
"""
90-
element_type(::ClusterTree{T}) where {T} = T
90+
element_type(::ClusterTree{T}) where {T} = eltype(T)
9191

9292
isleaf(clt::ClusterTree) = isempty(clt.children)
9393
isroot(clt::ClusterTree) = clt.parent == clt

src/Trees/splitter.jl

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ abstract type AbstractSplitter end
1717
1818
Determine whether or not a `ClusterTree` should be further divided.
1919
"""
20-
function should_split(clt,splitter::AbstractSplitter)
20+
function should_split(clt, splitter::AbstractSplitter)
2121
abstract_method(splitter)
2222
end
2323

@@ -26,7 +26,7 @@ end
2626
2727
Divide `clt` using the strategy implemented by `splitter`.
2828
"""
29-
function split!(clt,splitter::AbstractSplitter)
29+
function split!(clt, splitter::AbstractSplitter)
3030
abstract_method(splitter)
3131
end
3232

@@ -37,24 +37,24 @@ Used to split an `N` dimensional `ClusterTree` into `2^N` children until at most
3737
`nmax` points are contained in node.
3838
"""
3939
Base.@kwdef struct DyadicSplitter <: AbstractSplitter
40-
nmax::Int=typemax(Int)
40+
nmax::Int = typemax(Int)
4141
end
4242

43-
function should_split(node::ClusterTree,splitter::DyadicSplitter)
43+
function should_split(node::ClusterTree, splitter::DyadicSplitter)
4444
return length(node) > splitter.nmax
4545
end
4646

47-
function split!(parentcluster::ClusterTree,::DyadicSplitter)
48-
d = ambient_dimension(parentcluster)
47+
function split!(parentcluster::ClusterTree, ::DyadicSplitter)
48+
d = ambient_dimension(parentcluster)
4949
clusters = [parentcluster]
5050
rec = container(parentcluster)
5151
rec_center = center(rec)
52-
for i in 1:d
52+
for i = 1:d
5353
pos = rec_center[i]
5454
nel = length(clusters) #2^(i-1)
55-
for _ in 1:nel
55+
for _ = 1:nel
5656
clt = popfirst!(clusters)
57-
append!(clusters,_binary_split!(clt,i,pos;parentcluster))
57+
append!(clusters, _binary_split!(clt, i, pos; parentcluster))
5858
end
5959
end
6060
return clusters
@@ -65,16 +65,16 @@ end
6565
6666
Used to split a `ClusterTree` in half along the largest axis.
6767
"""
68-
@Base.kwdef struct GeometricSplitter <: AbstractSplitter
69-
nmax::Int=50
68+
Base.@kwdef struct GeometricSplitter <: AbstractSplitter
69+
nmax::Int = 50
7070
end
7171

72-
should_split(node::ClusterTree,splitter::GeometricSplitter) = length(node) > splitter.nmax
72+
should_split(node::ClusterTree, splitter::GeometricSplitter) = length(node) > splitter.nmax
7373

74-
function split!(cluster::ClusterTree,::GeometricSplitter)
75-
rec = cluster.container
76-
wmax, imax = findmax(high_corner(rec) - low_corner(rec))
77-
left_node, right_node = _binary_split!(cluster, imax, low_corner(rec)[imax]+wmax/2)
74+
function split!(cluster::ClusterTree, ::GeometricSplitter)
75+
rec = cluster.container
76+
wmax, imax = findmax(high_corner(rec) - low_corner(rec))
77+
left_node, right_node = _binary_split!(cluster, imax, low_corner(rec)[imax] + wmax / 2)
7878
return [left_node, right_node]
7979
end
8080

@@ -83,54 +83,54 @@ end
8383
8484
Like [`GeometricSplitter`](@ref), but shrinks the children's containters.
8585
"""
86-
@Base.kwdef struct GeometricMinimalSplitter <: AbstractSplitter
87-
nmax::Int=50
86+
Base.@kwdef struct GeometricMinimalSplitter <: AbstractSplitter
87+
nmax::Int = 50
8888
end
8989

90-
should_split(node::ClusterTree,splitter::GeometricMinimalSplitter) = length(node) > splitter.nmax
90+
should_split(node::ClusterTree, splitter::GeometricMinimalSplitter) = length(node) > splitter.nmax
9191

92-
function split!(cluster::ClusterTree,::GeometricMinimalSplitter)
93-
rec = cluster.container
94-
wmax, imax = findmax(high_corner(rec) - low_corner(rec))
95-
mid = low_corner(rec)[imax]+wmax/2
92+
function split!(cluster::ClusterTree, ::GeometricMinimalSplitter)
93+
rec = cluster.container
94+
wmax, imax = findmax(high_corner(rec) - low_corner(rec))
95+
mid = low_corner(rec)[imax] + wmax / 2
9696
predicate = (x) -> x[imax] < mid
97-
left_node,right_node = _binary_split!(cluster,predicate)
97+
left_node, right_node = _binary_split!(cluster, predicate)
9898
return [left_node, right_node]
9999
end
100100

101101
"""
102102
struct PrincipalComponentSplitter <: AbstractSplitter
103103
"""
104-
@Base.kwdef struct PrincipalComponentSplitter <: AbstractSplitter
105-
nmax::Int=50
104+
Base.@kwdef struct PrincipalComponentSplitter <: AbstractSplitter
105+
nmax::Int = 50
106106
end
107107

108-
should_split(node::ClusterTree,splitter::PrincipalComponentSplitter) = length(node) > splitter.nmax
108+
should_split(node::ClusterTree, splitter::PrincipalComponentSplitter) = length(node) > splitter.nmax
109109

110-
function split!(cluster::ClusterTree,::PrincipalComponentSplitter)
111-
pts = cluster._elements
112-
irange = cluster.index_range
113-
xc = center_of_mass(cluster)
110+
function split!(cluster::ClusterTree, ::PrincipalComponentSplitter)
111+
pts = cluster._elements
112+
irange = cluster.index_range
113+
xc = center_of_mass(cluster)
114114
# compute covariance matrix for principal direction
115-
cov = sum(irange) do i
115+
cov = sum(irange) do i
116116
x = coords(pts[i])
117-
(x - xc)*transpose(x - xc)
117+
(x - xc) * transpose(x - xc)
118118
end
119-
v = eigvecs(cov)[:,end]
120-
predicate = (x) -> dot(x-xc,v) < 0
121-
left_node, right_node = _binary_split!(cluster,predicate)
119+
v = eigvecs(cov)[:, end]
120+
predicate = (x) -> dot(x - xc, v) < 0
121+
left_node, right_node = _binary_split!(cluster, predicate)
122122
return [left_node, right_node]
123123
end
124124

125125
function center_of_mass(clt::ClusterTree)
126-
pts = clt._elements
127-
loc_idxs = clt.index_range
126+
pts = clt._elements
127+
loc_idxs = clt.index_range
128128
# w = clt.weights
129-
n = length(loc_idxs)
129+
n = length(loc_idxs)
130130
# M = isempty(w) ? n : sum(i->w[i],glob_idxs)
131131
# xc = isempty(w) ? sum(i->pts[i]/M,glob_idxs) : sum(i->w[i]*pts[i]/M,glob_idxs)
132-
M = n
133-
xc = sum(i->coords(pts[i])/M,loc_idxs)
132+
M = n
133+
xc = sum(i -> coords(pts[i]) / M, loc_idxs)
134134
return xc
135135
end
136136

@@ -141,19 +141,19 @@ Used to split a `ClusterTree` along the largest dimension if
141141
`length(tree)>nmax`. The split is performed so the `data` is evenly distributed
142142
amongst all children.
143143
"""
144-
@Base.kwdef struct CardinalitySplitter <: AbstractSplitter
145-
nmax::Int=50
144+
Base.@kwdef struct CardinalitySplitter <: AbstractSplitter
145+
nmax::Int = 50
146146
end
147147

148-
should_split(node::ClusterTree,splitter::CardinalitySplitter) = length(node) > splitter.nmax
148+
should_split(node::ClusterTree, splitter::CardinalitySplitter) = length(node) > splitter.nmax
149149

150-
function split!(cluster::ClusterTree,::CardinalitySplitter)
151-
points = cluster._elements
152-
irange = cluster.index_range
153-
rec = container(cluster)
154-
_, imax = findmax(high_corner(rec) - low_corner(rec))
155-
med = median(coords(points[i])[imax] for i in irange) # the median along largest axis `imax`
150+
function split!(cluster::ClusterTree, ::CardinalitySplitter)
151+
points = cluster._elements
152+
irange = cluster.index_range
153+
rec = container(cluster)
154+
_, imax = findmax(high_corner(rec) - low_corner(rec))
155+
med = median(coords(points[i])[imax] for i in irange) # the median along largest axis `imax`
156156
predicate = (x) -> x[imax] < med
157-
left_node, right_node = _binary_split!(cluster,predicate)
157+
left_node, right_node = _binary_split!(cluster, predicate)
158158
return [left_node, right_node]
159159
end

src/Utils/Utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,9 @@ that `0 ≤ θ ≤ π` and ` -π < φ ≤ π`.
244244
"""
245245
function cart2sph(x,y,z)
246246
azimuth = atan(y,x)
247-
elevation = atan(sqrt(x^2 + y^2),z)
248-
r = sqrt(x^2 + y^2 + z^2)
247+
a = x^2 + y^2
248+
elevation = atan(sqrt(a),z)
249+
r = sqrt(a + z^2)
249250
return r, elevation, azimuth
250251
end
251252

test/Mesh/cartesianmesh_test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
@testset "Two dimensions" begin
3030
l = HyperRectangle((0.0,0.0),(1.0,1.0))
3131
E = typeof(l) # type of mesh element
32-
mesh = UniformCartesianMesh(domain=l,sz=(10,20))
32+
mesh = UniformCartesianMesh(l,(10,20))
3333
iter = ElementIterator(mesh,E)
3434
@test iter[1,1] HyperRectangle((0,0),(0.1,0.05))
3535
@test length(iter) == 200

0 commit comments

Comments
 (0)