├── .gitignore ├── LICENSE ├── README.md ├── figs ├── coin.png ├── k400.png ├── main.png ├── ssv2.png ├── teaser.png └── zs_vt_retrieval.png └── single_modality ├── MODEL_ZOO.md ├── __pycache__ ├── functional.cpython-39.pyc ├── optim_factory.cpython-39.pyc └── utils.cpython-39.pyc ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── build.cpython-39.pyc │ ├── hmdb.cpython-39.pyc │ ├── kinetics.cpython-39.pyc │ ├── kinetics_sparse.cpython-39.pyc │ ├── kinetics_sparse_o4a.cpython-39.pyc │ ├── mae.cpython-39.pyc │ ├── mae_multi.cpython-39.pyc │ ├── mae_multi_ofa.cpython-39.pyc │ ├── mae_ofa.cpython-39.pyc │ ├── masking_generator.cpython-39.pyc │ ├── mixup.cpython-39.pyc │ ├── rand_augment.cpython-39.pyc │ ├── random_erasing.cpython-39.pyc │ ├── ssv2.cpython-39.pyc │ ├── ssv2_ofa.cpython-39.pyc │ ├── transforms.cpython-39.pyc │ ├── video_transforms.cpython-39.pyc │ └── volume_transforms.cpython-39.pyc ├── build.py ├── hmdb.py ├── kinetics.py ├── kinetics_sparse.py ├── kinetics_sparse_o4a.py ├── mae.py ├── mae_multi.py ├── mae_multi_ofa.py ├── mae_ofa.py ├── masking_generator.py ├── mixup.py ├── multiloader.py ├── rand_augment.py ├── random_erasing.py ├── ssv2.py ├── ssv2_ofa.py ├── transforms.py ├── video_transforms.py └── volume_transforms.py ├── engines ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── engine_for_flexible_tuning.cpython-39.pyc ├── engine_for_flexible_tuning.py └── engine_for_flexible_tuning_single.py ├── exp ├── base │ ├── eval │ │ └── k400_eval.sh │ └── ft │ │ └── k400 │ │ ├── k400_multi.sh │ │ └── k400_single.sh └── small │ └── eval │ └── k400_eval.sh ├── functional.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── flash_attention_class.cpython-39.pyc │ ├── fluxvit.cpython-39.pyc │ ├── pos_embed.cpython-39.pyc │ └── vid_tldr.cpython-39.pyc ├── flash_attention_class.py ├── fluxvit.py ├── pos_embed.py └── vid_tldr.py ├── optim_factory.py ├── run_flexible_finetune.py ├── run_flexible_finetune_single.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | multi_modality/* 2 | *.pt 3 | single_modality/rename* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 OpenGVLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Make Your Training Flexible: Towards Deployment-Efficient Video Models 2 | 3 | This repo is the official implementation of "[Make Your Training Flexible: Towards Deployment-Efficient Video Models](https://arxiv.org/abs/2503.14237)". By Chenting Wang, Kunchang Li, Tianxiang Jiang, XiangyuZeng, Yi Wang, and Limin Wang. 4 | 5 | teaser 6 | 7 | ## Update 8 | 9 | - **2025/04/29**: Release of the single modality evaluation script and part of model weights, see [MODEL_ZOO](./single_modality/MODEL_ZOO.md). 10 | - **2025/03/18**: We build the repo and release the [paper](https://arxiv.org/abs/2503.14237). 11 | 12 | ## Introduction 13 | 14 | Popular video training methods mainly operate on a fixed number of tokens sampled from a predetermined spatiotemporal grid, resulting in sub-optimal accuracy-computation trade-offs due to inherent video redundancy. They also lack adaptability to varying computational budgets for downstream tasks, hindering applications of the most competitive model in real-world scenes. We thus propose a new test setting, Token Optimization, for maximized input information across budgets, which optimizes the size-limited set of input tokens through token selection from more suitably sampled videos. To this end, we propose a novel augmentation tool termed Flux. By making the sampling grid flexible and leveraging token selection, it is easily adopted in most popular video training frameworks, boosting model robustness with nearly no additional cost. We integrate Flux in large-scale video pre-training, and the resulting FluxViT establishes new state-of-the-art results across extensive tasks at standard costs. Notably, with 1/4 tokens only, it can still match the performance of previous state-of-the-art models with Token Optimization, yielding nearly 90\% savings. 15 | 16 | Sampling Strategy 17 | 18 | ## Performance 19 | 20 | ### Single Modality Action Recognition 21 | 22 | #### K400 23 | 24 | - Base Scale 25 | 26 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** | 27 | |------------------|------------------|-----------|----------------| 28 | | InternVideo2-B | 440×12 | 87.4 | - | 29 | | FluxViT-B | 440×12 | 89.6 | 90.0 | 30 | | FluxViT-B | 255×12 | 89.3 | 89.7 | 31 | | FluxViT-B | 108×12 | 87.3 | 88.9 | 32 | | FluxViT-B | 49×12 | 84.7 | 87.4 | 33 | 34 | - Small Scale 35 | 36 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** | 37 | |------------------|------------------|-----------|----------------| 38 | | InternVideo2-S | 154×12 | 85.8 | - | 39 | | FluxViT-S | 154×12 | 87.7 | 88.0 | 40 | | FluxViT-S | 83×12 | 87.3 | 87.7 | 41 | | FluxViT-S | 32×12 | 84.7 | 86.6 | 42 | | FluxViT-S | 13×12 | 80.1 | 84.7 | 43 | 44 | #### SSv2 45 | 46 | - Base Scale 47 | 48 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** | 49 | |------------------|-----------------|-----------|----------------| 50 | | InternVideo2-B | 253×6 | 73.5 | - | 51 | | FluxViT-B | 440×6 | 75.3 | 75.6 | 52 | | FluxViT-B | 255×6 | 75.1 | 75.7 | 53 | | FluxViT-B | 108×6 | 72.0 | 75.1 | 54 | | FluxViT-B | 49×6 | 56.8 | 73.9 | 55 | 56 | - Small Scale 57 | 58 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** | 59 | |------------------|-----------------|-----------|----------------| 60 | | InternVideo2-S | 154×6 | 71.5 | - | 61 | | FluxViT-S | 154×6 | 73.4 | 73.8 | 62 | | FluxViT-S | 83×6 | 72.9 | 73.4 | 63 | | FluxViT-S | 32×6 | 70.0 | 72.5 | 64 | | FluxViT-S | 13×6 | 55.3 | 70.9 | 65 | 66 | ### Multi Modality Video-Text Retrieval 67 | 68 | zs_vt_retrieval 69 | 70 | ### Multi Modality VideoChat Model 71 | 72 | Coming Soon 73 | 74 | ## Acknowledgement 75 | 76 | This repository is built based on [UniFormer](https://github.com/Sense-X/UniFormer), [VideoMAE](https://github.com/MCG-NJU/VideoMAE), [VINDLU](https://github.com/klauscc/VindLU) and [Unmasked Teacher](https://github.com/OpenGVLab/unmasked_teacher/) repository. 77 | -------------------------------------------------------------------------------- /figs/coin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/coin.png -------------------------------------------------------------------------------- /figs/k400.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/k400.png -------------------------------------------------------------------------------- /figs/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/main.png -------------------------------------------------------------------------------- /figs/ssv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/ssv2.png -------------------------------------------------------------------------------- /figs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/teaser.png -------------------------------------------------------------------------------- /figs/zs_vt_retrieval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/zs_vt_retrieval.png -------------------------------------------------------------------------------- /single_modality/MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | ## Note 4 | 5 | - For all the pretraining and finetuning, we adopt spaese/uniform sampling. 6 | - `#Frame` $=$ `#input_frame` $\times$ `#crop` $\times$ `#clip` 7 | - `#input_frame` means how many frames are input for model per inference 8 | - `#crop` means spatial crops (e.g., 3 for left/right/center) 9 | - `#clip` means temporal clips (e.g., 4 means repeted sampling four clips with different start indices) 10 | 11 | ## Pretraining 12 | 13 | TBD 14 | 15 | ## Distillation 16 | 17 | TBD 18 | 19 | ## Finetuning 20 | 21 | ### K710 22 | 23 | TBD 24 | 25 | 26 | ### K400 27 | 28 | | Model | Setting | #Frame | Top-1 | Model | Shell | 29 | | -------- | ------------- | -------- | ------ | ------ | ------ | 30 | | $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 8x3x4 | 91.3 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2-Stage1-1B-224p-f8-K400/blob/main/1B_ft_k710_ft_k400_f8.pth) | TBD | 31 | | $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 16x3x4 | 91.6 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2-Stage1-1B-224p-f8-K400/blob/main/1B_ft_k710_ft_k400_f16.pth) | TBD | 32 | | $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 8x3x4 | 91.9 | TBD | TBD | 33 | | $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 16x3x4 | 92.1 | TBD | TBD | 34 | | $\text{InternVideo2}_{dist}$-S/14 | K-Mash PT + K710 FT | 8x3x4 | 85.4 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2_distillation_models/resolve/main/stage1/S14/S14_ft_k710_ft_k400_f8/pytorch_model.bin) | TBD | 35 | | $\text{InternVideo2}_{dist}$-B/14 | K-Mash PT + K710 FT | 8x3x4 | 88.4 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2_distillation_models/resolve/main/stage1/B14/B14_ft_k710_ft_k400_f8/pytorch_model.bin) | TBD | 36 | | $\text{InternVideo2}_{dist}$-L/14 | K-Mash PT + K710 FT | 8x3x4 | 90.4 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2_distillation_models/resolve/main/stage1/L14/L14_ft_k710_ft_k400_f8/pytorch_model.bin) | TBD | 37 | | $\text{FluxViT}$-S/14 | K-Mash PT + K710 FT | 8x3x4 | 87.3 | [Link](https://drive.google.com/file/d/1OTjTsAnZGaq7AufDaw8IYLeSgmLYZjds/view?usp=sharing) | [run.sh](./exp/small/eval/k400_eval.sh) | 38 | | $\text{FluxViT}$-B/14 | K-Mash PT + K710 FT | 8x3x4 | 89.3 | [Link](https://drive.google.com/file/d/1YsxsB3_pkpdvmXQIhD3YOmWlqCyzJskg/view?usp=sharing) | [run.sh](./exp/base/eval/k400_eval.sh) | 39 | 40 | 41 | 42 | ### SthSth V2 43 | 44 | TBD 45 | -------------------------------------------------------------------------------- /single_modality/__pycache__/functional.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/__pycache__/functional.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/__pycache__/optim_factory.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/__pycache__/optim_factory.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_dataset, build_pretraining_dataset, build_multi_pretraining_dataset, build_pretraining_dataset_ofa -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/build.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/build.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/hmdb.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/hmdb.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/kinetics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/kinetics.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/kinetics_sparse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/kinetics_sparse.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/kinetics_sparse_o4a.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/kinetics_sparse_o4a.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/mae.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/mae_multi.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae_multi.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/mae_multi_ofa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae_multi_ofa.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/mae_ofa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae_ofa.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/masking_generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/masking_generator.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/mixup.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mixup.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/rand_augment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/rand_augment.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/random_erasing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/random_erasing.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/ssv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/ssv2.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/ssv2_ofa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/ssv2_ofa.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/video_transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/video_transforms.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/__pycache__/volume_transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/volume_transforms.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/datasets/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import transforms 3 | from .transforms import * 4 | from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator 5 | from .mae import VideoMAE 6 | from .masking_generator import RandomMaskingGenerator, TemporalConsistencyMaskingGenerator, TemporalProgressiveMaskingGenerator, TemporalCenteringProgressiveMaskingGenerator 7 | from .mae_multi import VideoMAE_multi 8 | from .mae_multi_ofa import VideoMAE_multi_ofa 9 | from .kinetics import VideoClsDataset 10 | from .kinetics_sparse import VideoClsDataset_sparse 11 | from .kinetics_sparse_o4a import VideoClsDataset_sparse_ofa 12 | # from .anet import ANetDataset 13 | from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset 14 | from .ssv2_ofa import SSRawFrameClsDataset_OFA, SSVideoClsDataset_OFA 15 | from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset 16 | from .mae_ofa import VideoMAE_ofa 17 | 18 | 19 | class DataAugmentationForVideoMAE(object): 20 | def __init__(self, args): 21 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 22 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 23 | normalize = GroupNormalize(self.input_mean, self.input_std) 24 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) 25 | if args.color_jitter > 0: 26 | self.transform = transforms.Compose([ 27 | self.train_augmentation, 28 | GroupColorJitter(args.color_jitter), 29 | GroupRandomHorizontalFlip(flip=args.flip), 30 | Stack(roll=False), 31 | ToTorchFormatTensor(div=True), 32 | normalize, 33 | ]) 34 | else: 35 | self.transform = transforms.Compose([ 36 | self.train_augmentation, 37 | GroupRandomHorizontalFlip(flip=args.flip), 38 | Stack(roll=False), 39 | ToTorchFormatTensor(div=True), 40 | normalize, 41 | ]) 42 | if args.mask_type == 'tube': 43 | self.masked_position_generator = TubeMaskingGenerator( 44 | args.window_size, args.mask_ratio 45 | ) 46 | elif args.mask_type == 'random': 47 | self.masked_position_generator = RandomMaskingGenerator( 48 | args.window_size, args.mask_ratio 49 | ) 50 | elif args.mask_type == 't_consist': 51 | self.masked_position_generator = TemporalConsistencyMaskingGenerator( 52 | args.window_size, args.student_mask_ratio, args.teacher_mask_ratio 53 | ) 54 | elif args.mask_type == 't_progressive': 55 | self.masked_position_generator = TemporalProgressiveMaskingGenerator( 56 | args.window_size, args.student_mask_ratio 57 | ) 58 | elif args.mask_type == 't_center_prog': 59 | self.masked_position_generator = TemporalCenteringProgressiveMaskingGenerator( 60 | args.window_size, args.student_mask_ratio 61 | ) 62 | elif args.mask_type in 'attention': 63 | self.masked_position_generator = None 64 | 65 | def __call__(self, images): 66 | process_data, _ = self.transform(images) 67 | if self.masked_position_generator is None: 68 | return process_data, -1 69 | else: 70 | return process_data, self.masked_position_generator() 71 | 72 | def __repr__(self): 73 | repr = "(DataAugmentationForVideoMAE,\n" 74 | repr += " transform = %s,\n" % str(self.transform) 75 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 76 | repr += ")" 77 | return repr 78 | 79 | 80 | def build_pretraining_dataset(args): 81 | transform = DataAugmentationForVideoMAE(args) 82 | dataset = VideoMAE( 83 | root=None, 84 | setting=args.data_path, 85 | prefix=args.prefix, 86 | split=args.split, 87 | video_ext='mp4', 88 | is_color=True, 89 | modality='rgb', 90 | num_segments=args.num_segments, 91 | new_length=args.num_frames, 92 | new_step=args.sampling_rate, 93 | transform=transform, 94 | temporal_jitter=False, 95 | video_loader=True, 96 | use_decord=args.use_decord, 97 | lazy_init=False, 98 | num_sample=args.num_sample) 99 | print("Data Aug = %s" % str(transform)) 100 | return dataset 101 | 102 | def build_pretraining_dataset_ofa(args): 103 | transform = DataAugmentationForVideoMAE(args) 104 | dataset = VideoMAE_ofa( 105 | root=None, 106 | setting=args.data_path, 107 | prefix=args.prefix, 108 | split=args.split, 109 | video_ext='mp4', 110 | is_color=True, 111 | modality='rgb', 112 | num_segments=args.num_segments, 113 | new_length=args.num_frames, 114 | new_step=args.sampling_rate, 115 | transform=transform, 116 | temporal_jitter=False, 117 | video_loader=True, 118 | use_decord=args.use_decord, 119 | lazy_init=False, 120 | num_sample=args.num_sample) 121 | print("Data Aug = %s" % str(transform)) 122 | return dataset 123 | 124 | 125 | def build_once4all_pretraining_dataset(args, num_datasets): 126 | datasets = [] 127 | for i in range(num_datasets): 128 | args.input_size = ... 129 | datasets.append(build_pretraining_dataset(args)) 130 | return datasets 131 | 132 | 133 | def build_dataset(is_train, test_mode, args): 134 | print(f'Use Dataset: {args.data_set}') 135 | if args.data_set in [ 136 | 'Kinetics', 137 | 'Kinetics_sparse', 138 | 'Kinetics_sparse_ofa', 139 | 'mitv1_sparse' 140 | ]: 141 | mode = None 142 | anno_path = None 143 | if is_train is True: 144 | mode = 'train' 145 | anno_path = os.path.join(args.data_path, 'train.csv') 146 | elif test_mode is True: 147 | mode = 'test' 148 | anno_path = os.path.join(args.data_path, 'test.csv') 149 | else: 150 | mode = 'validation' 151 | anno_path = os.path.join(args.data_path, 'val.csv') 152 | 153 | if 'sparse_ofa' in args.data_set: 154 | func = VideoClsDataset_sparse_ofa 155 | elif 'sparse' in args.data_set: 156 | func = VideoClsDataset_sparse 157 | else: 158 | func = VideoClsDataset 159 | 160 | dataset = func( 161 | anno_path=anno_path, 162 | prefix=args.prefix, 163 | split=args.split, 164 | mode=mode, 165 | clip_len=args.num_frames if is_train else args.eval_true_frame, 166 | frame_sample_rate=args.sampling_rate, 167 | num_segment=1, 168 | test_num_segment=args.test_num_segment, 169 | test_num_crop=args.test_num_crop, 170 | num_crop=1 if not test_mode else 3, 171 | keep_aspect_ratio=True, 172 | crop_size=args.input_size if is_train else args.eval_input_size, 173 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size, 174 | new_height=256, 175 | new_width=320, 176 | args=args) 177 | 178 | nb_classes = args.nb_classes 179 | 180 | elif 'SSV2' in args.data_set: 181 | mode = None 182 | anno_path = None 183 | if is_train is True: 184 | mode = 'train' 185 | anno_path = os.path.join(args.data_path, 'train.csv') 186 | elif test_mode is True: 187 | mode = 'test' 188 | anno_path = os.path.join(args.data_path, 'test.csv') 189 | else: 190 | mode = 'validation' 191 | anno_path = os.path.join(args.data_path, 'val.csv') 192 | 193 | if args.use_decord: 194 | if 'ofa' in args.data_set: 195 | func = SSVideoClsDataset_OFA 196 | else: 197 | func = SSVideoClsDataset 198 | else: 199 | if 'ofa' in args.data_set: 200 | func = SSRawFrameClsDataset_OFA 201 | else: 202 | func = SSRawFrameClsDataset 203 | 204 | dataset = func( 205 | anno_path=anno_path, 206 | prefix=args.prefix, 207 | split=args.split, 208 | mode=mode, 209 | clip_len=1, 210 | num_segment=args.num_frames if is_train else args.eval_true_frame, 211 | test_num_segment=args.test_num_segment, 212 | test_num_crop=args.test_num_crop, 213 | num_crop=1 if not test_mode else 3, 214 | keep_aspect_ratio=True, 215 | crop_size=args.input_size if is_train else args.eval_input_size, 216 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size, 217 | new_height=256, 218 | new_width=320, 219 | filename_tmpl=args.filename_tmpl, 220 | args=args) 221 | nb_classes = 174 222 | 223 | elif args.data_set == 'UCF101': 224 | mode = None 225 | anno_path = None 226 | if is_train is True: 227 | mode = 'train' 228 | anno_path = os.path.join(args.data_path, 'train.csv') 229 | elif test_mode is True: 230 | mode = 'test' 231 | anno_path = os.path.join(args.data_path, 'test.csv') 232 | else: 233 | mode = 'validation' 234 | anno_path = os.path.join(args.data_path, 'val.csv') 235 | 236 | dataset = VideoClsDataset( 237 | anno_path=anno_path, 238 | prefix=args.prefix, 239 | split=args.split, 240 | mode=mode, 241 | clip_len=args.num_frames if is_train else args.eval_true_frame, 242 | frame_sample_rate=args.sampling_rate, 243 | num_segment=1, 244 | test_num_segment=args.test_num_segment, 245 | test_num_crop=args.test_num_crop, 246 | num_crop=1 if not test_mode else 3, 247 | keep_aspect_ratio=True, 248 | crop_size=args.input_size if is_train else args.eval_input_size, 249 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size, 250 | new_height=256, 251 | new_width=320, 252 | args=args) 253 | nb_classes = 101 254 | 255 | elif args.data_set == 'HMDB51': 256 | mode = None 257 | anno_path = None 258 | if is_train is True: 259 | mode = 'train' 260 | anno_path = os.path.join(args.data_path, 'train.csv') 261 | elif test_mode is True: 262 | mode = 'test' 263 | anno_path = os.path.join(args.data_path, 'test.csv') 264 | else: 265 | mode = 'validation' 266 | anno_path = os.path.join(args.data_path, 'val.csv') 267 | 268 | if args.use_decord: 269 | func = HMDBVideoClsDataset 270 | else: 271 | func = HMDBRawFrameClsDataset 272 | 273 | dataset = func( 274 | anno_path=anno_path, 275 | prefix=args.prefix, 276 | split=args.split, 277 | mode=mode, 278 | clip_len=1, 279 | num_segment=args.num_frames if is_train else args.eval_true_frame, 280 | test_num_segment=args.test_num_segment, 281 | test_num_crop=args.test_num_crop, 282 | num_crop=1 if not test_mode else 3, 283 | keep_aspect_ratio=True, 284 | crop_size=args.input_size if is_train else args.eval_input_size, 285 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size, 286 | new_height=256, 287 | new_width=320, 288 | filename_tmpl=args.filename_tmpl, 289 | args=args) 290 | nb_classes = 51 291 | 292 | elif args.data_set in [ 293 | 'ANet', 294 | 'HACS', 295 | 'ANet_interval', 296 | 'HACS_interval' 297 | ]: 298 | mode = None 299 | anno_path = None 300 | if is_train is True: 301 | mode = 'train' 302 | anno_path = os.path.join(args.data_path, 'train.csv') 303 | elif test_mode is True: 304 | mode = 'test' 305 | anno_path = os.path.join(args.data_path, 'test.csv') 306 | else: 307 | mode = 'validation' 308 | anno_path = os.path.join(args.data_path, 'val.csv') 309 | 310 | if 'interval' in args.data_set: 311 | func = ANetDataset 312 | else: 313 | func = VideoClsDataset_sparse 314 | 315 | dataset = func( 316 | anno_path=anno_path, 317 | prefix=args.prefix, 318 | split=args.split, 319 | mode=mode, 320 | clip_len=args.num_frames if is_train else args.eval_true_frame, 321 | frame_sample_rate=args.sampling_rate, 322 | num_segment=1, 323 | test_num_segment=args.test_num_segment, 324 | test_num_crop=args.test_num_crop, 325 | num_crop=1 if not test_mode else 3, 326 | keep_aspect_ratio=True, 327 | crop_size=args.input_size if is_train else args.eval_input_size, 328 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size, 329 | new_height=256, 330 | new_width=320, 331 | args=args) 332 | nb_classes = args.nb_classes 333 | 334 | else: 335 | print(f'Wrong: {args.data_set}') 336 | raise NotImplementedError() 337 | assert nb_classes == args.nb_classes 338 | print("Number of the class = %d" % args.nb_classes) 339 | 340 | return dataset, nb_classes 341 | 342 | 343 | def build_multi_pretraining_dataset(args): 344 | origianl_flip = args.flip 345 | transform = DataAugmentationForVideoMAE(args) 346 | args.flip = False 347 | transform_ssv2 = DataAugmentationForVideoMAE(args) 348 | args.flip = origianl_flip 349 | 350 | dataset = VideoMAE_multi_ofa( 351 | root=None, 352 | setting=args.data_path, 353 | prefix=args.prefix, 354 | split=args.split, 355 | is_color=True, 356 | modality='rgb', 357 | num_segments=args.num_segments, 358 | new_length=args.num_frames, 359 | new_step=args.sampling_rate, 360 | transform=transform, 361 | transform_ssv2=transform_ssv2, 362 | temporal_jitter=False, 363 | video_loader=True, 364 | use_decord=args.use_decord, 365 | lazy_init=False, 366 | num_sample=args.num_sample) 367 | print("Data Aug = %s" % str(transform)) 368 | print("Data Aug for SSV2 = %s" % str(transform_ssv2)) 369 | return dataset 370 | -------------------------------------------------------------------------------- /single_modality/datasets/kinetics_sparse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import io 4 | import random 5 | import numpy as np 6 | from numpy.lib.function_base import disp 7 | import torch 8 | from torchvision import transforms 9 | import warnings 10 | from decord import VideoReader, cpu 11 | from torch.utils.data import Dataset 12 | from .random_erasing import RandomErasing 13 | from .video_transforms import ( 14 | Compose, Resize, CenterCrop, Normalize, 15 | create_random_augment, random_short_side_scale_jitter, 16 | random_crop, random_resized_crop_with_shift, random_resized_crop, 17 | horizontal_flip, random_short_side_scale_jitter, uniform_crop, 18 | ) 19 | from .volume_transforms import ClipToTensor 20 | 21 | try: 22 | from petrel_client.client import Client 23 | has_client = True 24 | except ImportError: 25 | has_client = False 26 | 27 | class VideoClsDataset_sparse(Dataset): 28 | """Load your own video classification dataset.""" 29 | 30 | def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8, 31 | frame_sample_rate=2, crop_size=224, short_side_size=256, 32 | new_height=256, new_width=340, keep_aspect_ratio=True, 33 | num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3, 34 | args=None): 35 | self.anno_path = anno_path 36 | self.prefix = prefix 37 | self.split = split 38 | self.mode = mode 39 | self.clip_len = clip_len 40 | self.frame_sample_rate = frame_sample_rate 41 | self.crop_size = crop_size 42 | self.short_side_size = short_side_size 43 | self.new_height = new_height 44 | self.new_width = new_width 45 | self.keep_aspect_ratio = keep_aspect_ratio 46 | self.num_segment = num_segment 47 | self.test_num_segment = test_num_segment 48 | self.num_crop = num_crop 49 | self.test_num_crop = test_num_crop 50 | self.args = args 51 | self.aug = False 52 | self.rand_erase = False 53 | assert num_segment == 1 54 | if self.mode in ['train']: 55 | self.aug = True 56 | if self.args.reprob > 0: 57 | self.rand_erase = True 58 | if VideoReader is None: 59 | raise ImportError("Unable to import `decord` which is required to read videos.") 60 | 61 | import pandas as pd 62 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split) 63 | self.dataset_samples = list(cleaned.values[:, 0]) 64 | self.label_array = list(cleaned.values[:, 1]) 65 | 66 | self.client = None 67 | if has_client: 68 | self.client = Client('~/petreloss.conf') 69 | 70 | if (mode == 'train'): 71 | pass 72 | 73 | elif (mode == 'validation'): 74 | self.data_transform = Compose([ 75 | Resize(self.short_side_size, interpolation='bilinear'), 76 | CenterCrop(size=(self.crop_size, self.crop_size)), 77 | ClipToTensor(), 78 | Normalize(mean=[0.485, 0.456, 0.406], 79 | std=[0.229, 0.224, 0.225]) 80 | ]) 81 | elif mode == 'test': 82 | self.data_resize = Compose([ 83 | Resize(size=(short_side_size), interpolation='bilinear') 84 | ]) 85 | self.data_transform = Compose([ 86 | ClipToTensor(), 87 | Normalize(mean=[0.485, 0.456, 0.406], 88 | std=[0.229, 0.224, 0.225]) 89 | ]) 90 | self.test_seg = [] 91 | self.test_dataset = [] 92 | self.test_label_array = [] 93 | for ck in range(self.test_num_segment): 94 | for cp in range(self.test_num_crop): 95 | for idx in range(len(self.label_array)): 96 | sample_label = self.label_array[idx] 97 | self.test_label_array.append(sample_label) 98 | self.test_dataset.append(self.dataset_samples[idx]) 99 | self.test_seg.append((ck, cp)) 100 | 101 | def __getitem__(self, index): 102 | if self.mode == 'train': 103 | args = self.args 104 | 105 | sample = self.dataset_samples[index] 106 | buffer = self.loadvideo_decord(sample, chunk_nb=-1) # T H W C 107 | if len(buffer) == 0: 108 | while len(buffer) == 0: 109 | warnings.warn("video {} not correctly loaded during training".format(sample)) 110 | index = np.random.randint(self.__len__()) 111 | sample = self.dataset_samples[index] 112 | buffer = self.loadvideo_decord(sample, chunk_nb=-1) 113 | 114 | if args.num_sample > 1: 115 | frame_list = [] 116 | label_list = [] 117 | index_list = [] 118 | for _ in range(args.num_sample): 119 | new_frames = self._aug_frame(buffer, args) 120 | label = self.label_array[index] 121 | frame_list.append(new_frames) 122 | label_list.append(label) 123 | index_list.append(index) 124 | return frame_list, label_list, index_list, {} 125 | else: 126 | buffer = self._aug_frame(buffer, args) 127 | 128 | return buffer, self.label_array[index], index, {} 129 | 130 | elif self.mode == 'validation': 131 | sample = self.dataset_samples[index] 132 | buffer = self.loadvideo_decord(sample, chunk_nb=0) 133 | if len(buffer) == 0: 134 | while len(buffer) == 0: 135 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 136 | index = np.random.randint(self.__len__()) 137 | sample = self.dataset_samples[index] 138 | buffer = self.loadvideo_decord(sample, chunk_nb=0) 139 | buffer = self.data_transform(buffer) 140 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 141 | 142 | elif self.mode == 'test': 143 | sample = self.test_dataset[index] 144 | chunk_nb, split_nb = self.test_seg[index] 145 | buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb) 146 | 147 | while len(buffer) == 0: 148 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 149 | str(self.test_dataset[index]), chunk_nb, split_nb)) 150 | index = np.random.randint(self.__len__()) 151 | sample = self.test_dataset[index] 152 | chunk_nb, split_nb = self.test_seg[index] 153 | buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb) 154 | 155 | buffer = self.data_resize(buffer) 156 | if isinstance(buffer, list): 157 | buffer = np.stack(buffer, 0) 158 | 159 | if self.test_num_crop == 1: 160 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) / 2 161 | spatial_start = int(spatial_step) 162 | else: 163 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 164 | / (self.test_num_crop - 1) 165 | spatial_start = int(split_nb * spatial_step) 166 | 167 | if buffer.shape[1] >= buffer.shape[2]: 168 | buffer = buffer[:, spatial_start:spatial_start + self.short_side_size, :, :] 169 | else: 170 | buffer = buffer[:, :, spatial_start:spatial_start + self.short_side_size, :] 171 | 172 | buffer = self.data_transform(buffer) 173 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 174 | chunk_nb, split_nb 175 | else: 176 | raise NameError('mode {} unkown'.format(self.mode)) 177 | 178 | def _aug_frame( 179 | self, 180 | buffer, 181 | args, 182 | ): 183 | 184 | aug_transform = create_random_augment( 185 | input_size=(self.crop_size, self.crop_size), 186 | auto_augment=args.aa, 187 | interpolation=args.train_interpolation, 188 | ) 189 | 190 | buffer = [ 191 | transforms.ToPILImage()(frame) for frame in buffer 192 | ] 193 | 194 | buffer = aug_transform(buffer) 195 | 196 | buffer = [transforms.ToTensor()(img) for img in buffer] 197 | buffer = torch.stack(buffer) # T C H W 198 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 199 | 200 | # T H W C 201 | buffer = tensor_normalize( 202 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 203 | ) 204 | # T H W C -> C T H W. 205 | buffer = buffer.permute(3, 0, 1, 2) 206 | # Perform data augmentation. 207 | scl, asp = ( 208 | [0.08, 1.0], 209 | [0.75, 1.3333], 210 | ) 211 | 212 | buffer = spatial_sampling( 213 | buffer, 214 | spatial_idx=-1, 215 | min_scale=256, 216 | max_scale=320, 217 | crop_size=self.crop_size, 218 | random_horizontal_flip=False if args.data_set == 'SSV2' else True , 219 | inverse_uniform_sampling=False, 220 | aspect_ratio=asp, 221 | scale=scl, 222 | motion_shift=False 223 | ) 224 | 225 | if self.rand_erase: 226 | erase_transform = RandomErasing( 227 | args.reprob, 228 | mode=args.remode, 229 | max_count=args.recount, 230 | num_splits=args.recount, 231 | device="cpu", 232 | ) 233 | buffer = buffer.permute(1, 0, 2, 3) 234 | buffer = erase_transform(buffer) 235 | buffer = buffer.permute(1, 0, 2, 3) 236 | 237 | return buffer 238 | 239 | def _get_seq_frames(self, video_size, num_frames, clip_idx=-1): 240 | seg_size = max(0., float(video_size - 1) / num_frames) 241 | max_frame = int(video_size) - 1 242 | seq = [] 243 | # index from 1, must add 1 244 | if clip_idx == -1: 245 | for i in range(num_frames): 246 | start = int(np.round(seg_size * i)) 247 | end = int(np.round(seg_size * (i + 1))) 248 | idx = min(random.randint(start, end), max_frame) 249 | seq.append(idx) 250 | else: 251 | num_segment = 1 252 | if self.mode == 'test': 253 | num_segment = self.test_num_segment 254 | duration = seg_size / (num_segment + 1) 255 | for i in range(num_frames): 256 | start = int(np.round(seg_size * i)) 257 | frame_index = start + int(duration * (clip_idx + 1)) 258 | idx = min(frame_index, max_frame) 259 | seq.append(idx) 260 | return seq 261 | 262 | def loadvideo_decord(self, sample, chunk_nb=0): 263 | """Load video content using Decord""" 264 | fname = sample 265 | fname = os.path.join(self.prefix, fname) 266 | 267 | try: 268 | if self.keep_aspect_ratio: 269 | if "s3://" in fname: 270 | video_bytes = self.client.get(fname) 271 | vr = VideoReader(io.BytesIO(video_bytes), 272 | num_threads=1, 273 | ctx=cpu(0)) 274 | else: 275 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 276 | else: 277 | if "s3://" in fname: 278 | video_bytes = self.client.get(fname) 279 | vr = VideoReader(io.BytesIO(video_bytes), 280 | width=self.new_width, 281 | height=self.new_height, 282 | num_threads=1, 283 | ctx=cpu(0)) 284 | else: 285 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 286 | num_threads=1, ctx=cpu(0)) 287 | 288 | all_index = self._get_seq_frames(len(vr), self.clip_len, clip_idx=chunk_nb) 289 | vr.seek(0) 290 | buffer = vr.get_batch(all_index).asnumpy() 291 | return buffer 292 | except: 293 | print("video cannot be loaded by decord: ", fname) 294 | return [] 295 | 296 | def __len__(self): 297 | if self.mode != 'test': 298 | return len(self.dataset_samples) 299 | else: 300 | return len(self.test_dataset) 301 | 302 | 303 | def spatial_sampling( 304 | frames, 305 | spatial_idx=-1, 306 | min_scale=256, 307 | max_scale=320, 308 | crop_size=224, 309 | random_horizontal_flip=True, 310 | inverse_uniform_sampling=False, 311 | aspect_ratio=None, 312 | scale=None, 313 | motion_shift=False, 314 | ): 315 | """ 316 | Perform spatial sampling on the given video frames. If spatial_idx is 317 | -1, perform random scale, random crop, and random flip on the given 318 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 319 | with the given spatial_idx. 320 | Args: 321 | frames (tensor): frames of images sampled from the video. The 322 | dimension is `num frames` x `height` x `width` x `channel`. 323 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 324 | or 2, perform left, center, right crop if width is larger than 325 | height, and perform top, center, buttom crop if height is larger 326 | than width. 327 | min_scale (int): the minimal size of scaling. 328 | max_scale (int): the maximal size of scaling. 329 | crop_size (int): the size of height and width used to crop the 330 | frames. 331 | inverse_uniform_sampling (bool): if True, sample uniformly in 332 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 333 | scale. If False, take a uniform sample from [min_scale, 334 | max_scale]. 335 | aspect_ratio (list): Aspect ratio range for resizing. 336 | scale (list): Scale range for resizing. 337 | motion_shift (bool): Whether to apply motion shift for resizing. 338 | Returns: 339 | frames (tensor): spatially sampled frames. 340 | """ 341 | assert spatial_idx in [-1, 0, 1, 2] 342 | if spatial_idx == -1: 343 | if aspect_ratio is None and scale is None: 344 | frames, _ = random_short_side_scale_jitter( 345 | images=frames, 346 | min_size=min_scale, 347 | max_size=max_scale, 348 | inverse_uniform_sampling=inverse_uniform_sampling, 349 | ) 350 | frames, _ = random_crop(frames, crop_size) 351 | else: 352 | transform_func = ( 353 | random_resized_crop_with_shift 354 | if motion_shift 355 | else random_resized_crop 356 | ) 357 | frames = transform_func( 358 | images=frames, 359 | target_height=crop_size, 360 | target_width=crop_size, 361 | scale=scale, 362 | ratio=aspect_ratio, 363 | ) 364 | if random_horizontal_flip: 365 | frames, _ = horizontal_flip(0.5, frames) 366 | else: 367 | # The testing is deterministic and no jitter should be performed. 368 | # min_scale, max_scale, and crop_size are expect to be the same. 369 | assert len({min_scale, max_scale, crop_size}) == 1 370 | frames, _ = random_short_side_scale_jitter( 371 | frames, min_scale, max_scale 372 | ) 373 | frames, _ = uniform_crop(frames, crop_size, spatial_idx) 374 | return frames 375 | 376 | 377 | def tensor_normalize(tensor, mean, std): 378 | """ 379 | Normalize a given tensor by subtracting the mean and dividing the std. 380 | Args: 381 | tensor (tensor): tensor to normalize. 382 | mean (tensor or list): mean value to subtract. 383 | std (tensor or list): std to divide. 384 | """ 385 | if tensor.dtype == torch.uint8: 386 | tensor = tensor.float() 387 | tensor = tensor / 255.0 388 | if type(mean) == list: 389 | mean = torch.tensor(mean) 390 | if type(std) == list: 391 | std = torch.tensor(std) 392 | tensor = tensor - mean 393 | tensor = tensor / std 394 | return tensor 395 | 396 | -------------------------------------------------------------------------------- /single_modality/datasets/mae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import io 4 | import numpy as np 5 | import torch 6 | import decord 7 | from PIL import Image 8 | from decord import VideoReader, cpu 9 | import random 10 | 11 | try: 12 | from petrel_client.client import Client 13 | has_client = True 14 | except ImportError: 15 | has_client = False 16 | 17 | 18 | class VideoMAE(torch.utils.data.Dataset): 19 | """Load your own video classification dataset. 20 | Parameters 21 | ---------- 22 | root : str, required. 23 | Path to the root folder storing the dataset. 24 | setting : str, required. 25 | A text file describing the dataset, each line per video sample. 26 | There are three items in each line: (1) video path; (2) video length and (3) video label. 27 | prefix : str, required. 28 | The prefix for loading data. 29 | split : str, required. 30 | The split character for metadata. 31 | train : bool, default True. 32 | Whether to load the training or validation set. 33 | test_mode : bool, default False. 34 | Whether to perform evaluation on the test set. 35 | Usually there is three-crop or ten-crop evaluation strategy involved. 36 | name_pattern : str, default None. 37 | The naming pattern of the decoded video frames. 38 | For example, img_00012.jpg. 39 | video_ext : str, default 'mp4'. 40 | If video_loader is set to True, please specify the video format accordinly. 41 | is_color : bool, default True. 42 | Whether the loaded image is color or grayscale. 43 | modality : str, default 'rgb'. 44 | Input modalities, we support only rgb video frames for now. 45 | Will add support for rgb difference image and optical flow image later. 46 | num_segments : int, default 1. 47 | Number of segments to evenly divide the video into clips. 48 | A useful technique to obtain global video-level information. 49 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. 50 | num_crop : int, default 1. 51 | Number of crops for each image. default is 1. 52 | Common choices are three crops and ten crops during evaluation. 53 | new_length : int, default 1. 54 | The length of input video clip. Default is a single image, but it can be multiple video frames. 55 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames. 56 | new_step : int, default 1. 57 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. 58 | new_step=2 means we will extract a video clip of every other frame. 59 | temporal_jitter : bool, default False. 60 | Whether to temporally jitter if new_step > 1. 61 | video_loader : bool, default False. 62 | Whether to use video loader to load data. 63 | use_decord : bool, default True. 64 | Whether to use Decord video loader to load data. Otherwise load image. 65 | transform : function, default None. 66 | A function that takes data and label and transforms them. 67 | data_aug : str, default 'v1'. 68 | Different types of data augmentation auto. Supports v1, v2, v3 and v4. 69 | lazy_init : bool, default False. 70 | If set to True, build a dataset instance without loading any dataset. 71 | """ 72 | def __init__(self, 73 | root, 74 | setting, 75 | prefix='', 76 | split=' ', 77 | train=True, 78 | test_mode=False, 79 | name_pattern='img_%05d.jpg', 80 | video_ext='mp4', 81 | is_color=True, 82 | modality='rgb', 83 | num_segments=1, 84 | num_crop=1, 85 | new_length=1, 86 | new_step=1, 87 | transform=None, 88 | temporal_jitter=False, 89 | video_loader=False, 90 | use_decord=True, 91 | lazy_init=False, 92 | num_sample=1, 93 | ): 94 | 95 | super(VideoMAE, self).__init__() 96 | self.root = root 97 | self.setting = setting 98 | self.prefix = prefix 99 | self.split = split 100 | self.train = train 101 | self.test_mode = test_mode 102 | self.is_color = is_color 103 | self.modality = modality 104 | self.num_segments = num_segments 105 | self.num_crop = num_crop 106 | self.new_length = new_length 107 | self.new_step = new_step 108 | self.skip_length = self.new_length * self.new_step 109 | self.temporal_jitter = temporal_jitter 110 | self.name_pattern = name_pattern 111 | self.video_loader = video_loader 112 | self.video_ext = video_ext 113 | self.use_decord = use_decord 114 | self.transform = transform 115 | self.lazy_init = lazy_init 116 | self.num_sample = num_sample 117 | 118 | # sparse sampling, num_segments != 1 119 | if self.num_segments != 1: 120 | print('Use sparse sampling, change frame and stride') 121 | self.new_length = self.num_segments 122 | self.skip_length = 1 123 | 124 | self.client = None 125 | if has_client: 126 | self.client = Client('~/petreloss.conf') 127 | 128 | if not self.lazy_init: 129 | self.clips = self._make_dataset(root, setting) 130 | if len(self.clips) == 0: 131 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" 132 | "Check your data directory (opt.data-dir).")) 133 | 134 | def __getitem__(self, index): 135 | while True: 136 | try: 137 | images = None 138 | if self.use_decord: 139 | directory, target = self.clips[index] 140 | if self.video_loader: 141 | if '.' in directory.split('/')[-1]: 142 | # data in the "setting" file already have extension, e.g., demo.mp4 143 | video_name = directory 144 | else: 145 | # data in the "setting" file do not have extension, e.g., demo 146 | # So we need to provide extension (i.e., .mp4) to complete the file name. 147 | video_name = '{}.{}'.format(directory, self.video_ext) 148 | 149 | video_name = os.path.join(self.prefix, video_name) 150 | if video_name.startswith('s3') or video_name.startswith('p2:s3'): 151 | video_bytes = self.client.get(video_name) 152 | decord_vr = VideoReader(io.BytesIO(video_bytes), 153 | num_threads=1, 154 | ctx=cpu(0)) 155 | else: 156 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0)) 157 | duration = len(decord_vr) 158 | 159 | segment_indices, skip_offsets = self._sample_train_indices(duration) 160 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets) 161 | 162 | else: 163 | video_name, total_frame, target = self.clips[index] 164 | video_name = os.path.join(self.prefix, video_name) 165 | 166 | segment_indices, skip_offsets = self._sample_train_indices(total_frame) 167 | frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets) 168 | images = [] 169 | for idx in frame_id_list: 170 | frame_fname = os.path.join(video_name, self.name_pattern.format(idx)) 171 | img_bytes = self.client.get(frame_fname) 172 | img_np = np.frombuffer(img_bytes, np.uint8) 173 | img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) 174 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) 175 | images.append(Image.fromarray(img)) 176 | if images is not None: 177 | break 178 | except Exception as e: 179 | print("Failed to load video from {} with error {}".format( 180 | video_name, e)) 181 | index = random.randint(0, len(self.clips) - 1) 182 | 183 | if self.num_sample > 1: 184 | process_data_list = [] 185 | mask_list = [] 186 | for _ in range(self.num_sample): 187 | process_data, mask = self.transform((images, None)) 188 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) 189 | process_data_list.append(process_data) 190 | mask_list.append(mask) 191 | return process_data_list, mask_list 192 | else: 193 | process_data, mask = self.transform((images, None)) # T*C,H,W 194 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W 195 | return (process_data, mask) 196 | 197 | def __len__(self): 198 | return len(self.clips) 199 | 200 | def _make_dataset(self, directory, setting): 201 | if not os.path.exists(setting): 202 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) 203 | clips = [] 204 | 205 | print(f'Load dataset using decord: {self.use_decord}') 206 | with open(setting) as split_f: 207 | data = split_f.readlines() 208 | for line in data: 209 | line_info = line.split(self.split) 210 | if len(line_info) < 2: 211 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) 212 | if self.use_decord: 213 | # line format: video_path, video_label 214 | clip_path = os.path.join(line_info[0]) 215 | target = int(line_info[1]) 216 | item = (clip_path, target) 217 | else: 218 | # line format: video_path, video_duration, video_label 219 | clip_path = os.path.join(line_info[0]) 220 | total_frame = int(line_info[1]) 221 | target = int(line_info[2]) 222 | item = (clip_path, total_frame, target) 223 | clips.append(item) 224 | return clips 225 | 226 | def _sample_train_indices(self, num_frames): 227 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments 228 | if average_duration > 0: 229 | offsets = np.multiply(list(range(self.num_segments)), 230 | average_duration) 231 | offsets = offsets + np.random.randint(average_duration, 232 | size=self.num_segments) 233 | elif num_frames > max(self.num_segments, self.skip_length): 234 | offsets = np.sort(np.random.randint( 235 | num_frames - self.skip_length + 1, 236 | size=self.num_segments)) 237 | else: 238 | offsets = np.zeros((self.num_segments,)) 239 | 240 | if self.temporal_jitter: 241 | skip_offsets = np.random.randint( 242 | self.new_step, size=self.skip_length // self.new_step) 243 | else: 244 | skip_offsets = np.zeros( 245 | self.skip_length // self.new_step, dtype=int) 246 | return offsets + 1, skip_offsets 247 | 248 | def _get_frame_id_list(self, duration, indices, skip_offsets): 249 | frame_id_list = [] 250 | for seg_ind in indices: 251 | offset = int(seg_ind) 252 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 253 | if offset + skip_offsets[i] <= duration: 254 | frame_id = offset + skip_offsets[i] - 1 255 | else: 256 | frame_id = offset - 1 257 | frame_id_list.append(frame_id) 258 | if offset + self.new_step < duration: 259 | offset += self.new_step 260 | return frame_id_list 261 | 262 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets): 263 | sampled_list = [] 264 | frame_id_list = [] 265 | for seg_ind in indices: 266 | offset = int(seg_ind) 267 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 268 | if offset + skip_offsets[i] <= duration: 269 | frame_id = offset + skip_offsets[i] - 1 270 | else: 271 | frame_id = offset - 1 272 | frame_id_list.append(frame_id) 273 | if offset + self.new_step < duration: 274 | offset += self.new_step 275 | try: 276 | video_data = video_reader.get_batch(frame_id_list).asnumpy() 277 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 278 | except: 279 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) 280 | return sampled_list 281 | -------------------------------------------------------------------------------- /single_modality/datasets/mae_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import io 4 | import numpy as np 5 | import torch 6 | import decord 7 | from PIL import Image 8 | from decord import VideoReader, cpu 9 | import random 10 | 11 | try: 12 | from petrel_client.client import Client 13 | has_client = True 14 | except ImportError: 15 | has_client = False 16 | 17 | 18 | class VideoMAE_multi(torch.utils.data.Dataset): 19 | """Load your own video classification dataset. 20 | Parameters 21 | ---------- 22 | root : str, required. 23 | Path to the root folder storing the dataset. 24 | setting : str, required. 25 | A text file describing the dataset, each line per video sample. 26 | There are three items in each line: (1) video path; (2) video length and (3) video label. 27 | prefix : str, required. 28 | The prefix for loading data. 29 | split : str, required. 30 | The split character for metadata. 31 | train : bool, default True. 32 | Whether to load the training or validation set. 33 | test_mode : bool, default False. 34 | Whether to perform evaluation on the test set. 35 | Usually there is three-crop or ten-crop evaluation strategy involved. 36 | name_pattern : str, default None. 37 | The naming pattern of the decoded video frames. 38 | For example, img_00012.jpg. 39 | is_color : bool, default True. 40 | Whether the loaded image is color or grayscale. 41 | modality : str, default 'rgb'. 42 | Input modalities, we support only rgb video frames for now. 43 | Will add support for rgb difference image and optical flow image later. 44 | num_segments : int, default 1. 45 | Number of segments to evenly divide the video into clips. 46 | A useful technique to obtain global video-level information. 47 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. 48 | num_crop : int, default 1. 49 | Number of crops for each image. default is 1. 50 | Common choices are three crops and ten crops during evaluation. 51 | new_length : int, default 1. 52 | The length of input video clip. Default is a single image, but it can be multiple video frames. 53 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames. 54 | new_step : int, default 1. 55 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. 56 | new_step=2 means we will extract a video clip of every other frame. 57 | temporal_jitter : bool, default False. 58 | Whether to temporally jitter if new_step > 1. 59 | video_loader : bool, default False. 60 | Whether to use video loader to load data. 61 | use_decord : bool, default True. 62 | Whether to use Decord video loader to load data. Otherwise load image. 63 | transform : function, default None. 64 | A function that takes data and label and transforms them. 65 | transform_ssv2 : function, default None. 66 | A function that takes data and label and transforms them. 67 | data_aug : str, default 'v1'. 68 | Different types of data augmentation auto. Supports v1, v2, v3 and v4. 69 | lazy_init : bool, default False. 70 | If set to True, build a dataset instance without loading any dataset. 71 | """ 72 | def __init__(self, 73 | root, 74 | setting, 75 | prefix='', 76 | split=' ', 77 | train=True, 78 | test_mode=False, 79 | name_pattern='img_%05d.jpg', 80 | is_color=True, 81 | modality='rgb', 82 | num_segments=1, 83 | num_crop=1, 84 | new_length=1, 85 | new_step=1, 86 | transform=None, 87 | transform_ssv2=None, 88 | temporal_jitter=False, 89 | video_loader=False, 90 | use_decord=True, 91 | lazy_init=False, 92 | num_sample=1, 93 | ): 94 | 95 | super(VideoMAE_multi, self).__init__() 96 | self.root = root 97 | self.setting = setting 98 | self.prefix = prefix 99 | self.split = split 100 | self.train = train 101 | self.test_mode = test_mode 102 | self.is_color = is_color 103 | self.modality = modality 104 | self.num_segments = num_segments 105 | self.num_crop = num_crop 106 | self.new_length = new_length 107 | self.new_step = new_step 108 | self.skip_length = self.new_length * self.new_step 109 | self.temporal_jitter = temporal_jitter 110 | self.name_pattern = name_pattern 111 | self.video_loader = video_loader 112 | self.use_decord = use_decord 113 | self.transform = transform 114 | self.transform_ssv2 = transform_ssv2 115 | self.lazy_init = lazy_init 116 | self.num_sample = num_sample 117 | 118 | assert use_decord == True, "Only support to read video now!" 119 | 120 | # sparse sampling, num_segments != 1 121 | if self.num_segments != 1: 122 | print('Use sparse sampling, change frame and stride') 123 | self.new_length = self.num_segments 124 | self.skip_length = 1 125 | 126 | self.client = None 127 | if has_client: 128 | self.client = Client() 129 | 130 | if not self.lazy_init: 131 | self.clips = self._make_dataset(root, setting) 132 | if len(self.clips) == 0: 133 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" 134 | "Check your data directory (opt.data-dir).")) 135 | 136 | def __getitem__(self, index): 137 | idx = 0 138 | while True: 139 | try: 140 | images = None 141 | if self.use_decord: 142 | source, path, total_time, start_time, end_time, target = self.clips[index] 143 | if self.video_loader: 144 | video_name = os.path.join(self.prefix, path) 145 | # print(video_name) 146 | if "s3://" in video_name: 147 | video_bytes = self.client.get(video_name) 148 | # print(f'Got {video_name}') 149 | decord_vr = VideoReader(io.BytesIO(video_bytes), 150 | num_threads=1, 151 | ctx=cpu(0)) 152 | else: 153 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0)) 154 | duration = len(decord_vr) 155 | start_index = 0 156 | 157 | if total_time!= -1 and start_time != -1 and end_time != -1: 158 | fps = duration / total_time 159 | duration = int(fps * (end_time - start_time)) 160 | start_index = int(fps * start_time) 161 | segment_indices, skip_offsets = self._sample_train_indices(duration, start_index) 162 | images = self._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets) 163 | else: 164 | raise NotImplementedError 165 | 166 | if images is not None: 167 | break 168 | except Exception as e: 169 | print("Failed to load video from {} with error {}".format( 170 | video_name, e)) 171 | 172 | idx += 1 173 | if idx >= 11: 174 | idx = 0 175 | index = random.randint(0, len(self.clips) - 1) 176 | print(f'retry with new index {index}') 177 | else: 178 | print(f'retry with video_name {video_name}') 179 | 180 | if self.num_sample > 1: 181 | process_data_list = [] 182 | mask_list = [] 183 | for _ in range(self.num_sample): 184 | if source == "ssv2": 185 | process_data, mask = self.transform_ssv2((images, None)) 186 | else: 187 | process_data, mask = self.transform((images, None)) 188 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) 189 | process_data_list.append(process_data) 190 | mask_list.append(mask) 191 | return process_data_list, mask_list 192 | else: 193 | if source == "ssv2": 194 | process_data, mask = self.transform_ssv2((images, None)) # T*C,H,W 195 | else: 196 | process_data, mask = self.transform((images, None)) # T*C,H,W 197 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W 198 | return (process_data, mask) 199 | 200 | def __len__(self): 201 | return len(self.clips) 202 | 203 | def _make_dataset(self, directory, setting): 204 | if not os.path.exists(setting): 205 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) 206 | clips = [] 207 | 208 | print(f'Load dataset using decord: {self.use_decord}') 209 | with open(setting) as split_f: 210 | data = split_f.readlines() 211 | for line in data: 212 | line_info = line.split(self.split) 213 | if len(line_info) < 2: 214 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) 215 | if self.use_decord: 216 | # line format: source, path, total_time, start_time, end_time, target 217 | source = line_info[0] 218 | path = line_info[1] 219 | total_time = float(line_info[2]) 220 | start_time = float(line_info[3]) 221 | end_time = float(line_info[4]) 222 | target = int(line_info[5]) 223 | item = (source, path, total_time, start_time, end_time, target) 224 | else: 225 | raise NotImplementedError 226 | 227 | clips.append(item) 228 | return clips 229 | 230 | def _sample_train_indices(self, num_frames, start_index=0): 231 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments 232 | if average_duration > 0: 233 | offsets = np.multiply(list(range(self.num_segments)), 234 | average_duration) 235 | offsets = offsets + np.random.randint(average_duration, 236 | size=self.num_segments) 237 | elif num_frames > max(self.num_segments, self.skip_length): 238 | offsets = np.sort(np.random.randint( 239 | num_frames - self.skip_length + 1, 240 | size=self.num_segments)) 241 | else: 242 | offsets = np.zeros((self.num_segments,)) 243 | 244 | if self.temporal_jitter: 245 | skip_offsets = np.random.randint( 246 | self.new_step, size=self.skip_length // self.new_step) 247 | else: 248 | skip_offsets = np.zeros( 249 | self.skip_length // self.new_step, dtype=int) 250 | return offsets + start_index, skip_offsets 251 | 252 | def _get_frame_id_list(self, duration, indices, skip_offsets): 253 | frame_id_list = [] 254 | for seg_ind in indices: 255 | offset = int(seg_ind) 256 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 257 | if offset + skip_offsets[i] <= duration: 258 | frame_id = offset + skip_offsets[i] - 1 259 | else: 260 | frame_id = offset - 1 261 | frame_id_list.append(frame_id) 262 | if offset + self.new_step < duration: 263 | offset += self.new_step 264 | return frame_id_list 265 | 266 | def _video_TSN_decord_batch_loader(self, video_name, video_reader, duration, indices, skip_offsets): 267 | sampled_list = [] 268 | frame_id_list = [] 269 | for seg_ind in indices: 270 | offset = int(seg_ind) 271 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 272 | if offset + skip_offsets[i] <= duration: 273 | frame_id = offset + skip_offsets[i] - 1 274 | else: 275 | frame_id = offset - 1 276 | frame_id_list.append(frame_id) 277 | if offset + self.new_step < duration: 278 | offset += self.new_step 279 | try: 280 | video_data = video_reader.get_batch(frame_id_list).asnumpy() 281 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 282 | except: 283 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, video_name, duration)) 284 | return sampled_list 285 | -------------------------------------------------------------------------------- /single_modality/datasets/mae_multi_ofa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import io 4 | import numpy as np 5 | import torch 6 | import decord 7 | from PIL import Image 8 | from decord import VideoReader, cpu 9 | import random 10 | 11 | from torchvision import transforms 12 | 13 | from .transforms import * 14 | 15 | try: 16 | from petrel_client.client import Client 17 | has_client = True 18 | except ImportError: 19 | has_client = False 20 | 21 | 22 | class VideoMAE_multi_ofa(torch.utils.data.Dataset): 23 | """Load your own video classification dataset. 24 | Parameters 25 | ---------- 26 | root : str, required. 27 | Path to the root folder storing the dataset. 28 | setting : str, required. 29 | A text file describing the dataset, each line per video sample. 30 | There are three items in each line: (1) video path; (2) video length and (3) video label. 31 | prefix : str, required. 32 | The prefix for loading data. 33 | split : str, required. 34 | The split character for metadata. 35 | train : bool, default True. 36 | Whether to load the training or validation set. 37 | test_mode : bool, default False. 38 | Whether to perform evaluation on the test set. 39 | Usually there is three-crop or ten-crop evaluation strategy involved. 40 | name_pattern : str, default None. 41 | The naming pattern of the decoded video frames. 42 | For example, img_00012.jpg. 43 | is_color : bool, default True. 44 | Whether the loaded image is color or grayscale. 45 | modality : str, default 'rgb'. 46 | Input modalities, we support only rgb video frames for now. 47 | Will add support for rgb difference image and optical flow image later. 48 | num_segments : int, default 1. 49 | Number of segments to evenly divide the video into clips. 50 | A useful technique to obtain global video-level information. 51 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. 52 | num_crop : int, default 1. 53 | Number of crops for each image. default is 1. 54 | Common choices are three crops and ten crops during evaluation. 55 | new_length : int, default 1. 56 | The length of input video clip. Default is a single image, but it can be multiple video frames. 57 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames. 58 | new_step : int, default 1. 59 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. 60 | new_step=2 means we will extract a video clip of every other frame. 61 | temporal_jitter : bool, default False. 62 | Whether to temporally jitter if new_step > 1. 63 | video_loader : bool, default False. 64 | Whether to use video loader to load data. 65 | use_decord : bool, default True. 66 | Whether to use Decord video loader to load data. Otherwise load image. 67 | transform : function, default None. 68 | A function that takes data and label and transforms them. 69 | transform_ssv2 : function, default None. 70 | A function that takes data and label and transforms them. 71 | data_aug : str, default 'v1'. 72 | Different types of data augmentation auto. Supports v1, v2, v3 and v4. 73 | lazy_init : bool, default False. 74 | If set to True, build a dataset instance without loading any dataset. 75 | """ 76 | def __init__(self, 77 | root, 78 | setting, 79 | prefix='', 80 | split=' ', 81 | train=True, 82 | test_mode=False, 83 | name_pattern='img_%05d.jpg', 84 | is_color=True, 85 | modality='rgb', 86 | num_segments=1, 87 | num_crop=1, 88 | new_length=1, 89 | new_step=1, 90 | transform=None, 91 | transform_ssv2=None, 92 | temporal_jitter=False, 93 | video_loader=False, 94 | use_decord=True, 95 | lazy_init=False, 96 | num_sample=1, 97 | ): 98 | 99 | super(VideoMAE_multi_ofa, self).__init__() 100 | self.root = root 101 | self.setting = setting 102 | self.prefix = prefix 103 | self.split = split 104 | self.train = train 105 | self.test_mode = test_mode 106 | self.is_color = is_color 107 | self.modality = modality 108 | self.num_segments = num_segments 109 | self.num_crop = num_crop 110 | self.new_length = new_length 111 | self.new_step = new_step 112 | self.skip_length = self.new_length * self.new_step 113 | self.temporal_jitter = temporal_jitter 114 | self.name_pattern = name_pattern 115 | self.video_loader = video_loader 116 | self.use_decord = use_decord 117 | self.transform = transform 118 | 119 | self.transforms = {} 120 | for crop_size in range(140, 280, 28): 121 | input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 122 | input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 123 | normalize = GroupNormalize(input_mean, input_std) 124 | train_augmentation = GroupMultiScaleCrop(crop_size, [1, .875, .75, .66]) 125 | transform = transforms.Compose([ 126 | train_augmentation, 127 | GroupRandomHorizontalFlip(flip=True), 128 | Stack(roll=False), 129 | ToTorchFormatTensor(div=True), 130 | normalize, 131 | ]) 132 | self.transforms[crop_size] = transform 133 | 134 | self.transform_ssv2 = transform_ssv2 135 | 136 | self.transform_ssv2s = {} 137 | for crop_size in range(140, 280, 28): 138 | input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 139 | input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 140 | normalize = GroupNormalize(input_mean, input_std) 141 | train_augmentation = GroupMultiScaleCrop(crop_size, [1, .875, .75, .66]) 142 | transform = transforms.Compose([ 143 | train_augmentation, 144 | GroupRandomHorizontalFlip(flip=False), 145 | Stack(roll=False), 146 | ToTorchFormatTensor(div=True), 147 | normalize, 148 | ]) 149 | self.transform_ssv2s[crop_size] = transform 150 | 151 | self.lazy_init = lazy_init 152 | self.num_sample = num_sample 153 | 154 | assert use_decord == True, "Only support to read video now!" 155 | 156 | # sparse sampling, num_segments != 1 157 | if self.num_segments != 1: 158 | print('Use sparse sampling, change frame and stride') 159 | self.new_length = self.num_segments 160 | self.skip_length = 1 161 | 162 | self.client = None 163 | if has_client: 164 | self.client = Client() 165 | 166 | if not self.lazy_init: 167 | self.clips = self._make_dataset(root, setting) 168 | if len(self.clips) == 0: 169 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" 170 | "Check your data directory (opt.data-dir).")) 171 | 172 | def __getitem__(self, index): 173 | idx = 0 174 | index, self.num_segments, self.crop_size = index 175 | self.new_length = self.num_segments 176 | 177 | while True: 178 | try: 179 | images = None 180 | if self.use_decord: 181 | source, path, total_time, start_time, end_time, target = self.clips[index] 182 | if self.video_loader: 183 | video_name = os.path.join(self.prefix, path) 184 | # print(video_name) 185 | if "s3://" in video_name: 186 | video_bytes = self.client.get(video_name) 187 | # print(f'Got {video_name}') 188 | decord_vr = VideoReader(io.BytesIO(video_bytes), 189 | num_threads=1, 190 | ctx=cpu(0)) 191 | else: 192 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0)) 193 | duration = len(decord_vr) 194 | start_index = 0 195 | 196 | if total_time!= -1 and start_time != -1 and end_time != -1: 197 | fps = duration / total_time 198 | duration = int(fps * (end_time - start_time)) 199 | start_index = int(fps * start_time) 200 | segment_indices, skip_offsets = self._sample_train_indices(duration, start_index) 201 | images = self._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets) 202 | else: 203 | raise NotImplementedError 204 | 205 | if images is not None: 206 | break 207 | except Exception as e: 208 | print("Failed to load video from {} with error {}".format( 209 | video_name, e)) 210 | 211 | idx += 1 212 | if idx >= 11: 213 | idx = 0 214 | index = random.randint(0, len(self.clips) - 1) 215 | print(f'retry with new index {index}') 216 | else: 217 | print(f'retry with video_name {video_name}') 218 | 219 | if self.num_sample > 1: 220 | raise NotImplementedError 221 | else: 222 | if source == "ssv2": 223 | process_data, _ = self.transform_ssv2s[self.crop_size]((images, None)) # T*C,H,W 224 | else: 225 | process_data, _ = self.transforms[self.crop_size]((images, None)) # T*C,H,W 226 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W 227 | return process_data 228 | 229 | def __len__(self): 230 | return len(self.clips) 231 | 232 | def _make_dataset(self, directory, setting): 233 | if not os.path.exists(setting): 234 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) 235 | clips = [] 236 | 237 | print(f'Load dataset using decord: {self.use_decord}') 238 | with open(setting) as split_f: 239 | data = split_f.readlines() 240 | for line in data: 241 | line_info = line.split(self.split) 242 | if len(line_info) < 2: 243 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) 244 | if self.use_decord: 245 | # line format: source, path, total_time, start_time, end_time, target 246 | source = line_info[0] 247 | path = line_info[1] 248 | total_time = float(line_info[2]) 249 | start_time = float(line_info[3]) 250 | end_time = float(line_info[4]) 251 | target = int(line_info[5]) 252 | item = (source, path, total_time, start_time, end_time, target) 253 | else: 254 | raise NotImplementedError 255 | 256 | clips.append(item) 257 | return clips 258 | 259 | def _sample_train_indices(self, num_frames, start_index=0): 260 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments 261 | if average_duration > 0: 262 | offsets = np.multiply(list(range(self.num_segments)), 263 | average_duration) 264 | offsets = offsets + np.random.randint(average_duration, 265 | size=self.num_segments) 266 | elif num_frames > max(self.num_segments, self.skip_length): 267 | offsets = np.sort(np.random.randint( 268 | num_frames - self.skip_length + 1, 269 | size=self.num_segments)) 270 | else: 271 | offsets = np.zeros((self.num_segments,)) 272 | 273 | if self.temporal_jitter: 274 | skip_offsets = np.random.randint( 275 | self.new_step, size=self.skip_length // self.new_step) 276 | else: 277 | skip_offsets = np.zeros( 278 | self.skip_length // self.new_step, dtype=int) 279 | return offsets + start_index, skip_offsets 280 | 281 | def _get_frame_id_list(self, duration, indices, skip_offsets): 282 | frame_id_list = [] 283 | for seg_ind in indices: 284 | offset = int(seg_ind) 285 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 286 | if offset + skip_offsets[i] <= duration: 287 | frame_id = offset + skip_offsets[i] - 1 288 | else: 289 | frame_id = offset - 1 290 | frame_id_list.append(frame_id) 291 | if offset + self.new_step < duration: 292 | offset += self.new_step 293 | return frame_id_list 294 | 295 | def _video_TSN_decord_batch_loader(self, video_name, video_reader, duration, indices, skip_offsets): 296 | sampled_list = [] 297 | frame_id_list = [] 298 | for seg_ind in indices: 299 | offset = int(seg_ind) 300 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 301 | if offset + skip_offsets[i] <= duration: 302 | frame_id = offset + skip_offsets[i] - 1 303 | else: 304 | frame_id = offset - 1 305 | frame_id_list.append(frame_id) 306 | if offset + self.new_step < duration: 307 | offset += self.new_step 308 | try: 309 | video_data = video_reader.get_batch(frame_id_list).asnumpy() 310 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 311 | except: 312 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, video_name, duration)) 313 | return sampled_list 314 | -------------------------------------------------------------------------------- /single_modality/datasets/mae_ofa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import io 4 | import numpy as np 5 | import torch 6 | import decord 7 | from PIL import Image 8 | from decord import VideoReader, cpu 9 | import random 10 | 11 | try: 12 | from petrel_client.client import Client 13 | has_client = True 14 | except ImportError: 15 | has_client = False 16 | 17 | from torchvision import transforms 18 | 19 | from .transforms import * 20 | 21 | import time 22 | 23 | class VideoMAE_ofa(torch.utils.data.Dataset): 24 | """Load your own video classification dataset. 25 | Parameters 26 | ---------- 27 | root : str, required. 28 | Path to the root folder storing the dataset. 29 | setting : str, required. 30 | A text file describing the dataset, each line per video sample. 31 | There are three items in each line: (1) video path; (2) video length and (3) video label. 32 | prefix : str, required. 33 | The prefix for loading data. 34 | split : str, required. 35 | The split character for metadata. 36 | train : bool, default True. 37 | Whether to load the training or validation set. 38 | test_mode : bool, default False. 39 | Whether to perform evaluation on the test set. 40 | Usually there is three-crop or ten-crop evaluation strategy involved. 41 | name_pattern : str, default None. 42 | The naming pattern of the decoded video frames. 43 | For example, img_00012.jpg. 44 | video_ext : str, default 'mp4'. 45 | If video_loader is set to True, please specify the video format accordinly. 46 | is_color : bool, default True. 47 | Whether the loaded image is color or grayscale. 48 | modality : str, default 'rgb'. 49 | Input modalities, we support only rgb video frames for now. 50 | Will add support for rgb difference image and optical flow image later. 51 | num_segments : int, default 1. 52 | Number of segments to evenly divide the video into clips. 53 | A useful technique to obtain global video-level information. 54 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. 55 | num_crop : int, default 1. 56 | Number of crops for each image. default is 1. 57 | Common choices are three crops and ten crops during evaluation. 58 | new_length : int, default 1. 59 | The length of input video clip. Default is a single image, but it can be multiple video frames. 60 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames. 61 | new_step : int, default 1. 62 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. 63 | new_step=2 means we will extract a video clip of every other frame. 64 | temporal_jitter : bool, default False. 65 | Whether to temporally jitter if new_step > 1. 66 | video_loader : bool, default False. 67 | Whether to use video loader to load data. 68 | use_decord : bool, default True. 69 | Whether to use Decord video loader to load data. Otherwise load image. 70 | transform : function, default None. 71 | A function that takes data and label and transforms them. 72 | data_aug : str, default 'v1'. 73 | Different types of data augmentation auto. Supports v1, v2, v3 and v4. 74 | lazy_init : bool, default False. 75 | If set to True, build a dataset instance without loading any dataset. 76 | """ 77 | def __init__(self, 78 | root, 79 | setting, 80 | prefix='', 81 | split=' ', 82 | train=True, 83 | test_mode=False, 84 | name_pattern='img_%05d.jpg', 85 | video_ext='mp4', 86 | is_color=True, 87 | modality='rgb', 88 | num_segments=1, 89 | num_crop=1, 90 | new_length=1, 91 | new_step=1, 92 | transform=None, 93 | temporal_jitter=False, 94 | video_loader=False, 95 | use_decord=True, 96 | lazy_init=False, 97 | num_sample=1, 98 | ): 99 | 100 | super(VideoMAE_ofa, self).__init__() 101 | 102 | self.transforms = {} 103 | for crop_size in range(140, 280, 28): 104 | input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 105 | input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 106 | normalize = GroupNormalize(input_mean, input_std) 107 | train_augmentation = GroupMultiScaleCrop(crop_size, [1, .875, .75, .66]) 108 | transform = transforms.Compose([ 109 | train_augmentation, 110 | GroupRandomHorizontalFlip(flip=True), 111 | Stack(roll=False), 112 | ToTorchFormatTensor(div=True), 113 | normalize, 114 | ]) 115 | self.transforms[crop_size] = transform 116 | 117 | self.root = root 118 | self.setting = setting 119 | self.prefix = prefix 120 | self.split = split 121 | self.train = train 122 | self.test_mode = test_mode 123 | self.is_color = is_color 124 | self.modality = modality 125 | self.num_segments = num_segments 126 | self.num_crop = num_crop 127 | self.new_length = new_length 128 | self.new_step = new_step 129 | self.skip_length = self.new_length * self.new_step 130 | self.temporal_jitter = temporal_jitter 131 | self.name_pattern = name_pattern 132 | self.video_loader = video_loader 133 | self.video_ext = video_ext 134 | self.use_decord = use_decord 135 | self.transform = transform 136 | self.lazy_init = lazy_init 137 | self.num_sample = num_sample 138 | 139 | # sparse sampling, num_segments != 1 140 | if self.num_segments != 1: 141 | print('Use sparse sampling, change frame and stride') 142 | self.new_length = self.num_segments 143 | self.skip_length = 1 144 | 145 | self.client = None 146 | if has_client: 147 | self.client = Client('~/petreloss.conf') 148 | 149 | if not self.lazy_init: 150 | self.clips = self._make_dataset(root, setting) 151 | if len(self.clips) == 0: 152 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" 153 | "Check your data directory (opt.data-dir).")) 154 | 155 | def __getitem__(self, index): 156 | index, self.num_segments, self.crop_size = index 157 | self.new_length = self.num_segments 158 | 159 | while True: 160 | try: 161 | images = None 162 | if self.use_decord: 163 | directory, target = self.clips[index] 164 | if self.video_loader: 165 | if '.' in directory.split('/')[-1]: 166 | # data in the "setting" file already have extension, e.g., demo.mp4 167 | video_name = directory 168 | else: 169 | # data in the "setting" file do not have extension, e.g., demo 170 | # So we need to provide extension (i.e., .mp4) to complete the file name. 171 | video_name = '{}.{}'.format(directory, self.video_ext) 172 | 173 | video_name = os.path.join(self.prefix, video_name) 174 | if video_name.startswith('s3') or video_name.startswith('p2:s3'): 175 | video_bytes = self.client.get(video_name) 176 | decord_vr = VideoReader(io.BytesIO(video_bytes), 177 | num_threads=1, 178 | ctx=cpu(0)) 179 | else: 180 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0)) 181 | duration = len(decord_vr) 182 | 183 | segment_indices, skip_offsets = self._sample_train_indices(duration) 184 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets) 185 | 186 | else: 187 | video_name, total_frame, target = self.clips[index] 188 | video_name = os.path.join(self.prefix, video_name) 189 | 190 | segment_indices, skip_offsets = self._sample_train_indices(total_frame) 191 | frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets) 192 | images = [] 193 | for idx in frame_id_list: 194 | frame_fname = os.path.join(video_name, self.name_pattern.format(idx)) 195 | img_bytes = self.client.get(frame_fname) 196 | img_np = np.frombuffer(img_bytes, np.uint8) 197 | img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) 198 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) 199 | images.append(Image.fromarray(img)) 200 | if images is not None: 201 | break 202 | except Exception as e: 203 | print("Failed to load video from {} with error {}".format( 204 | video_name, e)) 205 | index = random.randint(0, len(self.clips) - 1) 206 | 207 | if self.num_sample > 1: 208 | raise NotImplementedError 209 | else: 210 | process_data, _ = self.transforms[self.crop_size]((images, None)) # T*C,H,W 211 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W 212 | return process_data 213 | 214 | def __len__(self): 215 | return len(self.clips) 216 | 217 | def _make_dataset(self, directory, setting): 218 | if not os.path.exists(setting): 219 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) 220 | clips = [] 221 | 222 | print(f'Load dataset using decord: {self.use_decord}') 223 | with open(setting) as split_f: 224 | data = split_f.readlines() 225 | for line in data: 226 | line_info = line.split(self.split) 227 | if len(line_info) < 2: 228 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) 229 | if self.use_decord: 230 | # line format: video_path, video_label 231 | clip_path = os.path.join(line_info[0]) 232 | target = int(line_info[1]) 233 | item = (clip_path, target) 234 | else: 235 | # line format: video_path, video_duration, video_label 236 | clip_path = os.path.join(line_info[0]) 237 | total_frame = int(line_info[1]) 238 | target = int(line_info[2]) 239 | item = (clip_path, total_frame, target) 240 | clips.append(item) 241 | return clips 242 | 243 | def _sample_train_indices(self, num_frames): 244 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments 245 | if average_duration > 0: 246 | offsets = np.multiply(list(range(self.num_segments)), 247 | average_duration) 248 | offsets = offsets + np.random.randint(average_duration, 249 | size=self.num_segments) 250 | elif num_frames > max(self.num_segments, self.skip_length): 251 | offsets = np.sort(np.random.randint( 252 | num_frames - self.skip_length + 1, 253 | size=self.num_segments)) 254 | else: 255 | offsets = np.zeros((self.num_segments,)) 256 | 257 | if self.temporal_jitter: 258 | skip_offsets = np.random.randint( 259 | self.new_step, size=self.skip_length // self.new_step) 260 | else: 261 | skip_offsets = np.zeros( 262 | self.skip_length // self.new_step, dtype=int) 263 | return offsets + 1, skip_offsets 264 | 265 | def _get_frame_id_list(self, duration, indices, skip_offsets): 266 | frame_id_list = [] 267 | for seg_ind in indices: 268 | offset = int(seg_ind) 269 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 270 | if offset + skip_offsets[i] <= duration: 271 | frame_id = offset + skip_offsets[i] - 1 272 | else: 273 | frame_id = offset - 1 274 | frame_id_list.append(frame_id) 275 | if offset + self.new_step < duration: 276 | offset += self.new_step 277 | return frame_id_list 278 | 279 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets): 280 | sampled_list = [] 281 | frame_id_list = [] 282 | for seg_ind in indices: 283 | offset = int(seg_ind) 284 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 285 | if offset + skip_offsets[i] <= duration: 286 | frame_id = offset + skip_offsets[i] - 1 287 | else: 288 | frame_id = offset - 1 289 | frame_id_list.append(frame_id) 290 | if offset + self.new_step < duration: 291 | offset += self.new_step 292 | try: 293 | video_data = video_reader.get_batch(frame_id_list).asnumpy() 294 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 295 | except: 296 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) 297 | return sampled_list 298 | -------------------------------------------------------------------------------- /single_modality/datasets/masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def topk(matrix, K, axis=1): 5 | if axis == 0: 6 | row_index = np.arange(matrix.shape[1 - axis]) 7 | topk_index = np.argpartition(-matrix, K, axis=axis)[0:K, :] 8 | topk_data = matrix[topk_index, row_index] 9 | topk_index_sort = np.argsort(-topk_data,axis=axis) 10 | topk_data_sort = topk_data[topk_index_sort,row_index] 11 | topk_index_sort = topk_index[0:K,:][topk_index_sort,row_index] 12 | else: 13 | column_index = np.arange(matrix.shape[1 - axis])[:, None] 14 | topk_index = np.argpartition(-matrix, K, axis=axis)[:, 0:K] 15 | topk_data = matrix[column_index, topk_index] 16 | topk_index_sort = np.argsort(-topk_data, axis=axis) 17 | topk_data_sort = topk_data[column_index, topk_index_sort] 18 | topk_index_sort = topk_index[:,0:K][column_index,topk_index_sort] 19 | return (topk_data_sort, topk_index_sort) 20 | 21 | 22 | class TubeMaskingGenerator: 23 | def __init__(self, input_size, mask_ratio): 24 | self.frames, self.height, self.width = input_size 25 | self.num_patches_per_frame = self.height * self.width 26 | self.total_patches = self.frames * self.num_patches_per_frame 27 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 28 | self.total_masks = self.frames * self.num_masks_per_frame 29 | 30 | def __repr__(self): 31 | repr_str = "Maks: total patches {}, mask patches {}".format( 32 | self.total_patches, self.total_masks 33 | ) 34 | return repr_str 35 | 36 | def __call__(self): 37 | mask_per_frame = np.hstack([ 38 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 39 | np.ones(self.num_masks_per_frame), 40 | ]) 41 | np.random.shuffle(mask_per_frame) 42 | mask = np.tile(mask_per_frame, (self.frames, 1)).flatten() 43 | return mask 44 | 45 | 46 | class RandomMaskingGenerator: 47 | def __init__(self, input_size, mask_ratio): 48 | if not isinstance(input_size, tuple): 49 | input_size = (input_size, ) * 3 50 | 51 | self.frames, self.height, self.width = input_size 52 | 53 | self.num_patches = self.frames * self.height * self.width # 8x14x14 54 | self.num_mask = int(mask_ratio * self.num_patches) 55 | 56 | def __repr__(self): 57 | repr_str = "Maks: total patches {}, mask patches {}".format( 58 | self.num_patches, self.num_mask) 59 | return repr_str 60 | 61 | def __call__(self): 62 | mask = np.hstack([ 63 | np.zeros(self.num_patches - self.num_mask), 64 | np.ones(self.num_mask), 65 | ]) 66 | np.random.shuffle(mask) 67 | return mask # [196*8] 68 | 69 | 70 | class TemporalConsistencyMaskingGenerator: 71 | def __init__(self, input_size, mask_ratio, mask_ratio_teacher): 72 | self.frames, self.height, self.width = input_size 73 | self.num_patches_per_frame = self.height * self.width 74 | self.total_patches = self.frames * self.num_patches_per_frame 75 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 76 | self.num_masks_per_frame_teacher = int(mask_ratio_teacher * self.num_patches_per_frame) 77 | self.total_masks = self.frames * self.num_masks_per_frame 78 | 79 | def __repr__(self): 80 | repr_str = "Maks: total patches {}, mask patches {}".format( 81 | self.total_patches, self.total_masks 82 | ) 83 | return repr_str 84 | 85 | def __call__(self): 86 | mask_per_frame = np.hstack([ 87 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 88 | np.ones(self.num_masks_per_frame), 89 | ]) 90 | if self.num_masks_per_frame_teacher != 0: 91 | mask_per_frame[-self.num_masks_per_frame_teacher:] = 2 92 | np.random.shuffle(mask_per_frame) 93 | mask_per_frame_student = (mask_per_frame > 0).astype(int) 94 | mask_per_frame_teacher = (mask_per_frame > 1).astype(int) 95 | mask_per_frame_diff = (mask_per_frame == 1).astype(int) 96 | mask_student = np.tile(mask_per_frame_student, (self.frames, 1)).flatten() 97 | mask_teacher = np.tile(mask_per_frame_teacher, (self.frames, 1)).flatten() 98 | mask_diff = np.tile(mask_per_frame_diff, (self.frames, 1)).flatten() 99 | mask_student = np.insert(mask_student, 0, 0) # For cls token 100 | mask_teacher = np.insert(mask_teacher, 0, 0) 101 | mask_diff = np.insert(mask_diff, 0, 0) 102 | return (mask_student, mask_teacher, mask_diff) 103 | 104 | 105 | class TemporalProgressiveMaskingGenerator: 106 | def __init__(self, input_size, mask_ratio): 107 | self.frames, self.height, self.width = input_size 108 | self.num_patches_per_frame = self.height * self.width 109 | self.total_patches = self.frames * self.num_patches_per_frame 110 | max_keep_patch = int((1 - mask_ratio) * self.num_patches_per_frame) 111 | min_keep_patch = int(0.05 * self.num_patches_per_frame) 112 | self.keep_patches_list = np.linspace(max_keep_patch, min_keep_patch, self.frames).astype(int) 113 | self.total_masks = self.total_patches - self.keep_patches_list.sum() 114 | def __repr__(self): 115 | repr_str = "Maks: total patches {}, mask patches {}".format( 116 | self.total_patches, self.total_masks 117 | ) 118 | return repr_str 119 | 120 | def __call__(self): 121 | 122 | rand = np.random.randn(1, self.num_patches_per_frame) 123 | mask = np.zeros((self.frames, self.num_patches_per_frame), dtype=np.bool) 124 | for i in range(self.frames): 125 | top_k, _ = topk(rand, self.keep_patches_list[i]) 126 | the_topk = top_k[0][-1] 127 | mask[i] = rand<=the_topk 128 | mask = mask.flatten().astype(int) 129 | return mask # [196*8] 130 | 131 | 132 | class TemporalCenteringProgressiveMaskingGenerator: 133 | def __init__(self, input_size, mask_ratio): 134 | self.num_frames, self.height, self.width = input_size 135 | self.num_patches_per_frame = self.height * self.width 136 | self.total_patches = self.num_frames * self.num_patches_per_frame 137 | min_mask_ratio = mask_ratio 138 | 139 | max_mask_ratio = 0.95 140 | max_keep_patch = int((1 - min_mask_ratio) * self.num_patches_per_frame) 141 | min_keep_patch = int((1 - max_mask_ratio) * self.num_patches_per_frame) 142 | patches_list = np.linspace(max_keep_patch, min_keep_patch, self.num_frames//2 ).astype(int).tolist() 143 | self.keep_patches_list = patches_list.copy() 144 | patches_list.reverse() 145 | self.keep_patches_list = patches_list + self.keep_patches_list 146 | self.total_masks = self.total_patches - sum(self.keep_patches_list) 147 | def __repr__(self): 148 | repr_str = "Maks: total patches {}, mask patches {}".format( 149 | self.total_patches, self.total_masks 150 | ) 151 | return repr_str 152 | 153 | def __call__(self): 154 | 155 | rand = np.random.randn(1, self.num_patches_per_frame) 156 | mask = np.zeros((self.num_frames, self.num_patches_per_frame), dtype=np.bool) 157 | for i in range(self.num_frames): 158 | top_k, _ = topk(rand, self.keep_patches_list[i]) 159 | the_topk = top_k[0][-1] 160 | mask[i] = rand<=the_topk 161 | mask = mask.flatten().astype(int) 162 | return mask 163 | -------------------------------------------------------------------------------- /single_modality/datasets/mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 27 | return y1 * lam + y2 * (1. - lam) 28 | 29 | 30 | def rand_bbox(img_shape, lam, margin=0., count=None): 31 | """ Standard CutMix bounding-box 32 | Generates a random square bbox based on lambda value. This impl includes 33 | support for enforcing a border margin as percent of bbox dimensions. 34 | 35 | Args: 36 | img_shape (tuple): Image shape as tuple 37 | lam (float): Cutmix lambda value 38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 39 | count (int): Number of bbox to generate 40 | """ 41 | ratio = np.sqrt(1 - lam) 42 | img_h, img_w = img_shape[-2:] 43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 47 | yl = np.clip(cy - cut_h // 2, 0, img_h) 48 | yh = np.clip(cy + cut_h // 2, 0, img_h) 49 | xl = np.clip(cx - cut_w // 2, 0, img_w) 50 | xh = np.clip(cx + cut_w // 2, 0, img_w) 51 | return yl, yh, xl, xh 52 | 53 | 54 | def rand_bbox_minmax(img_shape, minmax, count=None): 55 | """ Min-Max CutMix bounding-box 56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 57 | based on min/max percent values applied to each dimension of the input image. 58 | 59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 60 | 61 | Args: 62 | img_shape (tuple): Image shape as tuple 63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 64 | count (int): Number of bbox to generate 65 | """ 66 | assert len(minmax) == 2 67 | img_h, img_w = img_shape[-2:] 68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 70 | yl = np.random.randint(0, img_h - cut_h, size=count) 71 | xl = np.random.randint(0, img_w - cut_w, size=count) 72 | yu = yl + cut_h 73 | xu = xl + cut_w 74 | return yl, yu, xl, xu 75 | 76 | 77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 78 | """ Generate bbox and apply lambda correction. 79 | """ 80 | if ratio_minmax is not None: 81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 82 | else: 83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 84 | if correct_lam or ratio_minmax is not None: 85 | bbox_area = (yu - yl) * (xu - xl) 86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 87 | return (yl, yu, xl, xu), lam 88 | 89 | 90 | class Mixup: 91 | """ Mixup/Cutmix that applies different params to each element or whole batch 92 | 93 | Args: 94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 97 | prob (float): probability of applying mixup or cutmix per batch or element 98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 101 | label_smoothing (float): apply label smoothing to the mixed target tensor 102 | num_classes (int): number of classes for target 103 | """ 104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 106 | self.mixup_alpha = mixup_alpha 107 | self.cutmix_alpha = cutmix_alpha 108 | self.cutmix_minmax = cutmix_minmax 109 | if self.cutmix_minmax is not None: 110 | assert len(self.cutmix_minmax) == 2 111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 112 | self.cutmix_alpha = 1.0 113 | self.mix_prob = prob 114 | self.switch_prob = switch_prob 115 | self.label_smoothing = label_smoothing 116 | self.num_classes = num_classes 117 | self.mode = mode 118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 120 | 121 | def _params_per_elem(self, batch_size): 122 | lam = np.ones(batch_size, dtype=np.float32) 123 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 124 | if self.mixup_enabled: 125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 127 | lam_mix = np.where( 128 | use_cutmix, 129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 131 | elif self.mixup_alpha > 0.: 132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 133 | elif self.cutmix_alpha > 0.: 134 | use_cutmix = np.ones(batch_size, dtype=np.bool) 135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 136 | else: 137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 139 | return lam, use_cutmix 140 | 141 | def _params_per_batch(self): 142 | lam = 1. 143 | use_cutmix = False 144 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 146 | use_cutmix = np.random.rand() < self.switch_prob 147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 148 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 149 | elif self.mixup_alpha > 0.: 150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 151 | elif self.cutmix_alpha > 0.: 152 | use_cutmix = True 153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 154 | else: 155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 156 | lam = float(lam_mix) 157 | return lam, use_cutmix 158 | 159 | def _mix_elem(self, x): 160 | batch_size = len(x) 161 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 162 | x_orig = x.clone() # need to keep an unmodified original for mixing source 163 | for i in range(batch_size): 164 | j = batch_size - i - 1 165 | lam = lam_batch[i] 166 | if lam != 1.: 167 | if use_cutmix[i]: 168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh] 171 | lam_batch[i] = lam 172 | else: 173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 175 | 176 | def _mix_pair(self, x): 177 | batch_size = len(x) 178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 179 | x_orig = x.clone() # need to keep an unmodified original for mixing source 180 | for i in range(batch_size // 2): 181 | j = batch_size - i - 1 182 | lam = lam_batch[i] 183 | if lam != 1.: 184 | if use_cutmix[i]: 185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 189 | lam_batch[i] = lam 190 | else: 191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 195 | 196 | def _mix_batch(self, x): 197 | lam, use_cutmix = self._params_per_batch() 198 | if lam == 1.: 199 | return 1. 200 | if use_cutmix: 201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 204 | else: 205 | x_flipped = x.flip(0).mul_(1. - lam) 206 | x.mul_(lam).add_(x_flipped) 207 | return lam 208 | 209 | def __call__(self, x, target): 210 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 211 | if self.mode == 'elem': 212 | lam = self._mix_elem(x) 213 | elif self.mode == 'pair': 214 | lam = self._mix_pair(x) 215 | else: 216 | lam = self._mix_batch(x) 217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 218 | return x, target 219 | 220 | 221 | class FastCollateMixup(Mixup): 222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 223 | 224 | A Mixup impl that's performed while collating the batches. 225 | """ 226 | 227 | def _mix_elem_collate(self, output, batch, half=False): 228 | batch_size = len(batch) 229 | num_elem = batch_size // 2 if half else batch_size 230 | assert len(output) == num_elem 231 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 232 | for i in range(num_elem): 233 | j = batch_size - i - 1 234 | lam = lam_batch[i] 235 | mixed = batch[i][0] 236 | if lam != 1.: 237 | if use_cutmix[i]: 238 | if not half: 239 | mixed = mixed.copy() 240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 243 | lam_batch[i] = lam 244 | else: 245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 246 | np.rint(mixed, out=mixed) 247 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 248 | if half: 249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 250 | return torch.tensor(lam_batch).unsqueeze(1) 251 | 252 | def _mix_pair_collate(self, output, batch): 253 | batch_size = len(batch) 254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 255 | for i in range(batch_size // 2): 256 | j = batch_size - i - 1 257 | lam = lam_batch[i] 258 | mixed_i = batch[i][0] 259 | mixed_j = batch[j][0] 260 | assert 0 <= lam <= 1.0 261 | if lam < 1.: 262 | if use_cutmix[i]: 263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 267 | mixed_j[:, yl:yh, xl:xh] = patch_i 268 | lam_batch[i] = lam 269 | else: 270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 272 | mixed_i = mixed_temp 273 | np.rint(mixed_j, out=mixed_j) 274 | np.rint(mixed_i, out=mixed_i) 275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 278 | return torch.tensor(lam_batch).unsqueeze(1) 279 | 280 | def _mix_batch_collate(self, output, batch): 281 | batch_size = len(batch) 282 | lam, use_cutmix = self._params_per_batch() 283 | if use_cutmix: 284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 286 | for i in range(batch_size): 287 | j = batch_size - i - 1 288 | mixed = batch[i][0] 289 | if lam != 1.: 290 | if use_cutmix: 291 | mixed = mixed.copy() # don't want to modify the original while iterating 292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh] 293 | else: 294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 295 | np.rint(mixed, out=mixed) 296 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 297 | return lam 298 | 299 | def __call__(self, batch, _=None): 300 | batch_size = len(batch) 301 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 302 | half = 'half' in self.mode 303 | if half: 304 | batch_size //= 2 305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 306 | if self.mode == 'elem' or self.mode == 'half': 307 | lam = self._mix_elem_collate(output, batch, half=half) 308 | elif self.mode == 'pair': 309 | lam = self._mix_pair_collate(output, batch) 310 | else: 311 | lam = self._mix_batch_collate(output, batch) 312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 314 | target = target[:batch_size] 315 | return output, target 316 | 317 | -------------------------------------------------------------------------------- /single_modality/datasets/multiloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils import is_dist_avail_and_initialized 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class MetaLoader(object): 11 | """ wraps multiple data loader """ 12 | def __init__(self, loaders): 13 | """Iterates over multiple dataloaders, it ensures all processes 14 | work on data from the same dataloader. This loader will end when 15 | the shorter dataloader raises StopIteration exception. 16 | 17 | loaders: Dict, {name: dataloader} 18 | """ 19 | self.loaders = loaders 20 | iter_order = [i for i in range(len(self.loaders))] 21 | random.shuffle(iter_order) 22 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 23 | 24 | # sync 25 | if is_dist_avail_and_initialized(): 26 | # make sure all processes have the same order so that 27 | # each step they will have data from the same loader 28 | dist.broadcast(iter_order, src=0) 29 | self.iter_order = [e for e in iter_order.cpu()] 30 | 31 | def __len__(self): 32 | return len(self.iter_order) 33 | 34 | def __iter__(self): 35 | """ this iterator will run indefinitely """ 36 | while True: 37 | try: 38 | for i in self.iter_order: 39 | _iter = self.loaders[i] 40 | batch = next(_iter) 41 | yield batch 42 | except: 43 | iter_order = [i for i in range(len(self.loaders))] 44 | random.shuffle(iter_order) 45 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) 46 | if is_dist_avail_and_initialized(): 47 | # make sure all processes have the same order so that 48 | # each step they will have data from the same loader 49 | dist.broadcast(iter_order, src=0) 50 | self.iter_order = [e for e in iter_order.cpu()] 51 | continue -------------------------------------------------------------------------------- /single_modality/datasets/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /single_modality/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import random 5 | import numpy as np 6 | import torchvision 7 | from PIL import Image, ImageOps 8 | import numbers 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_tuple): 19 | img_group, label = img_tuple 20 | 21 | w, h = img_group[0].size 22 | th, tw = self.size 23 | 24 | out_images = list() 25 | 26 | x1 = random.randint(0, w - tw) 27 | y1 = random.randint(0, h - th) 28 | 29 | for img in img_group: 30 | assert(img.size[0] == w and img.size[1] == h) 31 | if w == tw and h == th: 32 | out_images.append(img) 33 | else: 34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 35 | 36 | return (out_images, label) 37 | 38 | 39 | class GroupCenterCrop(object): 40 | def __init__(self, size): 41 | self.worker = torchvision.transforms.CenterCrop(size) 42 | 43 | def __call__(self, img_tuple): 44 | img_group, label = img_tuple 45 | return ([self.worker(img) for img in img_group], label) 46 | 47 | 48 | class GroupRandomHorizontalFlip(object): 49 | def __init__(self, flip=False): 50 | self.flip = flip 51 | 52 | def __call__(self, img_tuple): 53 | v = random.random() 54 | if self.flip and v < 0.5: 55 | img_group, label = img_tuple 56 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 57 | return (ret, label) 58 | else: 59 | return img_tuple 60 | 61 | 62 | class GroupNormalize(object): 63 | def __init__(self, mean, std): 64 | self.mean = mean 65 | self.std = std 66 | 67 | def __call__(self, tensor_tuple): 68 | tensor, label = tensor_tuple 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 71 | 72 | # TODO: make efficient 73 | for t, m, s in zip(tensor, rep_mean, rep_std): 74 | t.sub_(m).div_(s) 75 | 76 | return (tensor,label) 77 | 78 | 79 | class GroupGrayScale(object): 80 | def __init__(self, size): 81 | self.worker = torchvision.transforms.Grayscale(size) 82 | 83 | def __call__(self, img_tuple): 84 | img_group, label = img_tuple 85 | return ([self.worker(img) for img in img_group], label) 86 | 87 | 88 | class GroupColorJitter(object): 89 | def __init__(self, size): 90 | self.worker = torchvision.transforms.ColorJitter( 91 | brightness=size, contrast=size, saturation=size 92 | ) 93 | 94 | def __call__(self, img_tuple): 95 | img_group, label = img_tuple 96 | return ([self.worker(img) for img in img_group], label) 97 | 98 | 99 | class GroupScale(object): 100 | """ Rescales the input PIL.Image to the given 'size'. 101 | 'size' will be the size of the smaller edge. 102 | For example, if height > width, then image will be 103 | rescaled to (size * height / width, size) 104 | size: size of the smaller edge 105 | interpolation: Default: PIL.Image.BILINEAR 106 | """ 107 | 108 | def __init__(self, size, interpolation=Image.BILINEAR): 109 | self.worker = torchvision.transforms.Resize(size, interpolation) 110 | 111 | def __call__(self, img_tuple): 112 | img_group, label = img_tuple 113 | return ([self.worker(img) for img in img_group], label) 114 | 115 | 116 | class GroupMultiScaleCrop(object): 117 | 118 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 119 | self.scales = scales if scales is not None else [1, 875, .75, .66] 120 | self.max_distort = max_distort 121 | self.fix_crop = fix_crop 122 | self.more_fix_crop = more_fix_crop 123 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 124 | self.interpolation = Image.BILINEAR 125 | 126 | def __call__(self, img_tuple): 127 | img_group, label = img_tuple 128 | 129 | im_size = img_group[0].size 130 | 131 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 132 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 133 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 134 | return (ret_img_group, label) 135 | 136 | def _sample_crop_size(self, im_size): 137 | image_w, image_h = im_size[0], im_size[1] 138 | 139 | # find a crop size 140 | base_size = min(image_w, image_h) 141 | crop_sizes = [int(base_size * x) for x in self.scales] 142 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 143 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 144 | 145 | pairs = [] 146 | for i, h in enumerate(crop_h): 147 | for j, w in enumerate(crop_w): 148 | if abs(i - j) <= self.max_distort: 149 | pairs.append((w, h)) 150 | 151 | crop_pair = random.choice(pairs) 152 | if not self.fix_crop: 153 | w_offset = random.randint(0, image_w - crop_pair[0]) 154 | h_offset = random.randint(0, image_h - crop_pair[1]) 155 | else: 156 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 157 | 158 | return crop_pair[0], crop_pair[1], w_offset, h_offset 159 | 160 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 161 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 162 | return random.choice(offsets) 163 | 164 | @staticmethod 165 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 166 | w_step = (image_w - crop_w) // 4 167 | h_step = (image_h - crop_h) // 4 168 | 169 | ret = list() 170 | ret.append((0, 0)) # upper left 171 | ret.append((4 * w_step, 0)) # upper right 172 | ret.append((0, 4 * h_step)) # lower left 173 | ret.append((4 * w_step, 4 * h_step)) # lower right 174 | ret.append((2 * w_step, 2 * h_step)) # center 175 | 176 | if more_fix_crop: 177 | ret.append((0, 2 * h_step)) # center left 178 | ret.append((4 * w_step, 2 * h_step)) # center right 179 | ret.append((2 * w_step, 4 * h_step)) # lower center 180 | ret.append((2 * w_step, 0 * h_step)) # upper center 181 | 182 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 183 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 184 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 185 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 186 | return ret 187 | 188 | 189 | class Stack(object): 190 | 191 | def __init__(self, roll=False): 192 | self.roll = roll 193 | 194 | def __call__(self, img_tuple): 195 | img_group, label = img_tuple 196 | 197 | if img_group[0].mode == 'L': 198 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 199 | elif img_group[0].mode == 'RGB': 200 | if self.roll: 201 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 202 | else: 203 | return (np.concatenate(img_group, axis=2), label) 204 | 205 | 206 | class ToTorchFormatTensor(object): 207 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 208 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 209 | def __init__(self, div=True): 210 | self.div = div 211 | 212 | def __call__(self, pic_tuple): 213 | pic, label = pic_tuple 214 | 215 | if isinstance(pic, np.ndarray): 216 | # handle numpy array 217 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 218 | else: 219 | # handle PIL Image 220 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 221 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 222 | # put it from HWC to CHW format 223 | # yikes, this transpose takes 80% of the loading time/CPU 224 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 225 | return (img.float().div(255.) if self.div else img.float(), label) 226 | 227 | 228 | class IdentityTransform(object): 229 | 230 | def __call__(self, data): 231 | return data 232 | -------------------------------------------------------------------------------- /single_modality/datasets/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /single_modality/engines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/engines/__init__.py -------------------------------------------------------------------------------- /single_modality/engines/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/engines/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/engines/__pycache__/engine_for_flexible_tuning.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/engines/__pycache__/engine_for_flexible_tuning.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/engines/engine_for_flexible_tuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | import torch 8 | from datasets.mixup import Mixup 9 | from timm.utils import accuracy, ModelEma 10 | import utils 11 | from scipy.special import softmax 12 | import random 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | 16 | 17 | def get_loss_scale_for_deepspeed(model): 18 | optimizer = model.optimizer 19 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 20 | 21 | 22 | def train_one_epoch( 23 | model: torch.nn.Module, criterion: torch.nn.Module, 24 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 27 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 28 | num_training_steps_per_epoch=None, update_freq=None, args=None, 29 | bf16=False, 30 | ): 31 | model.train(True) 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 1 37 | 38 | if loss_scaler is None: 39 | model.zero_grad() 40 | model.micro_steps = 0 41 | else: 42 | optimizer.zero_grad() 43 | 44 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | 46 | step = data_iter_step // update_freq 47 | if step >= num_training_steps_per_epoch: 48 | continue 49 | it = start_steps + step # global training iteration 50 | # Update LR & WD for the first acc 51 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 52 | for i, param_group in enumerate(optimizer.param_groups): 53 | if lr_schedule_values is not None: 54 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 55 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 56 | param_group["weight_decay"] = wd_schedule_values[it] 57 | 58 | samples = samples.to(device, non_blocking=True) 59 | targets = targets.to(device, non_blocking=True) 60 | 61 | if mixup_fn is not None: 62 | samples, targets = mixup_fn(samples, targets) 63 | 64 | if loss_scaler is None: 65 | 66 | B, C, T, H, W = samples.shape 67 | 68 | samples = samples.to(device, non_blocking=True) 69 | 70 | nums_token = T * (H // 14) ** 2 71 | samples = samples.bfloat16() if bf16 else samples.half() 72 | 73 | def compute_mask_ratio(num_current_token, num_input_token): 74 | if num_current_token <= num_input_token: 75 | mask_ratio = 0.0 76 | else: 77 | mask_ratio = 1 - num_input_token / num_current_token 78 | return mask_ratio 79 | 80 | def compute_smooth_l1_loss(src, tgt): 81 | tgt = tgt.detach() 82 | src = src / src.norm(dim=-1, keepdim=True).to(torch.float32) 83 | tgt = tgt / tgt.norm(dim=-1, keepdim=True).to(torch.float32) 84 | return torch.nn.SmoothL1Loss()(src, tgt.detach()) 85 | 86 | mask_ratio = compute_mask_ratio(nums_token, args.largest_num_input_token) 87 | outputs1, x_final_1 = model(samples, masking_ratio=mask_ratio, output_head=0, return_cls=True) 88 | 89 | mask_ratio = compute_mask_ratio(nums_token, args.middle_num_input_token) 90 | outputs2, x_final_2 = model(samples, masking_ratio=mask_ratio, output_head=1, return_cls=True, return_projected=True, align_proj=0) 91 | 92 | mask_ratio = compute_mask_ratio(nums_token, args.least_num_input_token) 93 | outputs3, x_final_3 = model(samples, masking_ratio=mask_ratio, output_head=2, return_cls=True, return_projected=True, align_proj=1) 94 | 95 | pred = criterion(outputs1, targets) + criterion(outputs2, targets) + criterion(outputs3, targets) 96 | align_loss = compute_smooth_l1_loss(x_final_2, x_final_1) + compute_smooth_l1_loss(x_final_3, x_final_1) 97 | 98 | loss = pred + align_loss 99 | 100 | else: 101 | raise NotImplementedError 102 | 103 | loss_value = loss.item() 104 | 105 | loss_list = [torch.zeros_like(loss) for _ in range(dist.get_world_size())] 106 | dist.all_gather(loss_list, loss) 107 | loss_list = torch.tensor(loss_list) 108 | loss_list_isnan = torch.isnan(loss_list).any() 109 | loss_list_isinf = torch.isinf(loss_list).any() 110 | 111 | if loss_list_isnan or loss_list_isinf: 112 | print("Loss is {}, stopping training".format(loss_value), force=True) 113 | model.zero_grad() 114 | continue 115 | 116 | if loss_scaler is None: 117 | loss /= update_freq 118 | model.backward(loss) 119 | model.step() 120 | 121 | if (data_iter_step + 1) % update_freq == 0: 122 | # model.zero_grad() 123 | # Deepspeed will call step() & model.zero_grad() automatic 124 | if model_ema is not None: 125 | model_ema.update(model) 126 | grad_norm = None 127 | loss_scale_value = get_loss_scale_for_deepspeed(model) 128 | else: 129 | # this attribute is added by timm on one optimizer (adahessian) 130 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 131 | loss /= update_freq 132 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 133 | parameters=model.parameters(), create_graph=is_second_order, 134 | update_grad=(data_iter_step + 1) % update_freq == 0) 135 | if (data_iter_step + 1) % update_freq == 0: 136 | optimizer.zero_grad() 137 | if model_ema is not None: 138 | model_ema.update(model) 139 | loss_scale_value = loss_scaler.state_dict()["scale"] 140 | 141 | torch.cuda.synchronize() 142 | 143 | class_acc = None 144 | 145 | metric_logger.update(loss=loss_value) 146 | metric_logger.update(class_acc=class_acc) 147 | metric_logger.update(loss_scale=loss_scale_value) 148 | min_lr = 10. 149 | max_lr = 0. 150 | for group in optimizer.param_groups: 151 | min_lr = min(min_lr, group["lr"]) 152 | max_lr = max(max_lr, group["lr"]) 153 | 154 | metric_logger.update(lr=max_lr) 155 | metric_logger.update(min_lr=min_lr) 156 | weight_decay_value = None 157 | for group in optimizer.param_groups: 158 | if group["weight_decay"] > 0: 159 | weight_decay_value = group["weight_decay"] 160 | metric_logger.update(weight_decay=weight_decay_value) 161 | metric_logger.update(grad_norm=grad_norm) 162 | 163 | if log_writer is not None: 164 | log_writer.update(loss=loss_value, head="loss") 165 | log_writer.update(class_acc=class_acc, head="loss") 166 | log_writer.update(loss_scale=loss_scale_value, head="opt") 167 | log_writer.update(lr=max_lr, head="opt") 168 | log_writer.update(min_lr=min_lr, head="opt") 169 | log_writer.update(weight_decay=weight_decay_value, head="opt") 170 | log_writer.update(grad_norm=grad_norm, head="opt") 171 | 172 | log_writer.set_step() 173 | 174 | # gather the stats from all processes 175 | metric_logger.synchronize_between_processes() 176 | print("Averaged stats:", metric_logger) 177 | 178 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 179 | 180 | 181 | @torch.no_grad() 182 | def validation_one_epoch(data_loader, model, device, args, ds=False, bf16=False): 183 | criterion = torch.nn.CrossEntropyLoss() 184 | 185 | metric_logger = utils.MetricLogger(delimiter=" ") 186 | header = 'Val:' 187 | 188 | # switch to evaluation mode 189 | model.eval() 190 | 191 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2) 192 | 193 | if current_tokens <= args.eval_max_token_num: 194 | mask_ratio = None 195 | nums_token = current_tokens 196 | else: 197 | mask_ratio = 1 - args.eval_max_token_num / current_tokens 198 | nums_token = args.eval_max_token_num 199 | 200 | print('Nums Token: ', nums_token) 201 | 202 | num_merging_to = args.eval_num_merging_to 203 | 204 | print("Num Merging Ratio: ", num_merging_to) 205 | 206 | for batch in metric_logger.log_every(data_loader, 10, header): 207 | videos = batch[0] 208 | target = batch[1] 209 | videos = videos.to(device, non_blocking=True) 210 | target = target.to(device, non_blocking=True) 211 | 212 | if ds: 213 | videos = videos.bfloat16() if bf16 else videos.half() 214 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head) 215 | loss = criterion(output, target) 216 | else: 217 | with torch.cuda.amp.autocast(): 218 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head) 219 | loss = criterion(output, target) 220 | 221 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 222 | 223 | batch_size = videos.shape[0] 224 | metric_logger.update(loss=loss.item()) 225 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 226 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 227 | # gather the stats from all processes 228 | metric_logger.synchronize_between_processes() 229 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 230 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 231 | 232 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 233 | 234 | 235 | @torch.no_grad() 236 | def final_test(data_loader, model, device, file, args, ds=False, bf16=False, test_time=0): 237 | criterion = torch.nn.CrossEntropyLoss() 238 | 239 | metric_logger = utils.MetricLogger(delimiter=" ") 240 | header = 'Test:' 241 | 242 | # switch to evaluation mode 243 | model.eval() 244 | final_result = [] 245 | 246 | # args.eval_masking_ratio = random.uniform(args.eval_masking_ratio / 3, args.eval_masking_ratio) 247 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2) 248 | 249 | if current_tokens <= args.eval_max_token_num: 250 | mask_ratio = None 251 | nums_token = current_tokens 252 | else: 253 | mask_ratio = 1 - args.eval_max_token_num / current_tokens 254 | nums_token = args.eval_max_token_num 255 | 256 | for batch in metric_logger.log_every(data_loader, 10, header): 257 | videos = batch[0] 258 | target = batch[1] 259 | ids = batch[2] 260 | chunk_nb = batch[3] 261 | split_nb = batch[4] 262 | videos = videos.to(device, non_blocking=True) 263 | target = target.to(device, non_blocking=True) 264 | 265 | # compute output 266 | if ds: 267 | videos = videos.bfloat16() if bf16 else videos.half() 268 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head) 269 | loss = criterion(output, target) 270 | else: 271 | with torch.cuda.amp.autocast(): 272 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head) 273 | loss = criterion(output, target) 274 | 275 | for i in range(output.size(0)): 276 | string = "{} {} {} {} {}\n".format(ids[i], \ 277 | str(output.data[i].float().cpu().numpy().tolist()), \ 278 | str(int(target[i].cpu().numpy())), \ 279 | str(int(chunk_nb[i].cpu().numpy())), \ 280 | str(int(split_nb[i].cpu().numpy()))) 281 | final_result.append(string) 282 | 283 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 284 | 285 | batch_size = videos.shape[0] 286 | metric_logger.update(loss=loss.item()) 287 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 288 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 289 | 290 | if not os.path.exists(file): 291 | os.mknod(file) 292 | with open(file, 'w') as f: 293 | f.write("{}, {}\n".format(acc1, acc5)) 294 | for line in final_result: 295 | f.write(line) 296 | # gather the stats from all processes 297 | metric_logger.synchronize_between_processes() 298 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 299 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 300 | 301 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 302 | 303 | 304 | def merge(eval_path, num_tasks): 305 | dict_feats = {} 306 | dict_label = {} 307 | dict_pos = {} 308 | print("Reading individual output files") 309 | 310 | for x in range(num_tasks): 311 | file = os.path.join(eval_path, str(x) + '.txt') 312 | lines = open(file, 'r').readlines()[1:] 313 | for line in lines: 314 | line = line.strip() 315 | name = line.split(' ')[0] 316 | label = line.split(']')[-1].split(' ')[1] 317 | chunk_nb = line.split(']')[-1].split(' ')[2] 318 | split_nb = line.split(']')[-1].split(' ')[3] 319 | data = np.fromstring(' '.join(line.split(' ')[1:]).split('[')[1].split(']')[0], dtype=np.float32, sep=',') 320 | data = softmax(data) 321 | if not name in dict_feats: 322 | dict_feats[name] = [] 323 | dict_label[name] = 0 324 | dict_pos[name] = [] 325 | if chunk_nb + split_nb in dict_pos[name]: 326 | continue 327 | dict_feats[name].append(data) 328 | dict_pos[name].append(chunk_nb + split_nb) 329 | dict_label[name] = label 330 | print("Computing final results") 331 | 332 | input_lst = [] 333 | print(len(dict_feats)) 334 | for i, item in enumerate(dict_feats): 335 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 336 | from multiprocessing import Pool 337 | p = Pool(64) 338 | ans = p.map(compute_video, input_lst) 339 | top1 = [x[1] for x in ans] 340 | top5 = [x[2] for x in ans] 341 | pred = [x[0] for x in ans] 342 | label = [x[3] for x in ans] 343 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 344 | return final_top1*100 ,final_top5*100 345 | 346 | def compute_video(lst): 347 | i, video_id, data, label = lst 348 | feat = [x for x in data] 349 | feat = np.mean(feat, axis=0) 350 | pred = np.argmax(feat) 351 | top1 = (int(pred) == int(label)) * 1.0 352 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 353 | return [pred, top1, top5, int(label)] 354 | -------------------------------------------------------------------------------- /single_modality/engines/engine_for_flexible_tuning_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | import torch 8 | from datasets.mixup import Mixup 9 | from timm.utils import accuracy, ModelEma 10 | import utils 11 | from scipy.special import softmax 12 | import random 13 | import torch.nn.functional as F 14 | import torch.distributed as dist 15 | 16 | 17 | def get_loss_scale_for_deepspeed(model): 18 | optimizer = model.optimizer 19 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 20 | 21 | 22 | def train_one_epoch( 23 | model: torch.nn.Module, criterion: torch.nn.Module, 24 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 27 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 28 | num_training_steps_per_epoch=None, update_freq=None, args=None, 29 | bf16=False, 30 | ): 31 | model.train(True) 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 1 37 | 38 | if loss_scaler is None: 39 | model.zero_grad() 40 | model.micro_steps = 0 41 | else: 42 | optimizer.zero_grad() 43 | 44 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | 46 | step = data_iter_step // update_freq 47 | if step >= num_training_steps_per_epoch: 48 | continue 49 | it = start_steps + step # global training iteration 50 | # Update LR & WD for the first acc 51 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 52 | for i, param_group in enumerate(optimizer.param_groups): 53 | if lr_schedule_values is not None: 54 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 55 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 56 | param_group["weight_decay"] = wd_schedule_values[it] 57 | 58 | samples = samples.to(device, non_blocking=True) 59 | targets = targets.to(device, non_blocking=True) 60 | 61 | if mixup_fn is not None: 62 | samples, targets = mixup_fn(samples, targets) 63 | 64 | if loss_scaler is None: 65 | 66 | B, C, T, H, W = samples.shape 67 | 68 | samples = samples.to(device, non_blocking=True) 69 | 70 | nums_token = T * (H // 14) ** 2 71 | samples = samples.bfloat16() if bf16 else samples.half() 72 | 73 | def compute_mask_ratio(num_current_token, num_input_token): 74 | if num_current_token <= num_input_token: 75 | mask_ratio = 0.0 76 | else: 77 | mask_ratio = 1 - num_input_token / num_current_token 78 | return mask_ratio 79 | 80 | mask_ratio = compute_mask_ratio(nums_token, args.eval_max_token_num) 81 | outputs, _ = model(samples, masking_ratio=mask_ratio, output_head=args.output_head, return_cls=True) 82 | 83 | pred = criterion(outputs, targets) 84 | 85 | loss = pred 86 | 87 | else: 88 | raise NotImplementedError 89 | 90 | loss_value = loss.item() 91 | 92 | loss_list = [torch.zeros_like(loss) for _ in range(dist.get_world_size())] 93 | dist.all_gather(loss_list, loss) 94 | loss_list = torch.tensor(loss_list) 95 | loss_list_isnan = torch.isnan(loss_list).any() 96 | loss_list_isinf = torch.isinf(loss_list).any() 97 | 98 | if loss_list_isnan or loss_list_isinf: 99 | print("Loss is {}, stopping training".format(loss_value), force=True) 100 | model.zero_grad() 101 | continue 102 | 103 | if loss_scaler is None: 104 | loss /= update_freq 105 | model.backward(loss) 106 | model.step() 107 | 108 | if (data_iter_step + 1) % update_freq == 0: 109 | # model.zero_grad() 110 | # Deepspeed will call step() & model.zero_grad() automatic 111 | if model_ema is not None: 112 | model_ema.update(model) 113 | grad_norm = None 114 | loss_scale_value = get_loss_scale_for_deepspeed(model) 115 | else: 116 | # this attribute is added by timm on one optimizer (adahessian) 117 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 118 | loss /= update_freq 119 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 120 | parameters=model.parameters(), create_graph=is_second_order, 121 | update_grad=(data_iter_step + 1) % update_freq == 0) 122 | if (data_iter_step + 1) % update_freq == 0: 123 | optimizer.zero_grad() 124 | if model_ema is not None: 125 | model_ema.update(model) 126 | loss_scale_value = loss_scaler.state_dict()["scale"] 127 | 128 | torch.cuda.synchronize() 129 | 130 | class_acc = None 131 | 132 | metric_logger.update(loss=loss_value) 133 | metric_logger.update(class_acc=class_acc) 134 | metric_logger.update(loss_scale=loss_scale_value) 135 | min_lr = 10. 136 | max_lr = 0. 137 | for group in optimizer.param_groups: 138 | min_lr = min(min_lr, group["lr"]) 139 | max_lr = max(max_lr, group["lr"]) 140 | 141 | metric_logger.update(lr=max_lr) 142 | metric_logger.update(min_lr=min_lr) 143 | weight_decay_value = None 144 | for group in optimizer.param_groups: 145 | if group["weight_decay"] > 0: 146 | weight_decay_value = group["weight_decay"] 147 | metric_logger.update(weight_decay=weight_decay_value) 148 | metric_logger.update(grad_norm=grad_norm) 149 | 150 | if log_writer is not None: 151 | log_writer.update(loss=loss_value, head="loss") 152 | log_writer.update(class_acc=class_acc, head="loss") 153 | log_writer.update(loss_scale=loss_scale_value, head="opt") 154 | log_writer.update(lr=max_lr, head="opt") 155 | log_writer.update(min_lr=min_lr, head="opt") 156 | log_writer.update(weight_decay=weight_decay_value, head="opt") 157 | log_writer.update(grad_norm=grad_norm, head="opt") 158 | 159 | log_writer.set_step() 160 | 161 | # gather the stats from all processes 162 | metric_logger.synchronize_between_processes() 163 | print("Averaged stats:", metric_logger) 164 | 165 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 166 | 167 | 168 | @torch.no_grad() 169 | def validation_one_epoch(data_loader, model, device, args, ds=False, bf16=False): 170 | criterion = torch.nn.CrossEntropyLoss() 171 | 172 | metric_logger = utils.MetricLogger(delimiter=" ") 173 | header = 'Val:' 174 | 175 | # switch to evaluation mode 176 | model.eval() 177 | 178 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2) 179 | 180 | if current_tokens <= args.eval_max_token_num: 181 | mask_ratio = None 182 | nums_token = current_tokens 183 | else: 184 | mask_ratio = 1 - args.eval_max_token_num / current_tokens 185 | nums_token = args.eval_max_token_num 186 | 187 | print('Nums Token: ', nums_token) 188 | 189 | num_merging_to = args.eval_num_merging_to 190 | 191 | print("Num Merging Ratio: ", num_merging_to) 192 | 193 | for batch in metric_logger.log_every(data_loader, 10, header): 194 | videos = batch[0] 195 | target = batch[1] 196 | videos = videos.to(device, non_blocking=True) 197 | target = target.to(device, non_blocking=True) 198 | 199 | if ds: 200 | videos = videos.bfloat16() if bf16 else videos.half() 201 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head) 202 | loss = criterion(output, target) 203 | else: 204 | with torch.cuda.amp.autocast(): 205 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head) 206 | loss = criterion(output, target) 207 | 208 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 209 | 210 | batch_size = videos.shape[0] 211 | metric_logger.update(loss=loss.item()) 212 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 213 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 214 | # gather the stats from all processes 215 | metric_logger.synchronize_between_processes() 216 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 217 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 218 | 219 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 220 | 221 | 222 | @torch.no_grad() 223 | def final_test(data_loader, model, device, file, args, ds=False, bf16=False, test_time=0): 224 | criterion = torch.nn.CrossEntropyLoss() 225 | 226 | metric_logger = utils.MetricLogger(delimiter=" ") 227 | header = 'Test:' 228 | 229 | # switch to evaluation mode 230 | model.eval() 231 | final_result = [] 232 | 233 | # args.eval_masking_ratio = random.uniform(args.eval_masking_ratio / 3, args.eval_masking_ratio) 234 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2) 235 | 236 | if current_tokens <= args.eval_max_token_num: 237 | mask_ratio = None 238 | nums_token = current_tokens 239 | else: 240 | mask_ratio = 1 - args.eval_max_token_num / current_tokens 241 | nums_token = args.eval_max_token_num 242 | 243 | for batch in metric_logger.log_every(data_loader, 10, header): 244 | videos = batch[0] 245 | target = batch[1] 246 | ids = batch[2] 247 | chunk_nb = batch[3] 248 | split_nb = batch[4] 249 | videos = videos.to(device, non_blocking=True) 250 | target = target.to(device, non_blocking=True) 251 | 252 | # compute output 253 | if ds: 254 | videos = videos.bfloat16() if bf16 else videos.half() 255 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head) 256 | loss = criterion(output, target) 257 | else: 258 | with torch.cuda.amp.autocast(): 259 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head) 260 | loss = criterion(output, target) 261 | 262 | for i in range(output.size(0)): 263 | string = "{} {} {} {} {}\n".format(ids[i], \ 264 | str(output.data[i].float().cpu().numpy().tolist()), \ 265 | str(int(target[i].cpu().numpy())), \ 266 | str(int(chunk_nb[i].cpu().numpy())), \ 267 | str(int(split_nb[i].cpu().numpy()))) 268 | final_result.append(string) 269 | 270 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 271 | 272 | batch_size = videos.shape[0] 273 | metric_logger.update(loss=loss.item()) 274 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 275 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 276 | 277 | if not os.path.exists(file): 278 | os.mknod(file) 279 | with open(file, 'w') as f: 280 | f.write("{}, {}\n".format(acc1, acc5)) 281 | for line in final_result: 282 | f.write(line) 283 | # gather the stats from all processes 284 | metric_logger.synchronize_between_processes() 285 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 286 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 287 | 288 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 289 | 290 | 291 | def merge(eval_path, num_tasks): 292 | dict_feats = {} 293 | dict_label = {} 294 | dict_pos = {} 295 | print("Reading individual output files") 296 | 297 | for x in range(num_tasks): 298 | file = os.path.join(eval_path, str(x) + '.txt') 299 | lines = open(file, 'r').readlines()[1:] 300 | for line in lines: 301 | line = line.strip() 302 | name = line.split(' ')[0] 303 | label = line.split(']')[-1].split(' ')[1] 304 | chunk_nb = line.split(']')[-1].split(' ')[2] 305 | split_nb = line.split(']')[-1].split(' ')[3] 306 | data = np.fromstring(' '.join(line.split(' ')[1:]).split('[')[1].split(']')[0], dtype=np.float32, sep=',') 307 | data = softmax(data) 308 | if not name in dict_feats: 309 | dict_feats[name] = [] 310 | dict_label[name] = 0 311 | dict_pos[name] = [] 312 | if chunk_nb + split_nb in dict_pos[name]: 313 | continue 314 | dict_feats[name].append(data) 315 | dict_pos[name].append(chunk_nb + split_nb) 316 | dict_label[name] = label 317 | print("Computing final results") 318 | 319 | input_lst = [] 320 | print(len(dict_feats)) 321 | for i, item in enumerate(dict_feats): 322 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 323 | from multiprocessing import Pool 324 | p = Pool(64) 325 | ans = p.map(compute_video, input_lst) 326 | top1 = [x[1] for x in ans] 327 | top5 = [x[2] for x in ans] 328 | pred = [x[0] for x in ans] 329 | label = [x[3] for x in ans] 330 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 331 | return final_top1*100 ,final_top5*100 332 | 333 | def compute_video(lst): 334 | i, video_id, data, label = lst 335 | feat = [x for x in data] 336 | feat = np.mean(feat, axis=0) 337 | pred = np.argmax(feat) 338 | top1 = (int(pred) == int(label)) * 1.0 339 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 340 | return [pred, top1, top5, int(label)] 341 | -------------------------------------------------------------------------------- /single_modality/exp/base/eval/k400_eval.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | 4 | JOB_NAME='k400_base_f8_lr2e-4_w5e35_dp0.1_hdp0.1_ld0.75_g64' 5 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME" 6 | LOG_DIR="./logs/${JOB_NAME}" 7 | PREFIX='p2:s3://k400' 8 | DATA_PATH='' # replace with your data_path 9 | MODEL_PATH='' # replace with your model_path 10 | 11 | PARTITION='video' 12 | GPUS=64 13 | GPUS_PER_NODE=8 14 | CPUS_PER_TASK=16 15 | 16 | MAX_TOKEN_COUNT=2048 17 | TEST_RES=224 18 | TEST_FRAME=8 19 | OUTPUT_HEAD=1 # can be set with {0, 1, 2}, corresponding to {1024, 2048, 3072} tokens, use nearest head for your token budget 20 | 21 | srun -p $PARTITION \ 22 | --gres=gpu:${GPUS_PER_NODE} \ 23 | --ntasks=${GPUS} \ 24 | --ntasks-per-node=${GPUS_PER_NODE} \ 25 | --cpus-per-task=${CPUS_PER_TASK} \ 26 | --kill-on-bad-exit=1 \ 27 | python -u run_flexible_finetune.py \ 28 | --eval \ 29 | --limit_token_num 6144 \ 30 | --largest_num_input_token 3072 \ 31 | --middle_num_input_token 2048 \ 32 | --least_num_input_token 1024 \ 33 | --eval_max_token_num ${MAX_TOKEN_COUNT} \ 34 | --eval_input_size ${TEST_RES} \ 35 | --eval_short_side_size ${TEST_RES} \ 36 | --eval_true_frame ${TEST_FRAME} \ 37 | --output_head ${OUTPUT_HEAD} \ 38 | --finetune ${MODEL_PATH} \ 39 | --model fluxvit_base_patch14 \ 40 | --data_path ${DATA_PATH} \ 41 | --data_set 'Kinetics_sparse_ofa' \ 42 | --prefix ${PREFIX} \ 43 | --split ',' \ 44 | --nb_classes 400 \ 45 | --finetune ${MODEL_PATH} \ 46 | --log_dir ${OUTPUT_DIR} \ 47 | --output_dir ${OUTPUT_DIR} \ 48 | --steps_per_print 10 \ 49 | --batch_size 8 \ 50 | --num_sample 2 \ 51 | --input_size 252 \ 52 | --short_side_size 252 \ 53 | --model_res 252 \ 54 | --save_ckpt_freq 100 \ 55 | --num_frames 32 \ 56 | --eval_orig_frame 32 \ 57 | --orig_t_size 32 \ 58 | --num_workers 12 \ 59 | --warmup_epochs 1 \ 60 | --tubelet_size 1 \ 61 | --epochs 5 \ 62 | --lr 2e-4 \ 63 | --drop_path 0.1 \ 64 | --head_drop_path 0.1 \ 65 | --fc_drop_rate 0.0 \ 66 | --layer_decay 0.75 \ 67 | --layer_scale_init_value 1e-5 \ 68 | --opt adamw \ 69 | --opt_betas 0.9 0.999 \ 70 | --weight_decay 0.05 \ 71 | --test_num_segment 4 \ 72 | --test_num_crop 3 \ 73 | --dist_eval \ 74 | --enable_deepspeed \ 75 | --bf16 \ 76 | --zero_stage 1 \ 77 | --log_dir ${OUTPUT_DIR} \ 78 | --output_dir ${OUTPUT_DIR} \ 79 | --test_best -------------------------------------------------------------------------------- /single_modality/exp/base/ft/k400/k400_multi.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/exp/base/ft/k400/k400_multi.sh -------------------------------------------------------------------------------- /single_modality/exp/base/ft/k400/k400_single.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/exp/base/ft/k400/k400_single.sh -------------------------------------------------------------------------------- /single_modality/exp/small/eval/k400_eval.sh: -------------------------------------------------------------------------------- 1 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 2 | export OMP_NUM_THREADS=1 3 | 4 | JOB_NAME='k400_small_f8_lr2e-4_w5e35_dp0.1_hdp0.1_ld0.75_g64' 5 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME" 6 | LOG_DIR="./logs/${JOB_NAME}" 7 | PREFIX='p2:s3://k400' 8 | DATA_PATH='' # replace with your data_path 9 | MODEL_PATH='' # replace with your model_path 10 | 11 | PARTITION='video' 12 | GPUS=64 13 | GPUS_PER_NODE=8 14 | CPUS_PER_TASK=16 15 | 16 | MAX_TOKEN_COUNT=2048 17 | TEST_RES=224 18 | TEST_FRAME=8 19 | OUTPUT_HEAD=1 # can be set with {0, 1, 2}, corresponding to {1024, 2048, 3072} tokens, use nearest head for your token budget 20 | 21 | srun -p $PARTITION \ 22 | --gres=gpu:${GPUS_PER_NODE} \ 23 | --ntasks=${GPUS} \ 24 | --ntasks-per-node=${GPUS_PER_NODE} \ 25 | --cpus-per-task=${CPUS_PER_TASK} \ 26 | --kill-on-bad-exit=1 \ 27 | python -u run_flexible_finetune.py \ 28 | --eval \ 29 | --limit_token_num 6144 \ 30 | --largest_num_input_token 3072 \ 31 | --middle_num_input_token 2048 \ 32 | --least_num_input_token 1024 \ 33 | --eval_max_token_num ${MAX_TOKEN_COUNT} \ 34 | --eval_input_size ${TEST_RES} \ 35 | --eval_short_side_size ${TEST_RES} \ 36 | --eval_true_frame ${TEST_FRAME} \ 37 | --output_head ${OUTPUT_HEAD} \ 38 | --finetune ${MODEL_PATH} \ 39 | --model fluxvit_small_patch14 \ 40 | --data_path ${DATA_PATH} \ 41 | --data_set 'Kinetics_sparse_ofa' \ 42 | --prefix ${PREFIX} \ 43 | --split ',' \ 44 | --nb_classes 400 \ 45 | --finetune ${MODEL_PATH} \ 46 | --log_dir ${OUTPUT_DIR} \ 47 | --output_dir ${OUTPUT_DIR} \ 48 | --steps_per_print 10 \ 49 | --batch_size 8 \ 50 | --num_sample 2 \ 51 | --input_size 252 \ 52 | --short_side_size 252 \ 53 | --model_res 252 \ 54 | --save_ckpt_freq 100 \ 55 | --num_frames 24 \ 56 | --eval_orig_frame 24 \ 57 | --orig_t_size 24 \ 58 | --num_workers 12 \ 59 | --warmup_epochs 1 \ 60 | --tubelet_size 1 \ 61 | --epochs 5 \ 62 | --lr 2e-4 \ 63 | --drop_path 0.1 \ 64 | --head_drop_path 0.1 \ 65 | --fc_drop_rate 0.0 \ 66 | --layer_decay 0.75 \ 67 | --layer_scale_init_value 1e-5 \ 68 | --opt adamw \ 69 | --opt_betas 0.9 0.999 \ 70 | --weight_decay 0.05 \ 71 | --test_num_segment 4 \ 72 | --test_num_crop 3 \ 73 | --dist_eval \ 74 | --enable_deepspeed \ 75 | --bf16 \ 76 | --zero_stage 1 \ 77 | --log_dir ${OUTPUT_DIR} \ 78 | --output_dir ${OUTPUT_DIR} \ 79 | --test_best -------------------------------------------------------------------------------- /single_modality/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /single_modality/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fluxvit import fluxvit_small_patch14, fluxvit_base_patch14 -------------------------------------------------------------------------------- /single_modality/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/models/__pycache__/flash_attention_class.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/flash_attention_class.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/models/__pycache__/fluxvit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/fluxvit.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/models/__pycache__/pos_embed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/pos_embed.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/models/__pycache__/vid_tldr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/vid_tldr.cpython-39.pyc -------------------------------------------------------------------------------- /single_modality/models/flash_attention_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from einops import rearrange 5 | 6 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 7 | from flash_attn.bert_padding import unpad_input, pad_input 8 | 9 | 10 | class FlashAttention(nn.Module): 11 | """Implement the scaled dot product attention with softmax. 12 | Arguments 13 | --------- 14 | softmax_scale: The temperature to use for the softmax attention. 15 | (default: 1/sqrt(d_keys) where d_keys is computed at 16 | runtime) 17 | attention_dropout: The dropout rate to apply to the attention 18 | (default: 0.0) 19 | """ 20 | 21 | def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): 22 | super().__init__() 23 | self.softmax_scale = softmax_scale 24 | self.dropout_p = attention_dropout 25 | 26 | def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, 27 | max_s=None, need_weights=False): 28 | """Implements the multihead softmax attention. 29 | Arguments 30 | --------- 31 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None 32 | if unpadded: (nnz, 3, h, d) 33 | key_padding_mask: a bool tensor of shape (B, S) 34 | """ 35 | assert not need_weights 36 | assert qkv.dtype in [torch.float16, torch.bfloat16] 37 | assert qkv.is_cuda 38 | 39 | if cu_seqlens is None: 40 | batch_size = qkv.shape[0] 41 | seqlen = qkv.shape[1] 42 | if key_padding_mask is None: 43 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 44 | max_s = seqlen 45 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, 46 | device=qkv.device) 47 | output = flash_attn_varlen_qkvpacked_func( 48 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 49 | softmax_scale=self.softmax_scale, causal=causal 50 | ) 51 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) 52 | else: 53 | nheads = qkv.shape[-2] 54 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 55 | x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) 56 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 57 | output_unpad = flash_attn_varlen_qkvpacked_func( 58 | x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 59 | softmax_scale=self.softmax_scale, causal=causal 60 | ) 61 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 62 | indices, batch_size, seqlen), 63 | 'b s (h d) -> b s h d', h=nheads) 64 | else: 65 | assert max_s is not None 66 | output = flash_attn_varlen_qkvpacked_func( 67 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 68 | softmax_scale=self.softmax_scale, causal=causal 69 | ) 70 | 71 | return output, None 72 | -------------------------------------------------------------------------------- /single_modality/models/pos_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | 6 | 7 | # -------------------------------------------------------- 8 | # 3D sine-cosine position embedding 9 | # References: 10 | # MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py 11 | # -------------------------------------------------------- 12 | def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): 13 | """ 14 | grid_size: int of the grid height and width 15 | t_size: int of the temporal size 16 | return: 17 | pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 18 | """ 19 | assert embed_dim % 4 == 0 20 | embed_dim_spatial = embed_dim // 4 * 3 21 | embed_dim_temporal = embed_dim // 4 22 | 23 | # spatial 24 | grid_h = np.arange(grid_size, dtype=np.float32) 25 | grid_w = np.arange(grid_size, dtype=np.float32) 26 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 27 | grid = np.stack(grid, axis=0) 28 | 29 | grid = grid.reshape([2, 1, grid_size, grid_size]) 30 | pos_embed_spatial = get_2d_sincos_pos_embed_from_grid( 31 | embed_dim_spatial, grid 32 | ) 33 | 34 | # temporal 35 | grid_t = np.arange(t_size, dtype=np.float32) 36 | pos_embed_temporal = get_1d_sincos_pos_embed_from_grid( 37 | embed_dim_temporal, grid_t 38 | ) 39 | 40 | # concate: [T, H, W] order 41 | pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] 42 | pos_embed_temporal = np.repeat( 43 | pos_embed_temporal, grid_size**2, axis=1 44 | ) # [T, H*W, D // 4] 45 | pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] 46 | pos_embed_spatial = np.repeat( 47 | pos_embed_spatial, t_size, axis=0 48 | ) # [T, H*W, D // 4 * 3] 49 | 50 | pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) 51 | pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] 52 | 53 | if cls_token: 54 | pos_embed = np.concatenate( 55 | [np.zeros([1, embed_dim]), pos_embed], axis=0 56 | ) 57 | return pos_embed 58 | 59 | 60 | # -------------------------------------------------------- 61 | # 2D sine-cosine position embedding 62 | # References: 63 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 64 | # MoCo v3: https://github.com/facebookresearch/moco-v3 65 | # -------------------------------------------------------- 66 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 67 | """ 68 | grid_size: int of the grid height and width 69 | return: 70 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 71 | """ 72 | grid_h = np.arange(grid_size, dtype=np.float32) 73 | grid_w = np.arange(grid_size, dtype=np.float32) 74 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 75 | grid = np.stack(grid, axis=0) 76 | 77 | grid = grid.reshape([2, 1, grid_size, grid_size]) 78 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 79 | if cls_token: 80 | pos_embed = np.concatenate( 81 | [np.zeros([1, embed_dim]), pos_embed], axis=0 82 | ) 83 | return pos_embed 84 | 85 | 86 | def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False): 87 | """ 88 | t_size: int of the temporal size 89 | return: 90 | pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token) 91 | """ 92 | grid_t = np.arange(t_size, dtype=np.float32) 93 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t) 94 | if cls_token: 95 | pos_embed = np.concatenate( 96 | [np.zeros([1, embed_dim]), pos_embed], axis=0 97 | ) 98 | return pos_embed 99 | 100 | 101 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 102 | assert embed_dim % 2 == 0 103 | 104 | # use half of dimensions to encode grid_h 105 | emb_h = get_1d_sincos_pos_embed_from_grid( 106 | embed_dim // 2, grid[0] 107 | ) # (H*W, D/2) 108 | emb_w = get_1d_sincos_pos_embed_from_grid( 109 | embed_dim // 2, grid[1] 110 | ) # (H*W, D/2) 111 | 112 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 113 | return emb 114 | 115 | 116 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 117 | """ 118 | embed_dim: output dimension for each position 119 | pos: a list of positions to be encoded: size (M,) 120 | out: (M, D) 121 | """ 122 | assert embed_dim % 2 == 0 123 | omega = np.arange(embed_dim // 2, dtype=np.float32) 124 | omega /= embed_dim / 2.0 125 | omega = 1.0 / 10000**omega # (D/2,) 126 | 127 | pos = pos.reshape(-1) # (M,) 128 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 129 | 130 | emb_sin = np.sin(out) # (M, D/2) 131 | emb_cos = np.cos(out) # (M, D/2) 132 | 133 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 134 | return emb 135 | 136 | 137 | def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'): 138 | if pos_name in checkpoint_model: 139 | pos_embed_checkpoint = checkpoint_model[pos_name] 140 | embedding_size = pos_embed_checkpoint.shape[-1] # channel dim 141 | num_patches = model.patch_embed.num_patches # 142 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 143 | 144 | # we use 4 frames for pretraining 145 | new_t_size = model.T 146 | # height (== width) for the checkpoint position embedding 147 | orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) 148 | # height (== width) for the new position embedding 149 | new_size = int((num_patches // (new_t_size))** 0.5) 150 | 151 | # class_token and dist_token are kept unchanged 152 | if orig_t_size != new_t_size: 153 | print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") 154 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 155 | # only the position tokens are interpolated 156 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 157 | # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) 158 | pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) 159 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) 160 | pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') 161 | pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) 162 | pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) 163 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 164 | checkpoint_model[pos_name] = new_pos_embed 165 | pos_embed_checkpoint = new_pos_embed 166 | 167 | # class_token and dist_token are kept unchanged 168 | if orig_size != new_size: 169 | print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") 170 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 171 | # only the position tokens are interpolated 172 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 173 | # B, L, C -> BT, H, W, C -> BT, C, H, W 174 | pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) 175 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 176 | pos_tokens = torch.nn.functional.interpolate( 177 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 178 | # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C 179 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) 180 | pos_tokens = pos_tokens.flatten(1, 3) # B, L, C 181 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 182 | checkpoint_model[pos_name] = new_pos_embed 183 | else: 184 | raise NotImplementedError 185 | -------------------------------------------------------------------------------- /single_modality/models/vid_tldr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | import math 8 | import torch 9 | from einops import rearrange 10 | from typing import Callable, Tuple, List, Union 11 | import torch.nn.functional as F 12 | 13 | def get_objective_score(score_attn, eps=1e-17): 14 | # Mean across the first dimension 15 | score_attn = score_attn.mean(dim=1) 16 | 17 | # Use torch.clamp with a smaller range 18 | score_attn = torch.clamp(score_attn, min=eps, max=1.0 - eps) 19 | score_attn = F.normalize(score_attn, p=1, dim=-1) 20 | 21 | # Compute entropy using a numerically stable method 22 | scores = score_attn * torch.log(score_attn) 23 | scores = scores.sum(dim=2).unsqueeze(-1) 24 | 25 | # BACKGROUND REMOVING 26 | B, T_R, _ = scores.shape 27 | scores = scores - scores.amin(dim=1, keepdim=True) 28 | scores = scores / (scores.amax(dim=1, keepdim=True)) 29 | score_mean = scores.mean(dim=1, keepdim=True) 30 | score_mask = scores < score_mean 31 | 32 | # FOREGROUND SHARPENING 33 | scores = scores - score_mean 34 | scores = scores / (scores.amax(dim=1, keepdim=True)) 35 | scores = scores.masked_fill(score_mask, 0.0) 36 | 37 | return scores 38 | 39 | def gini_impurity(probabilities): 40 | return 1 - torch.sum(probabilities ** 2, dim=-1) 41 | 42 | 43 | def get_objective_score_gini(score_attn): 44 | score_attn = score_attn.mean(dim=1) 45 | 46 | if torch.isnan(score_attn).any(): 47 | raise ValueError('The Score Value has NAN before impurity operation.') 48 | 49 | # score_attn = torch.clamp(score_attn, min=eps, max=1.0 - eps) 50 | 51 | # Normalize to ensure it sums to 1 along the last dimension 52 | # score_attn = F.normalize(score_attn, p=1, dim=-1) 53 | 54 | # Compute Gini impurity 55 | scores = gini_impurity(score_attn).unsqueeze(-1) 56 | 57 | if torch.isnan(scores).any(): 58 | raise ValueError('The Score Value has NAN after impurity computation.') 59 | 60 | # BACKGROUND REMOVING 61 | B, T_R, _ = scores.shape 62 | scores = scores - scores.amin(dim=1, keepdim=True) 63 | scores = scores / (scores.amax(dim=1, keepdim=True)) 64 | score_mean = scores.mean(dim=1, keepdim=True) 65 | score_mask = scores < score_mean 66 | 67 | # FOREGROUND SHARPENING 68 | scores = scores - score_mean 69 | scores = scores / (scores.amax(dim=1, keepdim=True)) 70 | scores = scores.masked_fill(score_mask, 0.0) 71 | 72 | return scores 73 | 74 | 75 | def vidTLDR(x, attn, r, with_cls_token=True, use_gini=False): 76 | 77 | B, T, _ = x.shape 78 | r_merge = T - r 79 | r_merge = max(min(r_merge, T // 2, T), 0) 80 | if not r_merge: 81 | return x 82 | 83 | with torch.no_grad(): 84 | if use_gini: 85 | score_obj = get_objective_score_gini(attn) 86 | else: 87 | score_obj = get_objective_score(attn) 88 | 89 | merge = merging( 90 | x, 91 | r_merge = r_merge, 92 | score_obj = score_obj, 93 | with_cls_token = with_cls_token, 94 | ) 95 | 96 | return merge 97 | 98 | 99 | def merging( 100 | metric: torch.Tensor, 101 | r_merge : int, 102 | score_obj : torch.Tensor, 103 | with_cls_token = True, 104 | ): 105 | 106 | with torch.no_grad(): 107 | metric = metric / metric.norm(dim=-1, keepdim=True) # (1, 2352, 768) 108 | 109 | # SECTION I. TOKEN MERGING 110 | a, b = metric[..., ::2, :], metric[..., 1::2, :] # (12, 99, 64), (12, 98, 64) 111 | n, s, t1, t2 = a.shape[0], a.shape[1], a.shape[-2], b.shape[-2] 112 | 113 | scores = (a @ b.transpose(-1, -2) + 1) / 2 # 0 - 1 114 | 115 | if with_cls_token: 116 | scores[..., 0, :] = -math.inf 117 | 118 | # TOKEN MERGING 119 | node_max, node_idx = scores.max(dim=-1) # (12, 99), (12, 99) 120 | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # (12, 99, 1) 121 | unm_idx = edge_idx[..., r_merge:, :] # Unmerged Tokens (12, 83, 1) 122 | src_idx = edge_idx[..., :r_merge, :] # Merged Tokens (12, 16, 1) 123 | dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) # (12, 16, 1) 124 | unm_idx = unm_idx.sort(dim=1)[0] 125 | 126 | src_so = None 127 | if score_obj is not None: 128 | src_so, dst_so = score_obj[..., ::2, :], score_obj[..., 1::2, :] # (1, 1176, 1) 129 | src_so = src_so.gather(dim=-2, index=src_idx) # (12, 91, 197) 130 | 131 | def merge(x: torch.Tensor, mode = "sum", dtype=torch.float32): 132 | ori_dtype = x.dtype 133 | x = x.to(dtype=dtype) 134 | src, dst = x[..., ::2, :], x[..., 1::2, :] # (12, 99, 197), (12, 98, 197) 135 | n, mid, c = src.shape[0], src.shape[1:-2], src.shape[-1] 136 | unm = src.gather(dim=-2, index=unm_idx.expand(n, *mid, t1 - r_merge, c)) # (12, 91, 197) 137 | src = src.gather(dim=-2, index=src_idx.expand(n, *mid, r_merge, c)) 138 | 139 | if score_obj is not None: 140 | src = src * src_so 141 | 142 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, *mid, r_merge, c), src, reduce=mode) # (12, 98, 197) 143 | 144 | x = torch.cat([unm, dst], dim=-2) # (12, 1 + 180, 197) 145 | x = x.to(dtype=ori_dtype) 146 | return x 147 | 148 | return merge -------------------------------------------------------------------------------- /single_modality/optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | # from timm.optim.nadam import Nadam 9 | # from timm.optim.novograd import NovoGrad 10 | # from timm.optim.nvnovograd import NvNovoGrad 11 | # from timm.optim.radam import RAdam 12 | # from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_vit(var_name, num_max_layer): 25 | if var_name in ("cls_token", "mask_token", "pos_embed"): 26 | return 0 27 | elif var_name.startswith("patch_embed"): 28 | return 0 29 | elif var_name.startswith("rel_pos_bias"): 30 | return num_max_layer - 1 31 | elif var_name.startswith("blocks"): 32 | layer_id = int(var_name.split('.')[1]) 33 | return layer_id + 1 34 | elif var_name.startswith("transformer.resblocks"): 35 | layer_id = int(var_name.split('.')[2]) 36 | return layer_id + 1 37 | elif var_name in ("class_embedding", "positional_embedding", "temporal_positional_embedding"): 38 | return 0 39 | elif var_name.startswith("conv1"): 40 | return 0 41 | else: 42 | return num_max_layer - 1 43 | 44 | 45 | class LayerDecayValueAssigner(object): 46 | def __init__(self, values): 47 | self.values = values 48 | 49 | def get_scale(self, layer_id): 50 | return self.values[layer_id] 51 | 52 | def get_layer_id(self, var_name): 53 | return get_num_layer_for_vit(var_name, len(self.values)) 54 | 55 | 56 | def get_parameter_groups( 57 | model, weight_decay=1e-5, skip_list=(), get_num_layer=None, 58 | get_layer_scale=None, 59 | ): 60 | parameter_group_names = {} 61 | parameter_group_vars = {} 62 | 63 | for name, param in model.named_parameters(): 64 | if not param.requires_grad: 65 | continue # frozen weights 66 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 67 | group_name = "no_decay" 68 | this_weight_decay = 0. 69 | else: 70 | group_name = "decay" 71 | this_weight_decay = weight_decay 72 | if get_num_layer is not None: 73 | layer_id = get_num_layer(name) 74 | group_name = "layer_%d_%s" % (layer_id, group_name) 75 | else: 76 | layer_id = None 77 | 78 | if group_name not in parameter_group_names: 79 | if get_layer_scale is not None: 80 | scale = get_layer_scale(layer_id) 81 | else: 82 | scale = 1. 83 | 84 | parameter_group_names[group_name] = { 85 | "weight_decay": this_weight_decay, 86 | "params": [], 87 | "lr_scale": scale 88 | } 89 | parameter_group_vars[group_name] = { 90 | "weight_decay": this_weight_decay, 91 | "params": [], 92 | "lr_scale": scale 93 | } 94 | 95 | parameter_group_vars[group_name]["params"].append(param) 96 | parameter_group_names[group_name]["params"].append(name) 97 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 98 | return list(parameter_group_vars.values()) 99 | 100 | 101 | def create_optimizer( 102 | args, model, get_num_layer=None, get_layer_scale=None, 103 | filter_bias_and_bn=True, skip_list=None 104 | ): 105 | opt_lower = args.opt.lower() 106 | weight_decay = args.weight_decay 107 | if weight_decay and filter_bias_and_bn: 108 | skip = {} 109 | if skip_list is not None: 110 | skip = skip_list 111 | elif hasattr(model, 'no_weight_decay'): 112 | skip = model.no_weight_decay() 113 | parameters = get_parameter_groups( 114 | model, weight_decay, skip, get_num_layer, get_layer_scale, 115 | ) 116 | weight_decay = 0. 117 | else: 118 | parameters = model.parameters() 119 | 120 | if 'fused' in opt_lower: 121 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 122 | 123 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 124 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 125 | opt_args['eps'] = args.opt_eps 126 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 127 | opt_args['betas'] = args.opt_betas 128 | 129 | print("optimizer settings:", opt_args) 130 | 131 | opt_split = opt_lower.split('_') 132 | opt_lower = opt_split[-1] 133 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 134 | opt_args.pop('eps', None) 135 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 136 | elif opt_lower == 'momentum': 137 | opt_args.pop('eps', None) 138 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 139 | elif opt_lower == 'adam': 140 | optimizer = optim.Adam(parameters, **opt_args) 141 | elif opt_lower == 'adamw': 142 | optimizer = optim.AdamW(parameters, **opt_args) 143 | # elif opt_lower == 'radam': 144 | # optimizer = RAdam(parameters, **opt_args) 145 | elif opt_lower == 'adamp': 146 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 147 | elif opt_lower == 'sgdp': 148 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 149 | elif opt_lower == 'adadelta': 150 | optimizer = optim.Adadelta(parameters, **opt_args) 151 | elif opt_lower == 'adafactor': 152 | if not args.lr: 153 | opt_args['lr'] = None 154 | optimizer = Adafactor(parameters, **opt_args) 155 | elif opt_lower == 'adahessian': 156 | optimizer = Adahessian(parameters, **opt_args) 157 | elif opt_lower == 'rmsprop': 158 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 159 | # elif opt_lower == 'rmsproptf': 160 | # optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 161 | # elif opt_lower == 'novograd': 162 | # optimizer = NovoGrad(parameters, **opt_args) 163 | # elif opt_lower == 'nvnovograd': 164 | # optimizer = NvNovoGrad(parameters, **opt_args) 165 | elif opt_lower == 'fusedsgd': 166 | opt_args.pop('eps', None) 167 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 168 | elif opt_lower == 'fusedmomentum': 169 | opt_args.pop('eps', None) 170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 171 | elif opt_lower == 'fusedadam': 172 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 173 | elif opt_lower == 'fusedadamw': 174 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 175 | elif opt_lower == 'fusedlamb': 176 | optimizer = FusedLAMB(parameters, **opt_args) 177 | elif opt_lower == 'fusednovograd': 178 | opt_args.setdefault('betas', (0.95, 0.98)) 179 | optimizer = FusedNovoGrad(parameters, **opt_args) 180 | else: 181 | assert False and "Invalid optimizer" 182 | raise ValueError 183 | 184 | if len(opt_split) > 1: 185 | if opt_split[0] == 'lookahead': 186 | optimizer = Lookahead(optimizer) 187 | 188 | return optimizer 189 | --------------------------------------------------------------------------------