2020from collections import OrderedDict
2121from contextlib import suppress
2222from copy import deepcopy
23+ from dataclasses import asdict
2324from itertools import chain
2425from typing import Any
2526
3233 _NUMPY_DTYPES_MAPPING ,
3334 _PIL_AVAILABLE ,
3435 _TORCH_DTYPES_MAPPING ,
36+ _TORCH_VISION_LESS_THAN_0_26 ,
3537)
3638
3739
@@ -403,6 +405,13 @@ def serialize(self, filepath: str) -> tuple[bytes, str | None]:
403405 return f .read (), f"video:{ file_extension } "
404406
405407 def deserialize (self , data : bytes ) -> Any :
408+ # if using torchvision <=0.25, we will use torchvision.io to decode the video
409+ # otherwise, we will use torchcodec to decode the video, which is faster and more robust
410+ if _TORCH_VISION_LESS_THAN_0_26 :
411+ return self ._deserialize_with_torchvision_io (data )
412+ return self ._deserialize_with_torchcodec (data )
413+
414+ def _deserialize_with_torchvision_io (self , data : bytes ) -> Any :
406415 if not _AV_AVAILABLE :
407416 raise ModuleNotFoundError ("av is required. Run `pip install av`" )
408417
@@ -416,6 +425,29 @@ def deserialize(self, data: bytes) -> Any:
416425 stream .write (data )
417426 return torchvision .io .read_video (fname , pts_unit = "sec" )
418427
428+ def _deserialize_with_torchcodec (self , data : bytes ) -> Any :
429+ try :
430+ import torch
431+ from torchcodec .decoders import AudioDecoder , VideoDecoder
432+ except ImportError :
433+ raise ModuleNotFoundError ("torchcodec is required. Run `pip install torchcodec>0.11`" )
434+
435+ dec = VideoDecoder (data , dimension_order = "NHWC" ) # NHWC → T,H,W,C after stacking
436+ metadata = asdict (dec .metadata ) if dec .metadata is not None else {}
437+
438+ # get_all_frames() returns a FrameBatch; .data is (N, C, H, W) or (N, H, W, C)
439+ # depending on dimension_order above
440+ frame_batch = dec .get_all_frames ()
441+ video = frame_batch .data # shape: (T, H, W, C) with NHWC
442+
443+ try :
444+ audio_dec = AudioDecoder (data )
445+ audio = audio_dec .get_all_samples ().data # (num_channels, num_samples)
446+ except ValueError :
447+ audio = torch .zeros (1 , 0 ) # old torchvision path returns aframes with shape (1, 0) for no-audio videos.
448+
449+ return video , audio , metadata
450+
419451 def can_serialize (self , data : Any ) -> bool :
420452 return isinstance (data , str ) and os .path .isfile (data ) and any (data .endswith (ext ) for ext in self ._EXTENSIONS )
421453
0 commit comments