├── data └── .gitkeep ├── .dockerignore ├── requirements.txt ├── readme_imgs └── benchmarks.png ├── src ├── video_io │ ├── __init__.py │ ├── utils.py │ ├── nvcodec_reader.py │ ├── torchvision_reader.py │ ├── torchcodec_reader.py │ ├── opencv_reader.py │ ├── vali_reader.py │ └── abstract_reader.py └── transforms.py ├── .pre-commit-config.yaml ├── pyproject.toml ├── Makefile ├── tests └── video_io │ ├── test_utils.py │ └── test_readers.py ├── .gitignore ├── Dockerfile ├── scripts └── benchmark.py └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | data 2 | **/__pycache__/ 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kornia==0.8.2 2 | PyNvVideoCodec==2.0.3 3 | python-vali==4.8.6 4 | av==16.0.1 5 | pytest==9.0.2 6 | -------------------------------------------------------------------------------- /readme_imgs/benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NikolasEnt/decode-video-pytorch/HEAD/readme_imgs/benchmarks.png -------------------------------------------------------------------------------- /src/video_io/__init__.py: -------------------------------------------------------------------------------- 1 | from .vali_reader import VALIVideoReader 2 | from .opencv_reader import OpenCVVideoReader 3 | from .nvcodec_reader import PyNvVideoCodecReader 4 | from .abstract_reader import AbstractVideoReader 5 | from .torchcodec_reader import TorchcodecVideoReader 6 | from .torchvision_reader import TorchvisionVideoReader 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-toml 9 | - id: check-added-large-files 10 | - id: debug-statements 11 | - id: detect-private-key 12 | - id: check-merge-conflict 13 | - repo: https://github.com/pycqa/isort 14 | rev: 7.0.0 15 | hooks: 16 | - id: isort 17 | name: isort (python) 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: v0.14.10 20 | hooks: 21 | - id: ruff 22 | types_or: [ python, pyi, jupyter ] 23 | args: [ --fix ] 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 79 3 | src = ["src", "scripts"] 4 | extend-exclude = ["data/"] 5 | target-version = "py311" 6 | 7 | [tool.ruff.lint] 8 | select = [ 9 | "E", # pycodestyle error 10 | "W", # pycodestyle warnings 11 | "F", # Pyflakes 12 | "B", # flake8-bugbear 13 | "D", # pydocstyle 14 | "N", # pep8-naming 15 | ] 16 | ignore = [ 17 | "D10", # D10?: Missing docstring 18 | "I001" # Imports are sorted by isort 19 | ] 20 | unfixable = ["B"] 21 | 22 | [tool.ruff.lint.per-file-ignores] 23 | "__init__.py" = ["F401"] 24 | 25 | [tool.ruff.lint.pydocstyle] 26 | convention = "google" 27 | 28 | [tool.ruff.format] 29 | quote-style = "single" 30 | 31 | [tool.isort] 32 | length_sort = true 33 | multi_line_output = 2 # Hanging indent 34 | atomic = true 35 | line_length = 79 36 | -------------------------------------------------------------------------------- /src/transforms.py: -------------------------------------------------------------------------------- 1 | """An example of kornia-based video transforms. 2 | 3 | See documentation in: 4 | https://kornia.readthedocs.io/en/latest/augmentation.container.html#video-data-augmentation 5 | """ 6 | import kornia.augmentation as augm 7 | 8 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 9 | IMAGENET_STD = [0.229, 0.224, 0.225] 10 | 11 | 12 | train_transform = augm.VideoSequential( 13 | augm.Normalize(mean=IMAGENET_MEAN, 14 | std=IMAGENET_STD), 15 | augm.RandomRotation(degrees=(-10.0, 10.0)), 16 | augm.RandomHorizontalFlip(p=0.5), 17 | augm.RandomBrightness(brightness=(0.8, 1.2)), 18 | data_format="BTCHW", 19 | same_on_frame=True 20 | ) 21 | 22 | val_transform = augm.VideoSequential( 23 | augm.Normalize(mean=IMAGENET_MEAN, 24 | std=IMAGENET_STD), 25 | data_format="BTCHW", 26 | same_on_frame=True 27 | ) 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT_NAME=decode-video-pytorch 2 | VERSION=0.1.1 3 | 4 | IMAGE_NAME=$(PROJECT_NAME):$(VERSION) 5 | CONTAINER_NAME=--name=$(PROJECT_NAME) 6 | 7 | NET=--net=host 8 | IPC=--ipc=host 9 | BUILD_NET=--network=host 10 | 11 | # Specify which GPU the container can see, e.g.: all gpus if called without 12 | # options, or a specific GPU if called with GPUS='"device=1"'. 13 | GPUS?=all 14 | ifeq ($(GPUS),none) 15 | GPUS_OPTION=--gpus=all 16 | else 17 | GPUS_OPTION=--gpus=$(GPUS) 18 | endif 19 | 20 | .PHONY: all build stop run logs 21 | 22 | all: build stop run logs 23 | 24 | build: 25 | docker build $(BUILD_NET) -t $(IMAGE_NAME) -f Dockerfile . 26 | 27 | stop: 28 | docker stop $(shell docker container ls -q --filter name=$(PROJECT_NAME)*) 29 | 30 | kill: 31 | docker kill $(shell docker container ls -q --filter name=$(PROJECT_NAME)*) 32 | docker rm $(shell docker container ls -q --filter name=$(PROJECT_NAME)*) 33 | 34 | run: 35 | docker run --rm -it $(GPUS_OPTION) \ 36 | $(NET) $(IPC) \ 37 | -v $(shell pwd):/workdir/ \ 38 | $(CONTAINER_NAME) \ 39 | $(IMAGE_NAME) \ 40 | bash 41 | 42 | logs: 43 | docker logs -f $(PROJECT_NAME) 44 | -------------------------------------------------------------------------------- /src/video_io/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from torch.cuda import device_count, is_available 4 | 5 | 6 | def get_device_id(device: str) -> int: 7 | """Convert device name to device ID. 8 | 9 | Cpu decoded to -1. 10 | 11 | Args: 12 | device (str): The device name ("cpu", "cuda", or "cuda:"). 13 | 14 | Returns: 15 | int: The device ID. 16 | 17 | Raises: 18 | ValueError: If the requested CUDA device ID is not available. 19 | """ 20 | if device.startswith("cuda:"): 21 | try: 22 | device_id = int(device.split(":")[1]) 23 | if device_id < 0 or device_id >= device_count(): 24 | raise ValueError( 25 | f"CUDA device {device_id} does not exist. " 26 | f"Available devices: {list(range(device_count()))}") 27 | except (IndexError, ValueError) as e: 28 | raise ValueError(f"Invalid CUDA device format: {device}") from e 29 | elif device == "cuda": 30 | if is_available(): 31 | device_id = 0 32 | else: 33 | warnings.warn( 34 | "CUDA is not available. Using CPU instead.", stacklevel=2) 35 | return -1 36 | elif device == "cpu": 37 | device_id = -1 38 | else: 39 | warnings.warn(f"Unknown device {device}, using CPU instead.", 40 | stacklevel=2) 41 | device_id = -1 42 | 43 | return device_id 44 | -------------------------------------------------------------------------------- /tests/video_io/test_utils.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from src.video_io.utils import get_device_id 6 | 7 | 8 | def mock_device_count(): 9 | return 2 10 | 11 | 12 | def mock_is_available_true(): 13 | return True 14 | 15 | 16 | def mock_is_available_false(): 17 | return False 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "device_input, expected_output", 22 | [ 23 | ("cpu", -1), 24 | ("cuda", 0), 25 | ("cuda:0", 0), 26 | ("cuda:1", 1), 27 | ] 28 | ) 29 | @patch('src.video_io.utils.device_count', side_effect=mock_device_count) 30 | @patch('src.video_io.utils.is_available', side_effect=mock_is_available_true) 31 | def test_get_device_id_happy_cases(mock_is_available, mock_device_count, 32 | device_input, expected_output): 33 | assert get_device_id(device_input) == expected_output 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "device_input", 38 | [ 39 | ("cuda:-1"), 40 | ("cuda:4"), 41 | ] 42 | ) 43 | @patch('src.video_io.utils.device_count', side_effect=mock_device_count) 44 | @patch('src.video_io.utils.is_available', side_effect=mock_is_available_true) 45 | def test_get_device_id_error_cases(mock_is_available, mock_device_count, 46 | device_input): 47 | with pytest.raises(ValueError): 48 | get_device_id(device_input) 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "device_input, expected_output", 53 | [ 54 | ("cpu", -1), 55 | ("cuda", -1), 56 | ] 57 | ) 58 | @patch('src.video_io.utils.is_available', side_effect=mock_is_available_false) 59 | def test_get_device_id_no_gpus(mock_is_available, device_input, 60 | expected_output): 61 | assert get_device_id(device_input) == expected_output 62 | -------------------------------------------------------------------------------- /src/video_io/nvcodec_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from pathlib import Path 3 | 4 | import torch 5 | import PyNvVideoCodec as nvc # noqa: N813 6 | 7 | from src.video_io.abstract_reader import AbstractVideoReader 8 | 9 | 10 | class PyNvVideoCodecReader(AbstractVideoReader): 11 | def __init__(self, 12 | video_path: str | Path, 13 | mode: Literal["seek", "stream"] = "stream", 14 | output_format: Literal["THWC", "TCHW"] = "THWC", 15 | device: str = "cuda:0"): 16 | self.decoder = None 17 | super().__init__(video_path, mode, output_format, device) 18 | if self.gpu_id < 0: 19 | ValueError("PyNvVideoCodecReader supports Nvidia GPU decoding only" 20 | f"provide a valid device. {self.device} was specified.") 21 | 22 | def _initialize_reader(self) -> None: 23 | self.decoder = nvc.SimpleDecoder( 24 | enc_file_path=self.video_path, 25 | gpu_id=self.gpu_id, 26 | use_device_memory=True, 27 | output_color_type=nvc.OutputColorType.RGB 28 | ) 29 | metadata = self.decoder.get_stream_metadata() 30 | self.num_frames = metadata.num_frames 31 | # Docs suggests it's avg_frame_rate, but it does not exist 32 | self.fps = metadata.average_fps 33 | 34 | def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 35 | frames = [] 36 | for idx in frame_indices: 37 | frames.append(torch.from_dlpack(self.decoder[idx])) 38 | return frames 39 | 40 | def stream_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 41 | return [torch.from_dlpack(frame) for frame in 42 | self.decoder.get_batch_frames_by_index(frame_indices)] 43 | 44 | def release(self) -> None: 45 | if hasattr(self, 'decoder') and self.decoder is not None: 46 | del self.decoder 47 | self.decoder = None 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/** 2 | !data/.gitkeep 3 | 4 | .env 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | **.coverage 11 | .vscode/ 12 | .devcontainer/ 13 | .mypy_cache/ 14 | 15 | # Hydra outputs 16 | **/outputs/ 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | 66 | **/*.log 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | .pybuilder/ 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | .pdm.toml 83 | 84 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 85 | __pypackages__/ 86 | 87 | # Celery stuff 88 | celerybeat-schedule 89 | celerybeat.pid 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mypy 107 | .mypy_cache/ 108 | .dmypy.json 109 | dmypy.json 110 | 111 | # Pyre type checker 112 | .pyre/ 113 | 114 | # pytype static type analyzer 115 | .pytype/ 116 | 117 | # Cython debug symbols 118 | cython_debug/ 119 | 120 | # results from notebooks 121 | notebooks/** 122 | !notebooks/*.ipynb 123 | -------------------------------------------------------------------------------- /src/video_io/torchvision_reader.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Literal 3 | from pathlib import Path 4 | 5 | import torch 6 | from torchvision.io import read_video, read_video_timestamps 7 | 8 | from src.video_io.abstract_reader import AbstractVideoReader 9 | 10 | 11 | class TorchvisionVideoReader(AbstractVideoReader): 12 | """Videoreader using PyTorch's torchvision.io.read_video. 13 | 14 | Note: 15 | TorchVision video decoding and encoding will be removed in a future 16 | release of TorchVision, TorchCodec is a recommended approach. 17 | 18 | Args: 19 | video_path (str | Path): Path to the input video file. 20 | mode (Literal["seek", "stream"], optional): Reading mode: "seek" - 21 | find each frame individually, "stream" - decode all frames from 22 | the range of requested indeces and subsample. 23 | Defaults to "stream". 24 | output_format (Literal["THWC", "TCHW"], optional): Data format: 25 | channel last or first. Defaults to "THWC". 26 | device (str, optional): Device to send the resulted tensor to. 27 | Defaults to "cuda:0". 28 | """ 29 | 30 | def __init__(self, video_path: str | Path, 31 | mode: Literal["seek", "stream"] = "stream", 32 | output_format: Literal["THWC", "TCHW"] = "THWC", 33 | device: str = "cuda:0"): 34 | super().__init__(video_path, mode=mode, output_format=output_format, 35 | device=device) 36 | 37 | def _initialize_reader(self) -> None: 38 | timestamps, fps = read_video_timestamps(self.video_path, 39 | pts_unit="sec") 40 | self.timestamps = timestamps 41 | self.num_frames = len(timestamps) 42 | self.fps = fps 43 | 44 | def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 45 | frame_timestamps = [self.timestamps[fid] for fid in frame_indices] 46 | frames = [] 47 | for ts in frame_timestamps: 48 | frame, _, _ = read_video( 49 | self.video_path, start_pts=ts, end_pts=ts, 50 | pts_unit="sec", output_format="THWC") 51 | frames.append(self._process_frame(frame[0])) 52 | return frames 53 | 54 | def stream_read(self, frame_indices: list[int]) -> torch.Tensor: 55 | frame_timestamps = [self.timestamps[fid] for fid in frame_indices] 56 | frames, _, _ = read_video( 57 | self.video_path, start_pts=min(frame_timestamps), 58 | end_pts=max(frame_timestamps) + (1 / self.fps), 59 | pts_unit="sec", output_format="THWC") 60 | frame_indices_sample = [fid - min(frame_indices) 61 | for fid in frame_indices] 62 | frames = frames[frame_indices_sample] 63 | return frames 64 | 65 | def release(self) -> None: 66 | pass 67 | -------------------------------------------------------------------------------- /src/video_io/torchcodec_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from pathlib import Path 3 | 4 | import torch 5 | from torchcodec.decoders import VideoDecoder 6 | 7 | from src.video_io.abstract_reader import AbstractVideoReader 8 | 9 | 10 | class TorchcodecVideoReader(AbstractVideoReader): 11 | """Videoreader using TorchCodec library. 12 | 13 | Args: 14 | video_path (str | Path): Path to the input video file. 15 | mode (Literal["seek", "stream"], optional): Reading mode: "seek" - 16 | find each frame individually, "stream" - decode all frames from 17 | the range of requested indeces and subsample. 18 | Defaults to "stream". 19 | output_format (Literal["THWC", "TCHW"], optional): Data format: 20 | channel last or first. Defaults to "THWC". 21 | device (str, optional): Device to send the resulted tensor to. 22 | Defaults to "cuda:0". 23 | """ 24 | 25 | def __init__(self, video_path: str | Path, 26 | mode: Literal["seek", "stream"] = "stream", 27 | output_format: Literal["THWC", "TCHW"] = "THWC", 28 | device: str = "cuda:0"): 29 | super().__init__(video_path, mode=mode, output_format=output_format, 30 | device=device) 31 | 32 | def _initialize_reader(self) -> None: 33 | # Tensor axis order rearranged in _finalize_tensor if required 34 | self.decoder = VideoDecoder(self.video_path, 35 | dimension_order="NHWC", 36 | device=self.device) 37 | self.num_frames = self.decoder.metadata.num_frames 38 | self.fps = self.decoder.metadata.average_fps 39 | 40 | def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 41 | """Retrieve frames by their indices using random access.""" 42 | frames = [] 43 | for idx in frame_indices: 44 | if idx < 0 or idx >= self.num_frames: 45 | raise ValueError(f"Invalid frame index: {idx}") 46 | frame = self.decoder[idx] 47 | frames.append(self._process_frame(frame)) 48 | return frames 49 | 50 | def stream_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 51 | start_idx, end_idx = min(frame_indices), max(frame_indices) 52 | if start_idx < 0 or end_idx >= self.num_frames: 53 | raise ValueError(f"Invalid frame indices: {frame_indices}") 54 | batch = self.decoder.get_frames_in_range( 55 | start=start_idx, stop=end_idx + 1, step=1) 56 | frames = batch.data # Frames in TCHW format 57 | relative_indices = [idx - start_idx for idx in frame_indices] 58 | return frames[relative_indices] 59 | 60 | def _read_frames_slice(self, start_idx: int, stop_idx: int, step: int)\ 61 | -> torch.Tensor: 62 | return self.decoder[start_idx:stop_idx:step] 63 | 64 | def release(self) -> None: 65 | self.decoder = None 66 | -------------------------------------------------------------------------------- /src/video_io/opencv_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Literal 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | from src.video_io.abstract_reader import AbstractVideoReader 10 | 11 | os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "video_codec;h264_cuvid" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | 14 | 15 | class OpenCVVideoReader(AbstractVideoReader): 16 | """OpenCV-based video reader. 17 | 18 | To enable Nvidia GPU decoding, add the following: 19 | 20 | os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "video_codec;h264_cuvid" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Set the GPU to use for decoding 22 | 23 | Adjust the codec 'h264_cuvid' in accordance to the input video file codec. 24 | 25 | Note: 26 | Similarly, the video reader can be used to decode videos with other 27 | hardware codecs, if FFmpeg is compiled with the appropriate hardware 28 | support (e.g., Intel or AMD GPUs). 29 | 30 | Args: 31 | video_path (str | Path): Path to the input video file. 32 | mode (Literal["seek", "stream"], optional): Reading mode: "seek" - 33 | find each frame individually, "stream" - decode all frames from 34 | the range of requested indeces and subsample. 35 | Defaults to "stream". 36 | output_format (Literal["THWC", "TCHW"], optional): Data format: 37 | channel last or first. Defaults to "THWC". 38 | device (str, optional): Device to send the resulted tensor to. 39 | Defaults to "cuda:0". 40 | """ 41 | 42 | def __init__(self, video_path: str | Path, 43 | mode: Literal["seek", "stream"] = "stream", 44 | output_format: Literal["THWC", "TCHW"] = "THWC", 45 | device: str = "cuda:0"): 46 | super().__init__(video_path, mode=mode, output_format=output_format, 47 | device=device) 48 | 49 | def _initialize_reader(self) -> None: 50 | self._cap = cv2.VideoCapture(self.video_path, cv2.CAP_FFMPEG) 51 | self.num_frames = int(self._cap.get(cv2.CAP_PROP_FRAME_COUNT)) 52 | self.fps = self._cap.get(cv2.CAP_PROP_FPS) 53 | 54 | def _process_frame(self, frame: np.ndarray) -> torch.Tensor: 55 | return torch.from_numpy(frame).to(self.device) 56 | 57 | def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 58 | frames = [] 59 | for idx in frame_indices: 60 | self._cap.set(cv2.CAP_PROP_POS_FRAMES, idx) 61 | ret, frame = self._cap.read() 62 | if ret: 63 | frames.append(self._process_frame(frame)) 64 | else: 65 | break 66 | return frames 67 | 68 | def stream_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 69 | frames = [] 70 | start_idx = min(frame_indices) 71 | end_idx = max(frame_indices) + 1 72 | self._cap.set(cv2.CAP_PROP_POS_FRAMES, start_idx) 73 | for idx in range(start_idx, end_idx): 74 | ret, frame = self._cap.read() 75 | if not ret: 76 | break 77 | if idx in frame_indices: 78 | frames.append(self._process_frame(frame)) 79 | return frames 80 | 81 | def release(self): 82 | self._cap.release() 83 | -------------------------------------------------------------------------------- /tests/video_io/test_readers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import pytest 5 | 6 | from src.video_io import VALIVideoReader, OpenCVVideoReader, \ 7 | PyNvVideoCodecReader, TorchcodecVideoReader, TorchvisionVideoReader 8 | 9 | VIDEO_PATH = '/workdir/data/videos/test.mp4' 10 | FRAMES_TO_READ = [10, 20, 30, 40] 11 | DEVICE = "cuda:0" 12 | 13 | VIDEO_READERS = [ 14 | TorchcodecVideoReader, 15 | OpenCVVideoReader, 16 | TorchvisionVideoReader, 17 | VALIVideoReader, 18 | PyNvVideoCodecReader 19 | ] 20 | 21 | MODES = ["seek", "stream"] 22 | OUTPUT_FORMATS = {"THWC": (4, None, None, 3), "TCHW": (4, 3, None, None)} 23 | 24 | TORCH_DEVICE = torch.device(DEVICE) 25 | 26 | 27 | @pytest.mark.skipif(not torch.cuda.is_available(), 28 | reason="CUDA is not available") 29 | @pytest.mark.skipif(not os.path.exists(VIDEO_PATH), 30 | reason=(f"Video file {VIDEO_PATH} does not exist, " 31 | "please provide a test video file")) 32 | @pytest.mark.parametrize("reader_class", VIDEO_READERS) 33 | @pytest.mark.parametrize("mode", MODES) 34 | @pytest.mark.parametrize("output_format", OUTPUT_FORMATS.keys()) 35 | def test_video_reader(reader_class, mode, output_format): 36 | reader = reader_class( 37 | video_path=VIDEO_PATH, 38 | mode=mode, 39 | output_format=output_format, 40 | device=DEVICE 41 | ) 42 | 43 | frames = reader.read_frames(FRAMES_TO_READ) 44 | expected_shape = OUTPUT_FORMATS[output_format] 45 | 46 | assert len(frames.shape) == len(expected_shape), \ 47 | f"Expected {len(expected_shape)} dims, but got {len(frames.shape)}" 48 | 49 | for i, (exp_dim, actual_dim) in enumerate( 50 | zip(expected_shape, frames.shape, strict=True)): 51 | if exp_dim is not None: 52 | assert exp_dim == actual_dim, \ 53 | (f"Dimension mismatch at position {i}: " 54 | f"expected {exp_dim}, got {actual_dim}") 55 | 56 | assert frames.device == TORCH_DEVICE, \ 57 | f"Expected device {DEVICE}, but got {frames.device}" 58 | 59 | 60 | MAX_AVG_PIXEL_DIFF = 6.0 61 | 62 | 63 | @pytest.mark.skipif(not torch.cuda.is_available(), 64 | reason="CUDA is not available") 65 | @pytest.mark.skipif(not os.path.exists(VIDEO_PATH), 66 | reason=(f"Video file {VIDEO_PATH} does not exist, " 67 | "please provide a test video file")) 68 | @pytest.mark.parametrize("reader_class", VIDEO_READERS) 69 | @pytest.mark.parametrize("output_format", OUTPUT_FORMATS.keys()) 70 | def test_modes_equality(reader_class, output_format): 71 | reader_seek = reader_class( 72 | video_path=VIDEO_PATH, 73 | mode="seek", 74 | output_format=output_format, 75 | device=DEVICE 76 | ) 77 | 78 | reader_stream = reader_class( 79 | video_path=VIDEO_PATH, 80 | mode="stream", 81 | output_format=output_format, 82 | device=DEVICE 83 | ) 84 | 85 | frames_seek = reader_seek.read_frames(FRAMES_TO_READ).float() 86 | frames_stream = reader_stream.read_frames(FRAMES_TO_READ).float() 87 | m_err = torch.abs(frames_seek - frames_stream).mean().item() 88 | assert m_err < MAX_AVG_PIXEL_DIFF, \ 89 | ("Frames by 'seek' mode should be roughly equal to those by " 90 | f"'stream' mode. Avg pixel diff {m_err:.2f} > {MAX_AVG_PIXEL_DIFF}.") 91 | -------------------------------------------------------------------------------- /src/video_io/vali_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from pathlib import Path 3 | 4 | import torch 5 | import python_vali as vali 6 | 7 | from src.video_io.abstract_reader import AbstractVideoReader 8 | 9 | 10 | class VALIVideoReader(AbstractVideoReader): 11 | """Videoreader using VALI. 12 | 13 | See details on VALI at https://github.com/RomanArzumanyan/VALI. 14 | 15 | Args: 16 | video_path (str | Path): Path to the input video file. 17 | mode (Literal["seek", "stream"], optional): Reading mode: "seek" - 18 | find each frame individually, "stream" - decode all frames from 19 | the range of requested indeces and subsample. 20 | Defaults to "stream". 21 | output_format (Literal["THWC", "TCHW"], optional): Data format: 22 | channel last or first. Defaults to "THWC". 23 | device (str, optional): Device to send the resulted tensor to. 24 | Defaults to "cuda:0". 25 | """ 26 | 27 | def __init__(self, video_path: str | Path, 28 | mode: Literal["seek", "stream"] = "stream", 29 | output_format: Literal["THWC", "TCHW"] = "THWC", 30 | device: str = "cuda:0"): 31 | super().__init__(video_path, mode=mode, output_format=output_format, 32 | device=device) 33 | 34 | def _initialize_reader(self) -> None: 35 | self._decoder = vali.PyDecoder(self.video_path, opts={}, 36 | gpu_id=self.gpu_id) 37 | self.num_frames = self._decoder.NumFrames 38 | self.width = self._decoder.Width 39 | self.height = self._decoder.Height 40 | self.fps = self._decoder.AvgFramerate 41 | 42 | target_format = vali.PixelFormat.RGB 43 | self._raw_to_rgb = vali.PySurfaceConverter(gpu_id=self.gpu_id) 44 | 45 | self.surf_raw = vali.Surface.Make( 46 | format=self._decoder.Format, width=self.width, height=self.height, 47 | gpu_id=self.gpu_id) 48 | 49 | self.surf_rgb = vali.Surface.Make( 50 | format=target_format, width=self.width, height=self.height, 51 | gpu_id=self.gpu_id) 52 | 53 | def _decode_surface(self, surface: vali.Surface) -> torch.Tensor: 54 | self._raw_to_rgb.Run(surface, self.surf_rgb) 55 | frame_tensor = torch.from_dlpack(self.surf_rgb) 56 | return frame_tensor 57 | 58 | def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 59 | frame_tensors = [] 60 | for idx in frame_indices: 61 | seek_ctx = vali.SeekContext(idx) 62 | success, details = self._decoder.DecodeSingleSurface( 63 | self.surf_raw, seek_ctx=seek_ctx) 64 | if not success: 65 | raise RuntimeError(f"Failed to decode frame {idx}: {details}") 66 | frame_tensors.append(self._decode_surface(self.surf_raw)) 67 | return frame_tensors 68 | 69 | def stream_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 70 | start_idx = frame_indices[0] # Assuming the indices are sorted 71 | seek_ctx = vali.SeekContext(start_idx) 72 | success, details = self._decoder.DecodeSingleSurface( 73 | self.surf_raw, seek_ctx=seek_ctx) 74 | if not success: 75 | raise RuntimeError( 76 | f"Failed to decode frame {start_idx}: {details}") 77 | frame_tensors = [self._decode_surface(self.surf_raw)] 78 | for idx in range(start_idx, max(frame_indices)): 79 | success, details = self._decoder.DecodeSingleSurface( 80 | self.surf_raw) 81 | if not success: 82 | raise RuntimeError(f"Failed to decode frame {idx}: {details}") 83 | if idx in frame_indices: 84 | frame_tensors.append(self._decode_surface(self.surf_raw)) 85 | return frame_tensors 86 | 87 | def release(self) -> None: 88 | del self._decoder 89 | del self.surf_raw 90 | del self.surf_rgb 91 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA_VERSION=12.8.1 2 | FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu24.04 3 | 4 | ENV LANG=C.UTF-8 5 | ENV NVIDIA_DRIVER_CAPABILITIES=video,compute,utility 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | 8 | RUN apt-get update && apt -y upgrade && \ 9 | apt-get -y install software-properties-common apt-utils git \ 10 | build-essential yasm nasm cmake unzip wget curl \ 11 | libtcmalloc-minimal4 pkgconf autoconf libtool libc6 libc6-dev \ 12 | libnuma1 libnuma-dev libx264-dev libx265-dev libmp3lame-dev \ 13 | python3-pip python3.12-dev python3-numpy && \ 14 | ln -s /usr/bin/python3.12 /usr/bin/python && \ 15 | ln -sf /usr/bin/python3.12 /usr/bin/python3 && \ 16 | ln -sf /usr/bin/pip3 /usr/bin/pip && \ 17 | apt-get clean &&\ 18 | apt-get autoremove &&\ 19 | rm -rf /var/lib/apt/lists/* &&\ 20 | rm -rf /var/cache/apt/archives/* 21 | 22 | ENV PIP_BREAK_SYSTEM_PACKAGES=1 23 | 24 | # Build nvidia codec headers 25 | RUN git clone --depth=1 --branch=n13.0.19.0 \ 26 | --single-branch https://github.com/FFmpeg/nv-codec-headers.git && \ 27 | cd nv-codec-headers && make install && \ 28 | cd .. && rm -rf nv-codec-headers 29 | 30 | # Build FFmpeg with NVENC support 31 | RUN git clone --depth=1 --branch=n7.1.1 --single-branch https://github.com/FFmpeg/FFmpeg.git && \ 32 | cd FFmpeg && \ 33 | mkdir ffmpeg_build && cd ffmpeg_build && \ 34 | ../configure \ 35 | --enable-nonfree \ 36 | --enable-cuda \ 37 | --enable-libnpp \ 38 | --enable-cuvid \ 39 | --enable-ffnvcodec \ 40 | --enable-nvdec \ 41 | --enable-nvenc \ 42 | --enable-shared \ 43 | --disable-static \ 44 | --disable-doc \ 45 | --extra-cflags=-I/usr/local/cuda/include \ 46 | --extra-ldflags=-L/usr/local/cuda/lib64 \ 47 | --enable-gpl \ 48 | --enable-libx264 \ 49 | --enable-libx265 \ 50 | --enable-libmp3lame \ 51 | --extra-libs=-lpthread \ 52 | --nvccflags="-arch=sm_75 \ 53 | -gencode=arch=compute_75,code=sm_75 \ 54 | -gencode=arch=compute_80,code=sm_80 \ 55 | -gencode=arch=compute_86,code=sm_86 \ 56 | -gencode=arch=compute_89,code=sm_89 \ 57 | -gencode=arch=compute_90,code=sm_90" && \ 58 | make -j$(nproc) && make install && ldconfig && \ 59 | cd ../.. && rm -rf FFmpeg 60 | 61 | RUN mkdir /tmp/opencv && cd /tmp/opencv && \ 62 | wget https://github.com/opencv/opencv/archive/4.11.0.zip -O opencv-4.11.0.zip && \ 63 | unzip opencv-4.11.0.zip && cd opencv-4.11.0 && \ 64 | wget https://github.com/opencv/opencv_contrib/archive/4.11.0.zip -O opencv_contrib-4.11.0.zip && \ 65 | unzip opencv_contrib-4.11.0.zip && \ 66 | mkdir build && cd build && \ 67 | cmake -D CMAKE_BUILD_TYPE=Release\ 68 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 69 | -D BUILD_ZLIB=OFF \ 70 | -D BUILD_EXAMPLES=OFF \ 71 | -D BUILD_opencv_java=OFF \ 72 | -D BUILD_opencv_python2=OFF \ 73 | -D BUILD_opencv_python3=ON \ 74 | -D ENABLE_PRECOMPILED_HEADERS=OFF \ 75 | -D WITH_OPENCL=OFF \ 76 | -D WITH_FFMPEG=ON \ 77 | -D WITH_GSTREAMER=OFF \ 78 | -D WITH_CUDA=ON \ 79 | -D WITH_GTK=OFF \ 80 | -D WITH_OPENEXR=OFF \ 81 | -D WITH_PROTOBUF=OFF \ 82 | -D BUILD_LIST=python3,core,imgproc,imgcodecs,videoio,video,calib3d,flann,cudev,cudacodec \ 83 | -D CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda \ 84 | -D OPENCV_EXTRA_MODULES_PATH=/tmp/opencv/opencv-4.11.0/opencv_contrib-4.11.0/modules/ .. && \ 85 | make -j$(nproc) && make install && ldconfig && cd && rm -r /tmp/opencv 86 | 87 | # Install PyTorch 88 | RUN pip3 install --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 torchcodec==0.5 --index-url\ 89 | https://download.pytorch.org/whl/cu128 90 | 91 | # Main system requirements 92 | COPY requirements.txt /tmp/requirements.txt 93 | RUN pip3 install --no-cache-dir -r /tmp/requirements.txt 94 | 95 | ENV CUDA_DEVICE_ORDER=PCI_BUS_ID 96 | ENV PYTHONPATH=$PYTHONPATH:/workdir 97 | ENV TORCH_HOME=/workdir/data/.torch 98 | 99 | WORKDIR /workdir 100 | -------------------------------------------------------------------------------- /scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from src.video_io import VALIVideoReader, OpenCVVideoReader, \ 8 | AbstractVideoReader, PyNvVideoCodecReader, TorchcodecVideoReader, \ 9 | TorchvisionVideoReader 10 | 11 | # Provide your test video files here. 12 | VIDEO_PATHS = [ 13 | "/workdir/data/videos/test.mp4", 14 | ] 15 | N_PASSES = 3 # Number of times to repeat the benchmark for each video 16 | 17 | # Make sure the frames exist in the video file and all frames fit into VRAM! 18 | FRAMES_TO_READ_SEQUENTIAL = list(range(10, 20)) 19 | FRAMES_TO_READ_SLICE = list(range(10, 200, 20)) 20 | 21 | # Define the video readers to test 22 | VIDEO_READERS = [ 23 | TorchcodecVideoReader, 24 | OpenCVVideoReader, 25 | TorchvisionVideoReader, 26 | VALIVideoReader, 27 | PyNvVideoCodecReader 28 | ] 29 | 30 | MODES_TO_USE = ["seek", "stream"] 31 | # Note that some of the video readers don't support 'cpu' 32 | DEVICE = "cuda:0" 33 | 34 | if DEVICE.startswith("cuda:"): 35 | os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "video_codec;h264_cuvid" 36 | os.environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE.split(":")[1]) 37 | DEVICE = "cuda:0" # Since other devices are invisible now 38 | 39 | 40 | def run_benchmark(video_reader: AbstractVideoReader, 41 | frames_to_read: list[int] = FRAMES_TO_READ_SEQUENTIAL)\ 42 | -> list[float]: 43 | """Run a benchmark for reading a video using a specific video reader class. 44 | 45 | Args: 46 | video_reader (AbstractVideoReader): The video reader to profile. 47 | frames_to_read (list[int]): List of frame indices to read. 48 | 49 | Returns: 50 | list[float]: Lists containing all individual reads time. 51 | """ 52 | reading_time = [] 53 | 54 | assert max(frames_to_read) < video_reader.num_frames, \ 55 | (f"Frame index {max(frames_to_read)} is out of range " 56 | f"for video {video_reader.video_path}") 57 | 58 | _ = video_reader[0] # Warmup 59 | 60 | for _ in range(N_PASSES): 61 | start_time = time.perf_counter() 62 | _ = video_reader.read_frames(frames_to_read) 63 | torch.cuda.synchronize() # Ensure all CUDA operations are completed 64 | end_time = time.perf_counter() 65 | reading_time.append(end_time - start_time) 66 | 67 | return reading_time 68 | 69 | 70 | def aggregate_results(timings: list[float]) -> tuple[float, float]: 71 | """Aggregate benchmark results from timings. 72 | 73 | Args: 74 | timings (list[float]): List of individual times. 75 | 76 | Returns: 77 | tuple[float, float]: Mean and std of the times. 78 | """ 79 | mean_time = float(np.mean(timings)) 80 | std_dev = float(np.std(timings)) 81 | 82 | return mean_time, std_dev 83 | 84 | 85 | def main(): 86 | all_results = { 87 | reader_class.__name__: {mode: {'sequential': [], 'slice': []} 88 | for mode in MODES_TO_USE} 89 | for reader_class in VIDEO_READERS 90 | } 91 | 92 | for video_path in VIDEO_PATHS: 93 | print(f"\nBenchmarking on video: {video_path}") 94 | for video_reader_class in VIDEO_READERS: 95 | for mode in MODES_TO_USE: 96 | print(f" - {video_reader_class.__name__}: {mode} mode") 97 | reader = video_reader_class( 98 | video_path=video_path, 99 | mode=mode, 100 | device=DEVICE) 101 | 102 | timings_sequential = run_benchmark( 103 | reader, FRAMES_TO_READ_SEQUENTIAL) 104 | print( 105 | f" Sequential frame reading times: {timings_sequential}") 106 | all_results[video_reader_class.__name__][mode]['sequential'].extend( 107 | timings_sequential) 108 | 109 | timings_slice = run_benchmark(reader, FRAMES_TO_READ_SLICE) 110 | print(f" Slice frame reading times: {timings_slice}") 111 | all_results[video_reader_class.__name__][mode]['slice'].extend( 112 | timings_slice) 113 | 114 | for reader_name in all_results: 115 | for mode in MODES_TO_USE: 116 | sequential_mean, sequential_std_dev = aggregate_results( 117 | all_results[reader_name][mode]['sequential']) 118 | print(f"Final result for {reader_name} ({mode} mode) - Sequential:" 119 | f" {sequential_mean:.4f} ± {sequential_std_dev:.4f} s") 120 | 121 | slice_mean, slice_std_dev = aggregate_results( 122 | all_results[reader_name][mode]['slice']) 123 | print(f"Final result for {reader_name} ({mode} mode) - Slice:" 124 | f" {slice_mean:.4f} ± {slice_std_dev:.4f} s") 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /src/video_io/abstract_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from abc import ABCMeta, abstractmethod 4 | from typing import Any, Literal 5 | from pathlib import Path 6 | 7 | import torch 8 | 9 | from src.video_io.utils import get_device_id 10 | 11 | 12 | class AbstractVideoReader(metaclass=ABCMeta): 13 | def __init__(self, video_path: str | Path, 14 | mode: Literal["seek", "stream"] = "stream", 15 | output_format: Literal["THWC", "TCHW"] = "THWC", 16 | device: str = "cuda:0") -> None: 17 | self.video_path = self._validate_and_convert_path(video_path) 18 | self.mode = self._validate_mode(mode) 19 | self.output_format = self._validate_output_format(output_format) 20 | self.device = device 21 | self.gpu_id = get_device_id(device) 22 | 23 | # Values to be initialised 24 | self.num_frames: int = 0 25 | self.fps: int | float = 0 26 | self._initialize_reader() 27 | 28 | def _validate_and_convert_path(self, video_path: str | Path) -> str: 29 | if isinstance(video_path, Path): 30 | video_path = str(video_path) 31 | if not os.path.exists(video_path): 32 | raise FileNotFoundError(f"Video file {video_path} does not exist") 33 | return video_path 34 | 35 | def _validate_mode(self, mode: Literal["seek", "stream"]) -> str: 36 | if mode not in ["seek", "stream"]: 37 | raise ValueError( 38 | f"Invalid mode '{mode}'. Must be one of ['seek', 'stream']") 39 | return mode 40 | 41 | def _validate_output_format( 42 | self, output_format: Literal["THWC", "TCHW"]) -> str: 43 | if output_format not in ["THWC", "TCHW"]: 44 | raise ValueError( 45 | f"Invalid output format '{output_format}'. " 46 | "Must be one of ['THWC', 'TCHW']") 47 | return output_format 48 | 49 | @abstractmethod 50 | def _initialize_reader(self) -> None: 51 | """Initialise metadata and prepare reader for reading the video.""" 52 | pass 53 | 54 | def _finalize_tensor( 55 | self, frames: list[torch.Tensor] | torch.Tensor) -> torch.Tensor: 56 | """Combine frame tensors and finalize the output format. 57 | 58 | Args: 59 | frames (list[torch.Tensor] | torch.Tensor): A list of frame tensors 60 | to be combined or a single tensor of shape THWC. 61 | 62 | Returns: 63 | torch.Tensor: The unified and finalized tensor. 64 | """ 65 | if isinstance(frames, list): 66 | tensor = torch.stack(frames, dim=0) 67 | else: 68 | tensor = frames 69 | tensor = tensor.to(self.device) 70 | if self.output_format == "TCHW": 71 | tensor = tensor.permute(0, 3, 1, 2) 72 | return tensor 73 | 74 | def _process_frame(self, frame: Any) -> torch.Tensor: 75 | """Process an individual frame if required and convert it to tensor.""" 76 | return frame 77 | 78 | @abstractmethod 79 | def seek_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 80 | """Seek to each frame and read the frames from the video one by one. 81 | 82 | Args: 83 | frame_indices (list[int]): List of frame indices to read. Indices 84 | are expected to be sorted, it is expected to be at least one 85 | index in the list. 86 | 87 | Returns: 88 | list[np.ndarray]: List of frames from the video. 89 | """ 90 | pass 91 | 92 | @abstractmethod 93 | def stream_read(self, frame_indices: list[int]) -> list[torch.Tensor]: 94 | """Read all frames in range of the given indices and subset them. 95 | 96 | Args: 97 | frame_indices (list[int]): List of frame indices to read. Indices 98 | are expected to be sorted, it is expected to be at least one 99 | index in the list. 100 | 101 | Returns: 102 | list[np.ndarray]: List of frames from the video. 103 | 104 | """ 105 | pass 106 | 107 | def read_frames(self, frame_indices: list[int]) -> torch.Tensor: 108 | if min(frame_indices) < 0 or max(frame_indices) >= self.num_frames: 109 | raise ValueError(f"Invalid frame indices {frame_indices} " 110 | f"in {self.video_path} video. " 111 | f"Must be in range [0, {self.num_frames - 1}]") 112 | frames = [] 113 | if self.mode == "seek": 114 | frames = self.seek_read(frame_indices) 115 | elif self.mode == "stream": 116 | frames = self.stream_read(frame_indices) 117 | return self._finalize_tensor(frames) 118 | 119 | @abstractmethod 120 | def release(self) -> None: 121 | """Release any resources used by the reader.""" 122 | pass 123 | 124 | def __len__(self): 125 | return self.num_frames 126 | 127 | def _read_frames_slice(self, start_idx: int, stop_idx: int, 128 | step: int) -> torch.Tensor: 129 | indices = list(range(start_idx, stop_idx, step)) 130 | return self.read_frames(indices) 131 | 132 | def __getitem__(self, index: int | slice) -> torch.Tensor: 133 | if isinstance(index, int): 134 | if index < 0 or index >= self.num_frames: 135 | raise IndexError( 136 | f"Index {index} is out of bounds for video " 137 | f"{self.video_path} with {self.num_frames} frames.") 138 | return self.read_frames([index]) 139 | 140 | if isinstance(index, slice): 141 | start, stop, step = index.start, index.stop, index.step 142 | start = start if start is not None else 0 143 | stop = stop if stop is not None else self.num_frames 144 | step = step if step is not None else 1 145 | 146 | if start < 0 or stop > self.num_frames or step <= 0: 147 | raise ValueError(f"Invalid slice {index} for video " 148 | f"with {self.num_frames} frames.") 149 | 150 | return self._read_frames_slice(start_idx=start, stop_idx=stop, 151 | step=step) 152 | 153 | raise TypeError( 154 | f"Index must be an integer or slice, not {type(index)}") 155 | 156 | async def seek_read_async(self, frame_indices: list[int])\ 157 | -> list[torch.Tensor]: 158 | """Asynchronously seeks and reads frames. 159 | 160 | Subclasses should override this method. 161 | This default implementation calls the synchronous `seek_read`. 162 | 163 | Args: 164 | frame_indices (list[int]): List of frame indices to read. Indices 165 | are expected to be sorted, it is expected to be at least one 166 | index in the list. 167 | 168 | Returns: 169 | list[np.ndarray]: List of frames from the video. 170 | """ 171 | return await asyncio.to_thread(self.seek_read, frame_indices) 172 | 173 | async def stream_read_async(self, frame_indices: list[int])\ 174 | -> list[torch.Tensor]: 175 | """Asynchronously streams and reads frames. 176 | 177 | Subclasses should override this method. 178 | This default implementation calls the synchronous `stream_read`. 179 | 180 | Args: 181 | frame_indices (list[int]): List of frame indices to read. Indices 182 | are expected to be sorted, it is expected to be at least one 183 | index in the list. 184 | 185 | Returns: 186 | list[np.ndarray]: List of frames from the video. 187 | """ 188 | return await asyncio.to_thread(self.stream_read, frame_indices) 189 | 190 | async def read_frames_async(self, frame_indices: list[int])\ 191 | -> torch.Tensor: 192 | """Asynchronously reads frames. 193 | 194 | This method will delegate to the `..._async` versions of the read 195 | methods. 196 | 197 | Args: 198 | frame_indices (list[int]): List of frame indices to read. Indices 199 | are expected to be sorted, it is expected to be at least one 200 | index in the list. 201 | 202 | Returns: 203 | torch.Tensor: Decoded video frames tensor. 204 | """ 205 | if min(frame_indices) < 0 or max(frame_indices) >= self.num_frames: 206 | raise ValueError(f"Invalid frame indices {frame_indices} " 207 | f"in {self.video_path} video. " 208 | f"Must be in range [0, {self.num_frames - 1}]") 209 | frames = [] 210 | if self.mode == "seek": 211 | frames = await self.seek_read_async(frame_indices) 212 | elif self.mode == "stream": 213 | frames = await self.stream_read_async(frame_indices) 214 | return self._finalize_tensor(frames) 215 | 216 | def __repr__(self) -> str: 217 | return (f"Video {self.video_path}: " 218 | f"{self.num_frames} frames @ {self.fps}fps") 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video decoding for DL models training with PyTorch 2 | 3 | This project demonstrates various approaches to decoding video frames into PyTorch tensors with hardware acceleration, providing benchmarks and examples to help helping users choose the most efficient video reader for their deep learning workflows. 4 | 5 | The repo was originally developed to illustrate a talk given at the [London PyTorch Meetup](https://www.meetup.com/London-PyTorch-Meetup/): 6 |
7 | Optimising Video Pipelines for Neural Network Training with PyTorch
8 | by Nikolay Falaleev on 21/11/2024 9 |
10 | 11 | The talk's slides are available [here](https://docs.google.com/presentation/d/1Qw9Cy0Pjikf5IBdZIGVqK968cKepKN2GuZD6hA1At8s/edit?usp=sharing). Note that the code has been substantially updated since the talk's presentation, including new video readers and improvements in the code structure. 12 | 13 | ![Benchmarks results](/readme_imgs/benchmarks.png) 14 | _Time of video decoding into PyTorch tensors for different video readers in different modes. The reported values are for decoding 10 frames into PyTorch tensors from 1080p 30fps video file: [Big Buck Bunny](https://download.blender.org/demo/movies/BBB/). The results were obtained using Nvidia RTX 3090 for hardware acceleration of all decoders using v. 0.1.0._ 15 | 16 | ## Prerequisites 17 | 18 | * Nvidia GPU with Video Encode and Decode feature [CUVID](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new). Nvidia Driver version >= 570. 19 | * GNU [make](https://www.gnu.org/software/make/) - it is quite likely that it is already installed on your system. 20 | * [Docker](https://docs.docker.com/engine/install/) and [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). 21 | * Some video files for testing, put them in the `data/videos` directory. 22 | 23 | ## How to run 24 | 25 | The project provides a Docker environment that includes PyTorch, along with FFmpeg and OpenCV, both compiled from source with NVIDIA hardware acceleration support. 26 | 27 | 1. Build Docker image: 28 | 29 | ``` 30 | make build 31 | ``` 32 | 33 | 2. Run the container: 34 | ``` 35 | make run 36 | ``` 37 | 38 | The Docker container will have the project folder mounted to `/workdir`, including the contents of `data` and all the code. 39 | 40 | All the following can be executed inside the running container. 41 | 42 | ## Benchmarking 43 | A simple benchmark script is provided in [scripts/benchmark.py](scripts/benchmark.py). It compares the performance of the different readers available in the project when running in different modes. 44 | 45 | In order to run benchmarks, provide representative video files and update parameters of the benchmarking process in [scripts/benchmark.py](scripts/benchmark.py). Please note that the results heavily depend on video file features, including encoding parameters and resolution. Another critical aspect is the required sampling strategy - whether it is required to sample individual frames randomly, a sequence of frames or a sparse subset of frames. That is why it is recommended to run the benchmark with parameters representing the actual use case of the video reader to select the most appropriate one as well as select the best strategy for reading frames. 46 | 47 | Adjust parameters of the benchmark as required. To run the script, run the following command in the project container: 48 | 49 | ```bash 50 | python scripts/benchmark.py 51 | ``` 52 | 53 | When selecting a particular video decoding approach, one should consider additional features offered by the tools. For example, although VALI may not be the fastest in the provided benchmarking framework, it offers significant flexibility and can outperform other readers when additional transforms are required as part of the pipeline, such as colour space conversion, resizing, and more. 54 | 55 | ## Code navigation 56 | 57 | Several base video readers classes are provided in [src/video_io](src/video_io); they follow the same interface and inherit from [AbstractVideoReader](src/video_io/abstract_reader.py). 58 | 59 | * [OpenCVVideoReader](src/video_io/opencv_reader.py) - Uses OpenCV's `cv2.VideoCapture` with the FFmpeg backend. It is the most straightforward way to read videos. Use `os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "video_codec;h264_cuvid"` to enable hardware acceleration. Adjust the video codec `h264_cuvid` parameter to match your video codec, e.g. `h264_cuvid` for h.264 codec and `hevc_cuvid` for HEVC codec; see all available codecs with Nvidia HW acceleration `ffmpeg -decoders | grep -i nvidia`. The provided Docker image includes OpenCV with hardware acceleration enabled, as well as FFmpeg compiled with Nvidia components. 60 | * [TorchvisionReadVideo](src/video_io/torchvision_reader.py) - uses PyTorch's `torchvision.io` module. 61 | * [TorchcodecVideoReader](src/video_io/torchcodec_reader.py) - uses [TorchCodec](https://github.com/pytorch/torchcodec) library. As TorchCodec is still in early stages of development and is installed from nightly builds, it may not work at some point or the API may change, but it is the recommended native approach for PyTorch. 62 | * [VALIVideoReader](src/video_io/vali_reader.py) - uses [VALI](https://github.com/RomanArzumanyan/VALI) library, which is a continuation of the [VideoProcessingFramework](https://github.com/NVIDIA/VideoProcessingFramework) project, which was discontinued by Nvidia. Unlike [PyNvVideoCodec](https://pypi.org/project/PyNvVideoCodec/), which is the current substitution by Nvidia, VALI offers a more flexible solution that includes pixel format and colour space conversion capabilities, as well as some low-level operations on surfaces. This allows it to be more powerful than PyNvVideoCodec, although it has a steeper learning curve, VALI allows for building more complex and optimized pipelines. 63 | * [PyNvVideoCodecReader](src/video_io/nvcodec_reader.py) - uses [PyNvVideoCodec](https://developer.nvidia.com/pynvvideocodec) project by Nvidia. It is one of the highest-performing options for decoding videos on Nvidia GPUs. Documentation on PyNvVideoCodec can be found [here](https://docs.nvidia.com/video-technologies/pynvvideocodec/index.html). 64 | 65 | In addition, there are some other examples of video-related components in the project: 66 | * [Kornia video augmentations](src/transforms.py) transforms. 67 | 68 | 69 | ### Try one of the video readers: 70 | 71 | ```python 72 | from src.video_io import TorchcodecVideoReader 73 | 74 | video_reader = TorchcodecVideoReader( 75 | "/workdir/data/videos/test.mp4", mode = "stream", output_format = "TCHW", 76 | device = "cuda:0") 77 | 78 | frames_to_read = list(range(0, 100, 5)) # Read every 5th frame 79 | tensor = video_reader.read_frames(frames_to_read) 80 | print(tensor.shape, tensor.device) # Should be (20, 3, H, W), cuda:0 81 | ``` 82 | 83 | All video readers classes use the same interface and return PyTorch tensors. 84 | 85 | Arguments: 86 | 87 | _video_path_ (str or Path): Path to the input video file. 88 | 89 | _mode_ (`seek` or `stream`): Reading mode: `seek` - 90 | find each frame individually, `stream` - read all frames in 91 | the range of requested indices (but not necessarily decode all frames) and subsample them. When using `mode = 'stream'`, 92 | one needs to ensure that all frames in the range 93 | (min(frames_to_read), max(frames_to_read)) fit into VRAM. 94 | Defaults to `stream`. 95 | 96 | _output_format_ (`THWC` or `TCHW`): Data format: 97 | channels-last or channels-first. Defaults to `THWC`. 98 | 99 | _device_ (str, optional): Device to send the resulted tensor to. If possible, the same device will be used for HW acceleration of decoding. Defaults to `cuda:0`. 100 | 101 | 102 | ## Known Limitations 103 | 104 | * TorchVision video decoding and encoding features are deprecated and will be removed in a future release of TorchVision. TorchCodec is the recommended alternative and is actively being developed for native integration with PyTorch. 105 | * The project currently does not implement asynchronous operations. 106 | * As for now, RGB is supported as the only target colour space. The main purpose of the project is to provide a unified interface for different video readers for convenient testing and selection of the most suitable one. In real scenarios, one may need to further customise the functionality to support particular formats and transforms in the most optimal way to fit the requirements of specific use cases. 107 | * The project is focused on Nvidia-based hardware acceleration, so `cpu` device is not properly supported and many readers are Nvidia-only. 108 | 109 | ## Acknowledgements 110 | 111 | This project demonstrates the use of several great open-source libraries and frameworks: 112 | 113 | - **[Torchcodec](https://github.com/pytorch/torchcodec)** – an experimental PyTorch library for video decoding, which is actively developed and offers promising native integration with PyTorch. 114 | - **[VALI](https://github.com/RomanArzumanyan/VALI)** – a powerful and flexible video processing library, based on the discontinued NVIDIA Video Processing Framework. It provides low-level control and is particularly well-suited for complex hardware-accelerated pipelines, where some additional frame processing (colour space conversion, resizing, etc.) is required as part of the pipeline. 115 | - **[PyNvVideoCodec](https://developer.nvidia.com/pynvvideocodec)** – an official NVIDIA project that provides Python bindings for video decoding using CUDA and NVDEC. 116 | - **[OpenCV](https://opencv.org/)** – a widely-used computer vision library, with hardware-accelerated video decoding capabilities when compiled with FFmpeg and CUDA support. 117 | - **[Kornia](https://kornia.org/)** – an open-source computer vision library for PyTorch, used in this project for video data augmentation examples. 118 | - **[FFmpeg](https://ffmpeg.org/)** 119 | --------------------------------------------------------------------------------