11import click
22import pytest
33
4- from .client import BaseFlasherClient
4+ from .client import BaseFlasherClient , FlashNonRetryableError , FlashRetryableError
55from jumpstarter .common .exceptions import ArgumentError
66
77
@@ -12,7 +12,14 @@ def __init__(self):
1212 self ._manifest = None
1313 self ._console_debug = False
1414 self .logger = type (
15- "MockLogger" , (), {"warning" : lambda msg : None , "info" : lambda msg : None , "error" : lambda msg : None }
15+ "MockLogger" ,
16+ (),
17+ {
18+ "warning" : lambda * args , ** kwargs : None ,
19+ "info" : lambda * args , ** kwargs : None ,
20+ "error" : lambda * args , ** kwargs : None ,
21+ "exception" : lambda * args , ** kwargs : None ,
22+ },
1623 )()
1724
1825 def close (self ):
@@ -49,3 +56,146 @@ def test_flash_fails_with_invalid_headers():
4956
5057 with pytest .raises (ArgumentError , match = "Invalid header name 'Invalid Header': must be an HTTP token" ):
5158 client .flash ("test.raw" , headers = {"Invalid Header" : "value" })
59+
60+
61+ def test_categorize_exception_returns_non_retryable_when_present ():
62+ """Test that non-retryable errors take priority"""
63+ client = MockFlasherClient ()
64+
65+ # Direct non-retryable error
66+ error = FlashNonRetryableError ("Config error" )
67+ result = client ._categorize_exception (error )
68+ assert isinstance (result , FlashNonRetryableError )
69+ assert str (result ) == "Config error"
70+
71+
72+ def test_categorize_exception_returns_retryable_when_present ():
73+ """Test that retryable errors are returned"""
74+ client = MockFlasherClient ()
75+
76+ # Direct retryable error
77+ error = FlashRetryableError ("Network timeout" )
78+ result = client ._categorize_exception (error )
79+ assert isinstance (result , FlashRetryableError )
80+ assert str (result ) == "Network timeout"
81+
82+
83+ def test_categorize_exception_wraps_unknown_exceptions ():
84+ """Test that unknown exceptions are wrapped as retryable"""
85+ client = MockFlasherClient ()
86+
87+ # Unknown exception type
88+ error = ValueError ("Something went wrong" )
89+ result = client ._categorize_exception (error )
90+ assert isinstance (result , FlashRetryableError )
91+ assert "ValueError" in str (result )
92+ assert "Something went wrong" in str (result )
93+ # Verify the cause chain is preserved
94+ assert result .__cause__ is error
95+
96+
97+ def test_categorize_exception_non_retryable_takes_priority_over_retryable ():
98+ """Test that non-retryable errors take priority in cause chain"""
99+ client = MockFlasherClient ()
100+
101+ # Create a chain: retryable caused by non-retryable
102+ non_retryable = FlashNonRetryableError ("Config issue" )
103+ retryable = FlashRetryableError ("Network error" )
104+ retryable .__cause__ = non_retryable
105+
106+ result = client ._categorize_exception (retryable )
107+ assert isinstance (result , FlashNonRetryableError )
108+ assert str (result ) == "Config issue"
109+
110+
111+ def test_categorize_exception_searches_cause_chain ():
112+ """Test that categorization searches through the cause chain"""
113+ client = MockFlasherClient ()
114+
115+ # Create a chain: generic -> generic -> retryable
116+ root = FlashRetryableError ("Root cause" )
117+ middle = ValueError ("Middle error" )
118+ middle .__cause__ = root
119+ top = RuntimeError ("Top error" )
120+ top .__cause__ = middle
121+
122+ result = client ._categorize_exception (top )
123+ assert isinstance (result , FlashRetryableError )
124+ assert str (result ) == "Root cause"
125+
126+
127+ def test_find_exception_in_chain_finds_target_type ():
128+ """Test that _find_exception_in_chain correctly finds the target type"""
129+ client = MockFlasherClient ()
130+
131+ # Create a chain with retryable error
132+ retryable = FlashRetryableError ("Network error" )
133+ generic = RuntimeError ("Generic error" )
134+ generic .__cause__ = retryable
135+
136+ result = client ._find_exception_in_chain (generic , FlashRetryableError )
137+ assert result is retryable
138+ assert str (result ) == "Network error"
139+
140+
141+ def test_find_exception_in_chain_returns_none_when_not_found ():
142+ """Test that _find_exception_in_chain returns None when target not found"""
143+ client = MockFlasherClient ()
144+
145+ error = ValueError ("Some error" )
146+ result = client ._find_exception_in_chain (error , FlashRetryableError )
147+ assert result is None
148+
149+
150+ def test_find_exception_in_chain_handles_exception_groups ():
151+ """Test that _find_exception_in_chain searches through ExceptionGroups"""
152+ client = MockFlasherClient ()
153+
154+ # Create an ExceptionGroup with a retryable error
155+ retryable = FlashRetryableError ("Network timeout" )
156+ generic = ValueError ("Generic error" )
157+
158+ # Mock an ExceptionGroup (Python 3.11+)
159+ class MockExceptionGroup (Exception ):
160+ def __init__ (self , message , exceptions ):
161+ super ().__init__ (message )
162+ self .exceptions = exceptions
163+
164+ group = MockExceptionGroup ("Multiple errors" , [generic , retryable ])
165+
166+ result = client ._find_exception_in_chain (group , FlashRetryableError )
167+ assert result is retryable
168+
169+
170+ def test_categorize_exception_with_nested_exception_groups ():
171+ """Test categorization with nested ExceptionGroups"""
172+ client = MockFlasherClient ()
173+
174+ # Create nested ExceptionGroups
175+ non_retryable = FlashNonRetryableError ("Config error" )
176+
177+ class MockExceptionGroup (Exception ):
178+ def __init__ (self , message , exceptions ):
179+ super ().__init__ (message )
180+ self .exceptions = exceptions
181+
182+ inner_group = MockExceptionGroup ("Inner errors" , [non_retryable ])
183+ outer_group = MockExceptionGroup ("Outer errors" , [ValueError ("Other" ), inner_group ])
184+
185+ result = client ._categorize_exception (outer_group )
186+ assert isinstance (result , FlashNonRetryableError )
187+ assert str (result ) == "Config error"
188+
189+
190+ def test_categorize_exception_preserves_cause_for_wrapped_exceptions ():
191+ """Test that wrapped unknown exceptions preserve the cause chain"""
192+ client = MockFlasherClient ()
193+
194+ original = IOError ("File not found" )
195+ result = client ._categorize_exception (original )
196+
197+ assert isinstance (result , FlashRetryableError )
198+ assert result .__cause__ is original
199+ # IOError is an alias for OSError in Python 3
200+ assert "OSError" in str (result ) or "IOError" in str (result )
201+ assert "File not found" in str (result )
0 commit comments