Skip to content

Commit 5c2bca3

Browse files
committed
Changes from Christie
2 parents 9c2fa16 + 17b7d76 commit 5c2bca3

5 files changed

Lines changed: 336 additions & 20 deletions

File tree

datacommons_client/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,18 @@ def observations_dataframe(
136136
entity_type (Optional[str]): The type of entities to filter by when `entity_dcids="all"`.
137137
Required if `entity_dcids="all"`. Defaults to None.
138138
parent_entity (Optional[str]): The parent entity under which the target entities fall.
139-
Used only when `entity_dcids="all"`. Defaults to None.
139+
Required if `entity_dcids="all"`. Defaults to None.
140140
property_filters (Optional[dict[str, str | list[str]]): An optional dictionary used to filter
141141
the data by using observation properties like `measurementMethod`, `unit`, or `observationPeriod`.
142142
143143
Returns:
144144
pd.DataFrame: A DataFrame containing the requested observations.
145145
"""
146146

147-
if entity_dcids == "all" and not entity_type:
147+
if entity_dcids == "all" and not (entity_type and parent_entity):
148148
raise ValueError(
149-
"When 'entity_dcids' is 'all', 'entity_type' must be specified.")
149+
"When 'entity_dcids' is 'all', both 'parent_entity' and 'entity_type' must be specified."
150+
)
150151

151152
if entity_dcids != "all" and (entity_type or parent_entity):
152153
raise ValueError(

datacommons_client/endpoints/response.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from dataclasses import dataclass
22
from dataclasses import field
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional
44

55
from datacommons_client.models.base import SerializableMixin
66
from datacommons_client.models.node import Arcs
77
from datacommons_client.models.node import NextToken
8+
from datacommons_client.models.node import Node
89
from datacommons_client.models.node import NodeDCID
10+
from datacommons_client.models.node import NodeGroup
911
from datacommons_client.models.node import Properties
12+
from datacommons_client.models.node import Property
1013
from datacommons_client.models.observation import Facet
1114
from datacommons_client.models.observation import facetID
1215
from datacommons_client.models.observation import Variable
@@ -68,6 +71,72 @@ def parse_data(data: Dict[str, Any]) -> Arcs | Properties:
6871
def get_properties(self) -> Dict:
6972
return flatten_properties(self.data)
7073

74+
def extract_connected_nodes(
75+
self,
76+
subject_dcid: NodeDCID,
77+
property_dcid: Property,
78+
connected_node_types: Optional[str | list[str]] = None) -> List[Node]:
79+
"""Retrieves Node objects in the NodeResponse connected to the subject node
80+
via the specified property.
81+
82+
Args:
83+
subject_dcid: The DCID of the starting node in the arc.
84+
property_dcid: The property connecting the subject node to the desired
85+
target nodes.
86+
connected_node_types: Optional. A type or list of types to filter the
87+
connected nodes. If provided, only connected nodes that have at least
88+
one of the specified types will be returned. If omitted, all nodes from
89+
the arc are returned.
90+
91+
Returns:
92+
A list of Node objects that are connected to the subject node via the
93+
specified property.
94+
"""
95+
if isinstance(connected_node_types, str):
96+
connected_node_types = [connected_node_types]
97+
98+
nodes = self.get_properties().get(subject_dcid, {}).get(property_dcid, [])
99+
100+
connected_nodes = []
101+
for node in nodes:
102+
if connected_node_types:
103+
# Filter out nodes that are missing a list of types or do not have the
104+
# desired type
105+
if not node.types or not any(nt in node.types
106+
for nt in connected_node_types):
107+
continue
108+
109+
connected_nodes.append(node)
110+
111+
return connected_nodes
112+
113+
def extract_connected_dcids(
114+
self,
115+
subject_dcid: NodeDCID,
116+
property_dcid: Property,
117+
connected_node_types: Optional[str | list[str]] = None) -> List[NodeDCID]:
118+
"""Retrieves DCIDs of the Nodes in the NodeResponse connected to the subject
119+
node via the specified property.
120+
121+
Args:
122+
subject_dcid: The DCID of the starting node.
123+
property_dcid: The property connecting the subject node to the desired
124+
target nodes.
125+
connected_node_types: Optional. A type or list of types to filter the
126+
connected nodes. If provided, only DCIDs of connected nodes that have at
127+
least one of the specified types will be returned. If omitted, DCIDs of
128+
all nodes from the arc are returned.
129+
130+
Returns:
131+
A list of NodeDCIDs for the nodes connected via the specified property
132+
from the subject node.
133+
"""
134+
135+
connected_nodes = self.extract_connected_nodes(subject_dcid, property_dcid,
136+
connected_node_types)
137+
138+
return [node.dcid for node in connected_nodes if node.dcid]
139+
71140

72141
@dataclass
73142
class ObservationResponse(SerializableMixin):

datacommons_client/tests/endpoints/test_response.py

Lines changed: 234 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datacommons_client.endpoints.response import NodeResponse
66
from datacommons_client.endpoints.response import ObservationResponse
77
from datacommons_client.endpoints.response import ResolveResponse
8+
from datacommons_client.models.node import Arcs
89
from datacommons_client.models.node import Node
910
from datacommons_client.models.node import NodeGroup
1011
from datacommons_client.models.observation import Facet
@@ -182,7 +183,11 @@ def test_flatten_arcs():
182183
result = flatten_properties(response.data)
183184

184185
assert "dc/03lw9rhpendw5" in result
185-
assert result["dc/03lw9rhpendw5"].value == "191 Peachtree Tower"
186+
assert result["dc/03lw9rhpendw5"] == {
187+
"name": [
188+
Node(value="191 Peachtree Tower", provenanceId="dc/base/EIA_860")
189+
]
190+
}
186191

187192

188193
def test_flatten_multiple_arcs_with_multiple_nodes():
@@ -283,6 +288,234 @@ def test_unpack_arcs_multiple_properties():
283288
assert result == expected
284289

285290

291+
def test_extract_connected_dcids():
292+
"""Test that extract_connected_dcids is successful when multiple dcid and multiple
293+
properties are in the response."""
294+
json_data = {
295+
"data": {
296+
"geoId/06": {
297+
"arcs": {
298+
"containedInPlace": {
299+
"nodes": [{
300+
"dcid": "country/USA",
301+
"name": "United States",
302+
"provenanceId": "dc/base/WikidataOtherIdGeos",
303+
"types": ["Country"]
304+
}, {
305+
"dcid": "usc/PacificDivision",
306+
"name": "Pacific Division",
307+
"provenanceId": "dc/base/WikidataOtherIdGeos",
308+
"types": ["CensusDivision"]
309+
}]
310+
},
311+
"name": {
312+
"nodes": [{
313+
"provenanceId": "dc/base/WikidataOtherIdGeos",
314+
"value": "California"
315+
}]
316+
}
317+
}
318+
},
319+
"geoId/07": {
320+
"arcs": {
321+
"containedInPlace": {
322+
"nodes": [{
323+
"dcid": "country/USA",
324+
"name": "United States",
325+
"provenanceId": "dc/base/WikidataOtherIdGeos",
326+
"types": ["Country"]
327+
}]
328+
},
329+
}
330+
}
331+
}
332+
}
333+
response = NodeResponse.from_json(json_data)
334+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
335+
property_dcid='containedInPlace')
336+
assert result == ['country/USA', 'usc/PacificDivision']
337+
338+
339+
def test_extract_connected_dcids_with_nonexistent_dcid():
340+
"""Test that extract_connected_dcids returns empty when requested dcid is not in the
341+
NodeResponse."""
342+
json_data = {
343+
"data": {
344+
"geoId/06": {
345+
"arcs": {
346+
"name": {
347+
"nodes": [{
348+
"provenanceId": "dc/base/WikidataOtherIdGeos",
349+
"value": "California"
350+
}]
351+
}
352+
}
353+
},
354+
}
355+
}
356+
response = NodeResponse.from_json(json_data)
357+
result = response.extract_connected_dcids(subject_dcid='geoId/07',
358+
property_dcid='name')
359+
assert result == []
360+
361+
362+
def test_extract_connected_dcids_with_nonexistent_property():
363+
"""Test that extract_connected_dcids returns empty when requested property is not in
364+
the NodeResponse."""
365+
json_data = {
366+
"data": {
367+
"geoId/06": {
368+
"arcs": {
369+
"name": {
370+
"nodes": [{
371+
"provenanceId": "dc/base/WikidataOtherIdGeos",
372+
"value": "California"
373+
}]
374+
}
375+
}
376+
},
377+
}
378+
}
379+
response = NodeResponse.from_json(json_data)
380+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
381+
property_dcid='containedInPlace')
382+
assert result == []
383+
384+
385+
def test_extract_connected_dcids_does_not_include_none_for_value_only_nodes():
386+
"""Test that extract_connected_dcids does not include None in the returned list
387+
when the nodes in the response only contain values and not dcids."""
388+
json_data = {
389+
"data": {
390+
"geoId/06": {
391+
"arcs": {
392+
"name": {
393+
"nodes": [{
394+
"provenanceId": "dc/base/WikidataOtherIdGeos",
395+
"value": "California"
396+
}]
397+
}
398+
}
399+
},
400+
}
401+
}
402+
response = NodeResponse.from_json(json_data)
403+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
404+
property_dcid='name')
405+
assert result == []
406+
407+
408+
def test_extract_connected_dcids_with_node_type_filter():
409+
"""Test that extract_connected_dcids returns dcids with the corresponding
410+
node_type."""
411+
412+
json_data = {
413+
"data": {
414+
"geoId/06": {
415+
"arcs": {
416+
"relatedPlaces": {
417+
"nodes": [{
418+
"dcid": "country/USA",
419+
"name": "United States",
420+
"provenanceId": "dc/base/WikidataOtherIdGeos",
421+
"types": ["Country"]
422+
}, {
423+
"dcid": "usc/PacificDivision",
424+
"name": "Pacific Division",
425+
"provenanceId": "dc/base/WikidataOtherIdGeos",
426+
"types": ["CensusDivision"]
427+
}, {
428+
"dcid": "node3",
429+
}]
430+
}
431+
}
432+
},
433+
}
434+
}
435+
response = NodeResponse.from_json(json_data)
436+
result = response.extract_connected_dcids(subject_dcid='geoId/06',
437+
property_dcid='relatedPlaces',
438+
connected_node_types="Country")
439+
assert result == ['country/USA']
440+
441+
442+
def test_extract_connected_dcids_with_multiple_node_type_filter():
443+
"""Test that extract_connected_dcids returns dcids with the corresponding
444+
connected_node_types."""
445+
json_data = {
446+
"data": {
447+
"geoId/06": {
448+
"arcs": {
449+
"relatedPlaces": {
450+
"nodes": [{
451+
"dcid": "country/USA",
452+
"name": "United States",
453+
"provenanceId": "dc/base/WikidataOtherIdGeos",
454+
"types": ["Country"]
455+
}, {
456+
"dcid": "usc/PacificDivision",
457+
"name": "Pacific Division",
458+
"provenanceId": "dc/base/WikidataOtherIdGeos",
459+
"types": ["CensusDivision"]
460+
}, {
461+
"dcid": "node3",
462+
"types": ["City"]
463+
}]
464+
}
465+
}
466+
},
467+
}
468+
}
469+
response = NodeResponse.from_json(json_data)
470+
result = response.extract_connected_dcids(
471+
subject_dcid='geoId/06',
472+
property_dcid='relatedPlaces',
473+
connected_node_types=["Country", "City"])
474+
assert result == ['country/USA', 'node3']
475+
476+
477+
def test_extract_connected_nodes_with_multiple_node_type_filter():
478+
"""Test that extract_connected_nodes returns only nodes with the corresponding
479+
connected_node_types."""
480+
481+
json_data = {
482+
"data": {
483+
"geoId/06": {
484+
"arcs": {
485+
"relatedPlaces": {
486+
"nodes": [{
487+
"dcid": "country/USA",
488+
"name": "United States",
489+
"provenanceId": "dc/base/WikidataOtherIdGeos",
490+
"types": ["Country"]
491+
}, {
492+
"dcid": "usc/PacificDivision",
493+
"name": "Pacific Division",
494+
"provenanceId": "dc/base/WikidataOtherIdGeos",
495+
"types": ["CensusDivision"]
496+
}, {
497+
"dcid": "node3",
498+
"types": ["City"]
499+
}]
500+
}
501+
}
502+
},
503+
}
504+
}
505+
response = NodeResponse.from_json(json_data)
506+
result = response.extract_connected_nodes(
507+
subject_dcid='geoId/06',
508+
property_dcid='relatedPlaces',
509+
connected_node_types=["Country", "City"])
510+
assert result == [
511+
Node(dcid="country/USA",
512+
name="United States",
513+
provenanceId="dc/base/WikidataOtherIdGeos",
514+
types=["Country"]),
515+
Node(dcid="node3", types=["City"])
516+
]
517+
518+
286519
### ----- Test Observation Response ----- ###
287520

288521

datacommons_client/tests/test_client.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,27 @@ def test_observations_dataframe_raises_error_when_entities_all_but_no_entity_typ
7070
"""Tests that ValueError is raised if 'entities' is 'all' but 'entity_type' is not specified."""
7171
with pytest.raises(
7272
ValueError,
73-
match="When 'entity_dcids' is 'all', 'entity_type' must be specified.",
73+
match=
74+
"When 'entity_dcids' is 'all', both 'parent_entity' and 'entity_type' must be specified.",
7475
):
7576
mock_client.observations_dataframe(variable_dcids="var1",
7677
date="2024",
77-
entity_dcids="all")
78+
entity_dcids="all",
79+
parent_entity="africa")
80+
81+
82+
def test_observations_dataframe_raises_error_when_entities_all_but_no_parent_entity(
83+
mock_client,):
84+
"""Tests that ValueError is raised if 'entities' is 'all' but 'entity_type' is not specified."""
85+
with pytest.raises(
86+
ValueError,
87+
match=
88+
"When 'entity_dcids' is 'all', both 'parent_entity' and 'entity_type' must be specified.",
89+
):
90+
mock_client.observations_dataframe(variable_dcids="var1",
91+
date="2024",
92+
entity_dcids="all",
93+
entity_type="Country")
7894

7995

8096
def test_observations_dataframe_raises_error_when_invalid_entity_type_usage(

0 commit comments

Comments
 (0)