-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathbatch.py
More file actions
376 lines (314 loc) · 13.8 KB
/
batch.py
File metadata and controls
376 lines (314 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import sys
import ssl
import asyncio
import traceback
import threading
import random
import json
from contextlib import suppress
from urllib.parse import urlencode
from tqdm import tqdm
import certifi
import backoff
from opencage.geocoder import (
OpenCageGeocode,
OpenCageGeocodeError,
_query_for_reverse_geocoding,
floatify_latlng
)
class OpenCageBatchGeocoder():
"""Batch geocoder that processes CSV files using the OpenCage API.
Reads rows from a CSV input, geocodes each address using async workers,
and writes results to a CSV output.
Args:
options: Parsed command-line options from argparse.
"""
def __init__(self, options):
self.options = options
self.sslcontext = ssl.create_default_context(cafile=certifi.where())
self.user_agent_comment = 'OpenCage CLI'
self.write_counter = 1
def __call__(self, *args, **kwargs):
"""Run the batch geocoder synchronously via asyncio.run."""
asyncio.run(self.geocode(*args, **kwargs))
async def geocode(self, csv_input, csv_output):
"""Process a CSV input, geocode each row, and write results.
Args:
csv_input: CSV reader for input rows.
csv_output: CSV writer for output rows.
"""
if not self.options.dry_run:
test = await self.test_request()
if test['error']:
self.log(test['error'])
return
if test['free'] is True and self.options.workers > 1:
sys.stderr.write("Free trial account detected. Resetting number of workers to 1.\n")
self.options.workers = 1
if self.options.headers:
header_columns = next(csv_input, None)
if header_columns is None:
return
queue = asyncio.Queue(maxsize=self.options.limit)
read_warnings = await self.read_input(csv_input, queue)
if self.options.dry_run:
if not read_warnings:
print('All good.')
return
if self.options.headers:
csv_output.writerow(header_columns + self.options.add_columns)
progress_bar = not (self.options.no_progress or self.options.quiet) and \
tqdm(total=queue.qsize(), position=0, desc="Addresses geocoded", dynamic_ncols=True)
tasks = []
for _ in range(self.options.workers):
task = asyncio.create_task(self.worker(csv_output, queue, progress_bar))
tasks.append(task)
# This starts the workers and waits until all are finished
await queue.join()
# All tasks done
for task in tasks:
task.cancel()
if progress_bar:
progress_bar.close()
async def test_request(self):
"""Send a test geocoding request to verify the API key.
Returns:
Dict with 'error' (None or exception) and 'free' (bool indicating
whether a free trial account is being used).
"""
try:
async with OpenCageGeocode(
self.options.api_key,
domain=self.options.api_domain,
sslcontext=self.sslcontext,
user_agent_comment=self.user_agent_comment
) as geocoder:
result = await geocoder.geocode_async('Kendall Sq, Cambridge, MA', raw_response=True)
free = False
with suppress(KeyError):
free = result['rate']['limit'] == 2500
return {'error': None, 'free': free}
except Exception as exc:
return {'error': exc}
async def read_input(self, csv_input, queue):
"""Read all rows from CSV input and add them to the work queue.
Args:
csv_input: CSV reader for input rows.
queue: Async queue to populate with parsed input items.
Returns:
True if any warnings were encountered while reading, False otherwise.
"""
any_warnings = False
for index, row in enumerate(csv_input):
line_number = index + 1
if len(row) == 0:
self.log(f"Line {line_number} - Empty line")
any_warnings = True
row = ['']
item = await self.read_one_line(row, line_number)
if item['warnings'] is True:
any_warnings = True
await queue.put(item)
if queue.full():
break
return any_warnings
async def read_one_line(self, row, row_id):
"""Parse a single CSV row into a work item for geocoding.
Args:
row: List of column values from the CSV reader.
row_id: 1-based line number of the row in the input.
Returns:
Dict with keys 'row_id', 'address', 'original_columns',
and 'warnings'.
"""
warnings = False
if self.options.input_columns:
input_columns = self.options.input_columns
elif self.options.command == 'reverse':
input_columns = [1, 2]
else:
input_columns = None
if input_columns:
address = []
try:
for column in input_columns:
# input_columns option uses 1-based indexing
address.append(row[column - 1])
except IndexError:
self.log(f"Line {row_id} - Missing input column {column} in {row}")
warnings = True
else:
address = row
if self.options.command == 'reverse':
if len(address) != 2:
self.log(
f"Line {row_id} - Expected two comma-separated values for reverse geocoding, got {address}")
else:
# _query_for_reverse_geocoding attempts to convert into numbers. We rather have it fail
# now than during the actual geocoding
try:
_query_for_reverse_geocoding(address[0], address[1])
except BaseException:
self.log(
f"Line {row_id} - Does not look like latitude and longitude: '{address[0]}' and '{address[1]}'")
warnings = True
address = []
return {'row_id': row_id, 'address': ','.join(address), 'original_columns': row, 'warnings': warnings}
async def worker(self, csv_output, queue, progress):
"""Consume items from the queue and geocode each one.
Args:
csv_output: CSV writer for output rows.
queue: Async queue of work items to process.
progress: tqdm progress bar, or False if disabled.
"""
while True:
item = await queue.get()
try:
await self.geocode_one_address(csv_output, item['row_id'], item['address'], item['original_columns'])
if progress:
progress.update(1)
except Exception as exc:
traceback.print_exception(exc, file=sys.stderr)
finally:
queue.task_done()
async def geocode_one_address(self, csv_output, row_id, address, original_columns):
"""Geocode a single address and write the result to the output.
Args:
csv_output: CSV writer for output rows.
row_id: 1-based line number of the row in the input.
address: Address string (or lat,lng for reverse geocoding).
original_columns: Original CSV row columns to preserve in output.
"""
def on_backoff(details):
if not self.options.quiet:
sys.stderr.write("Backing off {wait:0.1f} seconds afters {tries} tries "
"calling function {target} with args {args} and kwargs "
"{kwargs}\n".format(**details))
@backoff.on_exception(backoff.expo,
asyncio.TimeoutError,
max_time=self.options.timeout,
max_tries=self.options.retries,
on_backoff=on_backoff)
async def _geocode_one_address():
async with OpenCageGeocode(
self.options.api_key,
domain=self.options.api_domain,
sslcontext=self.sslcontext,
user_agent_comment=self.user_agent_comment
) as geocoder:
geocoding_results = None
response = None
params = {'no_annotations': 1, 'raw_response': True, **self.options.optional_api_params}
try:
if self.options.command == 'reverse':
if ',' in address:
lon, lat = address.split(',')
response = await geocoder.reverse_geocode_async(lon, lat, **params)
geocoding_results = floatify_latlng(response['results'])
else:
response = await geocoder.geocode_async(address, **params)
geocoding_results = floatify_latlng(response['results'])
except OpenCageGeocodeError as exc:
self.log(str(exc))
except Exception as exc:
traceback.print_exception(exc, file=sys.stderr)
try:
if geocoding_results is not None and len(geocoding_results):
geocoding_result = geocoding_results[0]
else:
geocoding_result = None
if self.options.verbose:
self.log({
'row_id': row_id,
'thread_id': threading.get_native_id(),
'request': geocoder.url + '?' + urlencode(geocoder._parse_request(address, params)),
'response': response
})
await self.write_one_geocoding_result(
csv_output,
row_id,
geocoding_result,
response,
original_columns
)
except Exception as exc:
traceback.print_exception(exc, file=sys.stderr)
await _geocode_one_address()
async def write_one_geocoding_result(
self,
csv_output,
row_id,
geocoding_result,
raw_response,
original_columns):
"""Write a single geocoding result row to the CSV output.
Appends the requested output columns to the original CSV columns.
Rows are written in order unless the --unordered option is set.
Args:
csv_output: CSV writer for output rows.
row_id: 1-based line number of the row in the input.
geocoding_result: First result dict from the API, or None.
raw_response: Full API response dict.
original_columns: Original CSV row columns to preserve in output.
"""
row = original_columns
for column in self.options.add_columns:
if column == 'status':
row.append(self.deep_get_result_value(raw_response, ['status', 'message']))
elif geocoding_result is None:
row.append('')
elif column in geocoding_result:
row.append(self.deep_get_result_value(geocoding_result, [column], ''))
elif column in geocoding_result['components']:
row.append(self.deep_get_result_value(geocoding_result, ['components', column], ''))
elif column in geocoding_result['geometry']:
row.append(self.deep_get_result_value(geocoding_result, ['geometry', column], ''))
elif column == 'FIPS':
row.append(
self.deep_get_result_value(
geocoding_result, [
'annotations', 'FIPS', 'county'], ''))
elif column == 'json':
row.append(json.dumps(geocoding_result, separators=(',', ':'))) # Compact JSON
else:
row.append('')
# Enforce that row are written ordered. That means we might wait for other threads
# to finish a task and make the overall process slower. Alternative would be to
# use a second queue, or keep some results in memory.
if not self.options.unordered:
while row_id > self.write_counter:
if self.options.verbose:
self.log(f"Want to write row {row_id}, but write_counter is at {self.write_counter}")
await asyncio.sleep(random.uniform(0.01, 0.1))
if self.options.verbose:
self.log(f"Writing row {row_id}")
csv_output.writerow(row)
self.write_counter = self.write_counter + 1
def log(self, message):
"""Write a message to stderr unless quiet mode is enabled.
Args:
message: Message string to display.
"""
if not self.options.quiet:
sys.stderr.write(f"{message}\n")
def deep_get_result_value(self, data, keys, default=None):
"""Retrieve a nested value from a dict using a list of keys.
Args:
data: Dict to traverse.
keys: List of keys to follow in sequence.
default: Value to return if any key is missing.
Returns:
The nested value, or default if the path doesn't exist.
Example:
>>> data = {'status': {'code': 200, 'message': 'OK'}}
>>> self.deep_get_result_value(data, ['status', 'message'])
'OK'
>>> self.deep_get_result_value(data, ['missing', 'key'], '')
''
"""
for key in keys:
if isinstance(data, dict):
data = data.get(key, default)
else:
return default
return data