Skip to content

Commit e5e819c

Browse files
authored
Add helpers to extract data from NodeResponse arcs (#246)
This adds two helper methods to simplify reading data from a NodeResponse. `extract_connected_dcids`: * An example use case: Starting from a gene, I want to see associated diseases and indicators for that accociations. To accomplish this, I must find all of the dcids for the gene-disease-association nodes, then perform a follow up query for properties of those dcids. Using this helper method, it would look like: ``` node_client = dc_client.node resp = node_client.fetch("bio/APOE", "<-geneID") gene_association_dcids = resp.extract_connected_dcids("bio/APOE", "geneID", connected_node_types="DiseaseGeneAssociation") association_metdata = node_client.fetch(gene_association_dcids, "->[relationshipAssociationType,relationshipEvidenceType, confidence]") ``` * The helper method makes it easy to extract just the dcids from the response that I care about without having to really dive into what the structure of the response is. `extract_connected_nodes`: This has a similar use case to the first helper method, but returns the nodes for the arc that I'm interested in, instead of just the DCIDs. --- This is a revised version of closed #238 @jm-rivera Let me know what you think! I can provide some more use cases if that would be helpful. Would it be more helpful to rename these "extract_nodes_from_arc" / "extract_connected_dcids_from_arc"? Granted, those are pretty long names....
1 parent 3559fc9 commit e5e819c

2 files changed

Lines changed: 299 additions & 1 deletion

File tree

datacommons_client/endpoints/response.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from dataclasses import dataclass
22
from dataclasses import field
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional
44

55
from datacommons_client.models.base import SerializableMixin
66
from datacommons_client.models.node import Arcs
77
from datacommons_client.models.node import NextToken
8+
from datacommons_client.models.node import Node
89
from datacommons_client.models.node import NodeDCID
10+
from datacommons_client.models.node import NodeGroup
911
from datacommons_client.models.node import Properties
12+
from datacommons_client.models.node import Property
1013
from datacommons_client.models.observation import Facet
1114
from datacommons_client.models.observation import facetID
1215
from datacommons_client.models.observation import Variable
@@ -68,6 +71,72 @@ def parse_data(data: Dict[str, Any]) -> Arcs | Properties:
6871
def get_properties(self) -> Dict:
6972
return flatten_properties(self.data)
7073

74+
def extract_connected_nodes(
75+
self,
76+
subject_dcid: NodeDCID,
77+
property_dcid: Property,
78+
connected_node_types: Optional[str | list[str]] = None) -> List[Node]:
79+
"""Retrieves Node objects in the NodeResponse connected to the subject node
80+
via the specified property.
81+
82+
Args:
83+
subject_dcid: The DCID of the starting node in the arc.
84+
property_dcid: The property connecting the subject node to the desired
85+
target nodes.
86+
connected_node_types: Optional. A type or list of types to filter the
87+
connected nodes. If provided, only connected nodes that have at least
88+
one of the specified types will be returned. If omitted, all nodes from
89+
the arc are returned.
90+
91+
Returns:
92+
A list of Node objects that are connected to the subject node via the
93+
specified property.
94+
"""
95+
if isinstance(connected_node_types, str):
96+
connected_node_types = [connected_node_types]
97+
98+
nodes = self.get_properties().get(subject_dcid, {}).get(property_dcid, [])
99+
100+
connected_nodes = []
101+
for node in nodes:
102+
if connected_node_types:
103+
# Filter out nodes that are missing a list of types or do not have the
104+
# desired type
105+
if not node.types or not any(nt in node.types
106+
for nt in connected_node_types):
107+
continue
108+
109+
connected_nodes.append(node)
110+
111+
return connected_nodes
112+
113+
def extract_connected_dcids(
114+
self,
115+
subject_dcid: NodeDCID,
116+
property_dcid: Property,
117+
connected_node_types: Optional[str | list[str]] = None) -> List[NodeDCID]:
118+
"""Retrieves DCIDs of the Nodes in the NodeResponse connected to the subject
119+
node via the specified property.
120+
121+
Args:
122+
subject_dcid: The DCID of the starting node.
123+
property_dcid: The property connecting the subject node to the desired
124+
target nodes.
125+
connected_node_types: Optional. A type or list of types to filter the
126+
connected nodes. If provided, only DCIDs of connected nodes that have at
127+
least one of the specified types will be returned. If omitted, DCIDs of
128+
all nodes from the arc are returned.
129+
130+
Returns:
131+
A list of NodeDCIDs for the nodes connected via the specified property
132+
from the subject node.
133+
"""
134+
135+
connected_nodes = self.extract_connected_nodes(subject_dcid, property_dcid,
136+
connected_node_types)
137+
138+
return [node.dcid for node in connected_nodes if node.dcid]
139+
71140

72141
@dataclass
73142
class ObservationResponse(SerializableMixin):

datacommons_client/tests/endpoints/test_response.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datacommons_client.endpoints.response import NodeResponse
66
from datacommons_client.endpoints.response import ObservationResponse
77
from datacommons_client.endpoints.response import ResolveResponse
8+
from datacommons_client.models.node import Arcs
89
from datacommons_client.models.node import Node
910
from datacommons_client.models.node import NodeGroup
1011
from datacommons_client.models.observation import Facet
@@ -287,6 +288,234 @@ def test_unpack_arcs_multiple_properties():
287288
assert result == expected
288289

289290

291+
def test_extract_connected_dcids():
292+
"""Test that extract_connected_dcids is successful when multiple dcid and multiple
293+
properties are in the response."""
294+
json_data = {
295+
"data": {
296+
"geoId/06": {
297+
"arcs": {
298+
"containedInPlace": {
299+
"nodes": [{
300+
"dcid": "country/USA",
301+
"name": "United States",
302+
"provenanceId": "dc/base/WikidataOtherIdGeos",
303+
"types": ["Country"]
304+
}, {
305+
"dcid": "usc/PacificDivision",
306+
"name": "Pacific Division",
307+
"provenanceId": "dc/base/WikidataOtherIdGeos",
308+
"types": ["CensusDivision"]
309+
}]
310+
},
311+
"name": {
312+
"nodes": [{
313+
"provenanceId": "dc/base/WikidataOtherIdGeos",
314+
"value": "California"
315+
}]
316+
}
317+
}
318+
},
319+
"geoId/07": {
320+
"arcs": {
321+
"containedInPlace": {
322+
"nodes": [{
323+
"dcid": "country/USA",
324+
"name": "United States",
325+
"provenanceId": "dc/base/WikidataOtherIdGeos",
326+
"types": ["Country"]
327+
}]
328+
},
329+
}
330+
}
331+
}
332+
}
333+
response = NodeResponse.from_json(json_data)
334+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
335+
property_dcid='containedInPlace')
336+
assert result == ['country/USA', 'usc/PacificDivision']
337+
338+
339+
def test_extract_connected_dcids_with_nonexistent_dcid():
340+
"""Test that extract_connected_dcids returns empty when requested dcid is not in the
341+
NodeResponse."""
342+
json_data = {
343+
"data": {
344+
"geoId/06": {
345+
"arcs": {
346+
"name": {
347+
"nodes": [{
348+
"provenanceId": "dc/base/WikidataOtherIdGeos",
349+
"value": "California"
350+
}]
351+
}
352+
}
353+
},
354+
}
355+
}
356+
response = NodeResponse.from_json(json_data)
357+
result = response.extract_connected_dcids(subject_dcid='geoId/07',
358+
property_dcid='name')
359+
assert result == []
360+
361+
362+
def test_extract_connected_dcids_with_nonexistent_property():
363+
"""Test that extract_connected_dcids returns empty when requested property is not in
364+
the NodeResponse."""
365+
json_data = {
366+
"data": {
367+
"geoId/06": {
368+
"arcs": {
369+
"name": {
370+
"nodes": [{
371+
"provenanceId": "dc/base/WikidataOtherIdGeos",
372+
"value": "California"
373+
}]
374+
}
375+
}
376+
},
377+
}
378+
}
379+
response = NodeResponse.from_json(json_data)
380+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
381+
property_dcid='containedInPlace')
382+
assert result == []
383+
384+
385+
def test_extract_connected_dcids_does_not_include_none_for_value_only_nodes():
386+
"""Test that extract_connected_dcids does not include None in the returned list
387+
when the nodes in the response only contain values and not dcids."""
388+
json_data = {
389+
"data": {
390+
"geoId/06": {
391+
"arcs": {
392+
"name": {
393+
"nodes": [{
394+
"provenanceId": "dc/base/WikidataOtherIdGeos",
395+
"value": "California"
396+
}]
397+
}
398+
}
399+
},
400+
}
401+
}
402+
response = NodeResponse.from_json(json_data)
403+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
404+
property_dcid='name')
405+
assert result == []
406+
407+
408+
def test_extract_connected_dcids_with_node_type_filter():
409+
"""Test that extract_connected_dcids returns dcids with the corresponding
410+
node_type."""
411+
412+
json_data = {
413+
"data": {
414+
"geoId/06": {
415+
"arcs": {
416+
"relatedPlaces": {
417+
"nodes": [{
418+
"dcid": "country/USA",
419+
"name": "United States",
420+
"provenanceId": "dc/base/WikidataOtherIdGeos",
421+
"types": ["Country"]
422+
}, {
423+
"dcid": "usc/PacificDivision",
424+
"name": "Pacific Division",
425+
"provenanceId": "dc/base/WikidataOtherIdGeos",
426+
"types": ["CensusDivision"]
427+
}, {
428+
"dcid": "node3",
429+
}]
430+
}
431+
}
432+
},
433+
}
434+
}
435+
response = NodeResponse.from_json(json_data)
436+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
437+
property_dcid='relatedPlaces',
438+
connected_node_types="Country")
439+
assert result == ['country/USA']
440+
441+
442+
def test_extract_connected_dcids_with_multiple_node_type_filter():
443+
"""Test that extract_connected_dcids returns dcids with the corresponding
444+
connected_node_types."""
445+
json_data = {
446+
"data": {
447+
"geoId/06": {
448+
"arcs": {
449+
"relatedPlaces": {
450+
"nodes": [{
451+
"dcid": "country/USA",
452+
"name": "United States",
453+
"provenanceId": "dc/base/WikidataOtherIdGeos",
454+
"types": ["Country"]
455+
}, {
456+
"dcid": "usc/PacificDivision",
457+
"name": "Pacific Division",
458+
"provenanceId": "dc/base/WikidataOtherIdGeos",
459+
"types": ["CensusDivision"]
460+
}, {
461+
"dcid": "node3",
462+
"types": ["City"]
463+
}]
464+
}
465+
}
466+
},
467+
}
468+
}
469+
response = NodeResponse.from_json(json_data)
470+
result = response.extract_connected_dcids(
471+
subject_dcid='geoId/06',
472+
property_dcid='relatedPlaces',
473+
connected_node_types=["Country", "City"])
474+
assert result == ['country/USA', 'node3']
475+
476+
477+
def test_extract_connected_nodes_with_multiple_node_type_filter():
478+
"""Test that extract_connected_nodes returns only nodes with the corresponding
479+
connected_node_types."""
480+
481+
json_data = {
482+
"data": {
483+
"geoId/06": {
484+
"arcs": {
485+
"relatedPlaces": {
486+
"nodes": [{
487+
"dcid": "country/USA",
488+
"name": "United States",
489+
"provenanceId": "dc/base/WikidataOtherIdGeos",
490+
"types": ["Country"]
491+
}, {
492+
"dcid": "usc/PacificDivision",
493+
"name": "Pacific Division",
494+
"provenanceId": "dc/base/WikidataOtherIdGeos",
495+
"types": ["CensusDivision"]
496+
}, {
497+
"dcid": "node3",
498+
"types": ["City"]
499+
}]
500+
}
501+
}
502+
},
503+
}
504+
}
505+
response = NodeResponse.from_json(json_data)
506+
result = response.extract_connected_nodes(
507+
subject_dcid='geoId/06',
508+
property_dcid='relatedPlaces',
509+
connected_node_types=["Country", "City"])
510+
assert result == [
511+
Node(dcid="country/USA",
512+
name="United States",
513+
provenanceId="dc/base/WikidataOtherIdGeos",
514+
types=["Country"]),
515+
Node(dcid="node3", types=["City"])
516+
]
517+
518+
290519
### ----- Test Observation Response ----- ###
291520

292521

0 commit comments

Comments
 (0)