├── .gitignore
├── LICENSE
├── README.md
├── figs
├── coin.png
├── k400.png
├── main.png
├── ssv2.png
├── teaser.png
└── zs_vt_retrieval.png
└── single_modality
├── MODEL_ZOO.md
├── __pycache__
├── functional.cpython-39.pyc
├── optim_factory.cpython-39.pyc
└── utils.cpython-39.pyc
├── datasets
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── build.cpython-39.pyc
│ ├── hmdb.cpython-39.pyc
│ ├── kinetics.cpython-39.pyc
│ ├── kinetics_sparse.cpython-39.pyc
│ ├── kinetics_sparse_o4a.cpython-39.pyc
│ ├── mae.cpython-39.pyc
│ ├── mae_multi.cpython-39.pyc
│ ├── mae_multi_ofa.cpython-39.pyc
│ ├── mae_ofa.cpython-39.pyc
│ ├── masking_generator.cpython-39.pyc
│ ├── mixup.cpython-39.pyc
│ ├── rand_augment.cpython-39.pyc
│ ├── random_erasing.cpython-39.pyc
│ ├── ssv2.cpython-39.pyc
│ ├── ssv2_ofa.cpython-39.pyc
│ ├── transforms.cpython-39.pyc
│ ├── video_transforms.cpython-39.pyc
│ └── volume_transforms.cpython-39.pyc
├── build.py
├── hmdb.py
├── kinetics.py
├── kinetics_sparse.py
├── kinetics_sparse_o4a.py
├── mae.py
├── mae_multi.py
├── mae_multi_ofa.py
├── mae_ofa.py
├── masking_generator.py
├── mixup.py
├── multiloader.py
├── rand_augment.py
├── random_erasing.py
├── ssv2.py
├── ssv2_ofa.py
├── transforms.py
├── video_transforms.py
└── volume_transforms.py
├── engines
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── engine_for_flexible_tuning.cpython-39.pyc
├── engine_for_flexible_tuning.py
└── engine_for_flexible_tuning_single.py
├── exp
├── base
│ ├── eval
│ │ └── k400_eval.sh
│ └── ft
│ │ └── k400
│ │ ├── k400_multi.sh
│ │ └── k400_single.sh
└── small
│ └── eval
│ └── k400_eval.sh
├── functional.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── flash_attention_class.cpython-39.pyc
│ ├── fluxvit.cpython-39.pyc
│ ├── pos_embed.cpython-39.pyc
│ └── vid_tldr.cpython-39.pyc
├── flash_attention_class.py
├── fluxvit.py
├── pos_embed.py
└── vid_tldr.py
├── optim_factory.py
├── run_flexible_finetune.py
├── run_flexible_finetune_single.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | multi_modality/*
2 | *.pt
3 | single_modality/rename*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 OpenGVLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Make Your Training Flexible: Towards Deployment-Efficient Video Models
2 |
3 | This repo is the official implementation of "[Make Your Training Flexible: Towards Deployment-Efficient Video Models](https://arxiv.org/abs/2503.14237)". By Chenting Wang, Kunchang Li, Tianxiang Jiang, XiangyuZeng, Yi Wang, and Limin Wang.
4 |
5 |
6 |
7 | ## Update
8 |
9 | - **2025/04/29**: Release of the single modality evaluation script and part of model weights, see [MODEL_ZOO](./single_modality/MODEL_ZOO.md).
10 | - **2025/03/18**: We build the repo and release the [paper](https://arxiv.org/abs/2503.14237).
11 |
12 | ## Introduction
13 |
14 | Popular video training methods mainly operate on a fixed number of tokens sampled from a predetermined spatiotemporal grid, resulting in sub-optimal accuracy-computation trade-offs due to inherent video redundancy. They also lack adaptability to varying computational budgets for downstream tasks, hindering applications of the most competitive model in real-world scenes. We thus propose a new test setting, Token Optimization, for maximized input information across budgets, which optimizes the size-limited set of input tokens through token selection from more suitably sampled videos. To this end, we propose a novel augmentation tool termed Flux. By making the sampling grid flexible and leveraging token selection, it is easily adopted in most popular video training frameworks, boosting model robustness with nearly no additional cost. We integrate Flux in large-scale video pre-training, and the resulting FluxViT establishes new state-of-the-art results across extensive tasks at standard costs. Notably, with 1/4 tokens only, it can still match the performance of previous state-of-the-art models with Token Optimization, yielding nearly 90\% savings.
15 |
16 |
17 |
18 | ## Performance
19 |
20 | ### Single Modality Action Recognition
21 |
22 | #### K400
23 |
24 | - Base Scale
25 |
26 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** |
27 | |------------------|------------------|-----------|----------------|
28 | | InternVideo2-B | 440×12 | 87.4 | - |
29 | | FluxViT-B | 440×12 | 89.6 | 90.0 |
30 | | FluxViT-B | 255×12 | 89.3 | 89.7 |
31 | | FluxViT-B | 108×12 | 87.3 | 88.9 |
32 | | FluxViT-B | 49×12 | 84.7 | 87.4 |
33 |
34 | - Small Scale
35 |
36 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** |
37 | |------------------|------------------|-----------|----------------|
38 | | InternVideo2-S | 154×12 | 85.8 | - |
39 | | FluxViT-S | 154×12 | 87.7 | 88.0 |
40 | | FluxViT-S | 83×12 | 87.3 | 87.7 |
41 | | FluxViT-S | 32×12 | 84.7 | 86.6 |
42 | | FluxViT-S | 13×12 | 80.1 | 84.7 |
43 |
44 | #### SSv2
45 |
46 | - Base Scale
47 |
48 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** |
49 | |------------------|-----------------|-----------|----------------|
50 | | InternVideo2-B | 253×6 | 73.5 | - |
51 | | FluxViT-B | 440×6 | 75.3 | 75.6 |
52 | | FluxViT-B | 255×6 | 75.1 | 75.7 |
53 | | FluxViT-B | 108×6 | 72.0 | 75.1 |
54 | | FluxViT-B | 49×6 | 56.8 | 73.9 |
55 |
56 | - Small Scale
57 |
58 | | **Model** | **GFLOPs** | **Top-1** | **Top-1 + TO** |
59 | |------------------|-----------------|-----------|----------------|
60 | | InternVideo2-S | 154×6 | 71.5 | - |
61 | | FluxViT-S | 154×6 | 73.4 | 73.8 |
62 | | FluxViT-S | 83×6 | 72.9 | 73.4 |
63 | | FluxViT-S | 32×6 | 70.0 | 72.5 |
64 | | FluxViT-S | 13×6 | 55.3 | 70.9 |
65 |
66 | ### Multi Modality Video-Text Retrieval
67 |
68 |
69 |
70 | ### Multi Modality VideoChat Model
71 |
72 | Coming Soon
73 |
74 | ## Acknowledgement
75 |
76 | This repository is built based on [UniFormer](https://github.com/Sense-X/UniFormer), [VideoMAE](https://github.com/MCG-NJU/VideoMAE), [VINDLU](https://github.com/klauscc/VindLU) and [Unmasked Teacher](https://github.com/OpenGVLab/unmasked_teacher/) repository.
77 |
--------------------------------------------------------------------------------
/figs/coin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/coin.png
--------------------------------------------------------------------------------
/figs/k400.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/k400.png
--------------------------------------------------------------------------------
/figs/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/main.png
--------------------------------------------------------------------------------
/figs/ssv2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/ssv2.png
--------------------------------------------------------------------------------
/figs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/teaser.png
--------------------------------------------------------------------------------
/figs/zs_vt_retrieval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/figs/zs_vt_retrieval.png
--------------------------------------------------------------------------------
/single_modality/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # Model Zoo
2 |
3 | ## Note
4 |
5 | - For all the pretraining and finetuning, we adopt spaese/uniform sampling.
6 | - `#Frame` $=$ `#input_frame` $\times$ `#crop` $\times$ `#clip`
7 | - `#input_frame` means how many frames are input for model per inference
8 | - `#crop` means spatial crops (e.g., 3 for left/right/center)
9 | - `#clip` means temporal clips (e.g., 4 means repeted sampling four clips with different start indices)
10 |
11 | ## Pretraining
12 |
13 | TBD
14 |
15 | ## Distillation
16 |
17 | TBD
18 |
19 | ## Finetuning
20 |
21 | ### K710
22 |
23 | TBD
24 |
25 |
26 | ### K400
27 |
28 | | Model | Setting | #Frame | Top-1 | Model | Shell |
29 | | -------- | ------------- | -------- | ------ | ------ | ------ |
30 | | $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 8x3x4 | 91.3 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2-Stage1-1B-224p-f8-K400/blob/main/1B_ft_k710_ft_k400_f8.pth) | TBD |
31 | | $\text{InternVideo2}_{s1}$-1B | K-Mash PT + K710 FT | 16x3x4 | 91.6 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2-Stage1-1B-224p-f8-K400/blob/main/1B_ft_k710_ft_k400_f16.pth) | TBD |
32 | | $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 8x3x4 | 91.9 | TBD | TBD |
33 | | $\text{InternVideo2}_{s1}$-6B | K-Mash PT + K710 FT | 16x3x4 | 92.1 | TBD | TBD |
34 | | $\text{InternVideo2}_{dist}$-S/14 | K-Mash PT + K710 FT | 8x3x4 | 85.4 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2_distillation_models/resolve/main/stage1/S14/S14_ft_k710_ft_k400_f8/pytorch_model.bin) | TBD |
35 | | $\text{InternVideo2}_{dist}$-B/14 | K-Mash PT + K710 FT | 8x3x4 | 88.4 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2_distillation_models/resolve/main/stage1/B14/B14_ft_k710_ft_k400_f8/pytorch_model.bin) | TBD |
36 | | $\text{InternVideo2}_{dist}$-L/14 | K-Mash PT + K710 FT | 8x3x4 | 90.4 | [:hugs: HF link](https://huggingface.co/OpenGVLab/InternVideo2_distillation_models/resolve/main/stage1/L14/L14_ft_k710_ft_k400_f8/pytorch_model.bin) | TBD |
37 | | $\text{FluxViT}$-S/14 | K-Mash PT + K710 FT | 8x3x4 | 87.3 | [Link](https://drive.google.com/file/d/1OTjTsAnZGaq7AufDaw8IYLeSgmLYZjds/view?usp=sharing) | [run.sh](./exp/small/eval/k400_eval.sh) |
38 | | $\text{FluxViT}$-B/14 | K-Mash PT + K710 FT | 8x3x4 | 89.3 | [Link](https://drive.google.com/file/d/1YsxsB3_pkpdvmXQIhD3YOmWlqCyzJskg/view?usp=sharing) | [run.sh](./exp/base/eval/k400_eval.sh) |
39 |
40 |
41 |
42 | ### SthSth V2
43 |
44 | TBD
45 |
--------------------------------------------------------------------------------
/single_modality/__pycache__/functional.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/__pycache__/functional.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/__pycache__/optim_factory.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/__pycache__/optim_factory.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_dataset, build_pretraining_dataset, build_multi_pretraining_dataset, build_pretraining_dataset_ofa
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/build.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/build.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/hmdb.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/hmdb.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/kinetics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/kinetics.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/kinetics_sparse.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/kinetics_sparse.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/kinetics_sparse_o4a.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/kinetics_sparse_o4a.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/mae.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/mae_multi.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae_multi.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/mae_multi_ofa.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae_multi_ofa.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/mae_ofa.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mae_ofa.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/masking_generator.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/masking_generator.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/mixup.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/mixup.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/rand_augment.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/rand_augment.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/random_erasing.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/random_erasing.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/ssv2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/ssv2.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/ssv2_ofa.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/ssv2_ofa.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/video_transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/video_transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/__pycache__/volume_transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/datasets/__pycache__/volume_transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/datasets/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import transforms
3 | from .transforms import *
4 | from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator
5 | from .mae import VideoMAE
6 | from .masking_generator import RandomMaskingGenerator, TemporalConsistencyMaskingGenerator, TemporalProgressiveMaskingGenerator, TemporalCenteringProgressiveMaskingGenerator
7 | from .mae_multi import VideoMAE_multi
8 | from .mae_multi_ofa import VideoMAE_multi_ofa
9 | from .kinetics import VideoClsDataset
10 | from .kinetics_sparse import VideoClsDataset_sparse
11 | from .kinetics_sparse_o4a import VideoClsDataset_sparse_ofa
12 | # from .anet import ANetDataset
13 | from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset
14 | from .ssv2_ofa import SSRawFrameClsDataset_OFA, SSVideoClsDataset_OFA
15 | from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset
16 | from .mae_ofa import VideoMAE_ofa
17 |
18 |
19 | class DataAugmentationForVideoMAE(object):
20 | def __init__(self, args):
21 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
22 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
23 | normalize = GroupNormalize(self.input_mean, self.input_std)
24 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
25 | if args.color_jitter > 0:
26 | self.transform = transforms.Compose([
27 | self.train_augmentation,
28 | GroupColorJitter(args.color_jitter),
29 | GroupRandomHorizontalFlip(flip=args.flip),
30 | Stack(roll=False),
31 | ToTorchFormatTensor(div=True),
32 | normalize,
33 | ])
34 | else:
35 | self.transform = transforms.Compose([
36 | self.train_augmentation,
37 | GroupRandomHorizontalFlip(flip=args.flip),
38 | Stack(roll=False),
39 | ToTorchFormatTensor(div=True),
40 | normalize,
41 | ])
42 | if args.mask_type == 'tube':
43 | self.masked_position_generator = TubeMaskingGenerator(
44 | args.window_size, args.mask_ratio
45 | )
46 | elif args.mask_type == 'random':
47 | self.masked_position_generator = RandomMaskingGenerator(
48 | args.window_size, args.mask_ratio
49 | )
50 | elif args.mask_type == 't_consist':
51 | self.masked_position_generator = TemporalConsistencyMaskingGenerator(
52 | args.window_size, args.student_mask_ratio, args.teacher_mask_ratio
53 | )
54 | elif args.mask_type == 't_progressive':
55 | self.masked_position_generator = TemporalProgressiveMaskingGenerator(
56 | args.window_size, args.student_mask_ratio
57 | )
58 | elif args.mask_type == 't_center_prog':
59 | self.masked_position_generator = TemporalCenteringProgressiveMaskingGenerator(
60 | args.window_size, args.student_mask_ratio
61 | )
62 | elif args.mask_type in 'attention':
63 | self.masked_position_generator = None
64 |
65 | def __call__(self, images):
66 | process_data, _ = self.transform(images)
67 | if self.masked_position_generator is None:
68 | return process_data, -1
69 | else:
70 | return process_data, self.masked_position_generator()
71 |
72 | def __repr__(self):
73 | repr = "(DataAugmentationForVideoMAE,\n"
74 | repr += " transform = %s,\n" % str(self.transform)
75 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
76 | repr += ")"
77 | return repr
78 |
79 |
80 | def build_pretraining_dataset(args):
81 | transform = DataAugmentationForVideoMAE(args)
82 | dataset = VideoMAE(
83 | root=None,
84 | setting=args.data_path,
85 | prefix=args.prefix,
86 | split=args.split,
87 | video_ext='mp4',
88 | is_color=True,
89 | modality='rgb',
90 | num_segments=args.num_segments,
91 | new_length=args.num_frames,
92 | new_step=args.sampling_rate,
93 | transform=transform,
94 | temporal_jitter=False,
95 | video_loader=True,
96 | use_decord=args.use_decord,
97 | lazy_init=False,
98 | num_sample=args.num_sample)
99 | print("Data Aug = %s" % str(transform))
100 | return dataset
101 |
102 | def build_pretraining_dataset_ofa(args):
103 | transform = DataAugmentationForVideoMAE(args)
104 | dataset = VideoMAE_ofa(
105 | root=None,
106 | setting=args.data_path,
107 | prefix=args.prefix,
108 | split=args.split,
109 | video_ext='mp4',
110 | is_color=True,
111 | modality='rgb',
112 | num_segments=args.num_segments,
113 | new_length=args.num_frames,
114 | new_step=args.sampling_rate,
115 | transform=transform,
116 | temporal_jitter=False,
117 | video_loader=True,
118 | use_decord=args.use_decord,
119 | lazy_init=False,
120 | num_sample=args.num_sample)
121 | print("Data Aug = %s" % str(transform))
122 | return dataset
123 |
124 |
125 | def build_once4all_pretraining_dataset(args, num_datasets):
126 | datasets = []
127 | for i in range(num_datasets):
128 | args.input_size = ...
129 | datasets.append(build_pretraining_dataset(args))
130 | return datasets
131 |
132 |
133 | def build_dataset(is_train, test_mode, args):
134 | print(f'Use Dataset: {args.data_set}')
135 | if args.data_set in [
136 | 'Kinetics',
137 | 'Kinetics_sparse',
138 | 'Kinetics_sparse_ofa',
139 | 'mitv1_sparse'
140 | ]:
141 | mode = None
142 | anno_path = None
143 | if is_train is True:
144 | mode = 'train'
145 | anno_path = os.path.join(args.data_path, 'train.csv')
146 | elif test_mode is True:
147 | mode = 'test'
148 | anno_path = os.path.join(args.data_path, 'test.csv')
149 | else:
150 | mode = 'validation'
151 | anno_path = os.path.join(args.data_path, 'val.csv')
152 |
153 | if 'sparse_ofa' in args.data_set:
154 | func = VideoClsDataset_sparse_ofa
155 | elif 'sparse' in args.data_set:
156 | func = VideoClsDataset_sparse
157 | else:
158 | func = VideoClsDataset
159 |
160 | dataset = func(
161 | anno_path=anno_path,
162 | prefix=args.prefix,
163 | split=args.split,
164 | mode=mode,
165 | clip_len=args.num_frames if is_train else args.eval_true_frame,
166 | frame_sample_rate=args.sampling_rate,
167 | num_segment=1,
168 | test_num_segment=args.test_num_segment,
169 | test_num_crop=args.test_num_crop,
170 | num_crop=1 if not test_mode else 3,
171 | keep_aspect_ratio=True,
172 | crop_size=args.input_size if is_train else args.eval_input_size,
173 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size,
174 | new_height=256,
175 | new_width=320,
176 | args=args)
177 |
178 | nb_classes = args.nb_classes
179 |
180 | elif 'SSV2' in args.data_set:
181 | mode = None
182 | anno_path = None
183 | if is_train is True:
184 | mode = 'train'
185 | anno_path = os.path.join(args.data_path, 'train.csv')
186 | elif test_mode is True:
187 | mode = 'test'
188 | anno_path = os.path.join(args.data_path, 'test.csv')
189 | else:
190 | mode = 'validation'
191 | anno_path = os.path.join(args.data_path, 'val.csv')
192 |
193 | if args.use_decord:
194 | if 'ofa' in args.data_set:
195 | func = SSVideoClsDataset_OFA
196 | else:
197 | func = SSVideoClsDataset
198 | else:
199 | if 'ofa' in args.data_set:
200 | func = SSRawFrameClsDataset_OFA
201 | else:
202 | func = SSRawFrameClsDataset
203 |
204 | dataset = func(
205 | anno_path=anno_path,
206 | prefix=args.prefix,
207 | split=args.split,
208 | mode=mode,
209 | clip_len=1,
210 | num_segment=args.num_frames if is_train else args.eval_true_frame,
211 | test_num_segment=args.test_num_segment,
212 | test_num_crop=args.test_num_crop,
213 | num_crop=1 if not test_mode else 3,
214 | keep_aspect_ratio=True,
215 | crop_size=args.input_size if is_train else args.eval_input_size,
216 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size,
217 | new_height=256,
218 | new_width=320,
219 | filename_tmpl=args.filename_tmpl,
220 | args=args)
221 | nb_classes = 174
222 |
223 | elif args.data_set == 'UCF101':
224 | mode = None
225 | anno_path = None
226 | if is_train is True:
227 | mode = 'train'
228 | anno_path = os.path.join(args.data_path, 'train.csv')
229 | elif test_mode is True:
230 | mode = 'test'
231 | anno_path = os.path.join(args.data_path, 'test.csv')
232 | else:
233 | mode = 'validation'
234 | anno_path = os.path.join(args.data_path, 'val.csv')
235 |
236 | dataset = VideoClsDataset(
237 | anno_path=anno_path,
238 | prefix=args.prefix,
239 | split=args.split,
240 | mode=mode,
241 | clip_len=args.num_frames if is_train else args.eval_true_frame,
242 | frame_sample_rate=args.sampling_rate,
243 | num_segment=1,
244 | test_num_segment=args.test_num_segment,
245 | test_num_crop=args.test_num_crop,
246 | num_crop=1 if not test_mode else 3,
247 | keep_aspect_ratio=True,
248 | crop_size=args.input_size if is_train else args.eval_input_size,
249 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size,
250 | new_height=256,
251 | new_width=320,
252 | args=args)
253 | nb_classes = 101
254 |
255 | elif args.data_set == 'HMDB51':
256 | mode = None
257 | anno_path = None
258 | if is_train is True:
259 | mode = 'train'
260 | anno_path = os.path.join(args.data_path, 'train.csv')
261 | elif test_mode is True:
262 | mode = 'test'
263 | anno_path = os.path.join(args.data_path, 'test.csv')
264 | else:
265 | mode = 'validation'
266 | anno_path = os.path.join(args.data_path, 'val.csv')
267 |
268 | if args.use_decord:
269 | func = HMDBVideoClsDataset
270 | else:
271 | func = HMDBRawFrameClsDataset
272 |
273 | dataset = func(
274 | anno_path=anno_path,
275 | prefix=args.prefix,
276 | split=args.split,
277 | mode=mode,
278 | clip_len=1,
279 | num_segment=args.num_frames if is_train else args.eval_true_frame,
280 | test_num_segment=args.test_num_segment,
281 | test_num_crop=args.test_num_crop,
282 | num_crop=1 if not test_mode else 3,
283 | keep_aspect_ratio=True,
284 | crop_size=args.input_size if is_train else args.eval_input_size,
285 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size,
286 | new_height=256,
287 | new_width=320,
288 | filename_tmpl=args.filename_tmpl,
289 | args=args)
290 | nb_classes = 51
291 |
292 | elif args.data_set in [
293 | 'ANet',
294 | 'HACS',
295 | 'ANet_interval',
296 | 'HACS_interval'
297 | ]:
298 | mode = None
299 | anno_path = None
300 | if is_train is True:
301 | mode = 'train'
302 | anno_path = os.path.join(args.data_path, 'train.csv')
303 | elif test_mode is True:
304 | mode = 'test'
305 | anno_path = os.path.join(args.data_path, 'test.csv')
306 | else:
307 | mode = 'validation'
308 | anno_path = os.path.join(args.data_path, 'val.csv')
309 |
310 | if 'interval' in args.data_set:
311 | func = ANetDataset
312 | else:
313 | func = VideoClsDataset_sparse
314 |
315 | dataset = func(
316 | anno_path=anno_path,
317 | prefix=args.prefix,
318 | split=args.split,
319 | mode=mode,
320 | clip_len=args.num_frames if is_train else args.eval_true_frame,
321 | frame_sample_rate=args.sampling_rate,
322 | num_segment=1,
323 | test_num_segment=args.test_num_segment,
324 | test_num_crop=args.test_num_crop,
325 | num_crop=1 if not test_mode else 3,
326 | keep_aspect_ratio=True,
327 | crop_size=args.input_size if is_train else args.eval_input_size,
328 | short_side_size=args.short_side_size if is_train else args.eval_short_side_size,
329 | new_height=256,
330 | new_width=320,
331 | args=args)
332 | nb_classes = args.nb_classes
333 |
334 | else:
335 | print(f'Wrong: {args.data_set}')
336 | raise NotImplementedError()
337 | assert nb_classes == args.nb_classes
338 | print("Number of the class = %d" % args.nb_classes)
339 |
340 | return dataset, nb_classes
341 |
342 |
343 | def build_multi_pretraining_dataset(args):
344 | origianl_flip = args.flip
345 | transform = DataAugmentationForVideoMAE(args)
346 | args.flip = False
347 | transform_ssv2 = DataAugmentationForVideoMAE(args)
348 | args.flip = origianl_flip
349 |
350 | dataset = VideoMAE_multi_ofa(
351 | root=None,
352 | setting=args.data_path,
353 | prefix=args.prefix,
354 | split=args.split,
355 | is_color=True,
356 | modality='rgb',
357 | num_segments=args.num_segments,
358 | new_length=args.num_frames,
359 | new_step=args.sampling_rate,
360 | transform=transform,
361 | transform_ssv2=transform_ssv2,
362 | temporal_jitter=False,
363 | video_loader=True,
364 | use_decord=args.use_decord,
365 | lazy_init=False,
366 | num_sample=args.num_sample)
367 | print("Data Aug = %s" % str(transform))
368 | print("Data Aug for SSV2 = %s" % str(transform_ssv2))
369 | return dataset
370 |
--------------------------------------------------------------------------------
/single_modality/datasets/kinetics_sparse.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os
3 | import io
4 | import random
5 | import numpy as np
6 | from numpy.lib.function_base import disp
7 | import torch
8 | from torchvision import transforms
9 | import warnings
10 | from decord import VideoReader, cpu
11 | from torch.utils.data import Dataset
12 | from .random_erasing import RandomErasing
13 | from .video_transforms import (
14 | Compose, Resize, CenterCrop, Normalize,
15 | create_random_augment, random_short_side_scale_jitter,
16 | random_crop, random_resized_crop_with_shift, random_resized_crop,
17 | horizontal_flip, random_short_side_scale_jitter, uniform_crop,
18 | )
19 | from .volume_transforms import ClipToTensor
20 |
21 | try:
22 | from petrel_client.client import Client
23 | has_client = True
24 | except ImportError:
25 | has_client = False
26 |
27 | class VideoClsDataset_sparse(Dataset):
28 | """Load your own video classification dataset."""
29 |
30 | def __init__(self, anno_path, prefix='', split=' ', mode='train', clip_len=8,
31 | frame_sample_rate=2, crop_size=224, short_side_size=256,
32 | new_height=256, new_width=340, keep_aspect_ratio=True,
33 | num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,
34 | args=None):
35 | self.anno_path = anno_path
36 | self.prefix = prefix
37 | self.split = split
38 | self.mode = mode
39 | self.clip_len = clip_len
40 | self.frame_sample_rate = frame_sample_rate
41 | self.crop_size = crop_size
42 | self.short_side_size = short_side_size
43 | self.new_height = new_height
44 | self.new_width = new_width
45 | self.keep_aspect_ratio = keep_aspect_ratio
46 | self.num_segment = num_segment
47 | self.test_num_segment = test_num_segment
48 | self.num_crop = num_crop
49 | self.test_num_crop = test_num_crop
50 | self.args = args
51 | self.aug = False
52 | self.rand_erase = False
53 | assert num_segment == 1
54 | if self.mode in ['train']:
55 | self.aug = True
56 | if self.args.reprob > 0:
57 | self.rand_erase = True
58 | if VideoReader is None:
59 | raise ImportError("Unable to import `decord` which is required to read videos.")
60 |
61 | import pandas as pd
62 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=self.split)
63 | self.dataset_samples = list(cleaned.values[:, 0])
64 | self.label_array = list(cleaned.values[:, 1])
65 |
66 | self.client = None
67 | if has_client:
68 | self.client = Client('~/petreloss.conf')
69 |
70 | if (mode == 'train'):
71 | pass
72 |
73 | elif (mode == 'validation'):
74 | self.data_transform = Compose([
75 | Resize(self.short_side_size, interpolation='bilinear'),
76 | CenterCrop(size=(self.crop_size, self.crop_size)),
77 | ClipToTensor(),
78 | Normalize(mean=[0.485, 0.456, 0.406],
79 | std=[0.229, 0.224, 0.225])
80 | ])
81 | elif mode == 'test':
82 | self.data_resize = Compose([
83 | Resize(size=(short_side_size), interpolation='bilinear')
84 | ])
85 | self.data_transform = Compose([
86 | ClipToTensor(),
87 | Normalize(mean=[0.485, 0.456, 0.406],
88 | std=[0.229, 0.224, 0.225])
89 | ])
90 | self.test_seg = []
91 | self.test_dataset = []
92 | self.test_label_array = []
93 | for ck in range(self.test_num_segment):
94 | for cp in range(self.test_num_crop):
95 | for idx in range(len(self.label_array)):
96 | sample_label = self.label_array[idx]
97 | self.test_label_array.append(sample_label)
98 | self.test_dataset.append(self.dataset_samples[idx])
99 | self.test_seg.append((ck, cp))
100 |
101 | def __getitem__(self, index):
102 | if self.mode == 'train':
103 | args = self.args
104 |
105 | sample = self.dataset_samples[index]
106 | buffer = self.loadvideo_decord(sample, chunk_nb=-1) # T H W C
107 | if len(buffer) == 0:
108 | while len(buffer) == 0:
109 | warnings.warn("video {} not correctly loaded during training".format(sample))
110 | index = np.random.randint(self.__len__())
111 | sample = self.dataset_samples[index]
112 | buffer = self.loadvideo_decord(sample, chunk_nb=-1)
113 |
114 | if args.num_sample > 1:
115 | frame_list = []
116 | label_list = []
117 | index_list = []
118 | for _ in range(args.num_sample):
119 | new_frames = self._aug_frame(buffer, args)
120 | label = self.label_array[index]
121 | frame_list.append(new_frames)
122 | label_list.append(label)
123 | index_list.append(index)
124 | return frame_list, label_list, index_list, {}
125 | else:
126 | buffer = self._aug_frame(buffer, args)
127 |
128 | return buffer, self.label_array[index], index, {}
129 |
130 | elif self.mode == 'validation':
131 | sample = self.dataset_samples[index]
132 | buffer = self.loadvideo_decord(sample, chunk_nb=0)
133 | if len(buffer) == 0:
134 | while len(buffer) == 0:
135 | warnings.warn("video {} not correctly loaded during validation".format(sample))
136 | index = np.random.randint(self.__len__())
137 | sample = self.dataset_samples[index]
138 | buffer = self.loadvideo_decord(sample, chunk_nb=0)
139 | buffer = self.data_transform(buffer)
140 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
141 |
142 | elif self.mode == 'test':
143 | sample = self.test_dataset[index]
144 | chunk_nb, split_nb = self.test_seg[index]
145 | buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
146 |
147 | while len(buffer) == 0:
148 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
149 | str(self.test_dataset[index]), chunk_nb, split_nb))
150 | index = np.random.randint(self.__len__())
151 | sample = self.test_dataset[index]
152 | chunk_nb, split_nb = self.test_seg[index]
153 | buffer = self.loadvideo_decord(sample, chunk_nb=chunk_nb)
154 |
155 | buffer = self.data_resize(buffer)
156 | if isinstance(buffer, list):
157 | buffer = np.stack(buffer, 0)
158 |
159 | if self.test_num_crop == 1:
160 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) / 2
161 | spatial_start = int(spatial_step)
162 | else:
163 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
164 | / (self.test_num_crop - 1)
165 | spatial_start = int(split_nb * spatial_step)
166 |
167 | if buffer.shape[1] >= buffer.shape[2]:
168 | buffer = buffer[:, spatial_start:spatial_start + self.short_side_size, :, :]
169 | else:
170 | buffer = buffer[:, :, spatial_start:spatial_start + self.short_side_size, :]
171 |
172 | buffer = self.data_transform(buffer)
173 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
174 | chunk_nb, split_nb
175 | else:
176 | raise NameError('mode {} unkown'.format(self.mode))
177 |
178 | def _aug_frame(
179 | self,
180 | buffer,
181 | args,
182 | ):
183 |
184 | aug_transform = create_random_augment(
185 | input_size=(self.crop_size, self.crop_size),
186 | auto_augment=args.aa,
187 | interpolation=args.train_interpolation,
188 | )
189 |
190 | buffer = [
191 | transforms.ToPILImage()(frame) for frame in buffer
192 | ]
193 |
194 | buffer = aug_transform(buffer)
195 |
196 | buffer = [transforms.ToTensor()(img) for img in buffer]
197 | buffer = torch.stack(buffer) # T C H W
198 | buffer = buffer.permute(0, 2, 3, 1) # T H W C
199 |
200 | # T H W C
201 | buffer = tensor_normalize(
202 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
203 | )
204 | # T H W C -> C T H W.
205 | buffer = buffer.permute(3, 0, 1, 2)
206 | # Perform data augmentation.
207 | scl, asp = (
208 | [0.08, 1.0],
209 | [0.75, 1.3333],
210 | )
211 |
212 | buffer = spatial_sampling(
213 | buffer,
214 | spatial_idx=-1,
215 | min_scale=256,
216 | max_scale=320,
217 | crop_size=self.crop_size,
218 | random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
219 | inverse_uniform_sampling=False,
220 | aspect_ratio=asp,
221 | scale=scl,
222 | motion_shift=False
223 | )
224 |
225 | if self.rand_erase:
226 | erase_transform = RandomErasing(
227 | args.reprob,
228 | mode=args.remode,
229 | max_count=args.recount,
230 | num_splits=args.recount,
231 | device="cpu",
232 | )
233 | buffer = buffer.permute(1, 0, 2, 3)
234 | buffer = erase_transform(buffer)
235 | buffer = buffer.permute(1, 0, 2, 3)
236 |
237 | return buffer
238 |
239 | def _get_seq_frames(self, video_size, num_frames, clip_idx=-1):
240 | seg_size = max(0., float(video_size - 1) / num_frames)
241 | max_frame = int(video_size) - 1
242 | seq = []
243 | # index from 1, must add 1
244 | if clip_idx == -1:
245 | for i in range(num_frames):
246 | start = int(np.round(seg_size * i))
247 | end = int(np.round(seg_size * (i + 1)))
248 | idx = min(random.randint(start, end), max_frame)
249 | seq.append(idx)
250 | else:
251 | num_segment = 1
252 | if self.mode == 'test':
253 | num_segment = self.test_num_segment
254 | duration = seg_size / (num_segment + 1)
255 | for i in range(num_frames):
256 | start = int(np.round(seg_size * i))
257 | frame_index = start + int(duration * (clip_idx + 1))
258 | idx = min(frame_index, max_frame)
259 | seq.append(idx)
260 | return seq
261 |
262 | def loadvideo_decord(self, sample, chunk_nb=0):
263 | """Load video content using Decord"""
264 | fname = sample
265 | fname = os.path.join(self.prefix, fname)
266 |
267 | try:
268 | if self.keep_aspect_ratio:
269 | if "s3://" in fname:
270 | video_bytes = self.client.get(fname)
271 | vr = VideoReader(io.BytesIO(video_bytes),
272 | num_threads=1,
273 | ctx=cpu(0))
274 | else:
275 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
276 | else:
277 | if "s3://" in fname:
278 | video_bytes = self.client.get(fname)
279 | vr = VideoReader(io.BytesIO(video_bytes),
280 | width=self.new_width,
281 | height=self.new_height,
282 | num_threads=1,
283 | ctx=cpu(0))
284 | else:
285 | vr = VideoReader(fname, width=self.new_width, height=self.new_height,
286 | num_threads=1, ctx=cpu(0))
287 |
288 | all_index = self._get_seq_frames(len(vr), self.clip_len, clip_idx=chunk_nb)
289 | vr.seek(0)
290 | buffer = vr.get_batch(all_index).asnumpy()
291 | return buffer
292 | except:
293 | print("video cannot be loaded by decord: ", fname)
294 | return []
295 |
296 | def __len__(self):
297 | if self.mode != 'test':
298 | return len(self.dataset_samples)
299 | else:
300 | return len(self.test_dataset)
301 |
302 |
303 | def spatial_sampling(
304 | frames,
305 | spatial_idx=-1,
306 | min_scale=256,
307 | max_scale=320,
308 | crop_size=224,
309 | random_horizontal_flip=True,
310 | inverse_uniform_sampling=False,
311 | aspect_ratio=None,
312 | scale=None,
313 | motion_shift=False,
314 | ):
315 | """
316 | Perform spatial sampling on the given video frames. If spatial_idx is
317 | -1, perform random scale, random crop, and random flip on the given
318 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
319 | with the given spatial_idx.
320 | Args:
321 | frames (tensor): frames of images sampled from the video. The
322 | dimension is `num frames` x `height` x `width` x `channel`.
323 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
324 | or 2, perform left, center, right crop if width is larger than
325 | height, and perform top, center, buttom crop if height is larger
326 | than width.
327 | min_scale (int): the minimal size of scaling.
328 | max_scale (int): the maximal size of scaling.
329 | crop_size (int): the size of height and width used to crop the
330 | frames.
331 | inverse_uniform_sampling (bool): if True, sample uniformly in
332 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
333 | scale. If False, take a uniform sample from [min_scale,
334 | max_scale].
335 | aspect_ratio (list): Aspect ratio range for resizing.
336 | scale (list): Scale range for resizing.
337 | motion_shift (bool): Whether to apply motion shift for resizing.
338 | Returns:
339 | frames (tensor): spatially sampled frames.
340 | """
341 | assert spatial_idx in [-1, 0, 1, 2]
342 | if spatial_idx == -1:
343 | if aspect_ratio is None and scale is None:
344 | frames, _ = random_short_side_scale_jitter(
345 | images=frames,
346 | min_size=min_scale,
347 | max_size=max_scale,
348 | inverse_uniform_sampling=inverse_uniform_sampling,
349 | )
350 | frames, _ = random_crop(frames, crop_size)
351 | else:
352 | transform_func = (
353 | random_resized_crop_with_shift
354 | if motion_shift
355 | else random_resized_crop
356 | )
357 | frames = transform_func(
358 | images=frames,
359 | target_height=crop_size,
360 | target_width=crop_size,
361 | scale=scale,
362 | ratio=aspect_ratio,
363 | )
364 | if random_horizontal_flip:
365 | frames, _ = horizontal_flip(0.5, frames)
366 | else:
367 | # The testing is deterministic and no jitter should be performed.
368 | # min_scale, max_scale, and crop_size are expect to be the same.
369 | assert len({min_scale, max_scale, crop_size}) == 1
370 | frames, _ = random_short_side_scale_jitter(
371 | frames, min_scale, max_scale
372 | )
373 | frames, _ = uniform_crop(frames, crop_size, spatial_idx)
374 | return frames
375 |
376 |
377 | def tensor_normalize(tensor, mean, std):
378 | """
379 | Normalize a given tensor by subtracting the mean and dividing the std.
380 | Args:
381 | tensor (tensor): tensor to normalize.
382 | mean (tensor or list): mean value to subtract.
383 | std (tensor or list): std to divide.
384 | """
385 | if tensor.dtype == torch.uint8:
386 | tensor = tensor.float()
387 | tensor = tensor / 255.0
388 | if type(mean) == list:
389 | mean = torch.tensor(mean)
390 | if type(std) == list:
391 | std = torch.tensor(std)
392 | tensor = tensor - mean
393 | tensor = tensor / std
394 | return tensor
395 |
396 |
--------------------------------------------------------------------------------
/single_modality/datasets/mae.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import io
4 | import numpy as np
5 | import torch
6 | import decord
7 | from PIL import Image
8 | from decord import VideoReader, cpu
9 | import random
10 |
11 | try:
12 | from petrel_client.client import Client
13 | has_client = True
14 | except ImportError:
15 | has_client = False
16 |
17 |
18 | class VideoMAE(torch.utils.data.Dataset):
19 | """Load your own video classification dataset.
20 | Parameters
21 | ----------
22 | root : str, required.
23 | Path to the root folder storing the dataset.
24 | setting : str, required.
25 | A text file describing the dataset, each line per video sample.
26 | There are three items in each line: (1) video path; (2) video length and (3) video label.
27 | prefix : str, required.
28 | The prefix for loading data.
29 | split : str, required.
30 | The split character for metadata.
31 | train : bool, default True.
32 | Whether to load the training or validation set.
33 | test_mode : bool, default False.
34 | Whether to perform evaluation on the test set.
35 | Usually there is three-crop or ten-crop evaluation strategy involved.
36 | name_pattern : str, default None.
37 | The naming pattern of the decoded video frames.
38 | For example, img_00012.jpg.
39 | video_ext : str, default 'mp4'.
40 | If video_loader is set to True, please specify the video format accordinly.
41 | is_color : bool, default True.
42 | Whether the loaded image is color or grayscale.
43 | modality : str, default 'rgb'.
44 | Input modalities, we support only rgb video frames for now.
45 | Will add support for rgb difference image and optical flow image later.
46 | num_segments : int, default 1.
47 | Number of segments to evenly divide the video into clips.
48 | A useful technique to obtain global video-level information.
49 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
50 | num_crop : int, default 1.
51 | Number of crops for each image. default is 1.
52 | Common choices are three crops and ten crops during evaluation.
53 | new_length : int, default 1.
54 | The length of input video clip. Default is a single image, but it can be multiple video frames.
55 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
56 | new_step : int, default 1.
57 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
58 | new_step=2 means we will extract a video clip of every other frame.
59 | temporal_jitter : bool, default False.
60 | Whether to temporally jitter if new_step > 1.
61 | video_loader : bool, default False.
62 | Whether to use video loader to load data.
63 | use_decord : bool, default True.
64 | Whether to use Decord video loader to load data. Otherwise load image.
65 | transform : function, default None.
66 | A function that takes data and label and transforms them.
67 | data_aug : str, default 'v1'.
68 | Different types of data augmentation auto. Supports v1, v2, v3 and v4.
69 | lazy_init : bool, default False.
70 | If set to True, build a dataset instance without loading any dataset.
71 | """
72 | def __init__(self,
73 | root,
74 | setting,
75 | prefix='',
76 | split=' ',
77 | train=True,
78 | test_mode=False,
79 | name_pattern='img_%05d.jpg',
80 | video_ext='mp4',
81 | is_color=True,
82 | modality='rgb',
83 | num_segments=1,
84 | num_crop=1,
85 | new_length=1,
86 | new_step=1,
87 | transform=None,
88 | temporal_jitter=False,
89 | video_loader=False,
90 | use_decord=True,
91 | lazy_init=False,
92 | num_sample=1,
93 | ):
94 |
95 | super(VideoMAE, self).__init__()
96 | self.root = root
97 | self.setting = setting
98 | self.prefix = prefix
99 | self.split = split
100 | self.train = train
101 | self.test_mode = test_mode
102 | self.is_color = is_color
103 | self.modality = modality
104 | self.num_segments = num_segments
105 | self.num_crop = num_crop
106 | self.new_length = new_length
107 | self.new_step = new_step
108 | self.skip_length = self.new_length * self.new_step
109 | self.temporal_jitter = temporal_jitter
110 | self.name_pattern = name_pattern
111 | self.video_loader = video_loader
112 | self.video_ext = video_ext
113 | self.use_decord = use_decord
114 | self.transform = transform
115 | self.lazy_init = lazy_init
116 | self.num_sample = num_sample
117 |
118 | # sparse sampling, num_segments != 1
119 | if self.num_segments != 1:
120 | print('Use sparse sampling, change frame and stride')
121 | self.new_length = self.num_segments
122 | self.skip_length = 1
123 |
124 | self.client = None
125 | if has_client:
126 | self.client = Client('~/petreloss.conf')
127 |
128 | if not self.lazy_init:
129 | self.clips = self._make_dataset(root, setting)
130 | if len(self.clips) == 0:
131 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
132 | "Check your data directory (opt.data-dir)."))
133 |
134 | def __getitem__(self, index):
135 | while True:
136 | try:
137 | images = None
138 | if self.use_decord:
139 | directory, target = self.clips[index]
140 | if self.video_loader:
141 | if '.' in directory.split('/')[-1]:
142 | # data in the "setting" file already have extension, e.g., demo.mp4
143 | video_name = directory
144 | else:
145 | # data in the "setting" file do not have extension, e.g., demo
146 | # So we need to provide extension (i.e., .mp4) to complete the file name.
147 | video_name = '{}.{}'.format(directory, self.video_ext)
148 |
149 | video_name = os.path.join(self.prefix, video_name)
150 | if video_name.startswith('s3') or video_name.startswith('p2:s3'):
151 | video_bytes = self.client.get(video_name)
152 | decord_vr = VideoReader(io.BytesIO(video_bytes),
153 | num_threads=1,
154 | ctx=cpu(0))
155 | else:
156 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
157 | duration = len(decord_vr)
158 |
159 | segment_indices, skip_offsets = self._sample_train_indices(duration)
160 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets)
161 |
162 | else:
163 | video_name, total_frame, target = self.clips[index]
164 | video_name = os.path.join(self.prefix, video_name)
165 |
166 | segment_indices, skip_offsets = self._sample_train_indices(total_frame)
167 | frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets)
168 | images = []
169 | for idx in frame_id_list:
170 | frame_fname = os.path.join(video_name, self.name_pattern.format(idx))
171 | img_bytes = self.client.get(frame_fname)
172 | img_np = np.frombuffer(img_bytes, np.uint8)
173 | img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
174 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
175 | images.append(Image.fromarray(img))
176 | if images is not None:
177 | break
178 | except Exception as e:
179 | print("Failed to load video from {} with error {}".format(
180 | video_name, e))
181 | index = random.randint(0, len(self.clips) - 1)
182 |
183 | if self.num_sample > 1:
184 | process_data_list = []
185 | mask_list = []
186 | for _ in range(self.num_sample):
187 | process_data, mask = self.transform((images, None))
188 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1)
189 | process_data_list.append(process_data)
190 | mask_list.append(mask)
191 | return process_data_list, mask_list
192 | else:
193 | process_data, mask = self.transform((images, None)) # T*C,H,W
194 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
195 | return (process_data, mask)
196 |
197 | def __len__(self):
198 | return len(self.clips)
199 |
200 | def _make_dataset(self, directory, setting):
201 | if not os.path.exists(setting):
202 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
203 | clips = []
204 |
205 | print(f'Load dataset using decord: {self.use_decord}')
206 | with open(setting) as split_f:
207 | data = split_f.readlines()
208 | for line in data:
209 | line_info = line.split(self.split)
210 | if len(line_info) < 2:
211 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
212 | if self.use_decord:
213 | # line format: video_path, video_label
214 | clip_path = os.path.join(line_info[0])
215 | target = int(line_info[1])
216 | item = (clip_path, target)
217 | else:
218 | # line format: video_path, video_duration, video_label
219 | clip_path = os.path.join(line_info[0])
220 | total_frame = int(line_info[1])
221 | target = int(line_info[2])
222 | item = (clip_path, total_frame, target)
223 | clips.append(item)
224 | return clips
225 |
226 | def _sample_train_indices(self, num_frames):
227 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments
228 | if average_duration > 0:
229 | offsets = np.multiply(list(range(self.num_segments)),
230 | average_duration)
231 | offsets = offsets + np.random.randint(average_duration,
232 | size=self.num_segments)
233 | elif num_frames > max(self.num_segments, self.skip_length):
234 | offsets = np.sort(np.random.randint(
235 | num_frames - self.skip_length + 1,
236 | size=self.num_segments))
237 | else:
238 | offsets = np.zeros((self.num_segments,))
239 |
240 | if self.temporal_jitter:
241 | skip_offsets = np.random.randint(
242 | self.new_step, size=self.skip_length // self.new_step)
243 | else:
244 | skip_offsets = np.zeros(
245 | self.skip_length // self.new_step, dtype=int)
246 | return offsets + 1, skip_offsets
247 |
248 | def _get_frame_id_list(self, duration, indices, skip_offsets):
249 | frame_id_list = []
250 | for seg_ind in indices:
251 | offset = int(seg_ind)
252 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
253 | if offset + skip_offsets[i] <= duration:
254 | frame_id = offset + skip_offsets[i] - 1
255 | else:
256 | frame_id = offset - 1
257 | frame_id_list.append(frame_id)
258 | if offset + self.new_step < duration:
259 | offset += self.new_step
260 | return frame_id_list
261 |
262 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets):
263 | sampled_list = []
264 | frame_id_list = []
265 | for seg_ind in indices:
266 | offset = int(seg_ind)
267 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
268 | if offset + skip_offsets[i] <= duration:
269 | frame_id = offset + skip_offsets[i] - 1
270 | else:
271 | frame_id = offset - 1
272 | frame_id_list.append(frame_id)
273 | if offset + self.new_step < duration:
274 | offset += self.new_step
275 | try:
276 | video_data = video_reader.get_batch(frame_id_list).asnumpy()
277 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
278 | except:
279 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration))
280 | return sampled_list
281 |
--------------------------------------------------------------------------------
/single_modality/datasets/mae_multi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import io
4 | import numpy as np
5 | import torch
6 | import decord
7 | from PIL import Image
8 | from decord import VideoReader, cpu
9 | import random
10 |
11 | try:
12 | from petrel_client.client import Client
13 | has_client = True
14 | except ImportError:
15 | has_client = False
16 |
17 |
18 | class VideoMAE_multi(torch.utils.data.Dataset):
19 | """Load your own video classification dataset.
20 | Parameters
21 | ----------
22 | root : str, required.
23 | Path to the root folder storing the dataset.
24 | setting : str, required.
25 | A text file describing the dataset, each line per video sample.
26 | There are three items in each line: (1) video path; (2) video length and (3) video label.
27 | prefix : str, required.
28 | The prefix for loading data.
29 | split : str, required.
30 | The split character for metadata.
31 | train : bool, default True.
32 | Whether to load the training or validation set.
33 | test_mode : bool, default False.
34 | Whether to perform evaluation on the test set.
35 | Usually there is three-crop or ten-crop evaluation strategy involved.
36 | name_pattern : str, default None.
37 | The naming pattern of the decoded video frames.
38 | For example, img_00012.jpg.
39 | is_color : bool, default True.
40 | Whether the loaded image is color or grayscale.
41 | modality : str, default 'rgb'.
42 | Input modalities, we support only rgb video frames for now.
43 | Will add support for rgb difference image and optical flow image later.
44 | num_segments : int, default 1.
45 | Number of segments to evenly divide the video into clips.
46 | A useful technique to obtain global video-level information.
47 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
48 | num_crop : int, default 1.
49 | Number of crops for each image. default is 1.
50 | Common choices are three crops and ten crops during evaluation.
51 | new_length : int, default 1.
52 | The length of input video clip. Default is a single image, but it can be multiple video frames.
53 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
54 | new_step : int, default 1.
55 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
56 | new_step=2 means we will extract a video clip of every other frame.
57 | temporal_jitter : bool, default False.
58 | Whether to temporally jitter if new_step > 1.
59 | video_loader : bool, default False.
60 | Whether to use video loader to load data.
61 | use_decord : bool, default True.
62 | Whether to use Decord video loader to load data. Otherwise load image.
63 | transform : function, default None.
64 | A function that takes data and label and transforms them.
65 | transform_ssv2 : function, default None.
66 | A function that takes data and label and transforms them.
67 | data_aug : str, default 'v1'.
68 | Different types of data augmentation auto. Supports v1, v2, v3 and v4.
69 | lazy_init : bool, default False.
70 | If set to True, build a dataset instance without loading any dataset.
71 | """
72 | def __init__(self,
73 | root,
74 | setting,
75 | prefix='',
76 | split=' ',
77 | train=True,
78 | test_mode=False,
79 | name_pattern='img_%05d.jpg',
80 | is_color=True,
81 | modality='rgb',
82 | num_segments=1,
83 | num_crop=1,
84 | new_length=1,
85 | new_step=1,
86 | transform=None,
87 | transform_ssv2=None,
88 | temporal_jitter=False,
89 | video_loader=False,
90 | use_decord=True,
91 | lazy_init=False,
92 | num_sample=1,
93 | ):
94 |
95 | super(VideoMAE_multi, self).__init__()
96 | self.root = root
97 | self.setting = setting
98 | self.prefix = prefix
99 | self.split = split
100 | self.train = train
101 | self.test_mode = test_mode
102 | self.is_color = is_color
103 | self.modality = modality
104 | self.num_segments = num_segments
105 | self.num_crop = num_crop
106 | self.new_length = new_length
107 | self.new_step = new_step
108 | self.skip_length = self.new_length * self.new_step
109 | self.temporal_jitter = temporal_jitter
110 | self.name_pattern = name_pattern
111 | self.video_loader = video_loader
112 | self.use_decord = use_decord
113 | self.transform = transform
114 | self.transform_ssv2 = transform_ssv2
115 | self.lazy_init = lazy_init
116 | self.num_sample = num_sample
117 |
118 | assert use_decord == True, "Only support to read video now!"
119 |
120 | # sparse sampling, num_segments != 1
121 | if self.num_segments != 1:
122 | print('Use sparse sampling, change frame and stride')
123 | self.new_length = self.num_segments
124 | self.skip_length = 1
125 |
126 | self.client = None
127 | if has_client:
128 | self.client = Client()
129 |
130 | if not self.lazy_init:
131 | self.clips = self._make_dataset(root, setting)
132 | if len(self.clips) == 0:
133 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
134 | "Check your data directory (opt.data-dir)."))
135 |
136 | def __getitem__(self, index):
137 | idx = 0
138 | while True:
139 | try:
140 | images = None
141 | if self.use_decord:
142 | source, path, total_time, start_time, end_time, target = self.clips[index]
143 | if self.video_loader:
144 | video_name = os.path.join(self.prefix, path)
145 | # print(video_name)
146 | if "s3://" in video_name:
147 | video_bytes = self.client.get(video_name)
148 | # print(f'Got {video_name}')
149 | decord_vr = VideoReader(io.BytesIO(video_bytes),
150 | num_threads=1,
151 | ctx=cpu(0))
152 | else:
153 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
154 | duration = len(decord_vr)
155 | start_index = 0
156 |
157 | if total_time!= -1 and start_time != -1 and end_time != -1:
158 | fps = duration / total_time
159 | duration = int(fps * (end_time - start_time))
160 | start_index = int(fps * start_time)
161 | segment_indices, skip_offsets = self._sample_train_indices(duration, start_index)
162 | images = self._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets)
163 | else:
164 | raise NotImplementedError
165 |
166 | if images is not None:
167 | break
168 | except Exception as e:
169 | print("Failed to load video from {} with error {}".format(
170 | video_name, e))
171 |
172 | idx += 1
173 | if idx >= 11:
174 | idx = 0
175 | index = random.randint(0, len(self.clips) - 1)
176 | print(f'retry with new index {index}')
177 | else:
178 | print(f'retry with video_name {video_name}')
179 |
180 | if self.num_sample > 1:
181 | process_data_list = []
182 | mask_list = []
183 | for _ in range(self.num_sample):
184 | if source == "ssv2":
185 | process_data, mask = self.transform_ssv2((images, None))
186 | else:
187 | process_data, mask = self.transform((images, None))
188 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1)
189 | process_data_list.append(process_data)
190 | mask_list.append(mask)
191 | return process_data_list, mask_list
192 | else:
193 | if source == "ssv2":
194 | process_data, mask = self.transform_ssv2((images, None)) # T*C,H,W
195 | else:
196 | process_data, mask = self.transform((images, None)) # T*C,H,W
197 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
198 | return (process_data, mask)
199 |
200 | def __len__(self):
201 | return len(self.clips)
202 |
203 | def _make_dataset(self, directory, setting):
204 | if not os.path.exists(setting):
205 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
206 | clips = []
207 |
208 | print(f'Load dataset using decord: {self.use_decord}')
209 | with open(setting) as split_f:
210 | data = split_f.readlines()
211 | for line in data:
212 | line_info = line.split(self.split)
213 | if len(line_info) < 2:
214 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
215 | if self.use_decord:
216 | # line format: source, path, total_time, start_time, end_time, target
217 | source = line_info[0]
218 | path = line_info[1]
219 | total_time = float(line_info[2])
220 | start_time = float(line_info[3])
221 | end_time = float(line_info[4])
222 | target = int(line_info[5])
223 | item = (source, path, total_time, start_time, end_time, target)
224 | else:
225 | raise NotImplementedError
226 |
227 | clips.append(item)
228 | return clips
229 |
230 | def _sample_train_indices(self, num_frames, start_index=0):
231 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments
232 | if average_duration > 0:
233 | offsets = np.multiply(list(range(self.num_segments)),
234 | average_duration)
235 | offsets = offsets + np.random.randint(average_duration,
236 | size=self.num_segments)
237 | elif num_frames > max(self.num_segments, self.skip_length):
238 | offsets = np.sort(np.random.randint(
239 | num_frames - self.skip_length + 1,
240 | size=self.num_segments))
241 | else:
242 | offsets = np.zeros((self.num_segments,))
243 |
244 | if self.temporal_jitter:
245 | skip_offsets = np.random.randint(
246 | self.new_step, size=self.skip_length // self.new_step)
247 | else:
248 | skip_offsets = np.zeros(
249 | self.skip_length // self.new_step, dtype=int)
250 | return offsets + start_index, skip_offsets
251 |
252 | def _get_frame_id_list(self, duration, indices, skip_offsets):
253 | frame_id_list = []
254 | for seg_ind in indices:
255 | offset = int(seg_ind)
256 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
257 | if offset + skip_offsets[i] <= duration:
258 | frame_id = offset + skip_offsets[i] - 1
259 | else:
260 | frame_id = offset - 1
261 | frame_id_list.append(frame_id)
262 | if offset + self.new_step < duration:
263 | offset += self.new_step
264 | return frame_id_list
265 |
266 | def _video_TSN_decord_batch_loader(self, video_name, video_reader, duration, indices, skip_offsets):
267 | sampled_list = []
268 | frame_id_list = []
269 | for seg_ind in indices:
270 | offset = int(seg_ind)
271 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
272 | if offset + skip_offsets[i] <= duration:
273 | frame_id = offset + skip_offsets[i] - 1
274 | else:
275 | frame_id = offset - 1
276 | frame_id_list.append(frame_id)
277 | if offset + self.new_step < duration:
278 | offset += self.new_step
279 | try:
280 | video_data = video_reader.get_batch(frame_id_list).asnumpy()
281 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
282 | except:
283 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, video_name, duration))
284 | return sampled_list
285 |
--------------------------------------------------------------------------------
/single_modality/datasets/mae_multi_ofa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import io
4 | import numpy as np
5 | import torch
6 | import decord
7 | from PIL import Image
8 | from decord import VideoReader, cpu
9 | import random
10 |
11 | from torchvision import transforms
12 |
13 | from .transforms import *
14 |
15 | try:
16 | from petrel_client.client import Client
17 | has_client = True
18 | except ImportError:
19 | has_client = False
20 |
21 |
22 | class VideoMAE_multi_ofa(torch.utils.data.Dataset):
23 | """Load your own video classification dataset.
24 | Parameters
25 | ----------
26 | root : str, required.
27 | Path to the root folder storing the dataset.
28 | setting : str, required.
29 | A text file describing the dataset, each line per video sample.
30 | There are three items in each line: (1) video path; (2) video length and (3) video label.
31 | prefix : str, required.
32 | The prefix for loading data.
33 | split : str, required.
34 | The split character for metadata.
35 | train : bool, default True.
36 | Whether to load the training or validation set.
37 | test_mode : bool, default False.
38 | Whether to perform evaluation on the test set.
39 | Usually there is three-crop or ten-crop evaluation strategy involved.
40 | name_pattern : str, default None.
41 | The naming pattern of the decoded video frames.
42 | For example, img_00012.jpg.
43 | is_color : bool, default True.
44 | Whether the loaded image is color or grayscale.
45 | modality : str, default 'rgb'.
46 | Input modalities, we support only rgb video frames for now.
47 | Will add support for rgb difference image and optical flow image later.
48 | num_segments : int, default 1.
49 | Number of segments to evenly divide the video into clips.
50 | A useful technique to obtain global video-level information.
51 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
52 | num_crop : int, default 1.
53 | Number of crops for each image. default is 1.
54 | Common choices are three crops and ten crops during evaluation.
55 | new_length : int, default 1.
56 | The length of input video clip. Default is a single image, but it can be multiple video frames.
57 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
58 | new_step : int, default 1.
59 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
60 | new_step=2 means we will extract a video clip of every other frame.
61 | temporal_jitter : bool, default False.
62 | Whether to temporally jitter if new_step > 1.
63 | video_loader : bool, default False.
64 | Whether to use video loader to load data.
65 | use_decord : bool, default True.
66 | Whether to use Decord video loader to load data. Otherwise load image.
67 | transform : function, default None.
68 | A function that takes data and label and transforms them.
69 | transform_ssv2 : function, default None.
70 | A function that takes data and label and transforms them.
71 | data_aug : str, default 'v1'.
72 | Different types of data augmentation auto. Supports v1, v2, v3 and v4.
73 | lazy_init : bool, default False.
74 | If set to True, build a dataset instance without loading any dataset.
75 | """
76 | def __init__(self,
77 | root,
78 | setting,
79 | prefix='',
80 | split=' ',
81 | train=True,
82 | test_mode=False,
83 | name_pattern='img_%05d.jpg',
84 | is_color=True,
85 | modality='rgb',
86 | num_segments=1,
87 | num_crop=1,
88 | new_length=1,
89 | new_step=1,
90 | transform=None,
91 | transform_ssv2=None,
92 | temporal_jitter=False,
93 | video_loader=False,
94 | use_decord=True,
95 | lazy_init=False,
96 | num_sample=1,
97 | ):
98 |
99 | super(VideoMAE_multi_ofa, self).__init__()
100 | self.root = root
101 | self.setting = setting
102 | self.prefix = prefix
103 | self.split = split
104 | self.train = train
105 | self.test_mode = test_mode
106 | self.is_color = is_color
107 | self.modality = modality
108 | self.num_segments = num_segments
109 | self.num_crop = num_crop
110 | self.new_length = new_length
111 | self.new_step = new_step
112 | self.skip_length = self.new_length * self.new_step
113 | self.temporal_jitter = temporal_jitter
114 | self.name_pattern = name_pattern
115 | self.video_loader = video_loader
116 | self.use_decord = use_decord
117 | self.transform = transform
118 |
119 | self.transforms = {}
120 | for crop_size in range(140, 280, 28):
121 | input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
122 | input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
123 | normalize = GroupNormalize(input_mean, input_std)
124 | train_augmentation = GroupMultiScaleCrop(crop_size, [1, .875, .75, .66])
125 | transform = transforms.Compose([
126 | train_augmentation,
127 | GroupRandomHorizontalFlip(flip=True),
128 | Stack(roll=False),
129 | ToTorchFormatTensor(div=True),
130 | normalize,
131 | ])
132 | self.transforms[crop_size] = transform
133 |
134 | self.transform_ssv2 = transform_ssv2
135 |
136 | self.transform_ssv2s = {}
137 | for crop_size in range(140, 280, 28):
138 | input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
139 | input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
140 | normalize = GroupNormalize(input_mean, input_std)
141 | train_augmentation = GroupMultiScaleCrop(crop_size, [1, .875, .75, .66])
142 | transform = transforms.Compose([
143 | train_augmentation,
144 | GroupRandomHorizontalFlip(flip=False),
145 | Stack(roll=False),
146 | ToTorchFormatTensor(div=True),
147 | normalize,
148 | ])
149 | self.transform_ssv2s[crop_size] = transform
150 |
151 | self.lazy_init = lazy_init
152 | self.num_sample = num_sample
153 |
154 | assert use_decord == True, "Only support to read video now!"
155 |
156 | # sparse sampling, num_segments != 1
157 | if self.num_segments != 1:
158 | print('Use sparse sampling, change frame and stride')
159 | self.new_length = self.num_segments
160 | self.skip_length = 1
161 |
162 | self.client = None
163 | if has_client:
164 | self.client = Client()
165 |
166 | if not self.lazy_init:
167 | self.clips = self._make_dataset(root, setting)
168 | if len(self.clips) == 0:
169 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
170 | "Check your data directory (opt.data-dir)."))
171 |
172 | def __getitem__(self, index):
173 | idx = 0
174 | index, self.num_segments, self.crop_size = index
175 | self.new_length = self.num_segments
176 |
177 | while True:
178 | try:
179 | images = None
180 | if self.use_decord:
181 | source, path, total_time, start_time, end_time, target = self.clips[index]
182 | if self.video_loader:
183 | video_name = os.path.join(self.prefix, path)
184 | # print(video_name)
185 | if "s3://" in video_name:
186 | video_bytes = self.client.get(video_name)
187 | # print(f'Got {video_name}')
188 | decord_vr = VideoReader(io.BytesIO(video_bytes),
189 | num_threads=1,
190 | ctx=cpu(0))
191 | else:
192 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
193 | duration = len(decord_vr)
194 | start_index = 0
195 |
196 | if total_time!= -1 and start_time != -1 and end_time != -1:
197 | fps = duration / total_time
198 | duration = int(fps * (end_time - start_time))
199 | start_index = int(fps * start_time)
200 | segment_indices, skip_offsets = self._sample_train_indices(duration, start_index)
201 | images = self._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets)
202 | else:
203 | raise NotImplementedError
204 |
205 | if images is not None:
206 | break
207 | except Exception as e:
208 | print("Failed to load video from {} with error {}".format(
209 | video_name, e))
210 |
211 | idx += 1
212 | if idx >= 11:
213 | idx = 0
214 | index = random.randint(0, len(self.clips) - 1)
215 | print(f'retry with new index {index}')
216 | else:
217 | print(f'retry with video_name {video_name}')
218 |
219 | if self.num_sample > 1:
220 | raise NotImplementedError
221 | else:
222 | if source == "ssv2":
223 | process_data, _ = self.transform_ssv2s[self.crop_size]((images, None)) # T*C,H,W
224 | else:
225 | process_data, _ = self.transforms[self.crop_size]((images, None)) # T*C,H,W
226 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
227 | return process_data
228 |
229 | def __len__(self):
230 | return len(self.clips)
231 |
232 | def _make_dataset(self, directory, setting):
233 | if not os.path.exists(setting):
234 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
235 | clips = []
236 |
237 | print(f'Load dataset using decord: {self.use_decord}')
238 | with open(setting) as split_f:
239 | data = split_f.readlines()
240 | for line in data:
241 | line_info = line.split(self.split)
242 | if len(line_info) < 2:
243 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
244 | if self.use_decord:
245 | # line format: source, path, total_time, start_time, end_time, target
246 | source = line_info[0]
247 | path = line_info[1]
248 | total_time = float(line_info[2])
249 | start_time = float(line_info[3])
250 | end_time = float(line_info[4])
251 | target = int(line_info[5])
252 | item = (source, path, total_time, start_time, end_time, target)
253 | else:
254 | raise NotImplementedError
255 |
256 | clips.append(item)
257 | return clips
258 |
259 | def _sample_train_indices(self, num_frames, start_index=0):
260 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments
261 | if average_duration > 0:
262 | offsets = np.multiply(list(range(self.num_segments)),
263 | average_duration)
264 | offsets = offsets + np.random.randint(average_duration,
265 | size=self.num_segments)
266 | elif num_frames > max(self.num_segments, self.skip_length):
267 | offsets = np.sort(np.random.randint(
268 | num_frames - self.skip_length + 1,
269 | size=self.num_segments))
270 | else:
271 | offsets = np.zeros((self.num_segments,))
272 |
273 | if self.temporal_jitter:
274 | skip_offsets = np.random.randint(
275 | self.new_step, size=self.skip_length // self.new_step)
276 | else:
277 | skip_offsets = np.zeros(
278 | self.skip_length // self.new_step, dtype=int)
279 | return offsets + start_index, skip_offsets
280 |
281 | def _get_frame_id_list(self, duration, indices, skip_offsets):
282 | frame_id_list = []
283 | for seg_ind in indices:
284 | offset = int(seg_ind)
285 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
286 | if offset + skip_offsets[i] <= duration:
287 | frame_id = offset + skip_offsets[i] - 1
288 | else:
289 | frame_id = offset - 1
290 | frame_id_list.append(frame_id)
291 | if offset + self.new_step < duration:
292 | offset += self.new_step
293 | return frame_id_list
294 |
295 | def _video_TSN_decord_batch_loader(self, video_name, video_reader, duration, indices, skip_offsets):
296 | sampled_list = []
297 | frame_id_list = []
298 | for seg_ind in indices:
299 | offset = int(seg_ind)
300 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
301 | if offset + skip_offsets[i] <= duration:
302 | frame_id = offset + skip_offsets[i] - 1
303 | else:
304 | frame_id = offset - 1
305 | frame_id_list.append(frame_id)
306 | if offset + self.new_step < duration:
307 | offset += self.new_step
308 | try:
309 | video_data = video_reader.get_batch(frame_id_list).asnumpy()
310 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
311 | except:
312 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, video_name, duration))
313 | return sampled_list
314 |
--------------------------------------------------------------------------------
/single_modality/datasets/mae_ofa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import io
4 | import numpy as np
5 | import torch
6 | import decord
7 | from PIL import Image
8 | from decord import VideoReader, cpu
9 | import random
10 |
11 | try:
12 | from petrel_client.client import Client
13 | has_client = True
14 | except ImportError:
15 | has_client = False
16 |
17 | from torchvision import transforms
18 |
19 | from .transforms import *
20 |
21 | import time
22 |
23 | class VideoMAE_ofa(torch.utils.data.Dataset):
24 | """Load your own video classification dataset.
25 | Parameters
26 | ----------
27 | root : str, required.
28 | Path to the root folder storing the dataset.
29 | setting : str, required.
30 | A text file describing the dataset, each line per video sample.
31 | There are three items in each line: (1) video path; (2) video length and (3) video label.
32 | prefix : str, required.
33 | The prefix for loading data.
34 | split : str, required.
35 | The split character for metadata.
36 | train : bool, default True.
37 | Whether to load the training or validation set.
38 | test_mode : bool, default False.
39 | Whether to perform evaluation on the test set.
40 | Usually there is three-crop or ten-crop evaluation strategy involved.
41 | name_pattern : str, default None.
42 | The naming pattern of the decoded video frames.
43 | For example, img_00012.jpg.
44 | video_ext : str, default 'mp4'.
45 | If video_loader is set to True, please specify the video format accordinly.
46 | is_color : bool, default True.
47 | Whether the loaded image is color or grayscale.
48 | modality : str, default 'rgb'.
49 | Input modalities, we support only rgb video frames for now.
50 | Will add support for rgb difference image and optical flow image later.
51 | num_segments : int, default 1.
52 | Number of segments to evenly divide the video into clips.
53 | A useful technique to obtain global video-level information.
54 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
55 | num_crop : int, default 1.
56 | Number of crops for each image. default is 1.
57 | Common choices are three crops and ten crops during evaluation.
58 | new_length : int, default 1.
59 | The length of input video clip. Default is a single image, but it can be multiple video frames.
60 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
61 | new_step : int, default 1.
62 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
63 | new_step=2 means we will extract a video clip of every other frame.
64 | temporal_jitter : bool, default False.
65 | Whether to temporally jitter if new_step > 1.
66 | video_loader : bool, default False.
67 | Whether to use video loader to load data.
68 | use_decord : bool, default True.
69 | Whether to use Decord video loader to load data. Otherwise load image.
70 | transform : function, default None.
71 | A function that takes data and label and transforms them.
72 | data_aug : str, default 'v1'.
73 | Different types of data augmentation auto. Supports v1, v2, v3 and v4.
74 | lazy_init : bool, default False.
75 | If set to True, build a dataset instance without loading any dataset.
76 | """
77 | def __init__(self,
78 | root,
79 | setting,
80 | prefix='',
81 | split=' ',
82 | train=True,
83 | test_mode=False,
84 | name_pattern='img_%05d.jpg',
85 | video_ext='mp4',
86 | is_color=True,
87 | modality='rgb',
88 | num_segments=1,
89 | num_crop=1,
90 | new_length=1,
91 | new_step=1,
92 | transform=None,
93 | temporal_jitter=False,
94 | video_loader=False,
95 | use_decord=True,
96 | lazy_init=False,
97 | num_sample=1,
98 | ):
99 |
100 | super(VideoMAE_ofa, self).__init__()
101 |
102 | self.transforms = {}
103 | for crop_size in range(140, 280, 28):
104 | input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
105 | input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
106 | normalize = GroupNormalize(input_mean, input_std)
107 | train_augmentation = GroupMultiScaleCrop(crop_size, [1, .875, .75, .66])
108 | transform = transforms.Compose([
109 | train_augmentation,
110 | GroupRandomHorizontalFlip(flip=True),
111 | Stack(roll=False),
112 | ToTorchFormatTensor(div=True),
113 | normalize,
114 | ])
115 | self.transforms[crop_size] = transform
116 |
117 | self.root = root
118 | self.setting = setting
119 | self.prefix = prefix
120 | self.split = split
121 | self.train = train
122 | self.test_mode = test_mode
123 | self.is_color = is_color
124 | self.modality = modality
125 | self.num_segments = num_segments
126 | self.num_crop = num_crop
127 | self.new_length = new_length
128 | self.new_step = new_step
129 | self.skip_length = self.new_length * self.new_step
130 | self.temporal_jitter = temporal_jitter
131 | self.name_pattern = name_pattern
132 | self.video_loader = video_loader
133 | self.video_ext = video_ext
134 | self.use_decord = use_decord
135 | self.transform = transform
136 | self.lazy_init = lazy_init
137 | self.num_sample = num_sample
138 |
139 | # sparse sampling, num_segments != 1
140 | if self.num_segments != 1:
141 | print('Use sparse sampling, change frame and stride')
142 | self.new_length = self.num_segments
143 | self.skip_length = 1
144 |
145 | self.client = None
146 | if has_client:
147 | self.client = Client('~/petreloss.conf')
148 |
149 | if not self.lazy_init:
150 | self.clips = self._make_dataset(root, setting)
151 | if len(self.clips) == 0:
152 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
153 | "Check your data directory (opt.data-dir)."))
154 |
155 | def __getitem__(self, index):
156 | index, self.num_segments, self.crop_size = index
157 | self.new_length = self.num_segments
158 |
159 | while True:
160 | try:
161 | images = None
162 | if self.use_decord:
163 | directory, target = self.clips[index]
164 | if self.video_loader:
165 | if '.' in directory.split('/')[-1]:
166 | # data in the "setting" file already have extension, e.g., demo.mp4
167 | video_name = directory
168 | else:
169 | # data in the "setting" file do not have extension, e.g., demo
170 | # So we need to provide extension (i.e., .mp4) to complete the file name.
171 | video_name = '{}.{}'.format(directory, self.video_ext)
172 |
173 | video_name = os.path.join(self.prefix, video_name)
174 | if video_name.startswith('s3') or video_name.startswith('p2:s3'):
175 | video_bytes = self.client.get(video_name)
176 | decord_vr = VideoReader(io.BytesIO(video_bytes),
177 | num_threads=1,
178 | ctx=cpu(0))
179 | else:
180 | decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0))
181 | duration = len(decord_vr)
182 |
183 | segment_indices, skip_offsets = self._sample_train_indices(duration)
184 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets)
185 |
186 | else:
187 | video_name, total_frame, target = self.clips[index]
188 | video_name = os.path.join(self.prefix, video_name)
189 |
190 | segment_indices, skip_offsets = self._sample_train_indices(total_frame)
191 | frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets)
192 | images = []
193 | for idx in frame_id_list:
194 | frame_fname = os.path.join(video_name, self.name_pattern.format(idx))
195 | img_bytes = self.client.get(frame_fname)
196 | img_np = np.frombuffer(img_bytes, np.uint8)
197 | img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
198 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
199 | images.append(Image.fromarray(img))
200 | if images is not None:
201 | break
202 | except Exception as e:
203 | print("Failed to load video from {} with error {}".format(
204 | video_name, e))
205 | index = random.randint(0, len(self.clips) - 1)
206 |
207 | if self.num_sample > 1:
208 | raise NotImplementedError
209 | else:
210 | process_data, _ = self.transforms[self.crop_size]((images, None)) # T*C,H,W
211 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
212 | return process_data
213 |
214 | def __len__(self):
215 | return len(self.clips)
216 |
217 | def _make_dataset(self, directory, setting):
218 | if not os.path.exists(setting):
219 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
220 | clips = []
221 |
222 | print(f'Load dataset using decord: {self.use_decord}')
223 | with open(setting) as split_f:
224 | data = split_f.readlines()
225 | for line in data:
226 | line_info = line.split(self.split)
227 | if len(line_info) < 2:
228 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
229 | if self.use_decord:
230 | # line format: video_path, video_label
231 | clip_path = os.path.join(line_info[0])
232 | target = int(line_info[1])
233 | item = (clip_path, target)
234 | else:
235 | # line format: video_path, video_duration, video_label
236 | clip_path = os.path.join(line_info[0])
237 | total_frame = int(line_info[1])
238 | target = int(line_info[2])
239 | item = (clip_path, total_frame, target)
240 | clips.append(item)
241 | return clips
242 |
243 | def _sample_train_indices(self, num_frames):
244 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments
245 | if average_duration > 0:
246 | offsets = np.multiply(list(range(self.num_segments)),
247 | average_duration)
248 | offsets = offsets + np.random.randint(average_duration,
249 | size=self.num_segments)
250 | elif num_frames > max(self.num_segments, self.skip_length):
251 | offsets = np.sort(np.random.randint(
252 | num_frames - self.skip_length + 1,
253 | size=self.num_segments))
254 | else:
255 | offsets = np.zeros((self.num_segments,))
256 |
257 | if self.temporal_jitter:
258 | skip_offsets = np.random.randint(
259 | self.new_step, size=self.skip_length // self.new_step)
260 | else:
261 | skip_offsets = np.zeros(
262 | self.skip_length // self.new_step, dtype=int)
263 | return offsets + 1, skip_offsets
264 |
265 | def _get_frame_id_list(self, duration, indices, skip_offsets):
266 | frame_id_list = []
267 | for seg_ind in indices:
268 | offset = int(seg_ind)
269 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
270 | if offset + skip_offsets[i] <= duration:
271 | frame_id = offset + skip_offsets[i] - 1
272 | else:
273 | frame_id = offset - 1
274 | frame_id_list.append(frame_id)
275 | if offset + self.new_step < duration:
276 | offset += self.new_step
277 | return frame_id_list
278 |
279 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets):
280 | sampled_list = []
281 | frame_id_list = []
282 | for seg_ind in indices:
283 | offset = int(seg_ind)
284 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
285 | if offset + skip_offsets[i] <= duration:
286 | frame_id = offset + skip_offsets[i] - 1
287 | else:
288 | frame_id = offset - 1
289 | frame_id_list.append(frame_id)
290 | if offset + self.new_step < duration:
291 | offset += self.new_step
292 | try:
293 | video_data = video_reader.get_batch(frame_id_list).asnumpy()
294 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
295 | except:
296 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration))
297 | return sampled_list
298 |
--------------------------------------------------------------------------------
/single_modality/datasets/masking_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def topk(matrix, K, axis=1):
5 | if axis == 0:
6 | row_index = np.arange(matrix.shape[1 - axis])
7 | topk_index = np.argpartition(-matrix, K, axis=axis)[0:K, :]
8 | topk_data = matrix[topk_index, row_index]
9 | topk_index_sort = np.argsort(-topk_data,axis=axis)
10 | topk_data_sort = topk_data[topk_index_sort,row_index]
11 | topk_index_sort = topk_index[0:K,:][topk_index_sort,row_index]
12 | else:
13 | column_index = np.arange(matrix.shape[1 - axis])[:, None]
14 | topk_index = np.argpartition(-matrix, K, axis=axis)[:, 0:K]
15 | topk_data = matrix[column_index, topk_index]
16 | topk_index_sort = np.argsort(-topk_data, axis=axis)
17 | topk_data_sort = topk_data[column_index, topk_index_sort]
18 | topk_index_sort = topk_index[:,0:K][column_index,topk_index_sort]
19 | return (topk_data_sort, topk_index_sort)
20 |
21 |
22 | class TubeMaskingGenerator:
23 | def __init__(self, input_size, mask_ratio):
24 | self.frames, self.height, self.width = input_size
25 | self.num_patches_per_frame = self.height * self.width
26 | self.total_patches = self.frames * self.num_patches_per_frame
27 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
28 | self.total_masks = self.frames * self.num_masks_per_frame
29 |
30 | def __repr__(self):
31 | repr_str = "Maks: total patches {}, mask patches {}".format(
32 | self.total_patches, self.total_masks
33 | )
34 | return repr_str
35 |
36 | def __call__(self):
37 | mask_per_frame = np.hstack([
38 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
39 | np.ones(self.num_masks_per_frame),
40 | ])
41 | np.random.shuffle(mask_per_frame)
42 | mask = np.tile(mask_per_frame, (self.frames, 1)).flatten()
43 | return mask
44 |
45 |
46 | class RandomMaskingGenerator:
47 | def __init__(self, input_size, mask_ratio):
48 | if not isinstance(input_size, tuple):
49 | input_size = (input_size, ) * 3
50 |
51 | self.frames, self.height, self.width = input_size
52 |
53 | self.num_patches = self.frames * self.height * self.width # 8x14x14
54 | self.num_mask = int(mask_ratio * self.num_patches)
55 |
56 | def __repr__(self):
57 | repr_str = "Maks: total patches {}, mask patches {}".format(
58 | self.num_patches, self.num_mask)
59 | return repr_str
60 |
61 | def __call__(self):
62 | mask = np.hstack([
63 | np.zeros(self.num_patches - self.num_mask),
64 | np.ones(self.num_mask),
65 | ])
66 | np.random.shuffle(mask)
67 | return mask # [196*8]
68 |
69 |
70 | class TemporalConsistencyMaskingGenerator:
71 | def __init__(self, input_size, mask_ratio, mask_ratio_teacher):
72 | self.frames, self.height, self.width = input_size
73 | self.num_patches_per_frame = self.height * self.width
74 | self.total_patches = self.frames * self.num_patches_per_frame
75 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
76 | self.num_masks_per_frame_teacher = int(mask_ratio_teacher * self.num_patches_per_frame)
77 | self.total_masks = self.frames * self.num_masks_per_frame
78 |
79 | def __repr__(self):
80 | repr_str = "Maks: total patches {}, mask patches {}".format(
81 | self.total_patches, self.total_masks
82 | )
83 | return repr_str
84 |
85 | def __call__(self):
86 | mask_per_frame = np.hstack([
87 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
88 | np.ones(self.num_masks_per_frame),
89 | ])
90 | if self.num_masks_per_frame_teacher != 0:
91 | mask_per_frame[-self.num_masks_per_frame_teacher:] = 2
92 | np.random.shuffle(mask_per_frame)
93 | mask_per_frame_student = (mask_per_frame > 0).astype(int)
94 | mask_per_frame_teacher = (mask_per_frame > 1).astype(int)
95 | mask_per_frame_diff = (mask_per_frame == 1).astype(int)
96 | mask_student = np.tile(mask_per_frame_student, (self.frames, 1)).flatten()
97 | mask_teacher = np.tile(mask_per_frame_teacher, (self.frames, 1)).flatten()
98 | mask_diff = np.tile(mask_per_frame_diff, (self.frames, 1)).flatten()
99 | mask_student = np.insert(mask_student, 0, 0) # For cls token
100 | mask_teacher = np.insert(mask_teacher, 0, 0)
101 | mask_diff = np.insert(mask_diff, 0, 0)
102 | return (mask_student, mask_teacher, mask_diff)
103 |
104 |
105 | class TemporalProgressiveMaskingGenerator:
106 | def __init__(self, input_size, mask_ratio):
107 | self.frames, self.height, self.width = input_size
108 | self.num_patches_per_frame = self.height * self.width
109 | self.total_patches = self.frames * self.num_patches_per_frame
110 | max_keep_patch = int((1 - mask_ratio) * self.num_patches_per_frame)
111 | min_keep_patch = int(0.05 * self.num_patches_per_frame)
112 | self.keep_patches_list = np.linspace(max_keep_patch, min_keep_patch, self.frames).astype(int)
113 | self.total_masks = self.total_patches - self.keep_patches_list.sum()
114 | def __repr__(self):
115 | repr_str = "Maks: total patches {}, mask patches {}".format(
116 | self.total_patches, self.total_masks
117 | )
118 | return repr_str
119 |
120 | def __call__(self):
121 |
122 | rand = np.random.randn(1, self.num_patches_per_frame)
123 | mask = np.zeros((self.frames, self.num_patches_per_frame), dtype=np.bool)
124 | for i in range(self.frames):
125 | top_k, _ = topk(rand, self.keep_patches_list[i])
126 | the_topk = top_k[0][-1]
127 | mask[i] = rand<=the_topk
128 | mask = mask.flatten().astype(int)
129 | return mask # [196*8]
130 |
131 |
132 | class TemporalCenteringProgressiveMaskingGenerator:
133 | def __init__(self, input_size, mask_ratio):
134 | self.num_frames, self.height, self.width = input_size
135 | self.num_patches_per_frame = self.height * self.width
136 | self.total_patches = self.num_frames * self.num_patches_per_frame
137 | min_mask_ratio = mask_ratio
138 |
139 | max_mask_ratio = 0.95
140 | max_keep_patch = int((1 - min_mask_ratio) * self.num_patches_per_frame)
141 | min_keep_patch = int((1 - max_mask_ratio) * self.num_patches_per_frame)
142 | patches_list = np.linspace(max_keep_patch, min_keep_patch, self.num_frames//2 ).astype(int).tolist()
143 | self.keep_patches_list = patches_list.copy()
144 | patches_list.reverse()
145 | self.keep_patches_list = patches_list + self.keep_patches_list
146 | self.total_masks = self.total_patches - sum(self.keep_patches_list)
147 | def __repr__(self):
148 | repr_str = "Maks: total patches {}, mask patches {}".format(
149 | self.total_patches, self.total_masks
150 | )
151 | return repr_str
152 |
153 | def __call__(self):
154 |
155 | rand = np.random.randn(1, self.num_patches_per_frame)
156 | mask = np.zeros((self.num_frames, self.num_patches_per_frame), dtype=np.bool)
157 | for i in range(self.num_frames):
158 | top_k, _ = topk(rand, self.keep_patches_list[i])
159 | the_topk = top_k[0][-1]
160 | mask[i] = rand<=the_topk
161 | mask = mask.flatten().astype(int)
162 | return mask
163 |
--------------------------------------------------------------------------------
/single_modality/datasets/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27 | return y1 * lam + y2 * (1. - lam)
28 |
29 |
30 | def rand_bbox(img_shape, lam, margin=0., count=None):
31 | """ Standard CutMix bounding-box
32 | Generates a random square bbox based on lambda value. This impl includes
33 | support for enforcing a border margin as percent of bbox dimensions.
34 |
35 | Args:
36 | img_shape (tuple): Image shape as tuple
37 | lam (float): Cutmix lambda value
38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39 | count (int): Number of bbox to generate
40 | """
41 | ratio = np.sqrt(1 - lam)
42 | img_h, img_w = img_shape[-2:]
43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47 | yl = np.clip(cy - cut_h // 2, 0, img_h)
48 | yh = np.clip(cy + cut_h // 2, 0, img_h)
49 | xl = np.clip(cx - cut_w // 2, 0, img_w)
50 | xh = np.clip(cx + cut_w // 2, 0, img_w)
51 | return yl, yh, xl, xh
52 |
53 |
54 | def rand_bbox_minmax(img_shape, minmax, count=None):
55 | """ Min-Max CutMix bounding-box
56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
57 | based on min/max percent values applied to each dimension of the input image.
58 |
59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60 |
61 | Args:
62 | img_shape (tuple): Image shape as tuple
63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64 | count (int): Number of bbox to generate
65 | """
66 | assert len(minmax) == 2
67 | img_h, img_w = img_shape[-2:]
68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70 | yl = np.random.randint(0, img_h - cut_h, size=count)
71 | xl = np.random.randint(0, img_w - cut_w, size=count)
72 | yu = yl + cut_h
73 | xu = xl + cut_w
74 | return yl, yu, xl, xu
75 |
76 |
77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78 | """ Generate bbox and apply lambda correction.
79 | """
80 | if ratio_minmax is not None:
81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82 | else:
83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84 | if correct_lam or ratio_minmax is not None:
85 | bbox_area = (yu - yl) * (xu - xl)
86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87 | return (yl, yu, xl, xu), lam
88 |
89 |
90 | class Mixup:
91 | """ Mixup/Cutmix that applies different params to each element or whole batch
92 |
93 | Args:
94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97 | prob (float): probability of applying mixup or cutmix per batch or element
98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101 | label_smoothing (float): apply label smoothing to the mixed target tensor
102 | num_classes (int): number of classes for target
103 | """
104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106 | self.mixup_alpha = mixup_alpha
107 | self.cutmix_alpha = cutmix_alpha
108 | self.cutmix_minmax = cutmix_minmax
109 | if self.cutmix_minmax is not None:
110 | assert len(self.cutmix_minmax) == 2
111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112 | self.cutmix_alpha = 1.0
113 | self.mix_prob = prob
114 | self.switch_prob = switch_prob
115 | self.label_smoothing = label_smoothing
116 | self.num_classes = num_classes
117 | self.mode = mode
118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120 |
121 | def _params_per_elem(self, batch_size):
122 | lam = np.ones(batch_size, dtype=np.float32)
123 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
124 | if self.mixup_enabled:
125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
127 | lam_mix = np.where(
128 | use_cutmix,
129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131 | elif self.mixup_alpha > 0.:
132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133 | elif self.cutmix_alpha > 0.:
134 | use_cutmix = np.ones(batch_size, dtype=np.bool)
135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136 | else:
137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139 | return lam, use_cutmix
140 |
141 | def _params_per_batch(self):
142 | lam = 1.
143 | use_cutmix = False
144 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146 | use_cutmix = np.random.rand() < self.switch_prob
147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
149 | elif self.mixup_alpha > 0.:
150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151 | elif self.cutmix_alpha > 0.:
152 | use_cutmix = True
153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154 | else:
155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156 | lam = float(lam_mix)
157 | return lam, use_cutmix
158 |
159 | def _mix_elem(self, x):
160 | batch_size = len(x)
161 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
162 | x_orig = x.clone() # need to keep an unmodified original for mixing source
163 | for i in range(batch_size):
164 | j = batch_size - i - 1
165 | lam = lam_batch[i]
166 | if lam != 1.:
167 | if use_cutmix[i]:
168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
171 | lam_batch[i] = lam
172 | else:
173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175 |
176 | def _mix_pair(self, x):
177 | batch_size = len(x)
178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179 | x_orig = x.clone() # need to keep an unmodified original for mixing source
180 | for i in range(batch_size // 2):
181 | j = batch_size - i - 1
182 | lam = lam_batch[i]
183 | if lam != 1.:
184 | if use_cutmix[i]:
185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189 | lam_batch[i] = lam
190 | else:
191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195 |
196 | def _mix_batch(self, x):
197 | lam, use_cutmix = self._params_per_batch()
198 | if lam == 1.:
199 | return 1.
200 | if use_cutmix:
201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
204 | else:
205 | x_flipped = x.flip(0).mul_(1. - lam)
206 | x.mul_(lam).add_(x_flipped)
207 | return lam
208 |
209 | def __call__(self, x, target):
210 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
211 | if self.mode == 'elem':
212 | lam = self._mix_elem(x)
213 | elif self.mode == 'pair':
214 | lam = self._mix_pair(x)
215 | else:
216 | lam = self._mix_batch(x)
217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218 | return x, target
219 |
220 |
221 | class FastCollateMixup(Mixup):
222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223 |
224 | A Mixup impl that's performed while collating the batches.
225 | """
226 |
227 | def _mix_elem_collate(self, output, batch, half=False):
228 | batch_size = len(batch)
229 | num_elem = batch_size // 2 if half else batch_size
230 | assert len(output) == num_elem
231 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
232 | for i in range(num_elem):
233 | j = batch_size - i - 1
234 | lam = lam_batch[i]
235 | mixed = batch[i][0]
236 | if lam != 1.:
237 | if use_cutmix[i]:
238 | if not half:
239 | mixed = mixed.copy()
240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243 | lam_batch[i] = lam
244 | else:
245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246 | np.rint(mixed, out=mixed)
247 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
248 | if half:
249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250 | return torch.tensor(lam_batch).unsqueeze(1)
251 |
252 | def _mix_pair_collate(self, output, batch):
253 | batch_size = len(batch)
254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255 | for i in range(batch_size // 2):
256 | j = batch_size - i - 1
257 | lam = lam_batch[i]
258 | mixed_i = batch[i][0]
259 | mixed_j = batch[j][0]
260 | assert 0 <= lam <= 1.0
261 | if lam < 1.:
262 | if use_cutmix[i]:
263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267 | mixed_j[:, yl:yh, xl:xh] = patch_i
268 | lam_batch[i] = lam
269 | else:
270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272 | mixed_i = mixed_temp
273 | np.rint(mixed_j, out=mixed_j)
274 | np.rint(mixed_i, out=mixed_i)
275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278 | return torch.tensor(lam_batch).unsqueeze(1)
279 |
280 | def _mix_batch_collate(self, output, batch):
281 | batch_size = len(batch)
282 | lam, use_cutmix = self._params_per_batch()
283 | if use_cutmix:
284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286 | for i in range(batch_size):
287 | j = batch_size - i - 1
288 | mixed = batch[i][0]
289 | if lam != 1.:
290 | if use_cutmix:
291 | mixed = mixed.copy() # don't want to modify the original while iterating
292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
293 | else:
294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295 | np.rint(mixed, out=mixed)
296 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
297 | return lam
298 |
299 | def __call__(self, batch, _=None):
300 | batch_size = len(batch)
301 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
302 | half = 'half' in self.mode
303 | if half:
304 | batch_size //= 2
305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306 | if self.mode == 'elem' or self.mode == 'half':
307 | lam = self._mix_elem_collate(output, batch, half=half)
308 | elif self.mode == 'pair':
309 | lam = self._mix_pair_collate(output, batch)
310 | else:
311 | lam = self._mix_batch_collate(output, batch)
312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314 | target = target[:batch_size]
315 | return output, target
316 |
317 |
--------------------------------------------------------------------------------
/single_modality/datasets/multiloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from utils import is_dist_avail_and_initialized
4 | import random
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class MetaLoader(object):
11 | """ wraps multiple data loader """
12 | def __init__(self, loaders):
13 | """Iterates over multiple dataloaders, it ensures all processes
14 | work on data from the same dataloader. This loader will end when
15 | the shorter dataloader raises StopIteration exception.
16 |
17 | loaders: Dict, {name: dataloader}
18 | """
19 | self.loaders = loaders
20 | iter_order = [i for i in range(len(self.loaders))]
21 | random.shuffle(iter_order)
22 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
23 |
24 | # sync
25 | if is_dist_avail_and_initialized():
26 | # make sure all processes have the same order so that
27 | # each step they will have data from the same loader
28 | dist.broadcast(iter_order, src=0)
29 | self.iter_order = [e for e in iter_order.cpu()]
30 |
31 | def __len__(self):
32 | return len(self.iter_order)
33 |
34 | def __iter__(self):
35 | """ this iterator will run indefinitely """
36 | while True:
37 | try:
38 | for i in self.iter_order:
39 | _iter = self.loaders[i]
40 | batch = next(_iter)
41 | yield batch
42 | except:
43 | iter_order = [i for i in range(len(self.loaders))]
44 | random.shuffle(iter_order)
45 | iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
46 | if is_dist_avail_and_initialized():
47 | # make sure all processes have the same order so that
48 | # each step they will have data from the same loader
49 | dist.broadcast(iter_order, src=0)
50 | self.iter_order = [e for e in iter_order.cpu()]
51 | continue
--------------------------------------------------------------------------------
/single_modality/datasets/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/single_modality/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import warnings
4 | import random
5 | import numpy as np
6 | import torchvision
7 | from PIL import Image, ImageOps
8 | import numbers
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_tuple):
19 | img_group, label = img_tuple
20 |
21 | w, h = img_group[0].size
22 | th, tw = self.size
23 |
24 | out_images = list()
25 |
26 | x1 = random.randint(0, w - tw)
27 | y1 = random.randint(0, h - th)
28 |
29 | for img in img_group:
30 | assert(img.size[0] == w and img.size[1] == h)
31 | if w == tw and h == th:
32 | out_images.append(img)
33 | else:
34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35 |
36 | return (out_images, label)
37 |
38 |
39 | class GroupCenterCrop(object):
40 | def __init__(self, size):
41 | self.worker = torchvision.transforms.CenterCrop(size)
42 |
43 | def __call__(self, img_tuple):
44 | img_group, label = img_tuple
45 | return ([self.worker(img) for img in img_group], label)
46 |
47 |
48 | class GroupRandomHorizontalFlip(object):
49 | def __init__(self, flip=False):
50 | self.flip = flip
51 |
52 | def __call__(self, img_tuple):
53 | v = random.random()
54 | if self.flip and v < 0.5:
55 | img_group, label = img_tuple
56 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
57 | return (ret, label)
58 | else:
59 | return img_tuple
60 |
61 |
62 | class GroupNormalize(object):
63 | def __init__(self, mean, std):
64 | self.mean = mean
65 | self.std = std
66 |
67 | def __call__(self, tensor_tuple):
68 | tensor, label = tensor_tuple
69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
70 | rep_std = self.std * (tensor.size()[0]//len(self.std))
71 |
72 | # TODO: make efficient
73 | for t, m, s in zip(tensor, rep_mean, rep_std):
74 | t.sub_(m).div_(s)
75 |
76 | return (tensor,label)
77 |
78 |
79 | class GroupGrayScale(object):
80 | def __init__(self, size):
81 | self.worker = torchvision.transforms.Grayscale(size)
82 |
83 | def __call__(self, img_tuple):
84 | img_group, label = img_tuple
85 | return ([self.worker(img) for img in img_group], label)
86 |
87 |
88 | class GroupColorJitter(object):
89 | def __init__(self, size):
90 | self.worker = torchvision.transforms.ColorJitter(
91 | brightness=size, contrast=size, saturation=size
92 | )
93 |
94 | def __call__(self, img_tuple):
95 | img_group, label = img_tuple
96 | return ([self.worker(img) for img in img_group], label)
97 |
98 |
99 | class GroupScale(object):
100 | """ Rescales the input PIL.Image to the given 'size'.
101 | 'size' will be the size of the smaller edge.
102 | For example, if height > width, then image will be
103 | rescaled to (size * height / width, size)
104 | size: size of the smaller edge
105 | interpolation: Default: PIL.Image.BILINEAR
106 | """
107 |
108 | def __init__(self, size, interpolation=Image.BILINEAR):
109 | self.worker = torchvision.transforms.Resize(size, interpolation)
110 |
111 | def __call__(self, img_tuple):
112 | img_group, label = img_tuple
113 | return ([self.worker(img) for img in img_group], label)
114 |
115 |
116 | class GroupMultiScaleCrop(object):
117 |
118 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
119 | self.scales = scales if scales is not None else [1, 875, .75, .66]
120 | self.max_distort = max_distort
121 | self.fix_crop = fix_crop
122 | self.more_fix_crop = more_fix_crop
123 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
124 | self.interpolation = Image.BILINEAR
125 |
126 | def __call__(self, img_tuple):
127 | img_group, label = img_tuple
128 |
129 | im_size = img_group[0].size
130 |
131 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
132 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
133 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
134 | return (ret_img_group, label)
135 |
136 | def _sample_crop_size(self, im_size):
137 | image_w, image_h = im_size[0], im_size[1]
138 |
139 | # find a crop size
140 | base_size = min(image_w, image_h)
141 | crop_sizes = [int(base_size * x) for x in self.scales]
142 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
143 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
144 |
145 | pairs = []
146 | for i, h in enumerate(crop_h):
147 | for j, w in enumerate(crop_w):
148 | if abs(i - j) <= self.max_distort:
149 | pairs.append((w, h))
150 |
151 | crop_pair = random.choice(pairs)
152 | if not self.fix_crop:
153 | w_offset = random.randint(0, image_w - crop_pair[0])
154 | h_offset = random.randint(0, image_h - crop_pair[1])
155 | else:
156 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
157 |
158 | return crop_pair[0], crop_pair[1], w_offset, h_offset
159 |
160 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
161 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
162 | return random.choice(offsets)
163 |
164 | @staticmethod
165 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
166 | w_step = (image_w - crop_w) // 4
167 | h_step = (image_h - crop_h) // 4
168 |
169 | ret = list()
170 | ret.append((0, 0)) # upper left
171 | ret.append((4 * w_step, 0)) # upper right
172 | ret.append((0, 4 * h_step)) # lower left
173 | ret.append((4 * w_step, 4 * h_step)) # lower right
174 | ret.append((2 * w_step, 2 * h_step)) # center
175 |
176 | if more_fix_crop:
177 | ret.append((0, 2 * h_step)) # center left
178 | ret.append((4 * w_step, 2 * h_step)) # center right
179 | ret.append((2 * w_step, 4 * h_step)) # lower center
180 | ret.append((2 * w_step, 0 * h_step)) # upper center
181 |
182 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
183 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
184 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
185 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
186 | return ret
187 |
188 |
189 | class Stack(object):
190 |
191 | def __init__(self, roll=False):
192 | self.roll = roll
193 |
194 | def __call__(self, img_tuple):
195 | img_group, label = img_tuple
196 |
197 | if img_group[0].mode == 'L':
198 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
199 | elif img_group[0].mode == 'RGB':
200 | if self.roll:
201 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
202 | else:
203 | return (np.concatenate(img_group, axis=2), label)
204 |
205 |
206 | class ToTorchFormatTensor(object):
207 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
208 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
209 | def __init__(self, div=True):
210 | self.div = div
211 |
212 | def __call__(self, pic_tuple):
213 | pic, label = pic_tuple
214 |
215 | if isinstance(pic, np.ndarray):
216 | # handle numpy array
217 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
218 | else:
219 | # handle PIL Image
220 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
221 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
222 | # put it from HWC to CHW format
223 | # yikes, this transpose takes 80% of the loading time/CPU
224 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
225 | return (img.float().div(255.) if self.div else img.float(), label)
226 |
227 |
228 | class IdentityTransform(object):
229 |
230 | def __call__(self, data):
231 | return data
232 |
--------------------------------------------------------------------------------
/single_modality/datasets/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 |
6 | def convert_img(img):
7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format
8 | """
9 | if len(img.shape) == 3:
10 | img = img.transpose(2, 0, 1)
11 | if len(img.shape) == 2:
12 | img = np.expand_dims(img, 0)
13 | return img
14 |
15 |
16 | class ClipToTensor(object):
17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
19 | """
20 |
21 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
22 | self.channel_nb = channel_nb
23 | self.div_255 = div_255
24 | self.numpy = numpy
25 |
26 | def __call__(self, clip):
27 | """
28 | Args: clip (list of numpy.ndarray): clip (list of images)
29 | to be converted to tensor.
30 | """
31 | # Retrieve shape
32 | if isinstance(clip[0], np.ndarray):
33 | h, w, ch = clip[0].shape
34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
35 | ch)
36 | elif isinstance(clip[0], Image.Image):
37 | w, h = clip[0].size
38 | else:
39 | raise TypeError('Expected numpy.ndarray or PIL.Image\
40 | but got list of {0}'.format(type(clip[0])))
41 |
42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
43 |
44 | # Convert
45 | for img_idx, img in enumerate(clip):
46 | if isinstance(img, np.ndarray):
47 | pass
48 | elif isinstance(img, Image.Image):
49 | img = np.array(img, copy=False)
50 | else:
51 | raise TypeError('Expected numpy.ndarray or PIL.Image\
52 | but got list of {0}'.format(type(clip[0])))
53 | img = convert_img(img)
54 | np_clip[:, img_idx, :, :] = img
55 | if self.numpy:
56 | if self.div_255:
57 | np_clip = np_clip / 255.0
58 | return np_clip
59 |
60 | else:
61 | tensor_clip = torch.from_numpy(np_clip)
62 |
63 | if not isinstance(tensor_clip, torch.FloatTensor):
64 | tensor_clip = tensor_clip.float()
65 | if self.div_255:
66 | tensor_clip = torch.div(tensor_clip, 255)
67 | return tensor_clip
68 |
69 |
70 | # Note this norms data to -1/1
71 | class ClipToTensor_K(object):
72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
74 | """
75 |
76 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
77 | self.channel_nb = channel_nb
78 | self.div_255 = div_255
79 | self.numpy = numpy
80 |
81 | def __call__(self, clip):
82 | """
83 | Args: clip (list of numpy.ndarray): clip (list of images)
84 | to be converted to tensor.
85 | """
86 | # Retrieve shape
87 | if isinstance(clip[0], np.ndarray):
88 | h, w, ch = clip[0].shape
89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
90 | ch)
91 | elif isinstance(clip[0], Image.Image):
92 | w, h = clip[0].size
93 | else:
94 | raise TypeError('Expected numpy.ndarray or PIL.Image\
95 | but got list of {0}'.format(type(clip[0])))
96 |
97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
98 |
99 | # Convert
100 | for img_idx, img in enumerate(clip):
101 | if isinstance(img, np.ndarray):
102 | pass
103 | elif isinstance(img, Image.Image):
104 | img = np.array(img, copy=False)
105 | else:
106 | raise TypeError('Expected numpy.ndarray or PIL.Image\
107 | but got list of {0}'.format(type(clip[0])))
108 | img = convert_img(img)
109 | np_clip[:, img_idx, :, :] = img
110 | if self.numpy:
111 | if self.div_255:
112 | np_clip = (np_clip - 127.5) / 127.5
113 | return np_clip
114 |
115 | else:
116 | tensor_clip = torch.from_numpy(np_clip)
117 |
118 | if not isinstance(tensor_clip, torch.FloatTensor):
119 | tensor_clip = tensor_clip.float()
120 | if self.div_255:
121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
122 | return tensor_clip
123 |
124 |
125 | class ToTensor(object):
126 | """Converts numpy array to tensor
127 | """
128 |
129 | def __call__(self, array):
130 | tensor = torch.from_numpy(array)
131 | return tensor
132 |
--------------------------------------------------------------------------------
/single_modality/engines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/engines/__init__.py
--------------------------------------------------------------------------------
/single_modality/engines/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/engines/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/engines/__pycache__/engine_for_flexible_tuning.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/engines/__pycache__/engine_for_flexible_tuning.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/engines/engine_for_flexible_tuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import math
5 | import sys
6 | from typing import Iterable, Optional
7 | import torch
8 | from datasets.mixup import Mixup
9 | from timm.utils import accuracy, ModelEma
10 | import utils
11 | from scipy.special import softmax
12 | import random
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 |
16 |
17 | def get_loss_scale_for_deepspeed(model):
18 | optimizer = model.optimizer
19 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
20 |
21 |
22 | def train_one_epoch(
23 | model: torch.nn.Module, criterion: torch.nn.Module,
24 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
27 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
28 | num_training_steps_per_epoch=None, update_freq=None, args=None,
29 | bf16=False,
30 | ):
31 | model.train(True)
32 | metric_logger = utils.MetricLogger(delimiter=" ")
33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
35 | header = 'Epoch: [{}]'.format(epoch)
36 | print_freq = 1
37 |
38 | if loss_scaler is None:
39 | model.zero_grad()
40 | model.micro_steps = 0
41 | else:
42 | optimizer.zero_grad()
43 |
44 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45 |
46 | step = data_iter_step // update_freq
47 | if step >= num_training_steps_per_epoch:
48 | continue
49 | it = start_steps + step # global training iteration
50 | # Update LR & WD for the first acc
51 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
52 | for i, param_group in enumerate(optimizer.param_groups):
53 | if lr_schedule_values is not None:
54 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
55 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
56 | param_group["weight_decay"] = wd_schedule_values[it]
57 |
58 | samples = samples.to(device, non_blocking=True)
59 | targets = targets.to(device, non_blocking=True)
60 |
61 | if mixup_fn is not None:
62 | samples, targets = mixup_fn(samples, targets)
63 |
64 | if loss_scaler is None:
65 |
66 | B, C, T, H, W = samples.shape
67 |
68 | samples = samples.to(device, non_blocking=True)
69 |
70 | nums_token = T * (H // 14) ** 2
71 | samples = samples.bfloat16() if bf16 else samples.half()
72 |
73 | def compute_mask_ratio(num_current_token, num_input_token):
74 | if num_current_token <= num_input_token:
75 | mask_ratio = 0.0
76 | else:
77 | mask_ratio = 1 - num_input_token / num_current_token
78 | return mask_ratio
79 |
80 | def compute_smooth_l1_loss(src, tgt):
81 | tgt = tgt.detach()
82 | src = src / src.norm(dim=-1, keepdim=True).to(torch.float32)
83 | tgt = tgt / tgt.norm(dim=-1, keepdim=True).to(torch.float32)
84 | return torch.nn.SmoothL1Loss()(src, tgt.detach())
85 |
86 | mask_ratio = compute_mask_ratio(nums_token, args.largest_num_input_token)
87 | outputs1, x_final_1 = model(samples, masking_ratio=mask_ratio, output_head=0, return_cls=True)
88 |
89 | mask_ratio = compute_mask_ratio(nums_token, args.middle_num_input_token)
90 | outputs2, x_final_2 = model(samples, masking_ratio=mask_ratio, output_head=1, return_cls=True, return_projected=True, align_proj=0)
91 |
92 | mask_ratio = compute_mask_ratio(nums_token, args.least_num_input_token)
93 | outputs3, x_final_3 = model(samples, masking_ratio=mask_ratio, output_head=2, return_cls=True, return_projected=True, align_proj=1)
94 |
95 | pred = criterion(outputs1, targets) + criterion(outputs2, targets) + criterion(outputs3, targets)
96 | align_loss = compute_smooth_l1_loss(x_final_2, x_final_1) + compute_smooth_l1_loss(x_final_3, x_final_1)
97 |
98 | loss = pred + align_loss
99 |
100 | else:
101 | raise NotImplementedError
102 |
103 | loss_value = loss.item()
104 |
105 | loss_list = [torch.zeros_like(loss) for _ in range(dist.get_world_size())]
106 | dist.all_gather(loss_list, loss)
107 | loss_list = torch.tensor(loss_list)
108 | loss_list_isnan = torch.isnan(loss_list).any()
109 | loss_list_isinf = torch.isinf(loss_list).any()
110 |
111 | if loss_list_isnan or loss_list_isinf:
112 | print("Loss is {}, stopping training".format(loss_value), force=True)
113 | model.zero_grad()
114 | continue
115 |
116 | if loss_scaler is None:
117 | loss /= update_freq
118 | model.backward(loss)
119 | model.step()
120 |
121 | if (data_iter_step + 1) % update_freq == 0:
122 | # model.zero_grad()
123 | # Deepspeed will call step() & model.zero_grad() automatic
124 | if model_ema is not None:
125 | model_ema.update(model)
126 | grad_norm = None
127 | loss_scale_value = get_loss_scale_for_deepspeed(model)
128 | else:
129 | # this attribute is added by timm on one optimizer (adahessian)
130 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
131 | loss /= update_freq
132 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
133 | parameters=model.parameters(), create_graph=is_second_order,
134 | update_grad=(data_iter_step + 1) % update_freq == 0)
135 | if (data_iter_step + 1) % update_freq == 0:
136 | optimizer.zero_grad()
137 | if model_ema is not None:
138 | model_ema.update(model)
139 | loss_scale_value = loss_scaler.state_dict()["scale"]
140 |
141 | torch.cuda.synchronize()
142 |
143 | class_acc = None
144 |
145 | metric_logger.update(loss=loss_value)
146 | metric_logger.update(class_acc=class_acc)
147 | metric_logger.update(loss_scale=loss_scale_value)
148 | min_lr = 10.
149 | max_lr = 0.
150 | for group in optimizer.param_groups:
151 | min_lr = min(min_lr, group["lr"])
152 | max_lr = max(max_lr, group["lr"])
153 |
154 | metric_logger.update(lr=max_lr)
155 | metric_logger.update(min_lr=min_lr)
156 | weight_decay_value = None
157 | for group in optimizer.param_groups:
158 | if group["weight_decay"] > 0:
159 | weight_decay_value = group["weight_decay"]
160 | metric_logger.update(weight_decay=weight_decay_value)
161 | metric_logger.update(grad_norm=grad_norm)
162 |
163 | if log_writer is not None:
164 | log_writer.update(loss=loss_value, head="loss")
165 | log_writer.update(class_acc=class_acc, head="loss")
166 | log_writer.update(loss_scale=loss_scale_value, head="opt")
167 | log_writer.update(lr=max_lr, head="opt")
168 | log_writer.update(min_lr=min_lr, head="opt")
169 | log_writer.update(weight_decay=weight_decay_value, head="opt")
170 | log_writer.update(grad_norm=grad_norm, head="opt")
171 |
172 | log_writer.set_step()
173 |
174 | # gather the stats from all processes
175 | metric_logger.synchronize_between_processes()
176 | print("Averaged stats:", metric_logger)
177 |
178 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
179 |
180 |
181 | @torch.no_grad()
182 | def validation_one_epoch(data_loader, model, device, args, ds=False, bf16=False):
183 | criterion = torch.nn.CrossEntropyLoss()
184 |
185 | metric_logger = utils.MetricLogger(delimiter=" ")
186 | header = 'Val:'
187 |
188 | # switch to evaluation mode
189 | model.eval()
190 |
191 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2)
192 |
193 | if current_tokens <= args.eval_max_token_num:
194 | mask_ratio = None
195 | nums_token = current_tokens
196 | else:
197 | mask_ratio = 1 - args.eval_max_token_num / current_tokens
198 | nums_token = args.eval_max_token_num
199 |
200 | print('Nums Token: ', nums_token)
201 |
202 | num_merging_to = args.eval_num_merging_to
203 |
204 | print("Num Merging Ratio: ", num_merging_to)
205 |
206 | for batch in metric_logger.log_every(data_loader, 10, header):
207 | videos = batch[0]
208 | target = batch[1]
209 | videos = videos.to(device, non_blocking=True)
210 | target = target.to(device, non_blocking=True)
211 |
212 | if ds:
213 | videos = videos.bfloat16() if bf16 else videos.half()
214 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head)
215 | loss = criterion(output, target)
216 | else:
217 | with torch.cuda.amp.autocast():
218 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head)
219 | loss = criterion(output, target)
220 |
221 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
222 |
223 | batch_size = videos.shape[0]
224 | metric_logger.update(loss=loss.item())
225 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
226 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
227 | # gather the stats from all processes
228 | metric_logger.synchronize_between_processes()
229 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
230 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
231 |
232 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
233 |
234 |
235 | @torch.no_grad()
236 | def final_test(data_loader, model, device, file, args, ds=False, bf16=False, test_time=0):
237 | criterion = torch.nn.CrossEntropyLoss()
238 |
239 | metric_logger = utils.MetricLogger(delimiter=" ")
240 | header = 'Test:'
241 |
242 | # switch to evaluation mode
243 | model.eval()
244 | final_result = []
245 |
246 | # args.eval_masking_ratio = random.uniform(args.eval_masking_ratio / 3, args.eval_masking_ratio)
247 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2)
248 |
249 | if current_tokens <= args.eval_max_token_num:
250 | mask_ratio = None
251 | nums_token = current_tokens
252 | else:
253 | mask_ratio = 1 - args.eval_max_token_num / current_tokens
254 | nums_token = args.eval_max_token_num
255 |
256 | for batch in metric_logger.log_every(data_loader, 10, header):
257 | videos = batch[0]
258 | target = batch[1]
259 | ids = batch[2]
260 | chunk_nb = batch[3]
261 | split_nb = batch[4]
262 | videos = videos.to(device, non_blocking=True)
263 | target = target.to(device, non_blocking=True)
264 |
265 | # compute output
266 | if ds:
267 | videos = videos.bfloat16() if bf16 else videos.half()
268 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head)
269 | loss = criterion(output, target)
270 | else:
271 | with torch.cuda.amp.autocast():
272 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head)
273 | loss = criterion(output, target)
274 |
275 | for i in range(output.size(0)):
276 | string = "{} {} {} {} {}\n".format(ids[i], \
277 | str(output.data[i].float().cpu().numpy().tolist()), \
278 | str(int(target[i].cpu().numpy())), \
279 | str(int(chunk_nb[i].cpu().numpy())), \
280 | str(int(split_nb[i].cpu().numpy())))
281 | final_result.append(string)
282 |
283 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
284 |
285 | batch_size = videos.shape[0]
286 | metric_logger.update(loss=loss.item())
287 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
288 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
289 |
290 | if not os.path.exists(file):
291 | os.mknod(file)
292 | with open(file, 'w') as f:
293 | f.write("{}, {}\n".format(acc1, acc5))
294 | for line in final_result:
295 | f.write(line)
296 | # gather the stats from all processes
297 | metric_logger.synchronize_between_processes()
298 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
299 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
300 |
301 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
302 |
303 |
304 | def merge(eval_path, num_tasks):
305 | dict_feats = {}
306 | dict_label = {}
307 | dict_pos = {}
308 | print("Reading individual output files")
309 |
310 | for x in range(num_tasks):
311 | file = os.path.join(eval_path, str(x) + '.txt')
312 | lines = open(file, 'r').readlines()[1:]
313 | for line in lines:
314 | line = line.strip()
315 | name = line.split(' ')[0]
316 | label = line.split(']')[-1].split(' ')[1]
317 | chunk_nb = line.split(']')[-1].split(' ')[2]
318 | split_nb = line.split(']')[-1].split(' ')[3]
319 | data = np.fromstring(' '.join(line.split(' ')[1:]).split('[')[1].split(']')[0], dtype=np.float32, sep=',')
320 | data = softmax(data)
321 | if not name in dict_feats:
322 | dict_feats[name] = []
323 | dict_label[name] = 0
324 | dict_pos[name] = []
325 | if chunk_nb + split_nb in dict_pos[name]:
326 | continue
327 | dict_feats[name].append(data)
328 | dict_pos[name].append(chunk_nb + split_nb)
329 | dict_label[name] = label
330 | print("Computing final results")
331 |
332 | input_lst = []
333 | print(len(dict_feats))
334 | for i, item in enumerate(dict_feats):
335 | input_lst.append([i, item, dict_feats[item], dict_label[item]])
336 | from multiprocessing import Pool
337 | p = Pool(64)
338 | ans = p.map(compute_video, input_lst)
339 | top1 = [x[1] for x in ans]
340 | top5 = [x[2] for x in ans]
341 | pred = [x[0] for x in ans]
342 | label = [x[3] for x in ans]
343 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
344 | return final_top1*100 ,final_top5*100
345 |
346 | def compute_video(lst):
347 | i, video_id, data, label = lst
348 | feat = [x for x in data]
349 | feat = np.mean(feat, axis=0)
350 | pred = np.argmax(feat)
351 | top1 = (int(pred) == int(label)) * 1.0
352 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
353 | return [pred, top1, top5, int(label)]
354 |
--------------------------------------------------------------------------------
/single_modality/engines/engine_for_flexible_tuning_single.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import math
5 | import sys
6 | from typing import Iterable, Optional
7 | import torch
8 | from datasets.mixup import Mixup
9 | from timm.utils import accuracy, ModelEma
10 | import utils
11 | from scipy.special import softmax
12 | import random
13 | import torch.nn.functional as F
14 | import torch.distributed as dist
15 |
16 |
17 | def get_loss_scale_for_deepspeed(model):
18 | optimizer = model.optimizer
19 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
20 |
21 |
22 | def train_one_epoch(
23 | model: torch.nn.Module, criterion: torch.nn.Module,
24 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
27 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
28 | num_training_steps_per_epoch=None, update_freq=None, args=None,
29 | bf16=False,
30 | ):
31 | model.train(True)
32 | metric_logger = utils.MetricLogger(delimiter=" ")
33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
35 | header = 'Epoch: [{}]'.format(epoch)
36 | print_freq = 1
37 |
38 | if loss_scaler is None:
39 | model.zero_grad()
40 | model.micro_steps = 0
41 | else:
42 | optimizer.zero_grad()
43 |
44 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45 |
46 | step = data_iter_step // update_freq
47 | if step >= num_training_steps_per_epoch:
48 | continue
49 | it = start_steps + step # global training iteration
50 | # Update LR & WD for the first acc
51 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
52 | for i, param_group in enumerate(optimizer.param_groups):
53 | if lr_schedule_values is not None:
54 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
55 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
56 | param_group["weight_decay"] = wd_schedule_values[it]
57 |
58 | samples = samples.to(device, non_blocking=True)
59 | targets = targets.to(device, non_blocking=True)
60 |
61 | if mixup_fn is not None:
62 | samples, targets = mixup_fn(samples, targets)
63 |
64 | if loss_scaler is None:
65 |
66 | B, C, T, H, W = samples.shape
67 |
68 | samples = samples.to(device, non_blocking=True)
69 |
70 | nums_token = T * (H // 14) ** 2
71 | samples = samples.bfloat16() if bf16 else samples.half()
72 |
73 | def compute_mask_ratio(num_current_token, num_input_token):
74 | if num_current_token <= num_input_token:
75 | mask_ratio = 0.0
76 | else:
77 | mask_ratio = 1 - num_input_token / num_current_token
78 | return mask_ratio
79 |
80 | mask_ratio = compute_mask_ratio(nums_token, args.eval_max_token_num)
81 | outputs, _ = model(samples, masking_ratio=mask_ratio, output_head=args.output_head, return_cls=True)
82 |
83 | pred = criterion(outputs, targets)
84 |
85 | loss = pred
86 |
87 | else:
88 | raise NotImplementedError
89 |
90 | loss_value = loss.item()
91 |
92 | loss_list = [torch.zeros_like(loss) for _ in range(dist.get_world_size())]
93 | dist.all_gather(loss_list, loss)
94 | loss_list = torch.tensor(loss_list)
95 | loss_list_isnan = torch.isnan(loss_list).any()
96 | loss_list_isinf = torch.isinf(loss_list).any()
97 |
98 | if loss_list_isnan or loss_list_isinf:
99 | print("Loss is {}, stopping training".format(loss_value), force=True)
100 | model.zero_grad()
101 | continue
102 |
103 | if loss_scaler is None:
104 | loss /= update_freq
105 | model.backward(loss)
106 | model.step()
107 |
108 | if (data_iter_step + 1) % update_freq == 0:
109 | # model.zero_grad()
110 | # Deepspeed will call step() & model.zero_grad() automatic
111 | if model_ema is not None:
112 | model_ema.update(model)
113 | grad_norm = None
114 | loss_scale_value = get_loss_scale_for_deepspeed(model)
115 | else:
116 | # this attribute is added by timm on one optimizer (adahessian)
117 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
118 | loss /= update_freq
119 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
120 | parameters=model.parameters(), create_graph=is_second_order,
121 | update_grad=(data_iter_step + 1) % update_freq == 0)
122 | if (data_iter_step + 1) % update_freq == 0:
123 | optimizer.zero_grad()
124 | if model_ema is not None:
125 | model_ema.update(model)
126 | loss_scale_value = loss_scaler.state_dict()["scale"]
127 |
128 | torch.cuda.synchronize()
129 |
130 | class_acc = None
131 |
132 | metric_logger.update(loss=loss_value)
133 | metric_logger.update(class_acc=class_acc)
134 | metric_logger.update(loss_scale=loss_scale_value)
135 | min_lr = 10.
136 | max_lr = 0.
137 | for group in optimizer.param_groups:
138 | min_lr = min(min_lr, group["lr"])
139 | max_lr = max(max_lr, group["lr"])
140 |
141 | metric_logger.update(lr=max_lr)
142 | metric_logger.update(min_lr=min_lr)
143 | weight_decay_value = None
144 | for group in optimizer.param_groups:
145 | if group["weight_decay"] > 0:
146 | weight_decay_value = group["weight_decay"]
147 | metric_logger.update(weight_decay=weight_decay_value)
148 | metric_logger.update(grad_norm=grad_norm)
149 |
150 | if log_writer is not None:
151 | log_writer.update(loss=loss_value, head="loss")
152 | log_writer.update(class_acc=class_acc, head="loss")
153 | log_writer.update(loss_scale=loss_scale_value, head="opt")
154 | log_writer.update(lr=max_lr, head="opt")
155 | log_writer.update(min_lr=min_lr, head="opt")
156 | log_writer.update(weight_decay=weight_decay_value, head="opt")
157 | log_writer.update(grad_norm=grad_norm, head="opt")
158 |
159 | log_writer.set_step()
160 |
161 | # gather the stats from all processes
162 | metric_logger.synchronize_between_processes()
163 | print("Averaged stats:", metric_logger)
164 |
165 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
166 |
167 |
168 | @torch.no_grad()
169 | def validation_one_epoch(data_loader, model, device, args, ds=False, bf16=False):
170 | criterion = torch.nn.CrossEntropyLoss()
171 |
172 | metric_logger = utils.MetricLogger(delimiter=" ")
173 | header = 'Val:'
174 |
175 | # switch to evaluation mode
176 | model.eval()
177 |
178 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2)
179 |
180 | if current_tokens <= args.eval_max_token_num:
181 | mask_ratio = None
182 | nums_token = current_tokens
183 | else:
184 | mask_ratio = 1 - args.eval_max_token_num / current_tokens
185 | nums_token = args.eval_max_token_num
186 |
187 | print('Nums Token: ', nums_token)
188 |
189 | num_merging_to = args.eval_num_merging_to
190 |
191 | print("Num Merging Ratio: ", num_merging_to)
192 |
193 | for batch in metric_logger.log_every(data_loader, 10, header):
194 | videos = batch[0]
195 | target = batch[1]
196 | videos = videos.to(device, non_blocking=True)
197 | target = target.to(device, non_blocking=True)
198 |
199 | if ds:
200 | videos = videos.bfloat16() if bf16 else videos.half()
201 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head)
202 | loss = criterion(output, target)
203 | else:
204 | with torch.cuda.amp.autocast():
205 | output = model(videos, num_merging_to=num_merging_to, masking_ratio=mask_ratio, output_head=args.output_head)
206 | loss = criterion(output, target)
207 |
208 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
209 |
210 | batch_size = videos.shape[0]
211 | metric_logger.update(loss=loss.item())
212 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
213 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
214 | # gather the stats from all processes
215 | metric_logger.synchronize_between_processes()
216 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
217 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
218 |
219 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
220 |
221 |
222 | @torch.no_grad()
223 | def final_test(data_loader, model, device, file, args, ds=False, bf16=False, test_time=0):
224 | criterion = torch.nn.CrossEntropyLoss()
225 |
226 | metric_logger = utils.MetricLogger(delimiter=" ")
227 | header = 'Test:'
228 |
229 | # switch to evaluation mode
230 | model.eval()
231 | final_result = []
232 |
233 | # args.eval_masking_ratio = random.uniform(args.eval_masking_ratio / 3, args.eval_masking_ratio)
234 | current_tokens = args.eval_true_frame * ((args.eval_input_size // 14) ** 2)
235 |
236 | if current_tokens <= args.eval_max_token_num:
237 | mask_ratio = None
238 | nums_token = current_tokens
239 | else:
240 | mask_ratio = 1 - args.eval_max_token_num / current_tokens
241 | nums_token = args.eval_max_token_num
242 |
243 | for batch in metric_logger.log_every(data_loader, 10, header):
244 | videos = batch[0]
245 | target = batch[1]
246 | ids = batch[2]
247 | chunk_nb = batch[3]
248 | split_nb = batch[4]
249 | videos = videos.to(device, non_blocking=True)
250 | target = target.to(device, non_blocking=True)
251 |
252 | # compute output
253 | if ds:
254 | videos = videos.bfloat16() if bf16 else videos.half()
255 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head)
256 | loss = criterion(output, target)
257 | else:
258 | with torch.cuda.amp.autocast():
259 | output = model(videos, masking_ratio=mask_ratio, output_head=args.output_head)
260 | loss = criterion(output, target)
261 |
262 | for i in range(output.size(0)):
263 | string = "{} {} {} {} {}\n".format(ids[i], \
264 | str(output.data[i].float().cpu().numpy().tolist()), \
265 | str(int(target[i].cpu().numpy())), \
266 | str(int(chunk_nb[i].cpu().numpy())), \
267 | str(int(split_nb[i].cpu().numpy())))
268 | final_result.append(string)
269 |
270 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
271 |
272 | batch_size = videos.shape[0]
273 | metric_logger.update(loss=loss.item())
274 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
275 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
276 |
277 | if not os.path.exists(file):
278 | os.mknod(file)
279 | with open(file, 'w') as f:
280 | f.write("{}, {}\n".format(acc1, acc5))
281 | for line in final_result:
282 | f.write(line)
283 | # gather the stats from all processes
284 | metric_logger.synchronize_between_processes()
285 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
286 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
287 |
288 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
289 |
290 |
291 | def merge(eval_path, num_tasks):
292 | dict_feats = {}
293 | dict_label = {}
294 | dict_pos = {}
295 | print("Reading individual output files")
296 |
297 | for x in range(num_tasks):
298 | file = os.path.join(eval_path, str(x) + '.txt')
299 | lines = open(file, 'r').readlines()[1:]
300 | for line in lines:
301 | line = line.strip()
302 | name = line.split(' ')[0]
303 | label = line.split(']')[-1].split(' ')[1]
304 | chunk_nb = line.split(']')[-1].split(' ')[2]
305 | split_nb = line.split(']')[-1].split(' ')[3]
306 | data = np.fromstring(' '.join(line.split(' ')[1:]).split('[')[1].split(']')[0], dtype=np.float32, sep=',')
307 | data = softmax(data)
308 | if not name in dict_feats:
309 | dict_feats[name] = []
310 | dict_label[name] = 0
311 | dict_pos[name] = []
312 | if chunk_nb + split_nb in dict_pos[name]:
313 | continue
314 | dict_feats[name].append(data)
315 | dict_pos[name].append(chunk_nb + split_nb)
316 | dict_label[name] = label
317 | print("Computing final results")
318 |
319 | input_lst = []
320 | print(len(dict_feats))
321 | for i, item in enumerate(dict_feats):
322 | input_lst.append([i, item, dict_feats[item], dict_label[item]])
323 | from multiprocessing import Pool
324 | p = Pool(64)
325 | ans = p.map(compute_video, input_lst)
326 | top1 = [x[1] for x in ans]
327 | top5 = [x[2] for x in ans]
328 | pred = [x[0] for x in ans]
329 | label = [x[3] for x in ans]
330 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
331 | return final_top1*100 ,final_top5*100
332 |
333 | def compute_video(lst):
334 | i, video_id, data, label = lst
335 | feat = [x for x in data]
336 | feat = np.mean(feat, axis=0)
337 | pred = np.argmax(feat)
338 | top1 = (int(pred) == int(label)) * 1.0
339 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
340 | return [pred, top1, top5, int(label)]
341 |
--------------------------------------------------------------------------------
/single_modality/exp/base/eval/k400_eval.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 |
4 | JOB_NAME='k400_base_f8_lr2e-4_w5e35_dp0.1_hdp0.1_ld0.75_g64'
5 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
6 | LOG_DIR="./logs/${JOB_NAME}"
7 | PREFIX='p2:s3://k400'
8 | DATA_PATH='' # replace with your data_path
9 | MODEL_PATH='' # replace with your model_path
10 |
11 | PARTITION='video'
12 | GPUS=64
13 | GPUS_PER_NODE=8
14 | CPUS_PER_TASK=16
15 |
16 | MAX_TOKEN_COUNT=2048
17 | TEST_RES=224
18 | TEST_FRAME=8
19 | OUTPUT_HEAD=1 # can be set with {0, 1, 2}, corresponding to {1024, 2048, 3072} tokens, use nearest head for your token budget
20 |
21 | srun -p $PARTITION \
22 | --gres=gpu:${GPUS_PER_NODE} \
23 | --ntasks=${GPUS} \
24 | --ntasks-per-node=${GPUS_PER_NODE} \
25 | --cpus-per-task=${CPUS_PER_TASK} \
26 | --kill-on-bad-exit=1 \
27 | python -u run_flexible_finetune.py \
28 | --eval \
29 | --limit_token_num 6144 \
30 | --largest_num_input_token 3072 \
31 | --middle_num_input_token 2048 \
32 | --least_num_input_token 1024 \
33 | --eval_max_token_num ${MAX_TOKEN_COUNT} \
34 | --eval_input_size ${TEST_RES} \
35 | --eval_short_side_size ${TEST_RES} \
36 | --eval_true_frame ${TEST_FRAME} \
37 | --output_head ${OUTPUT_HEAD} \
38 | --finetune ${MODEL_PATH} \
39 | --model fluxvit_base_patch14 \
40 | --data_path ${DATA_PATH} \
41 | --data_set 'Kinetics_sparse_ofa' \
42 | --prefix ${PREFIX} \
43 | --split ',' \
44 | --nb_classes 400 \
45 | --finetune ${MODEL_PATH} \
46 | --log_dir ${OUTPUT_DIR} \
47 | --output_dir ${OUTPUT_DIR} \
48 | --steps_per_print 10 \
49 | --batch_size 8 \
50 | --num_sample 2 \
51 | --input_size 252 \
52 | --short_side_size 252 \
53 | --model_res 252 \
54 | --save_ckpt_freq 100 \
55 | --num_frames 32 \
56 | --eval_orig_frame 32 \
57 | --orig_t_size 32 \
58 | --num_workers 12 \
59 | --warmup_epochs 1 \
60 | --tubelet_size 1 \
61 | --epochs 5 \
62 | --lr 2e-4 \
63 | --drop_path 0.1 \
64 | --head_drop_path 0.1 \
65 | --fc_drop_rate 0.0 \
66 | --layer_decay 0.75 \
67 | --layer_scale_init_value 1e-5 \
68 | --opt adamw \
69 | --opt_betas 0.9 0.999 \
70 | --weight_decay 0.05 \
71 | --test_num_segment 4 \
72 | --test_num_crop 3 \
73 | --dist_eval \
74 | --enable_deepspeed \
75 | --bf16 \
76 | --zero_stage 1 \
77 | --log_dir ${OUTPUT_DIR} \
78 | --output_dir ${OUTPUT_DIR} \
79 | --test_best
--------------------------------------------------------------------------------
/single_modality/exp/base/ft/k400/k400_multi.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/exp/base/ft/k400/k400_multi.sh
--------------------------------------------------------------------------------
/single_modality/exp/base/ft/k400/k400_single.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/exp/base/ft/k400/k400_single.sh
--------------------------------------------------------------------------------
/single_modality/exp/small/eval/k400_eval.sh:
--------------------------------------------------------------------------------
1 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
2 | export OMP_NUM_THREADS=1
3 |
4 | JOB_NAME='k400_small_f8_lr2e-4_w5e35_dp0.1_hdp0.1_ld0.75_g64'
5 | OUTPUT_DIR="$(dirname $0)/$JOB_NAME"
6 | LOG_DIR="./logs/${JOB_NAME}"
7 | PREFIX='p2:s3://k400'
8 | DATA_PATH='' # replace with your data_path
9 | MODEL_PATH='' # replace with your model_path
10 |
11 | PARTITION='video'
12 | GPUS=64
13 | GPUS_PER_NODE=8
14 | CPUS_PER_TASK=16
15 |
16 | MAX_TOKEN_COUNT=2048
17 | TEST_RES=224
18 | TEST_FRAME=8
19 | OUTPUT_HEAD=1 # can be set with {0, 1, 2}, corresponding to {1024, 2048, 3072} tokens, use nearest head for your token budget
20 |
21 | srun -p $PARTITION \
22 | --gres=gpu:${GPUS_PER_NODE} \
23 | --ntasks=${GPUS} \
24 | --ntasks-per-node=${GPUS_PER_NODE} \
25 | --cpus-per-task=${CPUS_PER_TASK} \
26 | --kill-on-bad-exit=1 \
27 | python -u run_flexible_finetune.py \
28 | --eval \
29 | --limit_token_num 6144 \
30 | --largest_num_input_token 3072 \
31 | --middle_num_input_token 2048 \
32 | --least_num_input_token 1024 \
33 | --eval_max_token_num ${MAX_TOKEN_COUNT} \
34 | --eval_input_size ${TEST_RES} \
35 | --eval_short_side_size ${TEST_RES} \
36 | --eval_true_frame ${TEST_FRAME} \
37 | --output_head ${OUTPUT_HEAD} \
38 | --finetune ${MODEL_PATH} \
39 | --model fluxvit_small_patch14 \
40 | --data_path ${DATA_PATH} \
41 | --data_set 'Kinetics_sparse_ofa' \
42 | --prefix ${PREFIX} \
43 | --split ',' \
44 | --nb_classes 400 \
45 | --finetune ${MODEL_PATH} \
46 | --log_dir ${OUTPUT_DIR} \
47 | --output_dir ${OUTPUT_DIR} \
48 | --steps_per_print 10 \
49 | --batch_size 8 \
50 | --num_sample 2 \
51 | --input_size 252 \
52 | --short_side_size 252 \
53 | --model_res 252 \
54 | --save_ckpt_freq 100 \
55 | --num_frames 24 \
56 | --eval_orig_frame 24 \
57 | --orig_t_size 24 \
58 | --num_workers 12 \
59 | --warmup_epochs 1 \
60 | --tubelet_size 1 \
61 | --epochs 5 \
62 | --lr 2e-4 \
63 | --drop_path 0.1 \
64 | --head_drop_path 0.1 \
65 | --fc_drop_rate 0.0 \
66 | --layer_decay 0.75 \
67 | --layer_scale_init_value 1e-5 \
68 | --opt adamw \
69 | --opt_betas 0.9 0.999 \
70 | --weight_decay 0.05 \
71 | --test_num_segment 4 \
72 | --test_num_crop 3 \
73 | --dist_eval \
74 | --enable_deepspeed \
75 | --bf16 \
76 | --zero_stage 1 \
77 | --log_dir ${OUTPUT_DIR} \
78 | --output_dir ${OUTPUT_DIR} \
79 | --test_best
--------------------------------------------------------------------------------
/single_modality/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import cv2
3 | import numpy as np
4 | import PIL
5 | import torch
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [
18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19 | ]
20 | else:
21 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
22 | 'but got list of {0}'.format(type(clip[0])))
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation='bilinear'):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
32 | and im_h == size):
33 | return clip
34 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
35 | size = (new_w, new_h)
36 | else:
37 | size = size[0], size[1]
38 | if interpolation == 'bilinear':
39 | np_inter = cv2.INTER_LINEAR
40 | else:
41 | np_inter = cv2.INTER_NEAREST
42 | scaled = [
43 | cv2.resize(img, size, interpolation=np_inter) for img in clip
44 | ]
45 | elif isinstance(clip[0], PIL.Image.Image):
46 | if isinstance(size, numbers.Number):
47 | im_w, im_h = clip[0].size
48 | # Min spatial dim already matches minimal size
49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
50 | and im_h == size):
51 | return clip
52 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
53 | size = (new_w, new_h)
54 | else:
55 | size = size[1], size[0]
56 | if interpolation == 'bilinear':
57 | pil_inter = PIL.Image.BILINEAR
58 | else:
59 | pil_inter = PIL.Image.NEAREST
60 | scaled = [img.resize(size, pil_inter) for img in clip]
61 | else:
62 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
63 | 'but got list of {0}'.format(type(clip[0])))
64 | return scaled
65 |
66 |
67 | def get_resize_sizes(im_h, im_w, size):
68 | if im_w < im_h:
69 | ow = size
70 | oh = int(size * im_h / im_w)
71 | else:
72 | oh = size
73 | ow = int(size * im_w / im_h)
74 | return oh, ow
75 |
76 |
77 | def normalize(clip, mean, std, inplace=False):
78 | if not _is_tensor_clip(clip):
79 | raise TypeError('tensor is not a torch clip.')
80 |
81 | if not inplace:
82 | clip = clip.clone()
83 |
84 | dtype = clip.dtype
85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
88 |
89 | return clip
90 |
--------------------------------------------------------------------------------
/single_modality/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .fluxvit import fluxvit_small_patch14, fluxvit_base_patch14
--------------------------------------------------------------------------------
/single_modality/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/models/__pycache__/flash_attention_class.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/flash_attention_class.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/models/__pycache__/fluxvit.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/fluxvit.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/models/__pycache__/pos_embed.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/pos_embed.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/models/__pycache__/vid_tldr.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/FluxViT/4b33eec598c03c6bcd28e3487345bb0e990cbaa0/single_modality/models/__pycache__/vid_tldr.cpython-39.pyc
--------------------------------------------------------------------------------
/single_modality/models/flash_attention_class.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from einops import rearrange
5 |
6 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
7 | from flash_attn.bert_padding import unpad_input, pad_input
8 |
9 |
10 | class FlashAttention(nn.Module):
11 | """Implement the scaled dot product attention with softmax.
12 | Arguments
13 | ---------
14 | softmax_scale: The temperature to use for the softmax attention.
15 | (default: 1/sqrt(d_keys) where d_keys is computed at
16 | runtime)
17 | attention_dropout: The dropout rate to apply to the attention
18 | (default: 0.0)
19 | """
20 |
21 | def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
22 | super().__init__()
23 | self.softmax_scale = softmax_scale
24 | self.dropout_p = attention_dropout
25 |
26 | def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
27 | max_s=None, need_weights=False):
28 | """Implements the multihead softmax attention.
29 | Arguments
30 | ---------
31 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
32 | if unpadded: (nnz, 3, h, d)
33 | key_padding_mask: a bool tensor of shape (B, S)
34 | """
35 | assert not need_weights
36 | assert qkv.dtype in [torch.float16, torch.bfloat16]
37 | assert qkv.is_cuda
38 |
39 | if cu_seqlens is None:
40 | batch_size = qkv.shape[0]
41 | seqlen = qkv.shape[1]
42 | if key_padding_mask is None:
43 | qkv = rearrange(qkv, 'b s ... -> (b s) ...')
44 | max_s = seqlen
45 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
46 | device=qkv.device)
47 | output = flash_attn_varlen_qkvpacked_func(
48 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
49 | softmax_scale=self.softmax_scale, causal=causal
50 | )
51 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
52 | else:
53 | nheads = qkv.shape[-2]
54 | x = rearrange(qkv, 'b s three h d -> b s (three h d)')
55 | x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
56 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
57 | output_unpad = flash_attn_varlen_qkvpacked_func(
58 | x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
59 | softmax_scale=self.softmax_scale, causal=causal
60 | )
61 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
62 | indices, batch_size, seqlen),
63 | 'b s (h d) -> b s h d', h=nheads)
64 | else:
65 | assert max_s is not None
66 | output = flash_attn_varlen_qkvpacked_func(
67 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
68 | softmax_scale=self.softmax_scale, causal=causal
69 | )
70 |
71 | return output, None
72 |
--------------------------------------------------------------------------------
/single_modality/models/pos_embed.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 |
6 |
7 | # --------------------------------------------------------
8 | # 3D sine-cosine position embedding
9 | # References:
10 | # MVD: https://github.com/ruiwang2021/mvd/blob/main/modeling_finetune.py
11 | # --------------------------------------------------------
12 | def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
13 | """
14 | grid_size: int of the grid height and width
15 | t_size: int of the temporal size
16 | return:
17 | pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
18 | """
19 | assert embed_dim % 4 == 0
20 | embed_dim_spatial = embed_dim // 4 * 3
21 | embed_dim_temporal = embed_dim // 4
22 |
23 | # spatial
24 | grid_h = np.arange(grid_size, dtype=np.float32)
25 | grid_w = np.arange(grid_size, dtype=np.float32)
26 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
27 | grid = np.stack(grid, axis=0)
28 |
29 | grid = grid.reshape([2, 1, grid_size, grid_size])
30 | pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
31 | embed_dim_spatial, grid
32 | )
33 |
34 | # temporal
35 | grid_t = np.arange(t_size, dtype=np.float32)
36 | pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
37 | embed_dim_temporal, grid_t
38 | )
39 |
40 | # concate: [T, H, W] order
41 | pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
42 | pos_embed_temporal = np.repeat(
43 | pos_embed_temporal, grid_size**2, axis=1
44 | ) # [T, H*W, D // 4]
45 | pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
46 | pos_embed_spatial = np.repeat(
47 | pos_embed_spatial, t_size, axis=0
48 | ) # [T, H*W, D // 4 * 3]
49 |
50 | pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
51 | pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
52 |
53 | if cls_token:
54 | pos_embed = np.concatenate(
55 | [np.zeros([1, embed_dim]), pos_embed], axis=0
56 | )
57 | return pos_embed
58 |
59 |
60 | # --------------------------------------------------------
61 | # 2D sine-cosine position embedding
62 | # References:
63 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
64 | # MoCo v3: https://github.com/facebookresearch/moco-v3
65 | # --------------------------------------------------------
66 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
67 | """
68 | grid_size: int of the grid height and width
69 | return:
70 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71 | """
72 | grid_h = np.arange(grid_size, dtype=np.float32)
73 | grid_w = np.arange(grid_size, dtype=np.float32)
74 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
75 | grid = np.stack(grid, axis=0)
76 |
77 | grid = grid.reshape([2, 1, grid_size, grid_size])
78 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79 | if cls_token:
80 | pos_embed = np.concatenate(
81 | [np.zeros([1, embed_dim]), pos_embed], axis=0
82 | )
83 | return pos_embed
84 |
85 |
86 | def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
87 | """
88 | t_size: int of the temporal size
89 | return:
90 | pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
91 | """
92 | grid_t = np.arange(t_size, dtype=np.float32)
93 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
94 | if cls_token:
95 | pos_embed = np.concatenate(
96 | [np.zeros([1, embed_dim]), pos_embed], axis=0
97 | )
98 | return pos_embed
99 |
100 |
101 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
102 | assert embed_dim % 2 == 0
103 |
104 | # use half of dimensions to encode grid_h
105 | emb_h = get_1d_sincos_pos_embed_from_grid(
106 | embed_dim // 2, grid[0]
107 | ) # (H*W, D/2)
108 | emb_w = get_1d_sincos_pos_embed_from_grid(
109 | embed_dim // 2, grid[1]
110 | ) # (H*W, D/2)
111 |
112 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
113 | return emb
114 |
115 |
116 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
117 | """
118 | embed_dim: output dimension for each position
119 | pos: a list of positions to be encoded: size (M,)
120 | out: (M, D)
121 | """
122 | assert embed_dim % 2 == 0
123 | omega = np.arange(embed_dim // 2, dtype=np.float32)
124 | omega /= embed_dim / 2.0
125 | omega = 1.0 / 10000**omega # (D/2,)
126 |
127 | pos = pos.reshape(-1) # (M,)
128 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
129 |
130 | emb_sin = np.sin(out) # (M, D/2)
131 | emb_cos = np.cos(out) # (M, D/2)
132 |
133 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
134 | return emb
135 |
136 |
137 | def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
138 | if pos_name in checkpoint_model:
139 | pos_embed_checkpoint = checkpoint_model[pos_name]
140 | embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
141 | num_patches = model.patch_embed.num_patches #
142 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
143 |
144 | # we use 4 frames for pretraining
145 | new_t_size = model.T
146 | # height (== width) for the checkpoint position embedding
147 | orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
148 | # height (== width) for the new position embedding
149 | new_size = int((num_patches // (new_t_size))** 0.5)
150 |
151 | # class_token and dist_token are kept unchanged
152 | if orig_t_size != new_t_size:
153 | print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
154 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
155 | # only the position tokens are interpolated
156 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
157 | # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
158 | pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
159 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
160 | pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
161 | pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
162 | pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
163 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
164 | checkpoint_model[pos_name] = new_pos_embed
165 | pos_embed_checkpoint = new_pos_embed
166 |
167 | # class_token and dist_token are kept unchanged
168 | if orig_size != new_size:
169 | print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
170 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
171 | # only the position tokens are interpolated
172 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
173 | # B, L, C -> BT, H, W, C -> BT, C, H, W
174 | pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
175 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
176 | pos_tokens = torch.nn.functional.interpolate(
177 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
178 | # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
179 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
180 | pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
181 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
182 | checkpoint_model[pos_name] = new_pos_embed
183 | else:
184 | raise NotImplementedError
185 |
--------------------------------------------------------------------------------
/single_modality/models/vid_tldr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | import math
8 | import torch
9 | from einops import rearrange
10 | from typing import Callable, Tuple, List, Union
11 | import torch.nn.functional as F
12 |
13 | def get_objective_score(score_attn, eps=1e-17):
14 | # Mean across the first dimension
15 | score_attn = score_attn.mean(dim=1)
16 |
17 | # Use torch.clamp with a smaller range
18 | score_attn = torch.clamp(score_attn, min=eps, max=1.0 - eps)
19 | score_attn = F.normalize(score_attn, p=1, dim=-1)
20 |
21 | # Compute entropy using a numerically stable method
22 | scores = score_attn * torch.log(score_attn)
23 | scores = scores.sum(dim=2).unsqueeze(-1)
24 |
25 | # BACKGROUND REMOVING
26 | B, T_R, _ = scores.shape
27 | scores = scores - scores.amin(dim=1, keepdim=True)
28 | scores = scores / (scores.amax(dim=1, keepdim=True))
29 | score_mean = scores.mean(dim=1, keepdim=True)
30 | score_mask = scores < score_mean
31 |
32 | # FOREGROUND SHARPENING
33 | scores = scores - score_mean
34 | scores = scores / (scores.amax(dim=1, keepdim=True))
35 | scores = scores.masked_fill(score_mask, 0.0)
36 |
37 | return scores
38 |
39 | def gini_impurity(probabilities):
40 | return 1 - torch.sum(probabilities ** 2, dim=-1)
41 |
42 |
43 | def get_objective_score_gini(score_attn):
44 | score_attn = score_attn.mean(dim=1)
45 |
46 | if torch.isnan(score_attn).any():
47 | raise ValueError('The Score Value has NAN before impurity operation.')
48 |
49 | # score_attn = torch.clamp(score_attn, min=eps, max=1.0 - eps)
50 |
51 | # Normalize to ensure it sums to 1 along the last dimension
52 | # score_attn = F.normalize(score_attn, p=1, dim=-1)
53 |
54 | # Compute Gini impurity
55 | scores = gini_impurity(score_attn).unsqueeze(-1)
56 |
57 | if torch.isnan(scores).any():
58 | raise ValueError('The Score Value has NAN after impurity computation.')
59 |
60 | # BACKGROUND REMOVING
61 | B, T_R, _ = scores.shape
62 | scores = scores - scores.amin(dim=1, keepdim=True)
63 | scores = scores / (scores.amax(dim=1, keepdim=True))
64 | score_mean = scores.mean(dim=1, keepdim=True)
65 | score_mask = scores < score_mean
66 |
67 | # FOREGROUND SHARPENING
68 | scores = scores - score_mean
69 | scores = scores / (scores.amax(dim=1, keepdim=True))
70 | scores = scores.masked_fill(score_mask, 0.0)
71 |
72 | return scores
73 |
74 |
75 | def vidTLDR(x, attn, r, with_cls_token=True, use_gini=False):
76 |
77 | B, T, _ = x.shape
78 | r_merge = T - r
79 | r_merge = max(min(r_merge, T // 2, T), 0)
80 | if not r_merge:
81 | return x
82 |
83 | with torch.no_grad():
84 | if use_gini:
85 | score_obj = get_objective_score_gini(attn)
86 | else:
87 | score_obj = get_objective_score(attn)
88 |
89 | merge = merging(
90 | x,
91 | r_merge = r_merge,
92 | score_obj = score_obj,
93 | with_cls_token = with_cls_token,
94 | )
95 |
96 | return merge
97 |
98 |
99 | def merging(
100 | metric: torch.Tensor,
101 | r_merge : int,
102 | score_obj : torch.Tensor,
103 | with_cls_token = True,
104 | ):
105 |
106 | with torch.no_grad():
107 | metric = metric / metric.norm(dim=-1, keepdim=True) # (1, 2352, 768)
108 |
109 | # SECTION I. TOKEN MERGING
110 | a, b = metric[..., ::2, :], metric[..., 1::2, :] # (12, 99, 64), (12, 98, 64)
111 | n, s, t1, t2 = a.shape[0], a.shape[1], a.shape[-2], b.shape[-2]
112 |
113 | scores = (a @ b.transpose(-1, -2) + 1) / 2 # 0 - 1
114 |
115 | if with_cls_token:
116 | scores[..., 0, :] = -math.inf
117 |
118 | # TOKEN MERGING
119 | node_max, node_idx = scores.max(dim=-1) # (12, 99), (12, 99)
120 | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] # (12, 99, 1)
121 | unm_idx = edge_idx[..., r_merge:, :] # Unmerged Tokens (12, 83, 1)
122 | src_idx = edge_idx[..., :r_merge, :] # Merged Tokens (12, 16, 1)
123 | dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) # (12, 16, 1)
124 | unm_idx = unm_idx.sort(dim=1)[0]
125 |
126 | src_so = None
127 | if score_obj is not None:
128 | src_so, dst_so = score_obj[..., ::2, :], score_obj[..., 1::2, :] # (1, 1176, 1)
129 | src_so = src_so.gather(dim=-2, index=src_idx) # (12, 91, 197)
130 |
131 | def merge(x: torch.Tensor, mode = "sum", dtype=torch.float32):
132 | ori_dtype = x.dtype
133 | x = x.to(dtype=dtype)
134 | src, dst = x[..., ::2, :], x[..., 1::2, :] # (12, 99, 197), (12, 98, 197)
135 | n, mid, c = src.shape[0], src.shape[1:-2], src.shape[-1]
136 | unm = src.gather(dim=-2, index=unm_idx.expand(n, *mid, t1 - r_merge, c)) # (12, 91, 197)
137 | src = src.gather(dim=-2, index=src_idx.expand(n, *mid, r_merge, c))
138 |
139 | if score_obj is not None:
140 | src = src * src_so
141 |
142 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, *mid, r_merge, c), src, reduce=mode) # (12, 98, 197)
143 |
144 | x = torch.cat([unm, dst], dim=-2) # (12, 1 + 180, 197)
145 | x = x.to(dtype=ori_dtype)
146 | return x
147 |
148 | return merge
--------------------------------------------------------------------------------
/single_modality/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | # from timm.optim.nadam import Nadam
9 | # from timm.optim.novograd import NovoGrad
10 | # from timm.optim.nvnovograd import NvNovoGrad
11 | # from timm.optim.radam import RAdam
12 | # from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_vit(var_name, num_max_layer):
25 | if var_name in ("cls_token", "mask_token", "pos_embed"):
26 | return 0
27 | elif var_name.startswith("patch_embed"):
28 | return 0
29 | elif var_name.startswith("rel_pos_bias"):
30 | return num_max_layer - 1
31 | elif var_name.startswith("blocks"):
32 | layer_id = int(var_name.split('.')[1])
33 | return layer_id + 1
34 | elif var_name.startswith("transformer.resblocks"):
35 | layer_id = int(var_name.split('.')[2])
36 | return layer_id + 1
37 | elif var_name in ("class_embedding", "positional_embedding", "temporal_positional_embedding"):
38 | return 0
39 | elif var_name.startswith("conv1"):
40 | return 0
41 | else:
42 | return num_max_layer - 1
43 |
44 |
45 | class LayerDecayValueAssigner(object):
46 | def __init__(self, values):
47 | self.values = values
48 |
49 | def get_scale(self, layer_id):
50 | return self.values[layer_id]
51 |
52 | def get_layer_id(self, var_name):
53 | return get_num_layer_for_vit(var_name, len(self.values))
54 |
55 |
56 | def get_parameter_groups(
57 | model, weight_decay=1e-5, skip_list=(), get_num_layer=None,
58 | get_layer_scale=None,
59 | ):
60 | parameter_group_names = {}
61 | parameter_group_vars = {}
62 |
63 | for name, param in model.named_parameters():
64 | if not param.requires_grad:
65 | continue # frozen weights
66 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
67 | group_name = "no_decay"
68 | this_weight_decay = 0.
69 | else:
70 | group_name = "decay"
71 | this_weight_decay = weight_decay
72 | if get_num_layer is not None:
73 | layer_id = get_num_layer(name)
74 | group_name = "layer_%d_%s" % (layer_id, group_name)
75 | else:
76 | layer_id = None
77 |
78 | if group_name not in parameter_group_names:
79 | if get_layer_scale is not None:
80 | scale = get_layer_scale(layer_id)
81 | else:
82 | scale = 1.
83 |
84 | parameter_group_names[group_name] = {
85 | "weight_decay": this_weight_decay,
86 | "params": [],
87 | "lr_scale": scale
88 | }
89 | parameter_group_vars[group_name] = {
90 | "weight_decay": this_weight_decay,
91 | "params": [],
92 | "lr_scale": scale
93 | }
94 |
95 | parameter_group_vars[group_name]["params"].append(param)
96 | parameter_group_names[group_name]["params"].append(name)
97 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
98 | return list(parameter_group_vars.values())
99 |
100 |
101 | def create_optimizer(
102 | args, model, get_num_layer=None, get_layer_scale=None,
103 | filter_bias_and_bn=True, skip_list=None
104 | ):
105 | opt_lower = args.opt.lower()
106 | weight_decay = args.weight_decay
107 | if weight_decay and filter_bias_and_bn:
108 | skip = {}
109 | if skip_list is not None:
110 | skip = skip_list
111 | elif hasattr(model, 'no_weight_decay'):
112 | skip = model.no_weight_decay()
113 | parameters = get_parameter_groups(
114 | model, weight_decay, skip, get_num_layer, get_layer_scale,
115 | )
116 | weight_decay = 0.
117 | else:
118 | parameters = model.parameters()
119 |
120 | if 'fused' in opt_lower:
121 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
122 |
123 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
124 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
125 | opt_args['eps'] = args.opt_eps
126 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
127 | opt_args['betas'] = args.opt_betas
128 |
129 | print("optimizer settings:", opt_args)
130 |
131 | opt_split = opt_lower.split('_')
132 | opt_lower = opt_split[-1]
133 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
134 | opt_args.pop('eps', None)
135 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
136 | elif opt_lower == 'momentum':
137 | opt_args.pop('eps', None)
138 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
139 | elif opt_lower == 'adam':
140 | optimizer = optim.Adam(parameters, **opt_args)
141 | elif opt_lower == 'adamw':
142 | optimizer = optim.AdamW(parameters, **opt_args)
143 | # elif opt_lower == 'radam':
144 | # optimizer = RAdam(parameters, **opt_args)
145 | elif opt_lower == 'adamp':
146 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
147 | elif opt_lower == 'sgdp':
148 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
149 | elif opt_lower == 'adadelta':
150 | optimizer = optim.Adadelta(parameters, **opt_args)
151 | elif opt_lower == 'adafactor':
152 | if not args.lr:
153 | opt_args['lr'] = None
154 | optimizer = Adafactor(parameters, **opt_args)
155 | elif opt_lower == 'adahessian':
156 | optimizer = Adahessian(parameters, **opt_args)
157 | elif opt_lower == 'rmsprop':
158 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
159 | # elif opt_lower == 'rmsproptf':
160 | # optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
161 | # elif opt_lower == 'novograd':
162 | # optimizer = NovoGrad(parameters, **opt_args)
163 | # elif opt_lower == 'nvnovograd':
164 | # optimizer = NvNovoGrad(parameters, **opt_args)
165 | elif opt_lower == 'fusedsgd':
166 | opt_args.pop('eps', None)
167 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
168 | elif opt_lower == 'fusedmomentum':
169 | opt_args.pop('eps', None)
170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
171 | elif opt_lower == 'fusedadam':
172 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
173 | elif opt_lower == 'fusedadamw':
174 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
175 | elif opt_lower == 'fusedlamb':
176 | optimizer = FusedLAMB(parameters, **opt_args)
177 | elif opt_lower == 'fusednovograd':
178 | opt_args.setdefault('betas', (0.95, 0.98))
179 | optimizer = FusedNovoGrad(parameters, **opt_args)
180 | else:
181 | assert False and "Invalid optimizer"
182 | raise ValueError
183 |
184 | if len(opt_split) > 1:
185 | if opt_split[0] == 'lookahead':
186 | optimizer = Lookahead(optimizer)
187 |
188 | return optimizer
189 |
--------------------------------------------------------------------------------