Skip to content

Commit 9d817be

Browse files
committed
add query type with specified anchor cells
1 parent c0d4b17 commit 9d817be

1 file changed

Lines changed: 146 additions & 99 deletions

File tree

src/vitessce/widget_plugins/spatial_query.py

Lines changed: 146 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,155 @@
1+
import colorsys
2+
import json
13
from oxc_py import transform
24
from ..widget import VitesscePlugin
35

46

5-
PLUGIN_ESM = transform("""
6-
function createPlugins(utilsForPlugins) {
7-
const {
7+
def _build_plugin_esm(cell_type_list):
8+
ct_list_js = json.dumps(cell_type_list)
9+
js_source = f"""
10+
function createPlugins(utilsForPlugins) {{
11+
const {{
812
React,
913
PluginFileType,
1014
PluginViewType,
1115
PluginCoordinationType,
1216
PluginJointFileType,
1317
z,
1418
useCoordination,
15-
} = utilsForPlugins;
16-
function SpatialQueryView(props) {
17-
const { coordinationScopes } = props;
18-
const [{
19+
}} = utilsForPlugins;
20+
21+
const CELL_TYPE_LIST = {ct_list_js};
22+
23+
function SpatialQueryView(props) {{
24+
const {{ coordinationScopes }} = props;
25+
const [{{
1926
queryParams,
20-
obsSetSelection,
21-
}, {
27+
}}, {{
2228
setQueryParams,
23-
}] = useCoordination(['queryParams', 'obsSetSelection', 'obsType'], coordinationScopes);
29+
}}] = useCoordination(['queryParams', 'obsType'], coordinationScopes);
2430
2531
const [uuid, setUuid] = React.useState(1);
2632
const [queryType, setQueryType] = React.useState('grid');
33+
const [anchorCellType, setAnchorCellType] = React.useState(CELL_TYPE_LIST[0] ?? '');
2734
const [maxDist, setMaxDist] = React.useState(10);
28-
const [minSize, setMinSize] = React.useState(0);
29-
const [minCount, setMinCount] = React.useState(0);
3035
const [minSupport, setMinSupport] = React.useState(0.5);
36+
const [k, setK] = React.useState(20);
37+
const [nPoints, setNPoints] = React.useState(1000);
38+
39+
const isAnchorMode = queryType === 'anchor-type-knn' || queryType === 'anchor-type-dist';
3140
32-
const cellTypeOfInterest = obsSetSelection?.length === 1 && obsSetSelection[0][0] === "Cell Type"
33-
? obsSetSelection[0][1]
34-
: null;
41+
const onQueryTypeChange = React.useCallback((e) => {{
42+
const newType = e.target.value;
43+
setQueryType(newType);
44+
// Update maxDist default when switching modes
45+
if (newType === 'anchor-type-knn') {{
46+
setMaxDist(20);
47+
}} else {{
48+
setMaxDist(10);
49+
}}
50+
}}, []);
3551
36-
const onQueryTypeChange = React.useCallback((e) => {
37-
setQueryType(e.target.value);
38-
}, []);
52+
const radiusLabel = queryType === 'anchor-type-knn' ? 'Max. Dist.' : 'Radius';
3953
4054
return (
4155
<div className="spatial-query">
42-
<p>Spatial Query Manager</p>
56+
<p>SpatialQuery Manager</p>
4357
<label>
4458
Query type&nbsp;
45-
<select onChange={onQueryTypeChange}>
59+
<select value={{queryType}} onChange={{onQueryTypeChange}}>
4660
<option value="grid">Grid-based</option>
4761
<option value="rand">Random-based</option>
48-
<option value="ct-center" disabled={cellTypeOfInterest === null}>Cell type of interest</option>
62+
<option value="anchor-type-knn">Anchor cell - kNN</option>
63+
<option value="anchor-type-dist">Anchor cell - Radius</option>
4964
</select>
5065
</label>
5166
<br/>
67+
{{isAnchorMode && (
5268
<label>
53-
{/* Maximum distance to consider a cell as a neighbor. */}
54-
Max. Dist.
55-
<input type="range" value={maxDist} onChange={e => setMaxDist(parseFloat(e.target.value))} min={1} max={20} step={1} />
56-
{maxDist}
69+
Anchor cell&nbsp;
70+
<select value={{anchorCellType}} onChange={{e => setAnchorCellType(e.target.value)}}>
71+
{{CELL_TYPE_LIST.map(ct => (
72+
<option key={{ct}} value={{ct}}>{{ct}}</option>
73+
))}}
74+
</select>
5775
</label>
58-
<br/>
76+
)}}
77+
{{isAnchorMode && <br/>}}
5978
<label>
60-
{/* Minimum neighborhood size for each point to consider. */}
61-
Min. Size
62-
<input type="range" value={minSize} onChange={e => setMinSize(parseFloat(e.target.value))} min={0} max={20} step={0.5} />
63-
{minSize}
79+
{{radiusLabel}}
80+
<input type="range" value={{maxDist}} onChange={{e => setMaxDist(parseFloat(e.target.value))}} min={{5}} max={{30}} step={{1}} />
81+
{{maxDist}}
6482
</label>
6583
<br/>
84+
{{queryType === 'anchor-type-knn' && (
6685
<label>
67-
{/* Minimum number of cell type to consider. */}
68-
Min. Count
69-
<input type="range" value={minCount} onChange={e => setMinCount(parseFloat(e.target.value))} min={0} max={100} step={1} />
70-
{minCount}
86+
k (neighbors)
87+
<input type="range" value={{k}} onChange={{e => setK(parseFloat(e.target.value))}} min={{5}} max={{50}} step={{1}} />
88+
{{k}}
89+
<br/>
7190
</label>
72-
<br/>
91+
)}}
92+
{{queryType === 'rand' && (
93+
<label>
94+
Sample points
95+
<input type="range" value={{nPoints}} onChange={{e => setNPoints(parseFloat(e.target.value))}} min={{100}} max={{5000}} step={{100}} />
96+
{{nPoints}}
97+
<br/>
98+
</label>
99+
)}}
73100
<label>
74-
{/* Threshold of frequency to consider a pattern as a frequent pattern. */}
75101
Min. Support
76-
<input type="range" value={minSupport} onChange={e => setMinSupport(parseFloat(e.target.value))} min={0} max={1} step={0.01} />
77-
{minSupport}
102+
<input type="range" value={{minSupport}} onChange={{e => setMinSupport(parseFloat(e.target.value))}} min={{0}} max={{1}} step={{0.01}} />
103+
{{minSupport}}
78104
</label>
79105
<br/>
80-
{/* TODO: disDuplicates: Distinguish duplicates in patterns. */}
81-
<button onClick={(e) => {
82-
setQueryParams({
83-
cellTypeOfInterest,
106+
<button onClick={{(e) => {{
107+
setQueryParams({{
108+
cellTypeOfInterest: isAnchorMode ? anchorCellType : null,
84109
queryType,
85110
maxDist,
86-
minSize,
87-
minCount,
88111
minSupport,
112+
k,
113+
nPoints,
89114
uuid,
90-
});
115+
}});
91116
setUuid(uuid+1);
92-
}}>Find patterns</button>
117+
}}}}>Find patterns</button>
93118
</div>
94119
);
95-
}
120+
}}
96121
97122
const pluginCoordinationTypes = [
98-
new PluginCoordinationType('queryParams', null, z.object({
123+
new PluginCoordinationType('queryParams', null, z.object({{
99124
cellTypeOfInterest: z.string().nullable(),
100-
queryType: z.enum(['grid', 'rand', 'ct-center']),
125+
queryType: z.enum(['grid', 'rand', 'anchor-type-knn', 'anchor-type-dist']),
101126
maxDist: z.number(),
102-
minSize: z.number(),
103-
minCount: z.number(),
104127
minSupport: z.number(),
105-
disDuplicates: z.boolean(),
128+
k: z.number(),
129+
nPoints: z.number(),
106130
uuid: z.number(),
107-
}).partial().nullable()),
131+
}}).partial().nullable()),
108132
];
109133
110134
const pluginViewTypes = [
111-
new PluginViewType('spatialQuery', SpatialQueryView, ['queryParams', 'obsSetSelection', 'obsType']),
135+
new PluginViewType('spatialQuery', SpatialQueryView, ['queryParams', 'obsType']),
112136
];
113-
return { pluginViewTypes, pluginCoordinationTypes };
114-
}
115-
export default { createPlugins };
116-
""")
137+
return {{ pluginViewTypes, pluginCoordinationTypes }};
138+
}}
139+
export default {{ createPlugins }};
140+
"""
141+
return transform(js_source)
117142

118143

119144
class SpatialQueryPlugin(VitesscePlugin):
120145
"""
121146
Spatial-Query plugin view renders controls to change parameters passed to the Spatial-Query methods.
122147
"""
123-
plugin_esm = PLUGIN_ESM
124148
commands = {}
125149

126-
def __init__(self,
127-
adata,
128-
spatial_key="X_spatial",
150+
def __init__(self,
151+
adata,
152+
spatial_key="X_spatial",
129153
label_key="cell_type",
130154
feature_name="gene_name",
131155
if_lognorm=True,
@@ -156,10 +180,10 @@ def __init__(self,
156180
self.label_key = label_key
157181

158182
self.tt = spatial_query(
159-
adata=adata,
160-
dataset='test',
161-
spatial_key=spatial_key,
162-
label_key=label_key,
183+
adata=adata,
184+
dataset='test',
185+
spatial_key=spatial_key,
186+
label_key=label_key,
163187
feature_name=feature_name,
164188
leaf_size=10,
165189
build_gene_index=False,
@@ -168,14 +192,19 @@ def __init__(self,
168192

169193
self.tab20_rgb = [[int(r * 255), int(g * 255), int(b * 255)] for (r, g, b, a) in [plt.cm.tab20(i) for i in range(20)]]
170194

195+
cell_type_list = adata.obs[label_key].unique().tolist()
196+
self.cell_type_list = cell_type_list
197+
self.initial_query_params = {}
198+
199+
# Build ESM with cell type list embedded directly as a JS constant
200+
self.plugin_esm = _build_plugin_esm(cell_type_list)
201+
171202
self.additional_obs_sets = {
172203
"version": "0.1.3",
173204
"tree": [
174205
{
175206
"name": "SpatialQuery Results",
176-
"children": [
177-
178-
]
207+
"children": []
179208
}
180209
]
181210
}
@@ -228,13 +257,18 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
228257
}
229258

230259
obs_set_color = []
260+
n_motifs = len(fp_tree)
231261

232-
for row_i, row in fp_tree.iterrows():
262+
for motif_i, (row_i, row) in enumerate(fp_tree.iterrows()):
233263
try:
234264
motif = row["itemsets"]
235265
except KeyError:
236266
motif = row["motifs"]
237-
cell_i = row["neighbor_id"]
267+
# anchor-type queries: use neighbor_id for motif cells, grid/rand use cell_id
268+
if "neighbor_id" in row.index:
269+
cell_i = row["neighbor_id"]
270+
else:
271+
cell_i = row["cell_id"]
238272

239273
motif_name = str(list(motif))
240274

@@ -249,9 +283,12 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
249283
]
250284
})
251285

252-
first_ct_color = self.ct_to_color.get(list(motif)[0], [255, 255, 255])
286+
# Assign each motif a unique color by evenly spacing hues around the color wheel
287+
hue = motif_i / max(n_motifs, 1)
288+
r, g, b = colorsys.hls_to_rgb(hue, 0.55, 0.75)
289+
motif_color = [int(r * 255), int(g * 255), int(b * 255)]
253290
obs_set_color.append({
254-
"color": first_ct_color,
291+
"color": motif_color,
255292
"path": [additional_obs_sets["tree"][0]["name"], motif_name]
256293
})
257294

@@ -267,45 +304,50 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
267304
def run_sq(self, prev_config):
268305
query_params = prev_config["coordinationSpace"]["queryParams"]["A"]
269306

270-
max_dist = query_params.get("maxDist", 150)
271-
min_size = query_params.get("minSize", 4)
272-
# min_count = query_params.get("minCount", 10)
307+
max_dist = query_params.get("maxDist", 10)
273308
min_support = query_params.get("minSupport", 0.5)
274-
# dis_duplicates = query_params.get("disDuplicates", False) # if distinguish duplicates of cell types in neighborhood
309+
k = query_params.get("k", 20)
310+
n_points = query_params.get("nPoints", 1000)
275311
query_type = query_params.get("queryType", "grid")
276312
cell_type_of_interest = query_params.get("cellTypeOfInterest", None)
277313

278314
query_uuid = query_params["uuid"]
279315

280-
params_dict = dict(
281-
max_dist=max_dist,
282-
min_size=min_size,
283-
# min_count=min_count,
284-
min_support=min_support,
285-
# dis_duplicates=dis_duplicates,
286-
if_display=True,
287-
figsize=(9, 6),
288-
return_cellID=True,
289-
)
290-
print(params_dict)
291-
292-
# TODO: add unit tests for this functionality
316+
print(query_params)
293317

294318
if query_type == "rand":
295-
# TODO: implement param similar to return_grid for find_patterns_rand (to return the random points used)
296-
fp_tree = self.tt.find_patterns_rand(**params_dict)
319+
fp_tree = self.tt.find_patterns_rand(
320+
max_dist=max_dist,
321+
n_points=int(n_points),
322+
min_support=min_support,
323+
if_display=True,
324+
figsize=(9, 6),
325+
return_cellID=True,
326+
)
297327
elif query_type == "grid":
298-
params_dict["return_grid"] = True
299-
fp_tree, grid_pos = self.tt.find_patterns_grid(**params_dict)
300-
elif query_type == "ct-center":
328+
fp_tree, grid_pos = self.tt.find_patterns_grid(
329+
max_dist=max_dist,
330+
min_support=min_support,
331+
if_display=True,
332+
figsize=(9, 6),
333+
return_cellID=True,
334+
return_grid=True,
335+
)
336+
elif query_type == "anchor-type-knn":
301337
fp_tree = self.tt.motif_enrichment_knn(
302338
ct=cell_type_of_interest,
303-
k=20, # TODO: make this a parameter in the UI.
339+
k=int(k),
340+
max_dist=max_dist,
341+
min_support=min_support,
342+
return_cellID=True,
343+
)
344+
elif query_type == "anchor-type-dist":
345+
fp_tree = self.tt.motif_enrichment_dist(
346+
ct=cell_type_of_interest,
347+
max_dist=max_dist,
304348
min_support=min_support,
305-
# dis_duplicates=dis_duplicates,
306349
return_cellID=True,
307350
)
308-
print(fp_tree)
309351

310352
# TODO: implement query types that are dependent on motif selection.
311353

@@ -316,7 +358,12 @@ def run_sq(self, prev_config):
316358
# Perform query
317359
(new_additional_obs_sets, new_obs_set_color) = self.fp_tree_to_obs_sets_tree(fp_tree, query_uuid)
318360

319-
additional_obs_sets["tree"][0] = new_additional_obs_sets["tree"][0]
361+
new_sq_node = new_additional_obs_sets["tree"][0]
362+
sq_idx = next((i for i, n in enumerate(additional_obs_sets["tree"]) if n["name"].startswith("SpatialQuery Results")), None)
363+
if sq_idx is not None:
364+
additional_obs_sets["tree"][sq_idx] = new_sq_node
365+
else:
366+
additional_obs_sets["tree"].append(new_sq_node)
320367
prev_config["coordinationSpace"]["additionalObsSets"]["A"] = additional_obs_sets
321368

322369
obs_set_color += new_obs_set_color

0 commit comments

Comments
 (0)