|
| 1 | +""" |
| 2 | + SSMProductDistribution |
| 3 | +
|
| 4 | +Wrapper around `ProductDistribution` for sequential sampling models. |
| 5 | +This type allows us to define `logpdf` methods for `NamedTuple` data |
| 6 | +without type piracy. |
| 7 | +""" |
| 8 | +struct SSMProductDistribution{D <: ProductDistribution} |
| 9 | + dist::D |
| 10 | +end |
| 11 | + |
| 12 | +""" |
| 13 | + product_distribution(dists) |
| 14 | +
|
| 15 | +Create a product distribution from a vector of distributions. |
| 16 | +Returns an `SSMProductDistribution` for SSM types, or a standard |
| 17 | +`ProductDistribution` for other types. |
| 18 | +""" |
| 19 | +function product_distribution(dists::AbstractVector) |
| 20 | + pd = ProductDistribution(dists) |
| 21 | + # Check if this is an SSM that produces NamedTuple data |
| 22 | + if eltype(dists) <: SSM2D |
| 23 | + return SSMProductDistribution(pd) |
| 24 | + else |
| 25 | + return pd |
| 26 | + end |
| 27 | +end |
| 28 | + |
| 29 | +Base.size(s::SSMProductDistribution, dims...) = size(s.dist, dims...) |
| 30 | +Base.length(s::SSMProductDistribution) = length(s.dist) |
| 31 | + |
1 | 32 | function rand( |
2 | 33 | rng::AbstractRNG, |
3 | | - s::Sampleable{T, R} |
4 | | -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} |
| 34 | + s::SSMProductDistribution |
| 35 | +) |
5 | 36 | n = size(s, 2) |
6 | 37 | data = (; choice = fill(0, n), rt = fill(0.0, n)) |
7 | 38 | return rand!(rng, s, data) |
8 | 39 | end |
9 | 40 |
|
10 | 41 | function rand( |
11 | 42 | rng::AbstractRNG, |
12 | | - s::Sampleable{T, R}, |
| 43 | + s::SSMProductDistribution, |
13 | 44 | dims::Dims |
14 | | -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} |
| 45 | +) |
15 | 46 | n = size(s, 2) |
16 | 47 | ax = map(Base.OneTo, dims) |
17 | 48 | data = [(; choice = fill(0, n), rt = fill(0.0, n)) for _ in Iterators.product(ax...)] |
|
20 | 51 |
|
21 | 52 | function rand!( |
22 | 53 | rng::AbstractRNG, |
23 | | - s::Sampleable{T, R}, |
| 54 | + s::SSMProductDistribution, |
24 | 55 | data::NamedTuple |
25 | | -) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed} |
| 56 | +) |
26 | 57 | for i ∈ 1:size(s, 2) |
27 | | - data.choice[i], data.rt[i] = rand(rng, s.dists[i]) |
| 58 | + data.choice[i], data.rt[i] = rand(rng, s.dist.dists[i]) |
28 | 59 | end |
29 | 60 | return data |
30 | 61 | end |
31 | 62 |
|
32 | | -function logpdf(d::ProductDistribution, data_array::Array{<:NamedTuple, N}) where {N} |
| 63 | +function logpdf(d::SSMProductDistribution, data_array::Array{<:NamedTuple, N}) where {N} |
33 | 64 | return [logpdf(d, data) for data ∈ data_array] |
34 | 65 | end |
35 | 66 |
|
36 | | -function logpdf(d::ProductDistribution, data::NamedTuple) |
| 67 | +function logpdf(d::SSMProductDistribution, data::NamedTuple) |
37 | 68 | LL = 0.0 |
38 | | - for i ∈ 1:length(d.dists) |
39 | | - LL += logpdf(d.dists[i], data.choice[i], data.rt[i]) |
| 69 | + for i ∈ 1:length(d.dist.dists) |
| 70 | + LL += logpdf(d.dist.dists[i], data.choice[i], data.rt[i]) |
40 | 71 | end |
41 | 72 | return LL |
42 | 73 | end |
0 commit comments