-
Notifications
You must be signed in to change notification settings - Fork 95
fix: Allow loading CUDA-saved models on CPU-only machines #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
acd63c7
97d5b90
55f7589
c02a95f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -74,13 +74,32 @@ def _safe_torch_load(filename, weights_only=False, **kwargs): | |||
| legacy_mode = packaging.version.parse( | ||||
| torch.__version__) < packaging.version.parse("2.6.0") | ||||
|
|
||||
| if legacy_mode: | ||||
| checkpoint = torch.load(filename, weights_only=False, **kwargs) | ||||
| else: | ||||
| with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): | ||||
| checkpoint = torch.load(filename, | ||||
| weights_only=weights_only, | ||||
| **kwargs) | ||||
| try: | ||||
| if legacy_mode: | ||||
| checkpoint = torch.load(filename, weights_only=False, **kwargs) | ||||
| else: | ||||
| with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): | ||||
| checkpoint = torch.load(filename, | ||||
| weights_only=weights_only, | ||||
| **kwargs) | ||||
| except RuntimeError as e: | ||||
| # Handle CUDA deserialization errors by retrying with map_location='cpu' | ||||
| if "CUDA" in str(e) or "cuda" in str(e).lower(): | ||||
| if "map_location" not in kwargs: | ||||
| kwargs["map_location"] = torch.device("cpu") | ||||
| if legacy_mode: | ||||
| checkpoint = torch.load(filename, weights_only=False, | ||||
| **kwargs) | ||||
| else: | ||||
| with torch.serialization.safe_globals( | ||||
| CEBRA_LOAD_SAFE_GLOBALS): | ||||
| checkpoint = torch.load(filename, | ||||
| weights_only=weights_only, | ||||
| **kwargs) | ||||
| else: | ||||
| raise | ||||
| else: | ||||
| raise | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. raise a meaningful error here |
||||
|
|
||||
| if not isinstance(checkpoint, dict): | ||||
| _check_type_checkpoint(checkpoint) | ||||
|
|
@@ -334,6 +353,32 @@ def _check_type_checkpoint(checkpoint): | |||
| return checkpoint | ||||
|
|
||||
|
|
||||
| def _resolve_checkpoint_device(device): | ||||
| """Resolve the device stored in a checkpoint for the current runtime. | ||||
|
|
||||
| If a checkpoint was saved on CUDA and CUDA is unavailable at load time, this | ||||
| falls back to CPU. | ||||
|
|
||||
| Args: | ||||
| device: The device from the checkpoint (str or torch.device). | ||||
|
|
||||
| Returns: | ||||
| str: The resolved device string ('cpu' or validated device). | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls dont mention types in args/returns. type annotate instead |
||||
| """ | ||||
| if isinstance(device, torch.device): | ||||
| device = str(device) | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not robust. use torch.device type instead of string parsing |
||||
|
|
||||
| if not isinstance(device, str): | ||||
| raise TypeError( | ||||
| "Expected checkpoint device to be a string or torch.device, " | ||||
| f"got {type(device)}.") | ||||
|
|
||||
| if device.startswith("cuda") and not torch.cuda.is_available(): | ||||
| return "cpu" | ||||
|
|
||||
| return sklearn_utils.check_device(device) | ||||
|
|
||||
|
|
||||
| def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | ||||
| """Loads a CEBRA model with a Sklearn backend. | ||||
|
|
||||
|
|
@@ -357,11 +402,24 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
|
|
||||
| args, state, state_dict = cebra_info['args'], cebra_info[ | ||||
| 'state'], cebra_info['state_dict'] | ||||
|
|
||||
| # Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
remove comments that are obvious from context |
||||
| saved_device = state["device_"] | ||||
| load_device = _resolve_checkpoint_device(saved_device) | ||||
|
|
||||
|
||||
| cebra_ = cebra.CEBRA(**args) | ||||
|
|
||||
| for key, value in state.items(): | ||||
| setattr(cebra_, key, value) | ||||
|
|
||||
| # Update device attributes to the resolved device for the current runtime | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above
Suggested change
|
||||
| cebra_.device_ = load_device | ||||
| saved_device_str = str(saved_device) if isinstance(saved_device, | ||||
| torch.device) else saved_device | ||||
| if isinstance(saved_device_str, | ||||
| str) and saved_device_str.startswith("cuda") and load_device == "cpu": | ||||
| cebra_.device = "cpu" | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above; lets use torch.device instead of string operations. e.g. instead of startswith you can check the |
||||
|
|
||||
| #TODO(stes): unused right now | ||||
| #state_and_args = {**args, **state} | ||||
|
|
||||
|
|
@@ -375,7 +433,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
| num_neurons=state["n_features_in_"], | ||||
| num_units=args["num_hidden_units"], | ||||
| num_output=args["output_dimension"], | ||||
| ).to(state['device_']) | ||||
| ).to(load_device) | ||||
|
|
||||
| elif isinstance(cebra_.num_sessions_, int): | ||||
| model = nn.ModuleList([ | ||||
|
|
@@ -385,10 +443,10 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
| num_units=args["num_hidden_units"], | ||||
| num_output=args["output_dimension"], | ||||
| ) for n_features in state["n_features_in_"] | ||||
| ]).to(state['device_']) | ||||
| ]).to(load_device) | ||||
|
|
||||
| criterion = cebra_._prepare_criterion() | ||||
| criterion.to(state['device_']) | ||||
| criterion.to(load_device) | ||||
|
|
||||
| optimizer = torch.optim.Adam( | ||||
| itertools.chain(model.parameters(), criterion.parameters()), | ||||
|
|
@@ -404,7 +462,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": | |||
| tqdm_on=args['verbose'], | ||||
| ) | ||||
| solver.load_state_dict(state_dict) | ||||
| solver.to(state['device_']) | ||||
| solver.to(load_device) | ||||
|
|
||||
| cebra_.model_ = model | ||||
| cebra_.solver_ = solver | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 1742156819804473408410328967400915377170 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 64 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| little |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| �UI�T�J=�����<L�ս�=c����=i͓;��ǽ4�z�����[UŽ=;����ħ��|���آ<K���sO�Gt�\4���ϻ� ��ց/=R$��ٴ<>nN:+�=�g.�����y/� | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| �Lƻ�U�q/3<k�亿�ʻ��;C�����;��������Ie��=�����<�<9�����ݚ�S:�;6)��\��?��';���m���"����������;~�ӻ�m | ||
| <\�;�[�� |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| ���8Ͱ�9���9�mD6u�9n�.8��-9�:�.�8�]�7�`?9Ņ�8�9��79 z9!�W9u!�8��8�W%7L��9 ��9E-:��B8/�K9���7�p�9vv�8ɔ9�&�8A� | ||
| :�E�8s.�9 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| @8uN�8��O8��7X8Ni�8�d7�t�8>lp6A�77�.8F�8���8ˇ�7�r*9r#�7��5n�7�46�z�8���8���8�>6.��8��7�j6(ϱ51� 9��6N��8 | ||
| 8q� | ||
| 8 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| � f�~&��v�:㙖9%������:%�DZ�:��[9�۸��,��B775ٹNҔ9�lعN��C9�XԹ�f���T��.�f:�k���\��_!:�:Ϲ2�:��9ݼ���94~:�r� |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| ͍�6���6:̓7&�6WU6��8%7�k8��6$4�6�7N6��05İ�5(*W6���8c!7uc5���7�z�6�/�7|3|8���7�6IF�7� | ||
| �7r1�6�a�7��5�w7��7�w�7��7 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| ��;�r�;b�w� |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 19�*�:��79 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| �5���k��R�=<��=�X`=�?>�˳��|����?=_I=�Kb=4�H�=a�X>k<�ɻZ�&<���=�0x=нڼYsͽ!�*>�D=h���.}7=l��<��=���>�����r�� | ||
| >z3d� | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| Lf=���<��=�*O<�{= | ||
| �3>�^����O=q_O>5L>m=f��=��<<���=�=O��<�L����=S��Ū<|�;]�=�#���=�=�oӽP��=�����ƪ=t�=C}�=I�~= |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| �{����=7�= | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 3 |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1097,6 +1097,169 @@ def get_ordered_cuda_devices(): | |||||
| ) else [] | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("saved_device", [ | ||||||
| "cuda", | ||||||
| "cuda:0", | ||||||
| torch.device("cuda"), | ||||||
| torch.device("cuda", 0), | ||||||
| ]) | ||||||
| @pytest.mark.parametrize("model_architecture", ["offset1-model", "parametrized-model-5"]) | ||||||
| def test_load_cuda_checkpoint_falls_back_to_cpu(saved_device, model_architecture, monkeypatch): | ||||||
| """Test that CUDA-saved checkpoints can be loaded on CPU-only machines. | ||||||
|
|
||||||
| This tests the fix for: Loading a model saved on CUDA when only CPU is available | ||||||
| should gracefully fall back to CPU instead of raising RuntimeError. | ||||||
| """ | ||||||
| X = np.random.uniform(0, 1, (100, 5)) | ||||||
|
|
||||||
| # Train a model on CPU | ||||||
| cebra_model = cebra_sklearn_cebra.CEBRA( | ||||||
| model_architecture=model_architecture, | ||||||
| max_iterations=5, | ||||||
| device="cpu" | ||||||
| ).fit(X) | ||||||
|
|
||||||
| with _windows_compatible_tempfile(mode="w+b") as tempname: | ||||||
| # Save the model | ||||||
| cebra_model.save(tempname) | ||||||
|
|
||||||
| # Modify the checkpoint to have a CUDA device | ||||||
| checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) | ||||||
| checkpoint["state"]["device_"] = saved_device | ||||||
| torch.save(checkpoint, tempname) | ||||||
|
|
||||||
|
Comment on lines
+1115
to
+1130
|
||||||
| # Mock CUDA as unavailable (simulating CPU-only machine) | ||||||
| monkeypatch.setattr(torch.cuda, "is_available", lambda: False) | ||||||
|
|
||||||
| # This should NOT raise RuntimeError: No CUDA GPUs are available | ||||||
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) | ||||||
|
|
||||||
| # Verify model is on CPU | ||||||
| assert loaded_model.device_ == "cpu", f"Expected device_='cpu', got {loaded_model.device_!r}" | ||||||
| assert loaded_model.device == "cpu", f"Expected device='cpu', got {loaded_model.device!r}" | ||||||
| assert next(loaded_model.solver_.model.parameters()).device == torch.device("cpu") | ||||||
|
|
||||||
| # Verify model actually works (can do inference) | ||||||
| X_test = np.random.uniform(0, 1, (10, 5)) | ||||||
| embedding = loaded_model.transform(X_test) | ||||||
| assert embedding.shape[0] == 10 # Correct number of samples | ||||||
| assert embedding.shape[1] > 0 # Has some output dimensions | ||||||
| assert isinstance(embedding, np.ndarray) | ||||||
|
|
||||||
|
|
||||||
| def test_safe_torch_load_cuda_fallback(monkeypatch): | ||||||
| """Test that _safe_torch_load retries with map_location='cpu' on CUDA errors. | ||||||
|
|
||||||
| This exercises the actual torch.load failure path when CUDA tensors are | ||||||
| present but CUDA is unavailable. | ||||||
| """ | ||||||
| import tempfile | ||||||
| import os | ||||||
|
|
||||||
| # Create a simple checkpoint | ||||||
| checkpoint = {"test": torch.tensor([1.0, 2.0, 3.0])} | ||||||
|
|
||||||
| with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: | ||||||
| tempname = f.name | ||||||
| torch.save(checkpoint, tempname) | ||||||
|
|
||||||
| try: | ||||||
| # Mock torch.load to fail on first call (simulating CUDA tensor load error) | ||||||
| original_torch_load = torch.load | ||||||
| call_count = [0] | ||||||
|
|
||||||
| def mock_torch_load(*args, **kwargs): | ||||||
| call_count[0] += 1 | ||||||
| if call_count[0] == 1 and "map_location" not in kwargs: | ||||||
| raise RuntimeError("Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False") | ||||||
| return original_torch_load(*args, **kwargs) | ||||||
|
|
||||||
| monkeypatch.setattr(torch, "load", mock_torch_load) | ||||||
|
|
||||||
| # Should retry with map_location='cpu' and succeed | ||||||
| result = cebra_sklearn_cebra._safe_torch_load(tempname) | ||||||
| assert "test" in result | ||||||
| assert torch.allclose(result["test"], checkpoint["test"]) | ||||||
| assert call_count[0] == 2 # First call failed, second retry succeeded | ||||||
|
|
||||||
| finally: | ||||||
| os.unlink(tempname) | ||||||
|
|
||||||
|
|
||||||
| @pytest.mark.parametrize("saved_device", ["cuda", "cuda:0"]) | ||||||
| def test_load_cuda_checkpoint_with_device_override(saved_device, monkeypatch): | ||||||
| """Test that automatic CPU fallback works with CUDA checkpoints.""" | ||||||
| X = np.random.uniform(0, 1, (100, 5)) | ||||||
|
|
||||||
| cebra_model = cebra_sklearn_cebra.CEBRA( | ||||||
| model_architecture="offset1-model", | ||||||
| max_iterations=5, | ||||||
| device="cpu" | ||||||
| ).fit(X) | ||||||
|
|
||||||
| with _windows_compatible_tempfile(mode="w+b") as tempname: | ||||||
| cebra_model.save(tempname) | ||||||
| checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) | ||||||
| checkpoint["state"]["device_"] = saved_device | ||||||
| torch.save(checkpoint, tempname) | ||||||
|
|
||||||
| monkeypatch.setattr(torch.cuda, "is_available", lambda: False) | ||||||
|
|
||||||
| # Load should automatically fall back to CPU | ||||||
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) | ||||||
|
||||||
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) | |
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname, map_location='cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate; I recommend to use
cebra.loadagain but adapt the map storage parameter instead, and fail on second attempt