diff --git a/.gitignore b/.gitignore index 501e2fe..65087a5 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,13 @@ wandb/ *.pth # VS Code -.vscode/ \ No newline at end of file +.vscode/ + +# Huggingface +hf_export/spectre/ +hf_export/imgs/ +hf_export/model.safetensors +hf_export/config.json +hf_export/README.md +hf_export/LICENSE +hf_export/LICENSE_MODELS \ No newline at end of file diff --git a/README.md b/README.md index a8c2a80..ef610b6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +📢 [2026-05-20] The pretrained SPECTRE model can now be loaded directly through the `transformers` library, no separate SPECTRE package installation required. Check below for details and usage examples. + 📢 [2026-04-10] SPECTRE is now an official baseline for the [**CVPR 2026 Workshop Competition: Foundation Models for General CT Image Diagnosis**](https://www.codabench.org/competitions/12650/)! See `experiments/cvpr26_fm_for_ct_diag_task_1` for scripts and additional details. 📢 [2026-02-21] SPECTRE has been accepted for presentation at **CVPR 2026** (Denver, Colorado, USA)! @@ -12,8 +14,8 @@ Python Versions Downloads per Month License - Model weights - Paper + Model weights + Preprint

@@ -27,14 +29,25 @@ SPECTRE has been trained on a large cohort of **open-source CT scans** of the ** This repository provides pretrained SPECTRE models together with tools for fine-tuning and evaluation. ## 🧠 Pretrained Models -The pretrained SPECTRE model can easily be imported as follows: +The pretrained SPECTRE model can easily be imported using the `transformers` library ```python -from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS -import torch +from transformers import AutoModel +model = AutoModel.from_pretrained('cclaess/SPECTRE-Large', trust_remote_code=True) +``` +or by using the `spectre-fm` package as follows: + +```python +from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS config = MODEL_CONFIGS['spectre-large-pretrained'] model = SpectreImageFeatureExtractor.from_config(config) +``` + +A simple forward pass would look like: +```python +import torch + model.eval() # Dummy input: (batch, crops, channels, height, width, depth) @@ -94,12 +107,40 @@ This repository is organized as follows: ## ⚙️ Setting Up the Environment -To get up and running with SPECTRE, simply install our package using pip: +To get up and running with SPECTRE, install the base package with pip: ```bash pip install spectre-fm ``` +This installs only the runtime dependencies needed to load and run the pretrained models. + +If you want to fine-tune or pretrain SPECTRE, install the matching extra: + +```bash +pip install "spectre-fm[training]" +``` + +If you only need the evaluation stack, install: + +```bash +pip install "spectre-fm[eval]" +``` + +If training on GDS-enabled systems is required, install the CUDA 12 specific extra: + +```bash +pip install "spectre-fm[gds-cuda12]" # with training stack: "spectre-fm[training,gds-cuda12]" +``` + +**Note that** `gds-cuda12` is only compatible with CUDA 12.x environments. + +To install everything at once, use: + +```bash +pip install "spectre-fm[all]" +``` + or install the latest updates directly from GitHub: ```bash diff --git a/hf_export/configuration_spectre.py b/hf_export/configuration_spectre.py new file mode 100644 index 0000000..60445bc --- /dev/null +++ b/hf_export/configuration_spectre.py @@ -0,0 +1,32 @@ +from transformers import PretrainedConfig + + +class SpectreConfig(PretrainedConfig): + model_type = "spectre" + + def __init__( + self, + backbone_name="vit_large_patch16_128", + backbone_kwargs={ + "num_classes": 0, + "global_pool": '', + "pos_embed": "rope", + "rope_kwargs": {"base": 1000.0}, + "init_values": 1.0, + }, + feature_combiner_name="feat_vit_large", + feature_combiner_kwargs={ + "num_classes": 0, + "global_pool": "", + "pos_embed": "rope", + "rope_kwargs": {"base": 100.0}, + "init_values": 1.0, + }, + **kwargs, + ): + super().__init__(**kwargs) + + self.backbone_name = backbone_name + self.backbone_kwargs = backbone_kwargs or {} + self.feature_combiner_name = feature_combiner_name + self.feature_combiner_kwargs = feature_combiner_kwargs or {} diff --git a/hf_export/metadata.yaml b/hf_export/metadata.yaml new file mode 100644 index 0000000..1d5c6e3 --- /dev/null +++ b/hf_export/metadata.yaml @@ -0,0 +1,13 @@ +license: cc-by-nc-sa-4.0 +language: +- en +tags: +- medical-imaging +- ct-scan +- 3d +- vision-transformer +- self-supervised-learning +- foundation-model +- radiology +library_name: transformers +pipeline_tag: feature-extraction \ No newline at end of file diff --git a/hf_export/modeling_spectre.py b/hf_export/modeling_spectre.py new file mode 100644 index 0000000..3c37b43 --- /dev/null +++ b/hf_export/modeling_spectre.py @@ -0,0 +1,41 @@ +import torch +from transformers import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutput + +from spectre.model import SpectreImageFeatureExtractor +try: + from .configuration_spectre import SpectreConfig +except ImportError: + from configuration_spectre import SpectreConfig + + +class SpectreModel(PreTrainedModel): + config_class = SpectreConfig + base_model_prefix = "spectre" + main_input_name = "pixel_values" + + def __init__(self, config): + super().__init__(config) + + self.model = SpectreImageFeatureExtractor( + backbone_name=config.backbone_name, + backbone_kwargs=config.backbone_kwargs, + feature_combiner_name=config.feature_combiner_name, + feature_combiner_kwargs=config.feature_combiner_kwargs, + ) + + self.post_init() + + def forward( + self, + pixel_values: torch.Tensor, + grid_size=None, + return_dict=False, + **kwargs, + ): + outputs = self.model(pixel_values, grid_size=grid_size) + + if not return_dict: + return outputs + + return BaseModelOutput(last_hidden_state=outputs) diff --git a/pyproject.toml b/pyproject.toml index 944c027..8a7266d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,27 +28,49 @@ keywords = ["medical", "ct", "self-supervised", "vision", "deep-learning"] dependencies = [ "torch", - "wandb", "numpy", - "monai[all]", + "timm", + "transformers", + "huggingface_hub", + "loralib", +] + +[project.optional-dependencies] +training = [ + "wandb", + "monai", "nibabel", "SimpleITK", - "timm", "scipy", "pandas", "openpyxl", + "tqdm", "accelerate", "omegaconf", - "transformers", - "tokenizers", - "peft", - "loralib", +] + +eval = [ + "monai", + "nibabel", + "SimpleITK", + "pillow", + "scipy", + "scikit-learn", + "pandas", + "openpyxl", + "tqdm", "matplotlib", "seaborn", - "scikit-learn", "umap-learn", ] +gds-cuda12 = [ + "cupy-cuda12x", + "kvikio", +] + +all = ["spectre-fm[training,eval,gds-cuda12]"] + [project.urls] Homepage = "https://github.com/cclaess/SPECTRE" Source = "https://github.com/cclaess/SPECTRE" diff --git a/scripts/export_hf.py b/scripts/export_hf.py new file mode 100644 index 0000000..2fa79a1 --- /dev/null +++ b/scripts/export_hf.py @@ -0,0 +1,307 @@ +import os +import re +import sys +import shutil +import argparse +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + +SRC_PACKAGE = ROOT / "src" / "spectre" +EXPORT_DIR = ROOT / "hf_export" +DEST_PACKAGE = EXPORT_DIR / "spectre" + +# Subset of the source package to copy — training code is intentionally excluded. +# Map name -> True (copy the whole file/directory) +# -> list (copy only these filenames from that subdirectory) +# Any __init__.py is automatically patched to drop imports of excluded items. +INCLUDE = { + "__init__.py": True, + "model.py": True, + "models": [ + "layers", + "__init__.py", + "vision_transformer.py", + "vision_transformer_features.py", + ], + "utils": [ + "__init__.py", + "_utils.py", + "modeling.py", + ], +} + + +def _patch_init(init_path, keep_modules): + """Strip import lines for modules not in keep_modules and remove their + exported names from __all__, handling multi-line imports throughout.""" + lines = init_path.read_text(encoding="utf-8").splitlines(keepends=True) + + # --- Pass 1: collect names exported by excluded imports --- + _KEYWORDS = {"as", "True", "False", "None"} + excluded_names = set() + i = 0 + while i < len(lines): + line = lines[i] + m_direct = re.match(r"^from \. import ([\w]+)\b", line) + m_from = re.match(r"^from \.([\w]+)\b", line) + module = (m_direct or m_from) + if module and module.group(1) not in keep_modules: + if m_direct: + excluded_names.add(m_direct.group(1)) + else: + after = line.partition("import")[2] + excluded_names.update(n for n in re.findall(r"\b([A-Za-z_]\w*)\b", after) if n not in _KEYWORDS) + depth = line.count("(") - line.count(")") + while depth > 0 and i + 1 < len(lines): + i += 1 + excluded_names.update(n for n in re.findall(r"\b([A-Za-z_]\w*)\b", lines[i]) if n not in _KEYWORDS) + depth += lines[i].count("(") - lines[i].count(")") + i += 1 + + # --- Pass 2: remove excluded import lines --- + result = [] + skip = False + paren_depth = 0 + for line in lines: + if not skip: + m = re.match(r"^from \.([\w]+)\b", line) or re.match(r"^from \. import ([\w]+)\b", line) + if m and m.group(1) not in keep_modules: + skip = True + paren_depth = line.count("(") - line.count(")") + if paren_depth <= 0: + skip = False + continue + result.append(line) + else: + paren_depth += line.count("(") - line.count(")") + if paren_depth <= 0: + skip = False + + # --- Pass 3: remove excluded names from __all__ --- + if excluded_names: + patched = [] + in_all = False + bracket_depth = 0 + for line in result: + if not in_all: + if re.match(r"^\s*__all__\s*=\s*\[", line): + in_all = True + bracket_depth = line.count("[") - line.count("]") + if bracket_depth <= 0: + in_all = False + patched.append(line) + else: + m = re.search(r'["\']([A-Za-z_]\w*)["\']', line) + if m and m.group(1) in excluded_names: + pass # drop this __all__ entry + else: + patched.append(line) + bracket_depth += line.count("[") - line.count("]") + if bracket_depth <= 0: + in_all = False + result = patched + + init_path.write_text("".join(result), encoding="utf-8") + + +def _copy_selective(src, dst, include): + """Copy src to dst including only the items in include. + + include maps name -> True (copy everything) or list[str] (selective filenames). + The __init__.py in dst is patched to drop imports of any excluded items. + """ + dst.mkdir(parents=True, exist_ok=True) + keep_modules = {Path(name).stem for name in include if name != "__init__.py"} + for name, what in include.items(): + src_path = src / name + dst_path = dst / name + if what is True: + if src_path.is_dir(): + shutil.copytree(src_path, dst_path, ignore=shutil.ignore_patterns("__pycache__")) + else: + shutil.copy2(src_path, dst_path) + else: + _copy_selective(src_path, dst_path, {f: True for f in what}) + init_path = dst / "__init__.py" + if init_path.exists(): + _patch_init(init_path, keep_modules) + + +def get_args(): + parser = argparse.ArgumentParser( + description="Export SPECTRE model to HuggingFace Hub format" + ) + parser.add_argument( + "--release", + action="store_true", + help="Upload model to HuggingFace Hub (default: False for safety)", + ) + parser.add_argument( + "--repo-id", + type=str, + default="cclaess/SPECTRE-Large", + help="HuggingFace repo ID (default: cclaess/SPECTRE-Large)", + ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="HuggingFace token (if not provided, uses HF_TOKEN env variable)", + ) + parser.add_argument( + "--skip-test", + action="store_true", + help="Skip local testing (default: False)", + ) + parser.add_argument( + "--commit-message", + type=str, + default="Initial commit", + help="Commit message for HuggingFace upload (default: 'Initial commit')", + ) + + return parser.parse_args() + + +def main(args): + # Determine if we should test locally + test_locally = not args.skip_test + + # Get HF token from argument or environment + hf_token = args.hf_token or os.getenv("HF_TOKEN") + + if args.release and not hf_token: + print("ERROR: --release flag requires HF_TOKEN env variable or --hf-token argument") + sys.exit(1) + + print("=" * 60) + print("SPECTRE Model Export to HuggingFace") + print("=" * 60) + print(f"Export directory: {EXPORT_DIR}") + print(f"Repo ID: {args.repo_id}") + print(f"Test locally: {test_locally}") + print(f"Release (upload): {args.release}") + print("=" * 60) + + # Clean export + if DEST_PACKAGE.exists(): + print("Cleaning existing export directory...") + shutil.rmtree(DEST_PACKAGE) + + # Copy selected parts of the spectre package + print("Copying spectre package...") + _copy_selective(SRC_PACKAGE, DEST_PACKAGE, INCLUDE) + print(f"✓ Exported spectre package to {DEST_PACKAGE}") + + # Prepend export dir to sys.path + sys.path.insert(0, str(EXPORT_DIR)) + + from spectre import SpectreImageFeatureExtractor, MODEL_CONFIGS + from configuration_spectre import SpectreConfig + from modeling_spectre import SpectreModel + + SpectreConfig.register_for_auto_class() + SpectreModel.register_for_auto_class("AutoModel") + + print("Building model configuration...") + config = SpectreConfig( + backbone_name="vit_large_patch16_128", + backbone_kwargs={ + "num_classes": 0, + "global_pool": "", + "pos_embed": "rope", + "rope_kwargs": {"base": 1000.0}, + "init_values": 1.0, + }, + feature_combiner_name="feat_vit_large", + feature_combiner_kwargs={ + "num_classes": 0, + "global_pool": "", + "pos_embed": "rope", + "rope_kwargs": {"base": 100.0}, + "init_values": 1.0, + }, + ) + + print("Loading base model weights...") + base = SpectreImageFeatureExtractor.from_config(MODEL_CONFIGS["spectre-large-pretrained"]) + + print("Creating HuggingFace model...") + hf_model = SpectreModel(config) + + hf_model.model.load_state_dict(base.state_dict()) + + print("Saving model and config...") + # save_pretrained copies the modeling file into the save directory; saving + # directly to EXPORT_DIR (where the file already lives) raises SameFileError, + # so we save to a temp dir and promote the outputs afterward. + import tempfile + with tempfile.TemporaryDirectory() as _tmp: + _tmp_path = Path(_tmp) + hf_model.save_pretrained(_tmp_path, safe_serialization=True) + config.save_pretrained(_tmp_path) + for f in _tmp_path.iterdir(): + shutil.copy2(f, EXPORT_DIR / f.name) + print(f"✓ Saved model and config to {EXPORT_DIR}") + + # Test loading the model locally + if test_locally: + print("\nTesting local model loading...") + from transformers import AutoModel + + try: + model = AutoModel.from_pretrained(EXPORT_DIR, trust_remote_code=True) + print("✓ Model loaded successfully") + except Exception as e: + print(f"✗ Error loading model: {e}") + sys.exit(1) + + # Upload to HuggingFace Hub + if args.release: + # Sync README from the repo root, prepending HF model card metadata as YAML frontmatter + metadata = (EXPORT_DIR / "metadata.yaml").read_text(encoding="utf-8") + readme = (ROOT / "README.md").read_text(encoding="utf-8") + (EXPORT_DIR / "README.md").write_text(f"---\n{metadata}\n---\n\n{readme}", encoding="utf-8") + shutil.copytree(ROOT / "imgs", EXPORT_DIR / "imgs", dirs_exist_ok=True) + + shutil.copy2(ROOT / "LICENSE", EXPORT_DIR / "LICENSE") + shutil.copy2(ROOT / "LICENSE_MODELS", EXPORT_DIR / "LICENSE_MODELS") + print("✓ Copied README and LICENSE files") + + + # Strip __pycache__ directories so they are not uploaded + for pycache in EXPORT_DIR.rglob("__pycache__"): + shutil.rmtree(pycache) + print("✓ Removed __pycache__ directories") + + print("\nUploading to HuggingFace Hub...") + from huggingface_hub import HfApi + + try: + api = HfApi(token=hf_token) + api.create_repo( + repo_id=args.repo_id, + repo_type="model", + exist_ok=True, + ) + api.upload_folder( + folder_path=str(EXPORT_DIR), + path_in_repo=".", + repo_id=args.repo_id, + repo_type="model", + commit_message=args.commit_message, + delete_patterns=["*"], + ) + print(f"✓ Successfully uploaded to {args.repo_id}") + except Exception as e: + print(f"✗ Error uploading to HuggingFace: {e}") + sys.exit(1) + else: + print("\nSkipping upload (--release not specified)") + print(f"To upload, run: python export_hf.py --release --repo-id {args.repo_id}") + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/src/spectre/configs/__init__.py b/src/spectre/configs/__init__.py index 00242cd..370138d 100644 --- a/src/spectre/configs/__init__.py +++ b/src/spectre/configs/__init__.py @@ -1,26 +1,57 @@ +from __future__ import annotations + +import warnings +from typing import Any from pathlib import Path -from omegaconf import OmegaConf +_OMEGACONF_IMPORT_ERROR = None +try: + from omegaconf import OmegaConf, DictConfig +except ImportError as e: + OmegaConf, DictConfig = None, Any # type: ignore + _OMEGACONF_IMPORT_ERROR = e -def load_config(config_name: str) -> OmegaConf: +def load_config(config_name: str) -> "DictConfig": """ Load config file from path. """ + if _OMEGACONF_IMPORT_ERROR is not None: + raise ImportError( + "OmegaConf is required to load config files but not installed. " + "Please install OmegaConf to use this feature." + ) from _OMEGACONF_IMPORT_ERROR + config_filename = config_name + ".yaml" config_path = Path(__file__).parent.resolve() / config_filename return OmegaConf.load(config_path) -default_config_dino = load_config("dino_default") -default_config_dinov2 = load_config("dinov2_default") -default_config_mae = load_config("mae_default") -default_config_siglip = load_config("siglip_default") +if OmegaConf is not None: + default_config_dino = load_config("dino_default") + default_config_dinov2 = load_config("dinov2_default") + default_config_mae = load_config("mae_default") + default_config_siglip = load_config("siglip_default") +else: + default_config_dino = None + default_config_dinov2 = None + default_config_mae = None + default_config_siglip = None + warnings.warn( + "OmegaConf is not installed. Default configs will not be available and are set to `None`. " + "Please install OmegaConf to use default configs." + ) -def load_and_merge_config(config_name: str, default_config: OmegaConf) -> OmegaConf: +def load_and_merge_config(config_name: str, default_config: "DictConfig") -> "DictConfig": """ Load and merge config file from path. """ + if _OMEGACONF_IMPORT_ERROR is not None: + raise ImportError( + "OmegaConf is required to load config files but not installed. " + "Please install OmegaConf to use this feature." + ) from _OMEGACONF_IMPORT_ERROR + config = load_config(config_name) - return OmegaConf.merge(default_config, config) \ No newline at end of file + return OmegaConf.merge(default_config, config) diff --git a/src/spectre/data/_base_datasets.py b/src/spectre/data/_base_datasets.py index 2286c41..0911bd3 100644 --- a/src/spectre/data/_base_datasets.py +++ b/src/spectre/data/_base_datasets.py @@ -4,30 +4,85 @@ import shutil import tempfile from typing import Any -from copy import deepcopy from pathlib import Path +from copy import deepcopy import torch import numpy as np -import monai.data as data -from monai.utils import look_up_option, convert_to_tensor + +SUPPORTED_PICKLE_MOD = {"pickle": pickle} +_MONAI_IMPORT_ERROR = None +_CUPY_IMPORT_ERROR = None +_KVIKIO_IMPORT_ERROR = None + +try: + import monai +except ImportError as e: + monai = None # type: ignore + _MONAI_IMPORT_ERROR = e try: import cupy as cp +except ImportError as e: + cp = None # type: ignore + _CUPY_IMPORT_ERROR = e + +try: import kvikio.numpy as kvikio_numpy -except ImportError: - cp = None - kvikio_numpy = None +except ImportError as e: + kvikio_numpy = None # type: ignore + _KVIKIO_IMPORT_ERROR = e -SUPPORTED_PICKLE_MOD = {"pickle": pickle} +def _require_monai(): + if monai is None: + raise ImportError( + "MONAI is required for this functionality. " + "Please install it with `pip install monai`." + ) from _MONAI_IMPORT_ERROR + + +def _require_gds_dependencies(): + _require_monai() + + if cp is None: + raise ImportError( + "cupy is required for this functionality. " + "Please install it with the appropriate CUDA build." + ) from _CUPY_IMPORT_ERROR + + if kvikio_numpy is None: + raise ImportError( + "kvikio is required for this functionality. " + "Please install it with `pip install kvikio`." + ) from _KVIKIO_IMPORT_ERROR + + +if monai is not None: + _BaseDataset = monai.data.Dataset + _BasePersistentDataset = monai.data.PersistentDataset + _BaseGDSDataset = monai.data.GDSDataset +else: + _BaseDataset = object + _BasePersistentDataset = object + _BaseGDSDataset = object + + +class Dataset(_BaseDataset): + """ + Base dataset class for SPECTRE datasets. + """ + def __init__(self, *args, **kwargs): + _require_monai() + super().__init__(*args, **kwargs) -class PersistentDataset(data.PersistentDataset): +class PersistentDataset(_BasePersistentDataset): """ Overwrite MONAI's PersistentDataset to support PyTorch 2.6. """ def __init__(self, *args, pickle_protocol=pickle.HIGHEST_PROTOCOL, **kwargs): + _require_monai() super().__init__(*args, pickle_protocol=pickle_protocol, **kwargs) def _cachecheck(self, item_transformed): @@ -91,7 +146,7 @@ def _cachecheck(self, item_transformed): torch.save( obj=_item_transformed, f=temp_hash_file, - pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_module=monai.utils.look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, ) if temp_hash_file.is_file() and not hashfile.is_file(): @@ -106,12 +161,13 @@ def _cachecheck(self, item_transformed): return _item_transformed -class GDSDataset(data.GDSDataset): +class GDSDataset(_BaseGDSDataset): """ Overwrite MONAI's GDSDataset to support PyTorch 2.6 and combined GPU/CPU data (image/text pairs) without breaking the GDS fast path. """ def __init__(self, *args, pickle_protocol=pickle.HIGHEST_PROTOCOL, **kwargs): + _require_gds_dependencies() super().__init__(*args, pickle_protocol=pickle_protocol, **kwargs) def _cachecheck(self, item_transformed): @@ -148,7 +204,7 @@ def _cachecheck(self, item_transformed): except FileNotFoundError: continue # non-tensor key handled by sidecar item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(())) - item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") + item[k] = monai.utils.convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}") item[f"{k}_meta_dict"] = meta_k sidecar_path = f"{hashfile}-aux" @@ -162,7 +218,7 @@ def _cachecheck(self, item_transformed): elif isinstance(item_transformed, (np.ndarray, torch.Tensor)): _meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta") _data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta["dtype"], like=cp.empty(())) - _data = convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}") + _data = monai.utils.convert_to_tensor(_data.reshape(_meta["shape"]), device=f"cuda:{self.device}") filtered_keys = list(filter(lambda key: key not in ["dtype", "shape"], _meta.keys())) if bool(filtered_keys): return (_data, _meta) @@ -175,7 +231,7 @@ def _cachecheck(self, item_transformed): item_k = kvikio_numpy.fromfile( f"{hashfile}-{k}-{i}", dtype=meta_i_k["dtype"], like=cp.empty(()) ) - item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") + item_k = monai.utils.convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}") item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k}) return item @@ -221,7 +277,7 @@ def _create_sidecar_cache(self, aux_dict, sidecar_path): torch.save( obj=aux_dict, f=temp_hash_file, - pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), + pickle_module=monai.utils.look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, ) if temp_hash_file.is_file() and not sidecar_hashfile.is_file(): diff --git a/src/spectre/data/abdomen_atlas.py b/src/spectre/data/abdomen_atlas.py index 1a06286..d61bc16 100644 --- a/src/spectre/data/abdomen_atlas.py +++ b/src/spectre/data/abdomen_atlas.py @@ -2,9 +2,11 @@ from pathlib import Path from typing import Callable, Dict, List -from monai.data import Dataset - -from spectre.data._base_datasets import PersistentDataset, GDSDataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) def _initialize_dataset( diff --git a/src/spectre/data/abdomenct_1k.py b/src/spectre/data/abdomenct_1k.py index 6070714..cd3789e 100644 --- a/src/spectre/data/abdomenct_1k.py +++ b/src/spectre/data/abdomenct_1k.py @@ -2,9 +2,11 @@ from pathlib import Path from typing import Callable, Dict, List -from monai.data import Dataset - -from spectre.data._base_datasets import PersistentDataset, GDSDataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) def _initialize_dataset( diff --git a/src/spectre/data/amos.py b/src/spectre/data/amos.py index ee53d5b..7a9e3a2 100644 --- a/src/spectre/data/amos.py +++ b/src/spectre/data/amos.py @@ -2,9 +2,11 @@ from pathlib import Path from typing import Callable, Dict, List -from monai.data import Dataset - -from spectre.data._base_datasets import PersistentDataset, GDSDataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) def _initialize_dataset( diff --git a/src/spectre/data/ct_rate.py b/src/spectre/data/ct_rate.py index 7b2cfbc..13540bb 100644 --- a/src/spectre/data/ct_rate.py +++ b/src/spectre/data/ct_rate.py @@ -2,9 +2,19 @@ from pathlib import Path from typing import Callable, Dict, List -from monai.data import Dataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) + +_PANDAS_IMPORT_ERROR = None +try: + import pandas as pd +except ImportError as e: + pd = None # type: ignore + _PANDAS_IMPORT_ERROR = e -from spectre.data._base_datasets import PersistentDataset, GDSDataset def _initialize_dataset( data_dir: str, @@ -20,7 +30,11 @@ def _initialize_dataset( image_paths = image_paths[:n_keep] if include_reports: - import pandas as pd + if _PANDAS_IMPORT_ERROR is not None: + raise ImportError( + "Pandas is required to include reports in the dataset but not installed. " + "Please install Pandas to use this feature." + ) from _PANDAS_IMPORT_ERROR text_path = os.path.join(Path(data_dir), 'dataset', "radiology_text_reports", f"{subset}_reports.xlsx" ) reports = pd.read_excel(text_path) if subset == "train": diff --git a/src/spectre/data/inspect.py b/src/spectre/data/inspect.py index 8b7b666..7cfaa64 100644 --- a/src/spectre/data/inspect.py +++ b/src/spectre/data/inspect.py @@ -2,9 +2,18 @@ from pathlib import Path from typing import Callable, List, Dict -from monai.data import Dataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) -from spectre.data._base_datasets import PersistentDataset, GDSDataset +_PANDAS_IMPORT_ERROR = None +try: + import pandas as pd +except ImportError as e: + pd = None # type: ignore + _PANDAS_IMPORT_ERROR = e def parse_name(image_path): @@ -24,7 +33,11 @@ def _initialize_dataset( image_paths = image_paths[:n_keep] if include_reports: - import pandas as pd + if _PANDAS_IMPORT_ERROR is not None: + raise ImportError( + "Pandas is required to include reports in the dataset but not installed. " + "Please install Pandas to use this feature." + ) from _PANDAS_IMPORT_ERROR text_path = os.path.join(Path(data_dir), "inspect2", "Final_Impressions.xlsx") reports = pd.read_excel(text_path) diff --git a/src/spectre/data/merlin.py b/src/spectre/data/merlin.py index 4f5be87..bf3fa74 100644 --- a/src/spectre/data/merlin.py +++ b/src/spectre/data/merlin.py @@ -2,10 +2,19 @@ from pathlib import Path from typing import Callable, List, Dict -import pandas as pd -from monai.data import Dataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) + +_PANDAS_IMPORT_ERROR = None +try: + import pandas as pd +except ImportError as e: + pd = None # type: ignore + _PANDAS_IMPORT_ERROR = e -from spectre.data._base_datasets import PersistentDataset, GDSDataset def parse_name(image_path): return image_path.name.replace(".nii.gz", "") @@ -17,6 +26,11 @@ def _initialize_dataset( subset: str = "train", fraction: float = 1.0, ) -> List[Dict[str, str]]: + if _PANDAS_IMPORT_ERROR is not None: + raise ImportError( + "Pandas is required to initialize the dataset but not installed. " + "Please install Pandas to use this dataset." + ) from _PANDAS_IMPORT_ERROR image_paths = sorted(Path(data_dir).glob(os.path.join( "merlinabdominalctdataset", "merlin_data", "*.nii.gz"))) diff --git a/src/spectre/data/nlst.py b/src/spectre/data/nlst.py index 74a5010..917ad02 100644 --- a/src/spectre/data/nlst.py +++ b/src/spectre/data/nlst.py @@ -2,9 +2,11 @@ from pathlib import Path from typing import Callable, Dict, List -from monai.data import Dataset - -from spectre.data._base_datasets import PersistentDataset, GDSDataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) def _initialize_dataset( diff --git a/src/spectre/data/panorama.py b/src/spectre/data/panorama.py index a1ed0b2..fe568ff 100644 --- a/src/spectre/data/panorama.py +++ b/src/spectre/data/panorama.py @@ -1,9 +1,11 @@ from pathlib import Path from typing import Callable, List, Dict -from monai.data import Dataset - -from spectre.data._base_datasets import PersistentDataset, GDSDataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) def _initialize_dataset( diff --git a/src/spectre/data/sinoct.py b/src/spectre/data/sinoct.py index 0a6800d..bfd7bc5 100644 --- a/src/spectre/data/sinoct.py +++ b/src/spectre/data/sinoct.py @@ -2,10 +2,18 @@ from pathlib import Path from typing import Callable, List, Dict, Union -import pandas as pd -from monai.data import Dataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) -from spectre.data._base_datasets import PersistentDataset, GDSDataset +_PANDAS_IMPORT_ERROR = None +try: + import pandas as pd +except ImportError as e: + pd = None # type: ignore + _PANDAS_IMPORT_ERROR = e def _initialize_dataset( @@ -14,7 +22,12 @@ def _initialize_dataset( split_ratio: tuple = (0.8, 0.1, 0.1), # train, val, test seed: int = 0, ) -> List[Dict[str, Union[str, int]]]: - + if _PANDAS_IMPORT_ERROR is not None: + raise ImportError( + "Pandas is required to initialize the dataset but not installed. " + "Please install Pandas to use this dataset." + ) from _PANDAS_IMPORT_ERROR + labels_df = pd.read_csv(Path(data_dir) / "labels.csv", index_col="patient_id") labels_df["abnormal"] = labels_df["label"].apply(lambda x: int(x.split(",")[1])) labels_df = labels_df[["abnormal"]] diff --git a/src/spectre/data/total_segmentator.py b/src/spectre/data/total_segmentator.py index 6a3d274..3134582 100644 --- a/src/spectre/data/total_segmentator.py +++ b/src/spectre/data/total_segmentator.py @@ -2,10 +2,18 @@ from pathlib import Path from typing import Callable, List, Union, Dict -import pandas as pd -from monai.data import Dataset +from spectre.data._base_datasets import ( + Dataset, + PersistentDataset, + GDSDataset, +) -from spectre.data._base_datasets import PersistentDataset, GDSDataset +_PANDAS_IMPORT_ERROR = None +try: + import pandas as pd +except ImportError as e: + pd = None # type: ignore + _PANDAS_IMPORT_ERROR = e LABEL_GROUPS = { @@ -68,6 +76,11 @@ def _initialize_dataset( ], subset: str = "train", ) -> List[Dict[str, str]]: + if _PANDAS_IMPORT_ERROR is not None: + raise ImportError( + "Pandas is required to initialize the dataset but not installed. " + "Please install Pandas to use this dataset." + ) from _PANDAS_IMPORT_ERROR image_paths = Path(data_dir).glob(os.path.join("*", "ct.nii.gz")) diff --git a/src/spectre/losses/mask_classification_loss.py b/src/spectre/losses/mask_classification_loss.py index 07ea8d6..a56d771 100644 --- a/src/spectre/losses/mask_classification_loss.py +++ b/src/spectre/losses/mask_classification_loss.py @@ -9,7 +9,13 @@ import torch.nn.functional as F import torch.distributed as dist import numpy as np -from scipy.optimize import linear_sum_assignment + +_SCIPY_IMPORT_ERROR = None +try: + from scipy.optimize import linear_sum_assignment +except ImportError as e: + linear_sum_assignment = None # type: ignore + _SCIPY_IMPORT_ERROR = e class MaskClassificationLoss(nn.Module): @@ -56,6 +62,11 @@ def __init__( (e.g. to downweight false-positive penalties). This is stored in `self.empty_weight` and passed to `CrossEntropyLoss`. """ + if _SCIPY_IMPORT_ERROR is not None: + raise ImportError( + "Scipy is required to use MaskClassificationLoss but not installed. " + "Please install Scipy to use this loss." + ) from _SCIPY_IMPORT_ERROR super().__init__() self.num_labels = num_labels self.num_points = num_points diff --git a/src/spectre/models/vision_transformer.py b/src/spectre/models/vision_transformer.py index 10fb7dd..adbd631 100644 --- a/src/spectre/models/vision_transformer.py +++ b/src/spectre/models/vision_transformer.py @@ -232,7 +232,7 @@ def __init__( self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, @@ -275,19 +275,20 @@ def __init__( self.init_weights() def init_weights(self) -> None: - if self.pos_embed is not None: + if self.pos_embed is not None and not self.pos_embed.is_meta: nn.init.trunc_normal_(self.pos_embed, std=.02) - if self.cls_token is not None: + if self.cls_token is not None and not self.cls_token.is_meta: nn.init.normal_(self.cls_token, std=1e-6) - if self.reg_token is not None: + if self.reg_token is not None and not self.reg_token.is_meta: nn.init.normal_(self.reg_token, std=1e-6) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: # this fn left here for compat with downstream users if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=.02) - if m.bias is not None: + if not m.weight.is_meta: + nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None and not m.bias.is_meta: nn.init.zeros_(m.bias) @torch.jit.ignore diff --git a/src/spectre/models/vision_transformer_features.py b/src/spectre/models/vision_transformer_features.py index 1975197..ef50f8b 100644 --- a/src/spectre/models/vision_transformer_features.py +++ b/src/spectre/models/vision_transformer_features.py @@ -134,7 +134,7 @@ def __init__( self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = [drop_path_rate * i / (depth - 1) if depth > 1 else 0.0 for i in range(depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, @@ -177,19 +177,20 @@ def __init__( self.init_weights() def init_weights(self) -> None: - if self.pos_embed is not None: + if self.pos_embed is not None and not self.pos_embed.is_meta: nn.init.trunc_normal_(self.pos_embed, std=.02) - if self.cls_token is not None: + if self.cls_token is not None and not self.cls_token.is_meta: nn.init.normal_(self.cls_token, std=1e-6) - if self.reg_token is not None: + if self.reg_token is not None and not self.reg_token.is_meta: nn.init.normal_(self.reg_token, std=1e-6) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: # this fn left here for compat with downstream users if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=.02) - if m.bias is not None: + if not m.weight.is_meta: + nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None and not m.bias.is_meta: nn.init.zeros_(m.bias) @torch.jit.ignore diff --git a/src/spectre/ssl/transforms/dino_transform.py b/src/spectre/ssl/transforms/dino_transform.py index cd92bd8..abf609e 100644 --- a/src/spectre/ssl/transforms/dino_transform.py +++ b/src/spectre/ssl/transforms/dino_transform.py @@ -2,35 +2,31 @@ from typing import Tuple, Mapping, Hashable, Any, List import torch -from monai.config import KeysCollection -from monai.transforms import ( - Compose, - LoadImaged, - EnsureChannelFirstd, - ScaleIntensityRanged, - Orientationd, - Spacingd, - CenterSpatialCropd, - SpatialPadd, - EnsureTyped, - RandSpatialCropSamplesd, - SelectItemsd, - RandSpatialCropSamples, - RandFlip, - OneOf, - RandGaussianSharpen, - RandGaussianSmooth, - RandGaussianNoise, - RandAdjustContrast, - Resize, - MapTransform, - Randomizable, - LazyTransform, -) + +MONAI_IMPORT_ERROR = None +try: + import monai.transforms as transforms + from monai.config import KeysCollection +except ImportError as e: + transforms = None # type: ignore + KeysCollection = Any # type: ignore + MONAI_IMPORT_ERROR = e from spectre.transforms import RandScaleIntensityRange +if transforms is not None: + Compose = transforms.Compose + _BaseClass = type("_BaseClass", ( + transforms.Randomizable, + transforms.MapTransform, + transforms.LazyTransform, + ), {}) +else: + Compose = object # type: ignore + _BaseClass = object # type: ignore + + class DINOTransform(Compose): def __init__( self, @@ -42,6 +38,12 @@ def __init__( dtype: str = "float32", use_gds: bool = False, ): + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use DINOTransform but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR + assert dtype in ["float16", "float32"], \ "dtype must be either 'float16' or 'float32'" @@ -51,12 +53,12 @@ def __init__( ) super().__init__([ - LoadImaged(keys=("image",)), - EnsureChannelFirstd( + transforms.LoadImaged(keys=("image",)), + transforms.EnsureChannelFirstd( keys=("image",), channel_dim="no_channel" ), - ScaleIntensityRanged( + transforms.ScaleIntensityRanged( keys=("image",), a_min=-1000, a_max=1000, @@ -64,26 +66,26 @@ def __init__( b_max=1.0, clip=True, ), - Orientationd(keys=("image",), axcodes="RAS"), - Spacingd( + transforms.Orientationd(keys=("image",), axcodes="RAS"), + transforms.Spacingd( keys=("image",), pixdim=(0.5, 0.5, 1.0), # comply with newest scanners mode=("bilinear",), ), - CenterSpatialCropd( + transforms.CenterSpatialCropd( keys=("image",), roi_size=(512, 512, 384), ), - SpatialPadd( + transforms.SpatialPadd( keys=("image",), spatial_size=base_crop_size, ), - EnsureTyped( + transforms.EnsureTyped( keys=("image",), dtype=getattr(torch, dtype), device=device, ), - RandSpatialCropSamplesd( + transforms.RandSpatialCropSamplesd( keys=("image",), num_samples=num_base_patches, roi_size=base_crop_size, @@ -99,13 +101,13 @@ def __init__( num_local_views=num_local_views, dtype=dtype, ), - SelectItemsd( + transforms.SelectItemsd( keys=("image_global_views", "image_local_views"), ), ]) -class DINORandomCropTransformd(Randomizable, MapTransform, LazyTransform): +class DINORandomCropTransformd(_BaseClass): def __init__( self, keys: KeysCollection, @@ -117,14 +119,20 @@ def __init__( dtype: str = "float32", lazy: bool = False, ) -> None: - MapTransform.__init__(self, keys) - LazyTransform.__init__(self, lazy) + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use DINORandomCropTransformd but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR + + transforms.MapTransform.__init__(self, keys) + transforms.LazyTransform.__init__(self, lazy) self.global_views_size = global_views_size self.local_views_size = local_views_size self.local_views_scale = local_views_scale self.num_local_views = num_local_views - self.cropper_global = RandSpatialCropSamples( + self.cropper_global = transforms.RandSpatialCropSamples( roi_size=tuple(int(local_views_scale[1] * sz) for sz in base_crop_size), num_samples=2, max_roi_size=base_crop_size, @@ -132,7 +140,7 @@ def __init__( random_size=True, lazy=lazy, ) - self.cropper_local = RandSpatialCropSamples( + self.cropper_local = transforms.RandSpatialCropSamples( roi_size=tuple(int(self.local_views_scale[0] * sz) for sz in base_crop_size), num_samples=num_local_views, max_roi_size=tuple(int(self.local_views_scale[1] * sz) for sz in base_crop_size), @@ -141,14 +149,14 @@ def __init__( lazy=lazy, ) - self.resize_global = Resize( + self.resize_global = transforms.Resize( spatial_size=global_views_size, mode="trilinear", dtype=getattr(torch, dtype), # worst case 0.1-0.3% error for fp16 anti_aliasing=True, # downsample ratios up to 2 lazy=lazy, ) - self.resize_local = Resize( + self.resize_local = transforms.Resize( spatial_size=local_views_size, mode="trilinear", dtype=getattr(torch, dtype), # worst case 0.1-0.3% error for fp16 @@ -156,23 +164,23 @@ def __init__( lazy=lazy, ) - self.augmentor = Compose([ - RandFlip(spatial_axis=0, prob=0.5), - RandFlip(spatial_axis=1, prob=0.5), - RandFlip(spatial_axis=2, prob=0.5), - OneOf([ - RandGaussianSharpen( + self.augmentor = transforms.Compose([ + transforms.RandFlip(spatial_axis=0, prob=0.5), + transforms.RandFlip(spatial_axis=1, prob=0.5), + transforms.RandFlip(spatial_axis=2, prob=0.5), + transforms.OneOf([ + transforms.RandGaussianSharpen( sigma1_x=(1.5, 2.5), sigma1_y=(1.5, 2.5), sigma1_z=(0.75, 1.25), sigma2_x=(0.5, 1.0), sigma2_y=(0.5, 1.0), sigma2_z=(0.25, 0.5), prob=0.25, ), - RandGaussianSmooth( + transforms.RandGaussianSmooth( sigma_x=(1.5, 2.5), sigma_y=(1.5, 2.5), sigma_z=(0.75, 1.25), prob=0.25, ), ]), - RandAdjustContrast(gamma=(0.9, 1.1), prob=0.25), - RandGaussianNoise(std=0.1, sample_std=True, prob=0.25), + transforms.RandAdjustContrast(gamma=(0.9, 1.1), prob=0.25), + transforms.RandGaussianNoise(std=0.1, sample_std=True, prob=0.25), RandScaleIntensityRange( a_min=(0.0, 0.4), # [0.0 * 2000 - 1000, 0.4 * 2000 - 1000] = [-1000, -200] a_max=(0.6, 1.0), # [0.6 * 2000 - 1000, 1.0 * 2000 - 1000] = [200, 1000] diff --git a/src/spectre/ssl/transforms/mae_transform.py b/src/spectre/ssl/transforms/mae_transform.py index 759b20d..c34f3b0 100644 --- a/src/spectre/ssl/transforms/mae_transform.py +++ b/src/spectre/ssl/transforms/mae_transform.py @@ -1,20 +1,19 @@ from typing import Tuple import torch -from monai.transforms import ( - Compose, - LoadImaged, - EnsureChannelFirstd, - ScaleIntensityRanged, - Orientationd, - Spacingd, - SpatialPadd, - CastToTyped, - ResizeWithPadOrCropd, - RandSpatialCropSamplesd, - RandSpatialCropd, - Resized, -) + +MONAI_IMPORT_ERROR = None +try: + import monai.transforms as transforms +except ImportError as e: + transforms = None # type: ignore + MONAI_IMPORT_ERROR = e + + +if transforms is not None: + Compose = transforms.Compose +else: + Compose = object # type: ignore class MAETransform(Compose): @@ -23,12 +22,18 @@ def __init__( input_size: Tuple[int, int, int] = (128, 128, 64), dtype: str = "float32", ): + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use MAETransform but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR + assert dtype in ["float16", "float32"], "dtype must be either 'float16' or 'float32'" super().__init__( [ - LoadImaged(keys=("image",)), - EnsureChannelFirstd(keys=("image",), channel_dim="no_channel"), - ScaleIntensityRanged( + transforms.LoadImaged(keys=("image",)), + transforms.EnsureChannelFirstd(keys=("image",), channel_dim="no_channel"), + transforms.ScaleIntensityRanged( keys=("image",), a_min=-1000, a_max=1000, @@ -36,12 +41,12 @@ def __init__( b_max=1.0, clip=True ), - Orientationd(keys=("image",), axcodes="RAS"), - Spacingd(keys=("image",), pixdim=(0.75, 0.75, 1.5), mode=("bilinear",)), - ResizeWithPadOrCropd(keys=("image",), spatial_size=(384, 384, -1)), - SpatialPadd(keys=("image",), spatial_size=(-1, -1, input_size[2])), - CastToTyped(keys=("image",), dtype=getattr(torch, dtype)), - RandSpatialCropSamplesd( + transforms.Orientationd(keys=("image",), axcodes="RAS"), + transforms.Spacingd(keys=("image",), pixdim=(0.75, 0.75, 1.5), mode=("bilinear",)), + transforms.ResizeWithPadOrCropd(keys=("image",), spatial_size=(384, 384, -1)), + transforms.SpatialPadd(keys=("image",), spatial_size=(-1, -1, input_size[2])), + transforms.CastToTyped(keys=("image",), dtype=getattr(torch, dtype)), + transforms.RandSpatialCropSamplesd( keys=("image",), roi_size=input_size, num_samples=36, @@ -49,36 +54,13 @@ def __init__( random_size=False, ), # Do a random resized crop - RandSpatialCropd( + transforms.RandSpatialCropd( keys=("image",), roi_size=tuple(int(sz * 0.34) for sz in input_size), # 0.34 = (0.2 ** 2) ** (1/3) max_roi_size=input_size, random_center=True, random_size=True, ), - Resized(keys=("image",), spatial_size=input_size), + transforms.Resized(keys=("image",), spatial_size=input_size), ] ) - - -if __name__ == "__main__": - - # Save some example data after transforming it. - import os - import SimpleITK as sitk - - data = {"image": r"data/test_data/train_1_a_1.nii.gz"} - transform = MAETransform() - transformed_data = transform(data) - - # Save the different crops to a folder for visualization. - output_dir = r"data/test_data/mae_transform_output" - os.makedirs(output_dir, exist_ok=True) - - for i, patch in enumerate(transformed_data): - - # Save the crops - patch_img = sitk.GetImageFromArray(patch["image"].squeeze(0).numpy()) - patch_img.SetSpacing((1.5, 0.75, 0.75)) - patch_path = os.path.join(output_dir, f"{i}_crop.nii.gz") - sitk.WriteImage(patch_img, patch_path) diff --git a/src/spectre/ssl/transforms/siglip_transform.py b/src/spectre/ssl/transforms/siglip_transform.py index b106e7f..e9703a2 100644 --- a/src/spectre/ssl/transforms/siglip_transform.py +++ b/src/spectre/ssl/transforms/siglip_transform.py @@ -1,24 +1,23 @@ from typing import Tuple import torch -from monai.transforms import ( - Compose, - LoadImaged, - EnsureChannelFirstd, - ScaleIntensityRanged, - Orientationd, - Spacingd, - ResizeWithPadOrCropd, - EnsureTyped, - RandSpatialCropd, - RandFlipd, - GridPatchd, - SelectItemsd, -) + +MONAI_IMPORT_ERROR = None +try: + import monai.transforms as transforms +except ImportError as e: + transforms = None # type: ignore + MONAI_IMPORT_ERROR = e from spectre.transforms import RandomReportTransformd +if transforms is not None: + Compose = transforms.Compose +else: + Compose = object # type: ignore + + class SigLIPTransform(Compose): def __init__( self, @@ -32,6 +31,11 @@ def __init__( dtype: str = "float32", use_gds: bool = False, ): + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use SigLIPTransform but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR assert dtype in ["float16", "float32"], \ "dtype must be either 'float16' or 'float32'" @@ -42,12 +46,12 @@ def __init__( ) super().__init__([ - LoadImaged(keys=("image",)), - EnsureChannelFirstd( + transforms.LoadImaged(keys=("image",)), + transforms.EnsureChannelFirstd( keys=("image",), channel_dim="no_channel" ), - ScaleIntensityRanged( + transforms.ScaleIntensityRanged( keys=("image",), a_min=-1000, a_max=1000, @@ -55,30 +59,30 @@ def __init__( b_max=1.0, clip=True ), - Orientationd(keys=("image",), axcodes="RAS"), - Spacingd( + transforms.Orientationd(keys=("image",), axcodes="RAS"), + transforms.Spacingd( keys=("image",), pixdim=image_pixdim, mode=("bilinear",) ), - ResizeWithPadOrCropd( + transforms.ResizeWithPadOrCropd( keys=("image",), spatial_size=base_crop_size ), - EnsureTyped( + transforms.EnsureTyped( keys=("image",), dtype=getattr(torch, dtype), device=device ), - RandSpatialCropd( + transforms.RandSpatialCropd( keys=("image",), roi_size=image_size, random_size=False, ), - RandFlipd(keys=("image",), spatial_axis=0, prob=0.5), - RandFlipd(keys=("image",), spatial_axis=1, prob=0.5), - RandFlipd(keys=("image",), spatial_axis=2, prob=0.5), - GridPatchd( + transforms.RandFlipd(keys=("image",), spatial_axis=0, prob=0.5), + transforms.RandFlipd(keys=("image",), spatial_axis=1, prob=0.5), + transforms.RandFlipd(keys=("image",), spatial_axis=2, prob=0.5), + transforms.GridPatchd( keys=("image",), patch_size=sliding_window_size, overlap=0.0, @@ -89,7 +93,7 @@ def __init__( keep_original_prob=keep_original_prob, drop_prob=drop_prob, ), - SelectItemsd( + transforms.SelectItemsd( keys=("image", "report"), ), ]) diff --git a/src/spectre/transforms/combine_labels.py b/src/spectre/transforms/combine_labels.py index dc7d675..0071347 100644 --- a/src/spectre/transforms/combine_labels.py +++ b/src/spectre/transforms/combine_labels.py @@ -1,10 +1,18 @@ -from typing import Sequence, Optional, Hashable, Mapping +from typing import Any, Sequence, Optional, Hashable, Mapping import torch import numpy as np -from monai.config import KeysCollection -from monai.transforms import MapTransform -from monai.config.type_definitions import NdarrayOrTensor + +MONAI_IMPORT_ERROR = None +try: + from monai.config import KeysCollection + from monai.transforms import MapTransform + from monai.config.type_definitions import NdarrayOrTensor +except ImportError as e: + KeysCollection = Any # type: ignore + MapTransform = object # type: ignore + NdarrayOrTensor = Any # type: ignore + MONAI_IMPORT_ERROR = e class CombineLabelsd(MapTransform): @@ -25,6 +33,11 @@ def __init__( labels: Optional[Sequence[int]] = None, allow_missing_keys: bool = False, ) -> None: + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use CombineLabelsd but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR if labels is not None and len(keys) != len(labels): raise ValueError("The number of keys must match the number of labels provided.") diff --git a/src/spectre/transforms/generate_report.py b/src/spectre/transforms/generate_report.py index cc77a10..2c591b4 100644 --- a/src/spectre/transforms/generate_report.py +++ b/src/spectre/transforms/generate_report.py @@ -1,11 +1,22 @@ from copy import deepcopy from typing import Any, Mapping, Hashable -from monai.config import KeysCollection -from monai.transforms import MapTransform, Randomizable +MONAI_IMPORT_ERROR = None +try: + from monai.config import KeysCollection + from monai.transforms import MapTransform, Randomizable +except ImportError as e: + KeysCollection = Any # type: ignore + MONAI_IMPORT_ERROR = e -class RandomReportTransformd(Randomizable, MapTransform): +if MONAI_IMPORT_ERROR is None: + _BaseClass = type("_BaseClass", (Randomizable, MapTransform), {}) +else: + _BaseClass = object + + +class RandomReportTransformd(_BaseClass): def __init__( self, keys: KeysCollection, @@ -14,6 +25,12 @@ def __init__( drop_prob=0.3, allow_missing_keys: bool = False, ): + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use RandomReportTransformd but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR + assert all(str(key) in ["findings", "impressions", "icd10"] for key in keys), \ "keys must be one of ['findings', 'impressions', 'icd10']" @@ -89,90 +106,3 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: ret["report"] = f"{findings}{impressions}{icd10}" return ret - - -# class GenerateReportTransform(Randomizable, MapTransform): -# def __init__( -# self, -# keys: KeysCollection, -# max_num_icd10=20, -# likelihood_original=0.5, -# drop_chance=0.3, -# allow_missing_keys: bool = False, -# ): -# super().__init__(keys, allow_missing_keys) -# self.max_num_icd10 = max_num_icd10 -# self.likelihood_original = likelihood_original -# self.drop_chance = drop_chance - -# # Random states (purely indices/flags) -# self.drop_findings = False -# self.drop_icd10 = False -# self.finding_idx = None -# self.impression_idx = None -# self.icd10_indices = [] - -# def randomize(self, data): -# findings = data.get("findings", []) -# impressions = data.get("impressions", []) -# icd10_codes = data.get("icd10", []) - -# if isinstance(icd10_codes, str): -# icd10_codes = icd10_codes.split(";") -# if not isinstance(icd10_codes, list): -# icd10_codes = [] - -# self.drop_findings = self.R.random() < self.drop_chance -# self.drop_icd10 = self.R.random() < self.drop_chance -# self.finding_idx = None -# self.impression_idx = None -# self.icd10_indices = [] - -# if not self.drop_findings and findings: -# num_elements = len(findings) -# if num_elements == 1: -# self.finding_idx = 0 -# else: -# weights = [self.likelihood_original] + [(1 - self.likelihood_original) / (num_elements - 1)] * (num_elements - 1) -# self.finding_idx = int(self.R.choice(np.arange(num_elements), p=weights)) - -# if impressions: -# num_elements = len(impressions) -# if num_elements == 1: -# self.impression_idx = 0 -# else: -# weights = [self.likelihood_original] + [(1 - self.likelihood_original) / (num_elements - 1)] * (num_elements - 1) -# self.impression_idx = int(self.R.choice(np.arange(num_elements), p=weights)) - -# if not self.drop_icd10 and icd10_codes: -# num_codes = min(self.max_num_icd10, len(icd10_codes)) -# self.icd10_indices = self.R.choice(len(icd10_codes), size=num_codes, replace=False).tolist() - -# def __call__(self, data): -# self.randomize(data) - -# findings = data.get("findings", []) -# impressions = data.get("impressions", []) -# icd10_codes = data.get("icd10", []) - -# if isinstance(icd10_codes, str): -# icd10_codes = icd10_codes.split(";") -# if not isinstance(icd10_codes, list): -# icd10_codes = [] - -# report = "" - -# if self.finding_idx is not None and self.finding_idx < len(findings): -# finding = findings[self.finding_idx].replace("Impressions", "").replace("impressions", "") -# report += f"Findings: {finding}\n" - -# if self.impression_idx is not None and self.impression_idx < len(impressions): -# impression = impressions[self.impression_idx] -# report += f"Impressions: {impression}\n" - -# if self.icd10_indices: -# selected_icd10 = [icd10_codes[i] for i in self.icd10_indices if i < len(icd10_codes)] -# report += f"ICD10: {'; '.join(selected_icd10)}\n" - -# data["report"] = report -# return data \ No newline at end of file diff --git a/src/spectre/transforms/largest_multiple_crop.py b/src/spectre/transforms/largest_multiple_crop.py index 6a12f84..8a72c58 100644 --- a/src/spectre/transforms/largest_multiple_crop.py +++ b/src/spectre/transforms/largest_multiple_crop.py @@ -1,8 +1,16 @@ -from typing import Sequence +from typing import Any, Sequence import numpy as np -from monai.config import KeysCollection -from monai.transforms import Cropd, CenterSpatialCrop + +MONAI_IMPORT_ERROR = None +try: + from monai.config import KeysCollection + from monai.transforms import Cropd, CenterSpatialCrop +except ImportError as e: + KeysCollection = Any # type: ignore + Cropd = object # type: ignore + CenterSpatialCrop = object # type: ignore + MONAI_IMPORT_ERROR = e class LargestMultipleCenterCropd(Cropd): @@ -22,6 +30,12 @@ def __init__( allow_missing_keys: bool = False, lazy: bool = False, ) -> None: + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use LargestMultipleCenterCropd but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR + self.patch_size = patch_size cropper = CenterSpatialCrop(roi_size=patch_size, lazy=lazy) # Placeholder, will be reset per image super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) diff --git a/src/spectre/transforms/scale_intensity_range.py b/src/spectre/transforms/scale_intensity_range.py index 85055c4..b94f469 100644 --- a/src/spectre/transforms/scale_intensity_range.py +++ b/src/spectre/transforms/scale_intensity_range.py @@ -1,14 +1,29 @@ -from typing import Optional, Tuple, Union import warnings -import numpy as np +from typing import Any, Optional, Tuple, Union + import torch +import numpy as np + +MONAI_IMPORT_ERROR = None +try: + from monai.transforms import Transform, Randomizable + from monai.config.type_definitions import DtypeLike, NdarrayOrTensor + from monai.utils import convert_to_tensor, convert_data_type +except ImportError as e: + DtypeLike = Any # type: ignore + NdarrayOrTensor = Any # type: ignore + convert_to_tensor = lambda x: x # type: ignore + convert_data_type = lambda x, dtype: (x, None, None) # type: ignore + MONAI_IMPORT_ERROR = e + -from monai.transforms import Transform, Randomizable -from monai.config.type_definitions import DtypeLike, NdarrayOrTensor -from monai.utils import convert_to_tensor, convert_data_type +if MONAI_IMPORT_ERROR is None: + _BaseClass = type("_BaseClass", (Randomizable, Transform), {}) +else: + _BaseClass = object -class RandScaleIntensityRange(Randomizable, Transform): +class RandScaleIntensityRange(_BaseClass): """ Randomizable variant of ScaleIntensityRange that samples the input window (a_min, a_max) per-call using MONAI's RNG (self.R). @@ -35,6 +50,12 @@ def __init__( dtype: DtypeLike = np.float32, prob: float = 1.0, ) -> None: + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use RandScaleIntensityRange but not installed. " + "Please install MONAI to use this transform." + ) from MONAI_IMPORT_ERROR + Transform.__init__(self) Randomizable.__init__(self) self.a_min = a_min diff --git a/src/spectre/utils/_utils.py b/src/spectre/utils/_utils.py index e131222..9b9952b 100644 --- a/src/spectre/utils/_utils.py +++ b/src/spectre/utils/_utils.py @@ -3,14 +3,25 @@ from itertools import repeat import torch -import monai import numpy as np +MONAI_IMPORT_ERROR = None +try: + import monai +except ImportError as e: + monai = None # type: ignore + MONAI_IMPORT_ERROR = e def fix_random_seeds(seed: int = 31): """ Fix random seeds. """ + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use fix_random_seeds but not installed. " + "Please install MONAI to use this function." + ) from MONAI_IMPORT_ERROR + torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) diff --git a/src/spectre/utils/collate.py b/src/spectre/utils/collate.py index 6496721..b4c8c0b 100644 --- a/src/spectre/utils/collate.py +++ b/src/spectre/utils/collate.py @@ -1,7 +1,13 @@ from typing import List, Callable, Optional import torch -from monai.data import list_data_collate + +MONAI_IMPORT_ERROR = None +try: + from monai.data import list_data_collate +except ImportError as e: + list_data_collate = lambda x: x # type: ignore + MONAI_IMPORT_ERROR = e def extended_collate_dino(samples_list: List) -> dict: @@ -19,6 +25,12 @@ def extended_collate_dino(samples_list: List) -> dict: Returns: A dictionary with collated global/local crops and corresponding masks. """ + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use extended_collate_dino but not installed. " + "Please install MONAI to use this collate function." + ) from MONAI_IMPORT_ERROR + # Apply MONAI's list_data_collate collated_data = list_data_collate(samples_list) @@ -50,6 +62,12 @@ def extended_collate_siglip( Returns: A dictionary with collated images and tokenized text. """ + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use extended_collate_siglip but not installed. " + "Please install MONAI to use this collate function." + ) from MONAI_IMPORT_ERROR + collated_data = list_data_collate(samples_list) if return_filenames: @@ -84,6 +102,12 @@ def collate_add_filenames(samples_list: List) -> dict: Returns: A dictionary with collated images and filenames. """ + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use collate_add_filenames but not installed. " + "Please install MONAI to use this collate function." + ) from MONAI_IMPORT_ERROR + collated_data = list_data_collate(samples_list) if "image" in collated_data.keys(): diff --git a/src/spectre/utils/config.py b/src/spectre/utils/config.py index 5206cdb..6ec8f1b 100644 --- a/src/spectre/utils/config.py +++ b/src/spectre/utils/config.py @@ -1,10 +1,15 @@ import os import math -from omegaconf import OmegaConf - from spectre.utils import _utils, distributed +OMEGACONF_IMPORT_ERROR = None +try: + from omegaconf import OmegaConf +except ImportError as e: + OmegaConf = None # type: ignore + OMEGACONF_IMPORT_ERROR = e + def apply_scaling_rules_to_cfg(cfg): """ @@ -38,6 +43,12 @@ def apply_scaling_rules_to_cfg(cfg): def write_config(cfg, output_dir, name="config.yaml"): + if OMEGACONF_IMPORT_ERROR is not None: + raise ImportError( + "OmegaConf is required to use write_config but not installed. " + "Please install OmegaConf to use this function." + ) from OMEGACONF_IMPORT_ERROR + saved_cfg_path = os.path.join(output_dir, name) with open(saved_cfg_path, "w") as f: OmegaConf.save(config=cfg, f=f) @@ -45,6 +56,12 @@ def write_config(cfg, output_dir, name="config.yaml"): def get_cfg_from_args(args, default_config): + if OMEGACONF_IMPORT_ERROR is not None: + raise ImportError( + "OmegaConf is required to use get_cfg_from_args but not installed. " + "Please install OmegaConf to use this function." + ) from OMEGACONF_IMPORT_ERROR + args.output_dir = os.path.abspath(args.output_dir) args.opts = [] if args.opts is None else args.opts args.opts += [f"train.output_dir={args.output_dir}"] diff --git a/src/spectre/utils/dataloader.py b/src/spectre/utils/dataloader.py index efdcb6d..fb5a4d4 100644 --- a/src/spectre/utils/dataloader.py +++ b/src/spectre/utils/dataloader.py @@ -1,10 +1,18 @@ +from __future__ import annotations import os from typing import Union, Callable, Optional, List import torch -import monai.data as data from torch.utils.data import ConcatDataset +MONAI_IMPORT_ERROR = None +try: + import monai.data as data +except ImportError as e: + data = None # type: ignore + MONAI_IMPORT_ERROR = e + + def get_dataloader( datasets: Union[str, List[str]], @@ -24,10 +32,15 @@ def get_dataloader( drop_last: bool = True, persistent_workers: bool = True, use_thread: bool = False, -) -> data.DataLoader: +) -> "DataLoader": """ Get dataloader for training. """ + if MONAI_IMPORT_ERROR is not None: + raise ImportError( + "MONAI is required to use get_dataloader but not installed. " + "Please install MONAI to use this function." + ) from MONAI_IMPORT_ERROR if isinstance(datasets, str): datasets = [datasets] diff --git a/src/spectre/utils/distributed.py b/src/spectre/utils/distributed.py index d41a35a..0dfa898 100644 --- a/src/spectre/utils/distributed.py +++ b/src/spectre/utils/distributed.py @@ -1,7 +1,14 @@ import os import torch.distributed as dist -from accelerate import Accelerator, DataLoaderConfiguration + +ACCELERATE_IMPORT_ERROR = None +try: + from accelerate import Accelerator, DataLoaderConfiguration +except ImportError as e: + Accelerator = None # type: ignore + DataLoaderConfiguration = None # type: ignore + ACCELERATE_IMPORT_ERROR = e def is_enabled() -> bool: @@ -56,6 +63,11 @@ def init_distributed(cfg): """ Initialize distributed training. """ + if ACCELERATE_IMPORT_ERROR is not None: + raise ImportError( + "Accelerate is required to use init_distributed but not installed. " + "Please install Accelerate to use this function." + ) from ACCELERATE_IMPORT_ERROR # Initialize accelerator dataloader_config = DataLoaderConfiguration(