Skip to content

Commit 13ad648

Browse files
committed
add HeatmapView of SpatialQuery motif
1 parent 3a96130 commit 13ad648

1 file changed

Lines changed: 173 additions & 14 deletions

File tree

src/vitessce/widget_plugins/spatial_query.py

Lines changed: 173 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import base64
12
import colorsys
3+
import io
24
import json
35
from oxc_py import transform
46
from ..widget import VitesscePlugin
@@ -16,6 +18,7 @@ def _build_plugin_esm(cell_type_list):
1618
PluginJointFileType,
1719
z,
1820
useCoordination,
21+
invokeCommand,
1922
}} = utilsForPlugins;
2023
2124
const CELL_TYPE_LIST = {ct_list_js};
@@ -75,20 +78,20 @@ def _build_plugin_esm(cell_type_list):
7578
</label>
7679
)}}
7780
{{isAnchorMode && <br/>}}
78-
<label>
79-
{{radiusLabel}}
80-
<input type="range" value={{maxDist}} onChange={{e => setMaxDist(parseFloat(e.target.value))}} min={{5}} max={{30}} step={{1}} />
81-
{{maxDist}}
82-
</label>
83-
<br/>
8481
{{queryType === 'anchor-type-knn' && (
8582
<label>
8683
k (neighbors)
87-
<input type="range" value={{k}} onChange={{e => setK(parseFloat(e.target.value))}} min={{5}} max={{50}} step={{1}} />
84+
<input type="range" value={{k}} onChange={{e => setK(parseFloat(e.target.value))}} min={{5}} max={{100}} step={{1}} />
8885
{{k}}
8986
<br/>
9087
</label>
9188
)}}
89+
<label>
90+
{{radiusLabel}}
91+
<input type="range" value={{maxDist}} onChange={{e => setMaxDist(parseFloat(e.target.value))}} min={{5}} max={{30}} step={{1}} />
92+
{{maxDist}}
93+
</label>
94+
<br/>
9295
{{queryType === 'rand' && (
9396
<label>
9497
Sample points
@@ -119,6 +122,37 @@ def _build_plugin_esm(cell_type_list):
119122
);
120123
}}
121124
125+
function HeatmapView(props) {{
126+
const {{ coordinationScopes }} = props;
127+
const [{{ queryParams }}] = useCoordination(['queryParams'], coordinationScopes);
128+
129+
const [imgSrc, setImgSrc] = React.useState(null);
130+
const [loading, setLoading] = React.useState(false);
131+
132+
const uuid = queryParams?.uuid;
133+
134+
React.useEffect(() => {{
135+
if (uuid == null) return;
136+
setLoading(true);
137+
invokeCommand('get_heatmap', {{ uuid }}, []).then(([result]) => {{
138+
if (result?.img) {{
139+
setImgSrc('data:image/png;base64,' + result.img);
140+
}} else {{
141+
setImgSrc(null);
142+
}}
143+
setLoading(false);
144+
}}).catch(() => setLoading(false));
145+
}}, [uuid]);
146+
147+
return (
148+
<div className="spatial-query-heatmap" style={{{{ width: '100%', height: '100%', overflow: 'auto', display: 'flex', alignItems: 'center', justifyContent: 'center' }}}}>
149+
{{loading && <p>Loading heatmap...</p>}}
150+
{{!loading && imgSrc && <img src={{imgSrc}} style={{{{ maxWidth: '100%', maxHeight: '100%' }}}} />}}
151+
{{!loading && !imgSrc && <p style={{{{ color: '#888' }}}}>Run a query to see the heatmap.</p>}}
152+
</div>
153+
);
154+
}}
155+
122156
const pluginCoordinationTypes = [
123157
new PluginCoordinationType('queryParams', null, z.object({{
124158
cellTypeOfInterest: z.string().nullable(),
@@ -133,6 +167,7 @@ def _build_plugin_esm(cell_type_list):
133167
134168
const pluginViewTypes = [
135169
new PluginViewType('spatialQuery', SpatialQueryView, ['queryParams', 'obsType']),
170+
new PluginViewType('spatialQueryHeatmap', HeatmapView, ['queryParams']),
136171
];
137172
return {{ pluginViewTypes, pluginCoordinationTypes }};
138173
}}
@@ -145,7 +180,6 @@ class SpatialQueryPlugin(VitesscePlugin):
145180
"""
146181
Spatial-Query plugin view renders controls to change parameters passed to the Spatial-Query methods.
147182
"""
148-
commands = {}
149183

150184
def __init__(self,
151185
adata,
@@ -225,6 +259,88 @@ def __init__(self,
225259
self.cell_i_to_cell_id = dict(zip(range(adata.obs.shape[0]), adata.obs.index.tolist()))
226260
self.cell_id_to_cell_type = dict(zip(adata.obs.index.tolist(), adata.obs[label_key].tolist()))
227261

262+
self._last_fp_tree = None
263+
self._last_query_type = None
264+
self._last_cell_type_of_interest = None
265+
266+
self.commands = {"get_heatmap": self._handle_get_heatmap}
267+
268+
def _render_heatmap_base64(self, fp_tree, query_type, cell_type_of_interest):
269+
import matplotlib
270+
import matplotlib.pyplot as plt
271+
import seaborn as sns
272+
_prev_backend = matplotlib.get_backend()
273+
matplotlib.use("Agg")
274+
275+
is_enrichment = query_type in ("anchor-type-knn", "anchor-type-dist")
276+
277+
if is_enrichment:
278+
enrich = fp_tree.copy()
279+
enrich["frequency"] = enrich["n_center_motif"] / enrich["n_center"]
280+
enrich = enrich.sort_values(by="frequency", ascending=False)
281+
enrich["motif_group"] = [f"motif_{i+1}" for i in range(len(enrich))]
282+
expanded = enrich.explode("motifs")
283+
heatmap_data = expanded.pivot_table(
284+
index="motifs", columns="motif_group", values="frequency", aggfunc="first"
285+
)
286+
# Sort columns by descending frequency of each motif
287+
col_order = enrich.sort_values("frequency", ascending=False)["motif_group"].tolist()
288+
heatmap_data = heatmap_data[[c for c in col_order if c in heatmap_data.columns]]
289+
cbar_label = "Frequency"
290+
if cell_type_of_interest:
291+
title = f"Distribution of enriched motifs around {cell_type_of_interest}"
292+
else:
293+
title = "Distribution of enriched motifs"
294+
xlabel = "Motifs"
295+
else:
296+
fp = fp_tree.copy()
297+
fp = fp.sort_values(by="support", ascending=False).reset_index(drop=True)
298+
fp["motif_group"] = [f"motif_{i+1}" for i in range(len(fp))]
299+
expanded = fp.explode("itemsets")
300+
heatmap_data = expanded.pivot_table(
301+
index="itemsets", columns="motif_group", values="support", aggfunc="first"
302+
)
303+
# Sort columns by descending support of each motif
304+
col_order = fp.sort_values("support", ascending=False)["motif_group"].tolist()
305+
heatmap_data = heatmap_data[[c for c in col_order if c in heatmap_data.columns]]
306+
cbar_label = "Support"
307+
title = "Distribution of frequent patterns"
308+
xlabel = "Patterns"
309+
310+
fig, ax = plt.subplots(figsize=(7, 5))
311+
sns.heatmap(
312+
heatmap_data,
313+
cmap="GnBu",
314+
linewidths=0.1,
315+
linecolor="lightgrey",
316+
annot=False,
317+
cbar_kws={"label": cbar_label},
318+
ax=ax,
319+
)
320+
ax.set_title(title, fontsize=13, pad=12)
321+
ax.set_ylabel("")
322+
ax.set_xlabel(xlabel, fontsize=11)
323+
ax.tick_params(axis="x", rotation=90, labelsize=10)
324+
ax.tick_params(axis="y", rotation=0, labelsize=10)
325+
fig.tight_layout()
326+
327+
buf = io.BytesIO()
328+
fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
329+
plt.close(fig)
330+
matplotlib.use(_prev_backend)
331+
buf.seek(0)
332+
return base64.b64encode(buf.read()).decode("utf-8")
333+
334+
def _handle_get_heatmap(self, _message, _buffers):
335+
if self._last_fp_tree is None:
336+
return {"img": None}, []
337+
img_b64 = self._render_heatmap_base64(
338+
self._last_fp_tree,
339+
self._last_query_type,
340+
self._last_cell_type_of_interest,
341+
)
342+
return {"img": img_b64}, []
343+
228344
def get_matching_cell_ids(self, cell_type, cell_i):
229345
cell_ids = [self.cell_i_to_cell_id[i] for i in cell_i]
230346
matches = []
@@ -234,7 +350,7 @@ def get_matching_cell_ids(self, cell_type, cell_i):
234350
matches.append([cell_id, None])
235351
return matches
236352

237-
def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
353+
def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id, anchor_ct=None):
238354
sq_motif_name = f"SpatialQuery Results {sq_id} — By Motif"
239355
sq_ct_name = f"SpatialQuery Results {sq_id} — By Cell Type"
240356

@@ -248,7 +364,8 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
248364
except KeyError:
249365
motif = row["motifs"]
250366
cell_i = row["neighbor_id"] if "neighbor_id" in row.index else row["cell_id"]
251-
motif_rows.append((motif, cell_i))
367+
center_i = row["center_id"] if "center_id" in row.index else None
368+
motif_rows.append((motif, cell_i, center_i))
252369
for cell_type in motif:
253370
matching = {i for i in cell_i if self.cell_id_to_cell_type.get(self.cell_i_to_cell_id.get(i)) == cell_type}
254371
ct_to_cell_ids.setdefault(cell_type, set()).update(matching)
@@ -258,7 +375,7 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
258375

259376
# Node 1: "By Motif" — each motif is a leaf, colored by motif index
260377
by_motif_children = []
261-
for motif_i, (motif, cell_i) in enumerate(motif_rows):
378+
for motif_i, (motif, cell_i, _center_i) in enumerate(motif_rows):
262379
motif_name = str(list(motif))
263380
hue = motif_i / max(n_motifs, 1)
264381
r, g, b = colorsys.hls_to_rgb(hue, 0.55, 0.75)
@@ -283,12 +400,48 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
283400
})
284401
obs_set_color.append({"color": self.ct_to_color[cell_type], "path": [sq_ct_name, cell_type]})
285402

403+
# Nodes 3+: one top-level node per motif, children are cell types within that motif.
404+
# For anchor queries, also include an "{anchor_ct} (anchor)" child for center cells.
405+
per_motif_nodes = []
406+
for motif_i, (motif, cell_i, center_i) in enumerate(motif_rows):
407+
motif_name = str(list(motif))
408+
node_name = f"SpatialQuery Results {sq_id} \u2014 Motif {motif_i + 1}: {motif_name}"
409+
motif_cell_types = set(motif)
410+
motif_ct_children = []
411+
for cell_type in motif_cell_types:
412+
ct_ids_in_motif = list(dict.fromkeys(
413+
i for i in cell_i
414+
if self.cell_id_to_cell_type.get(self.cell_i_to_cell_id.get(i)) == cell_type
415+
))
416+
motif_ct_children.append({
417+
"name": cell_type,
418+
"set": [[self.cell_i_to_cell_id[i], None] for i in ct_ids_in_motif if i in self.cell_i_to_cell_id]
419+
})
420+
obs_set_color.append({
421+
"color": self.ct_to_color[cell_type],
422+
"path": [node_name, cell_type]
423+
})
424+
# Add anchor child node if this is an anchor-mode query
425+
if anchor_ct is not None and center_i is not None and len(center_i) > 0:
426+
anchor_label = f"{anchor_ct} (anchor)"
427+
anchor_ids = list(dict.fromkeys(int(i) for i in center_i if i in self.cell_i_to_cell_id))
428+
motif_ct_children.append({
429+
"name": anchor_label,
430+
"set": [[self.cell_i_to_cell_id[i], None] for i in anchor_ids]
431+
})
432+
obs_set_color.append({
433+
"color": self.ct_to_color.get(anchor_ct, [200, 200, 200]),
434+
"path": [node_name, anchor_label]
435+
})
436+
per_motif_nodes.append({"name": node_name, "children": motif_ct_children})
437+
obs_set_color.append({"color": [255, 255, 255], "path": [node_name]})
438+
286439
additional_obs_sets = {
287440
"version": "0.1.3",
288441
"tree": [
289442
{"name": sq_motif_name, "children": by_motif_children},
290443
{"name": sq_ct_name, "children": by_ct_children},
291-
]
444+
] + per_motif_nodes
292445
}
293446

294447
obs_set_color.insert(0, {"color": [255, 255, 255], "path": [sq_motif_name]})
@@ -336,22 +489,28 @@ def run_sq(self, prev_config):
336489
min_support=min_support,
337490
return_cellID=True,
338491
)
492+
fp_tree = fp_tree[fp_tree["if_significant"]]
339493
elif query_type == "anchor-type-dist":
340494
fp_tree = self.tt.motif_enrichment_dist(
341495
ct=cell_type_of_interest,
342496
max_dist=max_dist,
343497
min_support=min_support,
344498
return_cellID=True,
345499
)
500+
fp_tree = fp_tree[fp_tree["if_significant"]]
346501

347-
# TODO: implement query types that are dependent on motif selection.
502+
# Cache for heatmap rendering
503+
self._last_fp_tree = fp_tree
504+
self._last_query_type = query_type
505+
self._last_cell_type_of_interest = cell_type_of_interest
348506

349507
# Previous values
350508
additional_obs_sets = prev_config["coordinationSpace"]["additionalObsSets"]["A"]
351509
obs_set_color = prev_config["coordinationSpace"]["obsSetColor"]["A"]
352510

353511
# Perform query
354-
(new_additional_obs_sets, new_obs_set_color) = self.fp_tree_to_obs_sets_tree(fp_tree, query_uuid)
512+
anchor_ct = cell_type_of_interest if query_type in ("anchor-type-knn", "anchor-type-dist") else None
513+
(new_additional_obs_sets, new_obs_set_color) = self.fp_tree_to_obs_sets_tree(fp_tree, query_uuid, anchor_ct=anchor_ct)
355514

356515
# Replace any existing SpatialQuery Results nodes (both By Motif and By Cell Type)
357516
existing_tree = additional_obs_sets["tree"]

0 commit comments

Comments
 (0)