Skip to content

Commit 2981309

Browse files
authored
Merge pull request #221 from Refound-445/master
Support For 3D Custom Dataset
2 parents b4aee37 + 7c14e34 commit 2981309

2 files changed

Lines changed: 74 additions & 8 deletions

File tree

guided_diffusion/custom_dataset_loader.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
from skimage.transform import rotate
1515
from glob import glob
1616
from sklearn.model_selection import train_test_split
17+
import nibabel
18+
1719

1820
class CustomDataset(Dataset):
19-
def __init__(self, args, data_path , transform = None, mode = 'Training',plane = False):
21+
def __init__(self, args, data_path, transform=None, mode="Training", plane=False):
2022

21-
print("loading data from the directory :",data_path)
22-
path=data_path
23+
print("loading data from the directory :", data_path)
24+
path = data_path
2325
images = sorted(glob(os.path.join(path, "images/*.png")))
2426
masks = sorted(glob(os.path.join(path, "masks/*.png")))
2527

@@ -37,12 +39,12 @@ def __getitem__(self, index):
3739
"""Get the images"""
3840
name = self.name_list[index]
3941
img_path = os.path.join(name)
40-
42+
4143
mask_name = self.label_list[index]
4244
msk_path = os.path.join(mask_name)
4345

44-
img = Image.open(img_path).convert('RGB')
45-
mask = Image.open(msk_path).convert('L')
46+
img = Image.open(img_path).convert("RGB")
47+
mask = Image.open(msk_path).convert("L")
4648

4749
# if self.mode == 'Training':
4850
# label = 0 if self.label_list[index] == 'benign' else 1
@@ -60,3 +62,60 @@ def __getitem__(self, index):
6062
# return (img, mask, name)
6163
# else:
6264
# return (img, mask, name)
65+
66+
67+
class CustomDataset3D(torch.utils.data.Dataset):
68+
def __init__(self, data_path, transform):
69+
super().__init__()
70+
71+
print("loading data from the directory :", data_path)
72+
path = data_path
73+
images = sorted(glob(os.path.join(path, "images/*.nii.gz")))
74+
masks = sorted(glob(os.path.join(path, "masks/*.nii.gz")))
75+
76+
assert len(images) == len(masks), "Number of images and masks must be the same"
77+
78+
self.valid_cases = [(img_path, seg_path) for img_path, seg_path in zip(images, masks)]
79+
80+
self.all_slices = []
81+
for case_idx, (img_path, seg_path) in enumerate(self.valid_cases):
82+
seg_vol = nibabel.load(seg_path)
83+
img = nibabel.load(img_path)
84+
assert (
85+
img.shape == seg_vol.shape
86+
), f"Image and segmentation shape mismatch: {img.shape} vs {seg_vol.shape}, Flies: {img_path}, {seg_path}"
87+
num_slices = img.shape[-1]
88+
self.all_slices.extend(
89+
[(case_idx, slice_idx) for slice_idx in range(num_slices)]
90+
)
91+
92+
self.data_path = path
93+
94+
self.transform = transform
95+
96+
def __len__(self):
97+
return len(self.all_slices)
98+
99+
def __getitem__(self, x):
100+
case_idx, slice_idx = self.all_slices[x]
101+
img_path, seg_path = self.valid_cases[case_idx]
102+
103+
nib_img = nibabel.load(img_path)
104+
nib_seg = nibabel.load(seg_path)
105+
106+
image = torch.tensor(nib_img.get_fdata(),dtype=torch.float32)[:, :, slice_idx].unsqueeze(0).unsqueeze(0)
107+
label = torch.tensor(nib_seg.get_fdata(),dtype=torch.float32)[:, :, slice_idx].unsqueeze(0).unsqueeze(0)
108+
label = torch.where(
109+
label > 0, 1, 0
110+
).float() # merge all tumor classes into one
111+
112+
if self.transform:
113+
state = torch.get_rng_state()
114+
image = self.transform(image)
115+
torch.set_rng_state(state)
116+
label = self.transform(label)
117+
return (
118+
image,
119+
label,
120+
img_path.split(".nii")[0] + "_slice" + str(slice_idx) + ".nii",
121+
) # virtual path

scripts/segmentation_train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from guided_diffusion.resample import create_named_schedule_sampler
88
from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D
99
from guided_diffusion.isicloader import ISICDataset
10-
from guided_diffusion.custom_dataset_loader import CustomDataset
10+
from guided_diffusion.custom_dataset_loader import CustomDataset,CustomDataset3D
1111
from guided_diffusion.script_util import (
1212
model_and_diffusion_defaults,
1313
create_model_and_diffusion,
1414
args_to_dict,
1515
add_dict_to_argparser,
1616
)
1717
import torch as th
18+
from pathlib import Path
1819
from guided_diffusion.train_util import TrainLoop
1920
from visdom import Visdom
2021
viz = Visdom(port=8850)
@@ -40,7 +41,13 @@ def main():
4041

4142
ds = BRATSDataset3D(args.data_dir, transform_train, test_flag=False)
4243
args.in_ch = 5
43-
else :
44+
elif any(Path(args.data_dir).glob("*\*.nii.gz")):
45+
tran_list = [transforms.Resize((args.image_size,args.image_size)),]
46+
transform_train = transforms.Compose(tran_list)
47+
print("Your current directory : ",args.data_dir)
48+
ds = CustomDataset3D(args, args.data_dir, transform_train)
49+
args.in_ch = 4
50+
else:
4451
tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),]
4552
transform_train = transforms.Compose(tran_list)
4653
print("Your current directory : ",args.data_dir)

0 commit comments

Comments
 (0)