Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions benchmark_utils/image_dataset.py

This file was deleted.

5 changes: 0 additions & 5 deletions config.yml

This file was deleted.

118 changes: 0 additions & 118 deletions datasets/bsd500_bsd20.py

This file was deleted.

13 changes: 6 additions & 7 deletions datasets/bsd500_cbsd68.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from benchopt import BaseDataset, safe_import_context, config
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path

with safe_import_context() as import_ctx:
import deepinv as dinv
import torch
from torchvision import transforms
from datasets import load_dataset
from benchmark_utils.image_dataset import ImageDataset
from benchmark_utils.hugging_face_torch_dataset import (
HuggingFaceTorchDataset
)
Expand Down Expand Up @@ -77,8 +77,9 @@ def get_data(self):
transforms.ToTensor()
])

train_dataset = ImageDataset(
config.get_data_path("BSD500") / "train", transform=transform
path = get_data_path("BSD500")
train_dataset = dinv.datasets.BSDS500(
path, download=True, transform=transform
)

dataset_cbsd68 = load_dataset("deepinv/CBSD68")
Expand All @@ -90,9 +91,7 @@ def get_data(self):
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
save_dir=config.get_data_path(
key="generated_datasets"
) / "bsd500_cbsd68",
save_dir=get_data_path("bsd500_cbsd68"),
dataset_filename=self.task,
device=device
)
Expand Down
14 changes: 6 additions & 8 deletions datasets/bsd500_imnet100.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from benchopt import BaseDataset, safe_import_context, config
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path

with safe_import_context() as import_ctx:
import deepinv as dinv
import torch
from torchvision import transforms
from benchmark_utils.image_dataset import ImageDataset
from benchmark_utils.hugging_face_torch_dataset import (
HuggingFaceTorchDataset
)
Expand Down Expand Up @@ -77,9 +77,9 @@ def get_data(self):
transforms.ToTensor()
])

train_dataset = ImageDataset(
config.get_data_path("BSD500") / "train",
transform=transform
path = get_data_path("BSD500")
train_dataset = dinv.datasets.BSDS500(
path, download=True, transform=transform
)

dataset_miniImnet100 = load_dataset("mterris/miniImnet100")
Expand All @@ -93,9 +93,7 @@ def get_data(self):
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
save_dir=config.get_data_path(
key="generated_datasets"
) / "bsd500_imnet100",
save_dir=get_data_path("bsd500_imnet100"),
dataset_filename=self.task,
device=device
)
Expand Down
7 changes: 3 additions & 4 deletions datasets/cbsd68_set3c.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from benchopt import BaseDataset, safe_import_context, config
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path

with safe_import_context() as import_ctx:
import deepinv as dinv
Expand Down Expand Up @@ -91,9 +92,7 @@ def get_data(self):
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
save_dir=config.get_data_path(
key="generated_datasets"
) / "sbsd68_set3c",
save_dir=get_data_path("cbsd68_set3c"),
dataset_filename=self.task,
device=device
)
Expand Down
7 changes: 7 additions & 0 deletions test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ def check_test_solver_install(solver_class):
detecting the situation.
"""
pass


def check_test_dataset_get_data(benchmark, dataset_class):
if sys.platform == "darwin":
pytest.skip(
"Skipping test_dataset_get_data on MacOS."
)
Loading