├── README.md ├── data_trainer.py ├── data_transform.py ├── dataset.py ├── demo ├── 9r8wpMS2iEk_000048_000058.mp4 ├── YABnJL_bDzw.mp4 ├── kinetics400_train_list_videos_25fps.txt ├── kinetics400_val_list_videos_25fps.txt └── log_arch_timesformer_lr5e-3_bs8_nw4_open.txt ├── k400_classmap.json ├── k600_classmap.json ├── mask_generator.py ├── mixup.py ├── model_pretrain.py ├── model_trainer.py ├── notebook └── VideoTransformer_demo.ipynb ├── optimizer.py ├── requirements.txt ├── transformer.py ├── utils.py ├── video_transformer.py ├── visualize_attention.py └── weight_init.py /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Video Transformer Benchmarks 2 | This repository is mainly built upon [Pytorch](https://pytorch.org/) and [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/). We wish to maintain a collections of scalable video transformer benchmarks, and discuss the training recipes of how to train a big video transformer model. 3 | 4 | Now, we implement the [TimeSformer](https://arxiv.org/abs/2102.05095), [ViViT](https://arxiv.org/abs/2103.15691) and [MaskFeat](https://arxiv.org/abs/2112.09133). And we have pre-trained the `TimeSformer-B`, `ViViT-B` and `MaskFeat` on [Kinetics400/600](https://deepmind.com/research/open-source/kinetics), but still can't guarantee the performance reported in the paper. However, we find some relevant hyper-parameters which may help us to reach the target performance. 5 | 6 | ## Update 7 | 1. We have fixed serval known issues and now can build script to pretrain `MViT-B` with `MaskFeat` or finetune `MViT-B`/`TimeSformer-B`/`ViViT-B` on K400. 8 | 2. We have reimplemented the methods of hog extraction and hog prediction in [MaskFeat](https://arxiv.org/abs/2112.09133) which are currently more efficient to pretrain. 9 | 3. Note that if someone want to train `TimeSformer-B` or `ViViT-B` with current repo, they need to carefully adjust the learning rate and weight decay for a better performance. For example, you can can choose 0.005 for peak learning rate and 0.0001 for weight decay by default. 10 | 11 | ## Table of Contents 12 | 1. [Difference](#difference) 13 | 2. [TODO](#todo) 14 | 3. [Setup](#setup) 15 | 4. [Usage](#usage) 16 | 5. [Result](#result) 17 | 6. [Acknowledge](#acknowledge) 18 | 7. [Contribution](#contribution) 19 | 20 | ## Difference 21 | In order to share the basic divided spatial-temporal attention module to different video transformer, we make some changes in the following apart. 22 | 23 | ### 1. Position embedding 24 | 25 | We split the `position embedding` from *R(nt\*h\*w×d)* mentioned in the [ViViT](https://arxiv.org/abs/2103.15691) paper into *R(nh\*w×d)* 26 | and *R(nt×d)* to stay the same as [TimeSformer](https://arxiv.org/abs/2102.05095). 27 | 28 | ### 2. Class token 29 | 30 | In order to make clear whether to add the `class_token` into the module forward computation, we only compute the interaction between `class_token` and `query` when the current layer is the last layer (except `FFN`) of each transformer block. 31 | 32 | ### 3. Initialize from the pre-trained model 33 | 34 | * Tokenization: the token embedding filter can be chosen either `Conv2D` or `Conv3D`, and the initializing weights of `Conv3D` filters from `Conv2D` can be replicated along temporal dimension and averaging them or initialized with zeros along the temporal positions except at the center `t/2`. 35 | * Temporal `MSA` module weights: one can choose to copy the weights from spatial `MSA` module or initialize all weights with zeros. 36 | * Initialize from the `MAE` pre-trained model provided by [ZhiLiang](https://github.com/pengzhiliang/MAE-pytorch), where the class_token that does not appear in the `MAE` pre-train model is initialized from truncated normal distribution. 37 | * Initialize from the `ViT` pre-trained model can be found [here](https://drive.google.com/file/d/1QjGpbR8K4Cf4TJaDc60liVhBvPtrc2v4/view?usp=sharing). 38 | 39 | ## TODO 40 | - [√] add more `TimeSformer` and `ViViT` variants pre-trained weights. 41 | - A larger version and other operation types. 42 | - [√] add `linear prob` and `finetune recipe`. 43 | - Make available to transfer the pre-trained model to downstream task. 44 | - [ ] add more scalable Video Transformer benchmarks. 45 | - We will mainly focus on the data-efficient models. 46 | - [ ] add more robust objective functions. 47 | - Pre-train the model through the dominated self-supervised methods, e.g [Mask Image Modeling](https://arxiv.org/abs/2111.06377). 48 | 49 | ## Setup 50 | ```shell 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ## Usage 55 | ### Training 56 | ```shell 57 | # path to Kinetics400 train set and val set 58 | TRAIN_DATA_PATH='/path/to/Kinetics400/train_list.txt' 59 | VAL_DATA_PATH='/path/to/Kinetics400/val_list.txt' 60 | # path to root directory 61 | ROOT_DIR='/path/to/work_space' 62 | # path to pretrain weights 63 | PRETRAIN_WEIGHTS='/path/to/weights' 64 | 65 | # pretrain mvit using maskfeat 66 | python model_pretrain.py \ 67 | -lr 8e-4 -epoch 300 -batch_size 16 -num_workers 8 -frame_interval 4 -num_frames 16 -num_class 400 \ 68 | -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH 69 | 70 | # finetune mvit with maskfeat pretrain weights 71 | python model_pretrain.py \ 72 | -lr 0.005 -epoch 200 -batch_size 8 -num_workers 4 -num_frames 16 -frame_interval 4 -num_class 400 \ 73 | -arch 'mvit' -optim_type 'adamw' -lr_schedule 'cosine' -objective 'supervised' -mixup True \ 74 | -auto_augment 'rand_aug' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \ 75 | -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS 76 | 77 | # finetune timesformer with imagenet pretrain weights 78 | python model_pretrain.py \ 79 | -lr 0.005 -epoch 30 -batch_size 8 -num_workers 4 -num_frames 8 -frame_interval 32 -num_class 400 \ 80 | -arch 'timesformer' -attention_type 'divided_space_time' -optim_type 'sgd' -lr_schedule 'cosine' \ 81 | -objective 'supervised' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \ 82 | -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS -weights_from 'imagenet' 83 | 84 | # finetune vivit with imagenet pretrain weights 85 | python model_pretrain.py \ 86 | -lr 0.005 -epoch 30 -batch_size 8 -num_workers 4 -num_frames 16 -frame_interval 16 -num_class 400 \ 87 | -arch 'vivit' -attention_type 'fact_encoder' -optim_type 'sgd' -lr_schedule 'cosine' \ 88 | -objective 'supervised' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \ 89 | -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS -weights_from 'imagenet' 90 | 91 | ``` 92 | The minimal folder structure will look like as belows. 93 | ``` 94 | root_dir 95 | ├── results 96 | │ ├── experiment_tag 97 | │ │ ├── ckpt 98 | │ │ ├── log 99 | ``` 100 | 101 | ## Result 102 | ### Kinetics-400/600 103 | 104 | #### 1. Model Zoo 105 | 106 | | name | weights from | dataset | epochs | num frames | spatial crop | top1_acc | top5_acc | weight | log | 107 | |:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| 108 | | TimeSformer-B | ImageNet-21K | K600 | 15e | 8 | 224 | 78.4 | 93.6 | [Google drive](https://drive.google.com/file/d/1-BSNROh35fiOIBcmtFNgWHEY_JC5UNDx/view?usp=sharing) or [BaiduYun](https://pan.baidu.com/s/1I5L41ZFHHSvFJttYt8F0Og)(code: yr4j) | [log](demo/log_arch_timesformer_lr5e-3_bs8_nw4_open.txt) | 109 | | ViViT-B | ImageNet-21K | K400 | 30e | 16 | 224 | 75.2 | 91.5 | [Google drive](https://drive.google.com/file/d/1-JVhSN3QHKUOLkXLWXWn5drdvKn0gPll/view?usp=sharing) | | 110 | | MaskFeat | from scratch | K400 | 100e | 16 | 224 | | | [Google drive](https://drive.google.com/file/d/1h3Q-267qV9kIcTT9Sct-zQzVvXljhyWW/view?usp=sharing) | | 111 | 112 | #### 1.1 Visualize 113 | 114 | For each column, we show the masked input(left), HOG predictions(middle) and original video frame(right). 115 |

116 | 117 |

118 | 119 | Here, we show the extracted attention map of a random frame sampled from the demo video. 120 |

121 | 122 |

123 | 124 |
125 | 126 | #### 2. Train Recipe(ablation study) 127 | #### 2.1 Acc 128 | 129 | | operation | top1_acc | top5_acc | top1_acc (three crop) | 130 | |:----|:----:|:----:|:----:| 131 | | base | 68.2 | 87.6 | - | 132 | | + `frame_interval` 4 -> 16 (span more time) | 72.9(+4.7) | 91.0(+3.4) | - | 133 | | + RandomCrop, flip (overcome overfit) | 75.7(+2.8) | 92.5(+1.5) | - | 134 | | + `batch size` 16 -> 8 (more iterations) | 75.8(+0.1) | 92.4(-0.1) | - | 135 | | + `frame_interval` 16 -> 24 (span more time) | 77.7(+1.9) | 93.3(+0.9) | 78.4 | 136 | | + `frame_interval` 24 -> 32 (span more time) | 78.4(+0.7) | 94.0(+0.7) | 79.1 | 137 | 138 | tips: `frame_interval` and `data augment` counts for the validation accuracy. 139 | 140 |
141 | 142 | #### 2.2 Time 143 | 144 | | operation | epoch_time | 145 | |:----|:----:| 146 | | base (start with DDP) | 9h+ | 147 | | + `speed up training recipes` | 1h+ | 148 | | + switch from `get_batch first` to `sample_Indice first` | 0.5h | 149 | | + `batch size` 16 -> 8 | 33.32m | 150 | | + `num_workers` 8 -> 4 | 35.52m | 151 | | + `frame_interval` 16 -> 24 | 44.35m | 152 | 153 | tips: Improve the `frame_interval` will drop a lot on time performance. 154 | 155 | 1.`speed up training recipes`: 156 | * More GPU device. 157 | * `pin_memory=True`. 158 | * Avoid CPU->GPU Device transfer (such as `.item()`, `.numpy()`, `.cpu()` operations on tensor or `log` to disk). 159 | 160 | 2.`get_batch first` means that we firstly read all frames through the video reader, and then get the target slice of frames, so it largely slow down the data-loading speed. 161 | 162 |
163 | 164 | 165 | ## Acknowledge 166 | this repo is built on top of [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), [pytorchvideo](https://github.com/facebookresearch/pytorchvideo/tree/9d0ca900f0427ed9b47b6182ad05f75c0e66274b), [skimage](https://github.com/scikit-image/scikit-image), [decord](https://github.com/dmlc/decord) and [kornia](https://github.com/kornia/kornia). I also learn many code designs from [MMaction2](https://github.com/open-mmlab/mmaction2). I thank the authors for releasing their code. 167 | 168 | ## Contribution 169 | I look forward to seeing one can provide some ideas about the repo, please feel free to report it in the issue, or even better, submit a pull request. 170 | 171 | And your star is my motivation, thank u~ 172 | -------------------------------------------------------------------------------- /data_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data.dataloader import DataLoader 6 | 7 | from dataset import Kinetics 8 | import data_transform as T 9 | 10 | class Collator(object): 11 | 12 | def __init__(self, objective): 13 | self.objective = objective 14 | 15 | def collate(self, minibatch): 16 | image_list = [] 17 | label_list = [] 18 | mask_list = [] 19 | marker_list = [] 20 | for record in minibatch: 21 | image_list.append(record[0]) 22 | label_list.append(record[1]) 23 | if self.objective == 'mim': 24 | mask_list.append(record[2]) 25 | marker_list.append(record[3]) 26 | minibatch = [] 27 | minibatch.append(torch.stack(image_list)) 28 | if self.objective == 'mim': 29 | minibatch.append(torch.stack(label_list)) 30 | minibatch.append(torch.stack(mask_list)) 31 | minibatch.append(marker_list) 32 | else: 33 | label = np.stack(label_list) 34 | minibatch.append(torch.from_numpy(label)) 35 | 36 | return minibatch 37 | 38 | class KineticsDataModule(pl.LightningDataModule): 39 | def __init__(self, 40 | configs, 41 | train_ann_path, 42 | val_ann_path=None, 43 | test_ann_path=None, 44 | ): 45 | super().__init__() 46 | self.train_ann_path = train_ann_path 47 | self.val_ann_path = val_ann_path 48 | self.test_ann_path = test_ann_path 49 | self.configs = configs 50 | 51 | def get_dataset(self, annotation_path, transform, temporal_sample): 52 | dataset = Kinetics( 53 | self.configs, 54 | annotation_path, 55 | transform=transform, 56 | temporal_sample=temporal_sample) 57 | 58 | return dataset 59 | 60 | def setup(self, stage): 61 | if self.configs.objective == 'mim': 62 | scale = (0.5, 1.0) 63 | color_jitter = None 64 | else: 65 | color_jitter = 0.4 66 | scale = None 67 | 68 | if self.configs.data_statics == 'imagenet': 69 | mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 70 | elif self.configs.data_statics == 'kinetics': 71 | mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225) 72 | else: 73 | mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 74 | 75 | train_transform = T.create_video_transform( 76 | objective=self.configs.objective, 77 | input_size=self.configs.img_size, 78 | is_training=True, 79 | scale=scale, 80 | hflip=0.5, 81 | color_jitter=color_jitter, 82 | auto_augment=self.configs.auto_augment, 83 | interpolation='bicubic', 84 | mean=mean, 85 | std=std) 86 | train_temporal_sample = T.TemporalRandomCrop( 87 | self.configs.num_frames * self.configs.frame_interval) 88 | 89 | self.train_dataset = self.get_dataset( 90 | self.train_ann_path, 91 | train_transform, 92 | train_temporal_sample) 93 | 94 | if self.val_ann_path is not None: 95 | val_transform = T.create_video_transform( 96 | input_size=self.configs.img_size, 97 | is_training=False, 98 | interpolation='bicubic', 99 | mean=mean, 100 | std=std) 101 | val_temporal_sample = T.TemporalRandomCrop( 102 | self.configs.num_frames * self.configs.frame_interval) 103 | self.val_dataset = self.get_dataset( 104 | self.val_ann_path, 105 | val_transform, 106 | val_temporal_sample) 107 | 108 | if self.test_ann_path is not None: 109 | # need to update 110 | test_transform = T.Compose([ 111 | T.Resize(scale_range=(-1, 256)), 112 | T.ThreeCrop(size=self.configs.img_size), 113 | T.ToTensor(), 114 | T.Normalize(mean, std), 115 | ]) 116 | test_temporal_sample = T.TemporalRandomCrop( 117 | self.configs.num_frames * self.configs.frame_interval) 118 | self.test_dataset = self.get_dataset( 119 | self.test_ann_path, 120 | test_transform, 121 | test_temporal_sample) 122 | 123 | def train_dataloader(self): 124 | return DataLoader( 125 | self.train_dataset, 126 | batch_size=self.configs.batch_size, 127 | num_workers=self.configs.num_workers, 128 | collate_fn=Collator(self.configs.objective).collate, 129 | shuffle=True, 130 | drop_last=True, 131 | pin_memory=True 132 | ) 133 | 134 | def val_dataloader(self): 135 | if self.val_ann_path is not None: 136 | return DataLoader( 137 | self.val_dataset, 138 | batch_size=self.configs.batch_size, 139 | num_workers=self.configs.num_workers, 140 | collate_fn=Collator(self.configs.objective).collate, 141 | shuffle=False, 142 | drop_last=False, 143 | ) 144 | 145 | def test_dataloader(self): 146 | if self.test_ann_path is not None: 147 | return DataLoader( 148 | self.test_dataset, 149 | batch_size=self.configs.batch_size, 150 | num_workers=self.configs.num_workers, 151 | collate_fn=Collator(self.configs.objective).collate, 152 | shuffle=False, 153 | drop_last=False, 154 | ) -------------------------------------------------------------------------------- /data_transform.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | import random 3 | import math 4 | 5 | from einops import rearrange 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from torchvision import transforms 11 | from torchvision.transforms.functional import InterpolationMode 12 | 13 | DEFAULT_CROP_PCT = 0.875 14 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 15 | _torch_interpolation_to_str = { 16 | InterpolationMode.NEAREST: 'nearest', 17 | InterpolationMode.BILINEAR: 'bilinear', 18 | InterpolationMode.BICUBIC: 'bicubic', 19 | InterpolationMode.BOX: 'box', 20 | InterpolationMode.HAMMING: 'hamming', 21 | InterpolationMode.LANCZOS: 'lanczos', 22 | } 23 | _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()} 24 | 25 | def str_to_interp_mode(mode_str): 26 | return _str_to_torch_interpolation[mode_str] 27 | 28 | # ------------------------------------------------------------ 29 | # ---------------------- Common ---------------------------- 30 | # ------------------------------------------------------------ 31 | class Compose(object): 32 | """Composes several transforms together. 33 | 34 | Args: 35 | transforms (list of transform objects): list of data transforms to compose. 36 | """ 37 | 38 | def __init__(self, transforms): 39 | self.transforms = transforms 40 | 41 | def __call__(self, img): 42 | for t in self.transforms: 43 | img = t(img) 44 | return img 45 | 46 | def randomize_parameters(self): 47 | for t in self.transforms: 48 | if hasattr(t, 'randomize_parameters'): 49 | t.randomize_parameters() 50 | 51 | 52 | class ToTensor(object): 53 | """Convert a tensor to torch.FloatTensor in the range [0.0, 1.0]. 54 | 55 | Args: 56 | norm_value (int): the max value of the input image tensor, default to 255. 57 | """ 58 | 59 | def __init__(self, norm_value=255): 60 | self.norm_value = norm_value 61 | 62 | def __call__(self, pic): 63 | if isinstance(pic, torch.Tensor): 64 | return pic.float().div(self.norm_value) 65 | 66 | def randomize_parameters(self): 67 | pass 68 | 69 | 70 | # ------------------------------------------------------------ 71 | # ------------------- Transformation ----------------------- 72 | # ------------------------------------------------------------ 73 | class RandomCrop(object): 74 | """Random crop a fixed size region in a given image. 75 | 76 | Args: 77 | size (int, Tuple[int]): Desired output size (out_h, out_w) of the crop 78 | """ 79 | 80 | def __init__(self, size): 81 | if isinstance(size, tuple): 82 | if size[0] != size[1]: 83 | raise ValueError(f'crop size {size[0], size[1]}, must be equal.') 84 | else: 85 | self.size = size[0] 86 | else: 87 | self.size = size 88 | 89 | def __call__(self, imgs): 90 | # Crop size 91 | size = self.size 92 | 93 | # Location 94 | img_height, img_width = imgs.size(2), imgs.size(3) 95 | y_offset = int(self.y_jitter * (img_height - size)) 96 | x_offset = int(self.x_jitter * (img_width - size)) 97 | 98 | imgs = imgs[..., y_offset : y_offset + size, x_offset : x_offset + size] 99 | return imgs 100 | 101 | def __repr__(self): 102 | repr_str = (f'{self.__class__.__name__}(' 103 | f'size={self.size})') 104 | return repr_str 105 | 106 | def randomize_parameters(self): 107 | self.x_jitter = random.random() 108 | self.y_jitter = random.random() 109 | 110 | 111 | class Resize(object): 112 | """Resize images to a specific size. 113 | 114 | Args: 115 | scale_range (Tuple[int]): If the first value equals to -1, the second value 116 | serves as a short edge of the resized image: else if it is a tuple of 2 117 | integers, the short edge of resized image will be random choice from 118 | [scale_range[0], scale_range[1]]. 119 | """ 120 | 121 | def __init__(self, scale_range): 122 | if not isinstance(scale_range, tuple): 123 | raise ValueError(f'Scale_range {scale_range}, must be tuple.') 124 | self.scale_range = scale_range 125 | 126 | def __call__(self, imgs): 127 | imgs = self._resize(imgs) 128 | return imgs 129 | 130 | def __repr__(self): 131 | repr_str = (f'{self.__class__.__name__}(' 132 | f'size={self.size})') 133 | return repr_str 134 | 135 | def randomize_parameters(self): 136 | if self.scale_range[0] == -1: 137 | self._resize = transforms.Resize(self.scale_range[1]) 138 | else: 139 | short_edge = np.random.randint(self.scale_range[0], 140 | self.scale_range[1]+1) 141 | self._resize = transforms.Resize(short_edge) 142 | 143 | 144 | class RandomResizedCrop: 145 | """Random crop that specifics the area and height-weight ratio range. 146 | 147 | Args: 148 | area_range (Tuple[float]): The candidate area scales range of 149 | output cropped images. Default: (0.08, 1.0). 150 | aspect_ratio_range (Tuple[float]): The candidate aspect ratio range of 151 | output cropped images. Default: (3 / 4, 4 / 3). 152 | """ 153 | 154 | def __init__(self, 155 | size, 156 | interpolation=3, 157 | scale=(0.08, 1.0), 158 | ratio=(3 / 4, 4 / 3)): 159 | self.size = size 160 | self.area_range = scale 161 | self.aspect_ratio_range = ratio 162 | self.interpolation = interpolation 163 | 164 | def __call__(self, imgs): 165 | """Performs the RandomResizeCrop augmentation. 166 | 167 | Args: 168 | results (dict): The resulting dict to be modified and passed 169 | to the next transform in pipeline. 170 | """ 171 | # version one- frame diverse 172 | #imgs = self._crop_imgs(imgs) 173 | 174 | # version two- frame consistent 175 | img_width = imgs.shape[-1] 176 | img_height = imgs.shape[-2] 177 | # crop size 178 | min_length = min(img_width, img_height) 179 | crop_size = int(min_length * self.scale) 180 | width = crop_size 181 | height = crop_size*self.ratio 182 | 183 | # location 184 | left = self.tl_x * (img_width - width) 185 | top = self.tl_y * (img_height - height) 186 | 187 | imgs = transforms.functional.resized_crop( 188 | imgs, int(top), int(left), int(height), int(width), self.size, interpolation=self.interpolation) 189 | 190 | return imgs 191 | 192 | def __repr__(self): 193 | repr_str = (f'{self.__class__.__name__}(' 194 | f'area_range={self.area_range}, ' 195 | f'aspect_ratio_range={self.aspect_ratio_range}, ' 196 | f'size={self.size})') 197 | return repr_str 198 | 199 | def randomize_parameters(self): 200 | self.scale = random.uniform(self.area_range[0], self.area_range[1]) 201 | self.ratio = random.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1]) 202 | ''' 203 | # version one- frame diverse 204 | self._crop_imgs = transforms.RandomResizedCrop( 205 | self.size, scale=(scale, scale), ratio=(ratio, ratio)) 206 | ''' 207 | # version two- frame consistent 208 | self.tl_x = random.random() 209 | self.tl_y = random.random() 210 | 211 | 212 | class Flip(object): 213 | """Flip the input images with a probability. 214 | 215 | Args: 216 | flip_ratio (float): Probability of implementing flip. Default: 0.5. 217 | """ 218 | 219 | def __init__(self, 220 | flip_ratio=0.5): 221 | self.flip_ratio = flip_ratio 222 | 223 | def __call__(self, imgs): 224 | imgs = self._flip(imgs) 225 | return imgs 226 | 227 | def __repr__(self): 228 | repr_str = ( 229 | f'{self.__class__.__name__}(' 230 | f'flip_ratio={self.flip_ratio})') 231 | return repr_str 232 | 233 | def randomize_parameters(self): 234 | p = random.random() 235 | if p > self.flip_ratio: 236 | self._flip = transforms.RandomHorizontalFlip(p=1) 237 | else: 238 | self._flip = transforms.RandomHorizontalFlip(p=0) 239 | 240 | 241 | class RandomGrayscale(object): 242 | """Flip the input images with a probability. 243 | 244 | Args: 245 | flip_ratio (float): Probability of implementing flip. Default: 0.5. 246 | """ 247 | 248 | def __init__(self, 249 | p=0.1): 250 | self.p = p 251 | 252 | def __call__(self, imgs): 253 | imgs = self._grayscale(imgs) 254 | return imgs 255 | 256 | def __repr__(self): 257 | repr_str = ( 258 | f'{self.__class__.__name__}(' 259 | f'p={self.p})') 260 | return repr_str 261 | 262 | def randomize_parameters(self): 263 | p = random.random() 264 | if p > self.p: 265 | self._grayscale = transforms.RandomGrayscale(p=0) 266 | else: 267 | self._grayscale = transforms.RandomGrayscale(p=1) 268 | 269 | 270 | class RandomApply(object): 271 | """Flip the input images with a probability. 272 | 273 | Args: 274 | flip_ratio (float): Probability of implementing flip. Default: 0.5. 275 | """ 276 | 277 | def __init__(self, 278 | transform, 279 | p=0.5): 280 | self.p = p 281 | self.transform = transform 282 | 283 | def __call__(self, imgs): 284 | imgs = self._random_apply(imgs) 285 | return imgs 286 | 287 | def __repr__(self): 288 | repr_str = ( 289 | f'{self.__class__.__name__}(' 290 | f'p={self.p})') 291 | return repr_str 292 | 293 | def randomize_parameters(self): 294 | p = random.random() 295 | if p > self.p: 296 | self._random_apply = transforms.RandomApply(self.transform, p=0) 297 | else: 298 | self._random_apply = transforms.RandomApply(self.transform, p=1) 299 | 300 | 301 | class Normalize(object): 302 | """Normalize the images with the given mean and std value. 303 | 304 | Args: 305 | mean (Sequence[float]): Mean values of different channels. 306 | std (Sequence[float]): Std values of different channels. 307 | """ 308 | 309 | def __init__(self, mean, std): 310 | if not isinstance(mean, Sequence): 311 | raise TypeError( 312 | f'Mean must be list, tuple or np.ndarray, but got {type(mean)}' 313 | ) 314 | 315 | if not isinstance(std, Sequence): 316 | raise TypeError( 317 | f'Std must be list, tuple or np.ndarray, but got {type(std)}') 318 | 319 | self._normalize = transforms.Normalize(mean, std) 320 | self.mean = mean 321 | self.std = std 322 | 323 | #@profile 324 | def __call__(self, imgs): 325 | imgs = self._normalize(imgs) 326 | return imgs 327 | 328 | def __repr__(self): 329 | repr_str = (f'{self.__class__.__name__}(' 330 | f'mean={self.mean}, ' 331 | f'std={self.std})') 332 | return repr_str 333 | 334 | def randomize_parameters(self): 335 | pass 336 | 337 | 338 | class ColorJitter(object): 339 | """Randomly distort the brightness, contrast, saturation and hue of images. 340 | 341 | Note: The input images should be in RGB channel order. 342 | 343 | Args: 344 | brightness (float): the std values of brightness distortion. 345 | contrast (float): the std values of contrast distortion. 346 | saturation (float): the std values of saturation distortion. 347 | hue (float): the std values of hue distortion. 348 | """ 349 | 350 | def __init__(self, 351 | brightness=0, 352 | contrast=0, 353 | saturation=0, 354 | hue=0): 355 | self.brightness = brightness 356 | self.contrast = contrast 357 | self.saturation = saturation 358 | self.hue = hue 359 | 360 | def __call__(self, imgs): 361 | print(imgs.shape) 362 | if imgs.ndim == 3: 363 | imgs = rearrange(imgs, '(t c) h w -> t c h w', c=3) 364 | imgs = self._color_jit(imgs) 365 | imgs = rearrange(imgs, 't c h w -> (t c) h w') 366 | return imgs 367 | 368 | def __repr__(self): 369 | repr_str = (f'{self.__class__.__name__}(' 370 | f'brightness={self.brightness}, ' 371 | f'contrast={self.contrast}, ' 372 | f'saturation={self.saturation}, ' 373 | f'hue={self.hue})') 374 | return repr_str 375 | 376 | def randomize_parameters(self): 377 | brightness = random.uniform(max(0,1-self.brightness), 1+self.brightness) 378 | contrast = random.uniform(max(0,1-self.contrast), 1+self.contrast) 379 | saturation = random.uniform(max(0,1-self.saturation), 1+self.saturation) 380 | hue = random.uniform(-self.hue, self.hue) 381 | 382 | self._color_jit = transforms.ColorJitter( 383 | brightness=(brightness,brightness), 384 | contrast=(contrast,contrast), 385 | saturation=(saturation,saturation), 386 | hue=(hue,hue)) 387 | 388 | 389 | class CenterCrop(object): 390 | """Crop the center area from images. 391 | 392 | Args: 393 | crop_size (int | tuple[int]): (w, h) of crop size. 394 | """ 395 | 396 | def __init__(self, size): 397 | self.size = size 398 | self._center_crop = transforms.CenterCrop(size=size) 399 | 400 | def __call__(self, imgs): 401 | imgs = self._center_crop(imgs) 402 | return imgs 403 | 404 | def __repr__(self): 405 | repr_str = (f'{self.__class__.__name__}(size={self.size})') 406 | return repr_str 407 | 408 | def randomize_parameters(self): 409 | pass 410 | 411 | 412 | class ThreeCrop(object): 413 | """Random crop the three pre-define regions of image. 414 | 415 | Args: 416 | size (int, Tuple[int]): Desired output size (out_h, out_w) of the crop 417 | """ 418 | 419 | def __init__(self, size): 420 | if isinstance(size, tuple): 421 | if size[0] != size[1]: 422 | raise ValueError(f'crop size {size[0], size[1]}, must be equal.') 423 | else: 424 | self.size = size[0] 425 | else: 426 | self.size = size 427 | 428 | def __call__(self, imgs): 429 | # Crop size 430 | size = int(self.size) 431 | img_height, img_width = imgs.size(2), imgs.size(3) 432 | if size > img_height or size > img_width: 433 | msg = "Requested crop size {} is bigger than input size {}" 434 | raise ValueError(msg.format(size, (img_height, img_width))) 435 | 436 | # Location 437 | crops = [] 438 | left_y_offset = (img_height - size) // 2 439 | left_x_offset = 0 440 | left = imgs[..., 441 | left_y_offset : left_y_offset + size, 442 | left_x_offset : left_x_offset + size] 443 | crops.append(left) 444 | 445 | right_y_offset = (img_height - size) // 2 446 | right_x_offset = img_width - size 447 | right = imgs[..., 448 | right_y_offset : right_y_offset + size, 449 | right_x_offset : right_x_offset + size] 450 | crops.append(right) 451 | 452 | center_y_offset = (img_height - size) // 2 453 | center_x_offset = (img_width - size) // 2 454 | center = imgs[..., 455 | center_y_offset : center_y_offset + size, 456 | center_x_offset : center_x_offset + size] 457 | crops.append(center) 458 | 459 | # (N_Crops T C H W) 460 | imgs = torch.stack(crops) 461 | return imgs 462 | 463 | def __repr__(self): 464 | repr_str = (f'{self.__class__.__name__}(' 465 | f'size={self.size})') 466 | return repr_str 467 | 468 | def randomize_parameters(self): 469 | pass 470 | 471 | 472 | # ------------------------------------------------------------ 473 | # --------------------- Sampling --------------------------- 474 | # ------------------------------------------------------------ 475 | class TemporalRandomCrop(object): 476 | """Temporally crop the given frame indices at a random location. 477 | 478 | Args: 479 | size (int): Desired length of frames will be seen in the model. 480 | """ 481 | 482 | def __init__(self, size): 483 | self.size = size 484 | 485 | def __call__(self, total_frames): 486 | rand_end = max(0, total_frames - self.size - 1) 487 | begin_index = random.randint(0, rand_end) 488 | end_index = min(begin_index + self.size, total_frames) 489 | return begin_index, end_index 490 | 491 | 492 | # ------------------------------------------------------------ 493 | # --------------------- AdvancedAugment -------------------- 494 | # ------------------------------------------------------------ 495 | def transforms_train(img_size=224, 496 | scale=None, 497 | ratio=None, 498 | hflip=0.5, 499 | color_jitter=0.4, 500 | auto_augment=None, 501 | interpolation='random', 502 | mean=IMAGENET_DEFAULT_MEAN, 503 | std=IMAGENET_DEFAULT_STD, 504 | objective='supervised'): 505 | """ 506 | If separate==True, the transforms are returned as a tuple of 3 separate transforms 507 | for use in a mixing dataset that passes 508 | * all data through the first (primary) transform, called the 'clean' data 509 | * a portion of the data through the secondary transform 510 | * normalizes and converts the branches above with the third, final transform 511 | """ 512 | scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range 513 | ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range 514 | primary_tfl = [ 515 | transforms.RandomResizedCrop(img_size, scale=scale, ratio=ratio, interpolation=str_to_interp_mode(interpolation))] 516 | if hflip > 0.: 517 | primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] 518 | 519 | secondary_tfl = [] 520 | if auto_augment: 521 | secondary_tfl += [transforms.autoaugment.RandAugment()] 522 | elif color_jitter is not None: 523 | # color jitter is enabled when not using AA 524 | if isinstance(color_jitter, (list, tuple)): 525 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation 526 | # or 4 if also augmenting hue 527 | assert len(color_jitter) in (3, 4) 528 | else: 529 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue 530 | color_jitter = (float(color_jitter),) * 3 531 | secondary_tfl += [transforms.ColorJitter(*color_jitter)] 532 | 533 | final_tfl = [] 534 | final_tfl += [ 535 | ToTensor(), 536 | transforms.Normalize( 537 | mean=torch.tensor(mean), 538 | std=torch.tensor(std)) 539 | ] 540 | if objective == 'mim': 541 | return [Compose(primary_tfl + secondary_tfl), Compose(final_tfl)] 542 | else: 543 | return Compose(primary_tfl + secondary_tfl + final_tfl) 544 | 545 | 546 | def transforms_eval(img_size=224, 547 | crop_pct=None, 548 | interpolation='bilinear', 549 | mean=IMAGENET_DEFAULT_MEAN, 550 | std=IMAGENET_DEFAULT_STD): 551 | crop_pct = crop_pct or DEFAULT_CROP_PCT 552 | 553 | if isinstance(img_size, (tuple, list)): 554 | assert len(img_size) == 2 555 | if img_size[-1] == img_size[-2]: 556 | # fall-back to older behaviour so Resize scales to shortest edge if target is square 557 | scale_size = int(math.floor(img_size[0] / crop_pct)) 558 | else: 559 | scale_size = tuple([int(x / crop_pct) for x in img_size]) 560 | else: 561 | scale_size = int(math.floor(img_size / crop_pct)) 562 | 563 | tfl = [ 564 | transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), 565 | transforms.CenterCrop(img_size), 566 | ] 567 | tfl += [ 568 | ToTensor(), 569 | transforms.Normalize( 570 | mean=torch.tensor(mean), 571 | std=torch.tensor(std)) 572 | ] 573 | 574 | return Compose(tfl) 575 | 576 | 577 | def create_video_transform(input_size=224, 578 | is_training=False, 579 | scale=None, 580 | ratio=None, 581 | hflip=0.5, 582 | color_jitter=0.4, 583 | auto_augment=None, 584 | interpolation='bilinear', 585 | mean=IMAGENET_DEFAULT_MEAN, 586 | std=IMAGENET_DEFAULT_STD, 587 | objective='supervised', 588 | crop_pct=None): 589 | 590 | if isinstance(input_size, (tuple, list)): 591 | img_size = input_size[-2:] 592 | else: 593 | img_size = input_size 594 | 595 | if is_training: 596 | transform = transforms_train( 597 | img_size, 598 | scale=scale, 599 | ratio=ratio, 600 | hflip=hflip, 601 | color_jitter=color_jitter, 602 | auto_augment=auto_augment, 603 | interpolation=interpolation, 604 | mean=mean, 605 | std=std, 606 | objective=objective) 607 | else: 608 | transform = transforms_eval( 609 | img_size, 610 | interpolation=interpolation, 611 | mean=mean, 612 | std=std, 613 | crop_pct=crop_pct) 614 | 615 | return transform 616 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import decord 5 | import numpy as np 6 | import torch 7 | 8 | from einops import rearrange 9 | from skimage.feature import hog 10 | from mask_generator import CubeMaskGenerator 11 | 12 | class_labels_map = None 13 | cls_sample_cnt = None 14 | 15 | def temporal_sampling(frames, start_idx, end_idx, num_samples): 16 | """ 17 | Given the start and end frame index, sample num_samples frames between 18 | the start and end with equal interval. 19 | Args: 20 | frames (tensor): a tensor of video frames, dimension is 21 | `num video frames` x `channel` x `height` x `width`. 22 | start_idx (int): the index of the start frame. 23 | end_idx (int): the index of the end frame. 24 | num_samples (int): number of frames to sample. 25 | Returns: 26 | frames (tersor): a tensor of temporal sampled video frames, dimension is 27 | `num clip frames` x `channel` x `height` x `width`. 28 | """ 29 | index = torch.linspace(start_idx, end_idx, num_samples) 30 | index = torch.clamp(index, 0, frames.shape[0] - 1).long() 31 | frames = torch.index_select(frames, 0, index) 32 | return frames 33 | 34 | 35 | def numpy2tensor(x): 36 | return torch.from_numpy(x) 37 | 38 | 39 | def extract_hog_features(image): 40 | hog_features_r = hog(image[:,:,0], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False) 41 | hog_features_g = hog(image[:,:,1], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False) 42 | hog_features_b = hog(image[:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False) #visualize=True 43 | hog_features = np.concatenate([hog_features_r,hog_features_g,hog_features_b], axis=-1) 44 | hog_features = rearrange(hog_features, '(ph dh) (pw dw) ch cw c -> ph pw (dh dw ch cw c)', ph=14, pw=14) 45 | return hog_features 46 | 47 | 48 | def load_annotation_data(data_file_path): 49 | with open(data_file_path, 'r') as data_file: 50 | return json.load(data_file) 51 | 52 | 53 | def get_class_labels(num_class, anno_pth='./k400_classmap.json'): 54 | global class_labels_map, cls_sample_cnt 55 | 56 | if class_labels_map is not None: 57 | return class_labels_map, cls_sample_cnt 58 | else: 59 | cls_sample_cnt = {} 60 | class_labels_map = load_annotation_data(anno_pth) 61 | for cls in class_labels_map: 62 | cls_sample_cnt[cls] = 0 63 | return class_labels_map, cls_sample_cnt 64 | 65 | 66 | def load_annotations(ann_file, num_class, num_samples_per_cls): 67 | dataset = [] 68 | class_to_idx, cls_sample_cnt = get_class_labels(num_class) 69 | with open(ann_file, 'r') as fin: 70 | for line in fin: 71 | line_split = line.strip().split('\t') 72 | sample = {} 73 | idx = 0 74 | # idx for frame_dir 75 | frame_dir = line_split[idx] 76 | sample['video'] = frame_dir 77 | idx += 1 78 | 79 | # idx for label[s] 80 | label = [x for x in line_split[idx:]] 81 | assert label, f'missing label in line: {line}' 82 | assert len(label) == 1 83 | class_name = label[0] 84 | class_index = int(class_to_idx[class_name]) 85 | 86 | # choose a class subset of whole dataset 87 | if class_index < num_class: 88 | sample['label'] = class_index 89 | if cls_sample_cnt[class_name] < num_samples_per_cls: 90 | dataset.append(sample) 91 | cls_sample_cnt[class_name]+=1 92 | 93 | return dataset 94 | 95 | 96 | class DecordInit(object): 97 | """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" 98 | 99 | def __init__(self, num_threads=1, **kwargs): 100 | self.num_threads = num_threads 101 | self.ctx = decord.cpu(0) 102 | self.kwargs = kwargs 103 | 104 | def __call__(self, filename): 105 | """Perform the Decord initialization. 106 | Args: 107 | results (dict): The resulting dict to be modified and passed 108 | to the next transform in pipeline. 109 | """ 110 | reader = decord.VideoReader(filename, 111 | ctx=self.ctx, 112 | num_threads=self.num_threads) 113 | return reader 114 | 115 | def __repr__(self): 116 | repr_str = (f'{self.__class__.__name__}(' 117 | f'sr={self.sr},' 118 | f'num_threads={self.num_threads})') 119 | return repr_str 120 | 121 | 122 | class Kinetics(torch.utils.data.Dataset): 123 | """Load the Kinetics video files 124 | 125 | Args: 126 | annotation_path (string): Annotation file path. 127 | num_class (int): The number of the class. 128 | num_samples_per_cls (int): the max samples used in each class. 129 | target_video_len (int): the number of video frames will be load. 130 | align_transform (callable): Align different videos in a specified size. 131 | temporal_sample (callable): Sample the target length of a video. 132 | """ 133 | 134 | def __init__(self, 135 | configs, 136 | annotation_path, 137 | transform=None, 138 | temporal_sample=None): 139 | self.configs = configs 140 | self.data = load_annotations(annotation_path, self.configs.num_class, self.configs.num_samples_per_cls) 141 | 142 | self.transform = transform 143 | self.temporal_sample = temporal_sample 144 | self.target_video_len = self.configs.num_frames 145 | self.objective = self.configs.objective 146 | self.v_decoder = DecordInit() 147 | 148 | # mask 149 | if self.objective == 'mim': 150 | self.mask_generator = CubeMaskGenerator(input_size=(self.target_video_len//2,14,14),min_num_patches=16) 151 | 152 | def __getitem__(self, index): 153 | while True: 154 | try: 155 | path = self.data[index]['video'] 156 | v_reader = self.v_decoder(path) 157 | total_frames = len(v_reader) 158 | 159 | # Sampling video frames 160 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) 161 | assert end_frame_ind-start_frame_ind >= self.target_video_len 162 | frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int) 163 | video = v_reader.get_batch(frame_indice).asnumpy() 164 | del v_reader 165 | break 166 | except Exception as e: 167 | print(e) 168 | index = random.randint(0, len(self.data) - 1) 169 | 170 | # Video align transform: T C H W 171 | with torch.no_grad(): 172 | video = torch.from_numpy(video).permute(0,3,1,2) 173 | if self.transform is not None: 174 | if self.objective == 'mim': 175 | pre_transform, post_transform = self.transform 176 | video = pre_transform(video) # align shape 177 | else: 178 | video = self.transform(video) 179 | 180 | # Label (depends) 181 | if self.objective == 'mim': 182 | # old version 183 | ''' 184 | mask, cube_marker = self.mask_generator() # T' H' W' 185 | label = np.stack(list(map(extract_hog_features, video.permute(0,2,3,1).numpy())), axis=0) # T H W C -> T H' W' C' 186 | ''' 187 | # new version 188 | mask, cube_marker = self.mask_generator() # T' H' W' 189 | hog_inputs = video.permute(0,2,3,1).numpy() 190 | hog_features = np.zeros((self.target_video_len,14,14,2*2*3*9)) 191 | # speed up the extraction of hog features 192 | for marker in cube_marker: # [[start, span]] 193 | start_frame, span_frame = marker 194 | center_index = start_frame*2 + span_frame*2//2 # fix the temporal stride to 2 195 | hog_features[center_index] = extract_hog_features(hog_inputs[center_index]) 196 | label = hog_features 197 | else: 198 | label = self.data[index]['label'] 199 | 200 | if self.objective == 'mim': 201 | if self.transform is not None: 202 | video = post_transform(video) # to tensor & norm 203 | return video, numpy2tensor(label), numpy2tensor(mask), cube_marker 204 | else: 205 | return video, label 206 | 207 | def __len__(self): 208 | return len(self.data) 209 | 210 | 211 | if __name__ == '__main__': 212 | # Unit test for loading video and computing time cost 213 | import data_transform as T 214 | import time 215 | path = './YABnJL_bDzw.mp4' 216 | color_jitter = 0.4 217 | auto_augment = 'rand-m9-mstd0.5-inc1' 218 | scale = None 219 | mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225) 220 | transform = T.create_video_transform( 221 | input_size=224, 222 | is_training=True, 223 | scale=scale, 224 | hflip=0.5, 225 | color_jitter=color_jitter, 226 | auto_augment=auto_augment, 227 | interpolation='bicubic', 228 | mean=mean, 229 | std=std) 230 | 231 | v_decoder = DecordInit() 232 | v_reader = v_decoder(path) 233 | total_frames = len(v_reader) 234 | target_video_len = 16 235 | # Sampling video frames 236 | temporal_sample = T.TemporalRandomCrop(target_video_len*16) 237 | start_frame_ind, end_frame_ind = temporal_sample(total_frames) 238 | frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, target_video_len, dtype=int) 239 | video = v_reader.get_batch(frame_indice).asnumpy() 240 | del v_reader 241 | 242 | # Video align transform: T C H W 243 | with torch.no_grad(): 244 | video = torch.from_numpy(video).permute(0,3,1,2) 245 | if transform is not None: 246 | video = transform(video) 247 | 248 | show_processed_image(video.permute(0,2,3,1), save_dir='./', mean=mean, std=std) 249 | ''' 250 | mask_generator = CubeMaskGenerator(input_size=(8,14,14),min_num_patches=16) 251 | counts = 1 252 | while True: 253 | if counts > 100: 254 | break 255 | start_time = time.perf_counter() 256 | v_decoder = DecordInit() 257 | v_reader = v_decoder(path) 258 | # Sampling video frames 259 | total_frames = len(v_reader) 260 | align_transform = T.Compose([ 261 | T.RandomResizedCrop(size=(224, 224), area_range=(0.5, 1.0), interpolation=3), #InterpolationMode.BICUBIC 262 | T.Flip(), 263 | ]) 264 | temporal_sample = T.TemporalRandomCrop(16*4) 265 | start_frame_ind, end_frame_ind = temporal_sample(total_frames) 266 | frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, 267 | 16, dtype=int) 268 | video = v_reader.get_batch(frame_indice).asnumpy() 269 | del v_reader 270 | 271 | # Video align transform: T C H W 272 | with torch.no_grad(): 273 | video = torch.from_numpy(video).permute(0,3,1,2) 274 | align_transform.randomize_parameters() 275 | video = align_transform(video) 276 | #label = np.stack(list(map(extract_hog_features, video.permute(0,2,3,1).numpy())), axis=0) # T H W C -> T H' W' C' 277 | _, hog_image = hog(video.permute(0,2,3,1).numpy()[0][:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False, visualize=True) 278 | mask, cube_marker = mask_generator() # T' H' W' 279 | counts += 1 280 | print(f'{(time.perf_counter()-start_time):.3f}') 281 | print('finish') 282 | ''' 283 | #_, hog_image = hog(video.permute(0,2,3,1).numpy()[0][:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False, visualize=True) 284 | #from skimage import io 285 | #io.imsave('./test_img_hog.jpg',hog_image) 286 | #show_processed_image(video.permute(0,2,3,1), save_dir='./') 287 | -------------------------------------------------------------------------------- /demo/9r8wpMS2iEk_000048_000058.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mx-mark/VideoTransformer-pytorch/194cae69722eb5efad031c59f4ff03bc60633fa8/demo/9r8wpMS2iEk_000048_000058.mp4 -------------------------------------------------------------------------------- /demo/YABnJL_bDzw.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mx-mark/VideoTransformer-pytorch/194cae69722eb5efad031c59f4ff03bc60633fa8/demo/YABnJL_bDzw.mp4 -------------------------------------------------------------------------------- /k400_classmap.json: -------------------------------------------------------------------------------- 1 | {"abseiling": "0", "air_drumming": "1", "answering_questions": "2", "applauding": "3", "applying_cream": "4", "archery": "5", "arm_wrestling": "6", "arranging_flowers": "7", "assembling_computer": "8", "auctioning": "9", "baby_waking_up": "10", "baking_cookies": "11", "balloon_blowing": "12", "bandaging": "13", "barbequing": "14", "bartending": "15", "beatboxing": "16", "bee_keeping": "17", "belly_dancing": "18", "bench_pressing": "19", "bending_back": "20", "bending_metal": "21", "biking_through_snow": "22", "blasting_sand": "23", "blowing_glass": "24", "blowing_leaves": "25", "blowing_nose": "26", "blowing_out_candles": "27", "bobsledding": "28", "bookbinding": "29", "bouncing_on_trampoline": "30", "bowling": "31", "braiding_hair": "32", "breading_or_breadcrumbing": "33", "breakdancing": "34", "brush_painting": "35", "brushing_hair": "36", "brushing_teeth": "37", "building_cabinet": "38", "building_shed": "39", "bungee_jumping": "40", "busking": "41", "canoeing_or_kayaking": "42", "capoeira": "43", "carrying_baby": "44", "cartwheeling": "45", "carving_pumpkin": "46", "catching_fish": "47", "catching_or_throwing_baseball": "48", "catching_or_throwing_frisbee": "49", "catching_or_throwing_softball": "50", "celebrating": "51", "changing_oil": "52", "changing_wheel": "53", "checking_tires": "54", "cheerleading": "55", "chopping_wood": "56", "clapping": "57", "clay_pottery_making": "58", "clean_and_jerk": "59", "cleaning_floor": "60", "cleaning_gutters": "61", "cleaning_pool": "62", "cleaning_shoes": "63", "cleaning_toilet": "64", "cleaning_windows": "65", "climbing_a_rope": "66", "climbing_ladder": "67", "climbing_tree": "68", "contact_juggling": "69", "cooking_chicken": "70", "cooking_egg": "71", "cooking_on_campfire": "72", "cooking_sausages": "73", "counting_money": "74", "country_line_dancing": "75", "cracking_neck": "76", "crawling_baby": "77", "crossing_river": "78", "crying": "79", "curling_hair": "80", "cutting_nails": "81", "cutting_pineapple": "82", "cutting_watermelon": "83", "dancing_ballet": "84", "dancing_charleston": "85", "dancing_gangnam_style": "86", "dancing_macarena": "87", "deadlifting": "88", "decorating_the_christmas_tree": "89", "digging": "90", "dining": "91", "disc_golfing": "92", "diving_cliff": "93", "dodgeball": "94", "doing_aerobics": "95", "doing_laundry": "96", "doing_nails": "97", "drawing": "98", "dribbling_basketball": "99", "drinking": "100", "drinking_beer": "101", "drinking_shots": "102", "driving_car": "103", "driving_tractor": "104", "drop_kicking": "105", "drumming_fingers": "106", "dunking_basketball": "107", "dying_hair": "108", "eating_burger": "109", "eating_cake": "110", "eating_carrots": "111", "eating_chips": "112", "eating_doughnuts": "113", "eating_hotdog": "114", "eating_ice_cream": "115", "eating_spaghetti": "116", "eating_watermelon": "117", "egg_hunting": "118", "exercising_arm": "119", "exercising_with_an_exercise_ball": "120", "extinguishing_fire": "121", "faceplanting": "122", "feeding_birds": "123", "feeding_fish": "124", "feeding_goats": "125", "filling_eyebrows": "126", "finger_snapping": "127", "fixing_hair": "128", "flipping_pancake": "129", "flying_kite": "130", "folding_clothes": "131", "folding_napkins": "132", "folding_paper": "133", "front_raises": "134", "frying_vegetables": "135", "garbage_collecting": "136", "gargling": "137", "getting_a_haircut": "138", "getting_a_tattoo": "139", "giving_or_receiving_award": "140", "golf_chipping": "141", "golf_driving": "142", "golf_putting": "143", "grinding_meat": "144", "grooming_dog": "145", "grooming_horse": "146", "gymnastics_tumbling": "147", "hammer_throw": "148", "headbanging": "149", "headbutting": "150", "high_jump": "151", "high_kick": "152", "hitting_baseball": "153", "hockey_stop": "154", "holding_snake": "155", "hopscotch": "156", "hoverboarding": "157", "hugging": "158", "hula_hooping": "159", "hurdling": "160", "hurling_(sport)": "161", "ice_climbing": "162", "ice_fishing": "163", "ice_skating": "164", "ironing": "165", "javelin_throw": "166", "jetskiing": "167", "jogging": "168", "juggling_balls": "169", "juggling_fire": "170", "juggling_soccer_ball": "171", "jumping_into_pool": "172", "jumpstyle_dancing": "173", "kicking_field_goal": "174", "kicking_soccer_ball": "175", "kissing": "176", "kitesurfing": "177", "knitting": "178", "krumping": "179", "laughing": "180", "laying_bricks": "181", "long_jump": "182", "lunge": "183", "making_a_cake": "184", "making_a_sandwich": "185", "making_bed": "186", "making_jewelry": "187", "making_pizza": "188", "making_snowman": "189", "making_sushi": "190", "making_tea": "191", "marching": "192", "massaging_back": "193", "massaging_feet": "194", "massaging_legs": "195", "massaging_person's_head": "196", "milking_cow": "197", "mopping_floor": "198", "motorcycling": "199", "moving_furniture": "200", "mowing_lawn": "201", "news_anchoring": "202", "opening_bottle": "203", "opening_present": "204", "paragliding": "205", "parasailing": "206", "parkour": "207", "passing_American_football_(in_game)": "208", "passing_American_football_(not_in_game)": "209", "peeling_apples": "210", "peeling_potatoes": "211", "petting_animal_(not_cat)": "212", "petting_cat": "213", "picking_fruit": "214", "planting_trees": "215", "plastering": "216", "playing_accordion": "217", "playing_badminton": "218", "playing_bagpipes": "219", "playing_basketball": "220", "playing_bass_guitar": "221", "playing_cards": "222", "playing_cello": "223", "playing_chess": "224", "playing_clarinet": "225", "playing_controller": "226", "playing_cricket": "227", "playing_cymbals": "228", "playing_didgeridoo": "229", "playing_drums": "230", "playing_flute": "231", "playing_guitar": "232", "playing_harmonica": "233", "playing_harp": "234", "playing_ice_hockey": "235", "playing_keyboard": "236", "playing_kickball": "237", "playing_monopoly": "238", "playing_organ": "239", "playing_paintball": "240", "playing_piano": "241", "playing_poker": "242", "playing_recorder": "243", "playing_saxophone": "244", "playing_squash_or_racquetball": "245", "playing_tennis": "246", "playing_trombone": "247", "playing_trumpet": "248", "playing_ukulele": "249", "playing_violin": "250", "playing_volleyball": "251", "playing_xylophone": "252", "pole_vault": "253", "presenting_weather_forecast": "254", "pull_ups": "255", "pumping_fist": "256", "pumping_gas": "257", "punching_bag": "258", "punching_person_(boxing)": "259", "push_up": "260", "pushing_car": "261", "pushing_cart": "262", "pushing_wheelchair": "263", "reading_book": "264", "reading_newspaper": "265", "recording_music": "266", "riding_a_bike": "267", "riding_camel": "268", "riding_elephant": "269", "riding_mechanical_bull": "270", "riding_mountain_bike": "271", "riding_mule": "272", "riding_or_walking_with_horse": "273", "riding_scooter": "274", "riding_unicycle": "275", "ripping_paper": "276", "robot_dancing": "277", "rock_climbing": "278", "rock_scissors_paper": "279", "roller_skating": "280", "running_on_treadmill": "281", "sailing": "282", "salsa_dancing": "283", "sanding_floor": "284", "scrambling_eggs": "285", "scuba_diving": "286", "setting_table": "287", "shaking_hands": "288", "shaking_head": "289", "sharpening_knives": "290", "sharpening_pencil": "291", "shaving_head": "292", "shaving_legs": "293", "shearing_sheep": "294", "shining_shoes": "295", "shooting_basketball": "296", "shooting_goal_(soccer)": "297", "shot_put": "298", "shoveling_snow": "299", "shredding_paper": "300", "shuffling_cards": "301", "side_kick": "302", "sign_language_interpreting": "303", "singing": "304", "situp": "305", "skateboarding": "306", "ski_jumping": "307", "skiing_(not_slalom_or_crosscountry)": "308", "skiing_crosscountry": "309", "skiing_slalom": "310", "skipping_rope": "311", "skydiving": "312", "slacklining": "313", "slapping": "314", "sled_dog_racing": "315", "smoking": "316", "smoking_hookah": "317", "snatch_weight_lifting": "318", "sneezing": "319", "sniffing": "320", "snorkeling": "321", "snowboarding": "322", "snowkiting": "323", "snowmobiling": "324", "somersaulting": "325", "spinning_poi": "326", "spray_painting": "327", "spraying": "328", "springboard_diving": "329", "squat": "330", "sticking_tongue_out": "331", "stomping_grapes": "332", "stretching_arm": "333", "stretching_leg": "334", "strumming_guitar": "335", "surfing_crowd": "336", "surfing_water": "337", "sweeping_floor": "338", "swimming_backstroke": "339", "swimming_breast_stroke": "340", "swimming_butterfly_stroke": "341", "swing_dancing": "342", "swinging_legs": "343", "swinging_on_something": "344", "sword_fighting": "345", "tai_chi": "346", "taking_a_shower": "347", "tango_dancing": "348", "tap_dancing": "349", "tapping_guitar": "350", "tapping_pen": "351", "tasting_beer": "352", "tasting_food": "353", "testifying": "354", "texting": "355", "throwing_axe": "356", "throwing_ball": "357", "throwing_discus": "358", "tickling": "359", "tobogganing": "360", "tossing_coin": "361", "tossing_salad": "362", "training_dog": "363", "trapezing": "364", "trimming_or_shaving_beard": "365", "trimming_trees": "366", "triple_jump": "367", "tying_bow_tie": "368", "tying_knot_(not_on_a_tie)": "369", "tying_tie": "370", "unboxing": "371", "unloading_truck": "372", "using_computer": "373", "using_remote_controller_(not_gaming)": "374", "using_segway": "375", "vault": "376", "waiting_in_line": "377", "walking_the_dog": "378", "washing_dishes": "379", "washing_feet": "380", "washing_hair": "381", "washing_hands": "382", "water_skiing": "383", "water_sliding": "384", "watering_plants": "385", "waxing_back": "386", "waxing_chest": "387", "waxing_eyebrows": "388", "waxing_legs": "389", "weaving_basket": "390", "welding": "391", "whistling": "392", "windsurfing": "393", "wrapping_present": "394", "wrestling": "395", "writing": "396", "yawning": "397", "yoga": "398", "zumba": "399"} -------------------------------------------------------------------------------- /k600_classmap.json: -------------------------------------------------------------------------------- 1 | { 2 | "arguing": 0, 3 | "throwing ball (not baseball or American football)": 1, 4 | "falling off chair": 2, 5 | "shooting basketball": 3, 6 | "burping": 4, 7 | "ice swimming": 5, 8 | "assembling computer": 6, 9 | "playing chess": 7, 10 | "yawning": 8, 11 | "tackling": 9, 12 | "using a sledge hammer": 10, 13 | "sneezing": 11, 14 | "putting in contact lenses": 12, 15 | "massaging back": 13, 16 | "playing paintball": 14, 17 | "looking at phone": 15, 18 | "massaging neck": 16, 19 | "making sushi": 17, 20 | "poking bellybutton": 18, 21 | "scrambling eggs": 19, 22 | "hammer throw": 20, 23 | "ice skating": 21, 24 | "playing rubiks cube": 22, 25 | "catching or throwing frisbee": 23, 26 | "auctioning": 24, 27 | "person collecting garbage": 25, 28 | "contorting": 26, 29 | "playing piano": 27, 30 | "ski jumping": 28, 31 | "cutting pineapple": 29, 32 | "riding unicycle": 30, 33 | "playing lute": 31, 34 | "blowing bubble gum": 32, 35 | "backflip (human)": 33, 36 | "spelunking": 34, 37 | "getting a piercing": 35, 38 | "parkour": 36, 39 | "frying vegetables": 37, 40 | "planting trees": 38, 41 | "long jump": 39, 42 | "setting table": 40, 43 | "using puppets": 41, 44 | "whistling": 42, 45 | "cooking egg": 43, 46 | "bench pressing": 44, 47 | "scrubbing face": 45, 48 | "combing hair": 46, 49 | "snowmobiling": 47, 50 | "ice climbing": 48, 51 | "drawing": 49, 52 | "massaging feet": 50, 53 | "headbanging": 51, 54 | "opening wine bottle": 52, 55 | "marching": 53, 56 | "feeding goats": 54, 57 | "jumping jacks": 55, 58 | "playing hand clapping games": 56, 59 | "brushing teeth": 57, 60 | "dumpster diving": 58, 61 | "dancing gangnam style": 59, 62 | "springboard diving": 60, 63 | "high kick": 61, 64 | "making tea": 62, 65 | "adjusting glasses": 63, 66 | "putting on mascara": 64, 67 | "slapping": 65, 68 | "juggling fire": 66, 69 | "situp": 67, 70 | "snatch weight lifting": 68, 71 | "using segway": 69, 72 | "motorcycling": 70, 73 | "base jumping": 71, 74 | "playing violin": 72, 75 | "wood burning (art)": 73, 76 | "pushing car": 74, 77 | "kitesurfing": 75, 78 | "acting in play": 76, 79 | "cleaning shoes": 77, 80 | "pinching": 78, 81 | "rolling pastry": 79, 82 | "eating carrots": 80, 83 | "jumping into pool": 81, 84 | "tasting beer": 82, 85 | "blowing leaves": 83, 86 | "playing squash or racquetball": 84, 87 | "directing traffic": 85, 88 | "drooling": 86, 89 | "catching fish": 87, 90 | "bungee jumping": 88, 91 | "geocaching": 89, 92 | "blowing glass": 90, 93 | "slacklining": 91, 94 | "bending metal": 92, 95 | "changing gear in car": 93, 96 | "cooking scallops": 94, 97 | "rock scissors paper": 95, 98 | "training dog": 96, 99 | "chiseling stone": 97, 100 | "running on treadmill": 98, 101 | "zumba": 99, 102 | "using a power drill": 100, 103 | "ripping paper": 101, 104 | "riding a bike": 102, 105 | "testifying": 103, 106 | "jetskiing": 104, 107 | "salsa dancing": 105, 108 | "bouncing on bouncy castle": 106, 109 | "windsurfing": 107, 110 | "playing blackjack": 108, 111 | "playing cricket": 109, 112 | "applying cream": 110, 113 | "drop kicking": 111, 114 | "waving hand": 112, 115 | "eating hotdog": 113, 116 | "hockey stop": 114, 117 | "paragliding": 115, 118 | "sharpening knives": 116, 119 | "assembling bicycle": 117, 120 | "playing polo": 118, 121 | "visiting the zoo": 119, 122 | "flying kite": 120, 123 | "lunge": 121, 124 | "shaping bread dough": 122, 125 | "playing didgeridoo": 123, 126 | "shining shoes": 124, 127 | "installing carpet": 125, 128 | "trapezing": 126, 129 | "smoking": 127, 130 | "sausage making": 128, 131 | "water sliding": 129, 132 | "wrestling": 130, 133 | "changing oil": 131, 134 | "dyeing eyebrows": 132, 135 | "bouncing on trampoline": 133, 136 | "clapping": 134, 137 | "playing tennis": 135, 138 | "busking": 136, 139 | "building sandcastle": 137, 140 | "cutting apple": 138, 141 | "getting a haircut": 139, 142 | "ironing": 140, 143 | "tiptoeing": 141, 144 | "pouring beer": 142, 145 | "riding or walking with horse": 143, 146 | "luge": 144, 147 | "playing pan pipes": 145, 148 | "making a cake": 146, 149 | "pushing wheelbarrow": 147, 150 | "yoga": 148, 151 | "passing american football (not in game)": 149, 152 | "somersaulting": 150, 153 | "smoking pipe": 151, 154 | "card throwing": 152, 155 | "riding snow blower": 153, 156 | "playing harmonica": 154, 157 | "ice fishing": 155, 158 | "building lego": 156, 159 | "playing scrabble": 157, 160 | "dribbling basketball": 158, 161 | "dancing charleston": 159, 162 | "golf putting": 160, 163 | "mushroom foraging": 161, 164 | "lighting fire": 162, 165 | "standing on hands": 163, 166 | "diving cliff": 164, 167 | "petting animal (not cat)": 165, 168 | "spinning poi": 166, 169 | "changing wheel (not on bike)": 167, 170 | "stretching arm": 168, 171 | "packing": 169, 172 | "mowing lawn": 170, 173 | "chewing gum": 171, 174 | "using a microscope": 172, 175 | "playing dominoes": 173, 176 | "falling off bike": 174, 177 | "krumping": 175, 178 | "twiddling fingers": 176, 179 | "shearing sheep": 177, 180 | "folding paper": 178, 181 | "playing marbles": 179, 182 | "holding snake": 180, 183 | "sweeping floor": 181, 184 | "smashing": 182, 185 | "cutting nails": 183, 186 | "kicking soccer ball": 184, 187 | "playing flute": 185, 188 | "using atm": 186, 189 | "playing cymbals": 187, 190 | "opening door": 188, 191 | "photobombing": 189, 192 | "fly tying": 190, 193 | "roasting marshmallows": 191, 194 | "playing ping pong": 192, 195 | "using inhaler": 193, 196 | "swimming backstroke": 194, 197 | "crossing eyes": 195, 198 | "kicking field goal": 196, 199 | "air drumming": 197, 200 | "throwing tantrum": 198, 201 | "biking through snow": 199, 202 | "waking up": 200, 203 | "playing controller": 201, 204 | "beatboxing": 202, 205 | "baking cookies": 203, 206 | "snorkeling": 204, 207 | "carrying baby": 205, 208 | "grooming horse": 206, 209 | "threading needle": 207, 210 | "tickling": 208, 211 | "plastering": 209, 212 | "making balloon shapes": 210, 213 | "drinking shots": 211, 214 | "passing American football (in game)": 212, 215 | "clean and jerk": 213, 216 | "building shed": 214, 217 | "eating spaghetti": 215, 218 | "sipping cup": 216, 219 | "walking the dog": 217, 220 | "steer roping": 218, 221 | "stretching leg": 219, 222 | "belly dancing": 220, 223 | "playing keyboard": 221, 224 | "playing netball": 222, 225 | "cleaning gutters": 223, 226 | "playing basketball": 224, 227 | "playing accordion": 225, 228 | "dining": 226, 229 | "dyeing hair": 227, 230 | "making jewelry": 228, 231 | "massaging legs": 229, 232 | "card stacking": 230, 233 | "clam digging": 231, 234 | "weaving basket": 232, 235 | "unboxing": 233, 236 | "news anchoring": 234, 237 | "using a paint roller": 235, 238 | "using circular saw": 236, 239 | "cracking back": 237, 240 | "wading through water": 238, 241 | "cheerleading": 239, 242 | "gospel singing in church": 240, 243 | "playing drums": 241, 244 | "checking tires": 242, 245 | "land sailing": 243, 246 | "tapping pen": 244, 247 | "laughing": 245, 248 | "lifting hat": 246, 249 | "capsizing": 247, 250 | "barbequing": 248, 251 | "tango dancing": 249, 252 | "swimming butterfly stroke": 250, 253 | "bathing dog": 251, 254 | "scuba diving": 252, 255 | "using remote controller (not gaming)": 253, 256 | "javelin throw": 254, 257 | "cutting orange": 255, 258 | "sign language interpreting": 256, 259 | "spray painting": 257, 260 | "bulldozing": 258, 261 | "laying concrete": 259, 262 | "walking through snow": 260, 263 | "applauding": 261, 264 | "sharpening pencil": 262, 265 | "putting on eyeliner": 263, 266 | "gymnastics tumbling": 264, 267 | "braiding hair": 265, 268 | "flint knapping": 266, 269 | "shaking head": 267, 270 | "making cheese": 268, 271 | "trimming trees": 269, 272 | "hula hooping": 270, 273 | "deadlifting": 271, 274 | "washing hands": 272, 275 | "tying necktie": 273, 276 | "sword fighting": 274, 277 | "playing gong": 275, 278 | "using a wrench": 276, 279 | "putting on foundation": 277, 280 | "pillow fight": 278, 281 | "crawling baby": 279, 282 | "moon walking": 280, 283 | "doing jigsaw puzzle": 281, 284 | "making horseshoes": 282, 285 | "bookbinding": 283, 286 | "stomping grapes": 284, 287 | "brush painting": 285, 288 | "playing darts": 286, 289 | "presenting weather forecast": 287, 290 | "throwing discus": 288, 291 | "sucking lolly": 289, 292 | "playing recorder": 290, 293 | "embroidering": 291, 294 | "climbing tree": 292, 295 | "petting cat": 293, 296 | "reading newspaper": 294, 297 | "contact juggling": 295, 298 | "riding elephant": 296, 299 | "picking fruit": 297, 300 | "sleeping": 298, 301 | "cartwheeling": 299, 302 | "eating watermelon": 300, 303 | "shooting goal (soccer)": 301, 304 | "jogging": 302, 305 | "head stand": 303, 306 | "yarn spinning": 304, 307 | "blowing nose": 305, 308 | "calligraphy": 306, 309 | "rope pushdown": 307, 310 | "cleaning toilet": 308, 311 | "making the bed": 309, 312 | "separating eggs": 310, 313 | "answering questions": 311, 314 | "laying bricks": 312, 315 | "opening present": 313, 316 | "robot dancing": 314, 317 | "shaking hands": 315, 318 | "weaving fabric": 316, 319 | "longboarding": 317, 320 | "washing dishes": 318, 321 | "giving or receiving award": 319, 322 | "curling (sport)": 320, 323 | "extinguishing fire": 321, 324 | "vacuuming floor": 322, 325 | "arranging flowers": 323, 326 | "mopping floor": 324, 327 | "shaving legs": 325, 328 | "pushing cart": 326, 329 | "cooking sausages (not on barbeque)": 327, 330 | "planing wood": 328, 331 | "throwing snowballs": 329, 332 | "swimming front crawl": 330, 333 | "pumping gas": 331, 334 | "winking": 332, 335 | "exercising with an exercise ball": 333, 336 | "fixing hair": 334, 337 | "feeding birds": 335, 338 | "surfing water": 336, 339 | "making bubbles": 337, 340 | "skiing mono": 338, 341 | "opening refrigerator": 339, 342 | "folding clothes": 340, 343 | "marriage proposal": 341, 344 | "shining flashlight": 342, 345 | "sawing wood": 343, 346 | "fencing (sport)": 344, 347 | "drumming fingers": 345, 348 | "playing xylophone": 346, 349 | "hitting baseball": 347, 350 | "fixing bicycle": 348, 351 | "sanding floor": 349, 352 | "swing dancing": 350, 353 | "moving furniture": 351, 354 | "side kick": 352, 355 | "snowkiting": 353, 356 | "opening bottle (not wine)": 354, 357 | "playing pinball": 355, 358 | "playing saxophone": 356, 359 | "texting": 357, 360 | "chopping wood": 358, 361 | "playing cello": 359, 362 | "scrapbooking": 360, 363 | "ironing hair": 361, 364 | "arm wrestling": 362, 365 | "playing poker": 363, 366 | "eating doughnuts": 364, 367 | "waxing legs": 365, 368 | "capoeira": 366, 369 | "playing bagpipes": 367, 370 | "shucking oysters": 368, 371 | "cosplaying": 369, 372 | "bending back": 370, 373 | "breakdancing": 371, 374 | "playing trombone": 372, 375 | "building cabinet": 373, 376 | "hurdling": 374, 377 | "mountain climber (exercise)": 375, 378 | "playing volleyball": 376, 379 | "bee keeping": 377, 380 | "wading through mud": 378, 381 | "hugging baby": 379, 382 | "abseiling": 380, 383 | "carving pumpkin": 381, 384 | "breading or breadcrumbing": 382, 385 | "pole vault": 383, 386 | "juggling soccer ball": 384, 387 | "counting money": 385, 388 | "dancing macarena": 386, 389 | "dunking basketball": 387, 390 | "roller skating": 388, 391 | "knitting": 389, 392 | "laying stone": 390, 393 | "tapping guitar": 391, 394 | "shuffling feet": 392, 395 | "breaking boards": 393, 396 | "waxing back": 394, 397 | "coloring in": 395, 398 | "chopping meat": 396, 399 | "archery": 397, 400 | "playing monopoly": 398, 401 | "swinging baseball bat": 399, 402 | "repairing puncture": 400, 403 | "making paper aeroplanes": 401, 404 | "cracking neck": 402, 405 | "tying knot (not on a tie)": 403, 406 | "wrapping present": 404, 407 | "using bagging machine": 405, 408 | "jaywalking": 406, 409 | "raising eyebrows": 407, 410 | "dancing ballet": 408, 411 | "smelling feet": 409, 412 | "celebrating": 410, 413 | "bodysurfing": 411, 414 | "bartending": 412, 415 | "shaving head": 413, 416 | "squat": 414, 417 | "chopping vegetables": 415, 418 | "skydiving": 416, 419 | "jumpstyle dancing": 417, 420 | "archaeological excavation": 418, 421 | "waiting in line": 419, 422 | "recording music": 420, 423 | "passing soccer ball": 421, 424 | "washing feet": 422, 425 | "waxing eyebrows": 423, 426 | "making snowman": 424, 427 | "hoverboarding": 425, 428 | "rock climbing": 426, 429 | "catching or throwing softball": 427, 430 | "smoking hookah": 428, 431 | "skateboarding": 429, 432 | "climbing a rope": 430, 433 | "watching tv": 431, 434 | "delivering mail": 432, 435 | "putting on sari": 433, 436 | "swimming breast stroke": 434, 437 | "playing beer pong": 435, 438 | "eating chips": 436, 439 | "playing laser tag": 437, 440 | "blowing out candles": 438, 441 | "sled dog racing": 439, 442 | "shopping": 440, 443 | "folding napkins": 441, 444 | "roasting pig": 442, 445 | "writing": 443, 446 | "hand washing clothes": 444, 447 | "playing trumpet": 445, 448 | "riding mechanical bull": 446, 449 | "carving ice": 447, 450 | "punching bag": 448, 451 | "shot put": 449, 452 | "sewing": 450, 453 | "crossing river": 451, 454 | "hugging (not baby)": 452, 455 | "tying shoe laces": 453, 456 | "throwing water balloon": 454, 457 | "putting on lipstick": 455, 458 | "playing bass guitar": 456, 459 | "attending conference": 457, 460 | "playing kickball": 458, 461 | "peeling potatoes": 459, 462 | "trimming shrubs": 460, 463 | "massaging person's head": 461, 464 | "throwing knife": 462, 465 | "skipping stone": 463, 466 | "inflating balloons": 464, 467 | "playing harp": 465, 468 | "lock picking": 466, 469 | "skiing crosscountry": 467, 470 | "washing hair": 468, 471 | "cooking on campfire": 469, 472 | "grinding meat": 470, 473 | "preparing salad": 471, 474 | "huddling": 472, 475 | "tobogganing": 473, 476 | "tasting wine": 474, 477 | "bowling": 475, 478 | "tightrope walking": 476, 479 | "home roasting coffee": 477, 480 | "curling hair": 478, 481 | "chiseling wood": 479, 482 | "water skiing": 480, 483 | "calculating": 481, 484 | "tagging graffiti": 482, 485 | "dodgeball": 483, 486 | "tying bow tie": 484, 487 | "square dancing": 485, 488 | "laying tiles": 486, 489 | "tai chi": 487, 490 | "tie dying": 488, 491 | "throwing axe": 489, 492 | "playing ocarina": 490, 493 | "cutting watermelon": 491, 494 | "eating ice cream": 492, 495 | "cleaning pool": 493, 496 | "making pizza": 494, 497 | "playing field hockey": 495, 498 | "shuffling cards": 496, 499 | "singing": 497, 500 | "driving tractor": 498, 501 | "pull ups": 499, 502 | "golf chipping": 500, 503 | "faceplanting": 501, 504 | "shoveling snow": 502, 505 | "sailing": 503, 506 | "playing ukulele": 504, 507 | "needle felting": 505, 508 | "milking cow": 506, 509 | "catching or throwing baseball": 507, 510 | "trimming or shaving beard": 508, 511 | "playing with trains": 509, 512 | "high jump": 510, 513 | "riding mule": 511, 514 | "bobsledding": 512, 515 | "cumbia": 513, 516 | "jumping bicycle": 514, 517 | "riding scooter": 515, 518 | "reading book": 516, 519 | "climbing ladder": 517, 520 | "getting a tattoo": 518, 521 | "mosh pit dancing": 519, 522 | "historical reenactment": 520, 523 | "docking boat": 521, 524 | "bottling": 522, 525 | "swinging on something": 523, 526 | "parasailing": 524, 527 | "doing nails": 525, 528 | "talking on cell phone": 526, 529 | "blowdrying hair": 527, 530 | "clay pottery making": 528, 531 | "eating burger": 529, 532 | "juggling balls": 530, 533 | "doing laundry": 531, 534 | "finger snapping": 532, 535 | "tasting food": 533, 536 | "playing clarinet": 534, 537 | "playing organ": 535, 538 | "decorating the christmas tree": 536, 539 | "grooming dog": 537, 540 | "eating cake": 538, 541 | "sticking tongue out": 539, 542 | "disc golfing": 540, 543 | "kissing": 541, 544 | "cleaning windows": 542, 545 | "making a sandwich": 543, 546 | "staring": 544, 547 | "front raises": 545, 548 | "hurling (sport)": 546, 549 | "leatherworking": 547, 550 | "playing maracas": 548, 551 | "brushing hair": 549, 552 | "playing badminton": 550, 553 | "photocopying": 551, 554 | "canoeing or kayaking": 552, 555 | "polishing metal": 553, 556 | "playing guitar": 554, 557 | "battle rope training": 555, 558 | "egg hunting": 556, 559 | "pirouetting": 557, 560 | "hopscotch": 558, 561 | "karaoke": 559, 562 | "golf driving": 560, 563 | "tossing coin": 561, 564 | "bandaging": 562, 565 | "alligator wrestling": 563, 566 | "gold panning": 564, 567 | "flipping pancake": 565, 568 | "skiing slalom": 566, 569 | "casting fishing line": 567, 570 | "licking": 568, 571 | "bull fighting": 569, 572 | "cracking knuckles": 570, 573 | "headbutting": 571, 574 | "triple jump": 572, 575 | "fidgeting": 573, 576 | "punching person (boxing)": 574, 577 | "push up": 575, 578 | "lawn mower racing": 576, 579 | "riding camel": 577, 580 | "skipping rope": 578, 581 | "sword swallowing": 579, 582 | "feeding fish": 580, 583 | "peeling apples": 581, 584 | "doing aerobics": 582, 585 | "country line dancing": 583, 586 | "driving car": 584, 587 | "breathing fire": 585, 588 | "watering plants": 586, 589 | "popping balloons": 587, 590 | "snowboarding": 588, 591 | "blasting sand": 589, 592 | "pushing wheelchair": 590, 593 | "waxing chest": 591, 594 | "playing ice hockey": 592, 595 | "putting on shoes": 593, 596 | "pumping fist": 594, 597 | "surfing crowd": 595, 598 | "crying": 596, 599 | "unloading truck": 597, 600 | "welding": 598, 601 | "tap dancing": 599 602 | } -------------------------------------------------------------------------------- /mask_generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | 5 | class RandomMaskGenerator: 6 | def __init__(self, input_size=224, mask_ratio=0.6): 7 | if not isinstance(input_size, tuple): 8 | input_size = (input_size,) * 2 9 | 10 | self.height, self.width = input_size 11 | 12 | self.num_patches = self.height * self.width 13 | self.num_mask = int(mask_ratio * self.num_patches) 14 | 15 | def __call__(self): 16 | mask = np.hstack([ 17 | np.zeros(self.num_patches - self.num_mask), 18 | np.ones(self.num_mask), 19 | ]) 20 | np.random.shuffle(mask) # 21 | return mask # [1024] 22 | 23 | class CubeMaskGenerator: 24 | def __init__( 25 | self, input_size=(8,14,14), mask_ratio=0.4, min_num_patches=16, max_num_patches=None, 26 | min_aspect=0.3, max_aspect=None): 27 | self.temporal ,self.height, self.width = input_size 28 | 29 | self.num_patches = self.height * self.width 30 | self.num_masking_patches = int(self.num_patches * mask_ratio) 31 | self.num_masking_frames = int(self.temporal * mask_ratio) 32 | 33 | self.min_num_patches = min_num_patches # smaller than max_num_patches 34 | self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches 35 | 36 | max_aspect = max_aspect or 1 / min_aspect 37 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 38 | 39 | def __repr__(self): 40 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 41 | self.height, self.width, self.min_num_patches, self.max_num_patches, 42 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 43 | return repr_str 44 | 45 | def get_shape(self): 46 | return self.temporal, self.height, self.width 47 | 48 | def _mask(self, mask, max_mask_patches): 49 | delta = 0 50 | for attempt in range(10): 51 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 52 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 53 | h = int(round(math.sqrt(target_area * aspect_ratio))) 54 | w = int(round(math.sqrt(target_area / aspect_ratio))) 55 | if w < self.width and h < self.height: 56 | top = random.randint(0, self.height - h) 57 | left = random.randint(0, self.width - w) 58 | 59 | num_masked = mask[top: top + h, left: left + w].sum() 60 | # Overlap 61 | if 0 < h * w - num_masked <= max_mask_patches: 62 | for i in range(top, top + h): 63 | for j in range(left, left + w): 64 | if mask[i, j] == 0: 65 | mask[i, j] = 1 66 | delta += 1 67 | 68 | if delta > 0: 69 | break 70 | return delta 71 | 72 | def __call__(self): 73 | time_marker = np.zeros(shape=self.temporal, dtype=np.int32) 74 | cube_mask = np.zeros(shape=self.get_shape(), dtype=np.int32) 75 | cube_marker = [] 76 | temp_mask_count = 0 77 | while temp_mask_count < self.num_masking_frames: 78 | # generate 2D block-wise mask 79 | mask = np.zeros(shape=self.get_shape()[1:], dtype=np.int32) 80 | mask_count = 0 81 | while mask_count < self.num_masking_patches: 82 | max_mask_patches = self.num_masking_patches - mask_count 83 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 84 | 85 | delta = self._mask(mask, max_mask_patches) 86 | if delta == 0: 87 | break 88 | else: 89 | mask_count += delta 90 | # assign to cube mask 91 | start_frame = random.randint(0, self.temporal) 92 | accumulate_frames = random.randint(1, self.num_masking_frames - temp_mask_count) 93 | mask_count = 0 94 | for i in range(start_frame, start_frame+accumulate_frames): 95 | if i > self.temporal-1: 96 | break 97 | if time_marker[i] == 0: # only update the unmask frame 98 | time_marker[i] = 1 99 | cube_mask[i] = mask 100 | mask_count+=1 101 | else: #avoid to overlap the orginal mask 102 | break 103 | temp_mask_count += mask_count 104 | if mask_count > 0: # mark the center frame index(mask_count > 0) 105 | cube_marker.append([start_frame, mask_count]) 106 | 107 | return cube_mask, cube_marker 108 | 109 | if __name__ == '__main__': 110 | # Unit test for computing cube mask and extracting hog features 111 | 112 | ''' 113 | mask_generator = CubeMaskGenerator(input_size=(8,14,14),min_num_patches=16) 114 | mask, cube_marker = mask_generator() 115 | print(mask) 116 | from einops import repeat 117 | mask = repeat(mask, 't h w -> t (h dh) (w dw)', dh=56//14, dw=56//14) # nearest-neighbor resize 118 | print(mask) 119 | 120 | #print(cube_marker) 121 | center_index = np.zeros(8).astype('bool') 122 | for marker in cube_marker: 123 | center_index[marker[0]+marker[1]//2] = 1 124 | mask[~center_index] = 0 125 | print(mask) 126 | #for i in cube_marker: 127 | #print(mask[i].sum()/(56*56)) 128 | ''' 129 | 130 | #''' 131 | from skimage.feature import hog 132 | from skimage import io 133 | from skimage import data 134 | from einops import rearrange 135 | import torch 136 | 137 | def extract_hog(image): 138 | hog_features_r = hog(image[:,:,0], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False) 139 | hog_features_g = hog(image[:,:,1], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False) 140 | hog_features_b, hog_image = hog(image[:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False, visualize=True) 141 | hog_features = np.concatenate([hog_features_r,hog_features_g,hog_features_b], axis=-1) 142 | hog_features = rearrange(hog_features, '(ph dh) (pw dw) ch cw c -> ph pw (dh dw ch cw c)', ph=14, pw=14) 143 | return hog_features 144 | 145 | images = np.zeros((2,224,224,3)) 146 | image = data.astronaut()[:224,:224,:] # h w c 147 | image = torch.from_numpy(image).numpy() 148 | images[0] = image 149 | image = io.imread('./test_1.jpg')[:224,:224,:] 150 | images[1] = image 151 | hog_features = np.stack(list(map(extract_hog, images)), axis=0) 152 | print(hog_features.shape, np.min(hog_features), np.max(hog_features)) 153 | #io.imsave('./test_img_hog.jpg',hog_image) 154 | #''' -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 17 | x = x.long().view(-1, 1) 18 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 19 | 20 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 21 | off_value = smoothing / num_classes 22 | on_value = 1. - smoothing + off_value 23 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 24 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 25 | return y1 * lam + y2 * (1. - lam) 26 | 27 | def rand_bbox(img_shape, lam, margin=0., count=None): 28 | """ Standard CutMix bounding-box 29 | Generates a random square bbox based on lambda value. This impl includes 30 | support for enforcing a border margin as percent of bbox dimensions. 31 | 32 | Args: 33 | img_shape (tuple): Image shape as tuple 34 | lam (float): Cutmix lambda value 35 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 36 | count (int): Number of bbox to generate 37 | """ 38 | ratio = np.sqrt(1 - lam) 39 | img_h, img_w = img_shape[-2:] 40 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 41 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 42 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 43 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 44 | yl = np.clip(cy - cut_h // 2, 0, img_h) 45 | yh = np.clip(cy + cut_h // 2, 0, img_h) 46 | xl = np.clip(cx - cut_w // 2, 0, img_w) 47 | xh = np.clip(cx + cut_w // 2, 0, img_w) 48 | return yl, yh, xl, xh 49 | 50 | def cutmix_bbox_and_lam(img_shape, lam, correct_lam=True, count=None): 51 | """ Generate bbox and apply lambda correction. 52 | """ 53 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 54 | if correct_lam: 55 | bbox_area = (yu - yl) * (xu - xl) 56 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 57 | return (yl, yu, xl, xu), lam 58 | 59 | class Mixup: 60 | """ Mixup/Cutmix that applies different params to each element or whole batch 61 | 62 | Args: 63 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 64 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 65 | prob (float): probability of applying mixup or cutmix per batch or element 66 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 67 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 68 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 69 | label_smoothing (float): apply label smoothing to the mixed target tensor 70 | num_classes (int): number of classes for target 71 | """ 72 | def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0, prob=1.0, switch_prob=0.5, 73 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 74 | self.mixup_alpha = mixup_alpha 75 | self.cutmix_alpha = cutmix_alpha 76 | self.mix_prob = prob 77 | self.switch_prob = switch_prob 78 | self.label_smoothing = label_smoothing 79 | self.num_classes = num_classes 80 | self.mode = mode 81 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 82 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 83 | 84 | def _params_per_batch(self): 85 | lam = 1. 86 | use_cutmix = False 87 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 88 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 89 | use_cutmix = np.random.rand() < self.switch_prob 90 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 91 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 92 | elif self.mixup_alpha > 0.: 93 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 94 | elif self.cutmix_alpha > 0.: 95 | use_cutmix = True 96 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 97 | else: 98 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0." 99 | lam = float(lam_mix) 100 | return lam, use_cutmix 101 | 102 | def _mix_batch(self, x): 103 | lam, use_cutmix = self._params_per_batch() 104 | if lam == 1.: 105 | return 1. 106 | if use_cutmix: 107 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 108 | x.shape, lam, correct_lam=self.correct_lam) 109 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] 110 | else: 111 | x_flipped = x.flip(0).mul_(1. - lam) 112 | x.mul_(lam).add_(x_flipped) 113 | return lam 114 | 115 | def __call__(self, x, target): 116 | assert len(x) % 2 == 0, 'Batch size should be even when using this' # [B,C,H,W] -> [B,T,C,H,W] 117 | need_reshape = False 118 | if x.ndim == 5: 119 | need_reshape = True 120 | b,t,c,h,w = x.shape 121 | x = x.view(b,t*c,h,w) 122 | lam = self._mix_batch(x) 123 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 124 | if need_reshape: 125 | x = x.view(b,t,c,h,w) 126 | return x, target 127 | 128 | if __name__ == '__main__': 129 | SEED = 0 130 | torch.random.manual_seed(SEED) 131 | np.random.seed(SEED) 132 | mixupfn = Mixup(num_classes=4) 133 | x = torch.rand(2,2,1,10,10) 134 | label = [0, 1] 135 | print(x, label) 136 | y = torch.from_numpy(np.array(label)) 137 | x, y = mixupfn(x, y) 138 | print(x.shape, y.shape) 139 | print(x, y) -------------------------------------------------------------------------------- /model_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import warnings 5 | import argparse 6 | 7 | import kornia.augmentation as K 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.plugins import DDPPlugin 11 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 12 | import torch 13 | import torch.utils.data as data 14 | 15 | from data_trainer import KineticsDataModule 16 | from model_trainer import VideoTransformer 17 | import data_transform as T 18 | from utils import print_on_rank_zero 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='lr receiver') 23 | # Common 24 | parser.add_argument( 25 | '-epoch', type=int, required=True, 26 | help='the max epochs of training') 27 | parser.add_argument( 28 | '-batch_size', type=int, required=True, 29 | help='the batch size of data inputs') 30 | parser.add_argument( 31 | '-num_workers', type=int, default=4, 32 | help='the num workers of loading data') 33 | parser.add_argument( 34 | '-resume', default=False, action='store_true') 35 | parser.add_argument( 36 | '-resume_from_checkpoint', type=str, default=None, 37 | help='the pretrain params from specific path') 38 | parser.add_argument( 39 | '-log_interval', type=int, default=30, 40 | help='the intervals of logging') 41 | parser.add_argument( 42 | '-save_ckpt_freq', type=int, default=20, 43 | help='the intervals of saving model') 44 | parser.add_argument( 45 | '-objective', type=str, default='mim', 46 | help='the learning objective from [mim, supervised]') 47 | parser.add_argument( 48 | '-eval_metrics', type=str, default='finetune', 49 | help='the eval metrics choosen from [linear_prob, finetune]') 50 | 51 | # Environment 52 | parser.add_argument( 53 | '-gpus', nargs='+', type=int, default=-1, 54 | help='the avaiable gpus in this experiment') 55 | parser.add_argument( 56 | '-root_dir', type=str, required=True, 57 | help='the path to root dir for work space') 58 | 59 | # Data 60 | parser.add_argument( 61 | '-num_class', type=int, required=True, 62 | help='the num class of dataset used') 63 | parser.add_argument( 64 | '-num_samples_per_cls', type=int, default=10000, 65 | help='the num samples of per class') 66 | parser.add_argument( 67 | '-img_size', type=int, default=224, 68 | help='the size of processed image') 69 | parser.add_argument( 70 | '-num_frames', type=int, required=True, 71 | help='the mumber of frame sampling') 72 | parser.add_argument( 73 | '-frame_interval', type=int, required=True, 74 | help='the intervals of frame sampling') 75 | parser.add_argument( 76 | '-data_statics', type=str, default='kinetics', 77 | help='choose data statics from [imagenet, kinetics]') 78 | parser.add_argument( 79 | '-train_data_path', type=str, required=True, 80 | help='the path to train set') 81 | parser.add_argument( 82 | '-val_data_path', type=str, default=None, 83 | help='the path to val set') 84 | parser.add_argument( 85 | '-test_data_path', type=str, default=None, 86 | help='the path to test set') 87 | parser.add_argument( 88 | '-multi_crop', type=bool, default=False, 89 | help="""Whether or not to use multi crop.""") 90 | parser.add_argument( 91 | '-mixup', type=bool, default=False, 92 | help="""Whether or not to use multi crop.""") 93 | parser.add_argument( 94 | '-auto_augment', type=str, default=None, 95 | help='the used Autoaugment policy') 96 | 97 | # Model 98 | parser.add_argument( 99 | '-arch', type=str, default='timesformer', 100 | help='the choosen model arch from [timesformer, vivit]') 101 | parser.add_argument( 102 | '-attention_type', type=str, default='divided_space_time', 103 | help='the choosen attention type using in model') 104 | parser.add_argument( 105 | '-pretrain_pth', type=str, default=None, 106 | help='the path to the pretrain weights') 107 | parser.add_argument( 108 | '-weights_from', type=str, default='imagenet', 109 | help='the pretrain params from [imagenet, kinetics]') 110 | 111 | # Training/Optimization parameters 112 | parser.add_argument( 113 | '-seed', type=int, default=0, 114 | help='the seed of exp') 115 | parser.add_argument( 116 | '-optim_type', type=str, default='adamw', 117 | help='the optimizer using in the training') 118 | parser.add_argument( 119 | '-lr_schedule', type=str, default='cosine', 120 | help='the lr schedule using in the training') 121 | parser.add_argument( 122 | '-lr', type=float, required=True, 123 | help='the initial learning rate') 124 | parser.add_argument( 125 | '-layer_decay', type=float, default=0.75, 126 | help='the value of layer_decay') 127 | parser.add_argument( 128 | '--min_lr', type=float, default=1e-6, 129 | help="""Target LR at the end of optimization. We use a cosine LR schedule with linear warmup.""") 130 | parser.add_argument( 131 | '-use_fp16', type=bool, default=True, 132 | help="""Whether or not to use half precision for training. Improves training time and memory requirements, 133 | but can provoke instability and slight decay of performance. We recommend disabling 134 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""") 135 | parser.add_argument( 136 | '-weight_decay', type=float, default=0.05, 137 | help="""Initial value of the weight decay. With ViT, a smaller value at the beginning of training works well.""") 138 | parser.add_argument( 139 | '-weight_decay_end', type=float, default=0.05, 140 | help="""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by 141 | the end of training improves performance for ViTs.""") 142 | parser.add_argument( 143 | '-clip_grad', type=float, default=0, 144 | help="""Maximal parameter gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 145 | help optimization for larger ViT architectures. 0 for disabling.""") 146 | parser.add_argument( 147 | "-warmup_epochs", default=5, type=int, 148 | help="Number of epochs for the linear learning-rate warm up.") 149 | 150 | args = parser.parse_args() 151 | 152 | return args 153 | 154 | def single_run(): 155 | args = parse_args() 156 | warnings.filterwarnings('ignore') 157 | 158 | # linear learning rate scale 159 | if isinstance(args.gpus, int): 160 | num_gpus = torch.cuda.device_count() 161 | else: 162 | num_gpus = len(args.gpus) 163 | effective_batch_size = args.batch_size * num_gpus 164 | args.lr = args.lr * effective_batch_size / 256 165 | 166 | # Experiment Settings 167 | ROOT_DIR = args.root_dir 168 | exp_tag = (f'objective_{args.objective}_arch_{args.arch}_lr_{args.lr}_' 169 | f'optim_{args.optim_type}_lr_schedule_{args.lr_schedule}_' 170 | f'fp16_{args.use_fp16}_weight_decay_{args.weight_decay}_' 171 | f'weight_decay_end_{args.weight_decay_end}_warmup_epochs_{args.warmup_epochs}_' 172 | f'pretrain_{args.pretrain_pth}_weights_from_{args.weights_from}_seed_{args.seed}_' 173 | f'img_size_{args.img_size}_num_frames_{args.num_frames}_eval_metrics_{args.eval_metrics}_' 174 | f'frame_interval_{args.frame_interval}_mixup_{args.mixup}_' 175 | f'multi_crop_{args.multi_crop}_auto_augment_{args.auto_augment}_') 176 | ckpt_dir = os.path.join(ROOT_DIR, f'results/{exp_tag}/ckpt') 177 | log_dir = os.path.join(ROOT_DIR, f'results/{exp_tag}/log') 178 | os.makedirs(ckpt_dir, exist_ok=True) 179 | os.makedirs(log_dir, exist_ok=True) 180 | 181 | # Data 182 | do_eval = True if args.val_data_path is not None else False 183 | do_test = True if args.test_data_path is not None else False 184 | 185 | data_module = KineticsDataModule(configs=args, 186 | train_ann_path=args.train_data_path, 187 | val_ann_path=args.val_data_path, 188 | test_ann_path=args.test_data_path) 189 | 190 | # Resume from the last checkpoint 191 | if args.resume and not args.resume_from_checkpoint: 192 | args.resume_from_checkpoint = os.path.join(ckpt_dir, 'last_checkpoint.pth') 193 | 194 | # Trainer 195 | if args.arch == 'mvit' and args.objective == 'supervised': 196 | find_unused_parameters = True 197 | else: 198 | find_unused_parameters = False 199 | 200 | trainer = pl.Trainer( 201 | gpus=args.gpus, 202 | accelerator="ddp", 203 | precision=16, 204 | plugins=[DDPPlugin(find_unused_parameters=find_unused_parameters),], 205 | max_epochs=args.epoch, 206 | callbacks=[ 207 | LearningRateMonitor(logging_interval='step'), 208 | ], 209 | resume_from_checkpoint=args.resume_from_checkpoint, 210 | check_val_every_n_epoch=1, 211 | log_every_n_steps=args.log_interval, 212 | progress_bar_refresh_rate=args.log_interval, 213 | flush_logs_every_n_steps=args.log_interval*5) 214 | 215 | # To be reproducable 216 | torch.random.manual_seed(args.seed) 217 | np.random.seed(args.seed) 218 | random.seed(args.seed) 219 | pl.seed_everything(args.seed, workers=True) 220 | 221 | # Model 222 | model = VideoTransformer(configs=args, 223 | trainer=trainer, 224 | ckpt_dir=ckpt_dir, 225 | do_eval=do_eval, 226 | do_test=do_test) 227 | print_on_rank_zero(args) 228 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 229 | print_on_rank_zero(f'{timestamp} - INFO - Start running,') 230 | trainer.fit(model, data_module) 231 | 232 | if __name__ == '__main__': 233 | single_run() -------------------------------------------------------------------------------- /model_trainer.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import math 3 | import time 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torchvision 11 | from torchmetrics import Accuracy 12 | from timm.loss import SoftTargetCrossEntropy 13 | 14 | import utils 15 | from mixup import Mixup 16 | from optimizer import build_optimizer 17 | from transformer import ClassificationHead 18 | from video_transformer import TimeSformer, ViViT, MaskFeat 19 | 20 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, base_lr, objective, min_lr=5e-5, last_epoch=-1): 21 | """ Create a schedule with a learning rate that decreases following the 22 | values of the cosine function between 0 and `pi * cycles` after a warmup 23 | period during which it increases linearly between 0 and base_lr. 24 | """ 25 | # step means epochs here 26 | def lr_lambda(current_step): 27 | current_step += 1 28 | if current_step <= num_warmup_steps: 29 | return float(current_step) / float(max(1, num_warmup_steps)) # * base_lr 30 | progress = min(float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)), 1) 31 | if objective == 'mim': 32 | return 0.5 * (1. + math.cos(math.pi * progress)) 33 | else: 34 | factor = 0.5 * (1. + math.cos(math.pi * progress)) 35 | return factor*(1 - min_lr/base_lr) + min_lr/base_lr 36 | 37 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 38 | 39 | class VideoTransformer(pl.LightningModule): 40 | 41 | def __init__(self, 42 | configs, 43 | trainer, 44 | ckpt_dir, 45 | do_eval, 46 | do_test, 47 | n_crops=3): 48 | super().__init__() 49 | self.configs = configs 50 | self.trainer = trainer 51 | 52 | # build models 53 | if self.configs.objective =='mim': 54 | self.model = MaskFeat(pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2]], feature_dim=2*2*2*3*9) 55 | else: # supervised 56 | # load pretrain weights from pretrained weight path and model.init_weights method 57 | if self.configs.arch == 'vivit': 58 | self.model = ViViT( 59 | pretrain_pth=self.configs.pretrain_pth, 60 | weights_from=self.configs.weights_from, 61 | img_size=self.configs.img_size, 62 | num_frames=self.configs.num_frames, 63 | attention_type=self.configs.attention_type) 64 | elif self.configs.arch == 'timesformer': 65 | self.model = TimeSformer( 66 | pretrain_pth=self.configs.pretrain_pth, 67 | weights_from=self.configs.weights_from, 68 | img_size=self.configs.img_size, 69 | num_frames=self.configs.num_frames, 70 | attention_type=self.configs.attention_type) 71 | else: # arch-mvit 72 | self.model = MaskFeat( 73 | pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2]], 74 | feature_dim=2*2*2*3*9, 75 | pretrain_pth=self.configs.pretrain_pth, 76 | img_size=self.configs.img_size, 77 | num_frames=self.configs.num_frames) 78 | for name, param in self.model.decoder_pred.named_parameters(): 79 | param.requires_grad = False 80 | 81 | self.cls_head = ClassificationHead( 82 | self.configs.num_class, self.model.embed_dims, eval_metrics=self.configs.eval_metrics) 83 | 84 | self.max_top1_acc = 0 85 | self.train_top1_acc = Accuracy() 86 | self.train_top5_acc = Accuracy(top_k=5) 87 | if self.configs.mixup: 88 | self.mixup_fn = Mixup(num_classes=self.configs.num_class) 89 | self.loss_fn = SoftTargetCrossEntropy() 90 | else: 91 | self.loss_fn = nn.CrossEntropyLoss() 92 | 93 | # common 94 | self.iteration = 0 95 | self.data_start = 0 96 | self.ckpt_dir = ckpt_dir 97 | self.do_eval = do_eval 98 | self.do_test = do_test 99 | if self.do_eval: 100 | self.val_top1_acc = Accuracy() 101 | self.val_top5_acc = Accuracy(top_k=5) 102 | if self.do_test: 103 | self.n_crops = n_crops 104 | self.test_top1_acc = Accuracy() 105 | self.test_top5_acc = Accuracy(top_k=5) 106 | 107 | @torch.jit.ignore 108 | def no_weight_decay_keywords(self): 109 | return {'pos_embed', 'cls_token', 'mask_token'} 110 | 111 | def configure_optimizers(self): 112 | # build optimzer 113 | is_pretrain = not (self.configs.objective == 'supervised') 114 | if self.configs.objective == 'supervised' and self.configs.eval_metrics == 'linear_prob': 115 | model = self.cls_head.module if hasattr(self.cls_head, 'module') else self.cls_head 116 | optimizer = build_optimizer(self.configs, model, is_pretrain=is_pretrain) 117 | else: 118 | optimizer = build_optimizer(self.configs, self, is_pretrain=is_pretrain) 119 | 120 | # lr schedule 121 | lr_scheduler = None 122 | lr_schedule = self.configs.lr_schedule 123 | if lr_schedule == 'multistep': 124 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 125 | milestones=[5, 11], 126 | gamma=0.1) 127 | elif lr_schedule == 'cosine': 128 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, 129 | num_warmup_steps=self.configs.warmup_epochs, 130 | num_training_steps=self.trainer.max_epochs, 131 | base_lr=self.configs.lr, 132 | min_lr=self.configs.min_lr, 133 | objective=self.configs.objective) 134 | return [optimizer], [lr_scheduler] 135 | 136 | def parse_batch(self, batch, train): 137 | if self.configs.objective == 'mim': 138 | inputs, labels, mask, cube_marker, = *batch, 139 | return inputs, labels, mask, cube_marker 140 | else: 141 | inputs, labels, = *batch, 142 | if self.configs.mixup and train: 143 | inputs, labels = self.mixup_fn(inputs, labels) 144 | return inputs, labels 145 | 146 | # epoch schedule 147 | def _get_momentum(self, base_value, final_value): 148 | return final_value - (final_value - base_value) * (math.cos(math.pi * self.trainer.current_epoch / self.trainer.max_epochs) + 1) / 2 149 | 150 | def _weight_decay_update(self): 151 | for i, param_group in enumerate(self.optimizers().optimizer.param_groups): 152 | if i == 1: # only the first group is regularized 153 | param_group["weight_decay"] = self._get_momentum(base_value=self.configs.weight_decay, final_value=self.configs.weight_decay_end) 154 | 155 | def clip_gradients(self, clip_grad, norm_type=2): 156 | layer_norm = [] 157 | if self.configs.objective == 'supervised' and self.configs.eval_metrics == 'linear_prob': 158 | model_wo_ddp = self.cls_head.module if hasattr(self.cls_head, 'module') else self.cls_head 159 | else: 160 | model_wo_ddp = self.module if hasattr(self, 'module') else self 161 | for name, p in model_wo_ddp.named_parameters(): 162 | if p.grad is not None: 163 | param_norm = torch.norm(p.grad.detach(), norm_type) 164 | layer_norm.append(param_norm) 165 | if clip_grad: 166 | clip_coef = clip_grad / (param_norm + 1e-6) 167 | if clip_coef < 1: 168 | p.grad.data.mul_(clip_coef) 169 | total_grad_norm = torch.norm(torch.stack(layer_norm), norm_type) 170 | return total_grad_norm 171 | 172 | def log_step_state(self, data_time, top1_acc=0, top5_acc=0): 173 | self.log("time",float(f'{time.perf_counter()-self.data_start:.3f}'),prog_bar=True) 174 | self.log("data_time", data_time, prog_bar=True) 175 | if self.configs.objective == 'supervised': 176 | self.log("top1_acc",top1_acc,on_step=True,on_epoch=False,prog_bar=True) 177 | self.log("top5_acc",top5_acc,on_step=True,on_epoch=False,prog_bar=True) 178 | 179 | return None 180 | 181 | def get_progress_bar_dict(self): 182 | # don't show the version number 183 | items = super().get_progress_bar_dict() 184 | items.pop("v_num", None) 185 | 186 | return items 187 | 188 | # Trainer Pipeline 189 | def training_step(self, batch, batch_idx): 190 | data_time = float(f'{time.perf_counter() - self.data_start:.3f}') 191 | if self.configs.objective == 'mim': 192 | inputs, labels, mask, cube_marker = self.parse_batch(batch, train=True) 193 | preds, loss = self.model(inputs, labels, mask, cube_marker) 194 | self.log_step_state(data_time) 195 | return {'loss': loss, 'data_time': data_time} 196 | else: 197 | inputs, labels = self.parse_batch(batch, train=True) 198 | if self.configs.eval_metrics == 'linear_prob': 199 | with torch.no_grad(): 200 | self.model.eval() 201 | preds = self.model(inputs) 202 | else: 203 | if self.configs.arch == 'mvit': 204 | preds = self.model.forward_features(inputs)[:, 0] 205 | else: 206 | preds = self.model(inputs) 207 | preds = self.cls_head(preds) 208 | loss = self.loss_fn(preds, labels) 209 | if self.configs.mixup: 210 | top1_acc = self.train_top1_acc(preds.softmax(dim=-1), labels.argmax(-1)) 211 | top5_acc = self.train_top5_acc(preds.softmax(dim=-1), labels.argmax(-1)) 212 | else: 213 | top1_acc = self.train_top1_acc(preds.softmax(dim=-1), labels) 214 | top5_acc = self.train_top5_acc(preds.softmax(dim=-1), labels) 215 | self.log_step_state(data_time, top1_acc, top5_acc) 216 | return {'loss': loss, 'data_time': data_time} 217 | 218 | def on_after_backward(self): 219 | param_norms = self.clip_gradients(self.configs.clip_grad) 220 | self._weight_decay_update() 221 | # log learning daynamic 222 | lr = self.optimizers().optimizer.param_groups[0]['lr'] 223 | self.log("lr",lr,on_step=True,on_epoch=False,prog_bar=True) 224 | self.log("grad_norm",param_norms,on_step=True,on_epoch=False,prog_bar=True) 225 | 226 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, 227 | optimizer_closure, on_tpu, using_native_amp, using_lbfgs): 228 | 229 | optimizer.step(closure=optimizer_closure) 230 | self.data_start = time.perf_counter() 231 | self.iteration += 1 232 | 233 | def training_epoch_end(self, outputs): 234 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 235 | if self.configs.objective == 'supervised': 236 | mean_top1_acc = self.train_top1_acc.compute() 237 | mean_top5_acc = self.train_top5_acc.compute() 238 | self.print(f'{timestamp} - Evaluating mean ', 239 | f'top1_acc:{mean_top1_acc:.3f},', 240 | f'top5_acc:{mean_top5_acc:.3f} of current training epoch') 241 | self.train_top1_acc.reset() 242 | self.train_top5_acc.reset() 243 | 244 | # save last checkpoint 245 | save_path = osp.join(self.ckpt_dir, 'last_checkpoint.pth') 246 | self.trainer.save_checkpoint(save_path) 247 | 248 | if self.configs.objective != 'supervised' and (self.trainer.current_epoch+1) % self.configs.save_ckpt_freq == 0: 249 | save_path = osp.join(self.ckpt_dir, 250 | f'{timestamp}_'+ 251 | f'ep_{self.trainer.current_epoch}.pth') 252 | self.trainer.save_checkpoint(save_path) 253 | 254 | def validation_step(self, batch, batch_indx): 255 | if self.do_eval: 256 | inputs, labels = self.parse_batch(batch, train=False) 257 | if self.configs.eval_metrics == 'linear_prob': 258 | with torch.no_grad(): 259 | preds = self.model(inputs) 260 | else: 261 | if self.configs.arch == 'mvit': 262 | preds = self.model.forward_features(inputs)[:, 0] 263 | else: 264 | preds = self.model(inputs) 265 | preds = self.cls_head(preds) 266 | 267 | self.val_top1_acc(preds.softmax(dim=-1), labels) 268 | self.val_top5_acc(preds.softmax(dim=-1), labels) 269 | self.data_start = time.perf_counter() 270 | 271 | def validation_epoch_end(self, outputs): 272 | if self.do_eval: 273 | mean_top1_acc = self.val_top1_acc.compute() 274 | mean_top5_acc = self.val_top5_acc.compute() 275 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 276 | self.print(f'{timestamp} - Evaluating mean ', 277 | f'top1_acc:{mean_top1_acc:.3f}, ', 278 | f'top5_acc:{mean_top5_acc:.3f} of current validation epoch') 279 | self.val_top1_acc.reset() 280 | self.val_top5_acc.reset() 281 | 282 | # save best checkpoint 283 | if mean_top1_acc > self.max_top1_acc: 284 | save_path = osp.join(self.ckpt_dir, 285 | f'{timestamp}_'+ 286 | f'ep_{self.trainer.current_epoch}_'+ 287 | f'top1_acc_{mean_top1_acc:.3f}.pth') 288 | self.trainer.save_checkpoint(save_path) 289 | self.max_top1_acc = mean_top1_acc 290 | 291 | def test_step(self, batch, batch_idx): 292 | if self.do_test: 293 | inputs, labels = self.parse_batch(batch) 294 | preds = self.cls_head(self.model(inputs)) 295 | preds = preds.view(-1, self.n_crops, self.configs.num_class).mean(1) 296 | 297 | self.test_top1_acc(preds.softmax(dim=-1), labels) 298 | self.test_top5_acc(preds.softmax(dim=-1), labels) 299 | self.data_start = time.perf_counter() 300 | 301 | def test_epoch_end(self, outputs): 302 | if self.do_test: 303 | mean_top1_acc = self.test_top1_acc.compute() 304 | mean_top5_acc = self.test_top5_acc.compute() 305 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 306 | self.print(f'{timestamp} - Evaluating mean ', 307 | f'top1_acc:{mean_top1_acc:.3f}, ', 308 | f'top5_acc:{mean_top5_acc:.3f} of current test epoch') 309 | self.test_top1_acc.reset() 310 | self.test_top5_acc.reset() 311 | -------------------------------------------------------------------------------- /notebook/VideoTransformer_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# VideoTransformer\n", 8 | "\n", 9 | "TimeSformer(https://arxiv.org/abs/2102.05095), ViViT(https://arxiv.org/abs/2103.15691)\n", 10 | "\n", 11 | "Welcome to the demo notebook for VideoTransformer. We'll showcase the prediction result by the above pre-trained models." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Preliminaries\n", 19 | "\n", 20 | "This section contains initial setup. Run it first." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "colab": { 28 | "base_uri": "https://localhost:8080/" 29 | }, 30 | "id": "Gv1qdarnNNVE", 31 | "outputId": "eaac44da-a976-4fda-b323-03dfa6790e21", 32 | "pycharm": {} 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "!pip3 install --user torch\n", 37 | "!pip3 install --user torchvision\n", 38 | "!pip3 install --user matplotlib\n", 39 | "!pip3 install --user decord\n", 40 | "!pip3 install --user einops\n", 41 | "!pip3 install --user scikit-image\n", 42 | "!pip3 install --user pytorch-lightning" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "colab": { 50 | "base_uri": "https://localhost:8080/" 51 | }, 52 | "id": "tBBG_T32pzPH", 53 | "outputId": "76921b65-9e35-4260-fa0a-bbb484ac285c", 54 | "pycharm": {} 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "import torch\n", 59 | "import torch.nn as nn\n", 60 | "import numpy as np\n", 61 | "\n", 62 | "from einops import rearrange, reduce, repeat\n", 63 | "from IPython.display import display\n", 64 | "\n", 65 | "!git clone https://github.com/mx-mark/VideoTransformer-pytorch.git\n", 66 | "%cd VideoTransformer-pytorch\n", 67 | "\n", 68 | "import data_transform as T\n", 69 | "from dataset import DecordInit, load_annotation_data\n", 70 | "from transformer import PatchEmbed, TransformerContainer, ClassificationHead" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "### Note\n", 78 | "Please firstly dowload the weights and move to the current path `./VideoTransformer-pytorch/`\n", 79 | "1. TimeSformer-B pre-trained on K400 https://drive.google.com/file/d/1jLkS24jkpmakPi3e5J8KH3FOPv370zvo/view?usp=sharing\n", 80 | "2. ViViT-B pre-trained on K400 from https://drive.google.com/file/d/1-JVhSN3QHKUOLkXLWXWn5drdvKn0gPll/view?usp=sharing" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "id": "0yLko87R2_m4", 87 | "pycharm": {} 88 | }, 89 | "source": [ 90 | "## Video Transformer Model\n", 91 | "\n", 92 | "We here load the pretrained weights of the transformer model TimeSformer-B or ViViT-B." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 19, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "class TimeSformer(nn.Module):\n", 102 | " \"\"\"TimeSformer. A PyTorch impl of `Is Space-Time Attention All You Need for\n", 103 | " Video Understanding? `_\n", 104 | "\n", 105 | " Args:\n", 106 | " num_frames (int): Number of frames in the video.\n", 107 | " img_size (int | tuple): Size of input image.\n", 108 | " patch_size (int): Size of one patch.\n", 109 | " pretrained (str | None): Name of pretrained model. Default: None.\n", 110 | " embed_dims (int): Dimensions of embedding. Defaults to 768.\n", 111 | " num_heads (int): Number of parallel attention heads in\n", 112 | " TransformerCoder. Defaults to 12.\n", 113 | " num_transformer_layers (int): Number of transformer layers. Defaults to\n", 114 | " 12.\n", 115 | " in_channels (int): Channel num of input features. Defaults to 3.\n", 116 | " dropout_p (float): Probability of dropout layer. Defaults to 0.\n", 117 | " conv_type (str): Type of the convolution in PatchEmbed layer. Defaults to Conv2d.\n", 118 | " attention_type (str): Type of attentions in TransformerCoder. Choices\n", 119 | " are 'divided_space_time', 'space_only' and 'joint_space_time'.\n", 120 | " Defaults to 'divided_space_time'.\n", 121 | " norm_layer (dict): Config for norm layers. Defaults to nn.LayerNorm.\n", 122 | " return_cls_token (bool): Whether to use cls_token to predict class label.\n", 123 | " \"\"\"\n", 124 | " supported_attention_types = [\n", 125 | " 'divided_space_time', 'space_only', 'joint_space_time'\n", 126 | " ]\n", 127 | "\n", 128 | " def __init__(self,\n", 129 | " num_frames,\n", 130 | " img_size,\n", 131 | " patch_size,\n", 132 | " embed_dims=768,\n", 133 | " num_heads=12,\n", 134 | " num_transformer_layers=12,\n", 135 | " in_channels=3,\n", 136 | " conv_type='Conv2d',\n", 137 | " dropout_p=0.,\n", 138 | " attention_type='divided_space_time',\n", 139 | " norm_layer=nn.LayerNorm,\n", 140 | " return_cls_token=True,\n", 141 | " **kwargs):\n", 142 | " super().__init__()\n", 143 | " assert attention_type in self.supported_attention_types, (\n", 144 | " f'Unsupported Attention Type {attention_type}!')\n", 145 | "\n", 146 | " self.num_frames = num_frames\n", 147 | " self.embed_dims = embed_dims\n", 148 | " self.num_transformer_layers = num_transformer_layers\n", 149 | " self.attention_type = attention_type\n", 150 | " self.conv_type = conv_type\n", 151 | " self.return_cls_token = return_cls_token\n", 152 | "\n", 153 | " #tokenize & position embedding\n", 154 | " self.patch_embed = PatchEmbed(\n", 155 | " img_size=img_size,\n", 156 | " patch_size=patch_size,\n", 157 | " in_channels=in_channels,\n", 158 | " embed_dims=embed_dims,\n", 159 | " conv_type=conv_type)\n", 160 | " num_patches = self.patch_embed.num_patches\n", 161 | "\n", 162 | " # Divided Space Time Attention\n", 163 | " operator_order = ['time_attn','space_attn','ffn']\n", 164 | " container = TransformerContainer(\n", 165 | " num_transformer_layers=num_transformer_layers,\n", 166 | " embed_dims=embed_dims,\n", 167 | " num_heads=num_heads,\n", 168 | " num_frames=num_frames,\n", 169 | " norm_layer=norm_layer,\n", 170 | " hidden_channels=embed_dims*4,\n", 171 | " operator_order=operator_order)\n", 172 | "\n", 173 | " self.transformer_layers = container\n", 174 | " self.norm = norm_layer(embed_dims, eps=1e-6)\n", 175 | "\n", 176 | " self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dims))\n", 177 | " num_patches = num_patches + 1\n", 178 | "\n", 179 | " # spatial pos_emb\n", 180 | " self.pos_embed = nn.Parameter(torch.zeros(1,num_patches,embed_dims))\n", 181 | " self.drop_after_pos = nn.Dropout(p=dropout_p)\n", 182 | "\n", 183 | " # temporal pos_emb\n", 184 | " self.time_embed = nn.Parameter(torch.zeros(1,num_frames,embed_dims))\n", 185 | " self.drop_after_time = nn.Dropout(p=dropout_p)\n", 186 | "\n", 187 | " def interpolate_pos_encoding(self, x, w, h):\n", 188 | " npatch = x.shape[1] - 1\n", 189 | " N = self.pos_embed.shape[1] - 1\n", 190 | " if npatch == N and w == h:\n", 191 | " return self.pos_embed\n", 192 | " class_pos_embed = self.pos_embed[:, 0]\n", 193 | " patch_pos_embed = self.pos_embed[:, 1:]\n", 194 | " dim = x.shape[-1]\n", 195 | " w0 = w // self.patch_embed.patch_size[0]\n", 196 | " h0 = h // self.patch_embed.patch_size[0]\n", 197 | " # we add a small number to avoid floating point error in the interpolation\n", 198 | " # see discussion at https://github.com/facebookresearch/dino/issues/8\n", 199 | " w0, h0 = w0 + 0.1, h0 + 0.1\n", 200 | " patch_pos_embed = nn.functional.interpolate(\n", 201 | " patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n", 202 | " scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n", 203 | " mode='bicubic',\n", 204 | " )\n", 205 | " assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n", 206 | " patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n", 207 | " return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n", 208 | "\n", 209 | " def forward(self, x):\n", 210 | " #Tokenize\n", 211 | " b, t, c, h, w = x.shape\n", 212 | " x = self.patch_embed(x)\n", 213 | "\n", 214 | " # Add Position Embedding\n", 215 | " cls_tokens = repeat(self.cls_token, 'b ... -> (repeat b) ...', repeat=x.shape[0])\n", 216 | " x = torch.cat((cls_tokens, x), dim=1)\n", 217 | " x = x + self.interpolate_pos_encoding(x, w, h) #self.pos_embed\n", 218 | " x = self.drop_after_pos(x)\n", 219 | "\n", 220 | " # Add Time Embedding\n", 221 | " cls_tokens = x[:b, 0, :].unsqueeze(1)\n", 222 | " x = rearrange(x[:, 1:, :], '(b t) p d -> (b p) t d', b=b)\n", 223 | " x = x + self.time_embed\n", 224 | " x = rearrange(x, '(b p) t d -> b (p t) d', b=b)\n", 225 | " x = torch.cat((cls_tokens, x), dim=1)\n", 226 | " x = self.drop_after_time(x)\n", 227 | "\n", 228 | " # Video transformer forward\n", 229 | " x = self.transformer_layers(x)\n", 230 | "\n", 231 | " x = self.norm(x)\n", 232 | " # Return Class Token\n", 233 | " if self.return_cls_token:\n", 234 | " return x[:, 0]\n", 235 | " else:\n", 236 | " return x[:, 1:].mean(1)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 21, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "class ViViT(nn.Module):\n", 246 | " \"\"\"ViViT. A PyTorch impl of `ViViT: A Video Vision Transformer`\n", 247 | " \n", 248 | "\n", 249 | " Args:\n", 250 | " num_frames (int): Number of frames in the video.\n", 251 | " img_size (int | tuple): Size of input image.\n", 252 | " patch_size (int): Size of one patch.\n", 253 | " pretrained (str | None): Name of pretrained model. Default: None.\n", 254 | " embed_dims (int): Dimensions of embedding. Defaults to 768.\n", 255 | " num_heads (int): Number of parallel attention heads. Defaults to 12.\n", 256 | " num_transformer_layers (int): Number of transformer layers. Defaults to 12.\n", 257 | " in_channels (int): Channel num of input features. Defaults to 3.\n", 258 | " dropout_p (float): Probability of dropout layer. Defaults to 0..\n", 259 | " tube_size (int): Dimension of the kernel size in Conv3d. Defaults to 2.\n", 260 | " conv_type (str): Type of the convolution in PatchEmbed layer. Defaults to Conv3d.\n", 261 | " attention_type (str): Type of attentions in TransformerCoder. Choices\n", 262 | " are 'divided_space_time', 'fact_encoder' and 'joint_space_time'.\n", 263 | " Defaults to 'fact_encoder'.\n", 264 | " norm_layer (dict): Config for norm layers. Defaults to nn.LayerNorm.\n", 265 | " copy_strategy (str): Copy or Initial to zero towards the new additional layer.\n", 266 | " extend_strategy (str): How to initialize the weights of Conv3d from pre-trained Conv2d.\n", 267 | " use_learnable_pos_emb (bool): Whether to use learnable position embeddings.\n", 268 | " return_cls_token (bool): Whether to use cls_token to predict class label.\n", 269 | " \"\"\"\n", 270 | " supported_attention_types = [\n", 271 | " 'fact_encoder', 'joint_space_time', 'divided_space_time'\n", 272 | " ]\n", 273 | "\n", 274 | " def __init__(self,\n", 275 | " num_frames,\n", 276 | " img_size,\n", 277 | " patch_size,\n", 278 | " embed_dims=768,\n", 279 | " num_heads=12,\n", 280 | " num_transformer_layers=12,\n", 281 | " in_channels=3,\n", 282 | " dropout_p=0.,\n", 283 | " tube_size=2,\n", 284 | " conv_type='Conv3d',\n", 285 | " attention_type='fact_encoder',\n", 286 | " norm_layer=nn.LayerNorm,\n", 287 | " return_cls_token=True,\n", 288 | " **kwargs):\n", 289 | " super().__init__()\n", 290 | " assert attention_type in self.supported_attention_types, (\n", 291 | " f'Unsupported Attention Type {attention_type}!')\n", 292 | "\n", 293 | " num_frames = num_frames//tube_size\n", 294 | " self.num_frames = num_frames\n", 295 | " self.embed_dims = embed_dims\n", 296 | " self.num_transformer_layers = num_transformer_layers\n", 297 | " self.attention_type = attention_type\n", 298 | " self.conv_type = conv_type\n", 299 | " self.tube_size = tube_size\n", 300 | " self.num_time_transformer_layers = 4\n", 301 | " self.return_cls_token = return_cls_token\n", 302 | "\n", 303 | " #tokenize & position embedding\n", 304 | " self.patch_embed = PatchEmbed(\n", 305 | " img_size=img_size,\n", 306 | " patch_size=patch_size,\n", 307 | " in_channels=in_channels,\n", 308 | " embed_dims=embed_dims,\n", 309 | " tube_size=tube_size,\n", 310 | " conv_type=conv_type)\n", 311 | " num_patches = self.patch_embed.num_patches\n", 312 | "\n", 313 | " # Divided Space Time Transformer Encoder - Model 2\n", 314 | " transformer_layers = nn.ModuleList([])\n", 315 | "\n", 316 | " spatial_transformer = TransformerContainer(\n", 317 | " num_transformer_layers=num_transformer_layers,\n", 318 | " embed_dims=embed_dims,\n", 319 | " num_heads=num_heads,\n", 320 | " num_frames=num_frames,\n", 321 | " norm_layer=norm_layer,\n", 322 | " hidden_channels=embed_dims*4,\n", 323 | " operator_order=['self_attn','ffn'])\n", 324 | "\n", 325 | " temporal_transformer = TransformerContainer(\n", 326 | " num_transformer_layers=self.num_time_transformer_layers,\n", 327 | " embed_dims=embed_dims,\n", 328 | " num_heads=num_heads,\n", 329 | " num_frames=num_frames,\n", 330 | " norm_layer=norm_layer,\n", 331 | " hidden_channels=embed_dims*4,\n", 332 | " operator_order=['self_attn','ffn'])\n", 333 | "\n", 334 | " transformer_layers.append(spatial_transformer)\n", 335 | " transformer_layers.append(temporal_transformer)\n", 336 | "\n", 337 | " self.transformer_layers = transformer_layers\n", 338 | " self.norm = norm_layer(embed_dims, eps=1e-6)\n", 339 | "\n", 340 | " self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dims))\n", 341 | " # whether to add one cls_token in temporal pos_enb\n", 342 | " num_frames = num_frames + 1\n", 343 | " num_patches = num_patches + 1\n", 344 | "\n", 345 | " self.pos_embed = nn.Parameter(torch.zeros(1,num_patches,embed_dims))\n", 346 | " self.time_embed = nn.Parameter(torch.zeros(1,num_frames,embed_dims))\n", 347 | " self.drop_after_pos = nn.Dropout(p=dropout_p)\n", 348 | " self.drop_after_time = nn.Dropout(p=dropout_p)\n", 349 | " \n", 350 | " def interpolate_pos_encoding(self, x, w, h):\n", 351 | " npatch = x.shape[1] - 1\n", 352 | " N = self.pos_embed.shape[1] - 1\n", 353 | " if npatch == N and w == h:\n", 354 | " return self.pos_embed\n", 355 | " class_pos_embed = self.pos_embed[:, 0]\n", 356 | " patch_pos_embed = self.pos_embed[:, 1:]\n", 357 | " dim = x.shape[-1]\n", 358 | " w0 = w // self.patch_embed.patch_size[0]\n", 359 | " h0 = h // self.patch_embed.patch_size[0]\n", 360 | " # we add a small number to avoid floating point error in the interpolation\n", 361 | " # see discussion at https://github.com/facebookresearch/dino/issues/8\n", 362 | " w0, h0 = w0 + 0.1, h0 + 0.1\n", 363 | " patch_pos_embed = nn.functional.interpolate(\n", 364 | " patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n", 365 | " scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n", 366 | " mode='bicubic',\n", 367 | " )\n", 368 | " assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n", 369 | " patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n", 370 | " return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n", 371 | "\n", 372 | " def forward(self, x):\n", 373 | " #Tokenize\n", 374 | " b, t, c, h, w = x.shape\n", 375 | " x = self.patch_embed(x)\n", 376 | "\n", 377 | " # Add Position Embedding\n", 378 | " cls_tokens = repeat(self.cls_token, 'b ... -> (repeat b) ...', repeat=x.shape[0])\n", 379 | " x = torch.cat((cls_tokens, x), dim=1)\n", 380 | " x = x + self.interpolate_pos_encoding(x, w, h)\n", 381 | " x = self.drop_after_pos(x)\n", 382 | "\n", 383 | " # fact encoder - CRNN style\n", 384 | " spatial_transformer, temporal_transformer, = *self.transformer_layers,\n", 385 | " x = spatial_transformer(x)\n", 386 | "\n", 387 | " # Add Time Embedding\n", 388 | " cls_tokens = x[:b, 0, :].unsqueeze(1)\n", 389 | " x = rearrange(x[:, 1:, :], '(b t) p d -> b t p d', b=b)\n", 390 | " x = reduce(x, 'b t p d -> b t d', 'mean')\n", 391 | " x = torch.cat((cls_tokens, x), dim=1)\n", 392 | " x = x + self.time_embed\n", 393 | " x = self.drop_after_time(x)\n", 394 | "\n", 395 | " x = temporal_transformer(x)\n", 396 | "\n", 397 | " x = self.norm(x)\n", 398 | " # Return Class Token\n", 399 | " if self.return_cls_token:\n", 400 | " return x[:, 0]\n", 401 | " else:\n", 402 | " return x[:, 1:].mean(1)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 23, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "def replace_state_dict(state_dict):\n", 412 | "\tfor old_key in list(state_dict.keys()):\n", 413 | "\t\tif old_key.startswith('model'):\n", 414 | "\t\t\tnew_key = old_key[6:]\n", 415 | "\t\t\tstate_dict[new_key] = state_dict.pop(old_key)\n", 416 | "\t\telse:\n", 417 | "\t\t\tnew_key = old_key[9:]\n", 418 | "\t\t\tstate_dict[new_key] = state_dict.pop(old_key)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": 25, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "def init_from_pretrain_(module, pretrained, init_module):\n", 428 | " if torch.cuda.is_available():\n", 429 | " state_dict = torch.load(pretrained)\n", 430 | " else:\n", 431 | " state_dict = torch.load(pretrained, map_location=torch.device('cpu'))\n", 432 | " if init_module == 'transformer':\n", 433 | " replace_state_dict(state_dict)\n", 434 | " elif init_module == 'cls_head':\n", 435 | " replace_state_dict(state_dict)\n", 436 | " else:\n", 437 | " raise TypeError(f'pretrained weights do not include the {init_module} module')\n", 438 | " msg = module.load_state_dict(state_dict, strict=False)\n", 439 | " return msg" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": { 446 | "colab": { 447 | "base_uri": "https://localhost:8080/", 448 | "height": 179 449 | }, 450 | "id": "u5J7lGPJ2bLJ", 451 | "outputId": "41c0568b-082f-4609-8a5e-9ff4b80ffd92", 452 | "pycharm": {} 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") \n", 457 | "num_frames = 8\n", 458 | "frame_interval = 32\n", 459 | "num_class = 400\n", 460 | "arch = 'timesformer' # turn to vivit for initializing vivit model\n", 461 | "\n", 462 | "if arch == 'timesformer':\n", 463 | " pretrain_pth = './timesformer_k400.pth'\n", 464 | " model = TimeSformer(num_frames=num_frames,\n", 465 | " img_size=224,\n", 466 | " patch_size=16,\n", 467 | " embed_dims=768,\n", 468 | " in_channels=3,\n", 469 | " attention_type='divided_space_time',\n", 470 | " return_cls_token=True)\n", 471 | "elif arch == 'vivit':\n", 472 | " pretrain_pth = './vivit_model.pth'\n", 473 | " num_frames = num_frames * 2\n", 474 | " frame_interval = frame_interval // 2\n", 475 | " model = ViViT(num_frames=num_frames,\n", 476 | " img_size=224,\n", 477 | " patch_size=16,\n", 478 | " embed_dims=768,\n", 479 | " in_channels=3,\n", 480 | " attention_type='fact_encoder',\n", 481 | " return_cls_token=True)\n", 482 | "else:\n", 483 | " raise TypeError(f'not supported arch type {arch}, chosen in (timesformer, vivit)')\n", 484 | "\n", 485 | "cls_head = ClassificationHead(num_classes=num_class, in_channels=768)\n", 486 | "msg_trans = init_from_pretrain_(model, pretrain_pth, init_module='transformer')\n", 487 | "msg_cls = init_from_pretrain_(cls_head, pretrain_pth, init_module='cls_head')\n", 488 | "model.eval()\n", 489 | "cls_head.eval()\n", 490 | "model = model.to(device)\n", 491 | "cls_head = cls_head.to(device)\n", 492 | "print(f'load model finished, the missing key of transformer is:{msg_trans[0]}, cls is:{msg_cls[0]}')" 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": {}, 498 | "source": [ 499 | "## Data preprocess\n", 500 | "\n", 501 | "Here we show the video demo and transform the video input for the model processing." 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "from IPython.display import display, HTML\n", 511 | "\n", 512 | "video_path = './demo/YABnJL_bDzw.mp4'\n", 513 | "html_str = '''\n", 514 | "\n", 515 | "'''.format(video_path)\n", 516 | "display(HTML(html_str))" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 54, 522 | "metadata": {}, 523 | "outputs": [], 524 | "source": [ 525 | "# Prepare data preprocess\n", 526 | "mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)\n", 527 | "data_transform = T.Compose([\n", 528 | " T.Resize(scale_range=(-1, 256)),\n", 529 | " T.ThreeCrop(size=224),\n", 530 | " T.ToTensor(),\n", 531 | " T.Normalize(mean, std)\n", 532 | " ])\n", 533 | "temporal_sample = T.TemporalRandomCrop(num_frames*frame_interval)\n", 534 | "\n", 535 | "# Sampling video frames\n", 536 | "video_decoder = DecordInit()\n", 537 | "v_reader = video_decoder(video_path)\n", 538 | "total_frames = len(v_reader)\n", 539 | "start_frame_ind, end_frame_ind = temporal_sample(total_frames)\n", 540 | "if end_frame_ind-start_frame_ind < num_frames:\n", 541 | " raise ValueError(f'the total frames of the video {video_path} is less than {num_frames}')\n", 542 | "frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, num_frames, dtype=int)\n", 543 | "video = v_reader.get_batch(frame_indice).asnumpy()\n", 544 | "del v_reader\n", 545 | "\n", 546 | "video = torch.from_numpy(video).permute(0,3,1,2) # Video transform: T C H W\n", 547 | "data_transform.randomize_parameters()\n", 548 | "video = data_transform(video)\n", 549 | "video = video.to(device)" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": {}, 555 | "source": [ 556 | "## Video Classification\n", 557 | "\n", 558 | "Here we use the pre-trained video transformer to classify the input video." 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": null, 564 | "metadata": {}, 565 | "outputs": [], 566 | "source": [ 567 | "# Predict class label\n", 568 | "with torch.no_grad():\n", 569 | " logits = model(video)\n", 570 | " output = cls_head(logits)\n", 571 | " output = output.view(3, 400).mean(0)\n", 572 | " cls_pred = output.argmax().item()\n", 573 | "\n", 574 | "class_map = './k400_classmap.json'\n", 575 | "class_map = load_annotation_data(class_map)\n", 576 | "for key, value in class_map.items():\n", 577 | " if int(value) == int(cls_pred):\n", 578 | " print(f'the shape of ouptut: {output.shape}, and the prediction is: {key}')\n", 579 | " break" 580 | ] 581 | } 582 | ], 583 | "metadata": { 584 | "accelerator": "GPU", 585 | "colab": { 586 | "collapsed_sections": [], 587 | "name": "iBOT_demo", 588 | "provenance": [], 589 | "toc_visible": true 590 | }, 591 | "kernelspec": { 592 | "display_name": "Python 3.8 (XPython)", 593 | "language": "python", 594 | "name": "xpython" 595 | }, 596 | "language_info": { 597 | "file_extension": ".py", 598 | "mimetype": "text/x-python", 599 | "name": "python", 600 | "version": "3.8.6" 601 | }, 602 | "widgets": { 603 | "application/vnd.jupyter.widget-state+json": { 604 | "10cb415f29954cb0a511e263fcb6c69f": { 605 | "model_module": "@jupyter-widgets/controls", 606 | "model_module_version": "1.5.0", 607 | "model_name": "HTMLModel", 608 | "state": { 609 | "_dom_classes": [], 610 | "_model_module": "@jupyter-widgets/controls", 611 | "_model_module_version": "1.5.0", 612 | "_model_name": "HTMLModel", 613 | "_view_count": null, 614 | "_view_module": "@jupyter-widgets/controls", 615 | "_view_module_version": "1.5.0", 616 | "_view_name": "HTMLView", 617 | "description": "", 618 | "description_tooltip": null, 619 | "layout": "IPY_MODEL_ecfd9523d50f4f7f87083460dc0e3dd7", 620 | "placeholder": "​", 621 | "style": "IPY_MODEL_6d68baa05ccc427890286e0e74452155", 622 | "value": " 327M/327M [00:09<00:00, 36.6MB/s]" 623 | } 624 | }, 625 | "3cb57e69cf054551b1978ddd6be3553b": { 626 | "model_module": "@jupyter-widgets/controls", 627 | "model_module_version": "1.5.0", 628 | "model_name": "HTMLModel", 629 | "state": { 630 | "_dom_classes": [], 631 | "_model_module": "@jupyter-widgets/controls", 632 | "_model_module_version": "1.5.0", 633 | "_model_name": "HTMLModel", 634 | "_view_count": null, 635 | "_view_module": "@jupyter-widgets/controls", 636 | "_view_module_version": "1.5.0", 637 | "_view_name": "HTMLView", 638 | "description": "", 639 | "description_tooltip": null, 640 | "layout": "IPY_MODEL_5bbd072a4a424ef88d5354793942a84b", 641 | "placeholder": "​", 642 | "style": "IPY_MODEL_6668b553344743e2a1450251c5a237e2", 643 | "value": "100%" 644 | } 645 | }, 646 | "557451188ee74f3698a8b358ee70d29d": { 647 | "model_module": "@jupyter-widgets/controls", 648 | "model_module_version": "1.5.0", 649 | "model_name": "FloatProgressModel", 650 | "state": { 651 | "_dom_classes": [], 652 | "_model_module": "@jupyter-widgets/controls", 653 | "_model_module_version": "1.5.0", 654 | "_model_name": "FloatProgressModel", 655 | "_view_count": null, 656 | "_view_module": "@jupyter-widgets/controls", 657 | "_view_module_version": "1.5.0", 658 | "_view_name": "ProgressView", 659 | "bar_style": "success", 660 | "description": "", 661 | "description_tooltip": null, 662 | "layout": "IPY_MODEL_af53b2f6e65b43fb8b3df5c656b40c92", 663 | "max": 343279349, 664 | "min": 0, 665 | "orientation": "horizontal", 666 | "style": "IPY_MODEL_9e962b8959ca4e90b496ab36d764f073", 667 | "value": 343279349 668 | } 669 | }, 670 | "5bbd072a4a424ef88d5354793942a84b": { 671 | "model_module": "@jupyter-widgets/base", 672 | "model_module_version": "1.2.0", 673 | "model_name": "LayoutModel", 674 | "state": { 675 | "_model_module": "@jupyter-widgets/base", 676 | "_model_module_version": "1.2.0", 677 | "_model_name": "LayoutModel", 678 | "_view_count": null, 679 | "_view_module": "@jupyter-widgets/base", 680 | "_view_module_version": "1.2.0", 681 | "_view_name": "LayoutView", 682 | "align_content": null, 683 | "align_items": null, 684 | "align_self": null, 685 | "border": null, 686 | "bottom": null, 687 | "display": null, 688 | "flex": null, 689 | "flex_flow": null, 690 | "grid_area": null, 691 | "grid_auto_columns": null, 692 | "grid_auto_flow": null, 693 | "grid_auto_rows": null, 694 | "grid_column": null, 695 | "grid_gap": null, 696 | "grid_row": null, 697 | "grid_template_areas": null, 698 | "grid_template_columns": null, 699 | "grid_template_rows": null, 700 | "height": null, 701 | "justify_content": null, 702 | "justify_items": null, 703 | "left": null, 704 | "margin": null, 705 | "max_height": null, 706 | "max_width": null, 707 | "min_height": null, 708 | "min_width": null, 709 | "object_fit": null, 710 | "object_position": null, 711 | "order": null, 712 | "overflow": null, 713 | "overflow_x": null, 714 | "overflow_y": null, 715 | "padding": null, 716 | "right": null, 717 | "top": null, 718 | "visibility": null, 719 | "width": null 720 | } 721 | }, 722 | "6668b553344743e2a1450251c5a237e2": { 723 | "model_module": "@jupyter-widgets/controls", 724 | "model_module_version": "1.5.0", 725 | "model_name": "DescriptionStyleModel", 726 | "state": { 727 | "_model_module": "@jupyter-widgets/controls", 728 | "_model_module_version": "1.5.0", 729 | "_model_name": "DescriptionStyleModel", 730 | "_view_count": null, 731 | "_view_module": "@jupyter-widgets/base", 732 | "_view_module_version": "1.2.0", 733 | "_view_name": "StyleView", 734 | "description_width": "" 735 | } 736 | }, 737 | "6d68baa05ccc427890286e0e74452155": { 738 | "model_module": "@jupyter-widgets/controls", 739 | "model_module_version": "1.5.0", 740 | "model_name": "DescriptionStyleModel", 741 | "state": { 742 | "_model_module": "@jupyter-widgets/controls", 743 | "_model_module_version": "1.5.0", 744 | "_model_name": "DescriptionStyleModel", 745 | "_view_count": null, 746 | "_view_module": "@jupyter-widgets/base", 747 | "_view_module_version": "1.2.0", 748 | "_view_name": "StyleView", 749 | "description_width": "" 750 | } 751 | }, 752 | "8902dbcd84b4439a987e908be8c8e19e": { 753 | "model_module": "@jupyter-widgets/controls", 754 | "model_module_version": "1.5.0", 755 | "model_name": "HBoxModel", 756 | "state": { 757 | "_dom_classes": [], 758 | "_model_module": "@jupyter-widgets/controls", 759 | "_model_module_version": "1.5.0", 760 | "_model_name": "HBoxModel", 761 | "_view_count": null, 762 | "_view_module": "@jupyter-widgets/controls", 763 | "_view_module_version": "1.5.0", 764 | "_view_name": "HBoxView", 765 | "box_style": "", 766 | "children": [ 767 | "IPY_MODEL_3cb57e69cf054551b1978ddd6be3553b", 768 | "IPY_MODEL_557451188ee74f3698a8b358ee70d29d", 769 | "IPY_MODEL_10cb415f29954cb0a511e263fcb6c69f" 770 | ], 771 | "layout": "IPY_MODEL_95c73030d06d4d578252f47989c208fc" 772 | } 773 | }, 774 | "95c73030d06d4d578252f47989c208fc": { 775 | "model_module": "@jupyter-widgets/base", 776 | "model_module_version": "1.2.0", 777 | "model_name": "LayoutModel", 778 | "state": { 779 | "_model_module": "@jupyter-widgets/base", 780 | "_model_module_version": "1.2.0", 781 | "_model_name": "LayoutModel", 782 | "_view_count": null, 783 | "_view_module": "@jupyter-widgets/base", 784 | "_view_module_version": "1.2.0", 785 | "_view_name": "LayoutView", 786 | "align_content": null, 787 | "align_items": null, 788 | "align_self": null, 789 | "border": null, 790 | "bottom": null, 791 | "display": null, 792 | "flex": null, 793 | "flex_flow": null, 794 | "grid_area": null, 795 | "grid_auto_columns": null, 796 | "grid_auto_flow": null, 797 | "grid_auto_rows": null, 798 | "grid_column": null, 799 | "grid_gap": null, 800 | "grid_row": null, 801 | "grid_template_areas": null, 802 | "grid_template_columns": null, 803 | "grid_template_rows": null, 804 | "height": null, 805 | "justify_content": null, 806 | "justify_items": null, 807 | "left": null, 808 | "margin": null, 809 | "max_height": null, 810 | "max_width": null, 811 | "min_height": null, 812 | "min_width": null, 813 | "object_fit": null, 814 | "object_position": null, 815 | "order": null, 816 | "overflow": null, 817 | "overflow_x": null, 818 | "overflow_y": null, 819 | "padding": null, 820 | "right": null, 821 | "top": null, 822 | "visibility": null, 823 | "width": null 824 | } 825 | }, 826 | "9e962b8959ca4e90b496ab36d764f073": { 827 | "model_module": "@jupyter-widgets/controls", 828 | "model_module_version": "1.5.0", 829 | "model_name": "ProgressStyleModel", 830 | "state": { 831 | "_model_module": "@jupyter-widgets/controls", 832 | "_model_module_version": "1.5.0", 833 | "_model_name": "ProgressStyleModel", 834 | "_view_count": null, 835 | "_view_module": "@jupyter-widgets/base", 836 | "_view_module_version": "1.2.0", 837 | "_view_name": "StyleView", 838 | "bar_color": null, 839 | "description_width": "" 840 | } 841 | }, 842 | "af53b2f6e65b43fb8b3df5c656b40c92": { 843 | "model_module": "@jupyter-widgets/base", 844 | "model_module_version": "1.2.0", 845 | "model_name": "LayoutModel", 846 | "state": { 847 | "_model_module": "@jupyter-widgets/base", 848 | "_model_module_version": "1.2.0", 849 | "_model_name": "LayoutModel", 850 | "_view_count": null, 851 | "_view_module": "@jupyter-widgets/base", 852 | "_view_module_version": "1.2.0", 853 | "_view_name": "LayoutView", 854 | "align_content": null, 855 | "align_items": null, 856 | "align_self": null, 857 | "border": null, 858 | "bottom": null, 859 | "display": null, 860 | "flex": null, 861 | "flex_flow": null, 862 | "grid_area": null, 863 | "grid_auto_columns": null, 864 | "grid_auto_flow": null, 865 | "grid_auto_rows": null, 866 | "grid_column": null, 867 | "grid_gap": null, 868 | "grid_row": null, 869 | "grid_template_areas": null, 870 | "grid_template_columns": null, 871 | "grid_template_rows": null, 872 | "height": null, 873 | "justify_content": null, 874 | "justify_items": null, 875 | "left": null, 876 | "margin": null, 877 | "max_height": null, 878 | "max_width": null, 879 | "min_height": null, 880 | "min_width": null, 881 | "object_fit": null, 882 | "object_position": null, 883 | "order": null, 884 | "overflow": null, 885 | "overflow_x": null, 886 | "overflow_y": null, 887 | "padding": null, 888 | "right": null, 889 | "top": null, 890 | "visibility": null, 891 | "width": null 892 | } 893 | }, 894 | "ecfd9523d50f4f7f87083460dc0e3dd7": { 895 | "model_module": "@jupyter-widgets/base", 896 | "model_module_version": "1.2.0", 897 | "model_name": "LayoutModel", 898 | "state": { 899 | "_model_module": "@jupyter-widgets/base", 900 | "_model_module_version": "1.2.0", 901 | "_model_name": "LayoutModel", 902 | "_view_count": null, 903 | "_view_module": "@jupyter-widgets/base", 904 | "_view_module_version": "1.2.0", 905 | "_view_name": "LayoutView", 906 | "align_content": null, 907 | "align_items": null, 908 | "align_self": null, 909 | "border": null, 910 | "bottom": null, 911 | "display": null, 912 | "flex": null, 913 | "flex_flow": null, 914 | "grid_area": null, 915 | "grid_auto_columns": null, 916 | "grid_auto_flow": null, 917 | "grid_auto_rows": null, 918 | "grid_column": null, 919 | "grid_gap": null, 920 | "grid_row": null, 921 | "grid_template_areas": null, 922 | "grid_template_columns": null, 923 | "grid_template_rows": null, 924 | "height": null, 925 | "justify_content": null, 926 | "justify_items": null, 927 | "left": null, 928 | "margin": null, 929 | "max_height": null, 930 | "max_width": null, 931 | "min_height": null, 932 | "min_width": null, 933 | "object_fit": null, 934 | "object_position": null, 935 | "order": null, 936 | "overflow": null, 937 | "overflow_x": null, 938 | "overflow_y": null, 939 | "padding": null, 940 | "right": null, 941 | "top": null, 942 | "visibility": null, 943 | "width": null 944 | } 945 | } 946 | } 947 | } 948 | }, 949 | "nbformat": 4, 950 | "nbformat_minor": 4 951 | } 952 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by MX 7 | # -------------------------------------------------------- 8 | from functools import partial 9 | from torch import optim as optim 10 | 11 | from utils import print_on_rank_zero 12 | 13 | 14 | def build_optimizer(hparams, model, is_pretrain): 15 | if is_pretrain: 16 | return build_pretrain_optimizer(hparams, model) 17 | else: 18 | return build_finetune_optimizer(hparams, model) 19 | 20 | 21 | def build_pretrain_optimizer(hparams, model): 22 | skip = {} 23 | skip_keywords = {} 24 | if hasattr(model, 'no_weight_decay'): 25 | skip = model.no_weight_decay() 26 | if hasattr(model, 'no_weight_decay_keywords'): 27 | skip_keywords = model.no_weight_decay_keywords() 28 | 29 | parameters = get_pretrain_param_groups(model, skip, skip_keywords) 30 | 31 | opt_lower = hparams.optim_type.lower() 32 | optimizer = None 33 | if opt_lower == 'sgd': 34 | optimizer = optim.SGD(parameters, momentum=0.9, nesterov=True, 35 | lr=hparams.lr, weight_decay=hparams.weight_decay) 36 | elif opt_lower == 'adamw': 37 | optimizer = optim.AdamW(parameters, betas=(0.9, 0.999), 38 | lr=hparams.lr, weight_decay=hparams.weight_decay) 39 | 40 | return optimizer 41 | 42 | 43 | def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): 44 | has_decay = [] 45 | no_decay = [] 46 | has_decay_name = [] 47 | no_decay_name = [] 48 | 49 | for name, param in model.named_parameters(): 50 | if not param.requires_grad: 51 | continue 52 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 53 | check_keywords_in_name(name, skip_keywords): 54 | no_decay.append(param) 55 | no_decay_name.append(name) 56 | else: 57 | has_decay.append(param) 58 | has_decay_name.append(name) 59 | 60 | print_on_rank_zero(f'params_no_decay_name: {no_decay_name} \n params_decay_name: {has_decay_name}') 61 | return [{'params': no_decay, 'weight_decay': 0.}, 62 | {'params': has_decay},] 63 | 64 | 65 | def build_finetune_optimizer(hparams, model): 66 | if hparams.arch == 'mvit': 67 | if hparams.layer_decay == 1: 68 | get_layer_func = None 69 | scales = None 70 | else: 71 | num_layers = 16 72 | get_layer_func = partial(get_mvit_layer, num_layers=num_layers + 2) 73 | scales = list(hparams.layer_decay ** i for i in reversed(range(num_layers + 2))) #layer_decay=1 disable 74 | else: 75 | return build_pretrain_optimizer(hparams, model) 76 | 77 | skip = {} 78 | skip_keywords = {} 79 | if hasattr(model, 'no_weight_decay'): 80 | skip = model.no_weight_decay() 81 | if hasattr(model, 'no_weight_decay_keywords'): 82 | skip_keywords = model.no_weight_decay_keywords() 83 | 84 | parameters = get_finetune_param_groups( 85 | model, hparams.lr, hparams.weight_decay, 86 | get_layer_func, scales, skip, skip_keywords) 87 | 88 | opt_lower = hparams.optim_type.lower() 89 | optimizer = None 90 | if opt_lower == 'sgd': 91 | optimizer = optim.SGD(parameters, momentum=0.9, nesterov=True, 92 | lr=hparams.lr, weight_decay=hparams.weight_decay) 93 | elif opt_lower == 'adamw': 94 | optimizer = optim.AdamW(parameters, betas=(0.9, 0.999), 95 | lr=hparams.lr, weight_decay=hparams.weight_decay) 96 | 97 | return optimizer 98 | 99 | 100 | def get_mvit_layer(name, num_layers): 101 | layer_name = name.replace('mvit.', '') 102 | layer_name = layer_name.replace('model.', '') 103 | if layer_name in ("mask_token"): 104 | return 0 105 | elif layer_name.startswith("patch_embed") or layer_name.startswith('cls_positional_encoding'): 106 | return 0 107 | elif layer_name.startswith("blocks"): 108 | layer_id = int(layer_name.split('.')[1]) 109 | return layer_id + 1 110 | else: 111 | return num_layers - 1 112 | 113 | 114 | def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 115 | parameter_group_names = {} 116 | parameter_group_vars = {} 117 | 118 | for name, param in model.named_parameters(): 119 | if not param.requires_grad: 120 | continue 121 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 122 | check_keywords_in_name(name, skip_keywords): 123 | group_name = "no_decay" 124 | this_weight_decay = 0. 125 | else: 126 | group_name = "decay" 127 | this_weight_decay = weight_decay 128 | if get_layer_func is not None: 129 | layer_id = get_layer_func(name) 130 | group_name = "layer_%d_%s" % (layer_id, group_name) 131 | #print(name, group_name) 132 | else: 133 | layer_id = None 134 | 135 | if group_name not in parameter_group_names: 136 | if scales is not None: 137 | scale = scales[layer_id] 138 | else: 139 | scale = 1. 140 | 141 | parameter_group_names[group_name] = { 142 | "group_name": group_name, 143 | "weight_decay": this_weight_decay, 144 | "params": [], 145 | "lr": lr * scale, 146 | "lr_scale": scale, 147 | } 148 | parameter_group_vars[group_name] = { 149 | "group_name": group_name, 150 | "weight_decay": this_weight_decay, 151 | "params": [], 152 | "lr": lr * scale, 153 | "lr_scale": scale 154 | } 155 | 156 | parameter_group_vars[group_name]["params"].append(param) 157 | parameter_group_names[group_name]["params"].append(name) 158 | return list(parameter_group_vars.values()) 159 | 160 | 161 | def check_keywords_in_name(name, keywords=()): 162 | isin = False 163 | for keyword in keywords: 164 | if keyword in name: 165 | isin = True 166 | return isin -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchmetrics==0.5.1 3 | torchvision 4 | pytorch-lightning==1.3.8 5 | pytorchvideo 6 | scikit-image 7 | decord==0.6.0 8 | einops==0.3.2 9 | kornia==0.6.1 10 | matplotlib==3.4.3 11 | timm==0.4.12 -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange, repeat, reduce 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.utils import _pair 6 | 7 | from weight_init import trunc_normal_, constant_init_, kaiming_init_ 8 | 9 | 10 | # sin-cos position encoding 11 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 12 | def get_sine_cosine_pos_emb(n_position, d_hid): 13 | ''' Sinusoid position encoding table ''' 14 | # TODO: make it with torch instead of numpy 15 | def get_position_angle_vec(position): 16 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 17 | 18 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 19 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 20 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 21 | 22 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 23 | 24 | 25 | class DropPath(nn.Module): 26 | 27 | def __init__(self, dropout_p=None): 28 | super(DropPath, self).__init__() 29 | self.dropout_p = dropout_p 30 | 31 | def forward(self, x): 32 | return self.drop_path(x, self.dropout_p, self.training) 33 | 34 | def drop_path(self, x, dropout_p=0., training=False): 35 | if dropout_p == 0. or not training: 36 | return x 37 | keep_prob = 1 - dropout_p 38 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 39 | random_tensor = keep_prob + torch.rand(shape).type_as(x) 40 | random_tensor.floor_() # binarize 41 | output = x.div(keep_prob) * random_tensor 42 | return output 43 | 44 | 45 | class ClassificationHead(nn.Module): 46 | """Classification head for Video Transformer. 47 | 48 | Args: 49 | num_classes (int): Number of classes to be classified. 50 | in_channels (int): Number of channels in input feature. 51 | init_std (float): Std value for Initiation. Defaults to 0.02. 52 | kwargs (dict, optional): Any keyword argument to be used to initialize 53 | the head. 54 | """ 55 | 56 | def __init__(self, 57 | num_classes, 58 | in_channels, 59 | init_std=0.02, 60 | eval_metrics='finetune', 61 | **kwargs): 62 | super().__init__() 63 | self.init_std = init_std 64 | self.eval_metrics = eval_metrics 65 | self.cls_head = nn.Linear(in_channels, num_classes) 66 | 67 | self.init_weights(self.cls_head) 68 | 69 | def init_weights(self, module): 70 | if hasattr(module, 'weight') and module.weight is not None: 71 | if self.eval_metrics == 'finetune': 72 | trunc_normal_(module.weight, std=self.init_std) 73 | else: 74 | module.weight.data.normal_(mean=0.0, std=0.01) 75 | if hasattr(module, 'bias') and module.bias is not None: 76 | constant_init_(module.bias, constant_value=0) 77 | 78 | def forward(self, x): 79 | cls_score = self.cls_head(x) 80 | return cls_score 81 | 82 | 83 | class PatchEmbed(nn.Module): 84 | """Images to Patch Embedding. 85 | 86 | Args: 87 | img_size (int | tuple): Size of input image. 88 | patch_size (int): Size of one patch. 89 | tube_size (int): Size of temporal field of one 3D patch. 90 | in_channels (int): Channel num of input features. Defaults to 3. 91 | embed_dims (int): Dimensions of embedding. Defaults to 768. 92 | conv_type (str): Type for convolution layer. Defaults to 'Conv2d'. 93 | """ 94 | 95 | def __init__(self, 96 | img_size, 97 | patch_size, 98 | tube_size=2, 99 | in_channels=3, 100 | embed_dims=768, 101 | conv_type='Conv2d'): 102 | super().__init__() 103 | self.img_size = _pair(img_size) 104 | self.patch_size = _pair(patch_size) 105 | 106 | num_patches = \ 107 | (self.img_size[1] // self.patch_size[1]) * \ 108 | (self.img_size[0] // self.patch_size[0]) 109 | assert (num_patches * self.patch_size[0] * self.patch_size[1] == 110 | self.img_size[0] * self.img_size[1], 111 | 'The image size H*W must be divisible by patch size') 112 | self.num_patches = num_patches 113 | 114 | # Use conv layer to embed 115 | if conv_type == 'Conv2d': 116 | self.projection = nn.Conv2d( 117 | in_channels, 118 | embed_dims, 119 | kernel_size=patch_size, 120 | stride=patch_size) 121 | elif conv_type == 'Conv3d': 122 | self.projection = nn.Conv3d( 123 | in_channels, 124 | embed_dims, 125 | kernel_size=(tube_size,patch_size,patch_size), 126 | stride=(tube_size,patch_size,patch_size)) 127 | else: 128 | raise TypeError(f'Unsupported conv layer type {conv_type}') 129 | 130 | self.init_weights(self.projection) 131 | 132 | def init_weights(self, module): 133 | if hasattr(module, 'weight') and module.weight is not None: 134 | kaiming_init_(module.weight, mode='fan_in', nonlinearity='relu') 135 | if hasattr(module, 'bias') and module.bias is not None: 136 | constant_init_(module.bias, constant_value=0) 137 | 138 | def forward(self, x): 139 | layer_type = type(self.projection) 140 | if layer_type == nn.Conv3d: 141 | x = rearrange(x, 'b t c h w -> b c t h w') 142 | x = self.projection(x) 143 | x = rearrange(x, 'b c t h w -> (b t) (h w) c') 144 | elif layer_type == nn.Conv2d: 145 | x = rearrange(x, 'b t c h w -> (b t) c h w') 146 | x = self.projection(x) 147 | x = rearrange(x, 'b c h w -> b (h w) c') 148 | else: 149 | raise TypeError(f'Unsupported conv layer type {layer_type}') 150 | 151 | return x 152 | 153 | class Attention(nn.Module): 154 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 155 | super().__init__() 156 | self.num_heads = num_heads 157 | head_dim = dim // num_heads 158 | self.scale = qk_scale or head_dim ** -0.5 159 | 160 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 161 | self.attn_drop = nn.Dropout(attn_drop) 162 | self.proj = nn.Linear(dim, dim) 163 | self.proj_drop = nn.Dropout(proj_drop) 164 | 165 | def forward(self, x): 166 | B, N, C = x.shape 167 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 168 | q, k, v = qkv[0], qkv[1], qkv[2] 169 | 170 | attn = (q @ k.transpose(-2, -1)) * self.scale 171 | attn = attn.softmax(dim=-1) 172 | attn = self.attn_drop(attn) 173 | 174 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 175 | x = self.proj(x) 176 | x = self.proj_drop(x) 177 | return x, attn 178 | 179 | class DividedTemporalAttentionWithPreNorm(nn.Module): 180 | """Temporal Attention in Divided Space Time Attention. 181 | A warp for torch.nn.MultiheadAttention. 182 | 183 | Args: 184 | embed_dims (int): Dimensions of embedding. 185 | num_heads (int): Number of parallel attention heads in 186 | TransformerCoder. 187 | num_frames (int): Number of frames in the video. 188 | use_cls_token (bool): Whether to perform MSA on cls_token. 189 | attn_drop (float): A Dropout layer on attn_output_weights. Defaults to 190 | 0.. 191 | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. 192 | Defaults to 0.. 193 | layer_drop (dict): The layer_drop used when adding the shortcut. 194 | Defaults to `dict(type=DropPath, dropout_p=0.1)`. 195 | norm_layer (class): Class name for normalization layer. Defaults to 196 | nn.LayerNorm. 197 | """ 198 | 199 | def __init__(self, 200 | embed_dims, 201 | num_heads, 202 | num_frames, 203 | use_cls_token, 204 | attn_drop=0., 205 | proj_drop=0., 206 | layer_drop=dict(type=DropPath, dropout_p=0.1), 207 | norm_layer=nn.LayerNorm, 208 | **kwargs): 209 | super().__init__() 210 | self.embed_dims = embed_dims 211 | self.num_heads = num_heads 212 | self.num_frames = num_frames 213 | self.use_cls_token = use_cls_token 214 | 215 | self.norm = norm_layer(embed_dims) 216 | #self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, 217 | # **kwargs) 218 | self.attn = Attention(embed_dims, num_heads, qkv_bias=True, attn_drop=attn_drop) # batch first 219 | 220 | self.proj_drop = nn.Dropout(proj_drop) 221 | dropout_p = layer_drop.pop('dropout_p') 222 | layer_drop= layer_drop.pop('type') 223 | self.layer_drop = layer_drop(dropout_p) if layer_drop else nn.Identity() 224 | if not use_cls_token: 225 | self.temporal_fc = nn.Linear(self.embed_dims, self.embed_dims) 226 | self.init_weights(self.temporal_fc) 227 | 228 | def init_weights(self, module): 229 | if hasattr(module, 'weight') and module.weight is not None: 230 | constant_init_(module.weight, constant_value=0) 231 | if hasattr(module, 'bias') and module.bias is not None: 232 | constant_init_(module.bias, constant_value=0) 233 | 234 | def forward(self, query, key=None, value=None, residual=None, return_attention=False, **kwargs): 235 | assert residual is None, ( 236 | 'Always adding the shortcut in the forward function') 237 | 238 | cls_token = query[:, 0, :].unsqueeze(1) 239 | if self.use_cls_token: 240 | residual = query 241 | query = query[:, 1:, :] 242 | else: 243 | query = query[:, 1:, :] 244 | residual = query 245 | 246 | b, n, d = query.size() 247 | p, t = n // self.num_frames, self.num_frames 248 | 249 | # Pre-Process 250 | query = rearrange(query, 'b (p t) d -> (b p) t d', p=p, t=t) 251 | if self.use_cls_token: 252 | cls_token = repeat(cls_token, 'b n d -> b (p n) d', p=p) 253 | cls_token = rearrange(cls_token, 'b p d -> (b p) 1 d') 254 | query = torch.cat((cls_token, query), 1) 255 | 256 | # Forward MSA 257 | query = self.norm(query) 258 | #query = rearrange(query, 'b n d -> n b d') 259 | #attn_out = self.attn(query, query, query)[0] 260 | #attn_out = rearrange(attn_out, 'n b d -> b n d') 261 | attn_out, attn_weights = self.attn(query) 262 | if return_attention: 263 | return attn_weights 264 | 265 | attn_out = self.layer_drop(self.proj_drop(attn_out.contiguous())) 266 | if not self.use_cls_token: 267 | attn_out = self.temporal_fc(attn_out) 268 | 269 | # Post-Process 270 | if self.use_cls_token: 271 | cls_token, attn_out = attn_out[:, 0, :], attn_out[:, 1:, :] 272 | cls_token = rearrange(cls_token, '(b p) d -> b p d', b=b) 273 | cls_token = reduce(cls_token, 'b p d -> b 1 d', 'mean') 274 | 275 | attn_out = rearrange(attn_out, '(b p) t d -> b (p t) d', p=p, t=t) 276 | attn_out = torch.cat((cls_token, attn_out), 1) 277 | new_query = residual + attn_out 278 | else: 279 | attn_out = rearrange(attn_out, '(b p) t d -> b (p t) d', p=p, t=t) 280 | new_query = residual + attn_out 281 | new_query = torch.cat((cls_token, new_query), 1) 282 | return new_query 283 | 284 | 285 | class DividedSpatialAttentionWithPreNorm(nn.Module): 286 | """Spatial Attention in Divided Space Time Attention. 287 | A warp for torch.nn.MultiheadAttention. 288 | 289 | Args: 290 | embed_dims (int): Dimensions of embedding. 291 | num_heads (int): Number of parallel attention heads in 292 | TransformerCoder. 293 | num_frames (int): Number of frames in the video. 294 | use_cls_token (bool): Whether to perform MSA on cls_token. 295 | attn_drop (float): A Dropout layer on attn_output_weights. Defaults to 296 | 0.. 297 | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. 298 | Defaults to 0.. 299 | layer_drop (dict): The layer_drop used when adding the shortcut. 300 | Defaults to `dict(type=DropPath, dropout_p=0.1)`. 301 | norm_layer (class): Class name for normalization layer. Defaults to 302 | nn.LayerNorm. 303 | """ 304 | 305 | def __init__(self, 306 | embed_dims, 307 | num_heads, 308 | num_frames, 309 | use_cls_token, 310 | attn_drop=0., 311 | proj_drop=0., 312 | layer_drop=dict(type=DropPath, dropout_p=0.1), 313 | norm_layer=nn.LayerNorm, 314 | **kwargs): 315 | super().__init__() 316 | self.embed_dims = embed_dims 317 | self.num_heads = num_heads 318 | self.num_frames = num_frames 319 | self.use_cls_token = use_cls_token 320 | 321 | self.norm = norm_layer(embed_dims) 322 | #self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, 323 | # **kwargs) 324 | self.attn = Attention(embed_dims, num_heads, qkv_bias=True, attn_drop=attn_drop) # batch first 325 | 326 | self.proj_drop = nn.Dropout(proj_drop) 327 | dropout_p = layer_drop.pop('dropout_p') 328 | layer_drop= layer_drop.pop('type') 329 | self.layer_drop = layer_drop(dropout_p) if layer_drop else nn.Identity() 330 | 331 | self.init_weights() 332 | 333 | def init_weights(self): 334 | pass 335 | 336 | def forward(self, query, key=None, value=None, residual=None, return_attention=False, **kwargs): 337 | assert residual is None, ( 338 | 'Always adding the shortcut in the forward function') 339 | 340 | cls_token = query[:, 0, :].unsqueeze(1) 341 | if self.use_cls_token: 342 | residual = query 343 | query = query[:, 1:, :] 344 | else: 345 | query = query[:, 1:, :] 346 | residual = query 347 | 348 | b, n, d = query.size() 349 | p, t = n // self.num_frames, self.num_frames 350 | 351 | # Pre-Process 352 | query = rearrange(query, 'b (p t) d -> (b t) p d', p=p, t=t) 353 | if self.use_cls_token: 354 | cls_token = repeat(cls_token, 'b n d -> b (t n) d', t=t) 355 | cls_token = rearrange(cls_token, 'b t d -> (b t) 1 d') 356 | query = torch.cat((cls_token, query), 1) 357 | 358 | # Forward MSA 359 | query = self.norm(query) 360 | #query = rearrange(query, 'b n d -> n b d') 361 | #attn_out = self.attn(query, query, query)[0] 362 | #attn_out = rearrange(attn_out, 'n b d -> b n d') 363 | attn_out, attn_weights = self.attn(query) 364 | if return_attention: 365 | return attn_weights 366 | 367 | attn_out = self.layer_drop(self.proj_drop(attn_out.contiguous())) 368 | 369 | # Post-Process 370 | if self.use_cls_token: 371 | cls_token, attn_out = attn_out[:, 0, :], attn_out[:, 1:, :] 372 | cls_token = rearrange(cls_token, '(b t) d -> b t d', b=b) 373 | cls_token = reduce(cls_token, 'b t d -> b 1 d', 'mean') 374 | 375 | attn_out = rearrange(attn_out, '(b t) p d -> b (p t) d', p=p, t=t) 376 | attn_out = torch.cat((cls_token, attn_out), 1) 377 | new_query = residual + attn_out 378 | else: 379 | attn_out = rearrange(attn_out, '(b t) p d -> b (p t) d', p=p, t=t) 380 | new_query = residual + attn_out 381 | new_query = torch.cat((cls_token, new_query), 1) 382 | return new_query 383 | 384 | 385 | class MultiheadAttentionWithPreNorm(nn.Module): 386 | """Implements MultiheadAttention with residual connection. 387 | 388 | Args: 389 | embed_dims (int): The embedding dimension. 390 | num_heads (int): Parallel attention heads. 391 | attn_drop (float): A Dropout layer on attn_output_weights. 392 | Default: 0.0. 393 | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. 394 | Default: 0.0. 395 | norm_layer (class): Class name for normalization layer. Defaults to 396 | nn.LayerNorm. 397 | layer_drop (obj:`ConfigDict`): The layer_drop used 398 | when adding the shortcut. 399 | batch_first (bool): When it is True, Key, Query and Value are shape of 400 | (batch, n, embed_dim), otherwise (n, batch, embed_dim). 401 | Default to False. 402 | """ 403 | 404 | def __init__(self, 405 | embed_dims, 406 | num_heads, 407 | attn_drop=0., 408 | proj_drop=0., 409 | norm_layer=nn.LayerNorm, 410 | layer_drop=dict(type=DropPath, dropout_p=0.), 411 | batch_first=False, 412 | **kwargs): 413 | super().__init__() 414 | self.embed_dims = embed_dims 415 | self.num_heads = num_heads 416 | #self.batch_first = batch_first 417 | 418 | self.norm = norm_layer(embed_dims) 419 | #self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, 420 | # **kwargs) 421 | self.attn = Attention(embed_dims, num_heads, qkv_bias=True, attn_drop=attn_drop) # batch first 422 | 423 | self.proj_drop = nn.Dropout(proj_drop) 424 | dropout_p = layer_drop.pop('dropout_p') 425 | layer_drop= layer_drop.pop('type') 426 | self.layer_drop = layer_drop(dropout_p) if layer_drop else nn.Identity() 427 | 428 | def forward(self, 429 | query, 430 | key=None, 431 | value=None, 432 | residual=None, 433 | attn_mask=None, 434 | key_padding_mask=None, 435 | return_attention=False, 436 | **kwargs): 437 | residual = query 438 | 439 | query = self.norm(query) 440 | #if self.batch_first: 441 | # query = query.transpose(0, 1) 442 | #attn_out = self.attn( 443 | # query=query, 444 | # key=query, 445 | # value=query, 446 | # attn_mask=attn_mask, 447 | # key_padding_mask=key_padding_mask)[0] 448 | #attn_out = self.attn(query, query, query)[0] 449 | #if self.batch_first: 450 | # attn_out = attn_out.transpose(0, 1) 451 | attn_out, attn_weights = self.attn(query) 452 | if return_attention: 453 | return attn_weights 454 | 455 | new_query = residual + self.layer_drop(self.proj_drop(attn_out)) 456 | return new_query 457 | 458 | 459 | class FFNWithPreNorm(nn.Module): 460 | """Implements feed-forward networks (FFNs) with residual connection. 461 | 462 | Args: 463 | embed_dims (int): The feature dimension. Same as 464 | `MultiheadAttention`. Defaults: 256. 465 | hidden_channels (int): The hidden dimension of FFNs. 466 | Defaults: 1024. 467 | num_layers (int, optional): The number of fully-connected layers in 468 | FFNs. Default: 2. 469 | act_layer (dict, optional): The activation layer for FFNs. 470 | Default: nn.GELU 471 | norm_layer (class): Class name for normalization layer. Defaults to 472 | nn.LayerNorm. 473 | dropout_p (float, optional): Probability of an element to be 474 | zeroed in FFN. Default 0.0. 475 | layer_drop (obj:`ConfigDict`): The layer_drop used 476 | when adding the shortcut. 477 | """ 478 | 479 | def __init__(self, 480 | embed_dims=256, 481 | hidden_channels=1024, 482 | num_layers=2, 483 | act_layer=nn.GELU, 484 | norm_layer=nn.LayerNorm, 485 | dropout_p=0., 486 | layer_drop=None, 487 | **kwargs): 488 | super().__init__() 489 | assert num_layers >= 2, 'num_layers should be no less ' \ 490 | f'than 2. got {num_layers}.' 491 | self.embed_dims = embed_dims 492 | self.hidden_channels = hidden_channels 493 | self.num_layers = num_layers 494 | 495 | self.norm = norm_layer(embed_dims) 496 | layers = [] 497 | in_channels = embed_dims 498 | for _ in range(num_layers - 1): 499 | layers.append( 500 | nn.Sequential( 501 | nn.Linear(in_channels, hidden_channels), 502 | act_layer(), 503 | nn.Dropout(dropout_p))) 504 | in_channels = hidden_channels 505 | layers.append(nn.Linear(hidden_channels, embed_dims)) 506 | layers.append(nn.Dropout(dropout_p)) 507 | self.layers = nn.ModuleList(layers) 508 | 509 | if layer_drop: 510 | dropout_p = layer_drop.pop('dropout_p') 511 | layer_drop= layer_drop.pop('type') 512 | self.layer_drop = layer_drop(dropout_p) 513 | else: 514 | self.layer_drop = nn.Identity() 515 | 516 | def forward(self, x): 517 | residual = x 518 | 519 | x = self.norm(x) 520 | for layer in self.layers: 521 | x = layer(x) 522 | 523 | return residual + self.layer_drop(x) 524 | 525 | 526 | class TransformerContainer(nn.Module): 527 | 528 | def __init__(self, 529 | num_transformer_layers, 530 | embed_dims, 531 | num_heads, 532 | num_frames, 533 | hidden_channels, 534 | operator_order, 535 | drop_path_rate=0.1, 536 | norm_layer=nn.LayerNorm, 537 | act_layer=nn.GELU, 538 | num_layers=2): 539 | super().__init__() 540 | self.layers = nn.ModuleList([]) 541 | self.num_transformer_layers = num_transformer_layers 542 | 543 | dpr = np.linspace(0, drop_path_rate, num_transformer_layers) 544 | for i in range(num_transformer_layers): 545 | self.layers.append( 546 | BasicTransformerBlock( 547 | embed_dims=embed_dims, 548 | num_heads=num_heads, 549 | num_frames=num_frames, 550 | hidden_channels=hidden_channels, 551 | operator_order=operator_order, 552 | norm_layer=norm_layer, 553 | act_layer=act_layer, 554 | num_layers=num_layers, 555 | dpr=dpr[i])) 556 | 557 | def forward(self, x, return_attention=False): 558 | layer_idx = 0 559 | for layer in self.layers: 560 | if layer_idx >= self.num_transformer_layers-1 and return_attention: 561 | x = layer(x, return_attention=True) 562 | else: 563 | x = layer(x) 564 | layer_idx += 1 565 | return x 566 | 567 | 568 | class BasicTransformerBlock(nn.Module): 569 | 570 | def __init__(self, 571 | embed_dims, 572 | num_heads, 573 | num_frames, 574 | hidden_channels, 575 | operator_order, 576 | norm_layer=nn.LayerNorm, 577 | act_layer=nn.GELU, 578 | num_layers=2, 579 | dpr=0, 580 | ): 581 | 582 | super().__init__() 583 | self.attentions = nn.ModuleList([]) 584 | self.ffns = nn.ModuleList([]) 585 | 586 | for i, operator in enumerate(operator_order): 587 | if operator == 'self_attn': 588 | self.attentions.append( 589 | MultiheadAttentionWithPreNorm( 590 | embed_dims=embed_dims, 591 | num_heads=num_heads, 592 | batch_first=True, 593 | norm_layer=nn.LayerNorm, 594 | layer_drop=dict(type=DropPath, dropout_p=dpr))) 595 | elif operator == 'time_attn': 596 | self.attentions.append( 597 | DividedTemporalAttentionWithPreNorm( 598 | embed_dims=embed_dims, 599 | num_heads=num_heads, 600 | num_frames=num_frames, 601 | norm_layer=norm_layer, 602 | use_cls_token=(i==len(operator_order)-2), 603 | layer_drop=dict(type=DropPath, dropout_p=dpr))) 604 | elif operator == 'space_attn': 605 | self.attentions.append( 606 | DividedSpatialAttentionWithPreNorm( 607 | embed_dims=embed_dims, 608 | num_heads=num_heads, 609 | num_frames=num_frames, 610 | norm_layer=norm_layer, 611 | use_cls_token=(i==len(operator_order)-2), 612 | layer_drop=dict(type=DropPath, dropout_p=dpr))) 613 | elif operator == 'ffn': 614 | self.ffns.append( 615 | FFNWithPreNorm( 616 | embed_dims=embed_dims, 617 | hidden_channels=hidden_channels, 618 | num_layers=num_layers, 619 | act_layer=act_layer, 620 | norm_layer=norm_layer, 621 | layer_drop=dict(type=DropPath, dropout_p=dpr))) 622 | else: 623 | raise TypeError(f'Unsupported operator type {operator}') 624 | 625 | def forward(self, x, return_attention=False): 626 | attention_idx = 0 627 | for layer in self.attentions: 628 | if attention_idx >= len(self.attentions)-1 and return_attention: 629 | x = layer(x, return_attention=True) 630 | return x 631 | else: 632 | x = layer(x) 633 | attention_idx += 1 634 | for layer in self.ffns: 635 | x = layer(x) 636 | return x 637 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import torch 7 | import matplotlib.pyplot as plt 8 | import torch.distributed as dist 9 | from pytorch_lightning.utilities.distributed import rank_zero_only 10 | 11 | @rank_zero_only 12 | def print_on_rank_zero(content): 13 | if is_main_process(): 14 | print(content) 15 | 16 | def is_dist_avail_and_initialized(): 17 | if not dist.is_available(): 18 | return False 19 | if not dist.is_initialized(): 20 | return False 21 | return True 22 | 23 | def get_world_size(): 24 | if not is_dist_avail_and_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | def get_rank(): 29 | if not is_dist_avail_and_initialized(): 30 | return 0 31 | return dist.get_rank() 32 | 33 | def is_main_process(): 34 | return get_rank() == 0 35 | 36 | def timeit_wrapper(func, *args, **kwargs): 37 | start = time.perf_counter() 38 | func_return_val = func(*args, **kwargs) 39 | end = time.perf_counter() 40 | return func_return_val, float(f'{end - start:.4f}') 41 | 42 | def show_trainable_params(named_parameters): 43 | for name, param in named_parameters: 44 | print(name, param.size()) 45 | 46 | def build_param_groups(model): 47 | params_no_decay = [] 48 | params_has_decay = [] 49 | params_no_decay_name = [] 50 | params_decay_name = [] 51 | for name, param in model.named_parameters(): 52 | if not param.requires_grad: 53 | continue 54 | if len(param) == 1 or name.endswith('.bias'): 55 | params_no_decay.append(param) 56 | params_no_decay_name.append(name) 57 | else: 58 | params_has_decay.append(param) 59 | params_decay_name.append(name) 60 | 61 | param_groups = [ 62 | {'params': params_no_decay, 'weight_decay': 0}, 63 | {'params': params_has_decay}, 64 | ] 65 | print_on_rank_zero(f'params_no_decay_name: {params_no_decay_name} \n params_decay_name: {params_decay_name}') 66 | return param_groups 67 | 68 | 69 | def denormalize(data, mean, std): 70 | """Denormalize an image/video tensor with mean and standard deviation. 71 | 72 | Args: 73 | input: Image tensor of size : (H W C). 74 | mean: Mean for each channel. 75 | std: Standard deviations for each channel. 76 | 77 | Return: 78 | Denormalised tensor with same size as input : (H W C). 79 | """ 80 | shape = data.shape 81 | 82 | if isinstance(mean, tuple): 83 | mean = np.array(mean, dtype=float) 84 | mean = torch.tensor(mean, device=data.device, dtype=data.dtype) 85 | 86 | if isinstance(std, tuple): 87 | std = np.array(std, dtype=float) 88 | std = torch.tensor(std, device=data.device, dtype=data.dtype) 89 | 90 | if mean.shape: 91 | mean = mean[None, :] 92 | if std.shape: 93 | std = std[None, :] 94 | 95 | out = (data.contiguous().view(-1, shape[-1]) * std) + mean 96 | 97 | return out.view(shape) 98 | 99 | 100 | def show_processed_image(imgs, save_dir, mean, std, index=0): 101 | """Plot the transformed images into figure and save to disk. 102 | 103 | Args: 104 | imgs: Image tensor of size : (T H W C). 105 | save_dir: The path to save the images. 106 | index: The index of current clips. 107 | """ 108 | os.makedirs(save_dir, exist_ok=True) 109 | if not isinstance(imgs[0], list): 110 | imgs = [imgs] 111 | 112 | num_show_clips = 5 113 | num_rows = len(imgs) 114 | num_cols = num_show_clips 115 | fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) 116 | for row_idx, row in enumerate(imgs): 117 | row = row[:num_show_clips] 118 | for col_idx, img in enumerate(row): 119 | ax = axs[row_idx, col_idx] 120 | img = denormalize(img, mean, std).cpu().numpy() 121 | img = (img * 255).astype(np.uint8) 122 | #img = img.cpu().numpy().astype(np.uint8) 123 | ax.imshow(np.asarray(img)) 124 | ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 125 | 126 | plt.tight_layout() 127 | filename = osp.join(save_dir, f'clip_transformed_b{index}.png') 128 | plt.savefig(filename) -------------------------------------------------------------------------------- /visualize_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Copy-paste from DINO library: 9 | https://github.com/facebookresearch/dino 10 | """ 11 | import os 12 | import argparse 13 | import cv2 14 | import random 15 | import colorsys 16 | import matplotlib 17 | import matplotlib.pyplot as plt 18 | import torch 19 | import torch.nn as nn 20 | import torchvision 21 | import numpy as np 22 | from PIL import Image 23 | from skimage.measure import find_contours 24 | from matplotlib.patches import Polygon 25 | from torch.utils.data import DataLoader 26 | 27 | import utils 28 | from video_transformer import TimeSformer 29 | import data_transform as T 30 | from dataset import DecordInit 31 | import weight_init 32 | 33 | matplotlib.use('Agg') 34 | 35 | company_colors = [ 36 | (0,160,215), # blue 37 | (220,55,60), # red 38 | (245,180,0), # yellow 39 | (10,120,190), # navy 40 | (40,150,100), # green 41 | (135,75,145), # purple 42 | ] 43 | company_colors = [(float(c[0]) / 255.0, float(c[1]) / 255.0, float(c[2]) / 255.0) for c in company_colors] 44 | 45 | def apply_mask2(image, mask, color, alpha=0.5): 46 | """Apply the given mask to the image. 47 | """ 48 | t= 0.2 49 | mi = np.min(mask) 50 | ma = np.max(mask) 51 | mask = (mask - mi) / (ma - mi) 52 | for c in range(3): 53 | image[:, :, c] = image[:, :, c] * (1 - alpha * np.sqrt(mask) * (mask>t))+ alpha * np.sqrt(mask) * (mask>t) * color[c] * 255 54 | return image 55 | 56 | def random_colors(N, bright=True): 57 | """ 58 | Generate random colors. 59 | """ 60 | brightness = 1.0 if bright else 0.7 61 | hsv = [(i / N, 1, brightness) for i in range(N)] 62 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 63 | random.shuffle(colors) 64 | return colors 65 | 66 | def show_attn(img, attentions, w_featmap, h_featmap, frame_index, index=None): 67 | 68 | nh = attentions.shape[0] # number of head 69 | 70 | # we keep only the output patch attention 71 | attentions = attentions[:, 0, 1:].reshape(nh, -1) 72 | 73 | if args.threshold is not None: 74 | # we keep only a certain percentage of the mass 75 | val, idx = torch.sort(attentions) 76 | val /= torch.sum(val, dim=1, keepdim=True) 77 | cumval = torch.cumsum(val, dim=1) 78 | th_attn = cumval > (1 - args.threshold) 79 | idx2 = torch.argsort(idx) 80 | for head in range(nh): 81 | th_attn[head] = th_attn[head][idx2[head]] 82 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 83 | # interpolate 84 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].detach().cpu().numpy() 85 | 86 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 87 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].detach().cpu().numpy() 88 | 89 | # save attentions heatmaps 90 | prefix = f'id{index}_' if index is not None else '' 91 | os.makedirs(args.output_dir, exist_ok=True) 92 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, f"img{frame_index}" + ".png")) 93 | img = Image.open(os.path.join(args.output_dir, f"img{frame_index}" + ".png")) 94 | 95 | attns = Image.new('RGB', (attentions.shape[2] * nh, attentions.shape[1])) 96 | for j in range(nh): 97 | #fname = os.path.join(args.output_dir, prefix + "attn-head" + str(j) + ".png") 98 | fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png") 99 | plt.imsave(fname=fname, arr=attentions[j], format='png') 100 | attns.paste(Image.open(fname), (j * attentions.shape[2], 0)) 101 | 102 | return attentions, th_attn, img, attns 103 | 104 | def show_attn_color(image, attentions, th_attn, index=None, head=[0,1,2,3,4,5]): 105 | M = image.max() 106 | m = image.min() 107 | span = 64 108 | image = ((image - m) / (M-m)) * span + (256 - span) 109 | image = image.mean(axis=2) 110 | image = np.repeat(image[:, :, np.newaxis], 3, axis=2) 111 | 112 | for j in head: 113 | m = attentions[j] 114 | m *= th_attn[j] 115 | attentions[j] = m 116 | mask = np.stack([attentions[j] for j in head]) 117 | 118 | blur = False 119 | contour = False 120 | alpha = 1 121 | figsize = tuple([i / 100 for i in args.image_size]) 122 | fig = plt.figure(figsize=figsize, frameon=False, dpi=100) 123 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 124 | ax.set_axis_off() 125 | fig.add_axes(ax) 126 | ax = plt.gca() 127 | 128 | if len(mask.shape) == 3: 129 | N = mask.shape[0] 130 | else: 131 | N = 1 132 | mask = mask[None, :, :] 133 | 134 | # AJ 135 | for i in range(N): 136 | mask[i] = mask[i] * ( mask[i] == np.amax(mask, axis=0)) 137 | a = np.cumsum(mask, axis=0) 138 | for i in range(N): 139 | mask[i] = mask[i] * (mask[i] == a[i]) 140 | 141 | colors = company_colors[:N] 142 | 143 | # Show area outside image boundaries. 144 | height, width = image.shape[:2] 145 | margin = 0 146 | ax.set_ylim(height + margin, -margin) 147 | ax.set_xlim(-margin, width + margin) 148 | ax.axis('off') 149 | masked_image = 0.1*image.astype(np.uint32).copy() 150 | for i in range(N): 151 | color = colors[i] 152 | _mask = mask[i] 153 | if blur: 154 | _mask = cv2.blur(_mask,(10,10)) 155 | # Mask 156 | masked_image = apply_mask2(masked_image, _mask, color, alpha) 157 | # Mask Polygon 158 | # Pad to ensure proper polygons for masks that touch image edges. 159 | if contour: 160 | padded_mask = np.zeros( 161 | (_mask.shape[0] + 2, _mask.shape[1] + 2))#, dtype=np.uint8) 162 | padded_mask[1:-1, 1:-1] = _mask 163 | contours = find_contours(padded_mask, 0.5) 164 | for verts in contours: 165 | # Subtract the padding and flip (y, x) to (x, y) 166 | verts = np.fliplr(verts) - 1 167 | p = Polygon(verts, facecolor="none", edgecolor=color) 168 | ax.add_patch(p) 169 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 170 | ax.axis('image') 171 | #fname = os.path.join(output_dir, 'bnw-{:04d}'.format(imid)) 172 | prefix = f'id{index}_' if index is not None else '' 173 | fname = os.path.join(args.output_dir, "attn_color.png") 174 | fig.savefig(fname) 175 | attn_color = Image.open(fname) 176 | 177 | return attn_color 178 | 179 | if __name__ == '__main__': 180 | parser = argparse.ArgumentParser('Visualize Self-Attention maps') 181 | parser.add_argument('--arch', default='timesformer', type=str, choices=['timesformer'], help='Architecture.') 182 | parser.add_argument('--pretrained_weights', default='', type=str, help="""Path to pretrained 183 | weights to evaluate. Set to `download` to automatically load the pretrained DINO from url. 184 | Otherwise the model is randomly initialized""") 185 | parser.add_argument('--output_dir', default='./attention_map', help='Path where to save visualizations.') 186 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks 187 | obtained by thresholding the self-attention maps to keep xx% of the mass.""") 188 | parser.add_argument("--patch_size", type=int, default=16, help="""patch size.""") 189 | parser.add_argument("--image_size", default=(224, 224), type=int, nargs="+", help="Resize image.") 190 | args = parser.parse_args() 191 | 192 | #utils.fix_random_seeds(0) 193 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 194 | 195 | # build model 196 | num_frames = 8 197 | frame_interval = 32 198 | num_class = 400 199 | arch = args.arch # turn to vivit for initializing vivit model 200 | if arch == 'timesformer': 201 | pretrain_pth = args.pretrained_weights #'./timesformer_k400.pth' 202 | model = TimeSformer(num_frames=num_frames, 203 | img_size=args.image_size, 204 | patch_size=16, 205 | embed_dims=768, 206 | in_channels=3, 207 | attention_type='divided_space_time', 208 | return_cls_token=True) 209 | else: 210 | raise TypeError(f'not supported arch type {arch}, chosen in (timesformer, vivit)') 211 | 212 | msg_trans = weight_init.init_from_kinetics_pretrain_(model, pretrain_pth, init_module='transformer') 213 | model.eval() 214 | model = model.to(device) 215 | print(f'load model finished, the missing key of transformer is:{msg_trans[0]}, unexpect_key is:{msg_trans[1]}') 216 | 217 | # build data 218 | video_path = './demo/YABnJL_bDzw.mp4' 219 | mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225) 220 | data_transform = T.Compose([ 221 | T.Resize(scale_range=(-1, 256)), 222 | T.CenterCrop(args.image_size), 223 | T.ToTensor(), 224 | T.Normalize(mean, std) 225 | ]) 226 | temporal_sample = T.TemporalRandomCrop(num_frames*frame_interval) 227 | 228 | video_decoder = DecordInit() 229 | v_reader = video_decoder(video_path) 230 | total_frames = len(v_reader) 231 | start_frame_ind, end_frame_ind = temporal_sample(total_frames) 232 | if end_frame_ind-start_frame_ind < num_frames: 233 | raise ValueError(f'the total frames of the video {video_path} is less than {num_frames}') 234 | frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, num_frames, dtype=int) 235 | video = v_reader.get_batch(frame_indice).asnumpy() 236 | del v_reader 237 | 238 | video = torch.from_numpy(video).permute(0,3,1,2) # Video transform: T C H W 239 | data_transform.randomize_parameters() 240 | video = data_transform(video) 241 | video = video.to(device) 242 | 243 | # extract the attention maps 244 | w_featmap = video.shape[-2] // args.patch_size 245 | h_featmap = video.shape[-1] // args.patch_size 246 | attentions = model.get_last_selfattention(video.unsqueeze(0).to(device)) # 247 | print(attentions.shape) # [8 12 197 197] 248 | for i,(frame, attention) in enumerate(zip(video, attentions)): 249 | # make the video frame divisible by the patch size 250 | attentions, th_attn, pic_i, pic_attn = show_attn(frame, attention, w_featmap, h_featmap, frame_index=i) 251 | pic_attn_color = show_attn_color(frame.permute(1, 2, 0).cpu().numpy(), attentions, th_attn) 252 | final_pic = Image.new('RGB', (pic_i.size[1] * 2 + pic_attn.size[0], pic_i.size[1])) 253 | final_pic.paste(pic_i, (0, 0)) 254 | final_pic.paste(pic_attn_color, (pic_i.size[1], 0)) 255 | final_pic.paste(pic_attn, (pic_i.size[1] * 2, 0)) 256 | final_pic.save(os.path.join(args.output_dir, f"attn_img{i}.png")) -------------------------------------------------------------------------------- /weight_init.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | import warnings 4 | 5 | from einops import repeat 6 | import torch 7 | import torch.nn as nn 8 | 9 | from utils import print_on_rank_zero 10 | 11 | 12 | def show_state_dict(state_dict): 13 | for name, value in state_dict.items(): 14 | print(name) 15 | 16 | 17 | def replace_state_dict(state_dict): 18 | for old_key in list(state_dict.keys()): 19 | if old_key.startswith('model'): 20 | new_key = old_key[6:] # skip 'model.' 21 | if 'in_proj' in new_key: 22 | new_key = new_key.replace('in_proj_', 'qkv.') #in_proj_weight -> qkv.weight 23 | elif 'out_proj' in new_key: 24 | new_key = new_key.replace('out_proj', 'proj') 25 | state_dict[new_key] = state_dict.pop(old_key) 26 | else: # cls_head 27 | new_key = old_key[9:] 28 | state_dict[new_key] = state_dict.pop(old_key) 29 | 30 | 31 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 32 | def norm_cdf(x): 33 | # Computes standard normal cumulative distribution function 34 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 35 | 36 | if (mean < a - 2 * std) or (mean > b + 2 * std): 37 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 38 | "The distribution of values may be incorrect.", 39 | stacklevel=2) 40 | 41 | with torch.no_grad(): 42 | # Values are generated by using a truncated uniform distribution and 43 | # then using the inverse CDF for the normal distribution. 44 | # Get upper and lower cdf values 45 | l = norm_cdf((a - mean) / std) 46 | u = norm_cdf((b - mean) / std) 47 | 48 | # Uniformly fill tensor with values from [l, u], then translate to 49 | # [2l-1, 2u-1]. 50 | tensor.uniform_(2 * l - 1, 2 * u - 1) 51 | 52 | # Use inverse cdf transform for normal distribution to get truncated 53 | # standard normal 54 | tensor.erfinv_() 55 | 56 | # Transform to proper mean, std 57 | tensor.mul_(std * math.sqrt(2.)) 58 | tensor.add_(mean) 59 | 60 | # Clamp to ensure it's in the proper range 61 | tensor.clamp_(min=a, max=b) 62 | return tensor 63 | 64 | 65 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 66 | # type: (Tensor, float, float, float, float) -> Tensor 67 | r"""Fills the input Tensor with values drawn from a truncated 68 | normal distribution. The values are effectively drawn from the 69 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 70 | with values outside :math:`[a, b]` redrawn until they are within 71 | the bounds. The method used for generating the random values works 72 | best when :math:`a \leq \text{mean} \leq b`. 73 | Args: 74 | tensor: an n-dimensional `torch.Tensor` 75 | mean: the mean of the normal distribution 76 | std: the standard deviation of the normal distribution 77 | a: the minimum cutoff value 78 | b: the maximum cutoff value 79 | Examples: 80 | >>> w = torch.empty(3, 5) 81 | >>> nn.init.trunc_normal_(w) 82 | """ 83 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 84 | 85 | 86 | @torch.no_grad() 87 | def constant_init_(tensor, constant_value=0): 88 | nn.init.constant_(tensor, constant_value) 89 | 90 | 91 | @torch.no_grad() 92 | def kaiming_init_(tensor, 93 | a=0, 94 | mode='fan_out', 95 | nonlinearity='relu', 96 | distribution='normal'): 97 | assert distribution in ['uniform', 'normal'] 98 | if distribution == 'uniform': 99 | nn.init.kaiming_uniform_( 100 | tensor, a=a, mode=mode, nonlinearity=nonlinearity) 101 | else: 102 | nn.init.kaiming_normal_( 103 | tensor, a=a, mode=mode, nonlinearity=nonlinearity) 104 | 105 | 106 | @torch.no_grad() 107 | def init_from_vit_pretrain_(module, 108 | pretrained, 109 | conv_type, 110 | attention_type, 111 | copy_strategy, 112 | extend_strategy='temporal_avg', 113 | tube_size=2, 114 | num_time_transformer_layers=4): 115 | 116 | if isinstance(pretrained, str): 117 | if torch.cuda.is_available(): 118 | state_dict = torch.load(pretrained) 119 | else: 120 | state_dict = torch.load(pretrained, map_location=torch.device('cpu')) 121 | 122 | if 'state_dict' in state_dict: 123 | state_dict = state_dict['state_dict'] 124 | 125 | old_state_dict_keys = list(state_dict.keys()) 126 | for old_key in old_state_dict_keys: 127 | # extend the Conv2d params to Conv3d 128 | if conv_type == 'Conv3d': 129 | if 'patch_embed.projection.weight' in old_key: 130 | weight = state_dict[old_key] 131 | new_weight = repeat(weight, 'd c h w -> d c t h w', t=tube_size) 132 | if extend_strategy == 'temporal_avg': 133 | new_weight = new_weight / tube_size 134 | elif extend_strategy == 'center_frame': 135 | new_weight.zero_() 136 | new_weight[:,:,tube_size//2,:,:] = weight 137 | state_dict[old_key] = new_weight 138 | continue 139 | 140 | # modify the key names of norm layers 141 | if attention_type == 'fact_encoder': 142 | new_key = old_key.replace('transformer_layers.layers', 143 | 'transformer_layers.0.layers') 144 | else: 145 | new_key = old_key 146 | 147 | if 'in_proj' in new_key: 148 | new_key = new_key.replace('in_proj_', 'qkv.') #in_proj_weight -> qkv.weight 149 | elif 'out_proj' in new_key: 150 | new_key = new_key.replace('out_proj', 'proj') 151 | 152 | if 'norms' in new_key: 153 | new_key = new_key.replace('norms.0', 'attentions.0.norm') 154 | new_key = new_key.replace('norms.1', 'ffns.0.norm') 155 | 156 | state_dict[new_key] = state_dict.pop(old_key) 157 | 158 | old_state_dict_keys = list(state_dict.keys()) 159 | for old_key in old_state_dict_keys: 160 | # copy the parameters of space attention to time attention 161 | if attention_type == 'divided_space_time': 162 | if 'attentions.0' in old_key: 163 | new_key = old_key.replace('attentions.0', 164 | 'attentions.1') 165 | if copy_strategy == 'repeat': 166 | state_dict[new_key] = state_dict[old_key].clone() 167 | elif copy_strategy == 'set_zero': 168 | state_dict[new_key] = state_dict[old_key].clone().zero_() 169 | # copy the part of parameters of space attention to time attention 170 | elif attention_type == 'fact_encoder': 171 | pattern = re.compile(r'(?<=layers.)\d+') 172 | matchObj = pattern.findall(old_key) 173 | if len(matchObj) > 1 and int(matchObj[1]) < num_time_transformer_layers: 174 | new_key = old_key.replace('transformer_layers.0.layers', 175 | 'transformer_layers.1.layers') 176 | if copy_strategy == 'repeat': 177 | state_dict[new_key] = state_dict[old_key].clone() 178 | elif copy_strategy == 'set_zero': 179 | state_dict[new_key] = state_dict[old_key].clone().zero_() 180 | 181 | missing_keys,unexpected_keys = module.load_state_dict(state_dict, strict=False) 182 | #print(f'missing_keys:{missing_keys}\n unexpected_keys:{unexpected_keys}') 183 | print_on_rank_zero(f'missing_keys:{missing_keys}\n ' 184 | f'unexpected_keys:{unexpected_keys}') 185 | 186 | 187 | @torch.no_grad() 188 | def init_from_mae_pretrain_(module, 189 | pretrained, 190 | conv_type, 191 | attention_type, 192 | copy_strategy, 193 | extend_strategy='temporal_avg', 194 | tube_size=2, 195 | num_time_transformer_layers=4): 196 | 197 | if isinstance(pretrained, str): 198 | if torch.cuda.is_available(): 199 | state_dict = torch.load(pretrained) 200 | else: 201 | state_dict = torch.load(pretrained, map_location=torch.device('cpu')) 202 | 203 | if 'model' in state_dict: 204 | state_dict = state_dict['model'] 205 | 206 | # adjust to our module 207 | old_state_dict_keys = list(state_dict.keys()) 208 | for old_key in old_state_dict_keys: 209 | if 'decoder' in old_key: 210 | state_dict.pop(old_key) 211 | continue 212 | 213 | # extend the Conv2d params to Conv3d 214 | if 'encoder.patch_embed.proj' in old_key: 215 | new_key = old_key.replace('encoder.patch_embed.proj', 216 | 'patch_embed.projection') 217 | if conv_type == 'Conv3d' and 'weight' in old_key: 218 | weight = state_dict[old_key] 219 | new_weight = repeat(weight, 'd c h w -> d c t h w', t=tube_size) 220 | if extend_strategy == 'temporal_avg': 221 | new_weight = new_weight / tube_size 222 | elif extend_strategy == 'center_frame': 223 | new_weight.zero_() 224 | new_weight[:,:,tube_size//2,:,:] = weight 225 | state_dict.pop(old_key) 226 | state_dict[new_key] = new_weight 227 | else: 228 | state_dict[new_key] = state_dict.pop(old_key) 229 | continue 230 | 231 | # modify the key names of norm layers 232 | if attention_type == 'fact_encoder': 233 | new_key = old_key.replace('encoder.blocks', 234 | 'transformer_layers.0.layers') 235 | else: 236 | new_key = old_key.replace('encoder.blocks', 237 | 'transformer_layers.layers') 238 | 239 | if 'norm' in new_key: 240 | new_key = new_key.replace('norm1', 'attentions.0.norm') 241 | new_key = new_key.replace('norm2', 'ffns.0.norm') 242 | elif 'attn' in new_key: 243 | #new_key = new_key.replace('attn.qkv.weight', 244 | # 'attentions.0.attn.in_proj_weight') 245 | #new_key = new_key.replace('attn.proj', 246 | # 'attentions.0.attn.out_proj') 247 | if 'q_bias' in new_key: 248 | pattern = re.compile(r'(?<=blocks.)\d+') 249 | matchObj = pattern.findall(old_key) 250 | block_id = int(matchObj[0]) 251 | q_bias = state_dict[f'encoder.blocks.{block_id}.attn.q_bias'] 252 | v_bias = state_dict[f'encoder.blocks.{block_id}.attn.v_bias'] 253 | weight = torch.cat((q_bias, 254 | torch.zeros_like(q_bias, requires_grad=False), 255 | v_bias)) 256 | new_key = new_key.replace('attn.q_bias', 257 | #'attentions.0.attn.in_proj_bias') 258 | 'attentions.0.attn.qkv.bias') 259 | state_dict.pop(f'encoder.blocks.{block_id}.attn.q_bias') 260 | state_dict.pop(f'encoder.blocks.{block_id}.attn.v_bias') 261 | state_dict[new_key] = weight 262 | continue 263 | elif 'v_bias' in new_key: 264 | continue 265 | elif 'mlp' in new_key: 266 | new_key = new_key.replace('mlp.fc1', 'ffns.0.layers.0.0') 267 | new_key = new_key.replace('mlp.fc2', 'ffns.0.layers.1') 268 | 269 | if 'encoder.norm' in old_key: 270 | new_key = old_key.replace('encoder.norm', 271 | 'norm') 272 | 273 | state_dict[new_key] = state_dict.pop(old_key) 274 | 275 | # copy to new layer 276 | old_state_dict_keys = list(state_dict.keys()) 277 | for old_key in old_state_dict_keys: 278 | # copy the parameters of space attention to time attention 279 | if attention_type == 'divided_space_time': 280 | if 'attentions.0' in old_key: 281 | new_key = old_key.replace('attentions.0', 282 | 'attentions.1') 283 | if copy_strategy == 'repeat': 284 | state_dict[new_key] = state_dict[old_key].clone() 285 | elif copy_strategy == 'set_zero': 286 | state_dict[new_key] = state_dict[old_key].clone().zero_() 287 | # copy the part of parameters of space attention to time attention 288 | elif attention_type == 'fact_encoder': 289 | pattern = re.compile(r'(?<=layers.)\d+') 290 | matchObj = pattern.findall(old_key) 291 | if len(matchObj) > 1 and int(matchObj[1]) < num_time_transformer_layers: 292 | new_key = old_key.replace('transformer_layers.0.layers', 293 | 'transformer_layers.1.layers') 294 | if copy_strategy == 'repeat': 295 | state_dict[new_key] = state_dict[old_key].clone() 296 | elif copy_strategy == 'set_zero': 297 | state_dict[new_key] = state_dict[old_key].clone().zero_() 298 | 299 | missing_keys,unexpected_keys = module.load_state_dict(state_dict, strict=False) 300 | #print(f'missing_keys:{missing_keys}\n unexpected_keys:{unexpected_keys}') 301 | print_on_rank_zero(f'missing_keys:{missing_keys}\n ' 302 | f'unexpected_keys:{unexpected_keys}') 303 | 304 | 305 | def init_from_kinetics_pretrain_(module, pretrain_pth): 306 | if torch.cuda.is_available(): 307 | state_dict = torch.load(pretrain_pth) 308 | else: 309 | state_dict = torch.load(pretrain_pth, map_location=torch.device('cpu')) 310 | if 'state_dict' in state_dict: 311 | state_dict = state_dict['state_dict'] 312 | 313 | replace_state_dict(state_dict) 314 | msg = module.load_state_dict(state_dict, strict=False) 315 | print_on_rank_zero(msg) --------------------------------------------------------------------------------