Skip to content

Commit efb32ba

Browse files
fix: prevent data loss on multipart complete and part retry
- Serialize metadata save before state deletion in complete_multipart to prevent DEK loss if metadata save fails - Subtract old part size before adding new on part retry to prevent total_plaintext_size from doubling
1 parent a2d061f commit efb32ba

2 files changed

Lines changed: 16 additions & 12 deletions

File tree

s3proxy/handlers/multipart/lifecycle.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,21 +160,22 @@ async def handle_complete_multipart_upload(
160160
e, client, bucket, key, upload_id, s3_parts, completed_parts, total_plaintext
161161
)
162162

163-
# Save metadata and cleanup
163+
# Save metadata first, then delete state.
164+
# Order matters: if metadata save fails, state is preserved
165+
# so the upload can be retried. Deleting state first would
166+
# lose the DEK, making the object permanently undecryptable.
164167
wrapped_dek = crypto.wrap_key(state.dek, self.settings.kek)
165-
await asyncio.gather(
166-
save_multipart_metadata(
167-
client, bucket, key,
168-
MultipartMetadata(
169-
version=1,
170-
part_count=len(completed_parts),
171-
total_plaintext_size=total_plaintext,
172-
parts=completed_parts,
173-
wrapped_dek=wrapped_dek,
174-
),
168+
await save_multipart_metadata(
169+
client, bucket, key,
170+
MultipartMetadata(
171+
version=1,
172+
part_count=len(completed_parts),
173+
total_plaintext_size=total_plaintext,
174+
parts=completed_parts,
175+
wrapped_dek=wrapped_dek,
175176
),
176-
delete_upload_state(client, bucket, key, upload_id),
177177
)
178+
await delete_upload_state(client, bucket, key, upload_id)
178179

179180
logger.info(
180181
"COMPLETE_MULTIPART_SUCCESS",

s3proxy/state/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def updater(data: bytes) -> bytes:
152152
if state is None:
153153
raise StateMissingError(f"Upload state corrupted for {bucket}/{key}")
154154

155+
old_part = state.parts.get(part.part_number)
156+
if old_part is not None:
157+
state.total_plaintext_size -= old_part.plaintext_size
155158
state.parts[part.part_number] = part
156159
state.total_plaintext_size += part.plaintext_size
157160
if max_internal >= state.next_internal_part_number:

0 commit comments

Comments
 (0)