@@ -23,8 +23,8 @@ def __init__(self, args, data_path , transform = None, mode = 'Training',plane =
2323 images = sorted (glob (os .path .join (path , "images/*.png" )))
2424 masks = sorted (glob (os .path .join (path , "masks/*.png" )))
2525
26- self .name_list = images [: 2 ]
27- self .label_list = masks [: 2 ]
26+ self .name_list = images
27+ self .label_list = masks
2828 self .data_path = path
2929 self .mode = mode
3030
@@ -44,18 +44,19 @@ def __getitem__(self, index):
4444 img = Image .open (img_path ).convert ('RGB' )
4545 mask = Image .open (msk_path ).convert ('L' )
4646
47- if self .mode == 'Training' :
48- label = 0 if self .label_list [index ] == 'benign' else 1
49- else :
50- label = int (self .label_list [index ])
47+ # if self.mode == 'Training':
48+ # label = 0 if self.label_list[index] == 'benign' else 1
49+ # else:
50+ # label = int(self.label_list[index])
5151
5252 if self .transform :
5353 state = torch .get_rng_state ()
5454 img = self .transform (img )
5555 torch .set_rng_state (state )
5656 mask = self .transform (mask )
5757
58- if self .mode == 'Training' :
59- return (img , mask , name )
60- else :
61- return (img , mask , name )
58+ return (img , mask , name )
59+ # if self.mode == 'Training':
60+ # return (img, mask, name)
61+ # else:
62+ # return (img, mask, name)
0 commit comments