Skip to content

Commit e65ae34

Browse files
sbngrosssgross-emlix
authored andcommitted
driver/sshdriver: add multifile support to scp
Users are accustomed to scp having an option for copying recursively as well as to accept multiple source files. Meet user expectation by adding well known `-r` option and support for multiple source files. Signed-off-by: Sebastian Gross <sgross@emlix.com>
1 parent 2f0bdf1 commit e65ae34

1 file changed

Lines changed: 19 additions & 9 deletions

File tree

labgrid/driver/sshdriver.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -356,25 +356,34 @@ def forward_unix_socket(self, unixsocket, localport=None):
356356
yield localport
357357

358358
@Driver.check_active
359-
@step(args=['src', 'dst'])
360-
def scp(self, *, src, dst):
359+
@step(args=['src', 'dst', 'recursive'])
360+
def scp(self, *, src: str | list(str), dst: str, recursive: bool = False):
361361
if not self._check_keepalive():
362362
raise ExecutionError("Keepalive no longer running")
363363

364-
if src.startswith(':') == dst.startswith(':'):
364+
if isinstance(src, str):
365+
src = [src]
366+
367+
remote_src = [f.startswith(':') for f in src]
368+
if any(remote_src) != all(remote_src):
369+
raise ValueError("All sources must be consistently local or remote (start with :)")
370+
371+
if all(remote_src) == dst.startswith(':'):
365372
raise ValueError("Either source or destination must be remote (start with :)")
366-
if src.startswith(':'):
367-
src = '_' + src
368-
if dst.startswith(':'):
369-
dst = '_' + dst
373+
374+
src = [s.replace(':', '_:') for s in src]
375+
dst = dst.replace(':', '_:')
370376

371377
complete_cmd = [self._scp,
372378
"-S", self._ssh,
373379
"-F", "none",
374380
"-o", f"ControlPath={self.control.replace('%', '%%')}",
375-
src, dst,
381+
*src,
382+
dst,
376383
]
377-
384+
385+
if recursive:
386+
complete_cmd.insert(1, "-r")
378387
if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode():
379388
complete_cmd.insert(1, "-s")
380389
if self.explicit_scp_mode and self._scp_supports_explicit_scp_mode():
@@ -594,3 +603,4 @@ def _stop_keepalive(self):
594603
if stdout:
595604
for line in stdout.splitlines():
596605
self.logger.warning("Keepalive %s: %s", self.networkservice.address, line)
606+

0 commit comments

Comments
 (0)