Skip to content

Commit 3517ff7

Browse files
Merge pull request #2 from Noble-Lab/v0.2_dev
V0.2 dev
2 parents 627a2cb + 9aa79f8 commit 3517ff7

13 files changed

Lines changed: 909 additions & 1496 deletions

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,9 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163+
164+
# data folder for dev
165+
data/
166+
167+
# notebook for dev
168+
cellcyclenet/dev.ipynb

cellcyclenet/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.1'
1+
__version__ = '0.2.0'

cellcyclenet/interface.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from glob import glob
3030
from cellcyclenet.unet3d.model import UNet3D
3131
from cellcyclenet import models
32+
from cellcyclenet.unet2d import UNet2D
3233
from torch.utils.data import Dataset, DataLoader
3334
import torchvision.transforms.v2 as transforms
3435
from skimage.transform import downscale_local_mean
@@ -42,12 +43,10 @@
4243

4344
class CCN_Dataset(Dataset):
4445
'''Create a class to hold the PyTorch Dataset, input to PyTorch Dataloader.'''
45-
def __init__(self, X, y, norm_factor, scale_factors, transform, lazy_load):
46+
def __init__(self, X, y, transform, lazy_load):
4647
self.X = X
4748
self.y = y
4849
self.transform = transform
49-
self.norm_factor = norm_factor
50-
self.scale_factors = scale_factors
5150
self.lazy_load = lazy_load
5251

5352
def __len__(self):
@@ -66,8 +65,6 @@ def __getitem__(self, index):
6665
# if initialized with image fns (lazy loading), load, normalize, and scale image #
6766
else:
6867
image = imread(self.X[index])
69-
image = downscale_local_mean(image, self.scale_factors)
70-
image = image / self.norm_factor
7168

7269
# convert image to float 32 #
7370
X = np.asarray(image, dtype=np.float32)
@@ -92,40 +89,47 @@ def __getitem__(self, index):
9289

9390
class CellCycleNet:
9491

95-
def __init__(self, state_dict_path=None):
96-
# Initialize device and model architecture, load model weights #
92+
def __init__(self, state_dict_path=None, is_3d=True):
93+
# initialize device as GPU if available, otherwise CPU #
9794
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98-
self.model = UNet3D(in_channels=1, out_channels=1, is_segmentation=False, f_maps=32)
95+
self.is_3d = is_3d
96+
97+
# initialize model architecture #
98+
if self.is_3d:
99+
self.model = UNet3D(in_channels=1, out_channels=1, is_segmentation=False, f_maps=32)
100+
else:
101+
self.model = UNet2D(in_channels=1, init_features=32)
99102
self.model = torch.nn.DataParallel(self.model)
103+
104+
# if user does not provide a path to weights, load pretrained weights #
100105
if state_dict_path is None:
101-
state_dict_path = pkg_resources.files(models).joinpath('pretrained-model.pt')
102-
if torch.cuda.is_available():
106+
if self.is_3d:
107+
state_dict_path = pkg_resources.files(models).joinpath('pretrained-model_3D.pt')
108+
else:
109+
state_dict_path = pkg_resources.files(models).joinpath('pretrained-model_2D.pt')
110+
111+
# load model weights #
112+
if torch.cuda.is_available(): # load to GPU if available #
103113
state_dict = torch.load(state_dict_path)
104-
else:
114+
else: # otherwise, load to CPU #
105115
state_dict = torch.load(state_dict_path, map_location=torch.device('cpu'))
106116
self.model.load_state_dict(state_dict)
107117
self.model.to(self.device)
108118

109119
################################################################################################
110120

111-
def create_dataset(self, dataframe, norm_factor, scale_factors, split_data=False, seed=15):
121+
def create_dataset(self, dataframe, split_data=False, seed=15):
112122
'''
113123
Creates training, validation, and testing dataframes containing GT labels and image filenames for each nucleus.
114124
(NOTE: it is assumed that the order of labels in the dataframe is the same as order of images in image_dir.)
115125
Arguments:
116126
- image_dir [str] : path to directory of single nucleus images
117127
- dataframe [pd.DataFrame] : dataframe containing GT labels; if None, it is assumed that user wants to proceed without using labels (use case 1)
118-
- norm_factor [int] : images are divided pixelwise by this value during loading (calculated as median value of the medians of each input image)
119-
- scale_factor [tuple] : Z,Y,X scale factors
120128
- split_data [bool] : flag to determine if input data is split into train/val/test sets (use case 2) or just a single dataset (use case 1)
121129
- seed [int] : random seed to use for train/val/test split
122130
Outputs:
123131
- train, val, test [pd.DataFrame] : dataframe containing GT label (if inputted) + image filename for each nucleus
124132
'''
125-
# set norm / scale factors as attributes (to be called by .train() and .predict() when creating dataloaders)
126-
self.norm_factor = norm_factor
127-
self.scale_factors = scale_factors
128-
129133
### FIXME small DF for debugging ###
130134
# dataframe = dataframe.iloc[::15]
131135

@@ -166,9 +170,10 @@ def run_epoch(self, dataloader, is_train):
166170
images.to(self.device)
167171
labels.to(self.device)
168172

169-
# reshape images to [batch, channel, Z, Y, X] #
170-
images = torch.swapaxes(images, 1, 2)
171-
images = torch.unsqueeze(images, 1)
173+
# for 3D images, reshape images to [batch, channel, Z, Y, X] #
174+
if self.is_3d:
175+
images = torch.swapaxes(images, 1, 2)
176+
images = torch.unsqueeze(images, 1)
172177

173178
### FORWARD PASS ###
174179
outputs = self.model(images)
@@ -234,13 +239,11 @@ def train(self, train_df, val_df, n_epochs, batch_size=4, initial_LR=1e-5, trans
234239
train_X_fn = train_df['filename'].values
235240
train_y = np.where(train_df['label'].values == 'G1', 0, 1)
236241
if lazy_load:
237-
train_dataloader = DataLoader(CCN_Dataset(train_X_fn, train_y, self.norm_factor, self.scale_factors, transform=transform, lazy_load=lazy_load),
242+
train_dataloader = DataLoader(CCN_Dataset(train_X_fn, train_y, transform=transform, lazy_load=lazy_load),
238243
batch_size=batch_size, shuffle=False)
239244
else:
240245
train_X = np.asarray([imread(fn) for fn in train_X_fn])
241-
train_X_ds = np.asarray([downscale_local_mean(image, self.scale_factors) for image in train_X])
242-
train_X_norm = np.asarray([image / self.norm_factor for image in train_X_ds])
243-
train_dataloader = DataLoader(CCN_Dataset(train_X_norm, train_y, self.norm_factor, self.scale_factors, transform=transform, lazy_load=lazy_load),
246+
train_dataloader = DataLoader(CCN_Dataset(train_X, train_y, transform=transform, lazy_load=lazy_load),
244247
batch_size=batch_size, shuffle=False)
245248

246249
# create dataloader for validation data #
@@ -249,13 +252,11 @@ def train(self, train_df, val_df, n_epochs, batch_size=4, initial_LR=1e-5, trans
249252
val_y = np.where(val_df['label'].values == 'G1', 0, 1)
250253

251254
if lazy_load:
252-
val_dataloader = DataLoader(CCN_Dataset(val_X_fn, val_y, self.norm_factor, self.scale_factors, transform=None, lazy_load=lazy_load),
255+
val_dataloader = DataLoader(CCN_Dataset(val_X_fn, val_y, transform=None, lazy_load=lazy_load),
253256
batch_size=batch_size, shuffle=False)
254257
else:
255258
val_X = np.asarray([imread(fn) for fn in val_X_fn])
256-
val_X_ds = np.asarray([downscale_local_mean(image, self.scale_factors) for image in val_X])
257-
val_X_norm = np.asarray([image / self.norm_factor for image in val_X_ds])
258-
val_dataloader = DataLoader(CCN_Dataset(val_X_norm, val_y, self.norm_factor, self.scale_factors, transform=None, lazy_load=lazy_load),
259+
val_dataloader = DataLoader(CCN_Dataset(val_X, val_y, transform=None, lazy_load=lazy_load),
259260
batch_size=batch_size, shuffle=False)
260261

261262
# track val acc for each epoch to check for improvement #
@@ -313,7 +314,7 @@ def predict(self, dataframe, with_labels, decision_threshold=0.5):
313314

314315
X_fn = dataframe['filename'].values
315316
y = np.where(dataframe['label'].values == 'G1', 0, 1) if with_labels else np.zeros(len(X_fn))
316-
dataloader = DataLoader(CCN_Dataset(X_fn, y, self.norm_factor, self.scale_factors, transform=None, lazy_load=True), batch_size=4, shuffle=False)
317+
dataloader = DataLoader(CCN_Dataset(X_fn, y, transform=None, lazy_load=True), batch_size=4, shuffle=False)
317318

318319
# Run through network #
319320
with torch.no_grad():
@@ -322,9 +323,10 @@ def predict(self, dataframe, with_labels, decision_threshold=0.5):
322323
images = images.to(self.device)
323324
labels = labels.to(self.device)
324325

325-
# Reshape to [batch, channels, Z, Y, X] #
326-
images = torch.swapaxes(images, 1, 2)
327-
images = torch.unsqueeze(images, 1)
326+
# for 3D images, reshape to [batch, channels, Z, Y, X] #
327+
if self.is_3d:
328+
images = torch.swapaxes(images, 1, 2)
329+
images = torch.unsqueeze(images, 1)
328330

329331
# Run inference #
330332
outputs = self.model(images)
@@ -422,7 +424,10 @@ def show_image(self, dataframe, index=None, with_preds=True, hide_plot=False, fi
422424
if not hide_plot:
423425
plt.figure(figsize=figsize)
424426
plt.axis('off')
425-
plt.imshow(np.max(image, axis=0))
427+
if self.is_3d:
428+
plt.imshow(np.max(image, axis=0))
429+
else:
430+
plt.imshow(image)
426431
plt.title(f'Index: {index} / Label: {label} / Pred: {pred} / Prob: {prob:.3f}', fontsize=10)
427432
plt.show()
428433

1.49 MB
Binary file not shown.
File renamed without changes.

cellcyclenet/unet2d.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class UNet2D(nn.Module):
5+
"""
6+
2D version of the classification UNet architecture.
7+
8+
Args:
9+
in_channels (int): Number of input channels (default: 1)
10+
init_features (int): Number of features in first layer (default: 32)
11+
dropout_prob (float): Dropout probability in classification head (default: 0.5)
12+
"""
13+
def __init__(self, in_channels=1, init_features=32, dropout_prob=0.5):
14+
super(UNet2D, self).__init__()
15+
16+
# Store feature numbers for each level
17+
features = init_features
18+
19+
# Level 1 (No pooling)
20+
num_groups = 1 if in_channels == 1 else 8 # Adjust the number of groups to 1 if in_channels is 1
21+
self.level1 = nn.Sequential(
22+
nn.GroupNorm(num_groups, in_channels),
23+
nn.Conv2d(in_channels, features, kernel_size=3, stride=1, padding=1),
24+
nn.ReLU(inplace=True)
25+
)
26+
27+
# Level 2
28+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
29+
self.level2 = nn.Sequential(
30+
nn.GroupNorm(8, features),
31+
nn.Conv2d(features, features*2, kernel_size=3, stride=1, padding=1),
32+
nn.ReLU(inplace=True)
33+
)
34+
35+
# Level 3
36+
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
37+
self.level3 = nn.Sequential(
38+
nn.GroupNorm(8, features*2),
39+
nn.Conv2d(features*2, features*4, kernel_size=3, stride=1, padding=1),
40+
nn.ReLU(inplace=True)
41+
)
42+
43+
# Level 4
44+
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
45+
self.level4 = nn.Sequential(
46+
nn.GroupNorm(8, features*4),
47+
nn.Conv2d(features*4, features*8, kernel_size=3, stride=1, padding=1),
48+
nn.ReLU(inplace=True)
49+
)
50+
51+
# Classification head
52+
self.classifier = nn.Sequential(
53+
nn.AdaptiveMaxPool2d(output_size=(1, 1)),
54+
nn.Flatten(),
55+
nn.Dropout(p=dropout_prob),
56+
nn.Linear(features*8, 1)
57+
)
58+
59+
def forward(self, x):
60+
"""
61+
Forward pass of the model.
62+
63+
Args:
64+
x (torch.Tensor): Input tensor of shape (batch, channels, H, W)
65+
66+
Returns:
67+
torch.Tensor: Output tensor of shape (batch, 1)
68+
69+
Shape transformations:
70+
Level 1: (batch, 1, H, W) -> (batch, 32, H, W)
71+
Level 2: (batch, 32, H/2, W/2) -> (batch, 64, H/2, W/2)
72+
Level 3: (batch, 64, H/4, W/4) -> (batch, 128, H/4, W/4)
73+
Level 4: (batch, 128, H/8, W/8) -> (batch, 256, H/8, W/8)
74+
Classification: (batch, 256, H/8, W/8) -> (batch, 1)
75+
"""
76+
# Encoder path
77+
x1 = self.level1(x) # (batch, 32, H, W)
78+
79+
x2 = self.pool2(x1) # (batch, 32, H/2, W/2)
80+
x2 = self.level2(x2) # (batch, 64, H/2, W/2)
81+
82+
x3 = self.pool3(x2) # (batch, 64, H/4, W/4)
83+
x3 = self.level3(x3) # (batch, 128, H/4, W/4)
84+
85+
x4 = self.pool4(x3) # (batch, 128, H/8, W/8)
86+
x4 = self.level4(x4) # (batch, 256, H/8, W/8)
87+
88+
# Classification head
89+
out = self.classifier(x4) # (batch, 1)
90+
91+
return out
92+
93+
def get_embedding(self, x):
94+
# Encoder path
95+
x1 = self.level1(x) # (batch, 32, H, W)
96+
97+
x2 = self.pool2(x1) # (batch, 32, H/2, W/2)
98+
x2 = self.level2(x2) # (batch, 64, H/2, W/2)
99+
100+
x3 = self.pool3(x2) # (batch, 64, H/4, W/4)
101+
x3 = self.level3(x3) # (batch, 128, H/4, W/4)
102+
103+
x4 = self.pool4(x3) # (batch, 128, H/8, W/8)
104+
x4 = self.level4(x4) # (batch, 256, H/8, W/8)
105+
106+
# Classification head (w/o final linear layer)
107+
pooled = self.classifier[0](x4)
108+
embedding = self.classifier[1](pooled)
109+
110+
return embedding
111+
112+
if __name__ == "__main__":
113+
# Test the model with a sample input
114+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
115+
model = UNet2D().to(device)
116+
117+
# Create random input tensor (batch_size=4, channels=1, height=64, width=64)
118+
x = torch.randn(4, 1, 64, 64).to(device)
119+
120+
# Forward pass
121+
output = model(x)
122+
123+
print(f"Input shape: {x.shape}")
124+
print(f"Output shape: {output.shape}")
125+
126+
# Print model summary
127+
print("\nModel Architecture:")
128+
print(model)

0 commit comments

Comments
 (0)