33import os
44import pandas as pd
55from unittest .mock import MagicMock , patch
6+ from ml_grid .model_classes .tabpfn_classifier_class import (
7+ TabPFNClassifierClass ,
8+ )
69
710# Add project root to path so we can import the class under test
811# Assumes this test file is in /tests/ and ml_grid is in the parent dir
2225sys .modules ["ml_grid.util.param_space" ] = MagicMock ()
2326sys .modules ["ml_grid.util.global_params" ] = MagicMock ()
2427
25- # Mock tabpfn to avoid downloading weights during tests
26- mock_tabpfn_module = MagicMock ()
27- # The mock needs to have the attributes that are accessed in the class
28- mock_tabpfn_classifier_instance = MagicMock ()
29- mock_tabpfn_module .TabPFNClassifier .return_value = mock_tabpfn_classifier_instance
30- mock_tabpfn_module .TabPFNClassifier .create_default_for_version .return_value = (
31- mock_tabpfn_classifier_instance
32- )
33- sys .modules ["tabpfn" ] = mock_tabpfn_module
34- sys .modules ["tabpfn.constants" ] = MagicMock ()
35-
36- from ml_grid .model_classes .tabpfn_classifier_class import (
37- TabPFNClassifierClass ,
38- ) # noqa: E402
28+ # Conditionally mock tabpfn. If it's installed, we might want to use it for integration tests.
29+ # If not installed, we mock it to allow importing the wrapper class.
30+ try :
31+ import tabpfn
32+ except ImportError :
33+ mock_tabpfn_module = MagicMock ()
34+ sys .modules ["tabpfn" ] = mock_tabpfn_module
35+ sys .modules ["tabpfn.constants" ] = MagicMock ()
3936
4037
4138class TestTabPFNClassifierClass (unittest .TestCase ):
4239 def setUp (self ):
43- # Reset the mock for TabPFNClassifier before each test
44- mock_tabpfn_module .TabPFNClassifier .reset_mock ()
45- mock_tabpfn_module .TabPFNClassifier .create_default_for_version .reset_mock ()
46- # also reset the instance mock
47- mock_tabpfn_classifier_instance .reset_mock ()
48- # re-assign the return value in case it was modified
49- mock_tabpfn_module .TabPFNClassifier .return_value = (
50- mock_tabpfn_classifier_instance
51- )
52- mock_tabpfn_module .TabPFNClassifier .create_default_for_version .return_value = (
53- mock_tabpfn_classifier_instance
54- )
55-
5640 # Patch global parameters to control bayessearch flag
5741 self .global_params_patch = patch ("ml_grid.util.global_params.global_parameters" )
5842 self .mock_global_params = self .global_params_patch .start ()
@@ -98,8 +82,12 @@ def test_parameter_space_compatibility(self):
9882 for param in removed_params :
9983 self .assertNotIn (param , model .parameter_space )
10084
101- def test_fit_v2_5_default (self ):
85+ @patch ("ml_grid.model_classes.tabpfn_classifier_class.TabPFNClassifier" )
86+ def test_fit_v2_5_default (self , mock_tabpfn_cls ):
10287 """Test fitting of the default v2.5 model."""
88+ # Setup mock return value
89+ mock_estimator_instance = mock_tabpfn_cls .return_value
90+
10391 # Instantiate the wrapper with hyperparameters
10492 model_wrapper = TabPFNClassifierClass (
10593 model_version = "v2.5_default" , n_estimators = 4 , device = "cpu" , random_state = 42
@@ -109,10 +97,10 @@ def test_fit_v2_5_default(self):
10997 model_wrapper .fit (self .X_dummy , self .y_dummy )
11098
11199 # Verify TabPFNClassifier constructor was called directly
112- mock_tabpfn_module . TabPFNClassifier .assert_called_once ()
100+ mock_tabpfn_cls .assert_called_once ()
113101
114102 # Verify arguments passed to the constructor
115- call_kwargs = mock_tabpfn_module . TabPFNClassifier .call_args .kwargs
103+ call_kwargs = mock_tabpfn_cls .call_args .kwargs
116104 self .assertEqual (call_kwargs ["n_estimators" ], 4 )
117105 self .assertEqual (call_kwargs ["device" ], "cpu" )
118106 self .assertEqual (call_kwargs ["random_state" ], 42 )
@@ -122,48 +110,56 @@ def test_fit_v2_5_default(self):
122110 self .assertNotIn ("subsample_samples" , call_kwargs )
123111
124112 # Verify the underlying estimator's fit method was called
125- mock_estimator_instance = mock_tabpfn_module .TabPFNClassifier .return_value
126113 mock_estimator_instance .fit .assert_called_once ()
127114
128- def test_fit_v2_5_synthetic (self ):
115+ @patch ("ml_grid.model_classes.tabpfn_classifier_class.TabPFNClassifier" )
116+ def test_fit_v2_5_synthetic (self , mock_tabpfn_cls ):
129117 """Test fitting of the synthetic v2.5 model (checks model_path logic)."""
118+ mock_estimator_instance = mock_tabpfn_cls .return_value
119+
130120 model_wrapper = TabPFNClassifierClass (
131121 model_version = "v2.5_synthetic" , n_estimators = 2
132122 )
133123
134124 model_wrapper .fit (self .X_dummy , self .y_dummy )
135125
136126 # Verify constructor was called
137- mock_tabpfn_module . TabPFNClassifier .assert_called_once ()
127+ mock_tabpfn_cls .assert_called_once ()
138128
139129 # Verify model_path was injected
140- call_kwargs = mock_tabpfn_module . TabPFNClassifier .call_args .kwargs
130+ call_kwargs = mock_tabpfn_cls .call_args .kwargs
141131 self .assertEqual (
142132 call_kwargs .get ("model_path" ), "tabpfn-v2.5-classifier-v2.5_default-2.ckpt"
143133 )
144134
145135 # Verify fit was called
146- mock_estimator_instance = mock_tabpfn_module .TabPFNClassifier .return_value
147136 mock_estimator_instance .fit .assert_called_once ()
148137
149- def test_fit_v2 (self ):
138+ @patch ("ml_grid.model_classes.tabpfn_classifier_class.TabPFNClassifier" )
139+ def test_fit_v2 (self , mock_tabpfn_cls ):
150140 """Test fitting of the legacy v2 model."""
141+ # Setup mock for create_default_for_version
142+ mock_estimator_instance = MagicMock ()
143+ mock_tabpfn_cls .create_default_for_version .return_value = (
144+ mock_estimator_instance
145+ )
146+
151147 model_wrapper = TabPFNClassifierClass (model_version = "v2" , n_estimators = 1 )
152148
153149 model_wrapper .fit (self .X_dummy , self .y_dummy )
154150
155151 # Verify it called create_default_for_version instead of standard constructor
156- mock_tabpfn_module . TabPFNClassifier .create_default_for_version .assert_called_once ()
157- mock_tabpfn_module . TabPFNClassifier .assert_not_called () # Ensure standard constructor was NOT called
152+ mock_tabpfn_cls .create_default_for_version .assert_called_once ()
153+ mock_tabpfn_cls .assert_not_called () # Ensure standard constructor was NOT called
158154
159155 # Verify fit was called
160- mock_estimator_instance = (
161- mock_tabpfn_module .TabPFNClassifier .create_default_for_version .return_value
162- )
163156 mock_estimator_instance .fit .assert_called_once ()
164157
165- def test_predict_and_predict_proba_delegation (self ):
158+ @patch ("ml_grid.model_classes.tabpfn_classifier_class.TabPFNClassifier" )
159+ def test_predict_and_predict_proba_delegation (self , mock_tabpfn_cls ):
166160 """Test that predict and predict_proba delegate to the internal estimator."""
161+ mock_estimator_instance = mock_tabpfn_cls .return_value
162+
167163 model_wrapper = TabPFNClassifierClass ()
168164
169165 # Fit the model to create the internal _estimator
@@ -180,8 +176,11 @@ def test_predict_and_predict_proba_delegation(self):
180176 model_wrapper .predict_proba (self .X_dummy )
181177 internal_estimator_mock .predict_proba .assert_called_once_with (self .X_dummy )
182178
183- def test_fit_with_subsampling (self ):
179+ @patch ("ml_grid.model_classes.tabpfn_classifier_class.TabPFNClassifier" )
180+ def test_fit_with_subsampling (self , mock_tabpfn_cls ):
184181 """Test that subsampling is applied when configured."""
182+ mock_estimator_instance = mock_tabpfn_cls .return_value
183+
185184 # Create larger dummy data
186185 X_large = pd .DataFrame ({"col1" : range (100 ), "col2" : range (100 )})
187186 y_large = pd .Series ([0 , 1 ] * 50 )
@@ -194,17 +193,53 @@ def test_fit_with_subsampling(self):
194193 model_wrapper .fit (X_large , y_large )
195194
196195 # Verify constructor called without subsample_samples
197- call_kwargs = mock_tabpfn_module . TabPFNClassifier .call_args .kwargs
196+ call_kwargs = mock_tabpfn_cls .call_args .kwargs
198197 self .assertNotIn ("subsample_samples" , call_kwargs )
199198
200199 # Verify fit was called with subsampled data
201- mock_estimator_instance = mock_tabpfn_module .TabPFNClassifier .return_value
202200 args , _ = mock_estimator_instance .fit .call_args
203201 X_passed , y_passed = args
204202
205203 self .assertEqual (len (X_passed ), subsample_size )
206204 self .assertEqual (len (y_passed ), subsample_size )
207205
206+ def test_real_execution_if_available (self ):
207+ """
208+ Integration test: Attempts to run with the real TabPFN library if available.
209+ If the model weights are missing (gated), it catches the RuntimeError and passes.
210+ """
211+ try :
212+ # Try to instantiate and fit a small model
213+ # We use n_estimators=1 for speed
214+ model_wrapper = TabPFNClassifierClass (n_estimators = 1 , device = "cpu" )
215+ model_wrapper .fit (self .X_dummy , self .y_dummy )
216+
217+ # If fit succeeds, try predict
218+ preds = model_wrapper .predict (self .X_dummy )
219+ self .assertEqual (len (preds ), len (self .X_dummy ))
220+
221+ except RuntimeError as e :
222+ # Check for the specific download error
223+ error_msg = str (e ).lower ()
224+ if (
225+ "download" in error_msg
226+ or "gated" in error_msg
227+ or "modelversion" in error_msg
228+ ):
229+ print (
230+ f"Skipping real execution test (Model download required/gated): { e } "
231+ )
232+ return
233+ # If it's another RuntimeError, re-raise it
234+ raise e
235+ except ImportError :
236+ print ("Skipping real execution test (TabPFN not installed)" )
237+ return
238+ except Exception as e :
239+ # Catch-all for other environment issues (e.g. network)
240+ print (f"Skipping real execution test due to unexpected error: { e } " )
241+ return
242+
208243
209244if __name__ == "__main__" :
210245 unittest .main ()
0 commit comments