From e761355e77371b41ee98f9544aa762136f376f71 Mon Sep 17 00:00:00 2001
From: cclaess
Date: Thu, 7 May 2026 19:44:03 +0200
Subject: [PATCH 1/3] Refactor optional dependency handling (#45)
---
README.md | 30 ++++-
pyproject.toml | 38 ++++--
src/spectre/configs/__init__.py | 47 ++++++--
src/spectre/data/_base_datasets.py | 84 ++++++++++---
src/spectre/data/abdomen_atlas.py | 8 +-
src/spectre/data/abdomenct_1k.py | 8 +-
src/spectre/data/amos.py | 8 +-
src/spectre/data/ct_rate.py | 20 +++-
src/spectre/data/inspect.py | 19 ++-
src/spectre/data/merlin.py | 20 +++-
src/spectre/data/nlst.py | 8 +-
src/spectre/data/panorama.py | 8 +-
src/spectre/data/sinoct.py | 21 +++-
src/spectre/data/total_segmentator.py | 19 ++-
.../losses/mask_classification_loss.py | 13 ++-
src/spectre/ssl/transforms/dino_transform.py | 110 ++++++++++--------
src/spectre/ssl/transforms/mae_transform.py | 78 +++++--------
.../ssl/transforms/siglip_transform.py | 58 ++++-----
src/spectre/transforms/combine_labels.py | 21 +++-
src/spectre/transforms/generate_report.py | 110 ++++--------------
.../transforms/largest_multiple_crop.py | 20 +++-
.../transforms/scale_intensity_range.py | 33 +++++-
src/spectre/utils/_utils.py | 13 ++-
src/spectre/utils/collate.py | 26 ++++-
src/spectre/utils/config.py | 21 +++-
src/spectre/utils/dataloader.py | 17 ++-
src/spectre/utils/distributed.py | 14 ++-
27 files changed, 573 insertions(+), 299 deletions(-)
diff --git a/README.md b/README.md
index a8c2a80..ef11f23 100644
--- a/README.md
+++ b/README.md
@@ -94,12 +94,40 @@ This repository is organized as follows:
## ⚙️ Setting Up the Environment
-To get up and running with SPECTRE, simply install our package using pip:
+To get up and running with SPECTRE, install the base package with pip:
```bash
pip install spectre-fm
```
+This installs only the runtime dependencies needed to load and run the pretrained models.
+
+If you want to fine-tune or pretrain SPECTRE, install the matching extra:
+
+```bash
+pip install "spectre-fm[training]"
+```
+
+If you only need the evaluation stack, install:
+
+```bash
+pip install "spectre-fm[eval]"
+```
+
+If training on GDS-enabled systems is required, install the CUDA 12 specific extra:
+
+```bash
+pip install "spectre-fm[gds-cuda12]" # with training stack: "spectre-fm[training,gds-cuda12]"
+```
+
+**Note that** `gds-cuda12` is only compatible with CUDA 12.x environments.
+
+To install everything at once, use:
+
+```bash
+pip install "spectre-fm[all]"
+```
+
or install the latest updates directly from GitHub:
```bash
diff --git a/pyproject.toml b/pyproject.toml
index 944c027..8a7266d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,27 +28,49 @@ keywords = ["medical", "ct", "self-supervised", "vision", "deep-learning"]
dependencies = [
"torch",
- "wandb",
"numpy",
- "monai[all]",
+ "timm",
+ "transformers",
+ "huggingface_hub",
+ "loralib",
+]
+
+[project.optional-dependencies]
+training = [
+ "wandb",
+ "monai",
"nibabel",
"SimpleITK",
- "timm",
"scipy",
"pandas",
"openpyxl",
+ "tqdm",
"accelerate",
"omegaconf",
- "transformers",
- "tokenizers",
- "peft",
- "loralib",
+]
+
+eval = [
+ "monai",
+ "nibabel",
+ "SimpleITK",
+ "pillow",
+ "scipy",
+ "scikit-learn",
+ "pandas",
+ "openpyxl",
+ "tqdm",
"matplotlib",
"seaborn",
- "scikit-learn",
"umap-learn",
]
+gds-cuda12 = [
+ "cupy-cuda12x",
+ "kvikio",
+]
+
+all = ["spectre-fm[training,eval,gds-cuda12]"]
+
[project.urls]
Homepage = "https://github.com/cclaess/SPECTRE"
Source = "https://github.com/cclaess/SPECTRE"
diff --git a/src/spectre/configs/__init__.py b/src/spectre/configs/__init__.py
index 00242cd..370138d 100644
--- a/src/spectre/configs/__init__.py
+++ b/src/spectre/configs/__init__.py
@@ -1,26 +1,57 @@
+from __future__ import annotations
+
+import warnings
+from typing import Any
from pathlib import Path
-from omegaconf import OmegaConf
+_OMEGACONF_IMPORT_ERROR = None
+try:
+ from omegaconf import OmegaConf, DictConfig
+except ImportError as e:
+ OmegaConf, DictConfig = None, Any # type: ignore
+ _OMEGACONF_IMPORT_ERROR = e
-def load_config(config_name: str) -> OmegaConf:
+def load_config(config_name: str) -> "DictConfig":
"""
Load config file from path.
"""
+ if _OMEGACONF_IMPORT_ERROR is not None:
+ raise ImportError(
+ "OmegaConf is required to load config files but not installed. "
+ "Please install OmegaConf to use this feature."
+ ) from _OMEGACONF_IMPORT_ERROR
+
config_filename = config_name + ".yaml"
config_path = Path(__file__).parent.resolve() / config_filename
return OmegaConf.load(config_path)
-default_config_dino = load_config("dino_default")
-default_config_dinov2 = load_config("dinov2_default")
-default_config_mae = load_config("mae_default")
-default_config_siglip = load_config("siglip_default")
+if OmegaConf is not None:
+ default_config_dino = load_config("dino_default")
+ default_config_dinov2 = load_config("dinov2_default")
+ default_config_mae = load_config("mae_default")
+ default_config_siglip = load_config("siglip_default")
+else:
+ default_config_dino = None
+ default_config_dinov2 = None
+ default_config_mae = None
+ default_config_siglip = None
+ warnings.warn(
+ "OmegaConf is not installed. Default configs will not be available and are set to `None`. "
+ "Please install OmegaConf to use default configs."
+ )
-def load_and_merge_config(config_name: str, default_config: OmegaConf) -> OmegaConf:
+def load_and_merge_config(config_name: str, default_config: "DictConfig") -> "DictConfig":
"""
Load and merge config file from path.
"""
+ if _OMEGACONF_IMPORT_ERROR is not None:
+ raise ImportError(
+ "OmegaConf is required to load config files but not installed. "
+ "Please install OmegaConf to use this feature."
+ ) from _OMEGACONF_IMPORT_ERROR
+
config = load_config(config_name)
- return OmegaConf.merge(default_config, config)
\ No newline at end of file
+ return OmegaConf.merge(default_config, config)
diff --git a/src/spectre/data/_base_datasets.py b/src/spectre/data/_base_datasets.py
index 2286c41..0911bd3 100644
--- a/src/spectre/data/_base_datasets.py
+++ b/src/spectre/data/_base_datasets.py
@@ -4,30 +4,85 @@
import shutil
import tempfile
from typing import Any
-from copy import deepcopy
from pathlib import Path
+from copy import deepcopy
import torch
import numpy as np
-import monai.data as data
-from monai.utils import look_up_option, convert_to_tensor
+
+SUPPORTED_PICKLE_MOD = {"pickle": pickle}
+_MONAI_IMPORT_ERROR = None
+_CUPY_IMPORT_ERROR = None
+_KVIKIO_IMPORT_ERROR = None
+
+try:
+ import monai
+except ImportError as e:
+ monai = None # type: ignore
+ _MONAI_IMPORT_ERROR = e
try:
import cupy as cp
+except ImportError as e:
+ cp = None # type: ignore
+ _CUPY_IMPORT_ERROR = e
+
+try:
import kvikio.numpy as kvikio_numpy
-except ImportError:
- cp = None
- kvikio_numpy = None
+except ImportError as e:
+ kvikio_numpy = None # type: ignore
+ _KVIKIO_IMPORT_ERROR = e
-SUPPORTED_PICKLE_MOD = {"pickle": pickle}
+def _require_monai():
+ if monai is None:
+ raise ImportError(
+ "MONAI is required for this functionality. "
+ "Please install it with `pip install monai`."
+ ) from _MONAI_IMPORT_ERROR
+
+
+def _require_gds_dependencies():
+ _require_monai()
+
+ if cp is None:
+ raise ImportError(
+ "cupy is required for this functionality. "
+ "Please install it with the appropriate CUDA build."
+ ) from _CUPY_IMPORT_ERROR
+
+ if kvikio_numpy is None:
+ raise ImportError(
+ "kvikio is required for this functionality. "
+ "Please install it with `pip install kvikio`."
+ ) from _KVIKIO_IMPORT_ERROR
+
+
+if monai is not None:
+ _BaseDataset = monai.data.Dataset
+ _BasePersistentDataset = monai.data.PersistentDataset
+ _BaseGDSDataset = monai.data.GDSDataset
+else:
+ _BaseDataset = object
+ _BasePersistentDataset = object
+ _BaseGDSDataset = object
+
+
+class Dataset(_BaseDataset):
+ """
+ Base dataset class for SPECTRE datasets.
+ """
+ def __init__(self, *args, **kwargs):
+ _require_monai()
+ super().__init__(*args, **kwargs)
-class PersistentDataset(data.PersistentDataset):
+class PersistentDataset(_BasePersistentDataset):
"""
Overwrite MONAI's PersistentDataset to support PyTorch 2.6.
"""
def __init__(self, *args, pickle_protocol=pickle.HIGHEST_PROTOCOL, **kwargs):
+ _require_monai()
super().__init__(*args, pickle_protocol=pickle_protocol, **kwargs)
def _cachecheck(self, item_transformed):
@@ -91,7 +146,7 @@ def _cachecheck(self, item_transformed):
torch.save(
obj=_item_transformed,
f=temp_hash_file,
- pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
+ pickle_module=monai.utils.look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
)
if temp_hash_file.is_file() and not hashfile.is_file():
@@ -106,12 +161,13 @@ def _cachecheck(self, item_transformed):
return _item_transformed
-class GDSDataset(data.GDSDataset):
+class GDSDataset(_BaseGDSDataset):
"""
Overwrite MONAI's GDSDataset to support PyTorch 2.6 and combined GPU/CPU data (image/text pairs)
without breaking the GDS fast path.
"""
def __init__(self, *args, pickle_protocol=pickle.HIGHEST_PROTOCOL, **kwargs):
+ _require_gds_dependencies()
super().__init__(*args, pickle_protocol=pickle_protocol, **kwargs)
def _cachecheck(self, item_transformed):
@@ -148,7 +204,7 @@ def _cachecheck(self, item_transformed):
except FileNotFoundError:
continue # non-tensor key handled by sidecar
item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(()))
- item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}")
+ item[k] = monai.utils.convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}")
item[f"{k}_meta_dict"] = meta_k
sidecar_path = f"{hashfile}-aux"
@@ -162,7 +218,7 @@ def _cachecheck(self, item_transformed):
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(()))
- _data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}")
+ _data = monai.utils.convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}")
filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys()))
if bool(filtered_keys):
return (_data, _meta)
@@ -175,7 +231,7 @@ def _cachecheck(self, item_transformed):
item_k = kvikio_numpy.fromfile(
f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(())
)
- item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
+ item_k = monai.utils.convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
return item
@@ -221,7 +277,7 @@ def _create_sidecar_cache(self, aux_dict, sidecar_path):
torch.save(
obj=aux_dict,
f=temp_hash_file,
- pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
+ pickle_module=monai.utils.look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
)
if temp_hash_file.is_file() and not sidecar_hashfile.is_file():
diff --git a/src/spectre/data/abdomen_atlas.py b/src/spectre/data/abdomen_atlas.py
index 1a06286..d61bc16 100644
--- a/src/spectre/data/abdomen_atlas.py
+++ b/src/spectre/data/abdomen_atlas.py
@@ -2,9 +2,11 @@
from pathlib import Path
from typing import Callable, Dict, List
-from monai.data import Dataset
-
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
def _initialize_dataset(
diff --git a/src/spectre/data/abdomenct_1k.py b/src/spectre/data/abdomenct_1k.py
index 6070714..cd3789e 100644
--- a/src/spectre/data/abdomenct_1k.py
+++ b/src/spectre/data/abdomenct_1k.py
@@ -2,9 +2,11 @@
from pathlib import Path
from typing import Callable, Dict, List
-from monai.data import Dataset
-
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
def _initialize_dataset(
diff --git a/src/spectre/data/amos.py b/src/spectre/data/amos.py
index ee53d5b..7a9e3a2 100644
--- a/src/spectre/data/amos.py
+++ b/src/spectre/data/amos.py
@@ -2,9 +2,11 @@
from pathlib import Path
from typing import Callable, Dict, List
-from monai.data import Dataset
-
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
def _initialize_dataset(
diff --git a/src/spectre/data/ct_rate.py b/src/spectre/data/ct_rate.py
index 7b2cfbc..13540bb 100644
--- a/src/spectre/data/ct_rate.py
+++ b/src/spectre/data/ct_rate.py
@@ -2,9 +2,19 @@
from pathlib import Path
from typing import Callable, Dict, List
-from monai.data import Dataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
+
+_PANDAS_IMPORT_ERROR = None
+try:
+ import pandas as pd
+except ImportError as e:
+ pd = None # type: ignore
+ _PANDAS_IMPORT_ERROR = e
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
def _initialize_dataset(
data_dir: str,
@@ -20,7 +30,11 @@ def _initialize_dataset(
image_paths = image_paths[:n_keep]
if include_reports:
- import pandas as pd
+ if _PANDAS_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Pandas is required to include reports in the dataset but not installed. "
+ "Please install Pandas to use this feature."
+ ) from _PANDAS_IMPORT_ERROR
text_path = os.path.join(Path(data_dir), 'dataset', "radiology_text_reports", f"{subset}_reports.xlsx" )
reports = pd.read_excel(text_path)
if subset == "train":
diff --git a/src/spectre/data/inspect.py b/src/spectre/data/inspect.py
index 8b7b666..7cfaa64 100644
--- a/src/spectre/data/inspect.py
+++ b/src/spectre/data/inspect.py
@@ -2,9 +2,18 @@
from pathlib import Path
from typing import Callable, List, Dict
-from monai.data import Dataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+_PANDAS_IMPORT_ERROR = None
+try:
+ import pandas as pd
+except ImportError as e:
+ pd = None # type: ignore
+ _PANDAS_IMPORT_ERROR = e
def parse_name(image_path):
@@ -24,7 +33,11 @@ def _initialize_dataset(
image_paths = image_paths[:n_keep]
if include_reports:
- import pandas as pd
+ if _PANDAS_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Pandas is required to include reports in the dataset but not installed. "
+ "Please install Pandas to use this feature."
+ ) from _PANDAS_IMPORT_ERROR
text_path = os.path.join(Path(data_dir), "inspect2", "Final_Impressions.xlsx")
reports = pd.read_excel(text_path)
diff --git a/src/spectre/data/merlin.py b/src/spectre/data/merlin.py
index 4f5be87..bf3fa74 100644
--- a/src/spectre/data/merlin.py
+++ b/src/spectre/data/merlin.py
@@ -2,10 +2,19 @@
from pathlib import Path
from typing import Callable, List, Dict
-import pandas as pd
-from monai.data import Dataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
+
+_PANDAS_IMPORT_ERROR = None
+try:
+ import pandas as pd
+except ImportError as e:
+ pd = None # type: ignore
+ _PANDAS_IMPORT_ERROR = e
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
def parse_name(image_path):
return image_path.name.replace(".nii.gz", "")
@@ -17,6 +26,11 @@ def _initialize_dataset(
subset: str = "train",
fraction: float = 1.0,
) -> List[Dict[str, str]]:
+ if _PANDAS_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Pandas is required to initialize the dataset but not installed. "
+ "Please install Pandas to use this dataset."
+ ) from _PANDAS_IMPORT_ERROR
image_paths = sorted(Path(data_dir).glob(os.path.join(
"merlinabdominalctdataset", "merlin_data", "*.nii.gz")))
diff --git a/src/spectre/data/nlst.py b/src/spectre/data/nlst.py
index 74a5010..917ad02 100644
--- a/src/spectre/data/nlst.py
+++ b/src/spectre/data/nlst.py
@@ -2,9 +2,11 @@
from pathlib import Path
from typing import Callable, Dict, List
-from monai.data import Dataset
-
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
def _initialize_dataset(
diff --git a/src/spectre/data/panorama.py b/src/spectre/data/panorama.py
index a1ed0b2..fe568ff 100644
--- a/src/spectre/data/panorama.py
+++ b/src/spectre/data/panorama.py
@@ -1,9 +1,11 @@
from pathlib import Path
from typing import Callable, List, Dict
-from monai.data import Dataset
-
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
def _initialize_dataset(
diff --git a/src/spectre/data/sinoct.py b/src/spectre/data/sinoct.py
index 0a6800d..bfd7bc5 100644
--- a/src/spectre/data/sinoct.py
+++ b/src/spectre/data/sinoct.py
@@ -2,10 +2,18 @@
from pathlib import Path
from typing import Callable, List, Dict, Union
-import pandas as pd
-from monai.data import Dataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+_PANDAS_IMPORT_ERROR = None
+try:
+ import pandas as pd
+except ImportError as e:
+ pd = None # type: ignore
+ _PANDAS_IMPORT_ERROR = e
def _initialize_dataset(
@@ -14,7 +22,12 @@ def _initialize_dataset(
split_ratio: tuple = (0.8, 0.1, 0.1), # train, val, test
seed: int = 0,
) -> List[Dict[str, Union[str, int]]]:
-
+ if _PANDAS_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Pandas is required to initialize the dataset but not installed. "
+ "Please install Pandas to use this dataset."
+ ) from _PANDAS_IMPORT_ERROR
+
labels_df = pd.read_csv(Path(data_dir) / "labels.csv", index_col="patient_id")
labels_df["abnormal"] = labels_df["label"].apply(lambda x: int(x.split(",")[1]))
labels_df = labels_df[["abnormal"]]
diff --git a/src/spectre/data/total_segmentator.py b/src/spectre/data/total_segmentator.py
index 6a3d274..3134582 100644
--- a/src/spectre/data/total_segmentator.py
+++ b/src/spectre/data/total_segmentator.py
@@ -2,10 +2,18 @@
from pathlib import Path
from typing import Callable, List, Union, Dict
-import pandas as pd
-from monai.data import Dataset
+from spectre.data._base_datasets import (
+ Dataset,
+ PersistentDataset,
+ GDSDataset,
+)
-from spectre.data._base_datasets import PersistentDataset, GDSDataset
+_PANDAS_IMPORT_ERROR = None
+try:
+ import pandas as pd
+except ImportError as e:
+ pd = None # type: ignore
+ _PANDAS_IMPORT_ERROR = e
LABEL_GROUPS = {
@@ -68,6 +76,11 @@ def _initialize_dataset(
],
subset: str = "train",
) -> List[Dict[str, str]]:
+ if _PANDAS_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Pandas is required to initialize the dataset but not installed. "
+ "Please install Pandas to use this dataset."
+ ) from _PANDAS_IMPORT_ERROR
image_paths = Path(data_dir).glob(os.path.join("*", "ct.nii.gz"))
diff --git a/src/spectre/losses/mask_classification_loss.py b/src/spectre/losses/mask_classification_loss.py
index 07ea8d6..a56d771 100644
--- a/src/spectre/losses/mask_classification_loss.py
+++ b/src/spectre/losses/mask_classification_loss.py
@@ -9,7 +9,13 @@
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
-from scipy.optimize import linear_sum_assignment
+
+_SCIPY_IMPORT_ERROR = None
+try:
+ from scipy.optimize import linear_sum_assignment
+except ImportError as e:
+ linear_sum_assignment = None # type: ignore
+ _SCIPY_IMPORT_ERROR = e
class MaskClassificationLoss(nn.Module):
@@ -56,6 +62,11 @@ def __init__(
(e.g. to downweight false-positive penalties). This is stored in
`self.empty_weight` and passed to `CrossEntropyLoss`.
"""
+ if _SCIPY_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Scipy is required to use MaskClassificationLoss but not installed. "
+ "Please install Scipy to use this loss."
+ ) from _SCIPY_IMPORT_ERROR
super().__init__()
self.num_labels = num_labels
self.num_points = num_points
diff --git a/src/spectre/ssl/transforms/dino_transform.py b/src/spectre/ssl/transforms/dino_transform.py
index cd92bd8..b7ef18f 100644
--- a/src/spectre/ssl/transforms/dino_transform.py
+++ b/src/spectre/ssl/transforms/dino_transform.py
@@ -2,35 +2,31 @@
from typing import Tuple, Mapping, Hashable, Any, List
import torch
-from monai.config import KeysCollection
-from monai.transforms import (
- Compose,
- LoadImaged,
- EnsureChannelFirstd,
- ScaleIntensityRanged,
- Orientationd,
- Spacingd,
- CenterSpatialCropd,
- SpatialPadd,
- EnsureTyped,
- RandSpatialCropSamplesd,
- SelectItemsd,
- RandSpatialCropSamples,
- RandFlip,
- OneOf,
- RandGaussianSharpen,
- RandGaussianSmooth,
- RandGaussianNoise,
- RandAdjustContrast,
- Resize,
- MapTransform,
- Randomizable,
- LazyTransform,
-)
+
+MONAI_IMPORT_ERROR = None
+try:
+ import monai.transforms as transforms
+ from monai.config import KeysCollection
+except ImportError as e:
+ transforms = None # type: ignore
+ KeysCollection = Any # type: ignore
+ MONAI_IMPORT_ERROR = e
from spectre.transforms import RandScaleIntensityRange
+if transforms is not None:
+ Compose = transforms.Compose
+ _BaseClass = (
+ transforms.Randomizable,
+ transforms.MapTransform,
+ transforms.LazyTransform,
+ )
+else:
+ Compose = object # type: ignore
+ _BaseClass = object # type: ignore
+
+
class DINOTransform(Compose):
def __init__(
self,
@@ -42,6 +38,12 @@ def __init__(
dtype: str = "float32",
use_gds: bool = False,
):
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use DINOTransform but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
+
assert dtype in ["float16", "float32"], \
"dtype must be either 'float16' or 'float32'"
@@ -51,12 +53,12 @@ def __init__(
)
super().__init__([
- LoadImaged(keys=("image",)),
- EnsureChannelFirstd(
+ transforms.LoadImaged(keys=("image",)),
+ transforms.EnsureChannelFirstd(
keys=("image",),
channel_dim="no_channel"
),
- ScaleIntensityRanged(
+ transforms.ScaleIntensityRanged(
keys=("image",),
a_min=-1000,
a_max=1000,
@@ -64,26 +66,26 @@ def __init__(
b_max=1.0,
clip=True,
),
- Orientationd(keys=("image",), axcodes="RAS"),
- Spacingd(
+ transforms.Orientationd(keys=("image",), axcodes="RAS"),
+ transforms.Spacingd(
keys=("image",),
pixdim=(0.5, 0.5, 1.0), # comply with newest scanners
mode=("bilinear",),
),
- CenterSpatialCropd(
+ transforms.CenterSpatialCropd(
keys=("image",),
roi_size=(512, 512, 384),
),
- SpatialPadd(
+ transforms.SpatialPadd(
keys=("image",),
spatial_size=base_crop_size,
),
- EnsureTyped(
+ transforms.EnsureTyped(
keys=("image",),
dtype=getattr(torch, dtype),
device=device,
),
- RandSpatialCropSamplesd(
+ transforms.RandSpatialCropSamplesd(
keys=("image",),
num_samples=num_base_patches,
roi_size=base_crop_size,
@@ -99,13 +101,13 @@ def __init__(
num_local_views=num_local_views,
dtype=dtype,
),
- SelectItemsd(
+ transforms.SelectItemsd(
keys=("image_global_views", "image_local_views"),
),
])
-class DINORandomCropTransformd(Randomizable, MapTransform, LazyTransform):
+class DINORandomCropTransformd(_BaseClass):
def __init__(
self,
keys: KeysCollection,
@@ -117,14 +119,20 @@ def __init__(
dtype: str = "float32",
lazy: bool = False,
) -> None:
- MapTransform.__init__(self, keys)
- LazyTransform.__init__(self, lazy)
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use DINORandomCropTransformd but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
+
+ transforms.MapTransform.__init__(self, keys)
+ transforms.LazyTransform.__init__(self, lazy)
self.global_views_size = global_views_size
self.local_views_size = local_views_size
self.local_views_scale = local_views_scale
self.num_local_views = num_local_views
- self.cropper_global = RandSpatialCropSamples(
+ self.cropper_global = transforms.RandSpatialCropSamples(
roi_size=tuple(int(local_views_scale[1] * sz) for sz in base_crop_size),
num_samples=2,
max_roi_size=base_crop_size,
@@ -132,7 +140,7 @@ def __init__(
random_size=True,
lazy=lazy,
)
- self.cropper_local = RandSpatialCropSamples(
+ self.cropper_local = transforms.RandSpatialCropSamples(
roi_size=tuple(int(self.local_views_scale[0] * sz) for sz in base_crop_size),
num_samples=num_local_views,
max_roi_size=tuple(int(self.local_views_scale[1] * sz) for sz in base_crop_size),
@@ -141,14 +149,14 @@ def __init__(
lazy=lazy,
)
- self.resize_global = Resize(
+ self.resize_global = transforms.Resize(
spatial_size=global_views_size,
mode="trilinear",
dtype=getattr(torch, dtype), # worst case 0.1-0.3% error for fp16
anti_aliasing=True, # downsample ratios up to 2
lazy=lazy,
)
- self.resize_local = Resize(
+ self.resize_local = transforms.Resize(
spatial_size=local_views_size,
mode="trilinear",
dtype=getattr(torch, dtype), # worst case 0.1-0.3% error for fp16
@@ -156,23 +164,23 @@ def __init__(
lazy=lazy,
)
- self.augmentor = Compose([
- RandFlip(spatial_axis=0, prob=0.5),
- RandFlip(spatial_axis=1, prob=0.5),
- RandFlip(spatial_axis=2, prob=0.5),
- OneOf([
- RandGaussianSharpen(
+ self.augmentor = transforms.Compose([
+ transforms.RandFlip(spatial_axis=0, prob=0.5),
+ transforms.RandFlip(spatial_axis=1, prob=0.5),
+ transforms.RandFlip(spatial_axis=2, prob=0.5),
+ transforms.OneOf([
+ transforms.RandGaussianSharpen(
sigma1_x=(1.5, 2.5), sigma1_y=(1.5, 2.5), sigma1_z=(0.75, 1.25),
sigma2_x=(0.5, 1.0), sigma2_y=(0.5, 1.0), sigma2_z=(0.25, 0.5),
prob=0.25,
),
- RandGaussianSmooth(
+ transforms.RandGaussianSmooth(
sigma_x=(1.5, 2.5), sigma_y=(1.5, 2.5), sigma_z=(0.75, 1.25),
prob=0.25,
),
]),
- RandAdjustContrast(gamma=(0.9, 1.1), prob=0.25),
- RandGaussianNoise(std=0.1, sample_std=True, prob=0.25),
+ transforms.RandAdjustContrast(gamma=(0.9, 1.1), prob=0.25),
+ transforms.RandGaussianNoise(std=0.1, sample_std=True, prob=0.25),
RandScaleIntensityRange(
a_min=(0.0, 0.4), # [0.0 * 2000 - 1000, 0.4 * 2000 - 1000] = [-1000, -200]
a_max=(0.6, 1.0), # [0.6 * 2000 - 1000, 1.0 * 2000 - 1000] = [200, 1000]
diff --git a/src/spectre/ssl/transforms/mae_transform.py b/src/spectre/ssl/transforms/mae_transform.py
index 759b20d..c34f3b0 100644
--- a/src/spectre/ssl/transforms/mae_transform.py
+++ b/src/spectre/ssl/transforms/mae_transform.py
@@ -1,20 +1,19 @@
from typing import Tuple
import torch
-from monai.transforms import (
- Compose,
- LoadImaged,
- EnsureChannelFirstd,
- ScaleIntensityRanged,
- Orientationd,
- Spacingd,
- SpatialPadd,
- CastToTyped,
- ResizeWithPadOrCropd,
- RandSpatialCropSamplesd,
- RandSpatialCropd,
- Resized,
-)
+
+MONAI_IMPORT_ERROR = None
+try:
+ import monai.transforms as transforms
+except ImportError as e:
+ transforms = None # type: ignore
+ MONAI_IMPORT_ERROR = e
+
+
+if transforms is not None:
+ Compose = transforms.Compose
+else:
+ Compose = object # type: ignore
class MAETransform(Compose):
@@ -23,12 +22,18 @@ def __init__(
input_size: Tuple[int, int, int] = (128, 128, 64),
dtype: str = "float32",
):
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use MAETransform but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
+
assert dtype in ["float16", "float32"], "dtype must be either 'float16' or 'float32'"
super().__init__(
[
- LoadImaged(keys=("image",)),
- EnsureChannelFirstd(keys=("image",), channel_dim="no_channel"),
- ScaleIntensityRanged(
+ transforms.LoadImaged(keys=("image",)),
+ transforms.EnsureChannelFirstd(keys=("image",), channel_dim="no_channel"),
+ transforms.ScaleIntensityRanged(
keys=("image",),
a_min=-1000,
a_max=1000,
@@ -36,12 +41,12 @@ def __init__(
b_max=1.0,
clip=True
),
- Orientationd(keys=("image",), axcodes="RAS"),
- Spacingd(keys=("image",), pixdim=(0.75, 0.75, 1.5), mode=("bilinear",)),
- ResizeWithPadOrCropd(keys=("image",), spatial_size=(384, 384, -1)),
- SpatialPadd(keys=("image",), spatial_size=(-1, -1, input_size[2])),
- CastToTyped(keys=("image",), dtype=getattr(torch, dtype)),
- RandSpatialCropSamplesd(
+ transforms.Orientationd(keys=("image",), axcodes="RAS"),
+ transforms.Spacingd(keys=("image",), pixdim=(0.75, 0.75, 1.5), mode=("bilinear",)),
+ transforms.ResizeWithPadOrCropd(keys=("image",), spatial_size=(384, 384, -1)),
+ transforms.SpatialPadd(keys=("image",), spatial_size=(-1, -1, input_size[2])),
+ transforms.CastToTyped(keys=("image",), dtype=getattr(torch, dtype)),
+ transforms.RandSpatialCropSamplesd(
keys=("image",),
roi_size=input_size,
num_samples=36,
@@ -49,36 +54,13 @@ def __init__(
random_size=False,
),
# Do a random resized crop
- RandSpatialCropd(
+ transforms.RandSpatialCropd(
keys=("image",),
roi_size=tuple(int(sz * 0.34) for sz in input_size), # 0.34 = (0.2 ** 2) ** (1/3)
max_roi_size=input_size,
random_center=True,
random_size=True,
),
- Resized(keys=("image",), spatial_size=input_size),
+ transforms.Resized(keys=("image",), spatial_size=input_size),
]
)
-
-
-if __name__ == "__main__":
-
- # Save some example data after transforming it.
- import os
- import SimpleITK as sitk
-
- data = {"image": r"data/test_data/train_1_a_1.nii.gz"}
- transform = MAETransform()
- transformed_data = transform(data)
-
- # Save the different crops to a folder for visualization.
- output_dir = r"data/test_data/mae_transform_output"
- os.makedirs(output_dir, exist_ok=True)
-
- for i, patch in enumerate(transformed_data):
-
- # Save the crops
- patch_img = sitk.GetImageFromArray(patch["image"].squeeze(0).numpy())
- patch_img.SetSpacing((1.5, 0.75, 0.75))
- patch_path = os.path.join(output_dir, f"{i}_crop.nii.gz")
- sitk.WriteImage(patch_img, patch_path)
diff --git a/src/spectre/ssl/transforms/siglip_transform.py b/src/spectre/ssl/transforms/siglip_transform.py
index b106e7f..e9703a2 100644
--- a/src/spectre/ssl/transforms/siglip_transform.py
+++ b/src/spectre/ssl/transforms/siglip_transform.py
@@ -1,24 +1,23 @@
from typing import Tuple
import torch
-from monai.transforms import (
- Compose,
- LoadImaged,
- EnsureChannelFirstd,
- ScaleIntensityRanged,
- Orientationd,
- Spacingd,
- ResizeWithPadOrCropd,
- EnsureTyped,
- RandSpatialCropd,
- RandFlipd,
- GridPatchd,
- SelectItemsd,
-)
+
+MONAI_IMPORT_ERROR = None
+try:
+ import monai.transforms as transforms
+except ImportError as e:
+ transforms = None # type: ignore
+ MONAI_IMPORT_ERROR = e
from spectre.transforms import RandomReportTransformd
+if transforms is not None:
+ Compose = transforms.Compose
+else:
+ Compose = object # type: ignore
+
+
class SigLIPTransform(Compose):
def __init__(
self,
@@ -32,6 +31,11 @@ def __init__(
dtype: str = "float32",
use_gds: bool = False,
):
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use SigLIPTransform but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
assert dtype in ["float16", "float32"], \
"dtype must be either 'float16' or 'float32'"
@@ -42,12 +46,12 @@ def __init__(
)
super().__init__([
- LoadImaged(keys=("image",)),
- EnsureChannelFirstd(
+ transforms.LoadImaged(keys=("image",)),
+ transforms.EnsureChannelFirstd(
keys=("image",),
channel_dim="no_channel"
),
- ScaleIntensityRanged(
+ transforms.ScaleIntensityRanged(
keys=("image",),
a_min=-1000,
a_max=1000,
@@ -55,30 +59,30 @@ def __init__(
b_max=1.0,
clip=True
),
- Orientationd(keys=("image",), axcodes="RAS"),
- Spacingd(
+ transforms.Orientationd(keys=("image",), axcodes="RAS"),
+ transforms.Spacingd(
keys=("image",),
pixdim=image_pixdim,
mode=("bilinear",)
),
- ResizeWithPadOrCropd(
+ transforms.ResizeWithPadOrCropd(
keys=("image",),
spatial_size=base_crop_size
),
- EnsureTyped(
+ transforms.EnsureTyped(
keys=("image",),
dtype=getattr(torch, dtype),
device=device
),
- RandSpatialCropd(
+ transforms.RandSpatialCropd(
keys=("image",),
roi_size=image_size,
random_size=False,
),
- RandFlipd(keys=("image",), spatial_axis=0, prob=0.5),
- RandFlipd(keys=("image",), spatial_axis=1, prob=0.5),
- RandFlipd(keys=("image",), spatial_axis=2, prob=0.5),
- GridPatchd(
+ transforms.RandFlipd(keys=("image",), spatial_axis=0, prob=0.5),
+ transforms.RandFlipd(keys=("image",), spatial_axis=1, prob=0.5),
+ transforms.RandFlipd(keys=("image",), spatial_axis=2, prob=0.5),
+ transforms.GridPatchd(
keys=("image",),
patch_size=sliding_window_size,
overlap=0.0,
@@ -89,7 +93,7 @@ def __init__(
keep_original_prob=keep_original_prob,
drop_prob=drop_prob,
),
- SelectItemsd(
+ transforms.SelectItemsd(
keys=("image", "report"),
),
])
diff --git a/src/spectre/transforms/combine_labels.py b/src/spectre/transforms/combine_labels.py
index dc7d675..0071347 100644
--- a/src/spectre/transforms/combine_labels.py
+++ b/src/spectre/transforms/combine_labels.py
@@ -1,10 +1,18 @@
-from typing import Sequence, Optional, Hashable, Mapping
+from typing import Any, Sequence, Optional, Hashable, Mapping
import torch
import numpy as np
-from monai.config import KeysCollection
-from monai.transforms import MapTransform
-from monai.config.type_definitions import NdarrayOrTensor
+
+MONAI_IMPORT_ERROR = None
+try:
+ from monai.config import KeysCollection
+ from monai.transforms import MapTransform
+ from monai.config.type_definitions import NdarrayOrTensor
+except ImportError as e:
+ KeysCollection = Any # type: ignore
+ MapTransform = object # type: ignore
+ NdarrayOrTensor = Any # type: ignore
+ MONAI_IMPORT_ERROR = e
class CombineLabelsd(MapTransform):
@@ -25,6 +33,11 @@ def __init__(
labels: Optional[Sequence[int]] = None,
allow_missing_keys: bool = False,
) -> None:
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use CombineLabelsd but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
if labels is not None and len(keys) != len(labels):
raise ValueError("The number of keys must match the number of labels provided.")
diff --git a/src/spectre/transforms/generate_report.py b/src/spectre/transforms/generate_report.py
index cc77a10..dc821ba 100644
--- a/src/spectre/transforms/generate_report.py
+++ b/src/spectre/transforms/generate_report.py
@@ -1,11 +1,22 @@
from copy import deepcopy
from typing import Any, Mapping, Hashable
-from monai.config import KeysCollection
-from monai.transforms import MapTransform, Randomizable
+MONAI_IMPORT_ERROR = None
+try:
+ from monai.config import KeysCollection
+ from monai.transforms import MapTransform, Randomizable
+except ImportError as e:
+ KeysCollection = Any # type: ignore
+ MONAI_IMPORT_ERROR = e
-class RandomReportTransformd(Randomizable, MapTransform):
+if MONAI_IMPORT_ERROR is None:
+ _BaseClass = (Randomizable, MapTransform)
+else:
+ _BaseClass = object
+
+
+class RandomReportTransformd(_BaseClass):
def __init__(
self,
keys: KeysCollection,
@@ -14,6 +25,12 @@ def __init__(
drop_prob=0.3,
allow_missing_keys: bool = False,
):
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use RandomReportTransformd but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
+
assert all(str(key) in ["findings", "impressions", "icd10"] for key in keys), \
"keys must be one of ['findings', 'impressions', 'icd10']"
@@ -89,90 +106,3 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
ret["report"] = f"{findings}{impressions}{icd10}"
return ret
-
-
-# class GenerateReportTransform(Randomizable, MapTransform):
-# def __init__(
-# self,
-# keys: KeysCollection,
-# max_num_icd10=20,
-# likelihood_original=0.5,
-# drop_chance=0.3,
-# allow_missing_keys: bool = False,
-# ):
-# super().__init__(keys, allow_missing_keys)
-# self.max_num_icd10 = max_num_icd10
-# self.likelihood_original = likelihood_original
-# self.drop_chance = drop_chance
-
-# # Random states (purely indices/flags)
-# self.drop_findings = False
-# self.drop_icd10 = False
-# self.finding_idx = None
-# self.impression_idx = None
-# self.icd10_indices = []
-
-# def randomize(self, data):
-# findings = data.get("findings", [])
-# impressions = data.get("impressions", [])
-# icd10_codes = data.get("icd10", [])
-
-# if isinstance(icd10_codes, str):
-# icd10_codes = icd10_codes.split(";")
-# if not isinstance(icd10_codes, list):
-# icd10_codes = []
-
-# self.drop_findings = self.R.random() < self.drop_chance
-# self.drop_icd10 = self.R.random() < self.drop_chance
-# self.finding_idx = None
-# self.impression_idx = None
-# self.icd10_indices = []
-
-# if not self.drop_findings and findings:
-# num_elements = len(findings)
-# if num_elements == 1:
-# self.finding_idx = 0
-# else:
-# weights = [self.likelihood_original] + [(1 - self.likelihood_original) / (num_elements - 1)] * (num_elements - 1)
-# self.finding_idx = int(self.R.choice(np.arange(num_elements), p=weights))
-
-# if impressions:
-# num_elements = len(impressions)
-# if num_elements == 1:
-# self.impression_idx = 0
-# else:
-# weights = [self.likelihood_original] + [(1 - self.likelihood_original) / (num_elements - 1)] * (num_elements - 1)
-# self.impression_idx = int(self.R.choice(np.arange(num_elements), p=weights))
-
-# if not self.drop_icd10 and icd10_codes:
-# num_codes = min(self.max_num_icd10, len(icd10_codes))
-# self.icd10_indices = self.R.choice(len(icd10_codes), size=num_codes, replace=False).tolist()
-
-# def __call__(self, data):
-# self.randomize(data)
-
-# findings = data.get("findings", [])
-# impressions = data.get("impressions", [])
-# icd10_codes = data.get("icd10", [])
-
-# if isinstance(icd10_codes, str):
-# icd10_codes = icd10_codes.split(";")
-# if not isinstance(icd10_codes, list):
-# icd10_codes = []
-
-# report = ""
-
-# if self.finding_idx is not None and self.finding_idx < len(findings):
-# finding = findings[self.finding_idx].replace("Impressions", "").replace("impressions", "")
-# report += f"Findings: {finding}\n"
-
-# if self.impression_idx is not None and self.impression_idx < len(impressions):
-# impression = impressions[self.impression_idx]
-# report += f"Impressions: {impression}\n"
-
-# if self.icd10_indices:
-# selected_icd10 = [icd10_codes[i] for i in self.icd10_indices if i < len(icd10_codes)]
-# report += f"ICD10: {'; '.join(selected_icd10)}\n"
-
-# data["report"] = report
-# return data
\ No newline at end of file
diff --git a/src/spectre/transforms/largest_multiple_crop.py b/src/spectre/transforms/largest_multiple_crop.py
index 6a12f84..8a72c58 100644
--- a/src/spectre/transforms/largest_multiple_crop.py
+++ b/src/spectre/transforms/largest_multiple_crop.py
@@ -1,8 +1,16 @@
-from typing import Sequence
+from typing import Any, Sequence
import numpy as np
-from monai.config import KeysCollection
-from monai.transforms import Cropd, CenterSpatialCrop
+
+MONAI_IMPORT_ERROR = None
+try:
+ from monai.config import KeysCollection
+ from monai.transforms import Cropd, CenterSpatialCrop
+except ImportError as e:
+ KeysCollection = Any # type: ignore
+ Cropd = object # type: ignore
+ CenterSpatialCrop = object # type: ignore
+ MONAI_IMPORT_ERROR = e
class LargestMultipleCenterCropd(Cropd):
@@ -22,6 +30,12 @@ def __init__(
allow_missing_keys: bool = False,
lazy: bool = False,
) -> None:
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use LargestMultipleCenterCropd but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
+
self.patch_size = patch_size
cropper = CenterSpatialCrop(roi_size=patch_size, lazy=lazy) # Placeholder, will be reset per image
super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)
diff --git a/src/spectre/transforms/scale_intensity_range.py b/src/spectre/transforms/scale_intensity_range.py
index 85055c4..230bc9b 100644
--- a/src/spectre/transforms/scale_intensity_range.py
+++ b/src/spectre/transforms/scale_intensity_range.py
@@ -1,14 +1,29 @@
-from typing import Optional, Tuple, Union
import warnings
-import numpy as np
+from typing import Any, Optional, Tuple, Union
+
import torch
+import numpy as np
+
+MONAI_IMPORT_ERROR = None
+try:
+ from monai.transforms import Transform, Randomizable
+ from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
+ from monai.utils import convert_to_tensor, convert_data_type
+except ImportError as e:
+ DtypeLike = Any # type: ignore
+ NdarrayOrTensor = Any # type: ignore
+ convert_to_tensor = lambda x: x # type: ignore
+ convert_data_type = lambda x, dtype: (x, None, None) # type: ignore
+ MONAI_IMPORT_ERROR = e
+
-from monai.transforms import Transform, Randomizable
-from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
-from monai.utils import convert_to_tensor, convert_data_type
+if MONAI_IMPORT_ERROR is None:
+ _BaseClass = (Randomizable, Transform)
+else:
+ _BaseClass = object
-class RandScaleIntensityRange(Randomizable, Transform):
+class RandScaleIntensityRange(_BaseClass):
"""
Randomizable variant of ScaleIntensityRange that samples the input window
(a_min, a_max) per-call using MONAI's RNG (self.R).
@@ -35,6 +50,12 @@ def __init__(
dtype: DtypeLike = np.float32,
prob: float = 1.0,
) -> None:
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use RandScaleIntensityRange but not installed. "
+ "Please install MONAI to use this transform."
+ ) from MONAI_IMPORT_ERROR
+
Transform.__init__(self)
Randomizable.__init__(self)
self.a_min = a_min
diff --git a/src/spectre/utils/_utils.py b/src/spectre/utils/_utils.py
index e131222..9b9952b 100644
--- a/src/spectre/utils/_utils.py
+++ b/src/spectre/utils/_utils.py
@@ -3,14 +3,25 @@
from itertools import repeat
import torch
-import monai
import numpy as np
+MONAI_IMPORT_ERROR = None
+try:
+ import monai
+except ImportError as e:
+ monai = None # type: ignore
+ MONAI_IMPORT_ERROR = e
def fix_random_seeds(seed: int = 31):
"""
Fix random seeds.
"""
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use fix_random_seeds but not installed. "
+ "Please install MONAI to use this function."
+ ) from MONAI_IMPORT_ERROR
+
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
diff --git a/src/spectre/utils/collate.py b/src/spectre/utils/collate.py
index 6496721..b4c8c0b 100644
--- a/src/spectre/utils/collate.py
+++ b/src/spectre/utils/collate.py
@@ -1,7 +1,13 @@
from typing import List, Callable, Optional
import torch
-from monai.data import list_data_collate
+
+MONAI_IMPORT_ERROR = None
+try:
+ from monai.data import list_data_collate
+except ImportError as e:
+ list_data_collate = lambda x: x # type: ignore
+ MONAI_IMPORT_ERROR = e
def extended_collate_dino(samples_list: List) -> dict:
@@ -19,6 +25,12 @@ def extended_collate_dino(samples_list: List) -> dict:
Returns:
A dictionary with collated global/local crops and corresponding masks.
"""
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use extended_collate_dino but not installed. "
+ "Please install MONAI to use this collate function."
+ ) from MONAI_IMPORT_ERROR
+
# Apply MONAI's list_data_collate
collated_data = list_data_collate(samples_list)
@@ -50,6 +62,12 @@ def extended_collate_siglip(
Returns:
A dictionary with collated images and tokenized text.
"""
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use extended_collate_siglip but not installed. "
+ "Please install MONAI to use this collate function."
+ ) from MONAI_IMPORT_ERROR
+
collated_data = list_data_collate(samples_list)
if return_filenames:
@@ -84,6 +102,12 @@ def collate_add_filenames(samples_list: List) -> dict:
Returns:
A dictionary with collated images and filenames.
"""
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use collate_add_filenames but not installed. "
+ "Please install MONAI to use this collate function."
+ ) from MONAI_IMPORT_ERROR
+
collated_data = list_data_collate(samples_list)
if "image" in collated_data.keys():
diff --git a/src/spectre/utils/config.py b/src/spectre/utils/config.py
index 5206cdb..6ec8f1b 100644
--- a/src/spectre/utils/config.py
+++ b/src/spectre/utils/config.py
@@ -1,10 +1,15 @@
import os
import math
-from omegaconf import OmegaConf
-
from spectre.utils import _utils, distributed
+OMEGACONF_IMPORT_ERROR = None
+try:
+ from omegaconf import OmegaConf
+except ImportError as e:
+ OmegaConf = None # type: ignore
+ OMEGACONF_IMPORT_ERROR = e
+
def apply_scaling_rules_to_cfg(cfg):
"""
@@ -38,6 +43,12 @@ def apply_scaling_rules_to_cfg(cfg):
def write_config(cfg, output_dir, name="config.yaml"):
+ if OMEGACONF_IMPORT_ERROR is not None:
+ raise ImportError(
+ "OmegaConf is required to use write_config but not installed. "
+ "Please install OmegaConf to use this function."
+ ) from OMEGACONF_IMPORT_ERROR
+
saved_cfg_path = os.path.join(output_dir, name)
with open(saved_cfg_path, "w") as f:
OmegaConf.save(config=cfg, f=f)
@@ -45,6 +56,12 @@ def write_config(cfg, output_dir, name="config.yaml"):
def get_cfg_from_args(args, default_config):
+ if OMEGACONF_IMPORT_ERROR is not None:
+ raise ImportError(
+ "OmegaConf is required to use get_cfg_from_args but not installed. "
+ "Please install OmegaConf to use this function."
+ ) from OMEGACONF_IMPORT_ERROR
+
args.output_dir = os.path.abspath(args.output_dir)
args.opts = [] if args.opts is None else args.opts
args.opts += [f"train.output_dir={args.output_dir}"]
diff --git a/src/spectre/utils/dataloader.py b/src/spectre/utils/dataloader.py
index efdcb6d..fb5a4d4 100644
--- a/src/spectre/utils/dataloader.py
+++ b/src/spectre/utils/dataloader.py
@@ -1,10 +1,18 @@
+from __future__ import annotations
import os
from typing import Union, Callable, Optional, List
import torch
-import monai.data as data
from torch.utils.data import ConcatDataset
+MONAI_IMPORT_ERROR = None
+try:
+ import monai.data as data
+except ImportError as e:
+ data = None # type: ignore
+ MONAI_IMPORT_ERROR = e
+
+
def get_dataloader(
datasets: Union[str, List[str]],
@@ -24,10 +32,15 @@ def get_dataloader(
drop_last: bool = True,
persistent_workers: bool = True,
use_thread: bool = False,
-) -> data.DataLoader:
+) -> "DataLoader":
"""
Get dataloader for training.
"""
+ if MONAI_IMPORT_ERROR is not None:
+ raise ImportError(
+ "MONAI is required to use get_dataloader but not installed. "
+ "Please install MONAI to use this function."
+ ) from MONAI_IMPORT_ERROR
if isinstance(datasets, str):
datasets = [datasets]
diff --git a/src/spectre/utils/distributed.py b/src/spectre/utils/distributed.py
index d41a35a..0dfa898 100644
--- a/src/spectre/utils/distributed.py
+++ b/src/spectre/utils/distributed.py
@@ -1,7 +1,14 @@
import os
import torch.distributed as dist
-from accelerate import Accelerator, DataLoaderConfiguration
+
+ACCELERATE_IMPORT_ERROR = None
+try:
+ from accelerate import Accelerator, DataLoaderConfiguration
+except ImportError as e:
+ Accelerator = None # type: ignore
+ DataLoaderConfiguration = None # type: ignore
+ ACCELERATE_IMPORT_ERROR = e
def is_enabled() -> bool:
@@ -56,6 +63,11 @@ def init_distributed(cfg):
"""
Initialize distributed training.
"""
+ if ACCELERATE_IMPORT_ERROR is not None:
+ raise ImportError(
+ "Accelerate is required to use init_distributed but not installed. "
+ "Please install Accelerate to use this function."
+ ) from ACCELERATE_IMPORT_ERROR
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(
From 61d7caa4d5950e08ad66c16294d57068482fb8cd Mon Sep 17 00:00:00 2001
From: "k.galanis"
Date: Mon, 18 May 2026 15:34:41 +0200
Subject: [PATCH 2/3] fix: use type() to build _BaseClass when MONAI is
installed
---
src/spectre/ssl/transforms/dino_transform.py | 4 ++--
src/spectre/transforms/generate_report.py | 2 +-
src/spectre/transforms/scale_intensity_range.py | 2 +-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/src/spectre/ssl/transforms/dino_transform.py b/src/spectre/ssl/transforms/dino_transform.py
index b7ef18f..abf609e 100644
--- a/src/spectre/ssl/transforms/dino_transform.py
+++ b/src/spectre/ssl/transforms/dino_transform.py
@@ -17,11 +17,11 @@
if transforms is not None:
Compose = transforms.Compose
- _BaseClass = (
+ _BaseClass = type("_BaseClass", (
transforms.Randomizable,
transforms.MapTransform,
transforms.LazyTransform,
- )
+ ), {})
else:
Compose = object # type: ignore
_BaseClass = object # type: ignore
diff --git a/src/spectre/transforms/generate_report.py b/src/spectre/transforms/generate_report.py
index dc821ba..2c591b4 100644
--- a/src/spectre/transforms/generate_report.py
+++ b/src/spectre/transforms/generate_report.py
@@ -11,7 +11,7 @@
if MONAI_IMPORT_ERROR is None:
- _BaseClass = (Randomizable, MapTransform)
+ _BaseClass = type("_BaseClass", (Randomizable, MapTransform), {})
else:
_BaseClass = object
diff --git a/src/spectre/transforms/scale_intensity_range.py b/src/spectre/transforms/scale_intensity_range.py
index 230bc9b..b94f469 100644
--- a/src/spectre/transforms/scale_intensity_range.py
+++ b/src/spectre/transforms/scale_intensity_range.py
@@ -18,7 +18,7 @@
if MONAI_IMPORT_ERROR is None:
- _BaseClass = (Randomizable, Transform)
+ _BaseClass = type("_BaseClass", (Randomizable, Transform), {})
else:
_BaseClass = object
From c87114d43122660bae80b5fbad3f753e0fd8f3f8 Mon Sep 17 00:00:00 2001
From: cclaess
Date: Wed, 20 May 2026 14:20:18 +0200
Subject: [PATCH 3/3] Add Hugging Face model export functionality and update
README with usage examples
---
.gitignore | 11 +-
README.md | 23 +-
hf_export/configuration_spectre.py | 32 ++
hf_export/metadata.yaml | 13 +
hf_export/modeling_spectre.py | 41 +++
scripts/export_hf.py | 307 ++++++++++++++++++
src/spectre/models/vision_transformer.py | 13 +-
.../models/vision_transformer_features.py | 13 +-
8 files changed, 435 insertions(+), 18 deletions(-)
create mode 100644 hf_export/configuration_spectre.py
create mode 100644 hf_export/metadata.yaml
create mode 100644 hf_export/modeling_spectre.py
create mode 100644 scripts/export_hf.py
diff --git a/.gitignore b/.gitignore
index 501e2fe..65087a5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -188,4 +188,13 @@ wandb/
*.pth
# VS Code
-.vscode/
\ No newline at end of file
+.vscode/
+
+# Huggingface
+hf_export/spectre/
+hf_export/imgs/
+hf_export/model.safetensors
+hf_export/config.json
+hf_export/README.md
+hf_export/LICENSE
+hf_export/LICENSE_MODELS
\ No newline at end of file
diff --git a/README.md b/README.md
index ef11f23..ef610b6 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,5 @@
+📢 [2026-05-20] The pretrained SPECTRE model can now be loaded directly through the `transformers` library, no separate SPECTRE package installation required. Check below for details and usage examples.
+
📢 [2026-04-10] SPECTRE is now an official baseline for the [**CVPR 2026 Workshop Competition: Foundation Models for General CT Image Diagnosis**](https://www.codabench.org/competitions/12650/)! See `experiments/cvpr26_fm_for_ct_diag_task_1` for scripts and additional details.
📢 [2026-02-21] SPECTRE has been accepted for presentation at **CVPR 2026** (Denver, Colorado, USA)!
@@ -12,8 +14,8 @@
-
-
+
+
@@ -27,14 +29,25 @@ SPECTRE has been trained on a large cohort of **open-source CT scans** of the **
This repository provides pretrained SPECTRE models together with tools for fine-tuning and evaluation.
## 🧠 Pretrained Models
-The pretrained SPECTRE model can easily be imported as follows:
+The pretrained SPECTRE model can easily be imported using the `transformers` library
```python
-from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS
-import torch
+from transformers import AutoModel
+model = AutoModel.from_pretrained('cclaess/SPECTRE-Large', trust_remote_code=True)
+```
+
+or by using the `spectre-fm` package as follows:
+```python
+from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS
config = MODEL_CONFIGS['spectre-large-pretrained']
model = SpectreImageFeatureExtractor.from_config(config)
+```
+
+A simple forward pass would look like:
+```python
+import torch
+
model.eval()
# Dummy input: (batch, crops, channels, height, width, depth)
diff --git a/hf_export/configuration_spectre.py b/hf_export/configuration_spectre.py
new file mode 100644
index 0000000..60445bc
--- /dev/null
+++ b/hf_export/configuration_spectre.py
@@ -0,0 +1,32 @@
+from transformers import PretrainedConfig
+
+
+class SpectreConfig(PretrainedConfig):
+ model_type = "spectre"
+
+ def __init__(
+ self,
+ backbone_name="vit_large_patch16_128",
+ backbone_kwargs={
+ "num_classes": 0,
+ "global_pool": '',
+ "pos_embed": "rope",
+ "rope_kwargs": {"base": 1000.0},
+ "init_values": 1.0,
+ },
+ feature_combiner_name="feat_vit_large",
+ feature_combiner_kwargs={
+ "num_classes": 0,
+ "global_pool": "",
+ "pos_embed": "rope",
+ "rope_kwargs": {"base": 100.0},
+ "init_values": 1.0,
+ },
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.backbone_name = backbone_name
+ self.backbone_kwargs = backbone_kwargs or {}
+ self.feature_combiner_name = feature_combiner_name
+ self.feature_combiner_kwargs = feature_combiner_kwargs or {}
diff --git a/hf_export/metadata.yaml b/hf_export/metadata.yaml
new file mode 100644
index 0000000..1d5c6e3
--- /dev/null
+++ b/hf_export/metadata.yaml
@@ -0,0 +1,13 @@
+license: cc-by-nc-sa-4.0
+language:
+- en
+tags:
+- medical-imaging
+- ct-scan
+- 3d
+- vision-transformer
+- self-supervised-learning
+- foundation-model
+- radiology
+library_name: transformers
+pipeline_tag: feature-extraction
\ No newline at end of file
diff --git a/hf_export/modeling_spectre.py b/hf_export/modeling_spectre.py
new file mode 100644
index 0000000..3c37b43
--- /dev/null
+++ b/hf_export/modeling_spectre.py
@@ -0,0 +1,41 @@
+import torch
+from transformers import PreTrainedModel
+from transformers.modeling_outputs import BaseModelOutput
+
+from spectre.model import SpectreImageFeatureExtractor
+try:
+ from .configuration_spectre import SpectreConfig
+except ImportError:
+ from configuration_spectre import SpectreConfig
+
+
+class SpectreModel(PreTrainedModel):
+ config_class = SpectreConfig
+ base_model_prefix = "spectre"
+ main_input_name = "pixel_values"
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.model = SpectreImageFeatureExtractor(
+ backbone_name=config.backbone_name,
+ backbone_kwargs=config.backbone_kwargs,
+ feature_combiner_name=config.feature_combiner_name,
+ feature_combiner_kwargs=config.feature_combiner_kwargs,
+ )
+
+ self.post_init()
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_size=None,
+ return_dict=False,
+ **kwargs,
+ ):
+ outputs = self.model(pixel_values, grid_size=grid_size)
+
+ if not return_dict:
+ return outputs
+
+ return BaseModelOutput(last_hidden_state=outputs)
diff --git a/scripts/export_hf.py b/scripts/export_hf.py
new file mode 100644
index 0000000..2fa79a1
--- /dev/null
+++ b/scripts/export_hf.py
@@ -0,0 +1,307 @@
+import os
+import re
+import sys
+import shutil
+import argparse
+from pathlib import Path
+
+ROOT = Path(__file__).resolve().parents[1]
+
+SRC_PACKAGE = ROOT / "src" / "spectre"
+EXPORT_DIR = ROOT / "hf_export"
+DEST_PACKAGE = EXPORT_DIR / "spectre"
+
+# Subset of the source package to copy — training code is intentionally excluded.
+# Map name -> True (copy the whole file/directory)
+# -> list (copy only these filenames from that subdirectory)
+# Any __init__.py is automatically patched to drop imports of excluded items.
+INCLUDE = {
+ "__init__.py": True,
+ "model.py": True,
+ "models": [
+ "layers",
+ "__init__.py",
+ "vision_transformer.py",
+ "vision_transformer_features.py",
+ ],
+ "utils": [
+ "__init__.py",
+ "_utils.py",
+ "modeling.py",
+ ],
+}
+
+
+def _patch_init(init_path, keep_modules):
+ """Strip import lines for modules not in keep_modules and remove their
+ exported names from __all__, handling multi-line imports throughout."""
+ lines = init_path.read_text(encoding="utf-8").splitlines(keepends=True)
+
+ # --- Pass 1: collect names exported by excluded imports ---
+ _KEYWORDS = {"as", "True", "False", "None"}
+ excluded_names = set()
+ i = 0
+ while i < len(lines):
+ line = lines[i]
+ m_direct = re.match(r"^from \. import ([\w]+)\b", line)
+ m_from = re.match(r"^from \.([\w]+)\b", line)
+ module = (m_direct or m_from)
+ if module and module.group(1) not in keep_modules:
+ if m_direct:
+ excluded_names.add(m_direct.group(1))
+ else:
+ after = line.partition("import")[2]
+ excluded_names.update(n for n in re.findall(r"\b([A-Za-z_]\w*)\b", after) if n not in _KEYWORDS)
+ depth = line.count("(") - line.count(")")
+ while depth > 0 and i + 1 < len(lines):
+ i += 1
+ excluded_names.update(n for n in re.findall(r"\b([A-Za-z_]\w*)\b", lines[i]) if n not in _KEYWORDS)
+ depth += lines[i].count("(") - lines[i].count(")")
+ i += 1
+
+ # --- Pass 2: remove excluded import lines ---
+ result = []
+ skip = False
+ paren_depth = 0
+ for line in lines:
+ if not skip:
+ m = re.match(r"^from \.([\w]+)\b", line) or re.match(r"^from \. import ([\w]+)\b", line)
+ if m and m.group(1) not in keep_modules:
+ skip = True
+ paren_depth = line.count("(") - line.count(")")
+ if paren_depth <= 0:
+ skip = False
+ continue
+ result.append(line)
+ else:
+ paren_depth += line.count("(") - line.count(")")
+ if paren_depth <= 0:
+ skip = False
+
+ # --- Pass 3: remove excluded names from __all__ ---
+ if excluded_names:
+ patched = []
+ in_all = False
+ bracket_depth = 0
+ for line in result:
+ if not in_all:
+ if re.match(r"^\s*__all__\s*=\s*\[", line):
+ in_all = True
+ bracket_depth = line.count("[") - line.count("]")
+ if bracket_depth <= 0:
+ in_all = False
+ patched.append(line)
+ else:
+ m = re.search(r'["\']([A-Za-z_]\w*)["\']', line)
+ if m and m.group(1) in excluded_names:
+ pass # drop this __all__ entry
+ else:
+ patched.append(line)
+ bracket_depth += line.count("[") - line.count("]")
+ if bracket_depth <= 0:
+ in_all = False
+ result = patched
+
+ init_path.write_text("".join(result), encoding="utf-8")
+
+
+def _copy_selective(src, dst, include):
+ """Copy src to dst including only the items in include.
+
+ include maps name -> True (copy everything) or list[str] (selective filenames).
+ The __init__.py in dst is patched to drop imports of any excluded items.
+ """
+ dst.mkdir(parents=True, exist_ok=True)
+ keep_modules = {Path(name).stem for name in include if name != "__init__.py"}
+ for name, what in include.items():
+ src_path = src / name
+ dst_path = dst / name
+ if what is True:
+ if src_path.is_dir():
+ shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns("__pycache__"))
+ else:
+ shutil.copy2(src_path, dst_path)
+ else:
+ _copy_selective(src_path, dst_path, {f: True for f in what})
+ init_path = dst / "__init__.py"
+ if init_path.exists():
+ _patch_init(init_path, keep_modules)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ description="Export SPECTRE model to HuggingFace Hub format"
+ )
+ parser.add_argument(
+ "--release",
+ action="store_true",
+ help="Upload model to HuggingFace Hub (default: False for safety)",
+ )
+ parser.add_argument(
+ "--repo-id",
+ type=str,
+ default="cclaess/SPECTRE-Large",
+ help="HuggingFace repo ID (default: cclaess/SPECTRE-Large)",
+ )
+ parser.add_argument(
+ "--hf-token",
+ type=str,
+ default=None,
+ help="HuggingFace token (if not provided, uses HF_TOKEN env variable)",
+ )
+ parser.add_argument(
+ "--skip-test",
+ action="store_true",
+ help="Skip local testing (default: False)",
+ )
+ parser.add_argument(
+ "--commit-message",
+ type=str,
+ default="Initial commit",
+ help="Commit message for HuggingFace upload (default: 'Initial commit')",
+ )
+
+ return parser.parse_args()
+
+
+def main(args):
+ # Determine if we should test locally
+ test_locally = not args.skip_test
+
+ # Get HF token from argument or environment
+ hf_token = args.hf_token or os.getenv("HF_TOKEN")
+
+ if args.release and not hf_token:
+ print("ERROR: --release flag requires HF_TOKEN env variable or --hf-token argument")
+ sys.exit(1)
+
+ print("=" * 60)
+ print("SPECTRE Model Export to HuggingFace")
+ print("=" * 60)
+ print(f"Export directory: {EXPORT_DIR}")
+ print(f"Repo ID: {args.repo_id}")
+ print(f"Test locally: {test_locally}")
+ print(f"Release (upload): {args.release}")
+ print("=" * 60)
+
+ # Clean export
+ if DEST_PACKAGE.exists():
+ print("Cleaning existing export directory...")
+ shutil.rmtree(DEST_PACKAGE)
+
+ # Copy selected parts of the spectre package
+ print("Copying spectre package...")
+ _copy_selective(SRC_PACKAGE, DEST_PACKAGE, INCLUDE)
+ print(f"✓ Exported spectre package to {DEST_PACKAGE}")
+
+ # Prepend export dir to sys.path
+ sys.path.insert(0, str(EXPORT_DIR))
+
+ from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS
+ from configuration_spectre import SpectreConfig
+ from modeling_spectre import SpectreModel
+
+ SpectreConfig.register_for_auto_class()
+ SpectreModel.register_for_auto_class("AutoModel")
+
+ print("Building model configuration...")
+ config = SpectreConfig(
+ backbone_name="vit_large_patch16_128",
+ backbone_kwargs={
+ "num_classes": 0,
+ "global_pool": "",
+ "pos_embed": "rope",
+ "rope_kwargs": {"base": 1000.0},
+ "init_values": 1.0,
+ },
+ feature_combiner_name="feat_vit_large",
+ feature_combiner_kwargs={
+ "num_classes": 0,
+ "global_pool": "",
+ "pos_embed": "rope",
+ "rope_kwargs": {"base": 100.0},
+ "init_values": 1.0,
+ },
+ )
+
+ print("Loading base model weights...")
+ base = SpectreImageFeatureExtractor.from_config(MODEL_CONFIGS["spectre-large-pretrained"])
+
+ print("Creating HuggingFace model...")
+ hf_model = SpectreModel(config)
+
+ hf_model.model.load_state_dict(base.state_dict())
+
+ print("Saving model and config...")
+ # save_pretrained copies the modeling file into the save directory; saving
+ # directly to EXPORT_DIR (where the file already lives) raises SameFileError,
+ # so we save to a temp dir and promote the outputs afterward.
+ import tempfile
+ with tempfile.TemporaryDirectory() as _tmp:
+ _tmp_path = Path(_tmp)
+ hf_model.save_pretrained(_tmp_path, safe_serialization=True)
+ config.save_pretrained(_tmp_path)
+ for f in _tmp_path.iterdir():
+ shutil.copy2(f, EXPORT_DIR / f.name)
+ print(f"✓ Saved model and config to {EXPORT_DIR}")
+
+ # Test loading the model locally
+ if test_locally:
+ print("\nTesting local model loading...")
+ from transformers import AutoModel
+
+ try:
+ model = AutoModel.from_pretrained(EXPORT_DIR, trust_remote_code=True)
+ print("✓ Model loaded successfully")
+ except Exception as e:
+ print(f"✗ Error loading model: {e}")
+ sys.exit(1)
+
+ # Upload to HuggingFace Hub
+ if args.release:
+ # Sync README from the repo root, prepending HF model card metadata as YAML frontmatter
+ metadata = (EXPORT_DIR / "metadata.yaml").read_text(encoding="utf-8")
+ readme = (ROOT / "README.md").read_text(encoding="utf-8")
+ (EXPORT_DIR / "README.md").write_text(f"---\n{metadata}\n---\n\n{readme}", encoding="utf-8")
+ shutil.copytree(ROOT / "imgs", EXPORT_DIR / "imgs", dirs_exist_ok=True)
+
+ shutil.copy2(ROOT / "LICENSE", EXPORT_DIR / "LICENSE")
+ shutil.copy2(ROOT / "LICENSE_MODELS", EXPORT_DIR / "LICENSE_MODELS")
+ print("✓ Copied README and LICENSE files")
+
+
+ # Strip __pycache__ directories so they are not uploaded
+ for pycache in EXPORT_DIR.rglob("__pycache__"):
+ shutil.rmtree(pycache)
+ print("✓ Removed __pycache__ directories")
+
+ print("\nUploading to HuggingFace Hub...")
+ from huggingface_hub import HfApi
+
+ try:
+ api = HfApi(token=hf_token)
+ api.create_repo(
+ repo_id=args.repo_id,
+ repo_type="model",
+ exist_ok=True,
+ )
+ api.upload_folder(
+ folder_path=str(EXPORT_DIR),
+ path_in_repo=".",
+ repo_id=args.repo_id,
+ repo_type="model",
+ commit_message=args.commit_message,
+ delete_patterns=["*"],
+ )
+ print(f"✓ Successfully uploaded to {args.repo_id}")
+ except Exception as e:
+ print(f"✗ Error uploading to HuggingFace: {e}")
+ sys.exit(1)
+ else:
+ print("\nSkipping upload (--release not specified)")
+ print(f"To upload, run: python export_hf.py --release --repo-id {args.repo_id}")
+
+
+if __name__ == "__main__":
+ args = get_args()
+ main(args)
diff --git a/src/spectre/models/vision_transformer.py b/src/spectre/models/vision_transformer.py
index 10fb7dd..adbd631 100644
--- a/src/spectre/models/vision_transformer.py
+++ b/src/spectre/models/vision_transformer.py
@@ -232,7 +232,7 @@ def __init__(
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
@@ -275,19 +275,20 @@ def __init__(
self.init_weights()
def init_weights(self) -> None:
- if self.pos_embed is not None:
+ if self.pos_embed is not None and not self.pos_embed.is_meta:
nn.init.trunc_normal_(self.pos_embed, std=.02)
- if self.cls_token is not None:
+ if self.cls_token is not None and not self.cls_token.is_meta:
nn.init.normal_(self.cls_token, std=1e-6)
- if self.reg_token is not None:
+ if self.reg_token is not None and not self.reg_token.is_meta:
nn.init.normal_(self.reg_token, std=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
# this fn left here for compat with downstream users
if isinstance(m, nn.Linear):
- nn.init.trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
+ if not m.weight.is_meta:
+ nn.init.trunc_normal_(m.weight, std=.02)
+ if m.bias is not None and not m.bias.is_meta:
nn.init.zeros_(m.bias)
@torch.jit.ignore
diff --git a/src/spectre/models/vision_transformer_features.py b/src/spectre/models/vision_transformer_features.py
index 1975197..ef50f8b 100644
--- a/src/spectre/models/vision_transformer_features.py
+++ b/src/spectre/models/vision_transformer_features.py
@@ -134,7 +134,7 @@ def __init__(
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
@@ -177,19 +177,20 @@ def __init__(
self.init_weights()
def init_weights(self) -> None:
- if self.pos_embed is not None:
+ if self.pos_embed is not None and not self.pos_embed.is_meta:
nn.init.trunc_normal_(self.pos_embed, std=.02)
- if self.cls_token is not None:
+ if self.cls_token is not None and not self.cls_token.is_meta:
nn.init.normal_(self.cls_token, std=1e-6)
- if self.reg_token is not None:
+ if self.reg_token is not None and not self.reg_token.is_meta:
nn.init.normal_(self.reg_token, std=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
# this fn left here for compat with downstream users
if isinstance(m, nn.Linear):
- nn.init.trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
+ if not m.weight.is_meta:
+ nn.init.trunc_normal_(m.weight, std=.02)
+ if m.bias is not None and not m.bias.is_meta:
nn.init.zeros_(m.bias)
@torch.jit.ignore