Skip to content

Commit cc00c9c

Browse files
authored
Make get_stats more user-friendly. Also help Mixer out by chunking re… (#124)
* Make get_stats more user-friendly. Also help Mixer out by chunking requests. * Update mock to return dictionary. * Add test for batching get_stats. * Remove unnecessary test that also did not get mocked. * Update mock for new batch test. * Correct mocked responses. * Optimize list to set since order doesn't matter. * Use six.moves.urllib in places_test, since places uses six. * Use full six.moves.urllib path for mock. Improve style. * Missed replace. Co-authored-by: tjann <tjann@google.com>
1 parent d5d312a commit cc00c9c

4 files changed

Lines changed: 177 additions & 23 deletions

File tree

datacommons/examples/places.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,26 @@ def main():
4141
print(' - {}'.format(dcid))
4242

4343
# Get place stats.
44-
print('Get place stats')
44+
print('Get place stats -- all')
45+
stats = dc.get_stats(['geoId/05', 'geoId/06', 'dc/madDcid'], 'dc/0hyp6tkn18vcb', obs_dates='all')
46+
print(stats)
47+
48+
print('Get place stats -- latest')
4549
stats = dc.get_stats(['geoId/05', 'geoId/06', 'dc/madDcid'], 'dc/0hyp6tkn18vcb')
4650
print(stats)
4751

52+
print('Get place stats -- 2014')
53+
stats = dc.get_stats(['geoId/05', 'geoId/06', 'dc/madDcid'], 'dc/0hyp6tkn18vcb', obs_dates=['2014'])
54+
print(stats)
55+
56+
print('Get place stats -- 2014 badly formatted')
57+
stats = dc.get_stats(['geoId/05', 'geoId/06', 'dc/madDcid'], 'dc/0hyp6tkn18vcb', obs_dates='2014')
58+
print(stats)
59+
60+
print('Get place stats -- 2015-2016')
61+
stats = dc.get_stats(['geoId/05', 'geoId/06', 'dc/madDcid'], 'dc/0hyp6tkn18vcb', obs_dates=['2015', '2016'])
62+
print(stats)
63+
4864
# Get related places.
4965
# TODO(*): Fix the related places example.
5066
# print('Get related places')

datacommons/places.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,21 @@ def get_places_in(dcids, place_type):
7272
result = utils._format_expand_payload(payload, 'place', must_exist=dcids)
7373
return result
7474

75-
def get_stats(dcids, stats_var):
75+
def get_stats(dcids, stats_var, obs_dates='latest'):
7676
""" Returns :obj:`TimeSeries` for :code:`dcids` \
7777
based on the :code:`stats_var`.
7878
7979
Args:
8080
dcids (:obj:`iterable` of :obj:`str`): Dcids of places to query for.
8181
stats_var (:obj:`str`): The dcid of the :obj:StatisticalVariable.
82+
obs_dates (:obj:`str` or :obj:`iterable` of :obj:`str`):
83+
Which observation to return.
84+
Can be 'latest', 'all', or an iterable of dates in 'YYYY-MM-DD' format.
8285
Returns:
8386
A :obj:`dict` mapping the :obj:`Place` identified by the given :code:`dcid`
8487
to its place name and the :obj:`TimeSeries` associated with the
85-
:obj:`StatisticalVariable` identified by the given :code:`stats_var`.
88+
:obj:`StatisticalVariable` identified by the given :code:`stats_var`
89+
and filtered by :code:`obs_dates`.
8690
See example below for more detail about how the returned :obj:`dict` is
8791
structured.
8892
@@ -131,12 +135,35 @@ def get_stats(dcids, stats_var):
131135
dcids = filter(lambda v: v==v, dcids) # Filter out NaN values
132136
dcids = list(dcids)
133137
url = utils._API_ROOT + utils._API_ENDPOINTS['get_stats']
134-
payload = utils._send_request(url, req_json={
135-
'place': dcids,
136-
'stats_var': stats_var,
137-
})
138+
batches = -(-len(dcids) // utils._QUERY_BATCH_SIZE) # Ceil to get # of batches.
139+
res = {}
140+
for i in range(batches):
141+
payload = utils._send_request(url, req_json={
142+
'place': dcids[i * utils._QUERY_BATCH_SIZE:(i+1) * utils._QUERY_BATCH_SIZE],
143+
'stats_var': stats_var,
144+
})
145+
if obs_dates == 'all':
146+
res.update(payload)
147+
elif obs_dates == 'latest':
148+
for geo, stats in payload.items():
149+
time_series = stats.get('data')
150+
if not time_series: continue
151+
max_date = max(time_series)
152+
max_date_stat = time_series[max_date]
153+
time_series.clear()
154+
time_series[max_date] = max_date_stat
155+
res[geo] = stats
156+
elif obs_dates:
157+
obs_dates = set(obs_dates)
158+
for geo, stats in payload.items():
159+
time_series = stats.get('data')
160+
if not time_series: continue
161+
for date in list(time_series):
162+
if date not in obs_dates:
163+
time_series.pop(date)
164+
res[geo] = stats
165+
return res
138166

139-
return payload
140167

141168
def get_related_places(dcids, population_type, measured_property,
142169
measurement_method, stat_type, constraining_properties={},

datacommons/test/places_test.py

Lines changed: 123 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import datacommons.utils as utils
2727
import json
2828
import unittest
29-
import urllib
29+
import six.moves.urllib
3030

3131

3232
def request_mock(*args, **kwargs):
@@ -122,9 +122,10 @@ def read(self):
122122
}
123123
})
124124
return MockResponse(json.dumps({'payload': res_json}))
125-
if (data['place'] == ['geoId/05', 'dc/MadDcid'] and
125+
if ((data['place'] == ['geoId/05', 'dc/MadDcid'] or
126+
data['place'] == ['geoId/05']) and
126127
data['stats_var'] == 'dc/0hyp6tkn18vcb'):
127-
# Response returned when querying for a dcid that does not exist.
128+
# Response ignores dcid that does not exist.
128129
res_json = json.dumps({
129130
'geoId/05': {
130131
'data': {
@@ -141,13 +142,31 @@ def read(self):
141142
}
142143
})
143144
return MockResponse(json.dumps({'payload': res_json}))
145+
if (data['place'] == ['geoId/06'] and
146+
data['stats_var'] == 'dc/0hyp6tkn18vcb'):
147+
res_json = json.dumps({
148+
'geoId/06': {
149+
'data': {
150+
'2011': 316667,
151+
'2012': 324116,
152+
'2013': 331853,
153+
'2014': 342818,
154+
'2015': 348979,
155+
'2016': 354806,
156+
'2017': 360645,
157+
'2018': 366331
158+
},
159+
'place_name': 'California'
160+
}
161+
})
162+
return MockResponse(json.dumps({'payload': res_json}))
144163
if (data['place'] == ['dc/MadDcid', 'dc/MadderDcid'] and
145164
data['stats_var'] == 'dc/0hyp6tkn18vcb'):
146165
# Response returned when both given dcids do not exist.
147-
res_json = json.dumps([])
166+
res_json = json.dumps({})
148167
return MockResponse(json.dumps({'payload': res_json}))
149168
if data['place'] == [] and data['stats_var'] == 'dc/0hyp6tkn18vcb':
150-
res_json = json.dumps([])
169+
res_json = json.dumps({})
151170
# Response returned when no dcids are given.
152171
return MockResponse(json.dumps({'payload': res_json}))
153172

@@ -216,7 +235,7 @@ def test_multiple_dcids(self, urlopen):
216235
dc.set_api_key('TEST-API-KEY')
217236

218237
# Call get_stats
219-
stats = dc.get_stats(['geoId/05', 'geoId/06'], 'dc/0hyp6tkn18vcb')
238+
stats = dc.get_stats(['geoId/05', 'geoId/06'], 'dc/0hyp6tkn18vcb', 'all')
220239
self.assertDictEqual(
221240
stats, {
222241
'geoId/05': {
@@ -247,6 +266,60 @@ def test_multiple_dcids(self, urlopen):
247266
}
248267
})
249268

269+
# Call get_stats for latest obs
270+
stats = dc.get_stats(['geoId/05', 'geoId/06'], 'dc/0hyp6tkn18vcb', 'latest')
271+
self.assertDictEqual(
272+
stats, {
273+
'geoId/05': {
274+
'data': {
275+
'2018': 18003
276+
},
277+
'place_name': 'Arkansas'
278+
},
279+
'geoId/06': {
280+
'data': {
281+
'2018': 366331
282+
},
283+
'place_name': 'California'
284+
}
285+
})
286+
287+
# Call get_stats for specific obs
288+
stats = dc.get_stats(['geoId/05', 'geoId/06'], 'dc/0hyp6tkn18vcb', ['2013', '2018'])
289+
self.assertDictEqual(
290+
stats, {
291+
'geoId/05': {
292+
'data': {
293+
'2013': 17459,
294+
'2018': 18003
295+
},
296+
'place_name': 'Arkansas'
297+
},
298+
'geoId/06': {
299+
'data': {
300+
'2013': 331853,
301+
'2018': 366331
302+
},
303+
'place_name': 'California'
304+
}
305+
})
306+
307+
# Call get_stats -- dates must be in interable
308+
stats = dc.get_stats(['geoId/05', 'geoId/06'], 'dc/0hyp6tkn18vcb', '2018')
309+
self.assertDictEqual(
310+
stats, {
311+
'geoId/05': {
312+
'data': {
313+
},
314+
'place_name': 'Arkansas'
315+
},
316+
'geoId/06': {
317+
'data': {
318+
},
319+
'place_name': 'California'
320+
}
321+
})
322+
250323
@mock.patch('urllib.request.urlopen', side_effect=request_mock)
251324
def test_bad_dcids(self, urlopen):
252325
""" Calling get_stats with dcids that do not exist returns empty
@@ -261,13 +334,6 @@ def test_bad_dcids(self, urlopen):
261334
bad_dcids_1, {
262335
'geoId/05': {
263336
'data': {
264-
'2011': 18136,
265-
'2012': 17279,
266-
'2013': 17459,
267-
'2014': 16966,
268-
'2015': 17173,
269-
'2016': 17041,
270-
'2017': 17783,
271337
'2018': 18003
272338
},
273339
'place_name': 'Arkansas'
@@ -277,7 +343,7 @@ def test_bad_dcids(self, urlopen):
277343
# Call get_stats when both dcids do not exist
278344
bad_dcids_2 = dc.get_stats(['dc/MadDcid', 'dc/MadderDcid'],
279345
'dc/0hyp6tkn18vcb')
280-
self.assertFalse(bad_dcids_2)
346+
self.assertDictEqual({}, bad_dcids_2)
281347

282348
@mock.patch('urllib.request.urlopen', side_effect=request_mock)
283349
def test_no_dcids(self, urlopen):
@@ -287,7 +353,49 @@ def test_no_dcids(self, urlopen):
287353

288354
# Call get_stats with no dcids.
289355
no_dcids = dc.get_stats([], 'dc/0hyp6tkn18vcb')
290-
self.assertFalse(no_dcids)
356+
self.assertDictEqual({}, no_dcids)
357+
358+
@mock.patch('six.moves.urllib.request.urlopen', side_effect=request_mock)
359+
def test_batch_request(self, mock_urlopen):
360+
""" Make multiple calls to REST API when number of geos exceeds the batch size. """
361+
# Set the API key
362+
dc.set_api_key('TEST-API-KEY')
363+
364+
save_batch_size = dc.utils._QUERY_BATCH_SIZE
365+
dc.utils._QUERY_BATCH_SIZE = 1
366+
367+
self.assertEqual(0, mock_urlopen.call_count)
368+
stats = dc.get_stats(['geoId/05'], 'dc/0hyp6tkn18vcb', 'latest')
369+
self.assertDictEqual(
370+
stats, {
371+
'geoId/05': {
372+
'data': {
373+
'2018': 18003
374+
},
375+
'place_name': 'Arkansas'
376+
},
377+
})
378+
self.assertEqual(1, mock_urlopen.call_count)
379+
380+
stats = dc.get_stats(['geoId/05', 'geoId/06'], 'dc/0hyp6tkn18vcb', 'latest')
381+
self.assertDictEqual(
382+
stats, {
383+
'geoId/05': {
384+
'data': {
385+
'2018': 18003
386+
},
387+
'place_name': 'Arkansas'
388+
},
389+
'geoId/06': {
390+
'data': {
391+
'2018': 366331
392+
},
393+
'place_name': 'California'
394+
}
395+
})
396+
self.assertEqual(3, mock_urlopen.call_count)
397+
398+
dc.utils._QUERY_BATCH_SIZE = save_batch_size
291399

292400

293401
if __name__ == '__main__':

datacommons/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
# The default value to limit to
5555
_MAX_LIMIT = 100
5656

57+
# Batch size for heavyweight queries.
58+
_QUERY_BATCH_SIZE = 500
59+
5760
# Environment variable names used by the package
5861
_ENV_VAR_API_KEY = 'DC_API_KEY' # Name the API key variable
5962

0 commit comments

Comments
 (0)