├── DATASET.md ├── LICENSE ├── README.md ├── convert.py ├── datasets ├── __init__.py ├── functional.py ├── kinetics.py ├── ssv2 │ ├── train.csv │ └── val.csv ├── video_datasets.py ├── video_transforms.py ├── videomae_transforms.py └── volume_transforms.py ├── engine_finetune.py ├── figs └── petls_patt.png ├── main.py ├── models ├── __init__.py ├── logger.py └── video_swin_transformer_patt.py └── util ├── crop.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py └── pos_embed.py /DATASET.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | 4 | - The pre-processing of **Something-Something-V2** follows [VideoMAE](https://github.com/MCG-NJU/VideoMAE), 5 | which can be summarized into 3 steps: 6 | 7 | 1. Download the dataset from [official website](https://developer.qualcomm.com/software/ai-datasets/something-something). 8 | 9 | 2. Preprocess the dataset by changing the video extension from `webm` to `.mp4` with the **original** height of **240px**.. You can simply run `ffmpeg -i [input.webm] -c:v libx264 [output.mp4]`. 10 | 11 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv` ( here `test.csv` is the same as `val.csv`). The format of `*.csv` file is like: 12 | 13 | ``` 14 | dataset_root/video_1.mp4 label_1 15 | dataset_root/video_2.mp4 label_2 16 | dataset_root/video_3.mp4 label_3 17 | ... 18 | dataset_root/video_N.mp4 label_N 19 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Xinbo Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## V-PETL: A Unified View of Visual PETL Techniques 4 | 5 | ![teaser](figs/petls_patt.png) 6 |
7 | 8 | This is a PyTorch implementation of the paper [Towards a Unified View on Visual Parameter-Efficient Transfer Learning](http://arxiv.org/abs/2210.00788) 9 | 10 | [Bruce X.B. Yu](https://bruceyo.github.io/)1, 11 | [Jianlong Chang](https://scholar.google.com/citations?user=RDwnNsQAAAAJ)2, 12 | [Lingbo Liu](https://lingboliu.com/)1, 13 | [Qi Tian](https://scholar.google.com/citations?user=61b6eYkAAAAJ)2, 14 | [Chang Wen Chen](https://chenlab.comp.polyu.edu.hk/)1\* 15 | 16 | 1The Hong Kong Polytechnic University, 2Huawei Inc. 17 | 18 | \*denotes the corresponding author 19 | 20 | ### Usage 21 | 22 | #### Install 23 | * Geforce 3090 (24G): CUDA 11.4+, PyTorch 1.13.0 + torchvision 0.14.0 24 | * timm 0.4.8 25 | * einops 26 | * easydict 27 | 28 | #### Data Preparation 29 | See [DATASET.md](DATASET.md). 30 | 31 | #### Prepare Pre-trained Checkpoints 32 | 33 | We use Swin-B pre-trained on Kinetics-400 and Kinetics-600. Pre-trained models are available at [Swin Video Tansformer](https://github.com/SwinTransformer/Video-Swin-Transformer). Put them to the folder ```./pre_trained```. 34 | 35 | #### Training 36 | Start 37 | ```bash 38 | CUDA_VISIBLE_DEVICES=3 torchrun --standalone --nnodes=1 \ 39 | --nproc_per_node=1 --master_port=22253 \ 40 | main.py \ 41 | --num_frames 8 \ 42 | --sampling_rate 2 \ 43 | --model swin_transformer \ 44 | --finetune pre_trained/swin_base_patch244_window877_kinetics400_22k.pth \ 45 | --output_dir output \ 46 | --tuned_backbone_layer_fc True \ 47 | --batch_size 16 --epochs 70 --blr 0.1 --weight_decay 0.0 --dist_eval \ 48 | --data_path /media/bruce/ssd1/data/hmdb51 --data_set HMDB51 \ 49 | --ffn_adapt \ 50 | --att_prefix \ 51 | --att_preseqlen 16 \ 52 | --att_mid_dim 128 \ 53 | --att_prefix_mode patt_kv \ 54 | --att_prefix_scale 0.8 \ 55 | ``` 56 | 57 | ### Acknowledgement 58 | 59 | The project is based on [PETL](https://github.com/jxhe/unify-parameter-efficient-tuning), 60 | [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer), [AdaptFormer](https://github.com/ShoufaChen/AdaptFormer). Thanks for their awesome works. 61 | 62 | ### Citation 63 | ``` 64 | @article{yu2022vpetl, 65 | title={Towards a Unified View on Visual Parameter-Efficient Transfer Learning}, 66 | author={Yu, Bruce X.B. and Chang, Jianlong and Liu, Lingbo and Tian, Qi and Chen, Chang Wen}, 67 | journal={arXiv preprint arXiv:2210.00788}, 68 | year={2022} 69 | } 70 | ``` 71 | 72 | ### License 73 | 74 | This project is under the MIT license. See [LICENSE](LICENSE) for details. 75 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | 4 | 5 | def convert_videomae_pretrain(path): 6 | old_ckpts = torch.load(path, map_location='cpu') 7 | new_ckpts = OrderedDict() 8 | 9 | for k, v in old_ckpts['model'].items(): 10 | if not k.startswith('encoder.'): 11 | continue 12 | if k.startswith('encoder.blocks.'): 13 | spk = k.split('.') 14 | if '.'.join(spk[3:]) == 'attn.qkv.weight': 15 | assert v.shape[0] % 3 == 0, v.shape 16 | qi, ki, vi = torch.split(v, v.shape[0] // 3, dim=0) 17 | new_ckpts['.'.join(spk[:3] + ['attn', 'q_proj', 'weight'])] = qi 18 | new_ckpts['.'.join(spk[:3] + ['attn', 'k_proj', 'weight'])] = ki 19 | new_ckpts['.'.join(spk[:3] + ['attn', 'v_proj', 'weight'])] = vi 20 | elif '.'.join(spk[3:]) == 'mlp.fc1.bias': # 'blocks.1.norm1.weight' --> 'norm1.weight' 21 | new_ckpts['.'.join(spk[:3] + ['fc1', 'bias'])] = v 22 | elif '.'.join(spk[3:]) == 'mlp.fc1.weight': 23 | new_ckpts['.'.join(spk[:3] + ['fc1', 'weight'])] = v 24 | elif '.'.join(spk[3:]) == 'mlp.fc2.bias': 25 | new_ckpts['.'.join(spk[:3] + ['fc2', 'bias'])] = v 26 | elif '.'.join(spk[3:]) == 'mlp.fc2.weight': 27 | new_ckpts['.'.join(spk[:3] + ['fc2', 'weight'])] = v 28 | else: 29 | new_ckpts[k] = v 30 | else: 31 | new_ckpts[k] = v 32 | 33 | assert path.endswith('.pth'), path 34 | new_path = path[:-4] + '_new.pth' 35 | torch.save(OrderedDict(model=new_ckpts), new_path) 36 | print('Finished :', path) 37 | 38 | if __name__ == '__main__': 39 | path = '/path/to/videomae/pretrained/checkpoint.pth' 40 | convert_videomae_pretrain(path) 41 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruceyo/V-PETL/661198dd3810d23b36808368930aba69e635c34c/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /datasets/kinetics.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # VideoMAE: https://github.com/MCG-NJU/VideoMAE 4 | # -------------------------------------------------------- 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import decord 10 | from PIL import Image 11 | from torchvision import transforms 12 | import warnings 13 | from decord import VideoReader, cpu 14 | from torch.utils.data import Dataset 15 | import datasets.video_transforms as video_transforms 16 | import datasets.volume_transforms as volume_transforms 17 | from datasets.videomae_transforms import GroupMultiScaleCrop, GroupNormalize, Stack, ToTorchFormatTensor 18 | from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | 21 | class VideoClsDataset(Dataset): 22 | """Load your own video classification dataset.""" 23 | 24 | def __init__(self, anno_path, data_path, mode='train', clip_len=8, 25 | frame_sample_rate=2, crop_size=224, short_side_size=256, 26 | new_height=256, new_width=340, keep_aspect_ratio=True, 27 | num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,args=None): 28 | self.anno_path = anno_path 29 | self.data_path = data_path 30 | self.mode = mode 31 | self.clip_len = clip_len 32 | self.frame_sample_rate = frame_sample_rate 33 | self.crop_size = crop_size 34 | self.short_side_size = short_side_size 35 | self.new_height = new_height 36 | self.new_width = new_width 37 | self.keep_aspect_ratio = keep_aspect_ratio 38 | self.num_segment = num_segment 39 | self.test_num_segment = test_num_segment 40 | self.num_crop = num_crop 41 | self.test_num_crop = test_num_crop 42 | self.args = args 43 | self.aug = False 44 | self.rand_erase = False 45 | if self.mode in ['train']: 46 | self.aug = True 47 | if self.args.reprob > 0: 48 | self.rand_erase = True 49 | if VideoReader is None: 50 | raise ImportError("Unable to import `decord` which is required to read videos.") 51 | 52 | import pandas as pd 53 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=',') 54 | self.dataset_samples = list(cleaned.values[:, 0]) 55 | self.label_array = list(cleaned.values[:, 1]) 56 | 57 | if (mode == 'train'): 58 | if args.linprob: 59 | self.input_mean = IMAGENET_INCEPTION_MEAN if args.inception else IMAGENET_DEFAULT_MEAN 60 | self.input_std = IMAGENET_INCEPTION_STD if args.inception else IMAGENET_DEFAULT_STD 61 | if isinstance(args.input_size, int): 62 | _size = (args.input_size, args.input_size) 63 | self.data_transform = transforms.Compose([ 64 | GroupMultiScaleCrop(_size, [1, .875, .75, .66]), 65 | Stack(roll=False), 66 | ToTorchFormatTensor(div=True), 67 | GroupNormalize(self.input_mean, self.input_std), 68 | ]) 69 | # self.data_transform = None 70 | else: 71 | self.data_transform = None 72 | 73 | elif (mode == 'validation'): 74 | self.data_transform = video_transforms.Compose([ 75 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'), 76 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 77 | volume_transforms.ClipToTensor(), 78 | video_transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN if args.inception else IMAGENET_DEFAULT_MEAN, 79 | std=IMAGENET_INCEPTION_STD if args.inception else IMAGENET_DEFAULT_STD) 80 | ]) 81 | elif mode == 'test': 82 | self.data_resize = video_transforms.Compose([ 83 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear'), 84 | # bruce 85 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 86 | ]) 87 | self.data_transform = video_transforms.Compose([ 88 | volume_transforms.ClipToTensor(), 89 | video_transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN if args.inception else IMAGENET_DEFAULT_MEAN, 90 | std=IMAGENET_INCEPTION_STD if args.inception else IMAGENET_DEFAULT_STD) 91 | ]) 92 | self.test_seg = [] 93 | self.test_dataset = [] 94 | self.test_label_array = [] 95 | for ck in range(self.test_num_segment): 96 | for cp in range(self.test_num_crop): 97 | for idx in range(len(self.label_array)): 98 | sample_label = self.label_array[idx] 99 | self.test_label_array.append(sample_label) 100 | self.test_dataset.append(self.dataset_samples[idx]) 101 | self.test_seg.append((ck, cp)) 102 | 103 | def __getitem__(self, index): 104 | if self.mode == 'train': 105 | args = self.args 106 | scale_t = 1 107 | 108 | sample = self.dataset_samples[index] 109 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C 110 | if len(buffer) == 0: 111 | while len(buffer) == 0: 112 | warnings.warn("video {} not correctly loaded during training".format(sample)) 113 | index = np.random.randint(self.__len__()) 114 | sample = self.dataset_samples[index] 115 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) 116 | 117 | if args.num_sample > 1: 118 | assert not args.linprob 119 | frame_list = [] 120 | label_list = [] 121 | index_list = [] 122 | for _ in range(args.num_sample): 123 | new_frames = self._aug_frame(buffer, args) 124 | label = self.label_array[index] 125 | frame_list.append(new_frames) 126 | label_list.append(label) 127 | index_list.append(index) 128 | return frame_list, label_list, index_list, {} 129 | else: 130 | if self.data_transform is None: 131 | buffer = self._aug_frame(buffer, args) 132 | else: 133 | sampled_list = [Image.fromarray(buffer[vid, :, :, :]).convert('RGB') for vid in range(buffer.shape[0])] 134 | process_data, _ = self.data_transform((sampled_list, None)) 135 | buffer = process_data.view((-1, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W 136 | 137 | return buffer, self.label_array[index], index, {} 138 | 139 | elif self.mode == 'validation': 140 | sample = self.dataset_samples[index] 141 | buffer = self.loadvideo_decord(sample) 142 | if len(buffer) == 0: 143 | while len(buffer) == 0: 144 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 145 | index = np.random.randint(self.__len__()) 146 | sample = self.dataset_samples[index] 147 | buffer = self.loadvideo_decord(sample) 148 | buffer = self.data_transform(buffer) 149 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 150 | 151 | elif self.mode == 'test': 152 | sample = self.test_dataset[index] 153 | chunk_nb, split_nb = self.test_seg[index] 154 | buffer = self.loadvideo_decord(sample) 155 | 156 | while len(buffer) == 0: 157 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 158 | str(self.test_dataset[index]), chunk_nb, split_nb)) 159 | index = np.random.randint(self.__len__()) 160 | sample = self.test_dataset[index] 161 | chunk_nb, split_nb = self.test_seg[index] 162 | buffer = self.loadvideo_decord(sample) 163 | 164 | buffer = self.data_resize(buffer) 165 | if isinstance(buffer, list): 166 | buffer = np.stack(buffer, 0) 167 | 168 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 169 | / (self.test_num_crop - 1) 170 | temporal_step = max(1.0 * (buffer.shape[0] - self.clip_len) \ 171 | / (self.test_num_segment - 1), 0) 172 | temporal_start = int(chunk_nb * temporal_step) 173 | spatial_start = int(split_nb * spatial_step) 174 | if buffer.shape[1] >= buffer.shape[2]: 175 | buffer = buffer[temporal_start:temporal_start + self.clip_len, \ 176 | spatial_start:spatial_start + self.short_side_size, :, :] 177 | else: 178 | buffer = buffer[temporal_start:temporal_start + self.clip_len, \ 179 | :, spatial_start:spatial_start + self.short_side_size, :] 180 | 181 | buffer = self.data_transform(buffer) 182 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 183 | chunk_nb, split_nb 184 | else: 185 | raise NameError('mode {} unkown'.format(self.mode)) 186 | 187 | def _aug_frame( 188 | self, 189 | buffer, 190 | args, 191 | ): 192 | 193 | aug_transform = video_transforms.create_random_augment( 194 | input_size=(self.crop_size, self.crop_size), 195 | auto_augment=args.aa, 196 | interpolation=args.train_interpolation, 197 | ) 198 | 199 | buffer = [ 200 | transforms.ToPILImage()(frame) for frame in buffer 201 | ] 202 | 203 | buffer = aug_transform(buffer) 204 | 205 | buffer = [transforms.ToTensor()(img) for img in buffer] 206 | buffer = torch.stack(buffer) # T C H W 207 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 208 | 209 | # T H W C 210 | buffer = tensor_normalize( 211 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 212 | ) 213 | # T H W C -> C T H W. 214 | buffer = buffer.permute(3, 0, 1, 2) 215 | # Perform data augmentation. 216 | scl, asp = ( 217 | [0.08, 1.0], 218 | [0.75, 1.3333], 219 | ) 220 | 221 | buffer = spatial_sampling( 222 | buffer, 223 | spatial_idx=-1, 224 | min_scale=256, 225 | max_scale=320, 226 | crop_size=self.crop_size, 227 | random_horizontal_flip=False if args.data_set == 'SSV2' else True , 228 | inverse_uniform_sampling=False, 229 | aspect_ratio=asp, 230 | scale=scl, 231 | motion_shift=False 232 | ) 233 | 234 | if self.rand_erase: 235 | erase_transform = RandomErasing( 236 | args.reprob, 237 | mode=args.remode, 238 | max_count=args.recount, 239 | num_splits=args.recount, 240 | device="cpu", 241 | ) 242 | buffer = buffer.permute(1, 0, 2, 3) 243 | buffer = erase_transform(buffer) 244 | buffer = buffer.permute(1, 0, 2, 3) 245 | 246 | return buffer 247 | 248 | def loadvideo_decord(self, sample, sample_rate_scale=1): 249 | """Load video content using Decord""" 250 | fname = sample 251 | 252 | if not (os.path.exists(fname)): 253 | return [] 254 | 255 | # avoid hanging issue 256 | if os.path.getsize(fname) < 1 * 1024: 257 | print('SKIP: ', fname, " - ", os.path.getsize(fname)) 258 | return [] 259 | try: 260 | if self.keep_aspect_ratio: 261 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 262 | else: 263 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 264 | num_threads=1, ctx=cpu(0)) 265 | except: 266 | print("video cannot be loaded by decord: ", fname) 267 | return [] 268 | 269 | if self.mode == 'test': 270 | all_index = [x for x in range(0, len(vr), self.frame_sample_rate)] 271 | while len(all_index) < self.clip_len: 272 | all_index.append(all_index[-1]) 273 | vr.seek(0) 274 | buffer = vr.get_batch(all_index).asnumpy() 275 | return buffer 276 | 277 | # handle temporal segments 278 | converted_len = int(self.clip_len * self.frame_sample_rate) 279 | seg_len = len(vr) // self.num_segment 280 | 281 | all_index = [] 282 | for i in range(self.num_segment): 283 | if seg_len <= converted_len: 284 | index = np.linspace(0, seg_len, num=seg_len // self.frame_sample_rate) 285 | index = np.concatenate((index, np.ones(self.clip_len - seg_len // self.frame_sample_rate) * seg_len)) 286 | index = np.clip(index, 0, seg_len - 1).astype(np.int64) 287 | else: 288 | #import pdb; pdb.set_trace() 289 | #print('..............seg_len > converted_len') 290 | end_idx = np.random.randint(converted_len, seg_len) 291 | str_idx = end_idx - converted_len 292 | index = np.linspace(str_idx, end_idx, num=self.clip_len) 293 | index = np.clip(index, str_idx, end_idx - 1).astype(np.int64) 294 | index = index + i*seg_len 295 | all_index.extend(list(index)) 296 | 297 | all_index = all_index[::int(sample_rate_scale)] 298 | vr.seek(0) 299 | buffer = vr.get_batch(all_index).asnumpy() 300 | return buffer 301 | 302 | def __len__(self): 303 | if self.mode != 'test': 304 | return len(self.dataset_samples) 305 | else: 306 | return len(self.test_dataset) 307 | 308 | 309 | def spatial_sampling( 310 | frames, 311 | spatial_idx=-1, 312 | min_scale=256, 313 | max_scale=320, 314 | crop_size=224, 315 | random_horizontal_flip=True, 316 | inverse_uniform_sampling=False, 317 | aspect_ratio=None, 318 | scale=None, 319 | motion_shift=False, 320 | ): 321 | """ 322 | Perform spatial sampling on the given video frames. If spatial_idx is 323 | -1, perform random scale, random crop, and random flip on the given 324 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 325 | with the given spatial_idx. 326 | Args: 327 | frames (tensor): frames of images sampled from the video. The 328 | dimension is `num frames` x `height` x `width` x `channel`. 329 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 330 | or 2, perform left, center, right crop if width is larger than 331 | height, and perform top, center, buttom crop if height is larger 332 | than width. 333 | min_scale (int): the minimal size of scaling. 334 | max_scale (int): the maximal size of scaling. 335 | crop_size (int): the size of height and width used to crop the 336 | frames. 337 | inverse_uniform_sampling (bool): if True, sample uniformly in 338 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 339 | scale. If False, take a uniform sample from [min_scale, 340 | max_scale]. 341 | aspect_ratio (list): Aspect ratio range for resizing. 342 | scale (list): Scale range for resizing. 343 | motion_shift (bool): Whether to apply motion shift for resizing. 344 | Returns: 345 | frames (tensor): spatially sampled frames. 346 | """ 347 | assert spatial_idx in [-1, 0, 1, 2] 348 | if spatial_idx == -1: 349 | if aspect_ratio is None and scale is None: 350 | frames, _ = video_transforms.random_short_side_scale_jitter( 351 | images=frames, 352 | min_size=min_scale, 353 | max_size=max_scale, 354 | inverse_uniform_sampling=inverse_uniform_sampling, 355 | ) 356 | frames, _ = video_transforms.random_crop(frames, crop_size) 357 | else: 358 | transform_func = ( 359 | video_transforms.random_resized_crop_with_shift 360 | if motion_shift 361 | else video_transforms.random_resized_crop 362 | ) 363 | frames = transform_func( 364 | images=frames, 365 | target_height=crop_size, 366 | target_width=crop_size, 367 | scale=scale, 368 | ratio=aspect_ratio, 369 | ) 370 | if random_horizontal_flip: 371 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 372 | else: 373 | # The testing is deterministic and no jitter should be performed. 374 | # min_scale, max_scale, and crop_size are expect to be the same. 375 | assert len({min_scale, max_scale, crop_size}) == 1 376 | frames, _ = video_transforms.random_short_side_scale_jitter( 377 | frames, min_scale, max_scale 378 | ) 379 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 380 | return frames 381 | 382 | 383 | def tensor_normalize(tensor, mean, std): 384 | """ 385 | Normalize a given tensor by subtracting the mean and dividing the std. 386 | Args: 387 | tensor (tensor): tensor to normalize. 388 | mean (tensor or list): mean value to subtract. 389 | std (tensor or list): std to divide. 390 | """ 391 | if tensor.dtype == torch.uint8: 392 | tensor = tensor.float() 393 | tensor = tensor / 255.0 394 | if type(mean) == list: 395 | mean = torch.tensor(mean) 396 | if type(std) == list: 397 | std = torch.tensor(std) 398 | tensor = tensor - mean 399 | tensor = tensor / std 400 | return tensor 401 | 402 | 403 | class VideoMAE(torch.utils.data.Dataset): 404 | """Load your own video classification dataset. 405 | Parameters 406 | ---------- 407 | root : str, required. 408 | Path to the root folder storing the dataset. 409 | setting : str, required. 410 | A text file describing the dataset, each line per video sample. 411 | There are three items in each line: (1) video path; (2) video length and (3) video label. 412 | train : bool, default True. 413 | Whether to load the training or validation set. 414 | test_mode : bool, default False. 415 | Whether to perform evaluation on the test set. 416 | Usually there is three-crop or ten-crop evaluation strategy involved. 417 | name_pattern : str, default None. 418 | The naming pattern of the decoded video frames. 419 | For example, img_00012.jpg. 420 | video_ext : str, default 'mp4'. 421 | If video_loader is set to True, please specify the video format accordinly. 422 | is_color : bool, default True. 423 | Whether the loaded image is color or grayscale. 424 | modality : str, default 'rgb'. 425 | Input modalities, we support only rgb video frames for now. 426 | Will add support for rgb difference image and optical flow image later. 427 | num_segments : int, default 1. 428 | Number of segments to evenly divide the video into clips. 429 | A useful technique to obtain global video-level information. 430 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. 431 | num_crop : int, default 1. 432 | Number of crops for each image. default is 1. 433 | Common choices are three crops and ten crops during evaluation. 434 | new_length : int, default 1. 435 | The length of input video clip. Default is a single image, but it can be multiple video frames. 436 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames. 437 | new_step : int, default 1. 438 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. 439 | new_step=2 means we will extract a video clip of every other frame. 440 | temporal_jitter : bool, default False. 441 | Whether to temporally jitter if new_step > 1. 442 | video_loader : bool, default False. 443 | Whether to use video loader to load data. 444 | use_decord : bool, default True. 445 | Whether to use Decord video loader to load data. Otherwise use mmcv video loader. 446 | transform : function, default None. 447 | A function that takes data and label and transforms them. 448 | data_aug : str, default 'v1'. 449 | Different types of data augmentation auto. Supports v1, v2, v3 and v4. 450 | lazy_init : bool, default False. 451 | If set to True, build a dataset instance without loading any dataset. 452 | """ 453 | def __init__(self, 454 | root, 455 | setting, 456 | train=True, 457 | test_mode=False, 458 | name_pattern='img_%05d.jpg', 459 | video_ext='mp4', 460 | is_color=True, 461 | modality='rgb', 462 | num_segments=1, 463 | num_crop=1, 464 | new_length=1, 465 | new_step=1, 466 | transform=None, 467 | temporal_jitter=False, 468 | video_loader=False, 469 | use_decord=False, 470 | lazy_init=False): 471 | 472 | super(VideoMAE, self).__init__() 473 | self.root = root 474 | self.setting = setting 475 | self.train = train 476 | self.test_mode = test_mode 477 | self.is_color = is_color 478 | self.modality = modality 479 | self.num_segments = num_segments 480 | self.num_crop = num_crop 481 | self.new_length = new_length 482 | self.new_step = new_step 483 | self.skip_length = self.new_length * self.new_step 484 | self.temporal_jitter = temporal_jitter 485 | self.name_pattern = name_pattern 486 | self.video_loader = video_loader 487 | self.video_ext = video_ext 488 | self.use_decord = use_decord 489 | self.transform = transform 490 | self.lazy_init = lazy_init 491 | 492 | 493 | if not self.lazy_init: 494 | self.clips = self._make_dataset(root, setting) 495 | if len(self.clips) == 0: 496 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" 497 | "Check your data directory (opt.data-dir).")) 498 | 499 | def __getitem__(self, index): 500 | 501 | directory, target = self.clips[index] 502 | if self.video_loader: 503 | if '.' in directory.split('/')[-1]: 504 | # data in the "setting" file already have extension, e.g., demo.mp4 505 | video_name = directory 506 | else: 507 | # data in the "setting" file do not have extension, e.g., demo 508 | # So we need to provide extension (i.e., .mp4) to complete the file name. 509 | video_name = '{}.{}'.format(directory, self.video_ext) 510 | 511 | decord_vr = decord.VideoReader(video_name, num_threads=1) 512 | duration = len(decord_vr) 513 | 514 | segment_indices, skip_offsets = self._sample_train_indices(duration) 515 | 516 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets) 517 | 518 | process_data = self.transform((images, None)) # T*C,H,W 519 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W 520 | 521 | return (process_data, target, -1, -1) 522 | 523 | def __len__(self): 524 | return len(self.clips) 525 | 526 | def _make_dataset(self, directory, setting): 527 | if not os.path.exists(setting): 528 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) 529 | clips = [] 530 | with open(setting) as split_f: 531 | data = split_f.readlines() 532 | for line in data: 533 | line_info = line.split(',') 534 | # line format: video_path, video_duration, video_label 535 | if len(line_info) < 2: 536 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) 537 | clip_path = os.path.join(line_info[0]) 538 | target = int(line_info[1]) 539 | item = (clip_path, target) 540 | clips.append(item) 541 | return clips 542 | 543 | def _sample_train_indices(self, num_frames): 544 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments 545 | if average_duration > 0: 546 | offsets = np.multiply(list(range(self.num_segments)), 547 | average_duration) 548 | offsets = offsets + np.random.randint(average_duration, 549 | size=self.num_segments) 550 | elif num_frames > max(self.num_segments, self.skip_length): 551 | offsets = np.sort(np.random.randint( 552 | num_frames - self.skip_length + 1, 553 | size=self.num_segments)) 554 | else: 555 | offsets = np.zeros((self.num_segments,)) 556 | 557 | if self.temporal_jitter: 558 | skip_offsets = np.random.randint( 559 | self.new_step, size=self.skip_length // self.new_step) 560 | else: 561 | skip_offsets = np.zeros( 562 | self.skip_length // self.new_step, dtype=int) 563 | return offsets + 1, skip_offsets 564 | 565 | 566 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets): 567 | sampled_list = [] 568 | frame_id_list = [] 569 | for seg_ind in indices: 570 | offset = int(seg_ind) 571 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 572 | if offset + skip_offsets[i] <= duration: 573 | frame_id = offset + skip_offsets[i] - 1 574 | else: 575 | frame_id = offset - 1 576 | frame_id_list.append(frame_id) 577 | if offset + self.new_step < duration: 578 | offset += self.new_step 579 | try: 580 | video_data = video_reader.get_batch(frame_id_list).asnumpy() 581 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 582 | except: 583 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) 584 | return sampled_list 585 | 586 | 587 | class DataAugmentationForVideoMAE(object): 588 | def __init__(self, args): 589 | self.input_mean = IMAGENET_INCEPTION_MEAN if args.inception else IMAGENET_DEFAULT_MEAN 590 | self.input_std = IMAGENET_INCEPTION_STD if args.inception else IMAGENET_DEFAULT_STD 591 | normalize = GroupNormalize(self.input_mean, self.input_std) 592 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) 593 | self.transform = transforms.Compose([ 594 | # RandomCrop(), 595 | self.train_augmentation, 596 | Stack(roll=False), 597 | ToTorchFormatTensor(div=True), 598 | normalize, 599 | ]) 600 | 601 | def __call__(self, images): 602 | process_data , _ = self.transform(images) 603 | return process_data 604 | 605 | def __repr__(self): 606 | repr = "(DataAugmentationForVideoMAE,\n" 607 | repr += " transform = %s,\n" % str(self.transform) 608 | repr += ")" 609 | return repr 610 | 611 | 612 | def build_training_dataset(args): 613 | transform = DataAugmentationForVideoMAE(args) 614 | dataset = VideoMAE( 615 | root=None, 616 | setting=os.path.join(args.data_path, 'train.csv'), 617 | video_ext='mp4', 618 | is_color=True, 619 | modality='rgb', 620 | new_length=args.num_frames, 621 | new_step=args.sampling_rate, 622 | transform=transform, 623 | temporal_jitter=False, 624 | video_loader=True, 625 | use_decord=True, 626 | lazy_init=False) 627 | print("Data Aug = %s" % str(transform)) 628 | return dataset 629 | -------------------------------------------------------------------------------- /datasets/video_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets.video_transforms import * 3 | from datasets.kinetics import VideoClsDataset, VideoMAE 4 | 5 | 6 | class DataAugmentationForVideoMAE(object): 7 | def __init__(self, args): 8 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 9 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 10 | normalize = GroupNormalize(self.input_mean, self.input_std) 11 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) 12 | self.transform = transforms.Compose([ 13 | self.train_augmentation, 14 | Stack(roll=False), 15 | ToTorchFormatTensor(div=True), 16 | normalize, 17 | ]) 18 | if args.linprob: 19 | self.masked_position_generator = None 20 | elif args.mask_type == 'tube': 21 | self.masked_position_generator = TubeMaskingGenerator( 22 | args.window_size, args.mask_ratio 23 | ) 24 | 25 | def __call__(self, images): 26 | process_data , _ = self.transform(images) 27 | if self.masked_position_generator is None: 28 | return process_data 29 | return process_data, self.masked_position_generator() 30 | 31 | def __repr__(self): 32 | repr = "(DataAugmentationForVideoMAE,\n" 33 | repr += " transform = %s,\n" % str(self.transform) 34 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 35 | repr += ")" 36 | return repr 37 | 38 | 39 | def build_pretraining_dataset(args): 40 | transform = DataAugmentationForVideoMAE(args) 41 | dataset = VideoMAE( 42 | root=None, 43 | setting=args.data_path, 44 | video_ext='mp4', 45 | is_color=True, 46 | modality='rgb', 47 | new_length=args.num_frames, 48 | new_step=args.sampling_rate, 49 | transform=transform, 50 | temporal_jitter=True, 51 | video_loader=True, 52 | use_decord=True, 53 | lazy_init=False) 54 | print("Data Aug = %s" % str(transform)) 55 | return dataset 56 | 57 | 58 | def build_dataset(is_train, test_mode, args): 59 | if args.data_set == 'Kinetics-400': 60 | mode = None 61 | anno_path = None 62 | if is_train == True: 63 | mode = 'train' 64 | anno_path = os.path.join(args.data_path, 'train.csv') 65 | elif test_mode == True: 66 | mode = 'test' 67 | anno_path = os.path.join(args.data_path, 'val.csv') 68 | else: 69 | mode = 'validation' 70 | anno_path = os.path.join(args.data_path, 'test.csv') 71 | 72 | dataset = VideoClsDataset( 73 | anno_path=anno_path, 74 | data_path='/', 75 | mode=mode, 76 | clip_len=args.num_frames, 77 | frame_sample_rate=args.sampling_rate, 78 | num_segment=1, 79 | test_num_segment=args.test_num_segment, 80 | test_num_crop=args.test_num_crop, 81 | num_crop=1 if not test_mode else 3, 82 | keep_aspect_ratio=True, 83 | crop_size=args.input_size, 84 | short_side_size=args.short_side_size, 85 | new_height=256, 86 | new_width=320, 87 | args=args) 88 | nb_classes = 400 89 | 90 | elif args.data_set == 'SSV2': 91 | mode = None 92 | anno_path = None 93 | if is_train == True: 94 | mode = 'train' 95 | anno_path = os.path.join(args.data_path, 'train.csv') 96 | elif test_mode == True: 97 | mode = 'test' 98 | anno_path = os.path.join(args.data_path, 'val.csv') 99 | else: 100 | mode = 'validation' 101 | anno_path = os.path.join(args.data_path, 'test.csv') 102 | 103 | dataset = VideoClsDataset( 104 | anno_path=anno_path, 105 | data_path='/', 106 | mode=mode, 107 | clip_len=args.num_frames, 108 | frame_sample_rate=args.sampling_rate, 109 | num_segment=1, 110 | test_num_segment=args.test_num_segment, 111 | test_num_crop=args.test_num_crop, 112 | num_crop=1 if not test_mode else 3, 113 | keep_aspect_ratio=True, 114 | crop_size=args.input_size, 115 | short_side_size=args.short_side_size, 116 | new_height=256, 117 | new_width=320, 118 | args=args) 119 | nb_classes = 174 120 | 121 | elif args.data_set == 'UCF101': 122 | mode = None 123 | anno_path = None 124 | if is_train == True: 125 | mode = 'train' 126 | anno_path = os.path.join(args.data_path, 'train.csv') 127 | elif test_mode == True: 128 | mode = 'test' 129 | anno_path = os.path.join(args.data_path, 'val.csv') 130 | else: 131 | mode = 'validation' 132 | anno_path = os.path.join(args.data_path, 'test.csv') 133 | 134 | dataset = VideoClsDataset( 135 | anno_path=anno_path, 136 | data_path='/', 137 | mode=mode, 138 | clip_len=args.num_frames, 139 | frame_sample_rate=args.sampling_rate, 140 | num_segment=1, 141 | test_num_segment=args.test_num_segment, 142 | test_num_crop=args.test_num_crop, 143 | num_crop=1 if not test_mode else 3, 144 | keep_aspect_ratio=True, 145 | crop_size=args.input_size, 146 | short_side_size=args.short_side_size, 147 | new_height=256, 148 | new_width=320, 149 | args=args) 150 | nb_classes = 101 151 | 152 | elif args.data_set == 'HMDB51': 153 | mode = None 154 | anno_path = None 155 | if is_train == True: 156 | mode = 'train' 157 | anno_path = os.path.join(args.data_path, 'train.csv') 158 | elif test_mode == True: 159 | mode = 'test' 160 | anno_path = os.path.join(args.data_path, 'val.csv') 161 | else: 162 | mode = 'validation' 163 | anno_path = os.path.join(args.data_path, 'test.csv') 164 | 165 | dataset = VideoClsDataset( 166 | anno_path=anno_path, 167 | data_path='/', 168 | mode=mode, 169 | clip_len=args.num_frames, 170 | frame_sample_rate=args.sampling_rate, 171 | num_segment=1, 172 | test_num_segment=args.test_num_segment, 173 | test_num_crop=args.test_num_crop, 174 | num_crop=1 if not test_mode else 3, 175 | keep_aspect_ratio=True, 176 | crop_size=args.input_size, 177 | short_side_size=args.short_side_size, 178 | new_height=256, 179 | new_width=320, 180 | args=args) 181 | nb_classes = 51 182 | else: 183 | raise NotImplementedError() 184 | assert nb_classes == args.nb_classes 185 | print("Number of the class = %d" % args.nb_classes) 186 | 187 | return dataset, nb_classes 188 | -------------------------------------------------------------------------------- /datasets/video_transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | import numpy as np 4 | import random 5 | import torch 6 | import torchvision.transforms.functional as F 7 | from PIL import Image 8 | from torchvision import transforms 9 | 10 | # from rand_augment import rand_augment_transform 11 | # from random_erasing import RandomErasing 12 | 13 | 14 | import numbers 15 | import PIL 16 | import torchvision 17 | 18 | import datasets.functional as FF 19 | 20 | _pil_interpolation_to_str = { 21 | Image.NEAREST: "PIL.Image.NEAREST", 22 | Image.BILINEAR: "PIL.Image.BILINEAR", 23 | Image.BICUBIC: "PIL.Image.BICUBIC", 24 | Image.LANCZOS: "PIL.Image.LANCZOS", 25 | Image.HAMMING: "PIL.Image.HAMMING", 26 | Image.BOX: "PIL.Image.BOX", 27 | } 28 | 29 | 30 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 31 | 32 | 33 | def _pil_interp(method): 34 | if method == "bicubic": 35 | return Image.BICUBIC 36 | elif method == "lanczos": 37 | return Image.LANCZOS 38 | elif method == "hamming": 39 | return Image.HAMMING 40 | else: 41 | return Image.BILINEAR 42 | 43 | 44 | def random_short_side_scale_jitter( 45 | images, min_size, max_size, boxes=None, inverse_uniform_sampling=False 46 | ): 47 | """ 48 | Perform a spatial short scale jittering on the given images and 49 | corresponding boxes. 50 | Args: 51 | images (tensor): images to perform scale jitter. Dimension is 52 | `num frames` x `channel` x `height` x `width`. 53 | min_size (int): the minimal size to scale the frames. 54 | max_size (int): the maximal size to scale the frames. 55 | boxes (ndarray): optional. Corresponding boxes to images. 56 | Dimension is `num boxes` x 4. 57 | inverse_uniform_sampling (bool): if True, sample uniformly in 58 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 59 | scale. If False, take a uniform sample from [min_scale, max_scale]. 60 | Returns: 61 | (tensor): the scaled images with dimension of 62 | `num frames` x `channel` x `new height` x `new width`. 63 | (ndarray or None): the scaled boxes with dimension of 64 | `num boxes` x 4. 65 | """ 66 | if inverse_uniform_sampling: 67 | size = int( 68 | round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) 69 | ) 70 | else: 71 | size = int(round(np.random.uniform(min_size, max_size))) 72 | 73 | height = images.shape[2] 74 | width = images.shape[3] 75 | if (width <= height and width == size) or ( 76 | height <= width and height == size 77 | ): 78 | return images, boxes 79 | new_width = size 80 | new_height = size 81 | if width < height: 82 | new_height = int(math.floor((float(height) / width) * size)) 83 | if boxes is not None: 84 | boxes = boxes * float(new_height) / height 85 | else: 86 | new_width = int(math.floor((float(width) / height) * size)) 87 | if boxes is not None: 88 | boxes = boxes * float(new_width) / width 89 | 90 | return ( 91 | torch.nn.functional.interpolate( 92 | images, 93 | size=(new_height, new_width), 94 | mode="bilinear", 95 | align_corners=False, 96 | ), 97 | boxes, 98 | ) 99 | 100 | 101 | def crop_boxes(boxes, x_offset, y_offset): 102 | """ 103 | Peform crop on the bounding boxes given the offsets. 104 | Args: 105 | boxes (ndarray or None): bounding boxes to peform crop. The dimension 106 | is `num boxes` x 4. 107 | x_offset (int): cropping offset in the x axis. 108 | y_offset (int): cropping offset in the y axis. 109 | Returns: 110 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 111 | `num boxes` x 4. 112 | """ 113 | cropped_boxes = boxes.copy() 114 | cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset 115 | cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset 116 | 117 | return cropped_boxes 118 | 119 | 120 | def random_crop(images, size, boxes=None): 121 | """ 122 | Perform random spatial crop on the given images and corresponding boxes. 123 | Args: 124 | images (tensor): images to perform random crop. The dimension is 125 | `num frames` x `channel` x `height` x `width`. 126 | size (int): the size of height and width to crop on the image. 127 | boxes (ndarray or None): optional. Corresponding boxes to images. 128 | Dimension is `num boxes` x 4. 129 | Returns: 130 | cropped (tensor): cropped images with dimension of 131 | `num frames` x `channel` x `size` x `size`. 132 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 133 | `num boxes` x 4. 134 | """ 135 | if images.shape[2] == size and images.shape[3] == size: 136 | return images 137 | height = images.shape[2] 138 | width = images.shape[3] 139 | y_offset = 0 140 | if height > size: 141 | y_offset = int(np.random.randint(0, height - size)) 142 | x_offset = 0 143 | if width > size: 144 | x_offset = int(np.random.randint(0, width - size)) 145 | cropped = images[ 146 | :, :, y_offset : y_offset + size, x_offset : x_offset + size 147 | ] 148 | 149 | cropped_boxes = ( 150 | crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None 151 | ) 152 | 153 | return cropped, cropped_boxes 154 | 155 | 156 | def horizontal_flip(prob, images, boxes=None): 157 | """ 158 | Perform horizontal flip on the given images and corresponding boxes. 159 | Args: 160 | prob (float): probility to flip the images. 161 | images (tensor): images to perform horizontal flip, the dimension is 162 | `num frames` x `channel` x `height` x `width`. 163 | boxes (ndarray or None): optional. Corresponding boxes to images. 164 | Dimension is `num boxes` x 4. 165 | Returns: 166 | images (tensor): images with dimension of 167 | `num frames` x `channel` x `height` x `width`. 168 | flipped_boxes (ndarray or None): the flipped boxes with dimension of 169 | `num boxes` x 4. 170 | """ 171 | if boxes is None: 172 | flipped_boxes = None 173 | else: 174 | flipped_boxes = boxes.copy() 175 | 176 | if np.random.uniform() < prob: 177 | images = images.flip((-1)) 178 | 179 | if len(images.shape) == 3: 180 | width = images.shape[2] 181 | elif len(images.shape) == 4: 182 | width = images.shape[3] 183 | else: 184 | raise NotImplementedError("Dimension does not supported") 185 | if boxes is not None: 186 | flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 187 | 188 | return images, flipped_boxes 189 | 190 | 191 | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): 192 | """ 193 | Perform uniform spatial sampling on the images and corresponding boxes. 194 | Args: 195 | images (tensor): images to perform uniform crop. The dimension is 196 | `num frames` x `channel` x `height` x `width`. 197 | size (int): size of height and weight to crop the images. 198 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width 199 | is larger than height. Or 0, 1, or 2 for top, center, and bottom 200 | crop if height is larger than width. 201 | boxes (ndarray or None): optional. Corresponding boxes to images. 202 | Dimension is `num boxes` x 4. 203 | scale_size (int): optinal. If not None, resize the images to scale_size before 204 | performing any crop. 205 | Returns: 206 | cropped (tensor): images with dimension of 207 | `num frames` x `channel` x `size` x `size`. 208 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 209 | `num boxes` x 4. 210 | """ 211 | assert spatial_idx in [0, 1, 2] 212 | ndim = len(images.shape) 213 | if ndim == 3: 214 | images = images.unsqueeze(0) 215 | height = images.shape[2] 216 | width = images.shape[3] 217 | 218 | if scale_size is not None: 219 | if width <= height: 220 | width, height = scale_size, int(height / width * scale_size) 221 | else: 222 | width, height = int(width / height * scale_size), scale_size 223 | images = torch.nn.functional.interpolate( 224 | images, 225 | size=(height, width), 226 | mode="bilinear", 227 | align_corners=False, 228 | ) 229 | 230 | y_offset = int(math.ceil((height - size) / 2)) 231 | x_offset = int(math.ceil((width - size) / 2)) 232 | 233 | if height > width: 234 | if spatial_idx == 0: 235 | y_offset = 0 236 | elif spatial_idx == 2: 237 | y_offset = height - size 238 | else: 239 | if spatial_idx == 0: 240 | x_offset = 0 241 | elif spatial_idx == 2: 242 | x_offset = width - size 243 | cropped = images[ 244 | :, :, y_offset : y_offset + size, x_offset : x_offset + size 245 | ] 246 | cropped_boxes = ( 247 | crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None 248 | ) 249 | if ndim == 3: 250 | cropped = cropped.squeeze(0) 251 | return cropped, cropped_boxes 252 | 253 | 254 | def clip_boxes_to_image(boxes, height, width): 255 | """ 256 | Clip an array of boxes to an image with the given height and width. 257 | Args: 258 | boxes (ndarray): bounding boxes to perform clipping. 259 | Dimension is `num boxes` x 4. 260 | height (int): given image height. 261 | width (int): given image width. 262 | Returns: 263 | clipped_boxes (ndarray): the clipped boxes with dimension of 264 | `num boxes` x 4. 265 | """ 266 | clipped_boxes = boxes.copy() 267 | clipped_boxes[:, [0, 2]] = np.minimum( 268 | width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) 269 | ) 270 | clipped_boxes[:, [1, 3]] = np.minimum( 271 | height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) 272 | ) 273 | return clipped_boxes 274 | 275 | 276 | def blend(images1, images2, alpha): 277 | """ 278 | Blend two images with a given weight alpha. 279 | Args: 280 | images1 (tensor): the first images to be blended, the dimension is 281 | `num frames` x `channel` x `height` x `width`. 282 | images2 (tensor): the second images to be blended, the dimension is 283 | `num frames` x `channel` x `height` x `width`. 284 | alpha (float): the blending weight. 285 | Returns: 286 | (tensor): blended images, the dimension is 287 | `num frames` x `channel` x `height` x `width`. 288 | """ 289 | return images1 * alpha + images2 * (1 - alpha) 290 | 291 | 292 | def grayscale(images): 293 | """ 294 | Get the grayscale for the input images. The channels of images should be 295 | in order BGR. 296 | Args: 297 | images (tensor): the input images for getting grayscale. Dimension is 298 | `num frames` x `channel` x `height` x `width`. 299 | Returns: 300 | img_gray (tensor): blended images, the dimension is 301 | `num frames` x `channel` x `height` x `width`. 302 | """ 303 | # R -> 0.299, G -> 0.587, B -> 0.114. 304 | img_gray = torch.tensor(images) 305 | gray_channel = ( 306 | 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] 307 | ) 308 | img_gray[:, 0] = gray_channel 309 | img_gray[:, 1] = gray_channel 310 | img_gray[:, 2] = gray_channel 311 | return img_gray 312 | 313 | 314 | def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): 315 | """ 316 | Perfrom a color jittering on the input images. The channels of images 317 | should be in order BGR. 318 | Args: 319 | images (tensor): images to perform color jitter. Dimension is 320 | `num frames` x `channel` x `height` x `width`. 321 | img_brightness (float): jitter ratio for brightness. 322 | img_contrast (float): jitter ratio for contrast. 323 | img_saturation (float): jitter ratio for saturation. 324 | Returns: 325 | images (tensor): the jittered images, the dimension is 326 | `num frames` x `channel` x `height` x `width`. 327 | """ 328 | 329 | jitter = [] 330 | if img_brightness != 0: 331 | jitter.append("brightness") 332 | if img_contrast != 0: 333 | jitter.append("contrast") 334 | if img_saturation != 0: 335 | jitter.append("saturation") 336 | 337 | if len(jitter) > 0: 338 | order = np.random.permutation(np.arange(len(jitter))) 339 | for idx in range(0, len(jitter)): 340 | if jitter[order[idx]] == "brightness": 341 | images = brightness_jitter(img_brightness, images) 342 | elif jitter[order[idx]] == "contrast": 343 | images = contrast_jitter(img_contrast, images) 344 | elif jitter[order[idx]] == "saturation": 345 | images = saturation_jitter(img_saturation, images) 346 | return images 347 | 348 | 349 | def brightness_jitter(var, images): 350 | """ 351 | Perfrom brightness jittering on the input images. The channels of images 352 | should be in order BGR. 353 | Args: 354 | var (float): jitter ratio for brightness. 355 | images (tensor): images to perform color jitter. Dimension is 356 | `num frames` x `channel` x `height` x `width`. 357 | Returns: 358 | images (tensor): the jittered images, the dimension is 359 | `num frames` x `channel` x `height` x `width`. 360 | """ 361 | alpha = 1.0 + np.random.uniform(-var, var) 362 | 363 | img_bright = torch.zeros(images.shape) 364 | images = blend(images, img_bright, alpha) 365 | return images 366 | 367 | 368 | def contrast_jitter(var, images): 369 | """ 370 | Perfrom contrast jittering on the input images. The channels of images 371 | should be in order BGR. 372 | Args: 373 | var (float): jitter ratio for contrast. 374 | images (tensor): images to perform color jitter. Dimension is 375 | `num frames` x `channel` x `height` x `width`. 376 | Returns: 377 | images (tensor): the jittered images, the dimension is 378 | `num frames` x `channel` x `height` x `width`. 379 | """ 380 | alpha = 1.0 + np.random.uniform(-var, var) 381 | 382 | img_gray = grayscale(images) 383 | img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) 384 | images = blend(images, img_gray, alpha) 385 | return images 386 | 387 | 388 | def saturation_jitter(var, images): 389 | """ 390 | Perfrom saturation jittering on the input images. The channels of images 391 | should be in order BGR. 392 | Args: 393 | var (float): jitter ratio for saturation. 394 | images (tensor): images to perform color jitter. Dimension is 395 | `num frames` x `channel` x `height` x `width`. 396 | Returns: 397 | images (tensor): the jittered images, the dimension is 398 | `num frames` x `channel` x `height` x `width`. 399 | """ 400 | alpha = 1.0 + np.random.uniform(-var, var) 401 | img_gray = grayscale(images) 402 | images = blend(images, img_gray, alpha) 403 | 404 | return images 405 | 406 | 407 | def lighting_jitter(images, alphastd, eigval, eigvec): 408 | """ 409 | Perform AlexNet-style PCA jitter on the given images. 410 | Args: 411 | images (tensor): images to perform lighting jitter. Dimension is 412 | `num frames` x `channel` x `height` x `width`. 413 | alphastd (float): jitter ratio for PCA jitter. 414 | eigval (list): eigenvalues for PCA jitter. 415 | eigvec (list[list]): eigenvectors for PCA jitter. 416 | Returns: 417 | out_images (tensor): the jittered images, the dimension is 418 | `num frames` x `channel` x `height` x `width`. 419 | """ 420 | if alphastd == 0: 421 | return images 422 | # generate alpha1, alpha2, alpha3. 423 | alpha = np.random.normal(0, alphastd, size=(1, 3)) 424 | eig_vec = np.array(eigvec) 425 | eig_val = np.reshape(eigval, (1, 3)) 426 | rgb = np.sum( 427 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), 428 | axis=1, 429 | ) 430 | out_images = torch.zeros_like(images) 431 | if len(images.shape) == 3: 432 | # C H W 433 | channel_dim = 0 434 | elif len(images.shape) == 4: 435 | # T C H W 436 | channel_dim = 1 437 | else: 438 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 439 | 440 | for idx in range(images.shape[channel_dim]): 441 | # C H W 442 | if len(images.shape) == 3: 443 | out_images[idx] = images[idx] + rgb[2 - idx] 444 | # T C H W 445 | elif len(images.shape) == 4: 446 | out_images[:, idx] = images[:, idx] + rgb[2 - idx] 447 | else: 448 | raise NotImplementedError( 449 | f"Unsupported dimension {len(images.shape)}" 450 | ) 451 | 452 | return out_images 453 | 454 | 455 | def color_normalization(images, mean, stddev): 456 | """ 457 | Perform color nomration on the given images. 458 | Args: 459 | images (tensor): images to perform color normalization. Dimension is 460 | `num frames` x `channel` x `height` x `width`. 461 | mean (list): mean values for normalization. 462 | stddev (list): standard deviations for normalization. 463 | 464 | Returns: 465 | out_images (tensor): the noramlized images, the dimension is 466 | `num frames` x `channel` x `height` x `width`. 467 | """ 468 | if len(images.shape) == 3: 469 | assert ( 470 | len(mean) == images.shape[0] 471 | ), "channel mean not computed properly" 472 | assert ( 473 | len(stddev) == images.shape[0] 474 | ), "channel stddev not computed properly" 475 | elif len(images.shape) == 4: 476 | assert ( 477 | len(mean) == images.shape[1] 478 | ), "channel mean not computed properly" 479 | assert ( 480 | len(stddev) == images.shape[1] 481 | ), "channel stddev not computed properly" 482 | else: 483 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 484 | 485 | out_images = torch.zeros_like(images) 486 | for idx in range(len(mean)): 487 | # C H W 488 | if len(images.shape) == 3: 489 | out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] 490 | elif len(images.shape) == 4: 491 | out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] 492 | else: 493 | raise NotImplementedError( 494 | f"Unsupported dimension {len(images.shape)}" 495 | ) 496 | return out_images 497 | 498 | 499 | def _get_param_spatial_crop( 500 | scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False 501 | ): 502 | """ 503 | Given scale, ratio, height and width, return sampled coordinates of the videos. 504 | """ 505 | for _ in range(num_repeat): 506 | area = height * width 507 | target_area = random.uniform(*scale) * area 508 | if log_scale: 509 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 510 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 511 | else: 512 | aspect_ratio = random.uniform(*ratio) 513 | 514 | w = int(round(math.sqrt(target_area * aspect_ratio))) 515 | h = int(round(math.sqrt(target_area / aspect_ratio))) 516 | 517 | if np.random.uniform() < 0.5 and switch_hw: 518 | w, h = h, w 519 | 520 | if 0 < w <= width and 0 < h <= height: 521 | i = random.randint(0, height - h) 522 | j = random.randint(0, width - w) 523 | return i, j, h, w 524 | 525 | # Fallback to central crop 526 | in_ratio = float(width) / float(height) 527 | if in_ratio < min(ratio): 528 | w = width 529 | h = int(round(w / min(ratio))) 530 | elif in_ratio > max(ratio): 531 | h = height 532 | w = int(round(h * max(ratio))) 533 | else: # whole image 534 | w = width 535 | h = height 536 | i = (height - h) // 2 537 | j = (width - w) // 2 538 | return i, j, h, w 539 | 540 | 541 | def random_resized_crop( 542 | images, 543 | target_height, 544 | target_width, 545 | scale=(0.8, 1.0), 546 | ratio=(3.0 / 4.0, 4.0 / 3.0), 547 | ): 548 | """ 549 | Crop the given images to random size and aspect ratio. A crop of random 550 | size (default: of 0.08 to 1.0) of the original size and a random aspect 551 | ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This 552 | crop is finally resized to given size. This is popularly used to train the 553 | Inception networks. 554 | 555 | Args: 556 | images: Images to perform resizing and cropping. 557 | target_height: Desired height after cropping. 558 | target_width: Desired width after cropping. 559 | scale: Scale range of Inception-style area based random resizing. 560 | ratio: Aspect ratio range of Inception-style area based random resizing. 561 | """ 562 | 563 | height = images.shape[2] 564 | width = images.shape[3] 565 | 566 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) 567 | cropped = images[:, :, i : i + h, j : j + w] 568 | return torch.nn.functional.interpolate( 569 | cropped, 570 | size=(target_height, target_width), 571 | mode="bilinear", 572 | align_corners=False, 573 | ) 574 | 575 | 576 | def random_resized_crop_with_shift( 577 | images, 578 | target_height, 579 | target_width, 580 | scale=(0.8, 1.0), 581 | ratio=(3.0 / 4.0, 4.0 / 3.0), 582 | ): 583 | """ 584 | This is similar to random_resized_crop. However, it samples two different 585 | boxes (for cropping) for the first and last frame. It then linearly 586 | interpolates the two boxes for other frames. 587 | 588 | Args: 589 | images: Images to perform resizing and cropping. 590 | target_height: Desired height after cropping. 591 | target_width: Desired width after cropping. 592 | scale: Scale range of Inception-style area based random resizing. 593 | ratio: Aspect ratio range of Inception-style area based random resizing. 594 | """ 595 | t = images.shape[1] 596 | height = images.shape[2] 597 | width = images.shape[3] 598 | 599 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) 600 | i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) 601 | i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] 602 | j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] 603 | h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] 604 | w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] 605 | out = torch.zeros((3, t, target_height, target_width)) 606 | for ind in range(t): 607 | out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( 608 | images[ 609 | :, 610 | ind : ind + 1, 611 | i_s[ind] : i_s[ind] + h_s[ind], 612 | j_s[ind] : j_s[ind] + w_s[ind], 613 | ], 614 | size=(target_height, target_width), 615 | mode="bilinear", 616 | align_corners=False, 617 | ) 618 | return out 619 | 620 | 621 | def create_random_augment( 622 | input_size, 623 | auto_augment=None, 624 | interpolation="bilinear", 625 | ): 626 | """ 627 | Get video randaug transform. 628 | 629 | Args: 630 | input_size: The size of the input video in tuple. 631 | auto_augment: Parameters for randaug. An example: 632 | "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number 633 | of operations to apply). 634 | interpolation: Interpolation method. 635 | """ 636 | if isinstance(input_size, tuple): 637 | img_size = input_size[-2:] 638 | else: 639 | img_size = input_size 640 | 641 | if auto_augment: 642 | assert isinstance(auto_augment, str) 643 | if isinstance(img_size, tuple): 644 | img_size_min = min(img_size) 645 | else: 646 | img_size_min = img_size 647 | aa_params = {"translate_const": int(img_size_min * 0.45)} 648 | if interpolation and interpolation != "random": 649 | aa_params["interpolation"] = _pil_interp(interpolation) 650 | if auto_augment.startswith("rand"): 651 | return transforms.Compose( 652 | [rand_augment_transform(auto_augment, aa_params)] 653 | ) 654 | raise NotImplementedError 655 | 656 | 657 | def random_sized_crop_img( 658 | im, 659 | size, 660 | jitter_scale=(0.08, 1.0), 661 | jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), 662 | max_iter=10, 663 | ): 664 | """ 665 | Performs Inception-style cropping (used for training). 666 | """ 667 | assert ( 668 | len(im.shape) == 3 669 | ), "Currently only support image for random_sized_crop" 670 | h, w = im.shape[1:3] 671 | i, j, h, w = _get_param_spatial_crop( 672 | scale=jitter_scale, 673 | ratio=jitter_aspect, 674 | height=h, 675 | width=w, 676 | num_repeat=max_iter, 677 | log_scale=False, 678 | switch_hw=True, 679 | ) 680 | cropped = im[:, i : i + h, j : j + w] 681 | return torch.nn.functional.interpolate( 682 | cropped.unsqueeze(0), 683 | size=(size, size), 684 | mode="bilinear", 685 | align_corners=False, 686 | ).squeeze(0) 687 | 688 | 689 | # The following code are modified based on timm lib, we will replace the following 690 | # contents with dependency from PyTorchVideo. 691 | # https://github.com/facebookresearch/pytorchvideo 692 | class RandomResizedCropAndInterpolation: 693 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 694 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 695 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 696 | is finally resized to given size. 697 | This is popularly used to train the Inception networks. 698 | Args: 699 | size: expected output size of each edge 700 | scale: range of size of the origin size cropped 701 | ratio: range of aspect ratio of the origin aspect ratio cropped 702 | interpolation: Default: PIL.Image.BILINEAR 703 | """ 704 | 705 | def __init__( 706 | self, 707 | size, 708 | scale=(0.08, 1.0), 709 | ratio=(3.0 / 4.0, 4.0 / 3.0), 710 | interpolation="bilinear", 711 | ): 712 | if isinstance(size, tuple): 713 | self.size = size 714 | else: 715 | self.size = (size, size) 716 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 717 | print("range should be of kind (min, max)") 718 | 719 | if interpolation == "random": 720 | self.interpolation = _RANDOM_INTERPOLATION 721 | else: 722 | self.interpolation = _pil_interp(interpolation) 723 | self.scale = scale 724 | self.ratio = ratio 725 | 726 | @staticmethod 727 | def get_params(img, scale, ratio): 728 | """Get parameters for ``crop`` for a random sized crop. 729 | Args: 730 | img (PIL Image): Image to be cropped. 731 | scale (tuple): range of size of the origin size cropped 732 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 733 | Returns: 734 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 735 | sized crop. 736 | """ 737 | area = img.size[0] * img.size[1] 738 | 739 | for _ in range(10): 740 | target_area = random.uniform(*scale) * area 741 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 742 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 743 | 744 | w = int(round(math.sqrt(target_area * aspect_ratio))) 745 | h = int(round(math.sqrt(target_area / aspect_ratio))) 746 | 747 | if w <= img.size[0] and h <= img.size[1]: 748 | i = random.randint(0, img.size[1] - h) 749 | j = random.randint(0, img.size[0] - w) 750 | return i, j, h, w 751 | 752 | # Fallback to central crop 753 | in_ratio = img.size[0] / img.size[1] 754 | if in_ratio < min(ratio): 755 | w = img.size[0] 756 | h = int(round(w / min(ratio))) 757 | elif in_ratio > max(ratio): 758 | h = img.size[1] 759 | w = int(round(h * max(ratio))) 760 | else: # whole image 761 | w = img.size[0] 762 | h = img.size[1] 763 | i = (img.size[1] - h) // 2 764 | j = (img.size[0] - w) // 2 765 | return i, j, h, w 766 | 767 | def __call__(self, img): 768 | """ 769 | Args: 770 | img (PIL Image): Image to be cropped and resized. 771 | Returns: 772 | PIL Image: Randomly cropped and resized image. 773 | """ 774 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 775 | if isinstance(self.interpolation, (tuple, list)): 776 | interpolation = random.choice(self.interpolation) 777 | else: 778 | interpolation = self.interpolation 779 | return F.resized_crop(img, i, j, h, w, self.size, interpolation) 780 | 781 | def __repr__(self): 782 | if isinstance(self.interpolation, (tuple, list)): 783 | interpolate_str = " ".join( 784 | [_pil_interpolation_to_str[x] for x in self.interpolation] 785 | ) 786 | else: 787 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 788 | format_string = self.__class__.__name__ + "(size={0}".format(self.size) 789 | format_string += ", scale={0}".format( 790 | tuple(round(s, 4) for s in self.scale) 791 | ) 792 | format_string += ", ratio={0}".format( 793 | tuple(round(r, 4) for r in self.ratio) 794 | ) 795 | format_string += ", interpolation={0})".format(interpolate_str) 796 | return format_string 797 | 798 | 799 | def transforms_imagenet_train( 800 | img_size=224, 801 | scale=None, 802 | ratio=None, 803 | hflip=0.5, 804 | vflip=0.0, 805 | color_jitter=0.4, 806 | auto_augment=None, 807 | interpolation="random", 808 | use_prefetcher=False, 809 | mean=(0.485, 0.456, 0.406), 810 | std=(0.229, 0.224, 0.225), 811 | re_prob=0.0, 812 | re_mode="const", 813 | re_count=1, 814 | re_num_splits=0, 815 | separate=False, 816 | ): 817 | """ 818 | If separate==True, the transforms are returned as a tuple of 3 separate transforms 819 | for use in a mixing dataset that passes 820 | * all data through the first (primary) transform, called the 'clean' data 821 | * a portion of the data through the secondary transform 822 | * normalizes and converts the branches above with the third, final transform 823 | """ 824 | if isinstance(img_size, tuple): 825 | img_size = img_size[-2:] 826 | else: 827 | img_size = img_size 828 | 829 | scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range 830 | ratio = tuple( 831 | ratio or (3.0 / 4.0, 4.0 / 3.0) 832 | ) # default imagenet ratio range 833 | primary_tfl = [ 834 | RandomResizedCropAndInterpolation( 835 | img_size, scale=scale, ratio=ratio, interpolation=interpolation 836 | ) 837 | ] 838 | if hflip > 0.0: 839 | primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] 840 | if vflip > 0.0: 841 | primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] 842 | 843 | secondary_tfl = [] 844 | if auto_augment: 845 | assert isinstance(auto_augment, str) 846 | if isinstance(img_size, tuple): 847 | img_size_min = min(img_size) 848 | else: 849 | img_size_min = img_size 850 | aa_params = dict( 851 | translate_const=int(img_size_min * 0.45), 852 | img_mean=tuple([min(255, round(255 * x)) for x in mean]), 853 | ) 854 | if interpolation and interpolation != "random": 855 | aa_params["interpolation"] = _pil_interp(interpolation) 856 | if auto_augment.startswith("rand"): 857 | secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] 858 | elif auto_augment.startswith("augmix"): 859 | raise NotImplementedError("Augmix not implemented") 860 | else: 861 | raise NotImplementedError("Auto aug not implemented") 862 | elif color_jitter is not None: 863 | # color jitter is enabled when not using AA 864 | if isinstance(color_jitter, (list, tuple)): 865 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation 866 | # or 4 if also augmenting hue 867 | assert len(color_jitter) in (3, 4) 868 | else: 869 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue 870 | color_jitter = (float(color_jitter),) * 3 871 | secondary_tfl += [transforms.ColorJitter(*color_jitter)] 872 | 873 | final_tfl = [] 874 | final_tfl += [ 875 | transforms.ToTensor(), 876 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 877 | ] 878 | if re_prob > 0.0: 879 | final_tfl.append( 880 | RandomErasing( 881 | re_prob, 882 | mode=re_mode, 883 | max_count=re_count, 884 | num_splits=re_num_splits, 885 | device="cpu", 886 | cube=False, 887 | ) 888 | ) 889 | 890 | if separate: 891 | return ( 892 | transforms.Compose(primary_tfl), 893 | transforms.Compose(secondary_tfl), 894 | transforms.Compose(final_tfl), 895 | ) 896 | else: 897 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) 898 | 899 | ############################################################################################################ 900 | ############################################################################################################ 901 | 902 | class Compose(object): 903 | """Composes several transforms 904 | Args: 905 | transforms (list of ``Transform`` objects): list of transforms 906 | to compose 907 | """ 908 | 909 | def __init__(self, transforms): 910 | self.transforms = transforms 911 | 912 | def __call__(self, clip): 913 | for t in self.transforms: 914 | clip = t(clip) 915 | return clip 916 | 917 | 918 | class RandomHorizontalFlip(object): 919 | """Horizontally flip the list of given images randomly 920 | with a probability 0.5 921 | """ 922 | 923 | def __call__(self, clip): 924 | """ 925 | Args: 926 | img (PIL.Image or numpy.ndarray): List of images to be cropped 927 | in format (h, w, c) in numpy.ndarray 928 | Returns: 929 | PIL.Image or numpy.ndarray: Randomly flipped clip 930 | """ 931 | if random.random() < 0.5: 932 | if isinstance(clip[0], np.ndarray): 933 | return [np.fliplr(img) for img in clip] 934 | elif isinstance(clip[0], PIL.Image.Image): 935 | return [ 936 | img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip 937 | ] 938 | else: 939 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 940 | ' but got list of {0}'.format(type(clip[0]))) 941 | return clip 942 | 943 | 944 | class RandomResize(object): 945 | """Resizes a list of (H x W x C) numpy.ndarray to the final size 946 | The larger the original image is, the more times it takes to 947 | interpolate 948 | Args: 949 | interpolation (str): Can be one of 'nearest', 'bilinear' 950 | defaults to nearest 951 | size (tuple): (widht, height) 952 | """ 953 | 954 | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): 955 | self.ratio = ratio 956 | self.interpolation = interpolation 957 | 958 | def __call__(self, clip): 959 | scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) 960 | 961 | if isinstance(clip[0], np.ndarray): 962 | im_h, im_w, im_c = clip[0].shape 963 | elif isinstance(clip[0], PIL.Image.Image): 964 | im_w, im_h = clip[0].size 965 | 966 | new_w = int(im_w * scaling_factor) 967 | new_h = int(im_h * scaling_factor) 968 | new_size = (new_w, new_h) 969 | resized = FF.resize_clip( 970 | clip, new_size, interpolation=self.interpolation) 971 | return resized 972 | 973 | 974 | class Resize(object): 975 | """Resizes a list of (H x W x C) numpy.ndarray to the final size 976 | The larger the original image is, the more times it takes to 977 | interpolate 978 | Args: 979 | interpolation (str): Can be one of 'nearest', 'bilinear' 980 | defaults to nearest 981 | size (tuple): (widht, height) 982 | """ 983 | 984 | def __init__(self, size, interpolation='nearest'): 985 | self.size = size 986 | self.interpolation = interpolation 987 | 988 | def __call__(self, clip): 989 | resized = FF.resize_clip( 990 | clip, self.size, interpolation=self.interpolation) 991 | return resized 992 | 993 | 994 | class RandomCrop(object): 995 | """Extract random crop at the same location for a list of images 996 | Args: 997 | size (sequence or int): Desired output size for the 998 | crop in format (h, w) 999 | """ 1000 | 1001 | def __init__(self, size): 1002 | if isinstance(size, numbers.Number): 1003 | size = (size, size) 1004 | 1005 | self.size = size 1006 | 1007 | def __call__(self, clip): 1008 | """ 1009 | Args: 1010 | img (PIL.Image or numpy.ndarray): List of images to be cropped 1011 | in format (h, w, c) in numpy.ndarray 1012 | Returns: 1013 | PIL.Image or numpy.ndarray: Cropped list of images 1014 | """ 1015 | h, w = self.size 1016 | if isinstance(clip[0], np.ndarray): 1017 | im_h, im_w, im_c = clip[0].shape 1018 | elif isinstance(clip[0], PIL.Image.Image): 1019 | im_w, im_h = clip[0].size 1020 | else: 1021 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 1022 | 'but got list of {0}'.format(type(clip[0]))) 1023 | if w > im_w or h > im_h: 1024 | error_msg = ( 1025 | 'Initial image size should be larger then ' 1026 | 'cropped size but got cropped sizes : ({w}, {h}) while ' 1027 | 'initial image is ({im_w}, {im_h})'.format( 1028 | im_w=im_w, im_h=im_h, w=w, h=h)) 1029 | raise ValueError(error_msg) 1030 | 1031 | x1 = random.randint(0, im_w - w) 1032 | y1 = random.randint(0, im_h - h) 1033 | cropped = FF.crop_clip(clip, y1, x1, h, w) 1034 | 1035 | return cropped 1036 | 1037 | 1038 | class ThreeCrop(object): 1039 | """Extract random crop at the same location for a list of images 1040 | Args: 1041 | size (sequence or int): Desired output size for the 1042 | crop in format (h, w) 1043 | """ 1044 | 1045 | def __init__(self, size): 1046 | if isinstance(size, numbers.Number): 1047 | size = (size, size) 1048 | 1049 | self.size = size 1050 | 1051 | def __call__(self, clip): 1052 | """ 1053 | Args: 1054 | img (PIL.Image or numpy.ndarray): List of images to be cropped 1055 | in format (h, w, c) in numpy.ndarray 1056 | Returns: 1057 | PIL.Image or numpy.ndarray: Cropped list of images 1058 | """ 1059 | h, w = self.size 1060 | if isinstance(clip[0], np.ndarray): 1061 | im_h, im_w, im_c = clip[0].shape 1062 | elif isinstance(clip[0], PIL.Image.Image): 1063 | im_w, im_h = clip[0].size 1064 | else: 1065 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 1066 | 'but got list of {0}'.format(type(clip[0]))) 1067 | if w != im_w and h != im_h: 1068 | clip = FF.resize_clip(clip, self.size, interpolation="bilinear") 1069 | im_h, im_w, im_c = clip[0].shape 1070 | 1071 | step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) 1072 | cropped = [] 1073 | for i in range(3): 1074 | if (im_h > self.size[0]): 1075 | x1 = 0 1076 | y1 = i * step 1077 | cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) 1078 | else: 1079 | x1 = i * step 1080 | y1 = 0 1081 | cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) 1082 | return cropped 1083 | 1084 | 1085 | class RandomRotation(object): 1086 | """Rotate entire clip randomly by a random angle within 1087 | given bounds 1088 | Args: 1089 | degrees (sequence or int): Range of degrees to select from 1090 | If degrees is a number instead of sequence like (min, max), 1091 | the range of degrees, will be (-degrees, +degrees). 1092 | """ 1093 | 1094 | def __init__(self, degrees): 1095 | if isinstance(degrees, numbers.Number): 1096 | if degrees < 0: 1097 | raise ValueError('If degrees is a single number,' 1098 | 'must be positive') 1099 | degrees = (-degrees, degrees) 1100 | else: 1101 | if len(degrees) != 2: 1102 | raise ValueError('If degrees is a sequence,' 1103 | 'it must be of len 2.') 1104 | 1105 | self.degrees = degrees 1106 | 1107 | def __call__(self, clip): 1108 | """ 1109 | Args: 1110 | img (PIL.Image or numpy.ndarray): List of images to be cropped 1111 | in format (h, w, c) in numpy.ndarray 1112 | Returns: 1113 | PIL.Image or numpy.ndarray: Cropped list of images 1114 | """ 1115 | import skimage 1116 | angle = random.uniform(self.degrees[0], self.degrees[1]) 1117 | if isinstance(clip[0], np.ndarray): 1118 | rotated = [skimage.transform.rotate(img, angle) for img in clip] 1119 | elif isinstance(clip[0], PIL.Image.Image): 1120 | rotated = [img.rotate(angle) for img in clip] 1121 | else: 1122 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 1123 | 'but got list of {0}'.format(type(clip[0]))) 1124 | 1125 | return rotated 1126 | 1127 | 1128 | class CenterCrop(object): 1129 | """Extract center crop at the same location for a list of images 1130 | Args: 1131 | size (sequence or int): Desired output size for the 1132 | crop in format (h, w) 1133 | """ 1134 | 1135 | def __init__(self, size): 1136 | if isinstance(size, numbers.Number): 1137 | size = (size, size) 1138 | 1139 | self.size = size 1140 | 1141 | def __call__(self, clip): 1142 | """ 1143 | Args: 1144 | img (PIL.Image or numpy.ndarray): List of images to be cropped 1145 | in format (h, w, c) in numpy.ndarray 1146 | Returns: 1147 | PIL.Image or numpy.ndarray: Cropped list of images 1148 | """ 1149 | h, w = self.size 1150 | if isinstance(clip[0], np.ndarray): 1151 | im_h, im_w, im_c = clip[0].shape 1152 | elif isinstance(clip[0], PIL.Image.Image): 1153 | im_w, im_h = clip[0].size 1154 | else: 1155 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 1156 | 'but got list of {0}'.format(type(clip[0]))) 1157 | if w > im_w or h > im_h: 1158 | error_msg = ( 1159 | 'Initial image size should be larger then ' 1160 | 'cropped size but got cropped sizes : ({w}, {h}) while ' 1161 | 'initial image is ({im_w}, {im_h})'.format( 1162 | im_w=im_w, im_h=im_h, w=w, h=h)) 1163 | raise ValueError(error_msg) 1164 | 1165 | x1 = int(round((im_w - w) / 2.)) 1166 | y1 = int(round((im_h - h) / 2.)) 1167 | cropped = FF.crop_clip(clip, y1, x1, h, w) 1168 | 1169 | return cropped 1170 | 1171 | 1172 | class ColorJitter(object): 1173 | """Randomly change the brightness, contrast and saturation and hue of the clip 1174 | Args: 1175 | brightness (float): How much to jitter brightness. brightness_factor 1176 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 1177 | contrast (float): How much to jitter contrast. contrast_factor 1178 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 1179 | saturation (float): How much to jitter saturation. saturation_factor 1180 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 1181 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 1182 | [-hue, hue]. Should be >=0 and <= 0.5. 1183 | """ 1184 | 1185 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 1186 | self.brightness = brightness 1187 | self.contrast = contrast 1188 | self.saturation = saturation 1189 | self.hue = hue 1190 | 1191 | def get_params(self, brightness, contrast, saturation, hue): 1192 | if brightness > 0: 1193 | brightness_factor = random.uniform( 1194 | max(0, 1 - brightness), 1 + brightness) 1195 | else: 1196 | brightness_factor = None 1197 | 1198 | if contrast > 0: 1199 | contrast_factor = random.uniform( 1200 | max(0, 1 - contrast), 1 + contrast) 1201 | else: 1202 | contrast_factor = None 1203 | 1204 | if saturation > 0: 1205 | saturation_factor = random.uniform( 1206 | max(0, 1 - saturation), 1 + saturation) 1207 | else: 1208 | saturation_factor = None 1209 | 1210 | if hue > 0: 1211 | hue_factor = random.uniform(-hue, hue) 1212 | else: 1213 | hue_factor = None 1214 | return brightness_factor, contrast_factor, saturation_factor, hue_factor 1215 | 1216 | def __call__(self, clip): 1217 | """ 1218 | Args: 1219 | clip (list): list of PIL.Image 1220 | Returns: 1221 | list PIL.Image : list of transformed PIL.Image 1222 | """ 1223 | if isinstance(clip[0], np.ndarray): 1224 | raise TypeError( 1225 | 'Color jitter not yet implemented for numpy arrays') 1226 | elif isinstance(clip[0], PIL.Image.Image): 1227 | brightness, contrast, saturation, hue = self.get_params( 1228 | self.brightness, self.contrast, self.saturation, self.hue) 1229 | 1230 | # Create img transform function sequence 1231 | img_transforms = [] 1232 | if brightness is not None: 1233 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) 1234 | if saturation is not None: 1235 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) 1236 | if hue is not None: 1237 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) 1238 | if contrast is not None: 1239 | img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) 1240 | random.shuffle(img_transforms) 1241 | 1242 | # Apply to all images 1243 | jittered_clip = [] 1244 | for img in clip: 1245 | for func in img_transforms: 1246 | jittered_img = func(img) 1247 | jittered_clip.append(jittered_img) 1248 | 1249 | else: 1250 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 1251 | 'but got list of {0}'.format(type(clip[0]))) 1252 | return jittered_clip 1253 | 1254 | 1255 | class Normalize(object): 1256 | """Normalize a clip with mean and standard deviation. 1257 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 1258 | will normalize each channel of the input ``torch.*Tensor`` i.e. 1259 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 1260 | .. note:: 1261 | This transform acts out of place, i.e., it does not mutates the input tensor. 1262 | Args: 1263 | mean (sequence): Sequence of means for each channel. 1264 | std (sequence): Sequence of standard deviations for each channel. 1265 | """ 1266 | 1267 | def __init__(self, mean, std): 1268 | self.mean = mean 1269 | self.std = std 1270 | 1271 | def __call__(self, clip): 1272 | """ 1273 | Args: 1274 | clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. 1275 | Returns: 1276 | Tensor: Normalized Tensor clip. 1277 | """ 1278 | return FF.normalize(clip, self.mean, self.std) 1279 | 1280 | def __repr__(self): 1281 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 1282 | 1283 | 1284 | class RandomResizedCrop(transforms.RandomResizedCrop): 1285 | """ 1286 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 1287 | This may lead to results different with torchvision's version. 1288 | Following BYOL's TF code: 1289 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 1290 | """ 1291 | @staticmethod 1292 | def get_params(img, scale, ratio): 1293 | width, height = F._get_image_size(img) 1294 | area = height * width 1295 | 1296 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 1297 | log_ratio = torch.log(torch.tensor(ratio)) 1298 | aspect_ratio = torch.exp( 1299 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 1300 | ).item() 1301 | 1302 | w = int(round(math.sqrt(target_area * aspect_ratio))) 1303 | h = int(round(math.sqrt(target_area / aspect_ratio))) 1304 | 1305 | w = min(w, width) 1306 | h = min(h, height) 1307 | 1308 | i = torch.randint(0, height - h + 1, size=(1,)).item() 1309 | j = torch.randint(0, width - w + 1, size=(1,)).item() 1310 | 1311 | return i, j, h, w 1312 | -------------------------------------------------------------------------------- /datasets/videomae_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import math 5 | import random 6 | import numpy as np 7 | import torchvision 8 | from PIL import Image, ImageOps 9 | import numbers 10 | 11 | 12 | class GroupRandomCrop(object): 13 | def __init__(self, size): 14 | if isinstance(size, numbers.Number): 15 | self.size = (int(size), int(size)) 16 | else: 17 | self.size = size 18 | 19 | def __call__(self, img_tuple): 20 | img_group, label = img_tuple 21 | 22 | w, h = img_group[0].size 23 | th, tw = self.size 24 | 25 | out_images = list() 26 | 27 | x1 = random.randint(0, w - tw) 28 | y1 = random.randint(0, h - th) 29 | 30 | for img in img_group: 31 | assert(img.size[0] == w and img.size[1] == h) 32 | if w == tw and h == th: 33 | out_images.append(img) 34 | else: 35 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 36 | 37 | return (out_images, label) 38 | 39 | 40 | class GroupCenterCrop(object): 41 | def __init__(self, size): 42 | self.worker = torchvision.transforms.CenterCrop(size) 43 | 44 | def __call__(self, img_tuple): 45 | img_group, label = img_tuple 46 | return ([self.worker(img) for img in img_group], label) 47 | 48 | 49 | class GroupNormalize(object): 50 | def __init__(self, mean, std): 51 | self.mean = mean 52 | self.std = std 53 | 54 | def __call__(self, tensor_tuple): 55 | tensor, label = tensor_tuple 56 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 57 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 58 | 59 | # TODO: make efficient 60 | for t, m, s in zip(tensor, rep_mean, rep_std): 61 | t.sub_(m).div_(s) 62 | 63 | return (tensor,label) 64 | 65 | 66 | class GroupGrayScale(object): 67 | def __init__(self, size): 68 | self.worker = torchvision.transforms.Grayscale(size) 69 | 70 | def __call__(self, img_tuple): 71 | img_group, label = img_tuple 72 | return ([self.worker(img) for img in img_group], label) 73 | 74 | 75 | class GroupScale(object): 76 | """ Rescales the input PIL.Image to the given 'size'. 77 | 'size' will be the size of the smaller edge. 78 | For example, if height > width, then image will be 79 | rescaled to (size * height / width, size) 80 | size: size of the smaller edge 81 | interpolation: Default: PIL.Image.BILINEAR 82 | """ 83 | 84 | def __init__(self, size, interpolation=Image.BILINEAR): 85 | self.worker = torchvision.transforms.Resize(size, interpolation) 86 | 87 | def __call__(self, img_tuple): 88 | img_group, label = img_tuple 89 | return ([self.worker(img) for img in img_group], label) 90 | 91 | 92 | class GroupMultiScaleCrop(object): 93 | 94 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 95 | self.scales = scales if scales is not None else [1, 875, .75, .66] 96 | self.max_distort = max_distort 97 | self.fix_crop = fix_crop 98 | self.more_fix_crop = more_fix_crop 99 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 100 | self.interpolation = Image.BILINEAR 101 | 102 | def __call__(self, img_tuple): 103 | img_group, label = img_tuple 104 | 105 | im_size = img_group[0].size # (454, 256) 106 | 107 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 108 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 109 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 110 | return (ret_img_group, label) 111 | 112 | def _sample_crop_size(self, im_size): 113 | image_w, image_h = im_size[0], im_size[1] 114 | 115 | # find a crop size 116 | base_size = min(image_w, image_h) 117 | crop_sizes = [int(base_size * x) for x in self.scales] 118 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 119 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 120 | 121 | pairs = [] 122 | for i, h in enumerate(crop_h): 123 | for j, w in enumerate(crop_w): 124 | if abs(i - j) <= self.max_distort: 125 | pairs.append((w, h)) 126 | 127 | crop_pair = random.choice(pairs) 128 | if not self.fix_crop: 129 | w_offset = random.randint(0, image_w - crop_pair[0]) 130 | h_offset = random.randint(0, image_h - crop_pair[1]) 131 | else: 132 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 133 | 134 | return crop_pair[0], crop_pair[1], w_offset, h_offset 135 | 136 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 137 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 138 | return random.choice(offsets) 139 | 140 | @staticmethod 141 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 142 | w_step = (image_w - crop_w) // 4 143 | h_step = (image_h - crop_h) // 4 144 | 145 | ret = list() 146 | ret.append((0, 0)) # upper left 147 | ret.append((4 * w_step, 0)) # upper right 148 | ret.append((0, 4 * h_step)) # lower left 149 | ret.append((4 * w_step, 4 * h_step)) # lower right 150 | ret.append((2 * w_step, 2 * h_step)) # center 151 | 152 | if more_fix_crop: 153 | ret.append((0, 2 * h_step)) # center left 154 | ret.append((4 * w_step, 2 * h_step)) # center right 155 | ret.append((2 * w_step, 4 * h_step)) # lower center 156 | ret.append((2 * w_step, 0 * h_step)) # upper center 157 | 158 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 159 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 160 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 161 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 162 | return ret 163 | 164 | 165 | class Stack(object): 166 | 167 | def __init__(self, roll=False): 168 | self.roll = roll 169 | 170 | def __call__(self, img_tuple): 171 | img_group, label = img_tuple 172 | 173 | if img_group[0].mode == 'L': 174 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 175 | elif img_group[0].mode == 'RGB': 176 | if self.roll: 177 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 178 | else: 179 | return (np.concatenate(img_group, axis=2), label) 180 | 181 | 182 | class ToTorchFormatTensor(object): 183 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 184 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 185 | def __init__(self, div=True): 186 | self.div = div 187 | 188 | def __call__(self, pic_tuple): 189 | pic, label = pic_tuple 190 | 191 | if isinstance(pic, np.ndarray): 192 | # handle numpy array 193 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 194 | else: 195 | # handle PIL Image 196 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 197 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 198 | # put it from HWC to CHW format 199 | # yikes, this transpose takes 80% of the loading time/CPU 200 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 201 | return (img.float().div(255.) if self.div else img.float(), label) 202 | 203 | 204 | class IdentityTransform(object): 205 | 206 | def __call__(self, data): 207 | return data 208 | 209 | 210 | class RandomCrop(object): 211 | def __init__(self, min_size=256, max_size=320, crop_size=224, inverse_uniform_sampling=True): 212 | self.min_size = min_size 213 | self.max_size = max_size 214 | self.crop_size = crop_size 215 | self.inverse_uniform_sampling = inverse_uniform_sampling 216 | 217 | def __call__(self, img_tuple): 218 | """ 219 | Perform a spatial short scale jittering on the given images and 220 | corresponding boxes. 221 | Args: 222 | images (tensor): images to perform scale jitter. Dimension is 223 | `num frames` x `channel` x `height` x `width`. 224 | min_size (int): the minimal size to scale the frames. 225 | max_size (int): the maximal size to scale the frames. 226 | inverse_uniform_sampling (bool): if True, sample uniformly in 227 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 228 | scale. If False, take a uniform sample from [min_scale, max_scale]. 229 | Returns: 230 | (tensor): the scaled images with dimension of 231 | `num frames` x `channel` x `new height` x `new width`. 232 | (ndarray or None): the scaled boxes with dimension of 233 | `num boxes` x 4. 234 | """ 235 | images, label = img_tuple 236 | if self.inverse_uniform_sampling: 237 | size = int( 238 | round(1.0 / np.random.uniform(1.0 / self.max_size, 1.0 / self.min_size)) 239 | ) 240 | else: 241 | size = int(round(np.random.uniform(self.min_size, self.max_size))) 242 | 243 | # height = images.shape[2] 244 | # width = images.shape[3] 245 | width, height = images[0].size # first w, then h 246 | 247 | if (width <= height and width == size) or ( 248 | height <= width and height == size 249 | ): 250 | return images 251 | new_width = size 252 | new_height = size 253 | if width < height: 254 | new_height = int(math.floor((float(height) / width) * size)) 255 | else: 256 | new_width = int(math.floor((float(width) / height) * size)) 257 | 258 | resized_images = [img.resize((new_width, new_height), Image.BILINEAR) for img in images] 259 | 260 | # crop 261 | images = resized_images 262 | size = self.crop_size 263 | 264 | """ 265 | Perform random spatial crop on the given images and corresponding boxes. 266 | Args: 267 | images (tensor): images to perform random crop. The dimension is 268 | `num frames` x `channel` x `height` x `width`. 269 | size (int): the size of height and width to crop on the image. 270 | boxes (ndarray or None): optional. Corresponding boxes to images. 271 | Dimension is `num boxes` x 4. 272 | Returns: 273 | cropped (tensor): cropped images with dimension of 274 | `num frames` x `channel` x `size` x `size`. 275 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 276 | `num boxes` x 4. 277 | """ 278 | width, height = images[0].size # first w, then h 279 | if width == size and height == size: 280 | return images 281 | y_offset = 0 282 | if height > size: 283 | y_offset = int(np.random.randint(0, height - size)) 284 | x_offset = 0 285 | if width > size: 286 | x_offset = int(np.random.randint(0, width - size)) 287 | 288 | cropped_images = [img.crop((x_offset, y_offset, x_offset + size, y_offset + size)) for img in images] 289 | 290 | return cropped_images, label -------------------------------------------------------------------------------- /datasets/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import os 13 | import numpy as np 14 | import math 15 | import sys 16 | from typing import Iterable, Optional 17 | 18 | import torch 19 | 20 | from timm.data import Mixup 21 | from timm.utils import accuracy 22 | 23 | import util.misc as misc 24 | import util.lr_sched as lr_sched 25 | 26 | 27 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 28 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 29 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 30 | mixup_fn: Optional[Mixup] = None, log_writer=None, 31 | args=None): 32 | model.train(True) 33 | metric_logger = misc.MetricLogger(delimiter=" ") 34 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 500 37 | 38 | accum_iter = args.accum_iter 39 | 40 | optimizer.zero_grad() 41 | 42 | if log_writer is not None: 43 | print('log_dir: {}'.format(log_writer.log_dir)) 44 | 45 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 46 | samples, targets = batch[0], batch[1] 47 | 48 | # we use a per iteration (instead of per epoch) lr scheduler 49 | if data_iter_step % accum_iter == 0: 50 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 51 | 52 | samples = samples.to(device, non_blocking=True) 53 | targets = targets.to(device, non_blocking=True) 54 | 55 | if mixup_fn is not None: 56 | samples, targets = mixup_fn(samples, targets) 57 | 58 | with torch.cuda.amp.autocast(): 59 | outputs = model(samples) 60 | loss = criterion(outputs, targets) 61 | 62 | loss_value = loss.item() 63 | 64 | if not math.isfinite(loss_value): 65 | print("Loss is {}, stopping training".format(loss_value)) 66 | sys.exit(1) 67 | 68 | loss /= accum_iter 69 | loss_scaler(loss, optimizer, clip_grad=max_norm, 70 | parameters=model.parameters(), create_graph=False, 71 | update_grad=(data_iter_step + 1) % accum_iter == 0) 72 | if (data_iter_step + 1) % accum_iter == 0: 73 | optimizer.zero_grad() 74 | 75 | torch.cuda.synchronize() 76 | 77 | metric_logger.update(loss=loss_value) 78 | min_lr = 10. 79 | max_lr = 0. 80 | for group in optimizer.param_groups: 81 | min_lr = min(min_lr, group["lr"]) 82 | max_lr = max(max_lr, group["lr"]) 83 | 84 | metric_logger.update(lr=max_lr) 85 | 86 | loss_value_reduce = misc.all_reduce_mean(loss_value) 87 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 88 | """ We use epoch_1000x as the x-axis in tensorboard. 89 | This calibrates different curves when batch size changes. 90 | """ 91 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 92 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 93 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 94 | 95 | # gather the stats from all processes 96 | metric_logger.synchronize_between_processes() 97 | print("Averaged stats:", metric_logger) 98 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 99 | 100 | 101 | @torch.no_grad() 102 | def evaluate(data_loader, model, device): 103 | criterion = torch.nn.CrossEntropyLoss() 104 | 105 | metric_logger = misc.MetricLogger(delimiter=" ") 106 | header = 'Test:' 107 | 108 | # switch to evaluation mode 109 | model.eval() 110 | for batch in metric_logger.log_every(data_loader, 100, header): 111 | images = batch[0] 112 | target = batch[1] # TODO: check why default use -1 113 | images = images.to(device, non_blocking=True) 114 | target = target.to(device, non_blocking=True) 115 | #images = images.permute(0,2,1,3,4).contiguous() 116 | # compute output 117 | with torch.cuda.amp.autocast(): 118 | output = model(images) 119 | loss = criterion(output, target) 120 | 121 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 122 | 123 | batch_size = images.shape[0] 124 | metric_logger.update(loss=loss.item()) 125 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 126 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 127 | 128 | # gather the stats from all processes 129 | metric_logger.synchronize_between_processes() 130 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 131 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 132 | 133 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 134 | 135 | 136 | @torch.no_grad() 137 | def final_test(data_loader, model, device, file): 138 | criterion = torch.nn.CrossEntropyLoss() 139 | 140 | metric_logger = misc.MetricLogger(delimiter=" ") 141 | header = 'Final_Test:' 142 | 143 | # switch to evaluation mode 144 | model.eval() 145 | final_result = [] 146 | 147 | for batch in metric_logger.log_every(data_loader, 100, header): 148 | images = batch[0] 149 | target = batch[1] 150 | ids = batch[2] 151 | chunk_nb = batch[3] 152 | split_nb = batch[4] 153 | images = images.to(device, non_blocking=True) 154 | target = target.to(device, non_blocking=True) 155 | 156 | # compute output 157 | with torch.cuda.amp.autocast(): 158 | output = model(images) 159 | loss = criterion(output, target) 160 | 161 | for i in range(output.size(0)): 162 | string = "{} {} {} {} {}\n".format( 163 | ids[i], str(output.data[i].cpu().numpy().tolist()), str(int(target[i].cpu().numpy())), 164 | str(int(chunk_nb[i].cpu().numpy())), str(int(split_nb[i].cpu().numpy())) 165 | ) 166 | final_result.append(string) 167 | 168 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 169 | 170 | batch_size = images.shape[0] 171 | metric_logger.update(loss=loss.item()) 172 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 173 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 174 | 175 | if not os.path.exists(file): 176 | os.mknod(file) 177 | with open(file, 'w') as f: 178 | f.write("{}, {}\n".format(acc1, acc5)) 179 | for line in final_result: 180 | f.write(line) 181 | 182 | # gather the stats from all processes 183 | metric_logger.synchronize_between_processes() 184 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 185 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 186 | 187 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 188 | 189 | 190 | def merge(eval_path, num_tasks, is_hmdb=False): 191 | dict_feats = {} 192 | dict_label = {} 193 | dict_pos = {} 194 | print("Reading individual output files") 195 | 196 | for x in range(num_tasks): 197 | file = os.path.join(eval_path, str(x) + '.txt') 198 | lines = open(file, 'r').readlines()[1:] 199 | for line in lines: 200 | line = line.strip() 201 | name = line.split('[')[0] 202 | label = line.split(']')[1].split(' ')[1] 203 | chunk_nb = line.split(']')[1].split(' ')[2] 204 | split_nb = line.split(']')[1].split(' ')[3] 205 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',') 206 | if not name in dict_feats: 207 | dict_feats[name] = [] 208 | dict_label[name] = 0 209 | dict_pos[name] = [] 210 | if chunk_nb + split_nb in dict_pos[name]: 211 | continue 212 | dict_feats[name].append(data) 213 | dict_pos[name].append(chunk_nb + split_nb) 214 | dict_label[name] = label 215 | print("Computing final results") 216 | 217 | input_lst = [] 218 | print(len(dict_feats)) 219 | for i, item in enumerate(dict_feats): 220 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 221 | from multiprocessing import Pool 222 | p = Pool(64) 223 | ans = p.map(compute_video_hmdb if is_hmdb else compute_video, input_lst) 224 | top1 = [x[1] for x in ans] 225 | top5 = [x[2] for x in ans] 226 | pred = [x[0] for x in ans] 227 | label = [x[3] for x in ans] 228 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 229 | 230 | return final_top1*100 ,final_top5*100 231 | 232 | 233 | def compute_video(lst): 234 | i, video_id, data, label = lst 235 | feat = [x for x in data] 236 | feat = np.mean(feat, axis=0) 237 | pred = np.argmax(feat) 238 | top1 = (int(pred) == int(label)) * 1.0 239 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 240 | return [pred, top1, top5, int(label)] 241 | 242 | 243 | def compute_video_hmdb(lst): 244 | i, video_id, data, label = lst 245 | feat = [x for x in data] 246 | feat = np.mean(feat, axis=0) 247 | try: 248 | pred = np.argmax(feat) 249 | top1 = (int(pred) == int(label)) * 1.0 250 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 251 | except: 252 | pred = 0 253 | top1 = 1.0 254 | top5 = 1.0 255 | label = 0 256 | return [pred, top1, top5, int(label)] 257 | -------------------------------------------------------------------------------- /figs/petls_patt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruceyo/V-PETL/661198dd3810d23b36808368930aba69e635c34c/figs/petls_patt.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # References: 3 | # A Unified Framework for Parameter-Efficient Transfer Learning:https://github.com/jxhe/unify-parameter-efficient-tuning 4 | # Video Swin Transformer: https://github.com/SwinTransformer/Video-Swin-Transformer 5 | # AdaptFormer: https://github.com/ShoufaChen/AdaptFormer 6 | # -------------------------------------------------------- 7 | 8 | import argparse 9 | import datetime 10 | import json 11 | import numpy as np 12 | import os 13 | import time 14 | from pathlib import Path 15 | from collections import OrderedDict 16 | from easydict import EasyDict 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | from torch.utils.tensorboard import SummaryWriter 20 | from datasets.video_datasets import build_dataset 21 | from datasets.kinetics import build_training_dataset 22 | 23 | # assert timm.__version__ == "0.3.2" # version check 24 | from timm.models.layers import trunc_normal_ 25 | from timm.models import create_model 26 | import util.misc as misc 27 | from util.pos_embed import interpolate_pos_embed_ori as interpolate_pos_embed 28 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 29 | 30 | from engine_finetune import train_one_epoch, evaluate 31 | from engine_finetune import merge, final_test 32 | import models 33 | from models.video_swin_transformer_patt import * 34 | local_rank_ = int(os.environ["LOCAL_RANK"]) 35 | 36 | def construct_optimizer(model, args): 37 | # Batchnorm parameters. 38 | bn_params = [] 39 | # Non-batchnorm parameters. 40 | non_bn_parameters = [] 41 | for name, p in model.named_parameters(): 42 | if p.requires_grad: 43 | if "bn" in name: 44 | bn_params.append(p) 45 | else: 46 | non_bn_parameters.append(p) 47 | optim_params = [ 48 | {"params": bn_params, "weight_decay": 0.}, 49 | {"params": non_bn_parameters, "weight_decay": args.weight_decay}, 50 | ] 51 | return torch.optim.SGD( 52 | optim_params, 53 | lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, 54 | ) 55 | 56 | 57 | def get_args_parser(): 58 | parser = argparse.ArgumentParser('AdaptFormer fine-tuning for action recognition', add_help=False) 59 | parser.add_argument('--batch_size', default=512, type=int, 60 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 61 | parser.add_argument('--epochs', default=90, type=int) 62 | parser.add_argument('--accum_iter', default=1, type=int, 63 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 64 | 65 | # Model parameters 66 | parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL', 67 | help='Name of model to train') 68 | # Optimizer parameters 69 | parser.add_argument('--weight_decay', type=float, default=0, 70 | help='weight decay (default: 0 for linear probe following MoCo v1)') 71 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 72 | help='learning rate (absolute lr)') 73 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 74 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 75 | 76 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 77 | help='lower lr bound for cyclic schedulers that hit 0') 78 | 79 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 80 | help='epochs to warmup LR') 81 | 82 | # * Finetuning params 83 | parser.add_argument('--finetune', default='', 84 | help='finetune from checkpoint') 85 | parser.add_argument('--global_pool', action='store_true') 86 | parser.set_defaults(global_pool=False) 87 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 88 | help='Use class token instead of global pool for classification') 89 | 90 | # Dataset parameters 91 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 92 | help='dataset path') 93 | parser.add_argument('--nb_classes', default=174, type=int, 94 | help='number of the classification types') 95 | 96 | parser.add_argument('--output_dir', default='./output_dir', 97 | help='path where to save, empty for no saving') 98 | parser.add_argument('--log_dir', default=None, 99 | help='path where to tensorboard log') 100 | parser.add_argument('--device', default='cuda', 101 | help='device to use for training / testing') 102 | parser.add_argument('--seed', default=0, type=int) 103 | parser.add_argument('--resume', default='', 104 | help='resume from checkpoint') 105 | 106 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 107 | help='start epoch') 108 | parser.add_argument('--eval', action='store_true', 109 | help='Perform evaluation only') 110 | parser.add_argument('--dist_eval', action='store_true', default=False, 111 | help='Enabling distributed evaluation (recommended during training for faster monitor') 112 | parser.add_argument('--num_workers', default=10, type=int) 113 | parser.add_argument('--pin_mem', action='store_true', 114 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 115 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 116 | parser.set_defaults(pin_mem=True) 117 | 118 | # distributed training parameters 119 | parser.add_argument('--world_size', default=1, type=int, 120 | help='number of distributed processes') 121 | parser.add_argument('--local_rank', default=-1, type=int) 122 | parser.add_argument('--dist_on_itp', action='store_true') 123 | parser.add_argument('--dist_url', default='env://', 124 | help='url used to set up distributed training') 125 | 126 | # custom parameters 127 | parser.add_argument('--linprob', default=True) 128 | parser.add_argument('--tubelet_size', type=int, default=2) 129 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 130 | help='Dropout rate (default: 0.)') 131 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', 132 | help='Attention dropout rate (default: 0.)') 133 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', 134 | help='No drop path for linear probe') 135 | parser.add_argument('--use_mean_pooling', default=True) 136 | parser.add_argument('--init_scale', default=0.001, type=float) 137 | 138 | # video data parameters 139 | parser.add_argument('--data_set', default='SSV2', 140 | choices=['SSV2', 'HMDB51', 'image_folder'], 141 | type=str, help='dataset') 142 | parser.add_argument('--num_segments', type=int, default=1) 143 | parser.add_argument('--num_frames', type=int, default=8) 144 | parser.add_argument('--sampling_rate', type=int, default=4) 145 | parser.add_argument('--num_sample', type=int, default=1, 146 | help='Repeated_aug (default: 1)') 147 | parser.add_argument('--crop_pct', type=float, default=None) 148 | parser.add_argument('--short_side_size', type=int, default=224) 149 | parser.add_argument('--test_num_segment', type=int, default=4) 150 | parser.add_argument('--test_num_crop', type=int, default=3) 151 | parser.add_argument('--input_size', default=224, type=int, help='videos input size') 152 | 153 | # AdaptFormer related parameters 154 | parser.add_argument('--ffn_adapt', default=False, action='store_true', help='whether activate AdaptFormer') 155 | parser.add_argument('--ffn_num', default=64, type=int, help='bottleneck middle dimension') 156 | parser.add_argument('--vpt', default=False, action='store_true', help='whether activate VPT') 157 | parser.add_argument('--vpt_num', default=1, type=int, help='number of VPT prompts') 158 | parser.add_argument('--fulltune', default=False, action='store_true', help='full finetune model') 159 | parser.add_argument('--inception', default=False, action='store_true', help='whether use INCPETION mean and std' 160 | '(for Jx provided IN-21K pretrain') 161 | 162 | parser.add_argument('--att_prefix_mode', default='prompt_qk', 163 | choices=['patt_kv','patt_qv','patt_qk','patt_qkv', 164 | 'prefix_kv', 'prefix_qk'], type=str, help='dataset') 165 | parser.add_argument('--att_prefix_scale', default=0.2, type=float) 166 | parser.add_argument('--att_prefix', default=False, action='store_true', help='whether activate AdaptFormer') 167 | parser.add_argument('--att_preseqlen', default=20, type=int, help='bottleneck middle dimension') 168 | parser.add_argument('--att_mid_dim', default=64, type=int, help='bottleneck middle dimension') 169 | parser.add_argument('--tuned_backbone_layer_fc', '--tbl_fc', default='False', type=str, help='configure the tuned layer') 170 | 171 | return parser 172 | 173 | 174 | def main(args): 175 | if args.log_dir is None: 176 | args.log_dir = args.output_dir 177 | misc.init_distributed_mode(args) 178 | 179 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 180 | print("{}".format(args).replace(', ', ',\n')) 181 | 182 | device = torch.device(args.device) 183 | 184 | # fix the seed for reproducibility 185 | seed = args.seed + misc.get_rank() 186 | torch.manual_seed(seed) 187 | np.random.seed(seed) 188 | 189 | cudnn.benchmark = True 190 | 191 | # dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args) 192 | if args.data_set == 'SSV2': 193 | args.nb_classes = 174 194 | elif args.data_set == 'HMDB51': 195 | args.nb_classes = 51 196 | else: 197 | raise ValueError(args.data_set) 198 | dataset_train = build_training_dataset(args) 199 | dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args) 200 | dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args) 201 | 202 | if True: # args.distributed: 203 | num_tasks = misc.get_world_size() 204 | global_rank = misc.get_rank() 205 | sampler_train = torch.utils.data.DistributedSampler( 206 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 207 | ) 208 | print("Sampler_train = %s" % str(sampler_train)) 209 | 210 | if args.dist_eval: 211 | if len(dataset_val) % num_tasks != 0: 212 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 213 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 214 | 'equal num of samples per-process.') 215 | sampler_val = torch.utils.data.DistributedSampler( 216 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) # shuffle=True to reduce monitor bias 217 | sampler_test = torch.utils.data.DistributedSampler( 218 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False) 219 | else: 220 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 221 | 222 | if global_rank == 0 and args.log_dir is not None and not args.eval: 223 | os.makedirs(args.log_dir, exist_ok=True) 224 | log_writer = SummaryWriter(log_dir=args.log_dir) 225 | else: 226 | log_writer = None 227 | 228 | data_loader_train = torch.utils.data.DataLoader( 229 | dataset_train, sampler=sampler_train, 230 | batch_size=args.batch_size, 231 | num_workers=args.num_workers, 232 | pin_memory=args.pin_mem, 233 | drop_last=True, 234 | ) 235 | 236 | data_loader_val = torch.utils.data.DataLoader( 237 | dataset_val, sampler=sampler_val, 238 | batch_size=args.batch_size, 239 | num_workers=args.num_workers, 240 | pin_memory=args.pin_mem, 241 | drop_last=False 242 | ) 243 | 244 | data_loader_test = torch.utils.data.DataLoader( 245 | dataset_test, sampler=sampler_test, 246 | batch_size=args.batch_size, 247 | num_workers=args.num_workers, 248 | pin_memory=args.pin_mem, 249 | drop_last=False 250 | ) 251 | 252 | # fine-tuning configs 253 | tuning_config = EasyDict( 254 | # AdaptFormer 255 | ffn_adapt=args.ffn_adapt, 256 | ffn_option="parallel", 257 | ffn_adapter_layernorm_option="none", 258 | ffn_adapter_init_option="lora", 259 | ffn_adapter_scalar="0.1", 260 | ffn_num=args.ffn_num, 261 | d_model= 768, 262 | # VPT related 263 | vpt_on=args.vpt, 264 | vpt_num=args.vpt_num, 265 | # Prefix 266 | att_prefix_mode=args.att_prefix_mode, 267 | att_prefix_scale=args.att_prefix_scale, 268 | att_prefix=args.att_prefix, 269 | att_preseqlen=args.att_preseqlen, 270 | att_mid_dim=args.att_mid_dim, 271 | ) 272 | 273 | if args.model.startswith('swin_'): 274 | model = SwinTransformer3D(embed_dim=128, 275 | depths=[2, 2, 18, 2], 276 | num_heads=[4, 8, 16, 32], 277 | patch_size=(2,4,4), 278 | window_size=(8,7,7), 279 | drop_path_rate=0.4, 280 | patch_norm=True, 281 | tuning_config=tuning_config,) 282 | checkpoint = torch.load(args.finetune) 283 | new_state_dict = OrderedDict() 284 | for k, v in checkpoint['state_dict'].items(): 285 | if 'backbone' in k: 286 | name = k[9:] 287 | new_state_dict[name] = v 288 | 289 | msg = model.load_state_dict(new_state_dict, strict=False) 290 | model.fc = nn.Sequential(nn.Conv3d(1024, args.nb_classes, kernel_size=1, stride=1, bias=True),) 291 | 292 | else: 293 | model = create_model( 294 | args.model, 295 | pretrained=False, 296 | num_classes=args.nb_classes, 297 | all_frames=args.num_frames * args.num_segments, 298 | tubelet_size=args.tubelet_size, 299 | drop_rate=args.drop, 300 | drop_path_rate=args.drop_path, 301 | attn_drop_rate=args.attn_drop_rate, 302 | drop_block_rate=None, 303 | use_mean_pooling=args.use_mean_pooling, 304 | init_scale=args.init_scale, 305 | tuning_config=tuning_config, 306 | ) 307 | patch_size = model.patch_embed.patch_size 308 | print("Patch size = %s" % str(patch_size)) 309 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) 310 | args.patch_size = patch_size 311 | 312 | if args.finetune and not args.eval and not args.model.startswith('s3d') and not args.model.startswith('swin'): 313 | checkpoint = torch.load(args.finetune, map_location='cpu') 314 | 315 | print("Load pre-trained checkpoint from: %s" % args.finetune) 316 | if 'model' in checkpoint: 317 | raw_checkpoint_model = checkpoint['model'] 318 | elif 'module' in checkpoint: 319 | raw_checkpoint_model = checkpoint['module'] 320 | else: 321 | raw_checkpoint_model = checkpoint 322 | 323 | if os.path.basename(args.finetune).startswith('pretrain'): 324 | checkpoint_model = OrderedDict() 325 | for k, v in raw_checkpoint_model.items(): 326 | if k.startswith('encoder.'): 327 | checkpoint_model[k[8:]] = v # remove 'encoder.' prefix 328 | del checkpoint_model['norm.weight'] 329 | del checkpoint_model['norm.bias'] 330 | elif os.path.basename(args.finetune).startswith('finetune'): 331 | checkpoint_model = raw_checkpoint_model 332 | elif os.path.basename(args.finetune).startswith('videomae'): 333 | checkpoint_model = raw_checkpoint_model 334 | elif os.path.basename(args.finetune) == "vit_base_patch16_224_in21k_tongzhan_new.pth": 335 | checkpoint_model = raw_checkpoint_model 336 | del checkpoint_model['norm.weight'] 337 | del checkpoint_model['norm.bias'] 338 | elif os.path.basename(args.finetune).startswith('swin_base_patch244'): 339 | checkpoint_model = OrderedDict() 340 | for k, v in raw_checkpoint_model['state_dict'].items(): 341 | if k.startswith('backbone.'): 342 | checkpoint_model[k[9:]] = v 343 | else: 344 | raise ValueError("Warning: Double Check!") 345 | 346 | # load pre-trained model 347 | msg = model.load_state_dict(checkpoint_model, strict=False) 348 | # manually initialize fc layer: following MoCo v3 349 | trunc_normal_(model.head.weight, std=0.01) 350 | 351 | if not args.resume or True: 352 | for name, p in model.named_parameters(): 353 | if name in msg.missing_keys: 354 | p.requires_grad = True 355 | else: 356 | p.requires_grad = False if not args.fulltune else True 357 | 358 | if args.tuned_backbone_layer_fc.lower() == 'false': 359 | for name2, params in model.fc.named_parameters(): 360 | params.requires_grad = False 361 | else: 362 | for name2, params in model.fc.named_parameters(): 363 | params.requires_grad = True 364 | 365 | model.to(device) 366 | 367 | model_without_ddp = model 368 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 369 | n_parameters_freeze = sum(p.numel() for p in model.parameters() if not p.requires_grad) 370 | 371 | print('number of params tune (M): %.5f' % (n_parameters / 1.e6)) 372 | print('number of params freeze (M): %.5f' % (n_parameters_freeze / 1.e6)) 373 | 374 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 375 | 376 | if args.lr is None: # only base_lr is specified 377 | args.lr = args.blr * eff_batch_size / 256 378 | 379 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 380 | print("actual lr: %.2e" % args.lr) 381 | 382 | print("accumulate grad iterations: %d" % args.accum_iter) 383 | print("effective batch size: %d" % eff_batch_size) 384 | 385 | if args.distributed: 386 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 387 | model_without_ddp = model.module 388 | optimizer = construct_optimizer(model_without_ddp, args) 389 | print(optimizer) 390 | loss_scaler = NativeScaler() 391 | 392 | criterion = torch.nn.CrossEntropyLoss() 393 | 394 | print("criterion = %s" % str(criterion)) 395 | 396 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 397 | 398 | if args.eval: 399 | preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt') 400 | test_stats = evaluate(data_loader_val, model, device) 401 | print(f'Max accuracy: {test_stats["acc1"]:.2f}%') 402 | test_stats = final_test(data_loader_test, model, device, preds_file) 403 | torch.distributed.barrier() 404 | if global_rank == 0: 405 | print("Start merging results...") 406 | final_top1, final_top5 = merge(args.output_dir, num_tasks, is_hmdb=args.data_set == 'HMDB51') 407 | print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%") 408 | log_stats = {'Final top-1': final_top1, 'Final Top-5': final_top5} 409 | if args.output_dir and misc.is_main_process(): 410 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 411 | f.write(json.dumps(log_stats) + "\n") 412 | exit(0) 413 | 414 | print(f"Start training for {args.epochs} epochs") 415 | start_time = time.time() 416 | max_accuracy = 0.0 417 | for epoch in range(args.start_epoch, args.epochs): 418 | if args.distributed: 419 | data_loader_train.sampler.set_epoch(epoch) 420 | train_stats = train_one_epoch( 421 | model, criterion, data_loader_train, 422 | optimizer, device, epoch, loss_scaler, 423 | max_norm=None, 424 | log_writer=log_writer, 425 | args=args 426 | ) 427 | 428 | test_stats = evaluate(data_loader_val, model, device) 429 | 430 | print(f"Accuracy of the network on the {len(data_loader_test)} test images: {test_stats['acc1']:.2f}%") 431 | 432 | if test_stats["acc1"] > max_accuracy: 433 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 434 | print(f'Max accuracy: {max_accuracy:.2f}%') 435 | model_best = model 436 | if args.output_dir: 437 | misc.save_model( 438 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 439 | loss_scaler=loss_scaler, epoch=epoch) 440 | 441 | misc.is_main_process() 442 | if log_writer is not None: 443 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 444 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 445 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 446 | 447 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 448 | **{f'test_{k}': v for k, v in test_stats.items()}, 449 | 'epoch': epoch, 450 | 'n_parameters': n_parameters} 451 | 452 | if args.output_dir and misc.is_main_process(): 453 | if log_writer is not None: 454 | log_writer.flush() 455 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 456 | f.write(json.dumps(log_stats) + "\n") 457 | 458 | test_stats = evaluate(data_loader_val, model_best, device) 459 | preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt') 460 | test_stats = final_test(data_loader_test, model_best, device, preds_file) 461 | torch.distributed.barrier() 462 | if global_rank == 0: 463 | print("Start merging results...") 464 | final_top1, final_top5 = merge(args.output_dir, num_tasks, is_hmdb=args.data_set == 'HMDB51') 465 | print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%") 466 | log_stats = {'Final top-1': final_top1, 'Final Top-5': final_top5} 467 | if args.output_dir and misc.is_main_process(): 468 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 469 | f.write(json.dumps(log_stats) + "\n") 470 | 471 | total_time = time.time() - start_time 472 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 473 | print('Training time {}'.format(total_time_str)) 474 | 475 | 476 | if __name__ == '__main__': 477 | args = get_args_parser() 478 | args = args.parse_args() 479 | args.local_rank = local_rank_ 480 | if args.output_dir: 481 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 482 | main(args) 483 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import models.video_swin_transformer_patt 2 | -------------------------------------------------------------------------------- /models/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /models/video_swin_transformer_patt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Credit to the official implementation: https://github.com/SwinTransformer/Video-Swin-Transformer 3 | ''' 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint as checkpoint 10 | import numpy as np 11 | from timm.models.layers import DropPath, trunc_normal_ 12 | 13 | from functools import reduce, lru_cache 14 | from operator import mul 15 | from einops import rearrange 16 | 17 | import logging 18 | from mmcv.utils import get_logger 19 | from mmcv.runner import load_checkpoint 20 | 21 | import math 22 | import torch 23 | import torch.nn as nn 24 | 25 | class Adapter(nn.Module): 26 | def __init__(self, 27 | config=None, 28 | d_model=None, 29 | bottleneck=None, 30 | dropout=0.0, 31 | init_option="bert", 32 | adapter_scalar="1.0", 33 | adapter_layernorm_option="in"): 34 | super().__init__() 35 | self.n_embd = config.d_model if d_model is None else d_model 36 | self.down_size = config.attn_bn if bottleneck is None else bottleneck 37 | 38 | #_before 39 | self.adapter_layernorm_option = adapter_layernorm_option 40 | 41 | self.adapter_layer_norm_before = None 42 | if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": 43 | self.adapter_layer_norm_before = nn.LayerNorm(self.n_embd) 44 | 45 | if adapter_scalar == "learnable_scalar": 46 | self.scale = nn.Parameter(torch.ones(1)) 47 | else: 48 | self.scale = float(adapter_scalar) 49 | 50 | self.down_proj = nn.Linear(self.n_embd, self.down_size) 51 | self.non_linear_func = nn.ReLU() 52 | self.up_proj = nn.Linear(self.down_size, self.n_embd) 53 | 54 | self.dropout = dropout 55 | if init_option == "bert": 56 | raise NotImplementedError 57 | elif init_option == "lora": 58 | with torch.no_grad(): 59 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 60 | nn.init.zeros_(self.up_proj.weight) 61 | nn.init.zeros_(self.down_proj.bias) 62 | nn.init.zeros_(self.up_proj.bias) 63 | 64 | def forward(self, x, add_residual=True, residual=None): 65 | residual = x if residual is None else residual 66 | if self.adapter_layernorm_option == 'in': 67 | x = self.adapter_layer_norm_before(x) 68 | 69 | down = self.down_proj(x) 70 | down = self.non_linear_func(down) 71 | down = nn.functional.dropout(down, p=self.dropout, training=self.training) 72 | up = self.up_proj(down) 73 | 74 | up = up * self.scale 75 | 76 | if self.adapter_layernorm_option == 'out': 77 | up = self.adapter_layer_norm_before(up) 78 | 79 | if add_residual: 80 | output = up + residual 81 | else: 82 | output = up 83 | 84 | return output 85 | 86 | def get_root_logger(log_file=None, log_level=logging.INFO): 87 | """Use ``get_logger`` method in mmcv to get the root logger. 88 | The logger will be initialized if it has not been initialized. By default a 89 | StreamHandler will be added. If ``log_file`` is specified, a FileHandler 90 | will also be added. The name of the root logger is the top-level package 91 | name, e.g., "mmaction". 92 | Args: 93 | log_file (str | None): The log filename. If specified, a FileHandler 94 | will be added to the root logger. 95 | log_level (int): The root logger level. Note that only the process of 96 | rank 0 is affected, while other processes will set the level to 97 | "Error" and be silent most of the time. 98 | Returns: 99 | :obj:`logging.Logger`: The root logger. 100 | """ 101 | return get_logger(__name__.split('.')[0], log_file, log_level) 102 | 103 | 104 | class Mlp(nn.Module): 105 | """ Multilayer perceptron.""" 106 | 107 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., config=None): 108 | super().__init__() 109 | out_features = out_features or in_features 110 | hidden_features = hidden_features or in_features 111 | self.fc1 = nn.Linear(in_features, hidden_features) 112 | self.act = act_layer() 113 | self.fc2 = nn.Linear(hidden_features, out_features) 114 | self.drop = nn.Dropout(drop) 115 | 116 | # implementation of AdapterMLP 117 | self.config = config 118 | if config.ffn_adapt: 119 | self.adaptmlp = Adapter(self.config, d_model=in_features, dropout=drop, bottleneck=config.ffn_num, 120 | init_option=config.ffn_adapter_init_option, 121 | adapter_scalar=config.ffn_adapter_scalar, 122 | adapter_layernorm_option=config.ffn_adapter_layernorm_option 123 | ) 124 | 125 | def forward(self, x): 126 | if self.config.ffn_adapt and self.config.ffn_option == 'parallel': 127 | adapt_x = self.adaptmlp(x, add_residual=False) 128 | 129 | x = self.fc1(x) 130 | x = self.act(x) 131 | x = self.drop(x) 132 | x = self.fc2(x) 133 | x = self.drop(x) 134 | 135 | if self.config.ffn_adapt: 136 | if self.config.ffn_option == 'sequential': 137 | x = self.adaptmlp(x) 138 | elif self.config.ffn_option == 'parallel': 139 | x = x + adapt_x 140 | else: 141 | raise ValueError(self.config.ffn_adapt) 142 | 143 | return x 144 | 145 | def window_partition(x, window_size): 146 | """ 147 | Args: 148 | x: (B, D, H, W, C) 149 | window_size (tuple[int]): window size 150 | Returns: 151 | windows: (B*num_windows, window_size*window_size, C) 152 | """ 153 | B, D, H, W, C = x.shape # torch.Size([2, 4, 56, 56, 128]), (4,7,7) 154 | x = x.view(B, D // window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], C) 155 | windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, reduce(mul, window_size), C) 156 | return windows 157 | 158 | 159 | def window_reverse(windows, window_size, B, D, H, W): 160 | """ 161 | Args: 162 | windows: (B*num_windows, window_size, window_size, C) 163 | window_size (tuple[int]): Window size 164 | H (int): Height of image 165 | W (int): Width of image 166 | Returns: 167 | x: (B, D, H, W, C) 168 | """ 169 | x = windows.view(B, D // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], -1) 170 | x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) 171 | return x 172 | 173 | 174 | def get_window_size(x_size, window_size, shift_size=None): 175 | use_window_size = list(window_size) 176 | if shift_size is not None: 177 | use_shift_size = list(shift_size) 178 | for i in range(len(x_size)): 179 | if x_size[i] <= window_size[i]: 180 | use_window_size[i] = x_size[i] 181 | if shift_size is not None: 182 | use_shift_size[i] = 0 183 | 184 | if shift_size is None: 185 | return tuple(use_window_size) 186 | else: 187 | return tuple(use_window_size), tuple(use_shift_size) 188 | 189 | 190 | class WindowAttention3D(nn.Module): 191 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 192 | It supports both of shifted and non-shifted window. 193 | Args: 194 | dim (int): Number of input channels. 195 | window_size (tuple[int]): The temporal length, height and width of the window. 196 | num_heads (int): Number of attention heads. 197 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 198 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 199 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 200 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 201 | """ 202 | def __init__(self, dim, window_size, num_heads, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., config=None): 203 | super().__init__() 204 | self.dim = dim 205 | self.window_size = window_size # Wd, Wh, Ww 206 | self.num_heads = num_heads 207 | head_dim = dim // num_heads 208 | self.scale = qk_scale or head_dim ** -0.5 209 | self.config = config 210 | 211 | # define a parameter table of relative position bias 212 | self.relative_position_bias_table = nn.Parameter( 213 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH 214 | 215 | # get pair-wise relative position index for each token inside the window 216 | coords_d = torch.arange(self.window_size[0]) 217 | coords_h = torch.arange(self.window_size[1]) 218 | coords_w = torch.arange(self.window_size[2]) 219 | coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww 220 | coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww 221 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww 222 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 223 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 224 | relative_coords[:, :, 1] += self.window_size[1] - 1 225 | relative_coords[:, :, 2] += self.window_size[2] - 1 226 | 227 | relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) 228 | relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) 229 | relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww 230 | self.register_buffer("relative_position_index", relative_position_index) 231 | 232 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 233 | self.attn_drop = nn.Dropout(attn_drop) 234 | self.proj = nn.Linear(dim, dim) 235 | self.proj_drop = nn.Dropout(proj_drop) 236 | 237 | trunc_normal_(self.relative_position_bias_table, std=.02) 238 | self.softmax = nn.Softmax(dim=-1) 239 | 240 | if self.config.att_prefix: 241 | 242 | n_embd = self.get_runtim_n_embd(128, 196, dim) 243 | if self.config.att_prefix_mode in ['patt_kv','patt_qv','patt_qk']: 244 | self.config.att_preseqlen=dim 245 | self.qkv_adapter = nn.Sequential( 246 | nn.Linear(dim, self.config.att_mid_dim, bias=qkv_bias), 247 | nn.Tanh(), 248 | nn.Linear(self.config.att_mid_dim, 2 * dim, bias=qkv_bias)) 249 | elif self.config.att_prefix_mode in ['patt_qkv']: 250 | self.config.att_preseqlen=dim 251 | self.qkv_adapter = nn.Sequential( 252 | nn.Linear(dim, self.config.att_mid_dim, bias=qkv_bias), 253 | nn.Tanh(), 254 | nn.Linear(self.config.att_mid_dim, 3 * dim, bias=qkv_bias)) 255 | elif self.config.att_prefix_mode in ['prefix_kv','prefix_qk']: 256 | self.input_tokens = torch.arange(self.config.att_preseqlen).long() 257 | self.wte = nn.Embedding(self.config.att_preseqlen, n_embd) 258 | self.control_trans = nn.Sequential( 259 | nn.Linear(n_embd, self.config.att_mid_dim), 260 | nn.Tanh(), 261 | nn.Linear(self.config.att_mid_dim, 2 * n_embd)) 262 | else: 263 | raise Exception("Sorry, no attention prefix_mode found") 264 | 265 | self.get_prompt = self.get_prompt_p5 266 | 267 | def get_runtim_n_embd(self, B, N, C): 268 | B_ = B 269 | x = torch.randn(B, N, C) 270 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 271 | q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C 272 | if self.config.att_prefix_mode in ['prefix_kv']: 273 | return int(self.num_heads * v.shape[-1]) 274 | else: 275 | return int(self.num_heads * v.shape[-2]) 276 | 277 | def get_prompt_p5(self, control_code=None, gpt2=None, bsz=None, device=None, match_n_embd=32): 278 | 279 | n_head = self.num_heads 280 | 281 | match_n_layer = 1 282 | match_n_head = n_head 283 | if self.config.att_prefix_mode in ['patt_kv','prefix_qk']: 284 | match_n_embd = match_n_embd // n_head 285 | 286 | input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(device) 287 | 288 | temp_control = self.wte(input_tokens) 289 | past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb 290 | bsz, seqlen, _ = past_key_values.shape 291 | past_key_values = past_key_values.view(bsz, seqlen, match_n_layer * 2, match_n_head, 292 | match_n_embd) 293 | 294 | past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(match_n_layer) 295 | return past_key_values 296 | 297 | def forward(self, x, past_layer=None,mask=None): 298 | """ Forward function. 299 | Args: 300 | x: input features with shape of (num_windows*B, N, C) 301 | mask: (0/-inf) mask with shape of (num_windows, N, N) or None 302 | """ 303 | B_, N, C = x.shape 304 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 305 | q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C 306 | if self.config.att_prefix: 307 | if self.config.att_prefix_mode == 'patt_kv': 308 | qkv_adapter = self.qkv_adapter(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 309 | k_, v_ = qkv_adapter[0], qkv_adapter[1] 310 | k = k + self.config.att_prefix_scale * k_ 311 | v = v + self.config.att_prefix_scale * v_ 312 | elif self.config.att_prefix_mode == 'patt_qv': 313 | qkv_adapter = self.qkv_adapter(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 314 | q_, v_ = qkv_adapter[0], qkv_adapter[1] 315 | q = q + self.config.att_prefix_scale * q_ 316 | v = v + self.config.att_prefix_scale * v_ 317 | elif self.config.att_prefix_mode == 'patt_qk': 318 | qkv_adapter = self.qkv_adapter(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 319 | q_, k_ = qkv_adapter[0], qkv_adapter[1] 320 | q = q + self.config.att_prefix_scale * q_ 321 | k = k + self.config.att_prefix_scale * k_ 322 | elif self.config.att_prefix_mode == 'patt_qkv': 323 | qkv_adapter = self.qkv_adapter(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 324 | q_, k_, v_ = qkv_adapter[0], qkv_adapter[1], qkv_adapter[2] 325 | q = q + self.config.att_prefix_scale * q_ 326 | k = k + self.config.att_prefix_scale * k_ 327 | v = v + self.config.att_prefix_scale * v_ 328 | elif self.config.att_prefix_mode == 'prefix_kv': 329 | past_key_values_prompt = self.get_prompt(bsz=x.shape[0], device=x.device, match_n_embd=int(self.num_heads * v.shape[-2])) 330 | past_key_values = past_key_values_prompt 331 | past_query, past_key = past_key_values[0], past_key_values[1] 332 | q = torch.cat([q, past_query[0].transpose(-2, -1)], dim=-1) 333 | k = torch.cat([k, past_key[0].transpose(-2, -1)], dim=-1) 334 | elif self.config.att_prefix_mode == 'prefix_qk': 335 | past_key_values_prompt = self.get_prompt(bsz=x.shape[0], device=x.device, match_n_embd=v.shape[-1]) 336 | past_key_values = past_key_values_prompt 337 | past_key, past_value = past_key_values[0], past_key_values[1] 338 | k = torch.cat([k, past_key[0]], dim=-2) 339 | v = torch.cat([v, past_value[0]], dim=-2) 340 | else: 341 | raise Exception("Sorry, no attention prefix_mode found") 342 | 343 | q = q * self.scale 344 | attn = q @ k.transpose(-2, -1) 345 | 346 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape( 347 | N, N, -1) # Wd*Wh*Ww,Wd*Wh*Ww,nH 348 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww 349 | if self.config.att_prefix_mode in ['prefix_qk','patt_kv','patt_qv','prefix_qk','patt_qkv']: 350 | attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N 351 | 352 | if mask is not None: 353 | nW = mask.shape[0] 354 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 355 | attn = attn.view(-1, self.num_heads, N, N) 356 | attn = self.softmax(attn) 357 | else: 358 | attn = self.softmax(attn) 359 | 360 | attn = self.attn_drop(attn) 361 | if self.config.att_prefix_mode =='prefix_kv': 362 | attn[:, :, :, :attn.shape[2]] = attn[:, :, :, :attn.shape[2]] + relative_position_bias.unsqueeze(0) # B_, nH, N, N 363 | 364 | if mask is not None: 365 | nW = mask.shape[0] 366 | 367 | org_attn = attn[:, :, :, :attn.shape[2]] 368 | org_attn = org_attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 369 | org_attn = org_attn.view(-1, self.num_heads, N, N) 370 | org_attn = self.softmax(org_attn) 371 | attn[:, :, :, :attn.shape[2]] = org_attn 372 | else: 373 | attn = self.softmax(attn) 374 | attn = self.attn_drop(attn) 375 | 376 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 377 | x = self.proj(x) 378 | x = self.proj_drop(x) 379 | return x 380 | 381 | class SwinTransformerBlock3D(nn.Module): 382 | """ Swin Transformer Block. 383 | Args: 384 | dim (int): Number of input channels. 385 | num_heads (int): Number of attention heads. 386 | window_size (tuple[int]): Window size. 387 | shift_size (tuple[int]): Shift size for SW-MSA. 388 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 389 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 390 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 391 | drop (float, optional): Dropout rate. Default: 0.0 392 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 393 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 394 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 395 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 396 | """ 397 | 398 | def __init__(self, dim, num_heads, window_size=(2,7,7), shift_size=(0,0,0), 399 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 400 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False, 401 | config=None): 402 | super().__init__() 403 | self.dim = dim 404 | self.num_heads = num_heads 405 | self.window_size = window_size 406 | self.shift_size = shift_size 407 | self.mlp_ratio = mlp_ratio 408 | self.use_checkpoint=use_checkpoint 409 | 410 | assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size" 411 | assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size" 412 | assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size" 413 | 414 | self.norm1 = norm_layer(dim) 415 | self.attn = WindowAttention3D( 416 | dim, window_size=self.window_size, num_heads=num_heads, 417 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, config=config) 418 | 419 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 420 | self.norm2 = norm_layer(dim) 421 | mlp_hidden_dim = int(dim * mlp_ratio) 422 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, config=config) 423 | 424 | def forward_part1(self, x, mask_matrix): 425 | B, D, H, W, C = x.shape 426 | window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size) 427 | 428 | x = self.norm1(x) 429 | # pad feature maps to multiples of window size 430 | pad_l = pad_t = pad_d0 = 0 431 | pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] 432 | pad_b = (window_size[1] - H % window_size[1]) % window_size[1] 433 | pad_r = (window_size[2] - W % window_size[2]) % window_size[2] 434 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) 435 | _, Dp, Hp, Wp, _ = x.shape 436 | # cyclic shift 437 | if any(i > 0 for i in shift_size): 438 | shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) 439 | attn_mask = mask_matrix 440 | else: 441 | shifted_x = x 442 | attn_mask = None 443 | 444 | # partition windows 445 | x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C window_size: (4, 7, 7) 446 | # W-MSA/SW-MSA 447 | attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C 448 | # merge windows 449 | attn_windows = attn_windows.view(-1, *(window_size+(C,))) 450 | shifted_x = window_reverse(attn_windows, window_size, B, Dp, Hp, Wp) # B D' H' W' C 451 | # reverse cyclic shift 452 | if any(i > 0 for i in shift_size): 453 | x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) 454 | else: 455 | x = shifted_x 456 | 457 | if pad_d1 >0 or pad_r > 0 or pad_b > 0: 458 | x = x[:, :D, :H, :W, :].contiguous() 459 | return x 460 | 461 | def forward_part2(self, x): 462 | return self.drop_path(self.mlp(self.norm2(x))) 463 | 464 | def forward(self, x, mask_matrix): 465 | """ Forward function. 466 | Args: 467 | x: Input feature, tensor size (B, D, H, W, C). 468 | mask_matrix: Attention mask for cyclic shift. 469 | """ 470 | 471 | shortcut = x 472 | if self.use_checkpoint: 473 | x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) 474 | else: 475 | x = self.forward_part1(x, mask_matrix) 476 | x = shortcut + self.drop_path(x) 477 | 478 | if self.use_checkpoint: 479 | x = x + checkpoint.checkpoint(self.forward_part2, x) 480 | else: 481 | x = x + self.forward_part2(x) 482 | 483 | return x 484 | 485 | 486 | class PatchMerging(nn.Module): 487 | """ Patch Merging Layer 488 | Args: 489 | dim (int): Number of input channels. 490 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 491 | """ 492 | def __init__(self, dim, norm_layer=nn.LayerNorm): 493 | super().__init__() 494 | self.dim = dim 495 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 496 | self.norm = norm_layer(4 * dim) 497 | 498 | def forward(self, x): 499 | """ Forward function. 500 | Args: 501 | x: Input feature, tensor size (B, D, H, W, C). 502 | """ 503 | B, D, H, W, C = x.shape 504 | 505 | # padding 506 | pad_input = (H % 2 == 1) or (W % 2 == 1) 507 | if pad_input: 508 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 509 | 510 | x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C 511 | x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C 512 | x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C 513 | x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C 514 | x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C 515 | 516 | x = self.norm(x) 517 | x = self.reduction(x) 518 | 519 | return x 520 | 521 | # cache each stage results 522 | @lru_cache() 523 | def compute_mask(D, H, W, window_size, shift_size, device): 524 | img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 525 | cnt = 0 526 | for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None): 527 | for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None): 528 | for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None): 529 | img_mask[:, d, h, w, :] = cnt 530 | cnt += 1 531 | mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 532 | mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] 533 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 534 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 535 | return attn_mask 536 | 537 | 538 | class BasicLayer(nn.Module): 539 | """ A basic Swin Transformer layer for one stage. 540 | Args: 541 | dim (int): Number of feature channels 542 | depth (int): Depths of this stage. 543 | num_heads (int): Number of attention head. 544 | window_size (tuple[int]): Local window size. Default: (1,7,7). 545 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 546 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 547 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 548 | drop (float, optional): Dropout rate. Default: 0.0 549 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 550 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 551 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 552 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 553 | """ 554 | 555 | def __init__(self, 556 | dim, 557 | depth, 558 | num_heads, 559 | window_size=(1,7,7), 560 | mlp_ratio=4., 561 | qkv_bias=False, 562 | qk_scale=None, 563 | drop=0., 564 | attn_drop=0., 565 | drop_path=0., 566 | norm_layer=nn.LayerNorm, 567 | downsample=None, 568 | use_checkpoint=False, 569 | config=None,): 570 | super().__init__() 571 | self.window_size = window_size 572 | self.shift_size = tuple(i // 2 for i in window_size) 573 | self.depth = depth 574 | self.use_checkpoint = use_checkpoint 575 | self.config = config 576 | 577 | # build blocks 578 | self.blocks = nn.ModuleList([ 579 | SwinTransformerBlock3D( 580 | dim=dim, 581 | num_heads=num_heads, 582 | window_size=window_size, 583 | shift_size=(0,0,0) if (i % 2 == 0) else self.shift_size, 584 | mlp_ratio=mlp_ratio, 585 | qkv_bias=qkv_bias, 586 | qk_scale=qk_scale, 587 | drop=drop, 588 | attn_drop=attn_drop, 589 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 590 | norm_layer=norm_layer, 591 | use_checkpoint=use_checkpoint, 592 | config=self.config, 593 | ) 594 | for i in range(depth)]) 595 | 596 | self.downsample = downsample 597 | if self.downsample is not None: 598 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 599 | 600 | def forward(self, x): 601 | """ Forward function. 602 | Args: 603 | x: Input feature, tensor size (B, C, D, H, W). 604 | """ 605 | # calculate attention mask for SW-MSA 606 | B, C, D, H, W = x.shape 607 | window_size, shift_size = get_window_size((D,H,W), self.window_size, self.shift_size) 608 | x = rearrange(x, 'b c d h w -> b d h w c') 609 | Dp = int(np.ceil(D / window_size[0])) * window_size[0] 610 | Hp = int(np.ceil(H / window_size[1])) * window_size[1] 611 | Wp = int(np.ceil(W / window_size[2])) * window_size[2] 612 | attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 613 | for blk in self.blocks: 614 | x = blk(x, attn_mask) 615 | x = x.view(B, D, H, W, -1) 616 | 617 | if self.downsample is not None: 618 | x = self.downsample(x) 619 | x = rearrange(x, 'b d h w c -> b c d h w') 620 | return x 621 | 622 | 623 | class PatchEmbed3D(nn.Module): 624 | """ Video to Patch Embedding. 625 | Args: 626 | patch_size (int): Patch token size. Default: (2,4,4). 627 | in_chans (int): Number of input video channels. Default: 3. 628 | embed_dim (int): Number of linear projection output channels. Default: 96. 629 | norm_layer (nn.Module, optional): Normalization layer. Default: None 630 | """ 631 | def __init__(self, patch_size=(2,4,4), in_chans=3, embed_dim=96, norm_layer=None): 632 | super().__init__() 633 | self.patch_size = patch_size 634 | 635 | self.in_chans = in_chans 636 | self.embed_dim = embed_dim 637 | 638 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 639 | if norm_layer is not None: 640 | self.norm = norm_layer(embed_dim) 641 | else: 642 | self.norm = None 643 | 644 | def forward(self, x): 645 | """Forward function.""" 646 | # padding 647 | _, _, D, H, W = x.size() 648 | if W % self.patch_size[2] != 0: 649 | x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) 650 | if H % self.patch_size[1] != 0: 651 | x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) 652 | if D % self.patch_size[0] != 0: 653 | x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) 654 | 655 | x = self.proj(x) # B C D Wh Ww 656 | if self.norm is not None: 657 | D, Wh, Ww = x.size(2), x.size(3), x.size(4) 658 | x = x.flatten(2).transpose(1, 2) 659 | x = self.norm(x) 660 | x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) 661 | 662 | return x 663 | 664 | 665 | class SwinTransformer3D(nn.Module): 666 | """ Swin Transformer backbone. 667 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 668 | https://arxiv.org/pdf/2103.14030 669 | Args: 670 | patch_size (int | tuple(int)): Patch size. Default: (4,4,4). 671 | in_chans (int): Number of input image channels. Default: 3. 672 | embed_dim (int): Number of linear projection output channels. Default: 96. 673 | depths (tuple[int]): Depths of each Swin Transformer stage. 674 | num_heads (tuple[int]): Number of attention head of each stage. 675 | window_size (int): Window size. Default: 7. 676 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 677 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee 678 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 679 | drop_rate (float): Dropout rate. 680 | attn_drop_rate (float): Attention dropout rate. Default: 0. 681 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 682 | norm_layer: Normalization layer. Default: nn.LayerNorm. 683 | patch_norm (bool): If True, add normalization after patch embedding. Default: False. 684 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 685 | -1 means not freezing any parameters. 686 | """ 687 | 688 | def __init__(self, 689 | pretrained=None, 690 | pretrained2d=True, 691 | patch_size=(4,4,4), 692 | in_chans=3, 693 | embed_dim=96, 694 | depths=[2, 2, 6, 2], 695 | num_heads=[3, 6, 12, 24], 696 | window_size=(2,7,7), 697 | mlp_ratio=4., 698 | qkv_bias=True, 699 | qk_scale=None, 700 | drop_rate=0., 701 | attn_drop_rate=0., 702 | drop_path_rate=0.2, 703 | norm_layer=nn.LayerNorm, 704 | patch_norm=False, 705 | tuning_config=None, 706 | frozen_stages=-1, 707 | use_checkpoint=False, 708 | ): 709 | super().__init__() 710 | self.tuning_config = tuning_config 711 | self.pretrained = pretrained 712 | self.pretrained2d = pretrained2d 713 | self.num_layers = len(depths) 714 | self.embed_dim = embed_dim 715 | self.patch_norm = patch_norm 716 | self.frozen_stages = frozen_stages 717 | self.window_size = window_size 718 | self.patch_size = patch_size 719 | 720 | # split image into non-overlapping patches 721 | self.patch_embed = PatchEmbed3D( 722 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 723 | norm_layer=norm_layer if self.patch_norm else None) 724 | 725 | self.pos_drop = nn.Dropout(p=drop_rate) 726 | 727 | # stochastic depth 728 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 729 | 730 | # build layers 731 | self.layers = nn.ModuleList() 732 | for i_layer in range(self.num_layers): 733 | layer = BasicLayer( 734 | dim=int(embed_dim * 2**i_layer), 735 | depth=depths[i_layer], 736 | num_heads=num_heads[i_layer], 737 | window_size=window_size, 738 | mlp_ratio=mlp_ratio, 739 | qkv_bias=qkv_bias, 740 | qk_scale=qk_scale, 741 | drop=drop_rate, 742 | attn_drop=attn_drop_rate, 743 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 744 | norm_layer=norm_layer, 745 | downsample=PatchMerging if i_layer= 0: 761 | self.patch_embed.eval() 762 | for param in self.patch_embed.parameters(): 763 | param.requires_grad = False 764 | 765 | if self.frozen_stages >= 1: 766 | self.pos_drop.eval() 767 | for i in range(0, self.frozen_stages): 768 | m = self.layers[i] 769 | m.eval() 770 | for param in m.parameters(): 771 | param.requires_grad = False 772 | 773 | def inflate_weights(self, logger): 774 | """Inflate the swin2d parameters to swin3d. 775 | The differences between swin3d and swin2d mainly lie in an extra 776 | axis. To utilize the pretrained parameters in 2d model, 777 | the weight of swin2d models should be inflated to fit in the shapes of 778 | the 3d counterpart. 779 | Args: 780 | logger (logging.Logger): The logger used to print 781 | debugging infomation. 782 | """ 783 | checkpoint = torch.load(self.pretrained, map_location='cpu') 784 | state_dict = checkpoint['model'] 785 | 786 | # delete relative_position_index since we always re-init it 787 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 788 | for k in relative_position_index_keys: 789 | del state_dict[k] 790 | 791 | # delete attn_mask since we always re-init it 792 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 793 | for k in attn_mask_keys: 794 | del state_dict[k] 795 | 796 | state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).repeat(1,1,self.patch_size[0],1,1) / self.patch_size[0] 797 | 798 | # bicubic interpolate relative_position_bias_table if not match 799 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 800 | for k in relative_position_bias_table_keys: 801 | relative_position_bias_table_pretrained = state_dict[k] 802 | relative_position_bias_table_current = self.state_dict()[k] 803 | L1, nH1 = relative_position_bias_table_pretrained.size() 804 | L2, nH2 = relative_position_bias_table_current.size() 805 | L2 = (2*self.window_size[1]-1) * (2*self.window_size[2]-1) 806 | wd = self.window_size[0] 807 | if nH1 != nH2: 808 | logger.warning(f"Error in loading {k}, passing") 809 | else: 810 | if L1 != L2: 811 | S1 = int(L1 ** 0.5) 812 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 813 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(2*self.window_size[1]-1, 2*self.window_size[2]-1), 814 | mode='bicubic') 815 | relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 816 | state_dict[k] = relative_position_bias_table_pretrained.repeat(2*wd-1,1) 817 | 818 | msg = self.load_state_dict(state_dict, strict=False) 819 | logger.info(msg) 820 | logger.info(f"=> loaded successfully '{self.pretrained}'") 821 | del checkpoint 822 | torch.cuda.empty_cache() 823 | 824 | def init_weights(self, pretrained=None): 825 | """Initialize the weights in backbone. 826 | Args: 827 | pretrained (str, optional): Path to pre-trained weights. 828 | Defaults to None. 829 | """ 830 | def _init_weights(m): 831 | if isinstance(m, nn.Linear): 832 | trunc_normal_(m.weight, std=.02) 833 | if isinstance(m, nn.Linear) and m.bias is not None: 834 | nn.init.constant_(m.bias, 0) 835 | elif isinstance(m, nn.LayerNorm): 836 | nn.init.constant_(m.bias, 0) 837 | nn.init.constant_(m.weight, 1.0) 838 | 839 | if pretrained: 840 | self.pretrained = pretrained 841 | if isinstance(self.pretrained, str): 842 | self.apply(_init_weights) 843 | logger = get_root_logger() 844 | logger.info(f'load model from: {self.pretrained}') 845 | 846 | if self.pretrained2d: 847 | # Inflate 2D model into 3D model. 848 | self.inflate_weights(logger) 849 | else: 850 | # Directly load 3D model. 851 | load_checkpoint(self, self.pretrained, strict=False, logger=logger) 852 | elif self.pretrained is None: 853 | self.apply(_init_weights) 854 | else: 855 | raise TypeError('pretrained must be a str or None') 856 | 857 | def forward(self, x): 858 | """Forward function.""" 859 | #import pdb; pdb.set_trace() 860 | x = self.patch_embed(x) 861 | 862 | x = self.pos_drop(x) 863 | 864 | for layer in self.layers: 865 | x = layer(x.contiguous()) 866 | 867 | x = rearrange(x, 'n c d h w -> n d h w c') 868 | x = self.norm(x) 869 | x = rearrange(x, 'n d h w c -> n c d h w') 870 | 871 | #import pdb; pdb.set_trace() 872 | y = F.avg_pool3d(x, (2, x.size(3), x.size(4)), stride=1) 873 | y = self.fc(y) 874 | y = y.view(y.size(0), y.size(1), y.size(2)) 875 | logits = torch.mean(y, 2) 876 | 877 | return logits 878 | 879 | def train(self, mode=True): 880 | """Convert the model into training mode while keep layers freezed.""" 881 | super(SwinTransformer3D, self).train(mode) 882 | self._freeze_stages() 883 | 884 | if __name__ == '__main__': 885 | 886 | from easydict import EasyDict 887 | 888 | tuning_config = EasyDict( 889 | # AdaptFormer 890 | ffn_adapt=False, 891 | ffn_option="parallel", 892 | ffn_adapter_layernorm_option="none", 893 | ffn_adapter_init_option="lora", 894 | ffn_adapter_scalar="0.1", 895 | ffn_num=64, 896 | d_model= 768, 897 | # VPT related 898 | vpt_on=False, 899 | vpt_num=1, 900 | # prefix from here 901 | att_prefix=True, 902 | att_preseqlen=2, 903 | org_seqlen=196, 904 | att_mid_dim=8 905 | ) 906 | 907 | mock_input = torch.randn(2,3,8,244,244) 908 | mock_model = SwinTransformer3D(embed_dim=128, 909 | depths=[2, 2, 18, 2], 910 | num_heads=[4, 8, 16, 32], 911 | patch_size=(2,4,4), 912 | window_size=(8,7,7), 913 | drop_path_rate=0.4, 914 | patch_norm=True, 915 | tuning_config=tuning_config) 916 | print(mock_model(mock_input).shape) 917 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import PIL.Image 10 | import torch 11 | 12 | from torchvision import transforms 13 | from torchvision.transforms import functional as F 14 | 15 | 16 | class RandomResizedCrop(transforms.RandomResizedCrop): 17 | """ 18 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 19 | This may lead to results different with torchvision's version. 20 | Following BYOL's TF code: 21 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 22 | """ 23 | @staticmethod 24 | def get_params(img, scale, ratio): 25 | assert isinstance(img, PIL.Image.Image) 26 | # width, height = F._get_image_size(img) 27 | width, height = img.width, img.height 28 | area = height * width 29 | 30 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 31 | log_ratio = torch.log(torch.tensor(ratio)) 32 | aspect_ratio = torch.exp( 33 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 34 | ).item() 35 | 36 | w = int(round(math.sqrt(target_area * aspect_ratio))) 37 | h = int(round(math.sqrt(target_area / aspect_ratio))) 38 | 39 | w = min(w, width) 40 | h = min(h, height) 41 | 42 | i = torch.randint(0, height - h + 1, size=(1,)).item() 43 | j = torch.randint(0, width - w + 1, size=(1,)).item() 44 | 45 | return i, j, h, w -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | # force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | #checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | checkpoint_paths = [output_dir / 'checkpoint-best.pth'] 301 | for checkpoint_path in checkpoint_paths: 302 | to_save = { 303 | 'model': model_without_ddp.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'epoch': epoch, 306 | 'scaler': loss_scaler.state_dict(), 307 | 'args': args, 308 | } 309 | 310 | save_on_master(to_save, checkpoint_path) 311 | else: 312 | client_state = {'epoch': epoch} 313 | #model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 314 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-best", client_state=client_state) 315 | 316 | 317 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 318 | if args.resume: 319 | if args.resume.startswith('https'): 320 | checkpoint = torch.hub.load_state_dict_from_url( 321 | args.resume, map_location='cpu', check_hash=True) 322 | else: 323 | checkpoint = torch.load(args.resume, map_location='cpu') 324 | if 'model' in checkpoint: 325 | _ckp = checkpoint['model'] 326 | elif 'module' in checkpoint: 327 | _ckp = checkpoint['module'] 328 | else: 329 | _ckp = checkpoint 330 | model_without_ddp.load_state_dict(_ckp) 331 | print("Resume checkpoint %s" % args.resume) 332 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 333 | optimizer.load_state_dict(checkpoint['optimizer']) 334 | args.start_epoch = checkpoint['epoch'] + 1 335 | if 'scaler' in checkpoint: 336 | loss_scaler.load_state_dict(checkpoint['scaler']) 337 | print("With optim & sched!") 338 | 339 | 340 | def all_reduce_mean(x): 341 | world_size = get_world_size() 342 | if world_size > 1: 343 | x_reduce = torch.tensor(x).cuda() 344 | dist.all_reduce(x_reduce) 345 | x_reduce /= world_size 346 | return x_reduce.item() 347 | else: 348 | return x 349 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model, args): 76 | # video 77 | if 'pos_embed' in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 79 | embedding_size = pos_embed_checkpoint.shape[-1] # channel dim 80 | num_patches = model.patch_embed.num_patches # 81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 82 | 83 | # height (== width) for the checkpoint position embedding 84 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 85 | # height (== width) for the new position embedding 86 | new_size = int((num_patches // (args.num_frames // model.patch_embed.tubelet_size)) ** 0.5) 87 | # class_token and dist_token are kept unchanged 88 | if orig_size != new_size: 89 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 90 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 91 | # only the position tokens are interpolated 92 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 93 | # B, L, C -> BT, H, W, C -> BT, C, H, W 94 | pos_tokens = pos_tokens.reshape(-1, args.num_frames // model.patch_embed.tubelet_size, orig_size, orig_size, 95 | embedding_size) 96 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 97 | pos_tokens = torch.nn.functional.interpolate( 98 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 99 | # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C 100 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, args.num_frames // model.patch_embed.tubelet_size, 101 | new_size, new_size, embedding_size) 102 | pos_tokens = pos_tokens.flatten(1, 3) # B, L, C 103 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 104 | checkpoint_model['pos_embed'] = new_pos_embed 105 | 106 | def interpolate_pos_embed_ori(model, checkpoint_model): 107 | if 'pos_embed' in checkpoint_model: 108 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 109 | embedding_size = pos_embed_checkpoint.shape[-1] 110 | num_patches = model.patch_embed.num_patches 111 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 112 | # height (== width) for the checkpoint position embedding 113 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 114 | # height (== width) for the new position embedding 115 | new_size = int(num_patches ** 0.5) 116 | # class_token and dist_token are kept unchanged 117 | if orig_size != new_size: 118 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 119 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 120 | # only the position tokens are interpolated 121 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 122 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 123 | pos_tokens = torch.nn.functional.interpolate( 124 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 125 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 126 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 127 | checkpoint_model['pos_embed'] = new_pos_embed 128 | --------------------------------------------------------------------------------