Skip to content

Commit 08c667d

Browse files
committed
Merge branch 'master' of github.com:itsdfish/SequentialSamplingModels.jl
2 parents d32edc0 + 03959ab commit 08c667d

8 files changed

Lines changed: 87 additions & 19 deletions

File tree

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ jobs:
1616
matrix:
1717
version:
1818
- '1.11' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
19+
- '1'
1920
os:
2021
- ubuntu-latest
2122
arch:

Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
FunctionZeros = "b21f74c0-b399-568f-9643-d20f4fa2c814"
1010
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
11-
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
12-
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
1311
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1512
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1613
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1714
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -20,14 +17,18 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
2017
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2118

2219
[weakdeps]
20+
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
21+
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
22+
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
2323
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2424

2525
[extensions]
26-
PlotsExt = "Plots"
26+
PlotsExt = ["Plots", "Interpolations", "KernelDensity"]
2727

2828
[compat]
2929
ArgCheck = "2.5.0"
3030
Distributions = "v0.24.6, 0.25"
31+
DynamicPPL = "0.25 - 0.39"
3132
FunctionZeros = "0.2.0,0.3.0, 1"
3233
HCubature = "1"
3334
Interpolations = "0.14.0,0.15.0,0.16.0"

src/SequentialSamplingModels.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export PoissonRace
6464
export ShiftedLogNormal
6565
export SSM1D
6666
export SSM2D
67+
export SSMProductDistribution
6768
export stDDM
6869
export ContinuousMultivariateSSM
6970
export Wald
@@ -85,7 +86,7 @@ export plot_model
8586
export plot_model!
8687
export plot_quantiles
8788
export plot_quantiles!
88-
export predict_distribution
89+
export product_distribution
8990
export rand
9091
export simulate
9192
export std

src/multi_choice_models/MDFT.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ make_default_contrast(3)
272272
-0.5 -0.5 1.0
273273
```
274274
"""
275-
function make_default_contrast(n)
275+
function make_default_contrast(n::Integer)
276276
C = fill(0.0, n, n)
277277
C .= -1 / (n - 1)
278278
for r 1:n

src/product_distribution.jl

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,48 @@
1+
"""
2+
SSMProductDistribution
3+
4+
Wrapper around `ProductDistribution` for sequential sampling models.
5+
This type allows us to define `logpdf` methods for `NamedTuple` data
6+
without type piracy.
7+
"""
8+
struct SSMProductDistribution{D <: ProductDistribution}
9+
dist::D
10+
end
11+
12+
"""
13+
product_distribution(dists)
14+
15+
Create a product distribution from a vector of distributions.
16+
Returns an `SSMProductDistribution` for SSM types, or a standard
17+
`ProductDistribution` for other types.
18+
"""
19+
function product_distribution(dists::AbstractVector)
20+
pd = ProductDistribution(dists)
21+
# Check if this is an SSM that produces NamedTuple data
22+
if eltype(dists) <: SSM2D
23+
return SSMProductDistribution(pd)
24+
else
25+
return pd
26+
end
27+
end
28+
29+
Base.size(s::SSMProductDistribution, dims...) = size(s.dist, dims...)
30+
Base.length(s::SSMProductDistribution) = length(s.dist)
31+
132
function rand(
233
rng::AbstractRNG,
3-
s::Sampleable{T, R}
4-
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
34+
s::SSMProductDistribution
35+
)
536
n = size(s, 2)
637
data = (; choice = fill(0, n), rt = fill(0.0, n))
738
return rand!(rng, s, data)
839
end
940

1041
function rand(
1142
rng::AbstractRNG,
12-
s::Sampleable{T, R},
43+
s::SSMProductDistribution,
1344
dims::Dims
14-
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
45+
)
1546
n = size(s, 2)
1647
ax = map(Base.OneTo, dims)
1748
data = [(; choice = fill(0, n), rt = fill(0.0, n)) for _ in Iterators.product(ax...)]
@@ -20,23 +51,23 @@ end
2051

2152
function rand!(
2253
rng::AbstractRNG,
23-
s::Sampleable{T, R},
54+
s::SSMProductDistribution,
2455
data::NamedTuple
25-
) where {T <: Matrixvariate, R <: SequentialSamplingModels.Mixed}
56+
)
2657
for i 1:size(s, 2)
27-
data.choice[i], data.rt[i] = rand(rng, s.dists[i])
58+
data.choice[i], data.rt[i] = rand(rng, s.dist.dists[i])
2859
end
2960
return data
3061
end
3162

32-
function logpdf(d::ProductDistribution, data_array::Array{<:NamedTuple, N}) where {N}
63+
function logpdf(d::SSMProductDistribution, data_array::Array{<:NamedTuple, N}) where {N}
3364
return [logpdf(d, data) for data data_array]
3465
end
3566

36-
function logpdf(d::ProductDistribution, data::NamedTuple)
67+
function logpdf(d::SSMProductDistribution, data::NamedTuple)
3768
LL = 0.0
38-
for i 1:length(d.dists)
39-
LL += logpdf(d.dists[i], data.choice[i], data.rt[i])
69+
for i 1:length(d.dist.dists)
70+
LL += logpdf(d.dist.dists[i], data.choice[i], data.rt[i])
4071
end
4172
return LL
4273
end

test/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
23
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
34
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
5+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
46
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
57
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
68
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
@@ -14,5 +16,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1416
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1517
TuringUtilities = "35dc62cd-6c01-44e1-a736-6cc36bfce0cc"
1618

17-
[sources.TuringUtilities]
18-
url = "https://github.com/itsdfish/TuringUtilities.jl"
19+
[sources]
20+
TuringUtilities = {rev = "main", url = "https://github.com/itsdfish/TuringUtilities.jl"}

test/codequality.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@safetestset "Code Quality" begin
2+
3+
# check code quality via Aqua
4+
@safetestset "Aqua" begin
5+
using Aqua
6+
using SequentialSamplingModels
7+
Aqua.test_all(
8+
SequentialSamplingModels;
9+
ambiguities = false,
10+
deps_compat = (check_extras = false,),
11+
project_extras = false
12+
)
13+
end
14+
15+
# test JET
16+
@safetestset "JET" begin
17+
using JET
18+
using SequentialSamplingModels
19+
JET.test_package(
20+
SequentialSamplingModels;
21+
target_modules = (SequentialSamplingModels,)
22+
)
23+
end
24+
end

test/product_distribution_tests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
@safetestset "rand SSM1D 1" begin
33
using Distributions
44
using SequentialSamplingModels
5+
using SequentialSamplingModels: product_distribution
56
using Test
67

78
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
@@ -15,6 +16,7 @@
1516
@safetestset "rand SSM1D 2" begin
1617
using Distributions
1718
using SequentialSamplingModels
19+
using SequentialSamplingModels: product_distribution
1820
using Test
1921

2022
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
@@ -28,6 +30,7 @@
2830
@safetestset "rand logpdf 1" begin
2931
using Distributions
3032
using SequentialSamplingModels
33+
using SequentialSamplingModels: product_distribution
3134
using Test
3235

3336
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
@@ -41,6 +44,7 @@
4144
@safetestset "logpdf SSM1D 2" begin
4245
using Distributions
4346
using SequentialSamplingModels
47+
using SequentialSamplingModels: product_distribution
4448
using Test
4549

4650
walds = [Wald(; ν = 2.5, α = 0.1, τ = 0.2), Wald(; ν = 1.5, α = 1, τ = 10)]
@@ -54,6 +58,7 @@
5458
@safetestset "rand SSM2D 1" begin
5559
using Distributions
5660
using SequentialSamplingModels
61+
using SequentialSamplingModels: product_distribution
5762
using Test
5863

5964
lbas = [
@@ -70,6 +75,7 @@
7075
@safetestset "rand SSM2D 2" begin
7176
using Distributions
7277
using SequentialSamplingModels
78+
using SequentialSamplingModels: product_distribution
7379
using Test
7480

7581
lbas = [
@@ -86,6 +92,7 @@
8692
@safetestset "logpdf SSM2D 1" begin
8793
using Distributions
8894
using SequentialSamplingModels
95+
using SequentialSamplingModels: product_distribution
8996
using Test
9097

9198
lbas = [
@@ -103,6 +110,7 @@
103110
@safetestset "logpdf SSM2D 2" begin
104111
using Distributions
105112
using SequentialSamplingModels
113+
using SequentialSamplingModels: product_distribution
106114
using Test
107115

108116
lbas = [

0 commit comments

Comments
 (0)