├── .dockerignore ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docker └── Dockerfile ├── frames_into_python.py ├── frames_into_pytorch.py ├── ghetto_nvds.py ├── logs └── .keep ├── media ├── .keep └── in.mp4 ├── tuning_baseline.py ├── tuning_batch.py ├── tuning_concurrency.py ├── tuning_dtod.py ├── tuning_fp16.py ├── tuning_postprocess_1.py └── tuning_postprocess_2.py /.dockerignore: -------------------------------------------------------------------------------- 1 | core 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | media 2 | logs/* 3 | core 4 | *DS_Store* 5 | *qdrep 6 | ._* 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Paul Bridger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | DOCKER_CMD := docker run -it --rm --gpus=all --privileged=true --ipc=host -v $(shell pwd):/app 3 | DOCKER_PY_CMD := ${DOCKER_CMD} --entrypoint=python 4 | DOCKER_NSYS_CMD := ${DOCKER_CMD} --entrypoint=nsys 5 | PROFILE_CMD := profile -t cuda,cublas,cudnn,nvtx,osrt --force-overwrite=true --delay=2 --duration=30 6 | 7 | PROFILE_TARGETS = logs/tuning_baseline.qdrep logs/tuning_postprocess_1.qdrep 8 | 9 | .PHONY: sleep 10 | 11 | 12 | build-container: docker/Dockerfile 13 | docker build -f $< -t pytorch-video-pipeline:latest . 14 | 15 | 16 | run-container: build-container 17 | ${DOCKER_CMD} pytorch-video-pipeline:latest 18 | 19 | 20 | logs/cli.pipeline.dot: 21 | ${DOCKER_CMD} --entrypoint=gst-launch-1.0 pytorch-video-pipeline:latest filesrc location=media/in.mp4 num-buffers=200 ! decodebin ! progressreport update-freq=1 ! fakesink sync=true 22 | 23 | 24 | logs/%.pipeline.dot: %.py 25 | ${DOCKER_PY_CMD} pytorch-video-pipeline:latest $< 26 | 27 | 28 | logs/%.qdrep: %.py 29 | ${DOCKER_NSYS_CMD} pytorch-video-pipeline:latest ${PROFILE_CMD} -o $@ python $< 30 | 31 | 32 | %.pipeline.png: logs/%.pipeline.dot 33 | dot -Tpng -o$@ $< && rm -f $< 34 | 35 | 36 | %.output.svg: %.rec 37 | cat $< | svg-term > $@ 38 | 39 | %.rec: 40 | asciinema rec $@ -c "$(MAKE) --no-print-directory logs/$*.pipeline.dot sleep" 41 | 42 | sleep: 43 | @sleep 2 44 | @echo "---" 45 | 46 | 47 | pipeline: cli.pipeline.png frames_into_python.pipeline.png frames_into_pytorch.pipeline.png 48 | 49 | tuning: logs/tuning_baseline.qdrep logs/tuning_postprocess_1.qdrep logs/tuning_postprocess_2.qdrep logs/tuning_batch.qdrep logs/tuning_fp16.qdrep logs/tuning_dtod.qdrep logs/tuning_concurrency.qdrep 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-video-pipeline -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/deepstream:5.0-20.07-triton 2 | 3 | RUN apt-get update && apt install --no-install-recommends -y \ 4 | ca-certificates \ 5 | python-gst-1.0 \ 6 | wget 7 | 8 | # allow GObject to find typelibs 9 | ENV GI_TYPELIB_PATH /usr/lib/x86_64-linux-gnu/girepository-1.0/ 10 | 11 | # use conda to simplify some dependency managemeny 12 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 13 | /bin/bash ~/miniconda.sh -b -p /opt/conda && \ 14 | rm ~/miniconda.sh && \ 15 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh 16 | 17 | ENV PATH /opt/conda/bin:${PATH} 18 | ENV LD_LIBRARY_PATH /usr/local/cuda-10.2/compat:/opt/conda/lib:/opt/nvidia/deepstream/deepstream-5.0/lib:${LD_LIBRARY_PATH} 19 | 20 | RUN conda install -y -c pytorch \ 21 | cudatoolkit=10.2 \ 22 | pytorch \ 23 | torchvision 24 | 25 | RUN conda install -y -c conda-forge \ 26 | pygobject \ 27 | scikit-image 28 | 29 | # Nvidia Apex for mixed-precision inference 30 | RUN git clone https://github.com/NVIDIA/apex.git /build/apex 31 | WORKDIR /build/apex 32 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 33 | 34 | RUN pip install --upgrade cython 35 | RUN pip install --upgrade gil_load 36 | 37 | # Gstreamer debug output location 38 | env GST_DEBUG_DUMP_DOT_DIR=/app/logs 39 | 40 | RUN python -c "import torch; torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math='fp32')" 2>/dev/null | : 41 | RUN python -c "import torch; torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math='fp16')" 2>/dev/null | : 42 | 43 | RUN rm -rf /var/lib/apt/lists/* && \ 44 | conda clean -afy 45 | 46 | WORKDIR /app 47 | -------------------------------------------------------------------------------- /frames_into_python.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import gi 3 | gi.require_version('Gst', '1.0') 4 | from gi.repository import Gst 5 | 6 | frame_format = 'RGBA' 7 | 8 | Gst.init() 9 | pipeline = Gst.parse_launch(f''' 10 | filesrc location=media/in.mp4 num-buffers=200 ! 11 | decodebin ! 12 | fakesink name=s 13 | ''') 14 | 15 | def on_frame_probe(pad, info): 16 | buf = info.get_buffer() 17 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 18 | return Gst.PadProbeReturn.OK 19 | 20 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 21 | Gst.PadProbeType.BUFFER, 22 | on_frame_probe 23 | ) 24 | 25 | pipeline.set_state(Gst.State.PLAYING) 26 | 27 | try: 28 | while True: 29 | msg = pipeline.get_bus().timed_pop_filtered( 30 | Gst.SECOND, 31 | Gst.MessageType.EOS | Gst.MessageType.ERROR 32 | ) 33 | if msg: 34 | text = msg.get_structure().to_string() if msg.get_structure() else '' 35 | msg_type = Gst.message_type_get_name(msg.type) 36 | print(f'{msg.src.name}: [{msg_type}] {text}') 37 | break 38 | finally: 39 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 40 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 41 | ) 42 | pipeline.set_state(Gst.State.NULL) 43 | -------------------------------------------------------------------------------- /frames_into_pytorch.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import gi 3 | gi.require_version('Gst', '1.0') 4 | from gi.repository import Gst 5 | import numpy as np 6 | import torch, torchvision 7 | 8 | frame_format, pixel_bytes = 'RGBA', 4 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math='fp32').eval().to(device) 11 | preprocess = torchvision.transforms.ToTensor() 12 | 13 | Gst.init() 14 | pipeline = Gst.parse_launch(f''' 15 | filesrc location=media/in.mp4 num-buffers=200 ! 16 | decodebin ! 17 | nvvideoconvert ! 18 | video/x-raw,format={frame_format} ! 19 | fakesink name=s 20 | ''') 21 | 22 | def on_frame_probe(pad, info): 23 | buf = info.get_buffer() 24 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 25 | 26 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 27 | image_batch = image_tensor.unsqueeze(0).to(device) 28 | with torch.no_grad(): 29 | detections = detector(image_batch)[0] 30 | 31 | return Gst.PadProbeReturn.OK 32 | 33 | def buffer_to_image_tensor(buf, caps): 34 | caps_structure = caps.get_structure(0) 35 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 36 | 37 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 38 | if is_mapped: 39 | try: 40 | image_array = np.ndarray( 41 | (height, width, pixel_bytes), 42 | dtype=np.uint8, 43 | buffer=map_info.data 44 | ).copy() # extend array lifetime beyond subsequent unmap 45 | return preprocess(image_array[:,:,:3]) # RGBA -> RGB 46 | finally: 47 | buf.unmap(map_info) 48 | 49 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 50 | Gst.PadProbeType.BUFFER, 51 | on_frame_probe 52 | ) 53 | 54 | pipeline.set_state(Gst.State.PLAYING) 55 | 56 | try: 57 | while True: 58 | msg = pipeline.get_bus().timed_pop_filtered( 59 | Gst.SECOND, 60 | Gst.MessageType.EOS | Gst.MessageType.ERROR 61 | ) 62 | if msg: 63 | text = msg.get_structure().to_string() if msg.get_structure() else '' 64 | msg_type = Gst.message_type_get_name(msg.type) 65 | print(f'{msg.src.name}: [{msg_type}] {text}') 66 | break 67 | finally: 68 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 69 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 70 | ) 71 | pipeline.set_state(Gst.State.NULL) 72 | -------------------------------------------------------------------------------- /ghetto_nvds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ctypes import Structure, POINTER, CDLL, addressof, sizeof, memmove, byref 3 | from ctypes import c_uint, c_int, c_ulong, c_void_p, c_bool 4 | 5 | nvbufsurface = CDLL('libnvbufsurface.so') 6 | 7 | max_planes = 4 8 | structure_padding = 4 9 | 10 | 11 | class NvBufSurfacePlaneParams(Structure): 12 | _fields_ = [ 13 | ("num_planes", c_uint), 14 | ("width", c_uint * max_planes), 15 | ("height", c_uint * max_planes), 16 | ("pitch", c_uint * max_planes), 17 | ("offset", c_uint * max_planes), 18 | ("psize", c_uint * max_planes), 19 | ("bytesPerPix", c_uint * max_planes), 20 | ("_reserved", c_void_p * max_planes * structure_padding) 21 | ] 22 | 23 | 24 | class NvBufSurfaceMappedAddr(Structure): 25 | _fields_ = [ 26 | ("addr", c_void_p * max_planes), 27 | ("eglImage", c_void_p), 28 | ("_reserved", c_void_p * structure_padding) 29 | ] 30 | 31 | 32 | class NvBufSurfaceParams(Structure): 33 | _fields_ = [ 34 | ("width", c_uint), 35 | ("height", c_uint), 36 | ("pitch", c_uint), 37 | ("colorFormat", c_int), 38 | ("layout", c_int), 39 | ("bufferDesc", c_ulong), 40 | ("dataSize", c_uint), 41 | ("dataPtr", c_void_p), 42 | ("planeParams", NvBufSurfacePlaneParams), 43 | ("mappedAddr", NvBufSurfaceMappedAddr), 44 | ("_reserved", c_void_p * structure_padding) 45 | ] 46 | 47 | 48 | class NvBufSurface(Structure): 49 | _fields_ = [ 50 | ("gpuId", c_uint), 51 | ("batchSize", c_uint), 52 | ("numFilled", c_uint), 53 | ("isContiguous", c_bool), 54 | ("memType", c_int), 55 | ("surfaceList", POINTER(NvBufSurfaceParams)), 56 | ("_reserved", c_void_p * structure_padding) 57 | ] 58 | 59 | def __init__(self, gst_map_info): 60 | nvbufsurface.NvBufSurfaceMemSet(byref(self), -1, -1, 0) 61 | memmove(addressof(self), gst_map_info.data, min(sizeof(self), len(gst_map_info.data))) 62 | 63 | def struct_copy_from(self, other_buf_surface): 64 | self.batchSize = other_buf_surface.batchSize 65 | self.numFilled = other_buf_surface.numFilled 66 | self.isContiguous = other_buf_surface.isContiguous 67 | self.memType = other_buf_surface.memType 68 | self.surfaceList = (NvBufSurfaceParams * other_buf_surface.numFilled)() 69 | for surface_ix in range(other_buf_surface.numFilled): 70 | self.surfaceList[surface_ix] = NvBufSurfaceParams() 71 | self.surfaceList[surface_ix].width = other_buf_surface.surfaceList[surface_ix].width 72 | self.surfaceList[surface_ix].height = other_buf_surface.surfaceList[surface_ix].height 73 | self.surfaceList[surface_ix].pitch = other_buf_surface.surfaceList[surface_ix].pitch 74 | self.surfaceList[surface_ix].colorFormat = other_buf_surface.surfaceList[surface_ix].colorFormat 75 | self.surfaceList[surface_ix].layout = other_buf_surface.surfaceList[surface_ix].layout 76 | self.surfaceList[surface_ix].bufferDesc = other_buf_surface.surfaceList[surface_ix].bufferDesc 77 | self.surfaceList[surface_ix].dataSize = other_buf_surface.surfaceList[surface_ix].dataSize 78 | self.surfaceList[surface_ix].planeParams = other_buf_surface.surfaceList[surface_ix].planeParams 79 | 80 | def mem_copy_from(self, other_buf_surface): 81 | copy_result = nvbufsurface.NvBufSurfaceCopy(byref(other_buf_surface), byref(self)) 82 | assert(copy_result == 0) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /logs/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbridger/pytorch-video-pipeline/1796f07d197ca2a7f3d71db561fa29acaefab564/logs/.keep -------------------------------------------------------------------------------- /media/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbridger/pytorch-video-pipeline/1796f07d197ca2a7f3d71db561fa29acaefab564/media/.keep -------------------------------------------------------------------------------- /media/in.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbridger/pytorch-video-pipeline/1796f07d197ca2a7f3d71db561fa29acaefab564/media/in.mp4 -------------------------------------------------------------------------------- /tuning_baseline.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import contextlib 4 | import gi 5 | gi.require_version('Gst', '1.0') 6 | from gi.repository import Gst 7 | import numpy as np 8 | import torch, torchvision 9 | 10 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp32' 11 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval().to(device) 14 | ssd_utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils') 15 | detection_threshold = 0.4 16 | start_time, frames_processed = None, 0 17 | 18 | # context manager to help keep track of ranges of time, using NVTX 19 | @contextlib.contextmanager 20 | def nvtx_range(msg): 21 | depth = torch.cuda.nvtx.range_push(msg) 22 | try: 23 | yield depth 24 | finally: 25 | torch.cuda.nvtx.range_pop() 26 | 27 | 28 | def on_frame_probe(pad, info): 29 | global start_time, frames_processed 30 | start_time = start_time or time.time() 31 | 32 | with nvtx_range('on_frame_probe'): 33 | buf = info.get_buffer() 34 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 35 | 36 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 37 | image_batch = preprocess(image_tensor.unsqueeze(0)) 38 | frames_processed += image_batch.size(0) 39 | 40 | with torch.no_grad(): 41 | with nvtx_range('inference'): 42 | locs, labels = detector(image_batch) 43 | postprocess(locs, labels) 44 | 45 | return Gst.PadProbeReturn.OK 46 | 47 | 48 | def buffer_to_image_tensor(buf, caps): 49 | with nvtx_range('buffer_to_image_tensor'): 50 | caps_structure = caps.get_structure(0) 51 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 52 | 53 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 54 | if is_mapped: 55 | try: 56 | image_array = np.ndarray( 57 | (height, width, pixel_bytes), 58 | dtype=np.uint8, 59 | buffer=map_info.data 60 | ) 61 | return torch.from_numpy( 62 | image_array[:,:,:3].copy() # RGBA -> RGB, and extend lifetime beyond subsequent unmap 63 | ) 64 | finally: 65 | buf.unmap(map_info) 66 | 67 | 68 | def preprocess(image_batch): 69 | '300x300 centre crop, normalize, HWC -> CHW' 70 | with nvtx_range('preprocess'): 71 | batch_dim, image_height, image_width, image_depth = image_batch.size() 72 | copy_x, copy_y = min(300, image_width), min(300, image_height) 73 | 74 | dest_x_offset = max(0, (300 - image_width) // 2) 75 | source_x_offset = max(0, (image_width - 300) // 2) 76 | dest_y_offset = max(0, (300 - image_height) // 2) 77 | source_y_offset = max(0, (image_height - 300) // 2) 78 | 79 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 80 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 81 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 82 | 83 | return torch.einsum( 84 | 'bhwc -> bchw', 85 | normalize(input_batch / 255) 86 | ).contiguous() 87 | 88 | 89 | def normalize(input_tensor): 90 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 91 | return (2.0 * input_tensor) - 1.0 92 | 93 | 94 | def postprocess(locs, labels): 95 | with nvtx_range('postprocess'): 96 | results_batch = ssd_utils.decode_results((locs, labels)) 97 | results_batch = [ssd_utils.pick_best(results, detection_threshold) for results in results_batch] 98 | for bboxes, classes, scores in results_batch: 99 | if scores.shape[0] > 0: 100 | print(bboxes, classes, scores) 101 | 102 | 103 | Gst.init() 104 | pipeline = Gst.parse_launch(f''' 105 | filesrc location=media/in.mp4 num-buffers=256 ! 106 | decodebin ! 107 | nvvideoconvert ! 108 | video/x-raw,format={frame_format} ! 109 | fakesink name=s 110 | ''') 111 | 112 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 113 | Gst.PadProbeType.BUFFER, 114 | on_frame_probe 115 | ) 116 | 117 | pipeline.set_state(Gst.State.PLAYING) 118 | 119 | try: 120 | while True: 121 | msg = pipeline.get_bus().timed_pop_filtered( 122 | Gst.SECOND, 123 | Gst.MessageType.EOS | Gst.MessageType.ERROR 124 | ) 125 | if msg: 126 | text = msg.get_structure().to_string() if msg.get_structure() else '' 127 | msg_type = Gst.message_type_get_name(msg.type) 128 | print(f'{msg.src.name}: [{msg_type}] {text}') 129 | break 130 | finally: 131 | finish_time = time.time() 132 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 133 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 134 | ) 135 | pipeline.set_state(Gst.State.NULL) 136 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 137 | -------------------------------------------------------------------------------- /tuning_batch.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import itertools 4 | import contextlib 5 | import gi 6 | gi.require_version('Gst', '1.0') 7 | from gi.repository import Gst 8 | import numpy as np 9 | import torch, torchvision 10 | 11 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp32' 12 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval().to(device) 15 | ssd_utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils') 16 | detection_threshold = 0.4 17 | start_time, frames_processed = None, 0 18 | image_batch, batch_size = [], 4 19 | 20 | # context manager to help keep track of ranges of time, using NVTX 21 | @contextlib.contextmanager 22 | def nvtx_range(msg): 23 | depth = torch.cuda.nvtx.range_push(msg) 24 | try: 25 | yield depth 26 | finally: 27 | torch.cuda.nvtx.range_pop() 28 | 29 | 30 | def on_frame_probe(pad, info): 31 | global start_time, frames_processed 32 | start_time = start_time or time.time() 33 | 34 | global image_batch 35 | 36 | if not image_batch: 37 | torch.cuda.nvtx.range_push('batch') 38 | torch.cuda.nvtx.range_push('create_batch') 39 | 40 | buf = info.get_buffer() 41 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 42 | 43 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 44 | image_batch.append(image_tensor) 45 | 46 | if len(image_batch) < batch_size: 47 | return Gst.PadProbeReturn.OK 48 | 49 | torch.cuda.nvtx.range_pop() # create_batch 50 | 51 | image_batch = preprocess(torch.stack(image_batch)) 52 | frames_processed += image_batch.size(0) 53 | 54 | with torch.no_grad(): 55 | with nvtx_range('inference'): 56 | locs, labels = detector(image_batch) 57 | image_batch = [] 58 | postprocess(locs, labels) 59 | 60 | torch.cuda.nvtx.range_pop() # batch 61 | return Gst.PadProbeReturn.OK 62 | 63 | 64 | def buffer_to_image_tensor(buf, caps): 65 | with nvtx_range('buffer_to_image_tensor'): 66 | caps_structure = caps.get_structure(0) 67 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 68 | 69 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 70 | if is_mapped: 71 | try: 72 | image_array = np.ndarray( 73 | (height, width, pixel_bytes), 74 | dtype=np.uint8, 75 | buffer=map_info.data 76 | ) 77 | return torch.from_numpy( 78 | image_array[:,:,:3].copy() # RGBA -> RGB, and extend lifetime beyond subsequent unmap 79 | ) 80 | finally: 81 | buf.unmap(map_info) 82 | 83 | 84 | def preprocess(image_batch): 85 | '300x300 centre crop, normalize, HWC -> CHW' 86 | with nvtx_range('preprocess'): 87 | batch_dim, image_height, image_width, image_depth = image_batch.size() 88 | copy_x, copy_y = min(300, image_width), min(300, image_height) 89 | 90 | dest_x_offset = max(0, (300 - image_width) // 2) 91 | source_x_offset = max(0, (image_width - 300) // 2) 92 | dest_y_offset = max(0, (300 - image_height) // 2) 93 | source_y_offset = max(0, (image_height - 300) // 2) 94 | 95 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 96 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 97 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 98 | 99 | return torch.einsum( 100 | 'bhwc -> bchw', 101 | normalize(input_batch / 255) 102 | ).contiguous() 103 | 104 | 105 | def normalize(input_tensor): 106 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 107 | return (2.0 * input_tensor) - 1.0 108 | 109 | 110 | def init_dboxes(): 111 | 'adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Detection/SSD/src/utils.py' 112 | fig_size = 300 113 | feat_size = [38, 19, 10, 5, 3, 1] 114 | steps = [8, 16, 32, 64, 100, 300] 115 | scales = [21, 45, 99, 153, 207, 261, 315] 116 | aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]] 117 | 118 | fk = fig_size / torch.tensor(steps).float() 119 | 120 | dboxes = [] 121 | # size of feature and number of feature 122 | for idx, sfeat in enumerate(feat_size): 123 | sk1 = scales[idx] / fig_size 124 | sk2 = scales[idx + 1] / fig_size 125 | sk3 = math.sqrt(sk1 * sk2) 126 | all_sizes = [(sk1, sk1), (sk3, sk3)] 127 | 128 | for alpha in aspect_ratios[idx]: 129 | w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha) 130 | all_sizes.append((w, h)) 131 | all_sizes.append((h, w)) 132 | 133 | for w, h in all_sizes: 134 | for i, j in itertools.product(range(sfeat), repeat=2): 135 | cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx] 136 | dboxes.append((cx, cy, w, h)) 137 | 138 | return torch.tensor( 139 | dboxes, 140 | dtype=model_dtype, 141 | device=device 142 | ).clamp(0, 1) 143 | 144 | 145 | dboxes_xywh = init_dboxes().unsqueeze(dim=0) 146 | scale_xy = 0.1 147 | scale_wh = 0.2 148 | 149 | 150 | def xywh_to_xyxy(bboxes_batch, scores_batch): 151 | bboxes_batch = bboxes_batch.permute(0, 2, 1) 152 | scores_batch = scores_batch.permute(0, 2, 1) 153 | 154 | bboxes_batch[:, :, :2] = scale_xy * bboxes_batch[:, :, :2] 155 | bboxes_batch[:, :, 2:] = scale_wh * bboxes_batch[:, :, 2:] 156 | 157 | bboxes_batch[:, :, :2] = bboxes_batch[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2] 158 | bboxes_batch[:, :, 2:] = bboxes_batch[:, :, 2:].exp() * dboxes_xywh[:, :, 2:] 159 | 160 | # transform format to ltrb 161 | l, t, r, b = bboxes_batch[:, :, 0] - 0.5 * bboxes_batch[:, :, 2],\ 162 | bboxes_batch[:, :, 1] - 0.5 * bboxes_batch[:, :, 3],\ 163 | bboxes_batch[:, :, 0] + 0.5 * bboxes_batch[:, :, 2],\ 164 | bboxes_batch[:, :, 1] + 0.5 * bboxes_batch[:, :, 3] 165 | 166 | bboxes_batch[:, :, 0] = l 167 | bboxes_batch[:, :, 1] = t 168 | bboxes_batch[:, :, 2] = r 169 | bboxes_batch[:, :, 3] = b 170 | 171 | return bboxes_batch, torch.nn.functional.softmax(scores_batch, dim=-1) 172 | 173 | 174 | def postprocess(locs, labels): 175 | with nvtx_range('postprocess'): 176 | locs, probs = xywh_to_xyxy(locs, labels) 177 | 178 | # flatten batch and classes 179 | batch_dim, box_dim, class_dim = probs.size() 180 | flat_locs = locs.reshape(-1, 4).repeat_interleave(class_dim, dim=0) 181 | flat_probs = probs.view(-1) 182 | class_indexes = torch.arange(class_dim, device=device).repeat(batch_dim * box_dim) 183 | image_indexes = (torch.ones(box_dim * class_dim, device=device) * torch.arange(1, batch_dim + 1, device=device).unsqueeze(-1)).view(-1) 184 | 185 | # only do NMS on detections over threshold, and ignore background (0) 186 | threshold_mask = (flat_probs > detection_threshold) & (class_indexes > 0) 187 | flat_locs = flat_locs[threshold_mask] 188 | flat_probs = flat_probs[threshold_mask] 189 | class_indexes = class_indexes[threshold_mask] 190 | image_indexes = image_indexes[threshold_mask] 191 | 192 | nms_mask = torchvision.ops.boxes.batched_nms( 193 | flat_locs, 194 | flat_probs, 195 | class_indexes * image_indexes, 196 | iou_threshold=0.7 197 | ) 198 | 199 | bboxes = flat_locs[nms_mask].cpu() 200 | probs = flat_probs[nms_mask].cpu() 201 | class_indexes = class_indexes[nms_mask].cpu() 202 | if bboxes.size(0) > 0: 203 | print(bboxes, class_indexes, probs) 204 | 205 | 206 | Gst.init() 207 | pipeline = Gst.parse_launch(f''' 208 | filesrc location=media/in.mp4 num-buffers=256 ! 209 | decodebin ! 210 | nvvideoconvert ! 211 | video/x-raw,format={frame_format} ! 212 | fakesink name=s 213 | ''') 214 | 215 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 216 | Gst.PadProbeType.BUFFER, 217 | on_frame_probe 218 | ) 219 | 220 | pipeline.set_state(Gst.State.PLAYING) 221 | 222 | try: 223 | while True: 224 | msg = pipeline.get_bus().timed_pop_filtered( 225 | Gst.SECOND, 226 | Gst.MessageType.EOS | Gst.MessageType.ERROR 227 | ) 228 | if msg: 229 | text = msg.get_structure().to_string() if msg.get_structure() else '' 230 | msg_type = Gst.message_type_get_name(msg.type) 231 | print(f'{msg.src.name}: [{msg_type}] {text}') 232 | break 233 | finally: 234 | finish_time = time.time() 235 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 236 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 237 | ) 238 | pipeline.set_state(Gst.State.NULL) 239 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 240 | -------------------------------------------------------------------------------- /tuning_concurrency.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import itertools 4 | import contextlib 5 | import copy 6 | import threading, queue 7 | import gil_load 8 | import gi 9 | gi.require_version('Gst', '1.0') 10 | from gi.repository import Gst 11 | import torch, torchvision 12 | import ghetto_nvds 13 | 14 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp16' 15 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 16 | detection_threshold = 0.4 17 | start_time, frames_processed = time.time(), 0 18 | batch_size, num_inference_threads = 8, 2 19 | num_devices = torch.cuda.device_count() 20 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval() 21 | 22 | 23 | # context manager to help keep track of ranges of time, using NVTX 24 | @contextlib.contextmanager 25 | def nvtx_range(msg): 26 | depth = torch.cuda.nvtx.range_push(msg) 27 | try: 28 | yield depth 29 | finally: 30 | torch.cuda.nvtx.range_pop() 31 | 32 | 33 | create_tensor_stream = torch.cuda.Stream() 34 | 35 | def on_frame_probe(pad, info): 36 | global start_time, frames_processed 37 | buf = info.get_buffer() 38 | # print(f'[{buf.pts / Gst.SECOND:6.2f}]') 39 | device, detector, dboxes, image_queue = thread_contexts[frames_processed % len(thread_contexts)] 40 | 41 | with torch.no_grad(): 42 | with torch.cuda.stream(create_tensor_stream): 43 | image_tensor = buffer_to_image_tensor(device, buf, pad.get_current_caps()) 44 | image_queue.put((image_tensor, torch.cuda.Event())) 45 | 46 | start_time = time.time() if frames_processed == 0 else start_time 47 | frames_processed += 1 48 | return Gst.PadProbeReturn.OK 49 | 50 | 51 | def buffer_to_image_tensor(device, buf, caps): 52 | with nvtx_range('buffer_to_image_tensor'): 53 | caps_structure = caps.get_structure(0) 54 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 55 | 56 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 57 | if is_mapped: 58 | try: 59 | source_surface = ghetto_nvds.NvBufSurface(map_info) 60 | torch_surface = ghetto_nvds.NvBufSurface(map_info) 61 | 62 | dest_tensor = torch.zeros( 63 | (torch_surface.surfaceList[0].height, torch_surface.surfaceList[0].width, 4), 64 | dtype=torch.uint8, 65 | device=device 66 | ) 67 | 68 | torch_surface.struct_copy_from(source_surface) 69 | assert(source_surface.numFilled == 1) 70 | assert(source_surface.surfaceList[0].colorFormat == 19) # RGBA 71 | 72 | # make torch_surface map to dest_tensor memory 73 | torch_surface.surfaceList[0].dataPtr = dest_tensor.data_ptr() 74 | torch_surface.gpuId = device.index 75 | 76 | # copy decoded GPU buffer (source_surface) into Pytorch tensor (torch_surface -> dest_tensor) 77 | torch_surface.mem_copy_from(source_surface) 78 | finally: 79 | buf.unmap(map_info) 80 | 81 | return dest_tensor[:, :, :3] 82 | 83 | 84 | def inference_thread_f(device, detector, dboxes, image_queue): 85 | cuda_stream = torch.cuda.Stream(device) 86 | 87 | while True: 88 | images, events = [], [] 89 | while len(images) < batch_size: 90 | next_image, image_event = image_queue.get() 91 | if next_image is None: 92 | return None 93 | images.append(next_image) 94 | events.append(image_event) 95 | 96 | with torch.cuda.stream(cuda_stream): 97 | with torch.no_grad(): 98 | for e in events: 99 | e.synchronize() 100 | 101 | image_batch = preprocess(device, torch.stack(images)) 102 | 103 | with nvtx_range('inference'): 104 | locs, labels = detector(image_batch) 105 | image_batch = [] 106 | postprocess(device, dboxes, locs, labels) 107 | 108 | 109 | def preprocess(device, image_batch): 110 | '300x300 centre crop, normalize, HWC -> CHW' 111 | with nvtx_range('preprocess'): 112 | batch_dim, image_height, image_width, image_depth = image_batch.size() 113 | copy_x, copy_y = min(300, image_width), min(300, image_height) 114 | 115 | dest_x_offset = max(0, (300 - image_width) // 2) 116 | source_x_offset = max(0, (image_width - 300) // 2) 117 | dest_y_offset = max(0, (300 - image_height) // 2) 118 | source_y_offset = max(0, (image_height - 300) // 2) 119 | 120 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 121 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 122 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 123 | 124 | return torch.einsum( 125 | 'bhwc -> bchw', 126 | normalize(input_batch / 255) 127 | ).contiguous() 128 | 129 | 130 | def normalize(input_tensor): 131 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 132 | return (2.0 * input_tensor) - 1.0 133 | 134 | 135 | def init_dboxes(device): 136 | 'adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Detection/SSD/src/utils.py' 137 | fig_size = 300 138 | feat_size = [38, 19, 10, 5, 3, 1] 139 | steps = [8, 16, 32, 64, 100, 300] 140 | scales = [21, 45, 99, 153, 207, 261, 315] 141 | aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]] 142 | 143 | fk = fig_size / torch.tensor(steps).float() 144 | 145 | dboxes = [] 146 | # size of feature and number of feature 147 | for idx, sfeat in enumerate(feat_size): 148 | sk1 = scales[idx] / fig_size 149 | sk2 = scales[idx + 1] / fig_size 150 | sk3 = math.sqrt(sk1 * sk2) 151 | all_sizes = [(sk1, sk1), (sk3, sk3)] 152 | 153 | for alpha in aspect_ratios[idx]: 154 | w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha) 155 | all_sizes.append((w, h)) 156 | all_sizes.append((h, w)) 157 | 158 | for w, h in all_sizes: 159 | for i, j in itertools.product(range(sfeat), repeat=2): 160 | cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx] 161 | dboxes.append((cx, cy, w, h)) 162 | 163 | return torch.tensor( 164 | dboxes, 165 | dtype=model_dtype, 166 | device=device 167 | ).clamp(0, 1) 168 | 169 | 170 | scale_xy = 0.1 171 | scale_wh = 0.2 172 | 173 | 174 | def xywh_to_xyxy(dboxes_xywh, bboxes_batch, scores_batch): 175 | bboxes_batch = bboxes_batch.permute(0, 2, 1) 176 | scores_batch = scores_batch.permute(0, 2, 1) 177 | 178 | bboxes_batch[:, :, :2] = scale_xy * bboxes_batch[:, :, :2] 179 | bboxes_batch[:, :, 2:] = scale_wh * bboxes_batch[:, :, 2:] 180 | 181 | bboxes_batch[:, :, :2] = bboxes_batch[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2] 182 | bboxes_batch[:, :, 2:] = bboxes_batch[:, :, 2:].exp() * dboxes_xywh[:, :, 2:] 183 | 184 | # transform format to ltrb 185 | l, t, r, b = bboxes_batch[:, :, 0] - 0.5 * bboxes_batch[:, :, 2],\ 186 | bboxes_batch[:, :, 1] - 0.5 * bboxes_batch[:, :, 3],\ 187 | bboxes_batch[:, :, 0] + 0.5 * bboxes_batch[:, :, 2],\ 188 | bboxes_batch[:, :, 1] + 0.5 * bboxes_batch[:, :, 3] 189 | 190 | bboxes_batch[:, :, 0] = l 191 | bboxes_batch[:, :, 1] = t 192 | bboxes_batch[:, :, 2] = r 193 | bboxes_batch[:, :, 3] = b 194 | 195 | return bboxes_batch, torch.nn.functional.softmax(scores_batch, dim=-1) 196 | 197 | 198 | def postprocess(device, dboxes, locs, labels): 199 | with nvtx_range('postprocess'): 200 | locs, probs = xywh_to_xyxy(dboxes, locs, labels) 201 | 202 | # flatten batch and classes 203 | batch_dim, box_dim, class_dim = probs.size() 204 | flat_locs = locs.reshape(-1, 4).repeat_interleave(class_dim, dim=0) 205 | flat_probs = probs.view(-1) 206 | class_indexes = torch.arange(class_dim, device=device).repeat(batch_dim * box_dim) 207 | image_indexes = (torch.ones(box_dim * class_dim, device=device) * torch.arange(1, batch_dim + 1, device=device).unsqueeze(-1)).view(-1) 208 | 209 | # only do NMS on detections over threshold, and ignore background (0) 210 | threshold_mask = (flat_probs > detection_threshold) & (class_indexes > 0) 211 | flat_locs = flat_locs[threshold_mask] 212 | flat_probs = flat_probs[threshold_mask] 213 | class_indexes = class_indexes[threshold_mask] 214 | image_indexes = image_indexes[threshold_mask] 215 | 216 | nms_mask = torchvision.ops.boxes.batched_nms( 217 | flat_locs, 218 | flat_probs, 219 | class_indexes * image_indexes, 220 | iou_threshold=0.7 221 | ) 222 | 223 | bboxes = flat_locs[nms_mask].cpu() 224 | probs = flat_probs[nms_mask].cpu() 225 | class_indexes = class_indexes[nms_mask].cpu() 226 | # if bboxes.size(0) > 0: 227 | # print(bboxes, class_indexes, probs) 228 | 229 | 230 | if num_devices: 231 | thread_contexts = [] 232 | 233 | for device_idx in range(num_devices): 234 | device = torch.device(f'cuda:{device_idx}') 235 | device_detector = copy.deepcopy(detector).to(device) 236 | dboxes_xywh = init_dboxes(device).unsqueeze(dim=0) 237 | 238 | for inference_idx in range(num_inference_threads): 239 | thread_queue = queue.Queue(2 * batch_size) 240 | thread_contexts.append((device, device_detector, dboxes_xywh, thread_queue)) 241 | 242 | else: 243 | sys.exit(1) 244 | 245 | try: 246 | gil_load.init() 247 | gil_load_enabled = True 248 | except RuntimeError: 249 | gil_load_enabled = False 250 | 251 | Gst.init() 252 | pipeline = Gst.parse_launch(f''' 253 | filesrc location=media/in.mp4 num-buffers=2048 ! 254 | decodebin ! 255 | nvvideoconvert ! 256 | video/x-raw(memory:NVMM),format={frame_format} ! 257 | fakesink name=s 258 | ''') 259 | 260 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 261 | Gst.PadProbeType.BUFFER, 262 | on_frame_probe 263 | ) 264 | 265 | inference_threads = [] 266 | for device, detector, dboxes, image_queue in thread_contexts: 267 | inference_threads.append( 268 | threading.Thread(target=inference_thread_f, args=(device, detector, dboxes, image_queue)) 269 | ) 270 | inference_threads[-1].start() 271 | 272 | # for each thread doing the pointless gil_10_pc, the GIL is busy an additional ~10% of time 273 | def gil_10_pc(): 274 | while True: 275 | for i in range(300): 276 | a = 1 + 1 277 | time.sleep(1e-9) 278 | 279 | gil_threads = [] 280 | for gil_idx in range(0): 281 | gil_threads.append(threading.Thread(target=gil_10_pc, daemon=True)) 282 | gil_threads[-1].daemon = True 283 | gil_threads[-1].start() 284 | 285 | pipeline.set_state(Gst.State.PLAYING) 286 | 287 | if gil_load_enabled: 288 | gil_load.start() 289 | 290 | try: 291 | while True: 292 | msg = pipeline.get_bus().timed_pop_filtered( 293 | Gst.SECOND, 294 | Gst.MessageType.EOS | Gst.MessageType.ERROR 295 | ) 296 | if msg: 297 | text = msg.get_structure().to_string() if msg.get_structure() else '' 298 | msg_type = Gst.message_type_get_name(msg.type) 299 | print(f'{msg.src.name}: [{msg_type}] {text}') 300 | break 301 | finally: 302 | if gil_load_enabled: 303 | gil_load.stop() 304 | for device, detector, dboxes, image_queue in thread_contexts: 305 | image_queue.put((None, None)) 306 | for inference_thread in inference_threads: 307 | inference_thread.join() 308 | finish_time = time.time() 309 | 310 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 311 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 312 | ) 313 | pipeline.set_state(Gst.State.NULL) 314 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 315 | if gil_load_enabled: 316 | print() 317 | print(gil_load.format(gil_load.get())) 318 | -------------------------------------------------------------------------------- /tuning_dtod.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import itertools 4 | import contextlib 5 | import gi 6 | gi.require_version('Gst', '1.0') 7 | from gi.repository import Gst 8 | import torch, torchvision 9 | import ghetto_nvds 10 | 11 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp16' 12 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval().to(device) 15 | ssd_utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils') 16 | detection_threshold = 0.4 17 | start_time, frames_processed = None, 0 18 | image_batch, batch_size = [], 8 19 | 20 | # context manager to help keep track of ranges of time, using NVTX 21 | @contextlib.contextmanager 22 | def nvtx_range(msg): 23 | depth = torch.cuda.nvtx.range_push(msg) 24 | try: 25 | yield depth 26 | finally: 27 | torch.cuda.nvtx.range_pop() 28 | 29 | 30 | def on_frame_probe(pad, info): 31 | global start_time, frames_processed 32 | start_time = start_time or time.time() 33 | 34 | global image_batch 35 | 36 | if not image_batch: 37 | torch.cuda.nvtx.range_push('batch') 38 | torch.cuda.nvtx.range_push('create_batch') 39 | 40 | buf = info.get_buffer() 41 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 42 | 43 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 44 | image_batch.append(image_tensor) 45 | 46 | if len(image_batch) < batch_size: 47 | return Gst.PadProbeReturn.OK 48 | 49 | torch.cuda.nvtx.range_pop() # create_batch 50 | 51 | image_batch = preprocess(torch.stack(image_batch)) 52 | frames_processed += image_batch.size(0) 53 | 54 | with torch.no_grad(): 55 | with nvtx_range('inference'): 56 | locs, labels = detector(image_batch) 57 | image_batch = [] 58 | postprocess(locs, labels) 59 | 60 | torch.cuda.nvtx.range_pop() # batch 61 | return Gst.PadProbeReturn.OK 62 | 63 | 64 | def buffer_to_image_tensor(buf, caps): 65 | with nvtx_range('buffer_to_image_tensor'): 66 | caps_structure = caps.get_structure(0) 67 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 68 | 69 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 70 | if is_mapped: 71 | try: 72 | source_surface = ghetto_nvds.NvBufSurface(map_info) 73 | torch_surface = ghetto_nvds.NvBufSurface(map_info) 74 | 75 | dest_tensor = torch.zeros( 76 | (torch_surface.surfaceList[0].height, torch_surface.surfaceList[0].width, 4), 77 | dtype=torch.uint8, 78 | device=device 79 | ) 80 | 81 | torch_surface.struct_copy_from(source_surface) 82 | assert(source_surface.numFilled == 1) 83 | assert(source_surface.surfaceList[0].colorFormat == 19) # RGBA 84 | 85 | # make torch_surface map to dest_tensor memory 86 | torch_surface.surfaceList[0].dataPtr = dest_tensor.data_ptr() 87 | 88 | # copy decoded GPU buffer (source_surface) into Pytorch tensor (torch_surface -> dest_tensor) 89 | torch_surface.mem_copy_from(source_surface) 90 | finally: 91 | buf.unmap(map_info) 92 | 93 | return dest_tensor[:, :, :3] 94 | 95 | 96 | def preprocess(image_batch): 97 | '300x300 centre crop, normalize, HWC -> CHW' 98 | with nvtx_range('preprocess'): 99 | batch_dim, image_height, image_width, image_depth = image_batch.size() 100 | copy_x, copy_y = min(300, image_width), min(300, image_height) 101 | 102 | dest_x_offset = max(0, (300 - image_width) // 2) 103 | source_x_offset = max(0, (image_width - 300) // 2) 104 | dest_y_offset = max(0, (300 - image_height) // 2) 105 | source_y_offset = max(0, (image_height - 300) // 2) 106 | 107 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 108 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 109 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 110 | 111 | return torch.einsum( 112 | 'bhwc -> bchw', 113 | normalize(input_batch / 255) 114 | ).contiguous() 115 | 116 | 117 | def normalize(input_tensor): 118 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 119 | return (2.0 * input_tensor) - 1.0 120 | 121 | 122 | def init_dboxes(): 123 | 'adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Detection/SSD/src/utils.py' 124 | fig_size = 300 125 | feat_size = [38, 19, 10, 5, 3, 1] 126 | steps = [8, 16, 32, 64, 100, 300] 127 | scales = [21, 45, 99, 153, 207, 261, 315] 128 | aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]] 129 | 130 | fk = fig_size / torch.tensor(steps).float() 131 | 132 | dboxes = [] 133 | # size of feature and number of feature 134 | for idx, sfeat in enumerate(feat_size): 135 | sk1 = scales[idx] / fig_size 136 | sk2 = scales[idx + 1] / fig_size 137 | sk3 = math.sqrt(sk1 * sk2) 138 | all_sizes = [(sk1, sk1), (sk3, sk3)] 139 | 140 | for alpha in aspect_ratios[idx]: 141 | w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha) 142 | all_sizes.append((w, h)) 143 | all_sizes.append((h, w)) 144 | 145 | for w, h in all_sizes: 146 | for i, j in itertools.product(range(sfeat), repeat=2): 147 | cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx] 148 | dboxes.append((cx, cy, w, h)) 149 | 150 | return torch.tensor( 151 | dboxes, 152 | dtype=model_dtype, 153 | device=device 154 | ).clamp(0, 1) 155 | 156 | 157 | dboxes_xywh = init_dboxes().unsqueeze(dim=0) 158 | scale_xy = 0.1 159 | scale_wh = 0.2 160 | 161 | 162 | def xywh_to_xyxy(bboxes_batch, scores_batch): 163 | bboxes_batch = bboxes_batch.permute(0, 2, 1) 164 | scores_batch = scores_batch.permute(0, 2, 1) 165 | 166 | bboxes_batch[:, :, :2] = scale_xy * bboxes_batch[:, :, :2] 167 | bboxes_batch[:, :, 2:] = scale_wh * bboxes_batch[:, :, 2:] 168 | 169 | bboxes_batch[:, :, :2] = bboxes_batch[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2] 170 | bboxes_batch[:, :, 2:] = bboxes_batch[:, :, 2:].exp() * dboxes_xywh[:, :, 2:] 171 | 172 | # transform format to ltrb 173 | l, t, r, b = bboxes_batch[:, :, 0] - 0.5 * bboxes_batch[:, :, 2],\ 174 | bboxes_batch[:, :, 1] - 0.5 * bboxes_batch[:, :, 3],\ 175 | bboxes_batch[:, :, 0] + 0.5 * bboxes_batch[:, :, 2],\ 176 | bboxes_batch[:, :, 1] + 0.5 * bboxes_batch[:, :, 3] 177 | 178 | bboxes_batch[:, :, 0] = l 179 | bboxes_batch[:, :, 1] = t 180 | bboxes_batch[:, :, 2] = r 181 | bboxes_batch[:, :, 3] = b 182 | 183 | return bboxes_batch, torch.nn.functional.softmax(scores_batch, dim=-1) 184 | 185 | 186 | def postprocess(locs, labels): 187 | with nvtx_range('postprocess'): 188 | locs, probs = xywh_to_xyxy(locs, labels) 189 | 190 | # flatten batch and classes 191 | batch_dim, box_dim, class_dim = probs.size() 192 | flat_locs = locs.reshape(-1, 4).repeat_interleave(class_dim, dim=0) 193 | flat_probs = probs.view(-1) 194 | class_indexes = torch.arange(class_dim, device=device).repeat(batch_dim * box_dim) 195 | image_indexes = (torch.ones(box_dim * class_dim, device=device) * torch.arange(1, batch_dim + 1, device=device).unsqueeze(-1)).view(-1) 196 | 197 | # only do NMS on detections over threshold, and ignore background (0) 198 | threshold_mask = (flat_probs > detection_threshold) & (class_indexes > 0) 199 | flat_locs = flat_locs[threshold_mask] 200 | flat_probs = flat_probs[threshold_mask] 201 | class_indexes = class_indexes[threshold_mask] 202 | image_indexes = image_indexes[threshold_mask] 203 | 204 | nms_mask = torchvision.ops.boxes.batched_nms( 205 | flat_locs, 206 | flat_probs, 207 | class_indexes * image_indexes, 208 | iou_threshold=0.7 209 | ) 210 | 211 | bboxes = flat_locs[nms_mask].cpu() 212 | probs = flat_probs[nms_mask].cpu() 213 | class_indexes = class_indexes[nms_mask].cpu() 214 | if bboxes.size(0) > 0: 215 | print(bboxes, class_indexes, probs) 216 | 217 | 218 | Gst.init() 219 | pipeline = Gst.parse_launch(f''' 220 | filesrc location=media/in.mp4 num-buffers=256 ! 221 | decodebin ! 222 | nvvideoconvert ! 223 | video/x-raw(memory:NVMM),format={frame_format} ! 224 | fakesink name=s 225 | ''') 226 | 227 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 228 | Gst.PadProbeType.BUFFER, 229 | on_frame_probe 230 | ) 231 | 232 | pipeline.set_state(Gst.State.PLAYING) 233 | 234 | try: 235 | while True: 236 | msg = pipeline.get_bus().timed_pop_filtered( 237 | Gst.SECOND, 238 | Gst.MessageType.EOS | Gst.MessageType.ERROR 239 | ) 240 | if msg: 241 | text = msg.get_structure().to_string() if msg.get_structure() else '' 242 | msg_type = Gst.message_type_get_name(msg.type) 243 | print(f'{msg.src.name}: [{msg_type}] {text}') 244 | break 245 | finally: 246 | finish_time = time.time() 247 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 248 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 249 | ) 250 | pipeline.set_state(Gst.State.NULL) 251 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 252 | -------------------------------------------------------------------------------- /tuning_fp16.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import itertools 4 | import contextlib 5 | import gi 6 | gi.require_version('Gst', '1.0') 7 | from gi.repository import Gst 8 | import numpy as np 9 | import torch, torchvision 10 | 11 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp16' 12 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval().to(device) 15 | ssd_utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils') 16 | detection_threshold = 0.4 17 | start_time, frames_processed = None, 0 18 | image_batch, batch_size = [], 8 19 | 20 | # context manager to help keep track of ranges of time, using NVTX 21 | @contextlib.contextmanager 22 | def nvtx_range(msg): 23 | depth = torch.cuda.nvtx.range_push(msg) 24 | try: 25 | yield depth 26 | finally: 27 | torch.cuda.nvtx.range_pop() 28 | 29 | 30 | def on_frame_probe(pad, info): 31 | global start_time, frames_processed 32 | start_time = start_time or time.time() 33 | 34 | global image_batch 35 | 36 | if not image_batch: 37 | torch.cuda.nvtx.range_push('batch') 38 | torch.cuda.nvtx.range_push('create_batch') 39 | 40 | buf = info.get_buffer() 41 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 42 | 43 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 44 | image_batch.append(image_tensor) 45 | 46 | if len(image_batch) < batch_size: 47 | return Gst.PadProbeReturn.OK 48 | 49 | torch.cuda.nvtx.range_pop() # create_batch 50 | 51 | image_batch = preprocess(torch.stack(image_batch)) 52 | frames_processed += image_batch.size(0) 53 | 54 | with torch.no_grad(): 55 | with nvtx_range('inference'): 56 | locs, labels = detector(image_batch) 57 | image_batch = [] 58 | postprocess(locs, labels) 59 | 60 | torch.cuda.nvtx.range_pop() # batch 61 | return Gst.PadProbeReturn.OK 62 | 63 | 64 | def buffer_to_image_tensor(buf, caps): 65 | with nvtx_range('buffer_to_image_tensor'): 66 | caps_structure = caps.get_structure(0) 67 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 68 | 69 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 70 | if is_mapped: 71 | try: 72 | image_array = np.ndarray( 73 | (height, width, pixel_bytes), 74 | dtype=np.uint8, 75 | buffer=map_info.data 76 | ) 77 | return torch.from_numpy( 78 | image_array[:,:,:3].copy() # RGBA -> RGB, and extend lifetime beyond subsequent unmap 79 | ) 80 | finally: 81 | buf.unmap(map_info) 82 | 83 | 84 | def preprocess(image_batch): 85 | '300x300 centre crop, normalize, HWC -> CHW' 86 | with nvtx_range('preprocess'): 87 | batch_dim, image_height, image_width, image_depth = image_batch.size() 88 | copy_x, copy_y = min(300, image_width), min(300, image_height) 89 | 90 | dest_x_offset = max(0, (300 - image_width) // 2) 91 | source_x_offset = max(0, (image_width - 300) // 2) 92 | dest_y_offset = max(0, (300 - image_height) // 2) 93 | source_y_offset = max(0, (image_height - 300) // 2) 94 | 95 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 96 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 97 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 98 | 99 | return torch.einsum( 100 | 'bhwc -> bchw', 101 | normalize(input_batch / 255) 102 | ).contiguous() 103 | 104 | 105 | def normalize(input_tensor): 106 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 107 | return (2.0 * input_tensor) - 1.0 108 | 109 | 110 | def init_dboxes(): 111 | 'adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Detection/SSD/src/utils.py' 112 | fig_size = 300 113 | feat_size = [38, 19, 10, 5, 3, 1] 114 | steps = [8, 16, 32, 64, 100, 300] 115 | scales = [21, 45, 99, 153, 207, 261, 315] 116 | aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]] 117 | 118 | fk = fig_size / torch.tensor(steps).float() 119 | 120 | dboxes = [] 121 | # size of feature and number of feature 122 | for idx, sfeat in enumerate(feat_size): 123 | sk1 = scales[idx] / fig_size 124 | sk2 = scales[idx + 1] / fig_size 125 | sk3 = math.sqrt(sk1 * sk2) 126 | all_sizes = [(sk1, sk1), (sk3, sk3)] 127 | 128 | for alpha in aspect_ratios[idx]: 129 | w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha) 130 | all_sizes.append((w, h)) 131 | all_sizes.append((h, w)) 132 | 133 | for w, h in all_sizes: 134 | for i, j in itertools.product(range(sfeat), repeat=2): 135 | cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx] 136 | dboxes.append((cx, cy, w, h)) 137 | 138 | return torch.tensor( 139 | dboxes, 140 | dtype=model_dtype, 141 | device=device 142 | ).clamp(0, 1) 143 | 144 | 145 | dboxes_xywh = init_dboxes().unsqueeze(dim=0) 146 | scale_xy = 0.1 147 | scale_wh = 0.2 148 | 149 | 150 | def xywh_to_xyxy(bboxes_batch, scores_batch): 151 | bboxes_batch = bboxes_batch.permute(0, 2, 1) 152 | scores_batch = scores_batch.permute(0, 2, 1) 153 | 154 | bboxes_batch[:, :, :2] = scale_xy * bboxes_batch[:, :, :2] 155 | bboxes_batch[:, :, 2:] = scale_wh * bboxes_batch[:, :, 2:] 156 | 157 | bboxes_batch[:, :, :2] = bboxes_batch[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2] 158 | bboxes_batch[:, :, 2:] = bboxes_batch[:, :, 2:].exp() * dboxes_xywh[:, :, 2:] 159 | 160 | # transform format to ltrb 161 | l, t, r, b = bboxes_batch[:, :, 0] - 0.5 * bboxes_batch[:, :, 2],\ 162 | bboxes_batch[:, :, 1] - 0.5 * bboxes_batch[:, :, 3],\ 163 | bboxes_batch[:, :, 0] + 0.5 * bboxes_batch[:, :, 2],\ 164 | bboxes_batch[:, :, 1] + 0.5 * bboxes_batch[:, :, 3] 165 | 166 | bboxes_batch[:, :, 0] = l 167 | bboxes_batch[:, :, 1] = t 168 | bboxes_batch[:, :, 2] = r 169 | bboxes_batch[:, :, 3] = b 170 | 171 | return bboxes_batch, torch.nn.functional.softmax(scores_batch, dim=-1) 172 | 173 | 174 | def postprocess(locs, labels): 175 | with nvtx_range('postprocess'): 176 | locs, probs = xywh_to_xyxy(locs, labels) 177 | 178 | # flatten batch and classes 179 | batch_dim, box_dim, class_dim = probs.size() 180 | flat_locs = locs.reshape(-1, 4).repeat_interleave(class_dim, dim=0) 181 | flat_probs = probs.view(-1) 182 | class_indexes = torch.arange(class_dim, device=device).repeat(batch_dim * box_dim) 183 | image_indexes = (torch.ones(box_dim * class_dim, device=device) * torch.arange(1, batch_dim + 1, device=device).unsqueeze(-1)).view(-1) 184 | 185 | # only do NMS on detections over threshold, and ignore background (0) 186 | threshold_mask = (flat_probs > detection_threshold) & (class_indexes > 0) 187 | flat_locs = flat_locs[threshold_mask] 188 | flat_probs = flat_probs[threshold_mask] 189 | class_indexes = class_indexes[threshold_mask] 190 | image_indexes = image_indexes[threshold_mask] 191 | 192 | nms_mask = torchvision.ops.boxes.batched_nms( 193 | flat_locs, 194 | flat_probs, 195 | class_indexes * image_indexes, 196 | iou_threshold=0.7 197 | ) 198 | 199 | bboxes = flat_locs[nms_mask].cpu() 200 | probs = flat_probs[nms_mask].cpu() 201 | class_indexes = class_indexes[nms_mask].cpu() 202 | if bboxes.size(0) > 0: 203 | print(bboxes, class_indexes, probs) 204 | 205 | 206 | Gst.init() 207 | pipeline = Gst.parse_launch(f''' 208 | filesrc location=media/in.mp4 num-buffers=256 ! 209 | decodebin ! 210 | nvvideoconvert ! 211 | video/x-raw,format={frame_format} ! 212 | fakesink name=s 213 | ''') 214 | 215 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 216 | Gst.PadProbeType.BUFFER, 217 | on_frame_probe 218 | ) 219 | 220 | pipeline.set_state(Gst.State.PLAYING) 221 | 222 | try: 223 | while True: 224 | msg = pipeline.get_bus().timed_pop_filtered( 225 | Gst.SECOND, 226 | Gst.MessageType.EOS | Gst.MessageType.ERROR 227 | ) 228 | if msg: 229 | text = msg.get_structure().to_string() if msg.get_structure() else '' 230 | msg_type = Gst.message_type_get_name(msg.type) 231 | print(f'{msg.src.name}: [{msg_type}] {text}') 232 | break 233 | finally: 234 | finish_time = time.time() 235 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 236 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 237 | ) 238 | pipeline.set_state(Gst.State.NULL) 239 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 240 | -------------------------------------------------------------------------------- /tuning_postprocess_1.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import contextlib 4 | import gi 5 | gi.require_version('Gst', '1.0') 6 | from gi.repository import Gst 7 | import numpy as np 8 | import torch, torchvision 9 | 10 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp32' 11 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval().to(device) 14 | ssd_utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils') 15 | detection_threshold = 0.4 16 | start_time, frames_processed = None, 0 17 | 18 | # context manager to help keep track of ranges of time, using NVTX 19 | @contextlib.contextmanager 20 | def nvtx_range(msg): 21 | depth = torch.cuda.nvtx.range_push(msg) 22 | try: 23 | yield depth 24 | finally: 25 | torch.cuda.nvtx.range_pop() 26 | 27 | 28 | def on_frame_probe(pad, info): 29 | global start_time, frames_processed 30 | start_time = start_time or time.time() 31 | 32 | with nvtx_range('on_frame_probe'): 33 | buf = info.get_buffer() 34 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 35 | 36 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 37 | image_batch = preprocess(image_tensor.unsqueeze(0)) 38 | frames_processed += image_batch.size(0) 39 | 40 | with torch.no_grad(): 41 | with nvtx_range('inference'): 42 | locs, labels = detector(image_batch) 43 | postprocess(locs, labels) 44 | 45 | return Gst.PadProbeReturn.OK 46 | 47 | 48 | def buffer_to_image_tensor(buf, caps): 49 | with nvtx_range('buffer_to_image_tensor'): 50 | caps_structure = caps.get_structure(0) 51 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 52 | 53 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 54 | if is_mapped: 55 | try: 56 | image_array = np.ndarray( 57 | (height, width, pixel_bytes), 58 | dtype=np.uint8, 59 | buffer=map_info.data 60 | ) 61 | return torch.from_numpy( 62 | image_array[:,:,:3].copy() # RGBA -> RGB, and extend lifetime beyond subsequent unmap 63 | ) 64 | finally: 65 | buf.unmap(map_info) 66 | 67 | 68 | def preprocess(image_batch): 69 | '300x300 centre crop, normalize, HWC -> CHW' 70 | with nvtx_range('preprocess'): 71 | batch_dim, image_height, image_width, image_depth = image_batch.size() 72 | copy_x, copy_y = min(300, image_width), min(300, image_height) 73 | 74 | dest_x_offset = max(0, (300 - image_width) // 2) 75 | source_x_offset = max(0, (image_width - 300) // 2) 76 | dest_y_offset = max(0, (300 - image_height) // 2) 77 | source_y_offset = max(0, (image_height - 300) // 2) 78 | 79 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 80 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 81 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 82 | 83 | return torch.einsum( 84 | 'bhwc -> bchw', 85 | normalize(input_batch / 255) 86 | ).contiguous() 87 | 88 | 89 | def normalize(input_tensor): 90 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 91 | return (2.0 * input_tensor) - 1.0 92 | 93 | 94 | def postprocess(locs, labels): 95 | with nvtx_range('postprocess'): 96 | results_batch = ssd_utils.decode_results((locs.cpu(), labels.cpu())) 97 | results_batch = [ssd_utils.pick_best(results, detection_threshold) for results in results_batch] 98 | for bboxes, classes, scores in results_batch: 99 | if scores.shape[0] > 0: 100 | print(bboxes, classes, scores) 101 | 102 | 103 | Gst.init() 104 | pipeline = Gst.parse_launch(f''' 105 | filesrc location=media/in.mp4 num-buffers=256 ! 106 | decodebin ! 107 | nvvideoconvert ! 108 | video/x-raw,format={frame_format} ! 109 | fakesink name=s 110 | ''') 111 | 112 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 113 | Gst.PadProbeType.BUFFER, 114 | on_frame_probe 115 | ) 116 | 117 | pipeline.set_state(Gst.State.PLAYING) 118 | 119 | try: 120 | while True: 121 | msg = pipeline.get_bus().timed_pop_filtered( 122 | Gst.SECOND, 123 | Gst.MessageType.EOS | Gst.MessageType.ERROR 124 | ) 125 | if msg: 126 | text = msg.get_structure().to_string() if msg.get_structure() else '' 127 | msg_type = Gst.message_type_get_name(msg.type) 128 | print(f'{msg.src.name}: [{msg_type}] {text}') 129 | break 130 | finally: 131 | finish_time = time.time() 132 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 133 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 134 | ) 135 | pipeline.set_state(Gst.State.NULL) 136 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 137 | -------------------------------------------------------------------------------- /tuning_postprocess_2.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math, time 3 | import itertools 4 | import contextlib 5 | import gi 6 | gi.require_version('Gst', '1.0') 7 | from gi.repository import Gst 8 | import numpy as np 9 | import torch, torchvision 10 | 11 | frame_format, pixel_bytes, model_precision = 'RGBA', 4, 'fp32' 12 | model_dtype = torch.float16 if model_precision == 'fp16' else torch.float32 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | detector = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=model_precision).eval().to(device) 15 | ssd_utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils') 16 | detection_threshold = 0.4 17 | start_time, frames_processed = None, 0 18 | 19 | # context manager to help keep track of ranges of time, using NVTX 20 | @contextlib.contextmanager 21 | def nvtx_range(msg): 22 | depth = torch.cuda.nvtx.range_push(msg) 23 | try: 24 | yield depth 25 | finally: 26 | torch.cuda.nvtx.range_pop() 27 | 28 | 29 | def on_frame_probe(pad, info): 30 | global start_time, frames_processed 31 | start_time = start_time or time.time() 32 | 33 | with nvtx_range('on_frame_probe'): 34 | buf = info.get_buffer() 35 | print(f'[{buf.pts / Gst.SECOND:6.2f}]') 36 | 37 | image_tensor = buffer_to_image_tensor(buf, pad.get_current_caps()) 38 | image_batch = preprocess(image_tensor.unsqueeze(0)) 39 | frames_processed += image_batch.size(0) 40 | 41 | with torch.no_grad(): 42 | with nvtx_range('inference'): 43 | locs, labels = detector(image_batch) 44 | postprocess(locs, labels) 45 | 46 | return Gst.PadProbeReturn.OK 47 | 48 | 49 | def buffer_to_image_tensor(buf, caps): 50 | with nvtx_range('buffer_to_image_tensor'): 51 | caps_structure = caps.get_structure(0) 52 | height, width = caps_structure.get_value('height'), caps_structure.get_value('width') 53 | 54 | is_mapped, map_info = buf.map(Gst.MapFlags.READ) 55 | if is_mapped: 56 | try: 57 | image_array = np.ndarray( 58 | (height, width, pixel_bytes), 59 | dtype=np.uint8, 60 | buffer=map_info.data 61 | ) 62 | return torch.from_numpy( 63 | image_array[:,:,:3].copy() # RGBA -> RGB, and extend lifetime beyond subsequent unmap 64 | ) 65 | finally: 66 | buf.unmap(map_info) 67 | 68 | 69 | def preprocess(image_batch): 70 | '300x300 centre crop, normalize, HWC -> CHW' 71 | with nvtx_range('preprocess'): 72 | batch_dim, image_height, image_width, image_depth = image_batch.size() 73 | copy_x, copy_y = min(300, image_width), min(300, image_height) 74 | 75 | dest_x_offset = max(0, (300 - image_width) // 2) 76 | source_x_offset = max(0, (image_width - 300) // 2) 77 | dest_y_offset = max(0, (300 - image_height) // 2) 78 | source_y_offset = max(0, (image_height - 300) // 2) 79 | 80 | input_batch = torch.zeros((batch_dim, 300, 300, 3), dtype=model_dtype, device=device) 81 | input_batch[:, dest_y_offset:dest_y_offset + copy_y, dest_x_offset:dest_x_offset + copy_x] = \ 82 | image_batch[:, source_y_offset:source_y_offset + copy_y, source_x_offset:source_x_offset + copy_x] 83 | 84 | return torch.einsum( 85 | 'bhwc -> bchw', 86 | normalize(input_batch / 255) 87 | ).contiguous() 88 | 89 | 90 | def normalize(input_tensor): 91 | 'Nvidia SSD300 code uses mean and std-dev of 128/256' 92 | return (2.0 * input_tensor) - 1.0 93 | 94 | 95 | def init_dboxes(): 96 | 'adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Detection/SSD/src/utils.py' 97 | fig_size = 300 98 | feat_size = [38, 19, 10, 5, 3, 1] 99 | steps = [8, 16, 32, 64, 100, 300] 100 | scales = [21, 45, 99, 153, 207, 261, 315] 101 | aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]] 102 | 103 | fk = fig_size / torch.tensor(steps).float() 104 | 105 | dboxes = [] 106 | # size of feature and number of feature 107 | for idx, sfeat in enumerate(feat_size): 108 | sk1 = scales[idx] / fig_size 109 | sk2 = scales[idx + 1] / fig_size 110 | sk3 = math.sqrt(sk1 * sk2) 111 | all_sizes = [(sk1, sk1), (sk3, sk3)] 112 | 113 | for alpha in aspect_ratios[idx]: 114 | w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha) 115 | all_sizes.append((w, h)) 116 | all_sizes.append((h, w)) 117 | 118 | for w, h in all_sizes: 119 | for i, j in itertools.product(range(sfeat), repeat=2): 120 | cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx] 121 | dboxes.append((cx, cy, w, h)) 122 | 123 | return torch.tensor( 124 | dboxes, 125 | dtype=model_dtype, 126 | device=device 127 | ).clamp(0, 1) 128 | 129 | 130 | dboxes_xywh = init_dboxes().unsqueeze(dim=0) 131 | scale_xy = 0.1 132 | scale_wh = 0.2 133 | 134 | 135 | def xywh_to_xyxy(bboxes_batch, scores_batch): 136 | bboxes_batch = bboxes_batch.permute(0, 2, 1) 137 | scores_batch = scores_batch.permute(0, 2, 1) 138 | 139 | bboxes_batch[:, :, :2] = scale_xy * bboxes_batch[:, :, :2] 140 | bboxes_batch[:, :, 2:] = scale_wh * bboxes_batch[:, :, 2:] 141 | 142 | bboxes_batch[:, :, :2] = bboxes_batch[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2] 143 | bboxes_batch[:, :, 2:] = bboxes_batch[:, :, 2:].exp() * dboxes_xywh[:, :, 2:] 144 | 145 | # transform format to ltrb 146 | l, t, r, b = bboxes_batch[:, :, 0] - 0.5 * bboxes_batch[:, :, 2],\ 147 | bboxes_batch[:, :, 1] - 0.5 * bboxes_batch[:, :, 3],\ 148 | bboxes_batch[:, :, 0] + 0.5 * bboxes_batch[:, :, 2],\ 149 | bboxes_batch[:, :, 1] + 0.5 * bboxes_batch[:, :, 3] 150 | 151 | bboxes_batch[:, :, 0] = l 152 | bboxes_batch[:, :, 1] = t 153 | bboxes_batch[:, :, 2] = r 154 | bboxes_batch[:, :, 3] = b 155 | 156 | return bboxes_batch, torch.nn.functional.softmax(scores_batch, dim=-1) 157 | 158 | 159 | def postprocess(locs, labels): 160 | with nvtx_range('postprocess'): 161 | locs, probs = xywh_to_xyxy(locs, labels) 162 | 163 | # flatten batch and classes 164 | batch_dim, box_dim, class_dim = probs.size() 165 | flat_locs = locs.reshape(-1, 4).repeat_interleave(class_dim, dim=0) 166 | flat_probs = probs.view(-1) 167 | class_indexes = torch.arange(class_dim, device=device).repeat(batch_dim * box_dim) 168 | image_indexes = (torch.ones(box_dim * class_dim, device=device) * torch.arange(1, batch_dim + 1, device=device).unsqueeze(-1)).view(-1) 169 | 170 | # only do NMS on detections over threshold, and ignore background (0) 171 | threshold_mask = (flat_probs > detection_threshold) & (class_indexes > 0) 172 | flat_locs = flat_locs[threshold_mask] 173 | flat_probs = flat_probs[threshold_mask] 174 | class_indexes = class_indexes[threshold_mask] 175 | image_indexes = image_indexes[threshold_mask] 176 | 177 | nms_mask = torchvision.ops.boxes.batched_nms( 178 | flat_locs, 179 | flat_probs, 180 | class_indexes * image_indexes, 181 | iou_threshold=0.7 182 | ) 183 | 184 | bboxes = flat_locs[nms_mask].cpu() 185 | probs = flat_probs[nms_mask].cpu() 186 | class_indexes = class_indexes[nms_mask].cpu() 187 | if bboxes.size(0) > 0: 188 | print(bboxes, class_indexes, probs) 189 | 190 | 191 | Gst.init() 192 | pipeline = Gst.parse_launch(f''' 193 | filesrc location=media/in.mp4 num-buffers=256 ! 194 | decodebin ! 195 | nvvideoconvert ! 196 | video/x-raw,format={frame_format} ! 197 | fakesink name=s 198 | ''') 199 | 200 | pipeline.get_by_name('s').get_static_pad('sink').add_probe( 201 | Gst.PadProbeType.BUFFER, 202 | on_frame_probe 203 | ) 204 | 205 | pipeline.set_state(Gst.State.PLAYING) 206 | 207 | try: 208 | while True: 209 | msg = pipeline.get_bus().timed_pop_filtered( 210 | Gst.SECOND, 211 | Gst.MessageType.EOS | Gst.MessageType.ERROR 212 | ) 213 | if msg: 214 | text = msg.get_structure().to_string() if msg.get_structure() else '' 215 | msg_type = Gst.message_type_get_name(msg.type) 216 | print(f'{msg.src.name}: [{msg_type}] {text}') 217 | break 218 | finally: 219 | finish_time = time.time() 220 | open(f'logs/{os.path.splitext(sys.argv[0])[0]}.pipeline.dot', 'w').write( 221 | Gst.debug_bin_to_dot_data(pipeline, Gst.DebugGraphDetails.ALL) 222 | ) 223 | pipeline.set_state(Gst.State.NULL) 224 | print(f'FPS: {frames_processed / (finish_time - start_time):.2f}') 225 | --------------------------------------------------------------------------------