-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathscript.py
More file actions
230 lines (196 loc) · 8.6 KB
/
script.py
File metadata and controls
230 lines (196 loc) · 8.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import shutil
from pathlib import Path
import xarray as xr
import dask
import numpy as np
import pandas as pd
import anndata as ad
import spatialdata as sd
import sopa
import cv2
import torch
from rna2seg.dataset_zarr.patches import create_patch_rna2seg
import albumentations as A
from rna2seg.dataset_zarr import RNA2segDataset
from rna2seg.models import RNA2seg
from tqdm import tqdm
from rna2seg.utils import save_shapes2zarr
## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input_ist': 'resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr',
'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr',
'transcripts_key': 'transcripts',
'coordinate_system': 'global',
'output': './temp/sopa_testing/rna2seg_transcripts.zarr',
'flow_threshold': 0.9,
'cellbound_flow_threshold': 0.4,
'create_cytoplasm_image': False,
'cytoplasm_min_threshold': 0.25,
'cytoplasm_max_threshold': 0.75,
'patch_width': 1000,
'patch_overlap': 50,
}
meta = {
'name': 'rna2seg',
'temp_dir': "/Users/habib/Projects/txsim_project/task_ist_preprocessing/temp/rna2seg",
'cpus': 10
}
## VIASH END
TMP_DIR = Path(meta["temp_dir"] or "/tmp")
TMP_ZARR = TMP_DIR / 'temp_rna2seg_sdata.zarr'
# Read input
print('Reading input files', flush=True)
sdata = sd.read_zarr(par['input_ist'])
# sdata_segm = sd.read_zarr(par['input_segmentation'])
# Check if coordinate system is available in input data
transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys()
assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data."
# segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys()
# assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data."
### Run RNA2seg with sopa
### CREATE CYTOPLASM IMAGE FUNCTION
# TODO define this function somewhere else and import
def get_nuclear_outline(nuclear_image, threshold_min = 0.25, threshold_max = 0.75):
threshold_image = np.clip(nuclear_image, np.quantile(nuclear_image, threshold_min), np.quantile(nuclear_image, threshold_max))
# get the nuclear values (over a nuclear mask)
nuclear_mask = (nuclear_image > np.quantile(nuclear_image, threshold_max)) * nuclear_image
# scale nuclear values to whole cell values
scaling_factor = (np.max(threshold_image) - np.min(threshold_image)) / np.max(nuclear_mask)
# subtract nucleus from whole cell to get cytoplasm
cyto_image = threshold_image - (nuclear_mask * scaling_factor)
cyto_image = np.clip(cyto_image, 0 ,np.inf).astype(nuclear_image.dtype)
return cyto_image
#create composite image with 2nd channel as either 0s or generated cytoplasm image
nuclear_image = sdata['morphology_mip']['scale0'].image.compute().to_numpy()
composite = np.zeros([2, nuclear_image.shape[1], nuclear_image.shape[2]], dtype=nuclear_image.dtype)
composite[0,:,:] = nuclear_image
if par['create_cytoplasm_image']:
cyto_image = get_nuclear_outline(nuclear_image=nuclear_image,
threshold_min=par['cytoplasm_min_threshold'],
threshold_max=par['cytoplasm_max_threshold'])
composite[1,:,:] = cyto_image
# else: # redundant since the matrix is initialized to zeros
# composite[1,:,:] = 0
morphology_mip = sd.models.Image2DModel.parse(data=composite,
scale_factors=[2]*(len(sdata['morphology_mip'].groups)-2),
dims=['c','y','x'],
chunks=composite.shape)
#make sure image is transformed correctly
img_trans=sd.transformations.get_transformation(sdata['morphology_mip'], to_coordinate_system=par['coordinate_system'])
sd.transformations.set_transformation(morphology_mip, img_trans, to_coordinate_system=par['coordinate_system'])
# Create reduced sdata
print("Creating sopa SpatialData object")
sdata_sopa = sd.SpatialData(
points={
"transcripts": sdata[par['transcripts_key']]
},
images={
"morphology_mip": morphology_mip
}
)
sdata_sopa.write(TMP_ZARR, overwrite=True)
print("Running RNA2Seg")
# create patch in the sdata and precompute transcipt.csv for each patch with sopa
image_key = 'morphology_mip'
points_key = par["transcripts_key"]
gene_column_name="feature_name" # typically "feature_name" for Xenium
patch_width = par['patch_width']
patch_overlap = par['patch_overlap']
min_points_per_patch = 1
folder_patch_rna2seg = Path(TMP_ZARR / f".rna2seg_{patch_width}_{patch_overlap}")
create_patch_rna2seg(sdata=sdata_sopa,
image_key=image_key,
points_key=points_key,
patch_width=patch_width,
patch_overlap=patch_overlap,
min_points_per_patch=min_points_per_patch,
folder_patch_rna2seg = folder_patch_rna2seg,
gene_column_name=gene_column_name,
overwrite = True)
# Resize and create RNA2Seg dataset object
transform_resize = A.Compose([
A.Resize(width=512, height=512, interpolation=cv2.INTER_NEAREST),
])
dataset = RNA2segDataset(
sdata=sdata_sopa,
channels_dapi=[0],
channels_cellbound=[1],
patch_width = patch_width,
patch_overlap = patch_overlap,
gene_column=gene_column_name,
transform_resize = transform_resize,
patch_dir=folder_patch_rna2seg
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Set up RNA2Seg model
rna2seg = RNA2seg(
device,
net='unet',
flow_threshold = par['flow_threshold'],
cellbound_flow_threshold = par['cellbound_flow_threshold'],
pretrained_model = "default_pretrained"
)
#Run on patches
for i in tqdm(range(len(dataset))):
input_dict = dataset[i]
rna2seg.run(
path_temp_save=folder_patch_rna2seg,
input_dict=input_dict
)
# save shapes to zarr
segmentation_shape_name = "rna2seg_boundaries"
save_shapes2zarr(dataset=dataset,
path_parquet_files=folder_patch_rna2seg,
segmentation_key=segmentation_shape_name,
overwrite= True
)
# ONLY IF TESTING/USING CROP
# for whatever reason the cropping breaks the rna2seg transformation
# this fixes it, somehow
transcript_min = sd.transform(sdata_sopa['transcripts'],to_coordinate_system=par['coordinate_system']).compute()['x'].min()
shapes_max = sd.transform(sdata_sopa['rna2seg_boundaries'],to_coordinate_system=par['coordinate_system']).bounds['maxx'].max()
if transcript_min > shapes_max:
print(f"crop detected ({transcript_min} > {shapes_max}), reformatting")
trans = sd.transformations.get_transformation(sdata_sopa['morphology_mip'], to_coordinate_system=par['coordinate_system'])
sd.transformations.set_transformation(sdata_sopa['rna2seg_boundaries'], trans, to_coordinate_system=par['coordinate_system'])
# print(sd.transform(sdata_sopa['rna2seg_boundaries'],to_coordinate_system=par['coordinate_system']).bounds['maxx'].max())
# Assign transcripts based on shapes
sopa.spatial.assign_transcript_to_cell(
sdata_sopa,
points_key="transcripts",
shapes_key="rna2seg_boundaries",
key_added="cell_id",
unassigned_value=0
)
# Create objects for cells table
print('Creating objects for cells table', flush=True)
#create new .obs for cells based on the segmentation output (corresponding with the transcripts 'cell_id')
unique_cells = np.unique(sdata_sopa["transcripts"]["cell_id"])
# print(unique_cells)
# check if a '0' (noise/background) cell is in cell_id and remove
zero_idx = np.where(unique_cells == 0)
if len(zero_idx[0]): unique_cells=np.delete(unique_cells, zero_idx[0][0])
#transform into pandas series and check
cell_id_col = pd.Series(unique_cells, name='cell_id', index=unique_cells)
assert 0 not in cell_id_col, "Found '0' in cell_id column of assingment output cell matrix"
# Create transcripts only sdata
print('Subsetting to transcripts cell id data', flush=True)
sdata_transcripts_only = sd.SpatialData(
points={
"transcripts": sdata_sopa['transcripts']
},
tables={
"table": ad.AnnData(
obs=pd.DataFrame(cell_id_col),
var=sdata.tables["table"].var[[]]
)
}
)
# Write output
print('Write transcripts with cell ids', flush=True)
if os.path.exists(par["output"]):
shutil.rmtree(par["output"])
sdata_transcripts_only.write(par['output'])