Skip to content

Commit 90bd404

Browse files
authored
feat: add support for video deserialization with torchcodec when torchvision>0.25 (#802)
1 parent 36431bd commit 90bd404

5 files changed

Lines changed: 52 additions & 1 deletion

File tree

.github/workflows/ci-testing.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ jobs:
3434
UV_TORCH_BACKEND: "cpu"
3535

3636
steps:
37+
# FFmpeg is required for Video Deserializer tests
38+
- name: Setup FFmpeg (shared libs, Linux)
39+
if: runner.os == 'Linux'
40+
run: sudo apt-get update && sudo apt-get install -y ffmpeg
41+
42+
- name: Setup FFmpeg (shared libs, macOS)
43+
if: runner.os == 'macOS'
44+
run: brew install ffmpeg
45+
3746
- uses: actions/checkout@v6
3847
- name: Install uv and setup python ${{ matrix.python-version }}
3948
uses: astral-sh/setup-uv@v7

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ lightning
1616
transformers >=4.51.0
1717
zstd; python_version < "3.14"
1818
soundfile >=0.13.0
19+
torchcodec >=0.1.0 # check compatibility table: https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-torchcodec

src/litdata/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_POLARS_AVAILABLE = RequirementCache("polars>1.0.0")
5050
_PIL_AVAILABLE = RequirementCache("PIL")
5151
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
52+
_TORCH_VISION_LESS_THAN_0_26 = RequirementCache("torchvision<0.26.0")
5253
_AV_AVAILABLE = RequirementCache("av")
5354
_OBSTORE_AVAILABLE = RequirementCache("obstore")
5455

src/litdata/streaming/serializers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from collections import OrderedDict
2121
from contextlib import suppress
2222
from copy import deepcopy
23+
from dataclasses import asdict
2324
from itertools import chain
2425
from typing import Any
2526

@@ -32,6 +33,7 @@
3233
_NUMPY_DTYPES_MAPPING,
3334
_PIL_AVAILABLE,
3435
_TORCH_DTYPES_MAPPING,
36+
_TORCH_VISION_LESS_THAN_0_26,
3537
)
3638

3739

@@ -403,6 +405,13 @@ def serialize(self, filepath: str) -> tuple[bytes, str | None]:
403405
return f.read(), f"video:{file_extension}"
404406

405407
def deserialize(self, data: bytes) -> Any:
408+
# if using torchvision <=0.25, we will use torchvision.io to decode the video
409+
# otherwise, we will use torchcodec to decode the video, which is faster and more robust
410+
if _TORCH_VISION_LESS_THAN_0_26:
411+
return self._deserialize_with_torchvision_io(data)
412+
return self._deserialize_with_torchcodec(data)
413+
414+
def _deserialize_with_torchvision_io(self, data: bytes) -> Any:
406415
if not _AV_AVAILABLE:
407416
raise ModuleNotFoundError("av is required. Run `pip install av`")
408417

@@ -416,6 +425,29 @@ def deserialize(self, data: bytes) -> Any:
416425
stream.write(data)
417426
return torchvision.io.read_video(fname, pts_unit="sec")
418427

428+
def _deserialize_with_torchcodec(self, data: bytes) -> Any:
429+
try:
430+
import torch
431+
from torchcodec.decoders import AudioDecoder, VideoDecoder
432+
except ImportError:
433+
raise ModuleNotFoundError("torchcodec is required. Run `pip install torchcodec>0.11`")
434+
435+
dec = VideoDecoder(data, dimension_order="NHWC") # NHWC → T,H,W,C after stacking
436+
metadata = asdict(dec.metadata) if dec.metadata is not None else {}
437+
438+
# get_all_frames() returns a FrameBatch; .data is (N, C, H, W) or (N, H, W, C)
439+
# depending on dimension_order above
440+
frame_batch = dec.get_all_frames()
441+
video = frame_batch.data # shape: (T, H, W, C) with NHWC
442+
443+
try:
444+
audio_dec = AudioDecoder(data)
445+
audio = audio_dec.get_all_samples().data # (num_channels, num_samples)
446+
except ValueError:
447+
audio = torch.zeros(1, 0) # old torchvision path returns aframes with shape (1, 0) for no-audio videos.
448+
449+
return video, audio, metadata
450+
419451
def can_serialize(self, data: Any) -> bool:
420452
return isinstance(data, str) and os.path.isfile(data) and any(data.endswith(ext) for ext in self._EXTENSIONS)
421453

tests/streaming/test_serializer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def test_assert_no_header_numpy_serializer():
277277
np.testing.assert_equal(t, new_t)
278278

279279

280+
@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows")
280281
@pytest.mark.skipif(condition=not _AV_AVAILABLE, reason="Requires: 'av'")
281282
def test_wav_deserialization(tmpdir):
282283
from torch.hub import download_url_to_file
@@ -293,7 +294,14 @@ def test_wav_deserialization(tmpdir):
293294
vframes, aframes, info = serializer.deserialize(data)
294295
assert vframes.shape == torch.Size([301, 512, 512, 3])
295296
assert aframes.shape == torch.Size([1, 0])
296-
assert info == {"video_fps": 25.0}
297+
# The metadata keys for video serialization may vary by serializer.
298+
# For example, `torchvision` typically uses `video_fps`, while `torchcodec` uses `average_fps`.
299+
# Despite these naming differences, both keys represent the same fps value,
300+
# ensuring consistency in video frame rate representation across serialization methods.
301+
assert "video_fps" in info or "average_fps" in info
302+
fps = info.get("video_fps", info.get("average_fps"))
303+
assert fps is not None
304+
assert fps == 25.0
297305

298306

299307
def test_get_serializers():

0 commit comments

Comments
 (0)