Skip to content

Commit 1785d51

Browse files
committed
fix volume mount decode, where only base volume fields were decoded
1 parent a18ec94 commit 1785d51

2 files changed

Lines changed: 103 additions & 2 deletions

File tree

tests/unit_tests/job_deployments/test_job_deployments.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
import pytest
44
import responses # https://github.com/getsentry/responses
55

6+
from dataclasses import replace
7+
68
from verda.containers import ComputeResource, Container, ContainerRegistrySettings
9+
from verda.containers._containers import (
10+
GeneralStorageMount,
11+
MemoryMount,
12+
SecretMount,
13+
SharedFileSystemMount,
14+
)
715
from verda.exceptions import APIException
816
from verda.job_deployments import (
917
JobDeployment,
@@ -207,3 +215,71 @@ def test_purge_job_deployment_queue(self, service, endpoint):
207215
service.purge_queue(JOB_NAME)
208216

209217
assert responses.assert_call_count(url, 1) is True
218+
219+
@responses.activate
220+
def test_update_preserves_volume_mounts_round_trip(self, service, endpoint):
221+
"""Regression test: volume mount subclass fields (volume_id, secret_name, etc.)
222+
must survive a get → update round trip without being dropped during deserialization."""
223+
volume_id = '550e8400-e29b-41d4-a716-446655440000'
224+
api_payload = {
225+
'name': JOB_NAME,
226+
'containers': [
227+
{
228+
'name': CONTAINER_NAME,
229+
'image': 'busybox:latest',
230+
'exposed_port': 8080,
231+
'env': [],
232+
'volume_mounts': [
233+
{'type': 'scratch', 'mount_path': '/data'},
234+
{'type': 'shared', 'mount_path': '/sfs', 'volume_id': volume_id},
235+
{
236+
'type': 'secret',
237+
'mount_path': '/secrets',
238+
'secret_name': 'my-secret',
239+
'file_names': ['key.pem'],
240+
},
241+
{'type': 'memory', 'mount_path': '/dev/shm', 'size_in_mb': 512},
242+
],
243+
}
244+
],
245+
'endpoint_base_url': 'https://test-job.datacrunch.io',
246+
'created_at': '2024-01-01T00:00:00Z',
247+
'compute': {'name': 'H100', 'size': 1},
248+
'container_registry_settings': {'is_private': False, 'credentials': None},
249+
}
250+
251+
get_url = f'{endpoint}/{JOB_NAME}'
252+
responses.add(responses.GET, get_url, json=api_payload, status=200)
253+
responses.add(responses.PATCH, get_url, json=api_payload, status=200)
254+
255+
# Simulate the user's flow: get → modify image → update
256+
deployment = service.get_by_name(JOB_NAME)
257+
258+
# Verify deserialization produced the correct subclasses
259+
vms = deployment.containers[0].volume_mounts
260+
assert isinstance(vms[0], GeneralStorageMount)
261+
assert isinstance(vms[1], SharedFileSystemMount)
262+
assert vms[1].volume_id == volume_id
263+
assert isinstance(vms[2], SecretMount)
264+
assert vms[2].secret_name == 'my-secret'
265+
assert vms[2].file_names == ['key.pem']
266+
assert isinstance(vms[3], MemoryMount)
267+
assert vms[3].size_in_mb == 512
268+
269+
# Update only the image (exactly what the reported user script does)
270+
containers = list(deployment.containers)
271+
containers[0] = replace(containers[0], image='busybox:v2')
272+
updated_deployment = replace(deployment, containers=containers)
273+
274+
service.update(JOB_NAME, updated_deployment)
275+
276+
# Verify the PATCH request body still contains volume_id
277+
request_body = json.loads(responses.calls[1].request.body.decode('utf-8'))
278+
sent_vms = request_body['containers'][0]['volume_mounts']
279+
assert sent_vms[0]['type'] == 'scratch'
280+
assert sent_vms[1]['type'] == 'shared'
281+
assert sent_vms[1]['volume_id'] == volume_id
282+
assert sent_vms[2]['type'] == 'secret'
283+
assert sent_vms[2]['secret_name'] == 'my-secret'
284+
assert sent_vms[3]['type'] == 'memory'
285+
assert sent_vms[3]['size_in_mb'] == 512

verda/containers/_containers.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from enum import Enum
1111
from typing import Any
1212

13-
from dataclasses_json import Undefined, dataclass_json # type: ignore
13+
from dataclasses_json import Undefined, config, dataclass_json # type: ignore
1414

1515
from verda.http_client import HTTPClient
1616
from verda.inference_client import InferenceClient, InferenceResponse
@@ -203,6 +203,29 @@ def __init__(self, mount_path: str, volume_id: str):
203203
self.volume_id = volume_id
204204

205205

206+
def _decode_volume_mount(data: dict) -> VolumeMount:
207+
"""Decode a volume mount dict into the correct VolumeMount subclass based on type."""
208+
mount_type = data.get('type')
209+
if mount_type == VolumeMountType.SHARED or mount_type == 'shared':
210+
return SharedFileSystemMount(mount_path=data['mount_path'], volume_id=data['volume_id'])
211+
if mount_type == VolumeMountType.SECRET or mount_type == 'secret':
212+
return SecretMount(
213+
mount_path=data['mount_path'],
214+
secret_name=data['secret_name'],
215+
file_names=data.get('file_names'),
216+
)
217+
if mount_type == VolumeMountType.MEMORY or mount_type == 'memory':
218+
return MemoryMount(size_in_mb=data['size_in_mb'])
219+
return GeneralStorageMount(mount_path=data['mount_path'])
220+
221+
222+
def _decode_volume_mounts(data: list[dict] | None) -> list[VolumeMount] | None:
223+
"""Decode a list of volume mount dicts into the correct VolumeMount subclasses."""
224+
if not data:
225+
return None
226+
return [_decode_volume_mount(v) for v in data]
227+
228+
206229
@dataclass_json
207230
@dataclass
208231
class Container:
@@ -224,7 +247,9 @@ class Container:
224247
healthcheck: HealthcheckSettings | None = None
225248
entrypoint_overrides: EntrypointOverridesSettings | None = None
226249
env: list[EnvVar] | None = None
227-
volume_mounts: list[VolumeMount] | None = None
250+
volume_mounts: list[VolumeMount] | None = field(
251+
default=None, metadata=config(decoder=_decode_volume_mounts)
252+
)
228253

229254

230255
@dataclass_json

0 commit comments

Comments
 (0)