Skip to content

Commit eb1b192

Browse files
committed
updated example notebooks
1 parent a458139 commit eb1b192

4 files changed

Lines changed: 225 additions & 191 deletions

File tree

notebooks/01_prediction_demo.ipynb

Lines changed: 0 additions & 85 deletions
This file was deleted.

notebooks/02_fine_tune_training_demo.ipynb

Lines changed: 0 additions & 106 deletions
This file was deleted.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "825a7d90",
6+
"metadata": {},
7+
"source": [
8+
"## CellCycleNet Example - Fine tune pretrained model and predict cell cycle stage on 3D DAPI images WITH ground truth labels.\n",
9+
"This notebook demonstrates how to use CellCycleNet to fine tune the pretrained model and predict cell cycle stage from images of DAPI-stained nuclei that have associated ground truth labels for cell cycle stage.\n",
10+
"\n",
11+
"CellCycleNet requires the following data:\n",
12+
" - A directory of 3D DAPI-stained fields of view named as `tile_<tile_num>.tiff`\n",
13+
" - A directory of 3D segmentation masks named as `mask_<tile_num>.tiff`\n",
14+
" - A directory of 3D ground truth label arrays named as `label_<tile_num>.npy` (pixel values must be 0, 1, or 2 where 0 = background, 1 = G1 nucleus, 2 = S/G2 nucleus)\n",
15+
" - Where `<tile_num>` is an integer that uniquely identifies each field of view and its corresponding segmentation mask and label array\n",
16+
"\n",
17+
"### Step 1: Create single-nucleus images from segmented FOVs."
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
23+
"id": "001a1733-cda4-49db-833b-278f3c4a2099",
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"from cellcyclenet import utils\n",
28+
"\n",
29+
"IMAGE_DIR = '../data/test_tiles/' # path to DAPI-stained FOVs\n",
30+
"MASK_DIR = '../data/test_masks/' # path to segmentation masks of FOVs \n",
31+
"LABEL_DIR = '../data/test_labels/' # path to label arrays of FOVs\n",
32+
"OUTPUT_DIR = '../data/test_SNI_label/' # path where labeled single-nucleus images will be saved\n",
33+
"\n",
34+
"# generate labeled SNIs #\n",
35+
"'''\n",
36+
"Optional Arguments for utils.generate_images_labeled():\n",
37+
" - return_df: boolean, if True, returns a pandas dataframe with the tile numbers, object numbers, and labels of the labeled SNIs\n",
38+
" - num_cores: integer, number of cores to use for parallel processing; if 'None', no parallel processing will be used\n",
39+
" - is_3d: boolean, set to True if your data is 3D; set to False if your data is 2D\n",
40+
"'''\n",
41+
"df = utils.generate_images_labeled(IMAGE_DIR, MASK_DIR, LABEL_DIR, OUTPUT_DIR, return_df=True, num_cores=None, is_3d=True)\n",
42+
"display(df)"
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"id": "63f4651f-d75b-4d53-8f78-464fd8145a55",
48+
"metadata": {},
49+
"source": [
50+
"### Step 2: Load pretrained model and finetune with labeled data."
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"id": "dd2f8298-c077-436e-827b-89a3d7d2486b",
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"from cellcyclenet import CellCycleNet\n",
61+
"\n",
62+
"# create 3D model instance (pretrained weights are loaded by default) #\n",
63+
"model = CellCycleNet(is_3d=True)\n",
64+
"\n",
65+
"# convert dataframe to CellCycleNet dataset; split_data=True --> 70% training, 20% validation, 10% test #\n",
66+
"train, val, test = model.create_dataset(dataframe=df, split_data=True)\n",
67+
"\n",
68+
"# train model on your labeled data #\n",
69+
"'''\n",
70+
"Optional Arguments for CellCycleNet.train():\n",
71+
" - transform: callable, a function that takes an image and returns a transformed image; if None, no transformation will be applied\n",
72+
" - lazy_load: boolean, if True, the dataset will be loaded lazily in each epoch (slower, but uses less memory); if False, the dataset will be loaded into memory (faster, but uses more memory) \n",
73+
" - verbose: boolean, if True, training progress will be printed to the console\n",
74+
"'''\n",
75+
"model.train(train, val, n_epochs=10, batch_size=4, initial_LR=1e-5, transform=None, lazy_load=True, verbose=True)\n",
76+
"\n",
77+
"# save finetuned model weights #\n",
78+
"model.save_model('fine_tuned_model.pt')"
79+
]
80+
},
81+
{
82+
"cell_type": "markdown",
83+
"id": "1577521d-4aaa-49c0-96de-91e3ea26ad97",
84+
"metadata": {},
85+
"source": [
86+
"### Step 3: Predict cell cycle stage for each single-nucleus image in the test set."
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"id": "49b5b1a2",
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"# generate cell cycle stage predictions (0 = G1, 1 = S/G2); with_labels=True --> predicting on labeled data #\n",
97+
"test_predictions = model.predict(test, with_labels=True)\n",
98+
"\n",
99+
"# display predictions #\n",
100+
"test_predictions.sort_values(['tile_num', 'obj_num'])\n",
101+
"display(predictions)\n",
102+
"\n",
103+
"# plot an ROC curve to evaluate model peformance #\n",
104+
"model.plot_ROC(test_predictions['label'], test_predictions['pred'], test_predictions['prob']"
105+
]
106+
}
107+
],
108+
"metadata": {
109+
"kernelspec": {
110+
"display_name": "Python 3 (ipykernel)",
111+
"language": "python",
112+
"name": "python3"
113+
},
114+
"language_info": {
115+
"codemirror_mode": {
116+
"name": "ipython",
117+
"version": 3
118+
},
119+
"file_extension": ".py",
120+
"mimetype": "text/x-python",
121+
"name": "python",
122+
"nbconvert_exporter": "python",
123+
"pygments_lexer": "ipython3",
124+
"version": "3.12.4"
125+
}
126+
},
127+
"nbformat": 4,
128+
"nbformat_minor": 5
129+
}

0 commit comments

Comments
 (0)