Skip to content

Commit ca9ba6d

Browse files
qubixesalanking
authored andcommitted
[#574] Allow for tqdm progress bars to be used
Works for both parallel as well as single threaded gets and puts.
1 parent 0c9600e commit ca9ba6d

2 files changed

Lines changed: 28 additions & 17 deletions

File tree

irods/manager/data_object_manager.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def should_parallelize_transfer( self,
124124
open_options[kw.DATA_SIZE_KW] = size
125125

126126

127-
def _download(self, obj, local_path, num_threads, **options):
127+
def _download(self, obj, local_path, num_threads, progress_bar, **options):
128128
"""Transfer the contents of a data object to a local file.
129129
130130
Called from get() when a local path is named.
@@ -145,14 +145,17 @@ def _download(self, obj, local_path, num_threads, **options):
145145
f.close()
146146
if not self.parallel_get( (obj,o), local_file, num_threads = num_threads,
147147
target_resource_name = options.get(kw.RESC_NAME_KW,''),
148-
data_open_returned_values = data_open_returned_values_):
148+
data_open_returned_values = data_open_returned_values_,
149+
progress_bar=progress_bar):
149150
raise RuntimeError("parallel get failed")
150151
else:
151152
for chunk in chunks(o, self.READ_BUFFER_SIZE):
152153
f.write(chunk)
154+
if progress_bar is not None:
155+
progress_bar.update(len(chunk))
153156

154157

155-
def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS, **options):
158+
def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS, progress_bar = None, **options):
156159
"""
157160
Get a reference to the data object at the specified `path'.
158161
@@ -163,7 +166,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
163166

164167
# TODO: optimize
165168
if local_path:
166-
self._download(path, local_path, num_threads = num_threads, **options)
169+
self._download(path, local_path, num_threads = num_threads, progress_bar=progress_bar, **options)
167170

168171
query = self.sess.query(DataObject)\
169172
.filter(DataObject.name == irods_basename(path))\
@@ -180,7 +183,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
180183
return iRODSDataObject(self, parent, results)
181184

182185

183-
def put(self, local_path, irods_path, return_data_object = False, num_threads = DEFAULT_NUMBER_OF_THREADS, **options):
186+
def put(self, local_path, irods_path, return_data_object = False, num_threads = DEFAULT_NUMBER_OF_THREADS, progress_bar = None, **options):
184187

185188
if self.sess.collections.exists(irods_path):
186189
obj = iRODSCollection.normalize_path(irods_path, os.path.basename(local_path))
@@ -195,7 +198,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
195198
if not self.parallel_put( local_path, (obj,o), total_bytes = sizelist[0], num_threads = num_threads,
196199
target_resource_name = options.get(kw.RESC_NAME_KW,'') or
197200
options.get(kw.DEST_RESC_NAME_KW,''),
198-
open_options = options ):
201+
open_options = options, progress_bar = progress_bar):
199202
raise RuntimeError("parallel put failed")
200203
else:
201204
with self.open(obj, 'w', **options) as o:
@@ -204,6 +207,8 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
204207
options[kw.OPR_TYPE_KW] = 1 # PUT_OPR
205208
for chunk in chunks(f, self.WRITE_BUFFER_SIZE):
206209
o.write(chunk)
210+
if progress_bar is not None:
211+
progress_bar.update(len(chunk))
207212
if kw.ALL_KW in options:
208213
repl_options = options.copy()
209214
repl_options[kw.UPDATE_REPL_KW] = ''
@@ -259,7 +264,8 @@ def parallel_get(self,
259264
num_threads = 0,
260265
target_resource_name = '',
261266
data_open_returned_values = None,
262-
progressQueue = False):
267+
progressQueue = False,
268+
progress_bar = None):
263269
"""Call into the irods.parallel library for multi-1247 GET.
264270
265271
Called from a session.data_objects.get(...) (via the _download method) on
@@ -270,7 +276,8 @@ def parallel_get(self,
270276
return parallel.io_main( self.sess, data_or_path_, parallel.Oper.GET | (parallel.Oper.NONBLOCKING if async_ else 0), file_,
271277
num_threads = num_threads, target_resource_name = target_resource_name,
272278
data_open_returned_values = data_open_returned_values,
273-
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0))
279+
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0),
280+
progress_bar = progress_bar)
274281

275282
def parallel_put(self,
276283
file_ ,
@@ -280,6 +287,7 @@ def parallel_put(self,
280287
num_threads = 0,
281288
target_resource_name = '',
282289
open_options = {},
290+
progress_bar = None,
283291
progressQueue = False):
284292
"""Call into the irods.parallel library for multi-1247 PUT.
285293
@@ -290,6 +298,7 @@ def parallel_put(self,
290298
return parallel.io_main( self.sess, data_or_path_, parallel.Oper.PUT | (parallel.Oper.NONBLOCKING if async_ else 0), file_,
291299
num_threads = num_threads, total_bytes = total_bytes, target_resource_name = target_resource_name,
292300
open_options = open_options,
301+
progress_bar = progress_bar,
293302
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0)
294303
)
295304

irods/parallel.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _io_send_bytes_progress (queueObject, item):
223223

224224
COPY_BUF_SIZE = (1024 ** 2) * 4
225225

226-
def _copy_part( src, dst, length, queueObject, debug_info, mgr):
226+
def _copy_part( src, dst, length, queueObject, debug_info, mgr, progress_bar):
227227
"""
228228
The work-horse for performing the copy between file and data object.
229229
@@ -240,6 +240,8 @@ def _copy_part( src, dst, length, queueObject, debug_info, mgr):
240240
bytecount += buf_len
241241
accum += buf_len
242242
if queueObject and accum and _io_send_bytes_progress(queueObject,accum): accum = 0
243+
if progress_bar is not None:
244+
progress_bar.update(buf_len)
243245
if verboseConnection:
244246
print ("("+debug_info+")",end='',file=sys.stderr)
245247
sys.stderr.flush()
@@ -301,7 +303,7 @@ def finalize(self):
301303
self.initial_io.close()
302304

303305

304-
def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueObject = None ):
306+
def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueObject = None, progress_bar = None):
305307
"""
306308
Runs in a separate thread to manage the transfer of a range of bytes within the data object.
307309
@@ -315,12 +317,12 @@ def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueO
315317
file_.seek(offset)
316318
if thread_debug_id == '': # for more succinct thread identifiers while debugging.
317319
thread_debug_id = str(threading.currentThread().ident)
318-
return ( _copy_part (file_, objHandle, length, queueObject, thread_debug_id, mgr_) if Operation.isPut()
319-
else _copy_part (objHandle, file_, length, queueObject, thread_debug_id, mgr_) )
320+
return ( _copy_part (file_, objHandle, length, queueObject, thread_debug_id, mgr_, progress_bar) if Operation.isPut()
321+
else _copy_part (objHandle, file_, length, queueObject, thread_debug_id, mgr_, progress_bar) )
320322

321323

322324
def _io_multipart_threaded(operation_ , dataObj_and_IO, replica_token, hier_str, session, fname,
323-
total_size, num_threads, **extra_options):
325+
total_size, num_threads, progress_bar, **extra_options):
324326
"""Called by _io_main.
325327
326328
Carve up (0,total_size) range into `num_threads` parts and initiate a transfer thread for each one.
@@ -366,7 +368,7 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
366368
mgr.add_io( Io )
367369
logger.debug(u'target_host = %s', Io.raw.session.pool.account.host)
368370
if File is None: File = gen_file_handle()
369-
futures.append(executor.submit( _io_part, Io, byte_range, File, Operation, mgr, str(counter), queueObject))
371+
futures.append(executor.submit( _io_part, Io, byte_range, File, Operation, mgr, str(counter), queueObject, progress_bar))
370372
counter += 1
371373
Io = File = None
372374

@@ -381,7 +383,7 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
381383

382384

383385

384-
def io_main( session, Data, opr_, fname, R='', **kwopt):
386+
def io_main( session, Data, opr_, fname, R='', progress_bar = None, **kwopt):
385387
"""
386388
The entry point for parallel transfers (multithreaded PUT and GET operations).
387389
@@ -395,7 +397,6 @@ def io_main( session, Data, opr_, fname, R='', **kwopt):
395397
Operation = Oper(opr_)
396398
d_path = None
397399
Io = None
398-
399400
if isinstance(Data,tuple):
400401
(Data, Io) = Data[:2]
401402

@@ -468,7 +469,8 @@ def io_main( session, Data, opr_, fname, R='', **kwopt):
468469
queueLength = kwopt.get('queueLength',0)
469470
retval = _io_multipart_threaded (Operation, (Data, Io), replica_token, resc_hier, session, fname, total_bytes,
470471
num_threads = num_threads,
471-
_queueLength = queueLength)
472+
_queueLength = queueLength,
473+
progress_bar = progress_bar)
472474

473475
# SessionObject.data_objects.parallel_{put,get} will return:
474476
# - immediately with an AsyncNotify instance, if Oper.NONBLOCKING flag is used.

0 commit comments

Comments
 (0)