Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit ab6b285

Browse files
authored
Merge pull request #734 from jumpstarter-dev/backport-lazy-ssh
Backport lazy ssh
2 parents 5814cec + 228a4e8 commit ab6b285

2 files changed

Lines changed: 317 additions & 7 deletions

File tree

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,6 @@ def __post_init__(self):
2424
if self.ssh_identity and self.ssh_identity_file:
2525
raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file")
2626

27-
# If ssh_identity_file is provided, read it into ssh_identity
28-
if self.ssh_identity_file:
29-
try:
30-
self.ssh_identity = Path(self.ssh_identity_file).read_text()
31-
except Exception as e:
32-
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None
33-
3427
@classmethod
3528
def client(cls) -> str:
3629
return "jumpstarter_driver_ssh.client.SSHWrapperClient"
@@ -48,4 +41,10 @@ def get_ssh_command(self):
4841
@export
4942
def get_ssh_identity(self):
5043
"""Get the SSH identity key content"""
44+
# If ssh_identity_file is provided, read it lazily and cache in ssh_identity
45+
if self.ssh_identity is None and self.ssh_identity_file:
46+
try:
47+
self.ssh_identity = Path(self.ssh_identity_file).read_text()
48+
except Exception as e:
49+
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None
5150
return self.ssh_identity

packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
from jumpstarter.common.exceptions import ConfigurationError
1111
from jumpstarter.common.utils import serve
1212

13+
# Test SSH key content used in multiple tests
14+
TEST_SSH_KEY = (
15+
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
16+
"test-key-content\n"
17+
"-----END OPENSSH PRIVATE KEY-----"
18+
)
19+
1320

1421
def test_ssh_wrapper_defaults():
1522
"""Test SSH wrapper with default configuration"""
@@ -348,3 +355,307 @@ def test_ssh_command_with_command_l_flag_does_not_interfere_with_username_inject
348355
assert ssh_l_index < hostname_index < command_l_index
349356

350357
assert result == 0
358+
359+
360+
def test_ssh_identity_string_configuration():
361+
"""Test SSH wrapper with ssh_identity string configuration"""
362+
instance = SSHWrapper(
363+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
364+
default_username="testuser",
365+
ssh_identity=TEST_SSH_KEY
366+
)
367+
368+
# Test that the instance was created correctly
369+
assert instance.ssh_identity == TEST_SSH_KEY
370+
assert instance.ssh_identity_file is None
371+
372+
# Test that the client class is correct
373+
assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient"
374+
375+
376+
def test_ssh_identity_file_configuration():
377+
"""Test SSH wrapper with ssh_identity_file configuration"""
378+
import os
379+
import tempfile
380+
381+
# Create a temporary file with SSH key content
382+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file:
383+
temp_file.write(TEST_SSH_KEY)
384+
temp_file_path = temp_file.name
385+
386+
try:
387+
instance = SSHWrapper(
388+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
389+
default_username="testuser",
390+
ssh_identity_file=temp_file_path
391+
)
392+
393+
# Test that the instance was created correctly
394+
# ssh_identity should be None until first use (lazy loading)
395+
assert instance.ssh_identity is None
396+
assert instance.ssh_identity_file == temp_file_path
397+
398+
# Test that get_ssh_identity() reads the file on first use
399+
identity = instance.get_ssh_identity()
400+
assert identity == TEST_SSH_KEY
401+
402+
# Test that ssh_identity is now cached
403+
assert instance.ssh_identity == TEST_SSH_KEY
404+
405+
# Test that the client class is correct
406+
assert instance.client() == "jumpstarter_driver_ssh.client.SSHWrapperClient"
407+
finally:
408+
# Clean up the temporary file
409+
os.unlink(temp_file_path)
410+
411+
412+
def test_ssh_identity_validation_error():
413+
"""Test SSH wrapper raises error when both ssh_identity and ssh_identity_file are provided"""
414+
with pytest.raises(ConfigurationError, match="Cannot specify both ssh_identity and ssh_identity_file"):
415+
SSHWrapper(
416+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
417+
default_username="testuser",
418+
ssh_identity="test-key-content",
419+
ssh_identity_file="/path/to/key"
420+
)
421+
422+
423+
def test_ssh_identity_file_read_error():
424+
"""Test SSH wrapper raises error when ssh_identity_file cannot be read on first use"""
425+
# Instance creation should succeed (lazy loading)
426+
instance = SSHWrapper(
427+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
428+
default_username="testuser",
429+
ssh_identity_file="/nonexistent/path/to/key"
430+
)
431+
432+
# Error should be raised when get_ssh_identity() is called
433+
with pytest.raises(ConfigurationError, match="Failed to read ssh_identity_file"):
434+
instance.get_ssh_identity()
435+
436+
437+
def test_ssh_command_with_identity_string():
438+
"""Test SSH command execution with ssh_identity string"""
439+
instance = SSHWrapper(
440+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
441+
default_username="testuser",
442+
ssh_identity=TEST_SSH_KEY
443+
)
444+
445+
with serve(instance) as client:
446+
with patch('subprocess.run') as mock_run:
447+
mock_run.return_value = MagicMock(returncode=0)
448+
449+
# Test SSH command with identity string
450+
result = client.run(False, ["hostname"])
451+
452+
# Verify subprocess.run was called
453+
assert mock_run.called
454+
call_args = mock_run.call_args[0][0] # First positional argument
455+
456+
# Should include -i flag with temporary identity file
457+
assert "-i" in call_args
458+
identity_file_index = call_args.index("-i")
459+
identity_file_path = call_args[identity_file_index + 1]
460+
461+
# The identity file should be a temporary file
462+
assert identity_file_path.endswith("_ssh_key")
463+
assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path
464+
465+
# Should include -l testuser
466+
assert "-l" in call_args
467+
assert "testuser" in call_args
468+
469+
# Should include the actual hostname (127.0.0.1) at the end
470+
assert "127.0.0.1" in call_args
471+
assert "hostname" in call_args
472+
473+
assert result == 0
474+
475+
476+
def test_ssh_command_with_identity_file():
477+
"""Test SSH command execution with ssh_identity_file"""
478+
import os
479+
import tempfile
480+
481+
# Create a temporary file with SSH key content
482+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_test_key') as temp_file:
483+
temp_file.write(TEST_SSH_KEY)
484+
temp_file_path = temp_file.name
485+
486+
try:
487+
instance = SSHWrapper(
488+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
489+
default_username="testuser",
490+
ssh_identity_file=temp_file_path
491+
)
492+
493+
with serve(instance) as client:
494+
with patch('subprocess.run') as mock_run:
495+
mock_run.return_value = MagicMock(returncode=0)
496+
497+
# Test SSH command with identity file
498+
result = client.run(False, ["hostname"])
499+
500+
# Verify subprocess.run was called
501+
assert mock_run.called
502+
call_args = mock_run.call_args[0][0] # First positional argument
503+
504+
# Should include -i flag with temporary identity file
505+
assert "-i" in call_args
506+
identity_file_index = call_args.index("-i")
507+
identity_file_path = call_args[identity_file_index + 1]
508+
509+
# The identity file should be a temporary file (not the original file)
510+
assert identity_file_path.endswith("_ssh_key")
511+
assert "/tmp" in identity_file_path or "/var/tmp" in identity_file_path
512+
assert identity_file_path != temp_file_path
513+
514+
# Should include -l testuser
515+
assert "-l" in call_args
516+
assert "testuser" in call_args
517+
518+
# Should include the actual hostname (127.0.0.1) at the end
519+
assert "127.0.0.1" in call_args
520+
assert "hostname" in call_args
521+
522+
assert result == 0
523+
finally:
524+
# Clean up the temporary file
525+
os.unlink(temp_file_path)
526+
527+
528+
def test_ssh_command_without_identity():
529+
"""Test SSH command execution without identity (should not include -i flag)"""
530+
instance = SSHWrapper(
531+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
532+
default_username="testuser"
533+
)
534+
535+
with serve(instance) as client:
536+
with patch('subprocess.run') as mock_run:
537+
mock_run.return_value = MagicMock(returncode=0)
538+
539+
# Test SSH command without identity
540+
result = client.run(False, ["hostname"])
541+
542+
# Verify subprocess.run was called
543+
assert mock_run.called
544+
call_args = mock_run.call_args[0][0] # First positional argument
545+
546+
# Should NOT include -i flag
547+
assert "-i" not in call_args
548+
549+
# Should include -l testuser
550+
assert "-l" in call_args
551+
assert "testuser" in call_args
552+
553+
# Should include the actual hostname (127.0.0.1) at the end
554+
assert "127.0.0.1" in call_args
555+
assert "hostname" in call_args
556+
557+
assert result == 0
558+
559+
560+
def test_ssh_identity_temp_file_creation_and_cleanup():
561+
"""Test that temporary identity file is created and cleaned up properly"""
562+
instance = SSHWrapper(
563+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
564+
default_username="testuser",
565+
ssh_identity=TEST_SSH_KEY
566+
)
567+
568+
with serve(instance) as client:
569+
with patch('subprocess.run') as mock_run:
570+
mock_run.return_value = MagicMock(returncode=0)
571+
572+
with patch('tempfile.NamedTemporaryFile') as mock_temp_file:
573+
with patch('os.chmod') as mock_chmod:
574+
with patch('os.unlink') as mock_unlink:
575+
# Mock the temporary file
576+
mock_temp_file_instance = MagicMock()
577+
mock_temp_file_instance.name = "/tmp/test_ssh_key_12345"
578+
mock_temp_file_instance.write = MagicMock()
579+
mock_temp_file_instance.close = MagicMock()
580+
mock_temp_file.return_value = mock_temp_file_instance
581+
582+
# Test SSH command with identity
583+
result = client.run(False, ["hostname"])
584+
585+
# Verify temporary file was created
586+
mock_temp_file.assert_called_once_with(mode='w', delete=False, suffix='_ssh_key')
587+
mock_temp_file_instance.write.assert_called_once_with(TEST_SSH_KEY)
588+
mock_temp_file_instance.close.assert_called_once()
589+
590+
# Verify proper permissions were set
591+
mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600)
592+
593+
# Verify temporary file was cleaned up
594+
mock_unlink.assert_called_once_with("/tmp/test_ssh_key_12345")
595+
596+
assert result == 0
597+
598+
599+
def test_ssh_identity_temp_file_creation_error():
600+
"""Test error handling when temporary identity file creation fails"""
601+
instance = SSHWrapper(
602+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
603+
default_username="testuser",
604+
ssh_identity=TEST_SSH_KEY
605+
)
606+
607+
with serve(instance) as client:
608+
with patch('subprocess.run') as mock_run:
609+
mock_run.return_value = MagicMock(returncode=0)
610+
611+
with patch('tempfile.NamedTemporaryFile') as mock_temp_file:
612+
mock_temp_file.side_effect = OSError("Permission denied")
613+
614+
# Test SSH command with identity should raise an error
615+
# The exception will be wrapped in an ExceptionGroup due to the context manager
616+
with pytest.raises(ExceptionGroup) as exc_info:
617+
client.run(False, ["hostname"])
618+
619+
# Check that the original OSError is in the exception group
620+
assert any(isinstance(e, OSError) and "Permission denied" in str(e) for e in exc_info.value.exceptions)
621+
622+
623+
def test_ssh_identity_temp_file_cleanup_error():
624+
"""Test error handling when temporary identity file cleanup fails"""
625+
instance = SSHWrapper(
626+
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
627+
default_username="testuser",
628+
ssh_identity=TEST_SSH_KEY
629+
)
630+
631+
with serve(instance) as client:
632+
with patch('subprocess.run') as mock_run:
633+
mock_run.return_value = MagicMock(returncode=0)
634+
635+
with patch('tempfile.NamedTemporaryFile') as mock_temp_file:
636+
with patch('os.chmod') as mock_chmod:
637+
with patch('os.unlink') as mock_unlink:
638+
# Mock the temporary file
639+
mock_temp_file_instance = MagicMock()
640+
mock_temp_file_instance.name = "/tmp/test_ssh_key_12345"
641+
mock_temp_file_instance.write = MagicMock()
642+
mock_temp_file_instance.close = MagicMock()
643+
mock_temp_file.return_value = mock_temp_file_instance
644+
645+
# Mock cleanup failure
646+
mock_unlink.side_effect = OSError("Permission denied")
647+
648+
# Test SSH command with identity - should still succeed but log warning
649+
with patch.object(client, 'logger') as mock_logger:
650+
result = client.run(False, ["hostname"])
651+
652+
# Verify chmod was called
653+
mock_chmod.assert_called_once_with("/tmp/test_ssh_key_12345", 0o600)
654+
655+
# Verify warning was logged
656+
mock_logger.warning.assert_called_once()
657+
warning_call = mock_logger.warning.call_args[0][0]
658+
assert "Failed to clean up temporary identity file" in warning_call
659+
assert "/tmp/test_ssh_key_12345" in warning_call
660+
661+
assert result == 0

0 commit comments

Comments
 (0)