1414from skimage .transform import rotate
1515from glob import glob
1616from sklearn .model_selection import train_test_split
17+ import nibabel
18+
1719
1820class CustomDataset (Dataset ):
19- def __init__ (self , args , data_path , transform = None , mode = ' Training' , plane = False ):
21+ def __init__ (self , args , data_path , transform = None , mode = " Training" , plane = False ):
2022
21- print ("loading data from the directory :" ,data_path )
22- path = data_path
23+ print ("loading data from the directory :" , data_path )
24+ path = data_path
2325 images = sorted (glob (os .path .join (path , "images/*.png" )))
2426 masks = sorted (glob (os .path .join (path , "masks/*.png" )))
2527
@@ -37,12 +39,12 @@ def __getitem__(self, index):
3739 """Get the images"""
3840 name = self .name_list [index ]
3941 img_path = os .path .join (name )
40-
42+
4143 mask_name = self .label_list [index ]
4244 msk_path = os .path .join (mask_name )
4345
44- img = Image .open (img_path ).convert (' RGB' )
45- mask = Image .open (msk_path ).convert ('L' )
46+ img = Image .open (img_path ).convert (" RGB" )
47+ mask = Image .open (msk_path ).convert ("L" )
4648
4749 # if self.mode == 'Training':
4850 # label = 0 if self.label_list[index] == 'benign' else 1
@@ -60,3 +62,60 @@ def __getitem__(self, index):
6062 # return (img, mask, name)
6163 # else:
6264 # return (img, mask, name)
65+
66+
67+ class CustomDataset3D (torch .utils .data .Dataset ):
68+ def __init__ (self , data_path , transform ):
69+ super ().__init__ ()
70+
71+ print ("loading data from the directory :" , data_path )
72+ path = data_path
73+ images = sorted (glob (os .path .join (path , "images/*.nii.gz" )))
74+ masks = sorted (glob (os .path .join (path , "masks/*.nii.gz" )))
75+
76+ assert len (images ) == len (masks ), "Number of images and masks must be the same"
77+
78+ self .valid_cases = [(img_path , seg_path ) for img_path , seg_path in zip (images , masks )]
79+
80+ self .all_slices = []
81+ for case_idx , (img_path , seg_path ) in enumerate (self .valid_cases ):
82+ seg_vol = nibabel .load (seg_path )
83+ img = nibabel .load (img_path )
84+ assert (
85+ img .shape == seg_vol .shape
86+ ), f"Image and segmentation shape mismatch: { img .shape } vs { seg_vol .shape } , Flies: { img_path } , { seg_path } "
87+ num_slices = img .shape [- 1 ]
88+ self .all_slices .extend (
89+ [(case_idx , slice_idx ) for slice_idx in range (num_slices )]
90+ )
91+
92+ self .data_path = path
93+
94+ self .transform = transform
95+
96+ def __len__ (self ):
97+ return len (self .all_slices )
98+
99+ def __getitem__ (self , x ):
100+ case_idx , slice_idx = self .all_slices [x ]
101+ img_path , seg_path = self .valid_cases [case_idx ]
102+
103+ nib_img = nibabel .load (img_path )
104+ nib_seg = nibabel .load (seg_path )
105+
106+ image = torch .tensor (nib_img .get_fdata (),dtype = torch .float32 )[:, :, slice_idx ].unsqueeze (0 ).unsqueeze (0 )
107+ label = torch .tensor (nib_seg .get_fdata (),dtype = torch .float32 )[:, :, slice_idx ].unsqueeze (0 ).unsqueeze (0 )
108+ label = torch .where (
109+ label > 0 , 1 , 0
110+ ).float () # merge all tumor classes into one
111+
112+ if self .transform :
113+ state = torch .get_rng_state ()
114+ image = self .transform (image )
115+ torch .set_rng_state (state )
116+ label = self .transform (label )
117+ return (
118+ image ,
119+ label ,
120+ img_path .split (".nii" )[0 ] + "_slice" + str (slice_idx ) + ".nii" ,
121+ ) # virtual path
0 commit comments