@@ -482,14 +482,22 @@ def _fit_and_get_history(X, y):
482482
483483
484484@pytest .mark .parametrize ("seed" , [42 , 24 , 10 ])
485- def test_infonce_to_goodness_of_fit (seed ):
485+ @pytest .mark .parametrize ("batch_size" , [100 , 200 ])
486+ @pytest .mark .parametrize ("num_negatives" , [None , 100 , 200 ])
487+ def test_infonce_to_goodness_of_fit (seed , batch_size , num_negatives ):
486488 """Test the conversion from InfoNCE loss to goodness of fit metric."""
489+ nats_to_bits = np .log2 (np .e )
490+
487491 # Test with model
488492 cebra_model = cebra_sklearn_cebra .CEBRA (
489493 model_architecture = "offset10-model" ,
490494 max_iterations = 5 ,
491- batch_size = 128 ,
495+ batch_size = batch_size ,
496+ num_negatives = num_negatives ,
492497 )
498+ if num_negatives is None :
499+ num_negatives = batch_size
500+
493501 generator = torch .Generator ().manual_seed (seed )
494502 X = torch .rand (1000 , 50 , dtype = torch .float32 , generator = generator )
495503 cebra_model .fit (X )
@@ -498,19 +506,22 @@ def test_infonce_to_goodness_of_fit(seed):
498506 gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
499507 model = cebra_model )
500508 assert isinstance (gof , float )
509+ assert np .isclose (gof , (np .log (num_negatives ) - 1.0 ) * nats_to_bits )
501510
502511 # Test array of values
503512 infonce_values = np .array ([1.0 , 2.0 , 3.0 ])
504513 gof_array = cebra_sklearn_metrics .infonce_to_goodness_of_fit (
505514 infonce_values , model = cebra_model )
506515 assert isinstance (gof_array , np .ndarray )
507516 assert gof_array .shape == infonce_values .shape
517+ assert np .allclose (gof_array ,
518+ (np .log (num_negatives ) - infonce_values ) * nats_to_bits )
508519
509520 # Test with explicit batch_size and num_sessions
510- gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (1.0 ,
511- batch_size = 128 ,
512- num_sessions = 1 )
521+ gof = cebra_sklearn_metrics .infonce_to_goodness_of_fit (
522+ 1.0 , batch_size = batch_size , num_sessions = 1 )
513523 assert isinstance (gof , float )
524+ assert np .isclose (gof , (np .log (batch_size ) - 1.0 ) * nats_to_bits )
514525
515526 # Test error cases
516527 with pytest .raises (ValueError , match = "batch_size.*should not be provided" ):
0 commit comments