Skip to content

Commit 1db5152

Browse files
committed
add test for MixedDataLoader including additional keywords
1 parent 8af5b0e commit 1db5152

1 file changed

Lines changed: 31 additions & 0 deletions

File tree

tests/test_loader.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,37 @@ def test_continuous(conditional, device, benchmark):
176176
benchmark(load_speed)
177177

178178

179+
@parametrize_device
180+
@pytest.mark.parametrize(
181+
"conditional, positive_sampling, discrete_sampling_prior",
182+
[
183+
("time", "discrete_variable", "empirical"),
184+
("time", "conditional", "empirical"),
185+
("time", "discrete_variable", "uniform"),
186+
("time", "conditional", "uniform"),
187+
("time_delta", "discrete_variable", "empirical"),
188+
("time_delta", "conditional", "empirical"),
189+
("time_delta", "discrete_variable", "uniform"),
190+
("time_delta", "conditional", "uniform"),
191+
],
192+
)
193+
def test_mixed(
194+
conditional, positive_sampling, discrete_sampling_prior, device, benchmark
195+
):
196+
dataset = RandomDataset(N=100, d=5, device=device)
197+
loader = cebra.data.MixedDataLoader(
198+
dataset=dataset,
199+
num_steps=10,
200+
batch_size=8,
201+
conditional=conditional,
202+
positive_sampling=positive_sampling,
203+
discrete_sampling_prior=discrete_sampling_prior,
204+
)
205+
_assert_dataset_on_correct_device(loader, device)
206+
load_speed = LoadSpeed(loader)
207+
benchmark(load_speed)
208+
209+
179210
def _check_attributes(obj, is_list=False):
180211
if is_list:
181212
for obj_ in obj:

0 commit comments

Comments
 (0)