Skip to content

Commit f0ff3a7

Browse files
committed
created example notebooks for 2D use case
1 parent 967c8d4 commit f0ff3a7

2 files changed

Lines changed: 225 additions & 0 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "825a7d90",
6+
"metadata": {},
7+
"source": [
8+
"## CellCycleNet Example - Fine tune pretrained model and predict cell cycle stage on 2D DAPI images WITH ground truth labels.\n",
9+
"This notebook demonstrates how to use CellCycleNet to fine tune the pretrained model and predict cell cycle stage from images of DAPI-stained nuclei that have associated ground truth labels for cell cycle stage.\n",
10+
"\n",
11+
"CellCycleNet requires the following data:\n",
12+
" - A directory of 2D DAPI-stained fields of view named as `tile_<tile_num>.tiff`\n",
13+
" - A directory of 2D segmentation masks named as `mask_<tile_num>.tiff`\n",
14+
" - A directory of 2D ground truth label arrays named as `label_<tile_num>.npy` (pixel values must be 0, 1, or 2 where 0 = background, 1 = G1 nucleus, 2 = S/G2 nucleus)\n",
15+
" - Where `<tile_num>` is an integer that uniquely identifies each field of view and its corresponding segmentation mask and label array\n",
16+
"\n",
17+
"### Step 1: Create single-nucleus images from segmented FOVs."
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
23+
"id": "001a1733-cda4-49db-833b-278f3c4a2099",
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"from cellcyclenet import utils\n",
28+
"\n",
29+
"IMAGE_DIR = '../data/test_tiles/' # path to DAPI-stained FOVs\n",
30+
"MASK_DIR = '../data/test_masks/' # path to segmentation masks of FOVs \n",
31+
"LABEL_DIR = '../data/test_labels/' # path to label arrays of FOVs\n",
32+
"OUTPUT_DIR = '../data/test_SNI_label/' # path where labeled single-nucleus images will be saved\n",
33+
"\n",
34+
"# generate labeled SNIs #\n",
35+
"'''\n",
36+
"Optional Arguments for utils.generate_images_labeled():\n",
37+
" - return_df: boolean, if True, returns a pandas dataframe with the tile numbers, object numbers, and labels of the labeled SNIs\n",
38+
" - num_cores: integer, number of cores to use for parallel processing; if 'None', no parallel processing will be used\n",
39+
" - is_3d: boolean, set to True if your data is 3D; set to False if your data is 2D\n",
40+
"'''\n",
41+
"df = utils.generate_images_labeled(IMAGE_DIR, MASK_DIR, LABEL_DIR, OUTPUT_DIR, return_df=True, num_cores=None, is_3d=False)\n",
42+
"df"
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"id": "63f4651f-d75b-4d53-8f78-464fd8145a55",
48+
"metadata": {},
49+
"source": [
50+
"### Step 2: Load pretrained model and finetune with labeled data."
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"id": "dd2f8298-c077-436e-827b-89a3d7d2486b",
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"from cellcyclenet import CellCycleNet\n",
61+
"\n",
62+
"# create 2D model instance (pretrained weights are loaded by default) #\n",
63+
"model = CellCycleNet(is_3d=False)\n",
64+
"\n",
65+
"# convert dataframe to CellCycleNet dataset; split_data=True --> 70% training, 20% validation, 10% test #\n",
66+
"train, val, test = model.create_dataset(dataframe=df, split_data=True)\n",
67+
"\n",
68+
"# train model on your labeled data #\n",
69+
"'''\n",
70+
"Optional Arguments for CellCycleNet.train():\n",
71+
" - transform: callable, a function that takes an image and returns a transformed image; if None, no transformation will be applied\n",
72+
" - lazy_load: boolean, if True, the dataset will be loaded lazily in each epoch (slower, but uses less memory); if False, the dataset will be loaded into memory (faster, but uses more memory) \n",
73+
" - verbose: boolean, if True, training progress will be printed to the console\n",
74+
"'''\n",
75+
"model.train(train, val, n_epochs=10, batch_size=4, initial_LR=1e-5, transform=None, lazy_load=True, verbose=True)\n",
76+
"\n",
77+
"# save finetuned model weights #\n",
78+
"model.save_model('fine_tuned_model.pt')"
79+
]
80+
},
81+
{
82+
"cell_type": "markdown",
83+
"id": "1577521d-4aaa-49c0-96de-91e3ea26ad97",
84+
"metadata": {},
85+
"source": [
86+
"### Step 3: Predict cell cycle stage for each single-nucleus image in the test set."
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"id": "49b5b1a2",
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"# generate cell cycle stage predictions (0 = G1, 1 = S/G2); with_labels=True --> predicting on labeled data #\n",
97+
"test_predictions = model.predict(test, with_labels=True)\n",
98+
"\n",
99+
"# display predictions #\n",
100+
"test_predictions.sort_values(['tile_num', 'obj_num'])\n",
101+
"test_predictions\n",
102+
"\n",
103+
"# plot an ROC curve to evaluate model peformance #\n",
104+
"model.plot_ROC(test_predictions['label'], test_predictions['pred'], test_predictions['prob'])"
105+
]
106+
}
107+
],
108+
"metadata": {
109+
"kernelspec": {
110+
"display_name": "Python 3 (ipykernel)",
111+
"language": "python",
112+
"name": "python3"
113+
},
114+
"language_info": {
115+
"codemirror_mode": {
116+
"name": "ipython",
117+
"version": 3
118+
},
119+
"file_extension": ".py",
120+
"mimetype": "text/x-python",
121+
"name": "python",
122+
"nbconvert_exporter": "python",
123+
"pygments_lexer": "ipython3",
124+
"version": "3.12.4"
125+
}
126+
},
127+
"nbformat": 4,
128+
"nbformat_minor": 5
129+
}

notebooks/2D_prediction_demo.ipynb

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "77b4fd88-3e35-4778-8a50-868827774cbb",
6+
"metadata": {},
7+
"source": [
8+
"## CellCycleNet Example - Predict cell cycle stage from 2D DAPI images WITHOUT ground truth labels.\n",
9+
"This notebook demonstrates how to use CellCycleNet to predict cell cycle stage from images of DAPI-stained nuclei that do not have associated ground truth labels for cell cycle stage.\n",
10+
"\n",
11+
"CellCycleNet requires the following data:\n",
12+
" - A directory of 2D DAPI-stained fields of view named as `tile_<tile_num>.tiff`\n",
13+
" - A directory of 2D segmentation masks named as `mask_<tile_num>.tiff`\n",
14+
" - Where `<tile_num>` is an integer that uniquely identifies each field of view and its corresponding segmentation mask\n",
15+
"\n",
16+
"### Step 1: Create single-nucleus images from segmented FOVs."
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"id": "001a1733-cda4-49db-833b-278f3c4a2099",
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"from cellcyclenet import utils\n",
27+
"\n",
28+
"IMAGE_DIR = '../data/test_tiles/' # path to DAPI-stained FOVs\n",
29+
"MASK_DIR = '../data/test_masks/' # path to segmentation masks of FOVs\n",
30+
"OUTPUT_DIR = '../data/test_SNI/' # path where single-nucleus images will be saved\n",
31+
"\n",
32+
"# generate unlabeled SNIs #\n",
33+
"'''\n",
34+
"Optional Arguments for utils.generate_images():\n",
35+
" - return_df: boolean, if True, returns a dataframe with the paths to the generated SNIs; dataframe will saved as a .csv file in any case\n",
36+
" - num_cores: integer, number of cores to use for parallel processing; if 'None', no parallel processing will be used\n",
37+
" - is_3d: boolean, set to True if your data is 3D; set to False if your data is 2D\n",
38+
"'''\n",
39+
"df = utils.generate_images(IMAGE_DIR, MASK_DIR, OUTPUT_DIR, return_df=True, num_cores=None, is_3d=False)\n",
40+
"df"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"id": "63f4651f-d75b-4d53-8f78-464fd8145a55",
46+
"metadata": {},
47+
"source": [
48+
"### Step 2: Predict cell cycle stage for each single-nucleus image."
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"id": "d68416d9-5afb-45cd-9899-0be6a7548f21",
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"from cellcyclenet import CellCycleNet\n",
59+
"\n",
60+
"# create 3D model instance (pretrained weights are loaded by default) #\n",
61+
"model = CellCycleNet(is_3d=False)\n",
62+
"\n",
63+
"# convert dataframe to CellCycleNet dataset; split_data=False --> no validation or testing sets will be created #\n",
64+
"dataset = model.create_dataset(dataframe=df, split_data=False)\n",
65+
"\n",
66+
"# generate cell cycle stage predictions (0 = G1, 1 = S/G2); with_labels=False --> predicting on unlabeled data #\n",
67+
"predictions = model.predict(dataset, with_labels=False)\n",
68+
"\n",
69+
"# display predictions #\n",
70+
"predictions.sort_values(['tile_num', 'obj_num'])\n",
71+
"predictions"
72+
]
73+
}
74+
],
75+
"metadata": {
76+
"kernelspec": {
77+
"display_name": "dask-cellpose",
78+
"language": "python",
79+
"name": "python3"
80+
},
81+
"language_info": {
82+
"codemirror_mode": {
83+
"name": "ipython",
84+
"version": 3
85+
},
86+
"file_extension": ".py",
87+
"mimetype": "text/x-python",
88+
"name": "python",
89+
"nbconvert_exporter": "python",
90+
"pygments_lexer": "ipython3",
91+
"version": "3.12.2"
92+
}
93+
},
94+
"nbformat": 4,
95+
"nbformat_minor": 5
96+
}

0 commit comments

Comments
 (0)