Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.

Commit f416d43

Browse files
mangelajogithub-actions[bot]
authored andcommitted
flasher: make unknown exceptions retriable
(cherry picked from commit 6bced4a)
1 parent 6280cb1 commit f416d43

2 files changed

Lines changed: 203 additions & 51 deletions

File tree

packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,18 @@ def flash( # noqa: C901
177177
self.logger.info(f"Flash operation succeeded on attempt {attempt + 1}")
178178
break
179179
except Exception as e:
180-
# Check if this is a retryable or non-retryable error (including sub-exceptions)
181-
retryable_error = self._get_retryable_error(e)
182-
non_retryable_error = self._get_non_retryable_error(e)
180+
# Categorize the exception as retryable or non-retryable
181+
categorized_error = self._categorize_exception(e)
183182

184-
if retryable_error is not None:
183+
if isinstance(categorized_error, FlashNonRetryableError):
184+
# Non-retryable error, fail immediately
185+
self.logger.error(f"Flash operation failed with non-retryable error: {categorized_error}")
186+
raise FlashError(f"Flash operation failed: {categorized_error}") from e
187+
else:
188+
# Retryable error
185189
if attempt < retries:
186190
self.logger.warning(
187-
f"Flash attempt {attempt + 1} failed with retryable error: {retryable_error}"
191+
f"Flash attempt {attempt + 1} failed with retryable error: {categorized_error}"
188192
)
189193
self.logger.info(f"Retrying flash operation (attempt {attempt + 2}/{retries + 1})")
190194
# Wait a bit before retrying
@@ -193,86 +197,84 @@ def flash( # noqa: C901
193197
else:
194198
self.logger.error(f"Flash operation failed after {retries + 1} attempts")
195199
raise FlashError(
196-
f"Flash operation failed after {retries + 1} attempts. Last error: {retryable_error}"
200+
f"Flash operation failed after {retries + 1} attempts. Last error: {categorized_error}"
197201
) from e
198-
elif non_retryable_error is not None:
199-
# Non-retryable error, fail immediately
200-
self.logger.error(f"Flash operation failed with non-retryable error: {non_retryable_error}")
201-
raise FlashError(f"Flash operation failed: {non_retryable_error}") from e
202-
else:
203-
# Unexpected error, don't retry
204-
self.logger.error(f"Flash operation failed with unexpected error: {e}")
205-
raise FlashError(f"Flash operation failed: {e}") from e
206202

207203

208204
total_time = time.time() - start_time
209205
# total time in minutes:seconds
210206
minutes, seconds = divmod(total_time, 60)
211207
self.logger.info(f"Flashing completed in {int(minutes)}m {int(seconds):02d}s")
212208

213-
def _get_retryable_error(self, exception: Exception) -> FlashRetryableError | None:
214-
"""Find a retryable error in an exception (or any of its causes).
209+
def _categorize_exception(self, exception: Exception) -> FlashRetryableError | FlashNonRetryableError:
210+
"""Categorize an exception as retryable or non-retryable.
211+
212+
This method searches through the exception chain (including ExceptionGroups)
213+
to find FlashRetryableError or FlashNonRetryableError instances.
214+
215+
Priority:
216+
1. FlashNonRetryableError - highest priority, fail immediately
217+
2. FlashRetryableError - retry with backoff
218+
3. Unknown exceptions - log full stack trace and treat as retryable
215219
216220
Args:
217-
exception: The exception to check
221+
exception: The exception to categorize
218222
219223
Returns:
220-
The FlashRetryableError if found, None otherwise
224+
FlashRetryableError or FlashNonRetryableError
221225
"""
222-
# Check if this is an ExceptionGroup and look through its exceptions
223-
if hasattr(exception, 'exceptions'):
224-
for sub_exc in exception.exceptions:
225-
result = self._get_retryable_error(sub_exc)
226-
if result is not None:
227-
return result
228-
229-
# Check the current exception
230-
if isinstance(exception, FlashRetryableError):
231-
return exception
226+
# First pass: look for non-retryable errors (highest priority)
227+
non_retryable = self._find_exception_in_chain(exception, FlashNonRetryableError)
228+
if non_retryable is not None:
229+
return non_retryable
230+
231+
# Second pass: look for retryable errors
232+
retryable = self._find_exception_in_chain(exception, FlashRetryableError)
233+
if retryable is not None:
234+
return retryable
235+
236+
# Unknown exception - log full stack trace and wrap as retryable
237+
self.logger.exception(
238+
f"Unknown exception encountered during flash operation, treating as retryable: "
239+
f"{type(exception).__name__}: {exception}"
240+
)
241+
wrapped_exception = FlashRetryableError(f"Unknown error occurred: {type(exception).__name__}: {exception}")
242+
wrapped_exception.__cause__ = exception
243+
return wrapped_exception
232244

233-
# Check the cause chain
234-
current = getattr(exception, '__cause__', None)
235-
while current is not None:
236-
if isinstance(current, FlashRetryableError):
237-
return current
238-
# Also check if the cause is an ExceptionGroup
239-
if hasattr(current, 'exceptions'):
240-
for sub_exc in current.exceptions:
241-
result = self._get_retryable_error(sub_exc)
242-
if result is not None:
243-
return result
244-
current = getattr(current, '__cause__', None)
245-
return None
245+
def _find_exception_in_chain(self, exception: Exception, target_type: type) -> Exception | None:
246+
"""Find an exception of a specific type in an exception chain.
246247
247-
def _get_non_retryable_error(self, exception: Exception) -> FlashNonRetryableError | None:
248-
"""Find a non-retryable error in an exception (or any of its causes).
248+
Searches through the exception, its ExceptionGroup members (if any),
249+
and the cause chain recursively.
249250
250251
Args:
251-
exception: The exception to check
252+
exception: The exception to search
253+
target_type: The exception type to search for
252254
253255
Returns:
254-
The FlashNonRetryableError if found, None otherwise
256+
The found exception instance if found, None otherwise
255257
"""
256258
# Check if this is an ExceptionGroup and look through its exceptions
257259
if hasattr(exception, 'exceptions'):
258260
for sub_exc in exception.exceptions:
259-
result = self._get_non_retryable_error(sub_exc)
261+
result = self._find_exception_in_chain(sub_exc, target_type)
260262
if result is not None:
261263
return result
262264

263265
# Check the current exception
264-
if isinstance(exception, FlashNonRetryableError):
266+
if isinstance(exception, target_type):
265267
return exception
266268

267269
# Check the cause chain
268270
current = getattr(exception, '__cause__', None)
269271
while current is not None:
270-
if isinstance(current, FlashNonRetryableError):
272+
if isinstance(current, target_type):
271273
return current
272274
# Also check if the cause is an ExceptionGroup
273275
if hasattr(current, 'exceptions'):
274276
for sub_exc in current.exceptions:
275-
result = self._get_non_retryable_error(sub_exc)
277+
result = self._find_exception_in_chain(sub_exc, target_type)
276278
if result is not None:
277279
return result
278280
current = getattr(current, '__cause__', None)

packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client_test.py

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import click
22
import pytest
33

4-
from .client import BaseFlasherClient
4+
from .client import BaseFlasherClient, FlashNonRetryableError, FlashRetryableError
55
from jumpstarter.common.exceptions import ArgumentError
66

77

@@ -12,7 +12,14 @@ def __init__(self):
1212
self._manifest = None
1313
self._console_debug = False
1414
self.logger = type(
15-
"MockLogger", (), {"warning": lambda msg: None, "info": lambda msg: None, "error": lambda msg: None}
15+
"MockLogger",
16+
(),
17+
{
18+
"warning": lambda *args, **kwargs: None,
19+
"info": lambda *args, **kwargs: None,
20+
"error": lambda *args, **kwargs: None,
21+
"exception": lambda *args, **kwargs: None,
22+
},
1623
)()
1724

1825
def close(self):
@@ -49,3 +56,146 @@ def test_flash_fails_with_invalid_headers():
4956

5057
with pytest.raises(ArgumentError, match="Invalid header name 'Invalid Header': must be an HTTP token"):
5158
client.flash("test.raw", headers={"Invalid Header": "value"})
59+
60+
61+
def test_categorize_exception_returns_non_retryable_when_present():
62+
"""Test that non-retryable errors take priority"""
63+
client = MockFlasherClient()
64+
65+
# Direct non-retryable error
66+
error = FlashNonRetryableError("Config error")
67+
result = client._categorize_exception(error)
68+
assert isinstance(result, FlashNonRetryableError)
69+
assert str(result) == "Config error"
70+
71+
72+
def test_categorize_exception_returns_retryable_when_present():
73+
"""Test that retryable errors are returned"""
74+
client = MockFlasherClient()
75+
76+
# Direct retryable error
77+
error = FlashRetryableError("Network timeout")
78+
result = client._categorize_exception(error)
79+
assert isinstance(result, FlashRetryableError)
80+
assert str(result) == "Network timeout"
81+
82+
83+
def test_categorize_exception_wraps_unknown_exceptions():
84+
"""Test that unknown exceptions are wrapped as retryable"""
85+
client = MockFlasherClient()
86+
87+
# Unknown exception type
88+
error = ValueError("Something went wrong")
89+
result = client._categorize_exception(error)
90+
assert isinstance(result, FlashRetryableError)
91+
assert "ValueError" in str(result)
92+
assert "Something went wrong" in str(result)
93+
# Verify the cause chain is preserved
94+
assert result.__cause__ is error
95+
96+
97+
def test_categorize_exception_non_retryable_takes_priority_over_retryable():
98+
"""Test that non-retryable errors take priority in cause chain"""
99+
client = MockFlasherClient()
100+
101+
# Create a chain: retryable caused by non-retryable
102+
non_retryable = FlashNonRetryableError("Config issue")
103+
retryable = FlashRetryableError("Network error")
104+
retryable.__cause__ = non_retryable
105+
106+
result = client._categorize_exception(retryable)
107+
assert isinstance(result, FlashNonRetryableError)
108+
assert str(result) == "Config issue"
109+
110+
111+
def test_categorize_exception_searches_cause_chain():
112+
"""Test that categorization searches through the cause chain"""
113+
client = MockFlasherClient()
114+
115+
# Create a chain: generic -> generic -> retryable
116+
root = FlashRetryableError("Root cause")
117+
middle = ValueError("Middle error")
118+
middle.__cause__ = root
119+
top = RuntimeError("Top error")
120+
top.__cause__ = middle
121+
122+
result = client._categorize_exception(top)
123+
assert isinstance(result, FlashRetryableError)
124+
assert str(result) == "Root cause"
125+
126+
127+
def test_find_exception_in_chain_finds_target_type():
128+
"""Test that _find_exception_in_chain correctly finds the target type"""
129+
client = MockFlasherClient()
130+
131+
# Create a chain with retryable error
132+
retryable = FlashRetryableError("Network error")
133+
generic = RuntimeError("Generic error")
134+
generic.__cause__ = retryable
135+
136+
result = client._find_exception_in_chain(generic, FlashRetryableError)
137+
assert result is retryable
138+
assert str(result) == "Network error"
139+
140+
141+
def test_find_exception_in_chain_returns_none_when_not_found():
142+
"""Test that _find_exception_in_chain returns None when target not found"""
143+
client = MockFlasherClient()
144+
145+
error = ValueError("Some error")
146+
result = client._find_exception_in_chain(error, FlashRetryableError)
147+
assert result is None
148+
149+
150+
def test_find_exception_in_chain_handles_exception_groups():
151+
"""Test that _find_exception_in_chain searches through ExceptionGroups"""
152+
client = MockFlasherClient()
153+
154+
# Create an ExceptionGroup with a retryable error
155+
retryable = FlashRetryableError("Network timeout")
156+
generic = ValueError("Generic error")
157+
158+
# Mock an ExceptionGroup (Python 3.11+)
159+
class MockExceptionGroup(Exception):
160+
def __init__(self, message, exceptions):
161+
super().__init__(message)
162+
self.exceptions = exceptions
163+
164+
group = MockExceptionGroup("Multiple errors", [generic, retryable])
165+
166+
result = client._find_exception_in_chain(group, FlashRetryableError)
167+
assert result is retryable
168+
169+
170+
def test_categorize_exception_with_nested_exception_groups():
171+
"""Test categorization with nested ExceptionGroups"""
172+
client = MockFlasherClient()
173+
174+
# Create nested ExceptionGroups
175+
non_retryable = FlashNonRetryableError("Config error")
176+
177+
class MockExceptionGroup(Exception):
178+
def __init__(self, message, exceptions):
179+
super().__init__(message)
180+
self.exceptions = exceptions
181+
182+
inner_group = MockExceptionGroup("Inner errors", [non_retryable])
183+
outer_group = MockExceptionGroup("Outer errors", [ValueError("Other"), inner_group])
184+
185+
result = client._categorize_exception(outer_group)
186+
assert isinstance(result, FlashNonRetryableError)
187+
assert str(result) == "Config error"
188+
189+
190+
def test_categorize_exception_preserves_cause_for_wrapped_exceptions():
191+
"""Test that wrapped unknown exceptions preserve the cause chain"""
192+
client = MockFlasherClient()
193+
194+
original = IOError("File not found")
195+
result = client._categorize_exception(original)
196+
197+
assert isinstance(result, FlashRetryableError)
198+
assert result.__cause__ is original
199+
# IOError is an alias for OSError in Python 3
200+
assert "OSError" in str(result) or "IOError" in str(result)
201+
assert "File not found" in str(result)

0 commit comments

Comments
 (0)