Skip to content

Commit 63fc6e8

Browse files
Alexey Stukalovalyst
authored andcommitted
RAMMatrices ctor: dupl. vars check
1 parent 191be7d commit 63fc6e8

1 file changed

Lines changed: 11 additions & 13 deletions

File tree

src/frontend/specification/RAMMatrices.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ function RAMMatrices(;
7474
"colnames length ($(length(colnames))) does not match the number of columns in A ($ncols)",
7575
),
7676
)
77+
dup_cols = nonunique(colnames)
78+
isempty(dup_cols) ||
79+
throw(ArgumentError("Duplicate variables detected: $(join(dup_cols, ", "))"))
7780
end
7881
size(A, 1) == size(A, 2) || throw(DimensionMismatch("A must be a square matrix"))
7982
size(S, 1) == size(S, 2) || throw(DimensionMismatch("S must be a square matrix"))
@@ -99,6 +102,10 @@ function RAMMatrices(;
99102
),
100103
)
101104
end
105+
dup_params = nonunique(params)
106+
isempty(dup_params) ||
107+
throw(ArgumentError("Duplicate parameters detected: $(join(dup_params, ", "))"))
108+
102109
A = ParamsMatrix{Float64}(A, params)
103110
S = ParamsMatrix{Float64}(S, params)
104111
M = !isnothing(M) ? ParamsVector{Float64}(M, params) : nothing
@@ -118,20 +125,11 @@ function RAMMatrices(
118125
params::Union{AbstractVector{Symbol}, Nothing} = nothing,
119126
)
120127
params = copy(isnothing(params) ? SEM.params(partable) : params)
121-
params_index = Dict(param => i for (i, param) in enumerate(params))
122-
if length(params) != length(params_index)
123-
params_seen = Set{Symbol}()
124-
params_nonunique = Vector{Symbol}()
125-
for par in params
126-
push!(par in params_seen ? params_nonunique : params_seen, par)
127-
end
128-
throw(
129-
ArgumentError(
130-
"Duplicate names in the parameters vector: $(join(params_nonunique, ", "))",
131-
),
132-
)
133-
end
128+
dup_params = nonunique(params)
129+
isempty(dup_params) ||
130+
throw(ArgumentError("Duplicate parameters detected: $(join(dup_params, ", "))"))
134131

132+
params_index = Dict(param => i for (i, param) in enumerate(params))
135133
n_observed = length(partable.variables.observed)
136134
n_latent = length(partable.variables.latent)
137135
n_vars = n_observed + n_latent

0 commit comments

Comments
 (0)