|
| 1 | +import numpy as np |
| 2 | +import dask |
| 3 | +import spatialdata as sd |
| 4 | +import txsim as tx |
| 5 | +import anndata as ad |
| 6 | +import os |
| 7 | +import shutil |
| 8 | + |
| 9 | +## VIASH START |
| 10 | +# Note: this section is auto-generated by viash at runtime. To edit it, make changes |
| 11 | +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. |
| 12 | +par = { |
| 13 | + 'input_ist': 'resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr', |
| 14 | + 'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr', |
| 15 | + 'transcripts_key': 'transcripts', |
| 16 | + 'coordinate_system': 'global', |
| 17 | + 'output': '../pciSeq_assigned_transcripts.zarr', |
| 18 | + |
| 19 | + 'input_scrnaseq': 'resources_test/task_ist_preprocessing/mouse_brain_combined/scrnaseq_reference.h5ad', |
| 20 | + 'sc_cell_type_key': 'cell_type', |
| 21 | + |
| 22 | + 'exclude_genes': None, |
| 23 | + 'max_iter': 1000, |
| 24 | + 'CellCallTolerance': 0.02, |
| 25 | + 'rGene': 20, |
| 26 | + 'Inefficiency': 0.2, |
| 27 | + 'InsideCellBonus': 2, |
| 28 | + 'MisreadDensity': 0.00001, |
| 29 | + 'SpotReg': 0.1, |
| 30 | + 'nNeighbors': 3, |
| 31 | + 'rSpot': 2, |
| 32 | + 'save_data': False, |
| 33 | + 'dtype': np.float64 |
| 34 | +} |
| 35 | +meta = { |
| 36 | + 'name': 'pciSeq_transcript_assignment' |
| 37 | +} |
| 38 | +## VIASH END |
| 39 | + |
| 40 | +# Read input |
| 41 | +print('Reading input files', flush=True) |
| 42 | +sdata = sd.read_zarr(par['input_ist']) |
| 43 | +sdata_segm = sd.read_zarr(par['input_segmentation']) |
| 44 | + |
| 45 | +# Check if coordinate system is available in input data |
| 46 | +transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys() |
| 47 | +assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." |
| 48 | +segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys() |
| 49 | +assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." |
| 50 | + |
| 51 | +# Transform transcript coordinates to the coordinate system |
| 52 | +print('Transforming transcripts coordinates', flush=True) |
| 53 | +transcripts = sd.transform(sdata[par['transcripts_key']], to_coordinate_system=par['coordinate_system']) |
| 54 | + |
| 55 | +# In case of a translation transformation of the segmentation (e.g. crop of the data), we need to adjust the transcript coordinates |
| 56 | +trans = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True)[par['coordinate_system']].inverse() |
| 57 | +transcripts = sd.transform(transcripts, trans, par['coordinate_system']) |
| 58 | + |
| 59 | +# Assign cell ids to transcripts |
| 60 | +print('Assigning transcripts to cell ids', flush=True) |
| 61 | +y_coords = transcripts.y.compute().to_numpy() |
| 62 | +x_coords = transcripts.x.compute().to_numpy() |
| 63 | + |
| 64 | +#Added for pciSeq |
| 65 | +#TODO this will immediately break when the name of the gene isn't feature_name |
| 66 | +transcripts_dataframe = sdata[par['transcripts_key']].compute()[['feature_name']] |
| 67 | +transcripts_dataframe['x'] = x_coords |
| 68 | +transcripts_dataframe['y'] = y_coords |
| 69 | + |
| 70 | +#same as before |
| 71 | +label_image = sdata_segm["segmentation"]["scale0"].image.to_numpy() #TODO: mabye this line needs generalization (DataTree vs DataArray) |
| 72 | + |
| 73 | +# Grab all the pciSeq parameters |
| 74 | +opts_keys = [#'exclude_genes', |
| 75 | + 'max_iter', |
| 76 | + 'CellCallTolerance', |
| 77 | + 'rGene', |
| 78 | + 'Inefficiency', |
| 79 | + 'InsideCellBonus', |
| 80 | + 'MisreadDensity', |
| 81 | + 'SpotReg', |
| 82 | + 'nNeighbors', |
| 83 | + 'rSpot', |
| 84 | + 'save_data'] |
| 85 | + |
| 86 | +opts = {k: par[k] for k in opts_keys} |
| 87 | + |
| 88 | +input_scrnaseq = ad.read_h5ad(par['input_scrnaseq']) |
| 89 | +input_scrnaseq.X = input_scrnaseq.layers['counts'] |
| 90 | + |
| 91 | +assignments, cell_types = tx.preprocessing.run_pciSeq( |
| 92 | + transcripts_dataframe, |
| 93 | + label_image, |
| 94 | + input_scrnaseq, |
| 95 | + par['sc_cell_type_key'], |
| 96 | + opts |
| 97 | +) |
| 98 | + |
| 99 | +#assign transcript -> cell |
| 100 | +cell_id_dask_series = dask.dataframe.from_dask_array( |
| 101 | + dask.array.from_array( |
| 102 | + assignments['cell'].to_numpy(), chunks=tuple(sdata[par['transcripts_key']].map_partitions(len).compute()) |
| 103 | + ), |
| 104 | + index=sdata[par['transcripts_key']].index |
| 105 | +) |
| 106 | + |
| 107 | +sdata[par['transcripts_key']]["cell_id"] = cell_id_dask_series |
| 108 | + |
| 109 | +# create new .obs for cells based on the segmentation output (corresponding with the transcripts 'cell_id') |
| 110 | +cell_types['type'] = cell_types['type'].replace({'None':'None_sp'}) |
| 111 | +cell_types.insert(0, 'cell_id', cell_types.index) |
| 112 | +cell_types.rename(columns={'type':'cell_type','prob':'cell_type_prob'}, inplace=True) |
| 113 | + |
| 114 | +assert 0 not in cell_types['cell_id'], "Found '0' in cell_id column of assingment output cell matrix" |
| 115 | + |
| 116 | +output_table = ad.AnnData( |
| 117 | + obs=cell_types[['cell_id','cell_type','cell_type_prob']], |
| 118 | + var=sdata.tables["table"].var[[]] |
| 119 | + ) |
| 120 | + |
| 121 | +# TODO: Also take care of the following cases: |
| 122 | +# - segmentation 3D, transcripts 3D |
| 123 | +# - segmentation 3D, transcripts 2D |
| 124 | +# - segmentation 2D, transcripts 3D |
| 125 | + |
| 126 | +# Subset sdata to transcripts with cell ids |
| 127 | + |
| 128 | +print('Subsetting to transcripts cell id and cell type data', flush=True) |
| 129 | +sdata_transcripts_only = sd.SpatialData( |
| 130 | + points={ |
| 131 | + "transcripts": sdata[par['transcripts_key']] |
| 132 | + }, |
| 133 | + tables={ |
| 134 | + "table": output_table |
| 135 | + } |
| 136 | +) |
| 137 | + |
| 138 | +print('Write transcripts with cell ids and cell types', flush=True) |
| 139 | +if os.path.exists(par["output"]): |
| 140 | + shutil.rmtree(par["output"]) |
| 141 | +sdata_transcripts_only.write(par['output']) |
| 142 | + |
| 143 | + |
0 commit comments