Skip to content

Commit e322666

Browse files
committed
finish arg checks
1 parent 9d5767a commit e322666

28 files changed

Lines changed: 1382 additions & 146 deletions

src/alternative_geometries/CircularDDM.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ struct CDDM{T <: Real} <: AbstractCDDM
4646
η::Vector{T}
4747
α::T
4848
τ::T
49+
50+
function CDDM::Vector{T}, σ::T, η::Vector{T}, α::T, τ::T) where {T}
51+
@argcheck σ 0
52+
@argcheck all.≥ 0)
53+
@argcheck α 0
54+
@argcheck τ 0
55+
return new{T}(ν, σ, η, α, τ)
56+
end
4957
end
5058

5159
function CDDM(ν, σ, η, α, τ)
@@ -55,9 +63,7 @@ function CDDM(ν, σ, η, α, τ)
5563
return CDDM(ν, σ, η, α, τ)
5664
end
5765

58-
function params(d::AbstractCDDM)
59-
return (d.ν, d.σ, d.η, d.α, d.τ)
60-
end
66+
params(d::AbstractCDDM) = (d.ν, d.σ, d.η, d.α, d.τ)
6167

6268
function CDDM(; ν = [1, 0.5], η = [1, 1], σ = 1, α = 1.5, τ = 0.30)
6369
return CDDM(ν, σ, η, α, τ)
@@ -94,6 +100,7 @@ end
94100

95101
function logpdf(d::AbstractCDDM, data::Vector{<:Real}; k_max = 50)
96102
θ, rt = data
103+
@argcheck d.τ rt
97104
return logpdf_term1(d, θ, rt) + logpdf_term2(d, rt; k_max)
98105
end
99106

@@ -111,11 +118,13 @@ end
111118

112119
function pdf(d::AbstractCDDM, data::Vector{<:Real}; k_max = 50)
113120
θ, rt = data
121+
@argcheck d.τ rt
114122
return max(0.0, pdf_term1(d, θ, rt) * pdf_term2(d, rt; k_max))
115123
end
116124

117125
function pdf(d::AbstractCDDM, data::Vector{<:Real}, j0, j01, j02; k_max = 50)
118126
θ, rt = data
127+
@argcheck d.τ rt
119128
return max(0.0, pdf_term1(d, θ, rt) * pdf_term2(d, rt, j0, j01, j02; k_max))
120129
end
121130

src/multi_choice_models/ClassicMDFT.jl

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,25 @@ A model type for Multiattribute Decision Field Theory.
66
77
# Parameters
88
- `σ::T = 1.0`: diffusion noise. σ ∈ ℝ⁺.
9-
- `α::T = 15.0`: evidence threshold. α ∈ ℝ⁺.
10-
- `τ::T = .30`: non-decision time. τ ∈ [0, min_rt].
11-
- `w::Vector{T}`: attention weights vector where each element corresponds to the attention given to the corresponding dimension. wᵢ ∈ [0,1], ∑wᵢ = 1.
9+
- `C::Array{T, 2}`: contrast weight matrix where c_ij is the contrast weight when comparing options i and j.
1210
- `S::Array{T, 2}`: feedback matrix allowing self-connections and interconnections between alternatives. Self-connections range from zero to 1, where s_ij < 1 represents decay. Interconnections
1311
between options i and j where i ≠ j are inhibatory if s_ij < 0.
14-
- `C::Array{T, 2}`: contrast weight matrix where c_ij is the contrast weight when comparing options i and j.
12+
- `w::Vector{T}`: attention weights vector where each element corresponds to the attention given to the corresponding dimension. wᵢ ∈ [0,1], ∑wᵢ = 1.
13+
- `α::T = 15.0`: evidence threshold. α ∈ ℝ⁺.
14+
- `τ::T = .30`: non-decision time. τ ∈ [0, min_rt].
1515
1616
# Constructors
1717
18-
ClassicMDFT(σ, α, τ, w, S, C)
18+
ClassicMDFT(σ, C, S, w, α, τ)
1919
20-
ClassicMDFT(σ, α, τ, w, S, C = make_default_contrast(S))
20+
ClassicMDFT(;
21+
σ = 1.0,
22+
α = 15.0,
23+
τ = 10.0,
24+
w,
25+
S,
26+
C = make_default_contrast(size(S, 1))
27+
)
2128
2229
# Example
2330
@@ -34,12 +41,8 @@ M = [
3441
]
3542
3643
model = ClassicMDFT(;
37-
# non-decision time
38-
τ = 0.300,
3944
# diffusion noise
4045
σ = 1.0,
41-
# decision threshold
42-
α = 17.5,
4346
# attribute attention weights
4447
w = [0.5, 0.5],
4548
# feedback matrix
@@ -48,6 +51,10 @@ model = ClassicMDFT(;
4851
-0.0122316 0.9500000 -0.00903030
4952
-0.0499996 -0.0090303 0.95000000
5053
],
54+
# decision threshold
55+
α = 17.5,
56+
# non-decision time
57+
τ = 0.300,
5158
)
5259
choices, rts = rand(model, 10_000, M; Δt = 1.0)
5360
map(c -> mean(choices .== c), 1:3)
@@ -58,19 +65,33 @@ Roe, Robert M., Jermone R. Busemeyer, and James T. Townsend. "Multiattribute Dec
5865
"""
5966
mutable struct ClassicMDFT{T <: Real} <: AbstractMDFT
6067
σ::T
68+
C::Array{T, 2}
69+
S::Array{T, 2}
70+
w::Vector{T}
6171
α::T
6272
τ::T
63-
w::Vector{T}
64-
S::Array{T, 2}
65-
C::Array{T, 2}
73+
function ClassicMDFT(
74+
σ::T,
75+
C::Array{T, 2},
76+
S::Array{T, 2},
77+
w::Vector{T},
78+
α::T,
79+
τ::T
80+
) where {T <: Real}
81+
@argcheck σ 0
82+
@argcheck α 0
83+
@argcheck τ 0
84+
@argcheck all(w .≥ 0) && (sum(w) == 1)
85+
return new{T}(σ, C, S, w, α, τ)
86+
end
6687
end
6788

68-
function ClassicMDFT(σ, α, τ, w, S, C)
69-
σ, α, τ, _, _, _ = promote(σ, α, τ, w[1], S[1], C[1])
70-
w = convert(Vector{typeof(τ)}, w)
71-
S = convert(Array{typeof(τ), 2}, S)
89+
function ClassicMDFT(σ, C, S, w, α, τ)
90+
σ, _, _, _, α, τ, = promote(σ, C[1], S[1], w[1], α, τ)
7291
C = convert(Array{typeof(τ), 2}, C)
73-
return ClassicMDFT(σ, α, τ, w, S, C)
92+
S = convert(Array{typeof(τ), 2}, S)
93+
w = convert(Vector{typeof(τ)}, w)
94+
return ClassicMDFT(σ, C, S, w, α, τ)
7495
end
7596

7697
function ClassicMDFT(;
@@ -81,14 +102,12 @@ function ClassicMDFT(;
81102
S,
82103
C = make_default_contrast(size(S, 1))
83104
)
84-
return ClassicMDFT(σ, α, τ, w, S, C)
105+
return ClassicMDFT(σ, C, S, w, α, τ)
85106
end
86107

87-
get_pdf_type(d::AbstractMDFT) = Approximate
108+
get_pdf_type(::AbstractMDFT) = Approximate
88109

89-
function params(d::ClassicMDFT)
90-
return (d.σ, d.α, d.τ, d.w, d.S, d.C)
91-
end
110+
params(d::ClassicMDFT) = (d.σ, d.C, d.S, d.w, d.α, d.τ)
92111

93112
"""
94113
rand(

src/multi_choice_models/DDM.jl

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,19 @@ mutable struct DDM{T <: Real} <: AbstractDDM
3939
α::T
4040
z::T
4141
τ::T
42+
function DDM::T, α::T, z::T, τ::T) where {T <: Real}
43+
@argcheck α 0
44+
@argcheck (z 0) && (z 1)
45+
@argcheck τ 0
46+
return new{T}(ν, α, z, τ)
47+
end
4248
end
4349

4450
function DDM(ν, α, z, τ)
4551
return DDM(promote(ν, α, z, τ)...)
4652
end
4753

48-
function params(d::DDM)
49-
return (d.ν, d.α, d.z, d.τ)
50-
end
54+
params(d::DDM) = (d.ν, d.α, d.z, d.τ)
5155

5256
function DDM(; ν = 1.00, α = 0.80, τ = 0.30, z = 0.50)
5357
return DDM(ν, α, z, τ)
@@ -64,7 +68,8 @@ end
6468
# Wabersich & Vandekerckhove (2014) #
6569
#####################################
6670

67-
function pdf(d::DDM, choice, rt; ϵ::Real = 1.0e-12)
71+
function pdf(d::DDM{T}, choice, rt; ϵ::Real = 1.0e-12) where {T <: Real}
72+
@argcheck d.τ < rt
6873
if choice == 1
6974
(ν, α, z, τ) = params(d)
7075
return _pdf(DDM(-ν, α, 1 - z, τ), rt; ϵ)
@@ -73,11 +78,8 @@ function pdf(d::DDM, choice, rt; ϵ::Real = 1.0e-12)
7378
end
7479

7580
# probability density function over the lower boundary
76-
function _pdf(d::DDM{T}, t::Real; ϵ::Real = 1.0e-12) where {T <: Real}
81+
function _pdf(d::DDM, t::Real; ϵ::Real = 1.0e-12)
7782
(ν, α, z, τ) = params(d)
78-
if τ t
79-
return T(NaN)
80-
end
8183
u = (t - τ) / α^2 #use normalized time
8284

8385
K_s = 2.0
@@ -97,7 +99,6 @@ function _pdf(d::DDM{T}, t::Real; ϵ::Real = 1.0e-12) where {T <: Real}
9799
if K_s < K_l
98100
return p * _small_time_pdf(u, z, ceil(Int, K_s))
99101
end
100-
101102
return p * _large_time_pdf(u, z, ceil(Int, K_l))
102103
end
103104

@@ -106,7 +107,7 @@ function _small_time_pdf(u::T, z::T, K::Int) where {T <: Real}
106107
inf_sum = zero(T)
107108

108109
k_series = (-floor(Int, 0.5 * (K - 1))):ceil(Int, 0.5 * (K - 1))
109-
for k in k_series
110+
for k k_series
110111
inf_sum += ((2k + z) * exp(-((2k + z)^2 / (2u))))
111112
end
112113

@@ -117,15 +118,14 @@ end
117118
function _large_time_pdf(u::T, z::T, K::Int) where {T <: Real}
118119
inf_sum = zero(T)
119120

120-
for k = 1:K
121+
for k 1:K
121122
inf_sum += (k * exp(-0.5 * (k^2 * π^2 * u)) * sin(k * π * z))
122123
end
123124

124125
return π * inf_sum
125126
end
126127

127128
logpdf(d::DDM, choice, rt; ϵ::Real = 1.0e-12) = log(pdf(d, choice, rt; ϵ))
128-
#logpdf(d::DDM, t::Real; ϵ::Real = 1.0e-12) = log(pdf(d, t; ϵ))
129129

130130
logpdf(d::DDM, data::Tuple) = logpdf(d, data...)
131131

@@ -134,7 +134,8 @@ logpdf(d::DDM, data::Tuple) = logpdf(d, data...)
134134
# Blurton, Kesselmeier, & Gondan (2012) #
135135
#########################################
136136

137-
function cdf(d::DDM, choice::Int, rt::Real = 10; ϵ::Real = 1.0e-12)
137+
function cdf(d::DDM{T}, choice::Int, rt::Real = 10; ϵ::Real = 1.0e-12) where {T <: Real}
138+
@argcheck d.τ < rt
138139
if choice == 1
139140
(ν, α, z, τ) = params(d)
140141
return _cdf(DDM(-ν, α, 1 - z, τ), rt; ϵ)
@@ -144,11 +145,7 @@ function cdf(d::DDM, choice::Int, rt::Real = 10; ϵ::Real = 1.0e-12)
144145
end
145146

146147
# cumulative density function over the lower boundary
147-
function _cdf(d::DDM{T}, t::Real; ϵ::Real = 1.0e-12) where {T <: Real}
148-
if d.τ t
149-
return T(NaN)
150-
end
151-
148+
function _cdf(d::DDM, t::Real; ϵ::Real = 1.0e-12)
152149
K_l = _K_large(d, t; ϵ)
153150
K_s = _K_small(d, t; ϵ)
154151

@@ -163,7 +160,7 @@ function _Fl_lower(d::DDM{T}, K::Int, t::Real) where {T <: Real}
163160
(ν, α, z, τ) = params(d)
164161
F = zero(T)
165162
K_series = K:-1:1
166-
for k in K_series
163+
for k K_series
167164
F -= (
168165
k /^2 + k^2 * π^2 /^2)) *
169166
exp(-ν * α * z - 0.5 * ν^2 * (t - τ) - 0.5 * k^2 * π^2 /^2) * (t - τ)) *
@@ -186,7 +183,7 @@ function _Fs_lower(d::DDM{T}, K::Int, t::Real) where {T <: Real}
186183
S2 = zero(T)
187184
K_series = K:-1:1
188185

189-
for k in K_series
186+
for k K_series
190187
S1 += (
191188
_exp_pnorm(2 * ν * α * k, -sign(ν) * (2 * α * k + α * z + ν * (t - τ)) / sqt) - _exp_pnorm(
192189
-2 * ν * α * k - 2 * ν * α * z,
@@ -214,11 +211,10 @@ function _Fs_lower(d::DDM{T}, K::Int, t::Real) where {T <: Real}
214211
end
215212

216213
# Zero drift version
217-
function _Fs0_lower(d::DDM{T}, K::Int, t::Real) where {T <: Real}
218-
(_, α, z, τ) = params(d)
214+
function _Fs0_lower(dist::DDM{T}, K::Int, t::Real) where {T <: Real}
215+
(; α, z, τ) = dist
219216
F = zero(T)
220-
K_series = K:-1:0
221-
for k in K_series
217+
for k K:-1:0
222218
F -= (
223219
cdf(Distributions.Normal(), (-2 * k - 2 + z) * α / sqrt(t - τ)) +
224220
cdf(Distributions.Normal(), (-2 * k - z) * α / sqrt(t - τ))

src/multi_choice_models/LBA.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,18 @@ mutable struct LBA{T <: Real, T1 <: Union{<:T, Vector{<:T}}} <: AbstractLBA{T, T
4141
A::T
4242
k::T
4343
τ::T
44-
function LBA::Vector{T}, σ::T1, A::T, k::T, τ::T) where {T <: Real, T1 <: Union{<:T, Vector{<:T}}}
44+
function LBA(
45+
ν::Vector{T},
46+
σ::T1,
47+
A::T,
48+
k::T,
49+
τ::T
50+
) where {T <: Real, T1 <: Union{<:T, Vector{<:T}}}
4551
@argcheck all.≥ 0)
4652
@argcheck A 0
4753
@argcheck k 0
4854
@argcheck τ 0
49-
return new{T,T1}(ν, σ, A, k, τ)
55+
return new{T, T1}(ν, σ, A, k, τ)
5056
end
5157
end
5258

@@ -84,7 +90,6 @@ sample_drift_rates(ν, σ) = sample_drift_rates(Random.default_rng(), ν, σ)
8490
function sample_drift_rates(rng::AbstractRNG, ν, σ)
8591
negative = true
8692
v = similar(ν)
87-
n_options = length(ν)
8893
while negative
8994
v = @. rand(rng, Normal(ν, σ))
9095
negative = any(x -> x > 0, v) ? false : true
@@ -106,8 +111,8 @@ end
106111

107112
function logpdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Vector{<:Real}}
108113
(; τ, ν, σ) = d
114+
@argcheck τ rt
109115
LL = 0.0
110-
rt < τ ? (return -Inf) : nothing
111116
for i 1:length(ν)
112117
if c == i
113118
LL += log_dens(d, ν[i], σ[i], rt)
@@ -123,7 +128,7 @@ end
123128
function logpdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Real}
124129
(; τ, ν, σ) = d
125130
LL = 0.0
126-
rt < τ ? (return -Inf) : nothing
131+
@argcheck τ rt
127132
for i 1:length(ν)
128133
if c == i
129134
LL += log_dens(d, ν[i], σ, rt)
@@ -137,9 +142,9 @@ function logpdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Real}
137142
end
138143

139144
function pdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Vector{<:Real}}
140-
(; τ, A, k, ν, σ) = d
145+
(; τ, ν, σ) = d
141146
den = 1.0
142-
rt < τ ? (return 1e-10) : nothing
147+
@argcheck τ rt
143148
for i 1:length(ν)
144149
if c == i
145150
den *= dens(d, ν[i], σ[i], rt)
@@ -156,7 +161,7 @@ end
156161
function pdf(d::AbstractLBA{T, T1}, c, rt) where {T, T1 <: Real}
157162
(; τ, ν, σ) = d
158163
den = 1.0
159-
rt < τ ? (return 1e-10) : nothing
164+
@argcheck τ rt
160165
for i 1:length(ν)
161166
if c == i
162167
den *= dens(d, ν[i], σ, rt)

0 commit comments

Comments
 (0)