@@ -176,6 +176,37 @@ def test_continuous(conditional, device, benchmark):
176176 benchmark (load_speed )
177177
178178
179+ @parametrize_device
180+ @pytest .mark .parametrize (
181+ "conditional, positive_sampling, discrete_sampling_prior" ,
182+ [
183+ ("time" , "discrete_variable" , "empirical" ),
184+ ("time" , "conditional" , "empirical" ),
185+ ("time" , "discrete_variable" , "uniform" ),
186+ ("time" , "conditional" , "uniform" ),
187+ ("time_delta" , "discrete_variable" , "empirical" ),
188+ ("time_delta" , "conditional" , "empirical" ),
189+ ("time_delta" , "discrete_variable" , "uniform" ),
190+ ("time_delta" , "conditional" , "uniform" ),
191+ ],
192+ )
193+ def test_mixed (
194+ conditional , positive_sampling , discrete_sampling_prior , device , benchmark
195+ ):
196+ dataset = RandomDataset (N = 100 , d = 5 , device = device )
197+ loader = cebra .data .MixedDataLoader (
198+ dataset = dataset ,
199+ num_steps = 10 ,
200+ batch_size = 8 ,
201+ conditional = conditional ,
202+ positive_sampling = positive_sampling ,
203+ discrete_sampling_prior = discrete_sampling_prior ,
204+ )
205+ _assert_dataset_on_correct_device (loader , device )
206+ load_speed = LoadSpeed (loader )
207+ benchmark (load_speed )
208+
209+
179210def _check_attributes (obj , is_list = False ):
180211 if is_list :
181212 for obj_ in obj :
0 commit comments