1+ import base64
12import colorsys
3+ import io
24import json
35from oxc_py import transform
46from ..widget import VitesscePlugin
@@ -16,6 +18,7 @@ def _build_plugin_esm(cell_type_list):
1618 PluginJointFileType,
1719 z,
1820 useCoordination,
21+ invokeCommand,
1922 }} = utilsForPlugins;
2023
2124 const CELL_TYPE_LIST = { ct_list_js } ;
@@ -75,20 +78,20 @@ def _build_plugin_esm(cell_type_list):
7578 </label>
7679 )}}
7780 {{isAnchorMode && <br/>}}
78- <label>
79- {{radiusLabel}}
80- <input type="range" value={{maxDist}} onChange={{e => setMaxDist(parseFloat(e.target.value))}} min={{5}} max={{30}} step={{1}} />
81- {{maxDist}}
82- </label>
83- <br/>
8481 {{queryType === 'anchor-type-knn' && (
8582 <label>
8683 k (neighbors)
87- <input type="range" value={{k}} onChange={{e => setK(parseFloat(e.target.value))}} min={{5}} max={{50 }} step={{1}} />
84+ <input type="range" value={{k}} onChange={{e => setK(parseFloat(e.target.value))}} min={{5}} max={{100 }} step={{1}} />
8885 {{k}}
8986 <br/>
9087 </label>
9188 )}}
89+ <label>
90+ {{radiusLabel}}
91+ <input type="range" value={{maxDist}} onChange={{e => setMaxDist(parseFloat(e.target.value))}} min={{5}} max={{30}} step={{1}} />
92+ {{maxDist}}
93+ </label>
94+ <br/>
9295 {{queryType === 'rand' && (
9396 <label>
9497 Sample points
@@ -119,6 +122,37 @@ def _build_plugin_esm(cell_type_list):
119122 );
120123 }}
121124
125+ function HeatmapView(props) {{
126+ const {{ coordinationScopes }} = props;
127+ const [{{ queryParams }}] = useCoordination(['queryParams'], coordinationScopes);
128+
129+ const [imgSrc, setImgSrc] = React.useState(null);
130+ const [loading, setLoading] = React.useState(false);
131+
132+ const uuid = queryParams?.uuid;
133+
134+ React.useEffect(() => {{
135+ if (uuid == null) return;
136+ setLoading(true);
137+ invokeCommand('get_heatmap', {{ uuid }}, []).then(([result]) => {{
138+ if (result?.img) {{
139+ setImgSrc('data:image/png;base64,' + result.img);
140+ }} else {{
141+ setImgSrc(null);
142+ }}
143+ setLoading(false);
144+ }}).catch(() => setLoading(false));
145+ }}, [uuid]);
146+
147+ return (
148+ <div className="spatial-query-heatmap" style={{{{ width: '100%', height: '100%', overflow: 'auto', display: 'flex', alignItems: 'center', justifyContent: 'center' }}}}>
149+ {{loading && <p>Loading heatmap...</p>}}
150+ {{!loading && imgSrc && <img src={{imgSrc}} style={{{{ maxWidth: '100%', maxHeight: '100%' }}}} />}}
151+ {{!loading && !imgSrc && <p style={{{{ color: '#888' }}}}>Run a query to see the heatmap.</p>}}
152+ </div>
153+ );
154+ }}
155+
122156 const pluginCoordinationTypes = [
123157 new PluginCoordinationType('queryParams', null, z.object({{
124158 cellTypeOfInterest: z.string().nullable(),
@@ -133,6 +167,7 @@ def _build_plugin_esm(cell_type_list):
133167
134168 const pluginViewTypes = [
135169 new PluginViewType('spatialQuery', SpatialQueryView, ['queryParams', 'obsType']),
170+ new PluginViewType('spatialQueryHeatmap', HeatmapView, ['queryParams']),
136171 ];
137172 return {{ pluginViewTypes, pluginCoordinationTypes }};
138173}}
@@ -145,7 +180,6 @@ class SpatialQueryPlugin(VitesscePlugin):
145180 """
146181 Spatial-Query plugin view renders controls to change parameters passed to the Spatial-Query methods.
147182 """
148- commands = {}
149183
150184 def __init__ (self ,
151185 adata ,
@@ -225,6 +259,88 @@ def __init__(self,
225259 self .cell_i_to_cell_id = dict (zip (range (adata .obs .shape [0 ]), adata .obs .index .tolist ()))
226260 self .cell_id_to_cell_type = dict (zip (adata .obs .index .tolist (), adata .obs [label_key ].tolist ()))
227261
262+ self ._last_fp_tree = None
263+ self ._last_query_type = None
264+ self ._last_cell_type_of_interest = None
265+
266+ self .commands = {"get_heatmap" : self ._handle_get_heatmap }
267+
268+ def _render_heatmap_base64 (self , fp_tree , query_type , cell_type_of_interest ):
269+ import matplotlib
270+ import matplotlib .pyplot as plt
271+ import seaborn as sns
272+ _prev_backend = matplotlib .get_backend ()
273+ matplotlib .use ("Agg" )
274+
275+ is_enrichment = query_type in ("anchor-type-knn" , "anchor-type-dist" )
276+
277+ if is_enrichment :
278+ enrich = fp_tree .copy ()
279+ enrich ["frequency" ] = enrich ["n_center_motif" ] / enrich ["n_center" ]
280+ enrich = enrich .sort_values (by = "frequency" , ascending = False )
281+ enrich ["motif_group" ] = [f"motif_{ i + 1 } " for i in range (len (enrich ))]
282+ expanded = enrich .explode ("motifs" )
283+ heatmap_data = expanded .pivot_table (
284+ index = "motifs" , columns = "motif_group" , values = "frequency" , aggfunc = "first"
285+ )
286+ # Sort columns by descending frequency of each motif
287+ col_order = enrich .sort_values ("frequency" , ascending = False )["motif_group" ].tolist ()
288+ heatmap_data = heatmap_data [[c for c in col_order if c in heatmap_data .columns ]]
289+ cbar_label = "Frequency"
290+ if cell_type_of_interest :
291+ title = f"Distribution of enriched motifs around { cell_type_of_interest } "
292+ else :
293+ title = "Distribution of enriched motifs"
294+ xlabel = "Motifs"
295+ else :
296+ fp = fp_tree .copy ()
297+ fp = fp .sort_values (by = "support" , ascending = False ).reset_index (drop = True )
298+ fp ["motif_group" ] = [f"motif_{ i + 1 } " for i in range (len (fp ))]
299+ expanded = fp .explode ("itemsets" )
300+ heatmap_data = expanded .pivot_table (
301+ index = "itemsets" , columns = "motif_group" , values = "support" , aggfunc = "first"
302+ )
303+ # Sort columns by descending support of each motif
304+ col_order = fp .sort_values ("support" , ascending = False )["motif_group" ].tolist ()
305+ heatmap_data = heatmap_data [[c for c in col_order if c in heatmap_data .columns ]]
306+ cbar_label = "Support"
307+ title = "Distribution of frequent patterns"
308+ xlabel = "Patterns"
309+
310+ fig , ax = plt .subplots (figsize = (7 , 5 ))
311+ sns .heatmap (
312+ heatmap_data ,
313+ cmap = "GnBu" ,
314+ linewidths = 0.1 ,
315+ linecolor = "lightgrey" ,
316+ annot = False ,
317+ cbar_kws = {"label" : cbar_label },
318+ ax = ax ,
319+ )
320+ ax .set_title (title , fontsize = 13 , pad = 12 )
321+ ax .set_ylabel ("" )
322+ ax .set_xlabel (xlabel , fontsize = 11 )
323+ ax .tick_params (axis = "x" , rotation = 90 , labelsize = 10 )
324+ ax .tick_params (axis = "y" , rotation = 0 , labelsize = 10 )
325+ fig .tight_layout ()
326+
327+ buf = io .BytesIO ()
328+ fig .savefig (buf , format = "png" , dpi = 150 , bbox_inches = "tight" )
329+ plt .close (fig )
330+ matplotlib .use (_prev_backend )
331+ buf .seek (0 )
332+ return base64 .b64encode (buf .read ()).decode ("utf-8" )
333+
334+ def _handle_get_heatmap (self , _message , _buffers ):
335+ if self ._last_fp_tree is None :
336+ return {"img" : None }, []
337+ img_b64 = self ._render_heatmap_base64 (
338+ self ._last_fp_tree ,
339+ self ._last_query_type ,
340+ self ._last_cell_type_of_interest ,
341+ )
342+ return {"img" : img_b64 }, []
343+
228344 def get_matching_cell_ids (self , cell_type , cell_i ):
229345 cell_ids = [self .cell_i_to_cell_id [i ] for i in cell_i ]
230346 matches = []
@@ -234,7 +350,7 @@ def get_matching_cell_ids(self, cell_type, cell_i):
234350 matches .append ([cell_id , None ])
235351 return matches
236352
237- def fp_tree_to_obs_sets_tree (self , fp_tree , sq_id ):
353+ def fp_tree_to_obs_sets_tree (self , fp_tree , sq_id , anchor_ct = None ):
238354 sq_motif_name = f"SpatialQuery Results { sq_id } — By Motif"
239355 sq_ct_name = f"SpatialQuery Results { sq_id } — By Cell Type"
240356
@@ -248,7 +364,8 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
248364 except KeyError :
249365 motif = row ["motifs" ]
250366 cell_i = row ["neighbor_id" ] if "neighbor_id" in row .index else row ["cell_id" ]
251- motif_rows .append ((motif , cell_i ))
367+ center_i = row ["center_id" ] if "center_id" in row .index else None
368+ motif_rows .append ((motif , cell_i , center_i ))
252369 for cell_type in motif :
253370 matching = {i for i in cell_i if self .cell_id_to_cell_type .get (self .cell_i_to_cell_id .get (i )) == cell_type }
254371 ct_to_cell_ids .setdefault (cell_type , set ()).update (matching )
@@ -258,7 +375,7 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
258375
259376 # Node 1: "By Motif" — each motif is a leaf, colored by motif index
260377 by_motif_children = []
261- for motif_i , (motif , cell_i ) in enumerate (motif_rows ):
378+ for motif_i , (motif , cell_i , _center_i ) in enumerate (motif_rows ):
262379 motif_name = str (list (motif ))
263380 hue = motif_i / max (n_motifs , 1 )
264381 r , g , b = colorsys .hls_to_rgb (hue , 0.55 , 0.75 )
@@ -283,12 +400,48 @@ def fp_tree_to_obs_sets_tree(self, fp_tree, sq_id):
283400 })
284401 obs_set_color .append ({"color" : self .ct_to_color [cell_type ], "path" : [sq_ct_name , cell_type ]})
285402
403+ # Nodes 3+: one top-level node per motif, children are cell types within that motif.
404+ # For anchor queries, also include an "{anchor_ct} (anchor)" child for center cells.
405+ per_motif_nodes = []
406+ for motif_i , (motif , cell_i , center_i ) in enumerate (motif_rows ):
407+ motif_name = str (list (motif ))
408+ node_name = f"SpatialQuery Results { sq_id } \u2014 Motif { motif_i + 1 } : { motif_name } "
409+ motif_cell_types = set (motif )
410+ motif_ct_children = []
411+ for cell_type in motif_cell_types :
412+ ct_ids_in_motif = list (dict .fromkeys (
413+ i for i in cell_i
414+ if self .cell_id_to_cell_type .get (self .cell_i_to_cell_id .get (i )) == cell_type
415+ ))
416+ motif_ct_children .append ({
417+ "name" : cell_type ,
418+ "set" : [[self .cell_i_to_cell_id [i ], None ] for i in ct_ids_in_motif if i in self .cell_i_to_cell_id ]
419+ })
420+ obs_set_color .append ({
421+ "color" : self .ct_to_color [cell_type ],
422+ "path" : [node_name , cell_type ]
423+ })
424+ # Add anchor child node if this is an anchor-mode query
425+ if anchor_ct is not None and center_i is not None and len (center_i ) > 0 :
426+ anchor_label = f"{ anchor_ct } (anchor)"
427+ anchor_ids = list (dict .fromkeys (int (i ) for i in center_i if i in self .cell_i_to_cell_id ))
428+ motif_ct_children .append ({
429+ "name" : anchor_label ,
430+ "set" : [[self .cell_i_to_cell_id [i ], None ] for i in anchor_ids ]
431+ })
432+ obs_set_color .append ({
433+ "color" : self .ct_to_color .get (anchor_ct , [200 , 200 , 200 ]),
434+ "path" : [node_name , anchor_label ]
435+ })
436+ per_motif_nodes .append ({"name" : node_name , "children" : motif_ct_children })
437+ obs_set_color .append ({"color" : [255 , 255 , 255 ], "path" : [node_name ]})
438+
286439 additional_obs_sets = {
287440 "version" : "0.1.3" ,
288441 "tree" : [
289442 {"name" : sq_motif_name , "children" : by_motif_children },
290443 {"name" : sq_ct_name , "children" : by_ct_children },
291- ]
444+ ] + per_motif_nodes
292445 }
293446
294447 obs_set_color .insert (0 , {"color" : [255 , 255 , 255 ], "path" : [sq_motif_name ]})
@@ -336,22 +489,28 @@ def run_sq(self, prev_config):
336489 min_support = min_support ,
337490 return_cellID = True ,
338491 )
492+ fp_tree = fp_tree [fp_tree ["if_significant" ]]
339493 elif query_type == "anchor-type-dist" :
340494 fp_tree = self .tt .motif_enrichment_dist (
341495 ct = cell_type_of_interest ,
342496 max_dist = max_dist ,
343497 min_support = min_support ,
344498 return_cellID = True ,
345499 )
500+ fp_tree = fp_tree [fp_tree ["if_significant" ]]
346501
347- # TODO: implement query types that are dependent on motif selection.
502+ # Cache for heatmap rendering
503+ self ._last_fp_tree = fp_tree
504+ self ._last_query_type = query_type
505+ self ._last_cell_type_of_interest = cell_type_of_interest
348506
349507 # Previous values
350508 additional_obs_sets = prev_config ["coordinationSpace" ]["additionalObsSets" ]["A" ]
351509 obs_set_color = prev_config ["coordinationSpace" ]["obsSetColor" ]["A" ]
352510
353511 # Perform query
354- (new_additional_obs_sets , new_obs_set_color ) = self .fp_tree_to_obs_sets_tree (fp_tree , query_uuid )
512+ anchor_ct = cell_type_of_interest if query_type in ("anchor-type-knn" , "anchor-type-dist" ) else None
513+ (new_additional_obs_sets , new_obs_set_color ) = self .fp_tree_to_obs_sets_tree (fp_tree , query_uuid , anchor_ct = anchor_ct )
355514
356515 # Replace any existing SpatialQuery Results nodes (both By Motif and By Cell Type)
357516 existing_tree = additional_obs_sets ["tree" ]
0 commit comments