├── README.md
├── data_trainer.py
├── data_transform.py
├── dataset.py
├── demo
├── 9r8wpMS2iEk_000048_000058.mp4
├── YABnJL_bDzw.mp4
├── kinetics400_train_list_videos_25fps.txt
├── kinetics400_val_list_videos_25fps.txt
└── log_arch_timesformer_lr5e-3_bs8_nw4_open.txt
├── k400_classmap.json
├── k600_classmap.json
├── mask_generator.py
├── mixup.py
├── model_pretrain.py
├── model_trainer.py
├── notebook
└── VideoTransformer_demo.ipynb
├── optimizer.py
├── requirements.txt
├── transformer.py
├── utils.py
├── video_transformer.py
├── visualize_attention.py
└── weight_init.py
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch implementation of Video Transformer Benchmarks
2 | This repository is mainly built upon [Pytorch](https://pytorch.org/) and [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/). We wish to maintain a collections of scalable video transformer benchmarks, and discuss the training recipes of how to train a big video transformer model.
3 |
4 | Now, we implement the [TimeSformer](https://arxiv.org/abs/2102.05095), [ViViT](https://arxiv.org/abs/2103.15691) and [MaskFeat](https://arxiv.org/abs/2112.09133). And we have pre-trained the `TimeSformer-B`, `ViViT-B` and `MaskFeat` on [Kinetics400/600](https://deepmind.com/research/open-source/kinetics), but still can't guarantee the performance reported in the paper. However, we find some relevant hyper-parameters which may help us to reach the target performance.
5 |
6 | ## Update
7 | 1. We have fixed serval known issues and now can build script to pretrain `MViT-B` with `MaskFeat` or finetune `MViT-B`/`TimeSformer-B`/`ViViT-B` on K400.
8 | 2. We have reimplemented the methods of hog extraction and hog prediction in [MaskFeat](https://arxiv.org/abs/2112.09133) which are currently more efficient to pretrain.
9 | 3. Note that if someone want to train `TimeSformer-B` or `ViViT-B` with current repo, they need to carefully adjust the learning rate and weight decay for a better performance. For example, you can can choose 0.005 for peak learning rate and 0.0001 for weight decay by default.
10 |
11 | ## Table of Contents
12 | 1. [Difference](#difference)
13 | 2. [TODO](#todo)
14 | 3. [Setup](#setup)
15 | 4. [Usage](#usage)
16 | 5. [Result](#result)
17 | 6. [Acknowledge](#acknowledge)
18 | 7. [Contribution](#contribution)
19 |
20 | ## Difference
21 | In order to share the basic divided spatial-temporal attention module to different video transformer, we make some changes in the following apart.
22 |
23 | ### 1. Position embedding
24 |
25 | We split the `position embedding` from *R(nt\*h\*w×d)* mentioned in the [ViViT](https://arxiv.org/abs/2103.15691) paper into *R(nh\*w×d)*
26 | and *R(nt×d)* to stay the same as [TimeSformer](https://arxiv.org/abs/2102.05095).
27 |
28 | ### 2. Class token
29 |
30 | In order to make clear whether to add the `class_token` into the module forward computation, we only compute the interaction between `class_token` and `query` when the current layer is the last layer (except `FFN`) of each transformer block.
31 |
32 | ### 3. Initialize from the pre-trained model
33 |
34 | * Tokenization: the token embedding filter can be chosen either `Conv2D` or `Conv3D`, and the initializing weights of `Conv3D` filters from `Conv2D` can be replicated along temporal dimension and averaging them or initialized with zeros along the temporal positions except at the center `t/2`.
35 | * Temporal `MSA` module weights: one can choose to copy the weights from spatial `MSA` module or initialize all weights with zeros.
36 | * Initialize from the `MAE` pre-trained model provided by [ZhiLiang](https://github.com/pengzhiliang/MAE-pytorch), where the class_token that does not appear in the `MAE` pre-train model is initialized from truncated normal distribution.
37 | * Initialize from the `ViT` pre-trained model can be found [here](https://drive.google.com/file/d/1QjGpbR8K4Cf4TJaDc60liVhBvPtrc2v4/view?usp=sharing).
38 |
39 | ## TODO
40 | - [√] add more `TimeSformer` and `ViViT` variants pre-trained weights.
41 | - A larger version and other operation types.
42 | - [√] add `linear prob` and `finetune recipe`.
43 | - Make available to transfer the pre-trained model to downstream task.
44 | - [ ] add more scalable Video Transformer benchmarks.
45 | - We will mainly focus on the data-efficient models.
46 | - [ ] add more robust objective functions.
47 | - Pre-train the model through the dominated self-supervised methods, e.g [Mask Image Modeling](https://arxiv.org/abs/2111.06377).
48 |
49 | ## Setup
50 | ```shell
51 | pip install -r requirements.txt
52 | ```
53 |
54 | ## Usage
55 | ### Training
56 | ```shell
57 | # path to Kinetics400 train set and val set
58 | TRAIN_DATA_PATH='/path/to/Kinetics400/train_list.txt'
59 | VAL_DATA_PATH='/path/to/Kinetics400/val_list.txt'
60 | # path to root directory
61 | ROOT_DIR='/path/to/work_space'
62 | # path to pretrain weights
63 | PRETRAIN_WEIGHTS='/path/to/weights'
64 |
65 | # pretrain mvit using maskfeat
66 | python model_pretrain.py \
67 | -lr 8e-4 -epoch 300 -batch_size 16 -num_workers 8 -frame_interval 4 -num_frames 16 -num_class 400 \
68 | -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH
69 |
70 | # finetune mvit with maskfeat pretrain weights
71 | python model_pretrain.py \
72 | -lr 0.005 -epoch 200 -batch_size 8 -num_workers 4 -num_frames 16 -frame_interval 4 -num_class 400 \
73 | -arch 'mvit' -optim_type 'adamw' -lr_schedule 'cosine' -objective 'supervised' -mixup True \
74 | -auto_augment 'rand_aug' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \
75 | -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS
76 |
77 | # finetune timesformer with imagenet pretrain weights
78 | python model_pretrain.py \
79 | -lr 0.005 -epoch 30 -batch_size 8 -num_workers 4 -num_frames 8 -frame_interval 32 -num_class 400 \
80 | -arch 'timesformer' -attention_type 'divided_space_time' -optim_type 'sgd' -lr_schedule 'cosine' \
81 | -objective 'supervised' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \
82 | -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS -weights_from 'imagenet'
83 |
84 | # finetune vivit with imagenet pretrain weights
85 | python model_pretrain.py \
86 | -lr 0.005 -epoch 30 -batch_size 8 -num_workers 4 -num_frames 16 -frame_interval 16 -num_class 400 \
87 | -arch 'vivit' -attention_type 'fact_encoder' -optim_type 'sgd' -lr_schedule 'cosine' \
88 | -objective 'supervised' -root_dir $ROOT_DIR -train_data_path $TRAIN_DATA_PATH \
89 | -val_data_path $VAL_DATA_PATH -pretrain_pth $PRETRAIN_WEIGHTS -weights_from 'imagenet'
90 |
91 | ```
92 | The minimal folder structure will look like as belows.
93 | ```
94 | root_dir
95 | ├── results
96 | │ ├── experiment_tag
97 | │ │ ├── ckpt
98 | │ │ ├── log
99 | ```
100 |
101 | ## Result
102 | ### Kinetics-400/600
103 |
104 | #### 1. Model Zoo
105 |
106 | | name | weights from | dataset | epochs | num frames | spatial crop | top1_acc | top5_acc | weight | log |
107 | |:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
108 | | TimeSformer-B | ImageNet-21K | K600 | 15e | 8 | 224 | 78.4 | 93.6 | [Google drive](https://drive.google.com/file/d/1-BSNROh35fiOIBcmtFNgWHEY_JC5UNDx/view?usp=sharing) or [BaiduYun](https://pan.baidu.com/s/1I5L41ZFHHSvFJttYt8F0Og)(code: yr4j) | [log](demo/log_arch_timesformer_lr5e-3_bs8_nw4_open.txt) |
109 | | ViViT-B | ImageNet-21K | K400 | 30e | 16 | 224 | 75.2 | 91.5 | [Google drive](https://drive.google.com/file/d/1-JVhSN3QHKUOLkXLWXWn5drdvKn0gPll/view?usp=sharing) | |
110 | | MaskFeat | from scratch | K400 | 100e | 16 | 224 | | | [Google drive](https://drive.google.com/file/d/1h3Q-267qV9kIcTT9Sct-zQzVvXljhyWW/view?usp=sharing) | |
111 |
112 | #### 1.1 Visualize
113 |
114 | For each column, we show the masked input(left), HOG predictions(middle) and original video frame(right).
115 |
116 |
117 |
118 |
119 | Here, we show the extracted attention map of a random frame sampled from the demo video.
120 |
121 |
122 |
123 |
124 |
125 |
126 | #### 2. Train Recipe(ablation study)
127 | #### 2.1 Acc
128 |
129 | | operation | top1_acc | top5_acc | top1_acc (three crop) |
130 | |:----|:----:|:----:|:----:|
131 | | base | 68.2 | 87.6 | - |
132 | | + `frame_interval` 4 -> 16 (span more time) | 72.9(+4.7) | 91.0(+3.4) | - |
133 | | + RandomCrop, flip (overcome overfit) | 75.7(+2.8) | 92.5(+1.5) | - |
134 | | + `batch size` 16 -> 8 (more iterations) | 75.8(+0.1) | 92.4(-0.1) | - |
135 | | + `frame_interval` 16 -> 24 (span more time) | 77.7(+1.9) | 93.3(+0.9) | 78.4 |
136 | | + `frame_interval` 24 -> 32 (span more time) | 78.4(+0.7) | 94.0(+0.7) | 79.1 |
137 |
138 | tips: `frame_interval` and `data augment` counts for the validation accuracy.
139 |
140 |
141 |
142 | #### 2.2 Time
143 |
144 | | operation | epoch_time |
145 | |:----|:----:|
146 | | base (start with DDP) | 9h+ |
147 | | + `speed up training recipes` | 1h+ |
148 | | + switch from `get_batch first` to `sample_Indice first` | 0.5h |
149 | | + `batch size` 16 -> 8 | 33.32m |
150 | | + `num_workers` 8 -> 4 | 35.52m |
151 | | + `frame_interval` 16 -> 24 | 44.35m |
152 |
153 | tips: Improve the `frame_interval` will drop a lot on time performance.
154 |
155 | 1.`speed up training recipes`:
156 | * More GPU device.
157 | * `pin_memory=True`.
158 | * Avoid CPU->GPU Device transfer (such as `.item()`, `.numpy()`, `.cpu()` operations on tensor or `log` to disk).
159 |
160 | 2.`get_batch first` means that we firstly read all frames through the video reader, and then get the target slice of frames, so it largely slow down the data-loading speed.
161 |
162 |
163 |
164 |
165 | ## Acknowledge
166 | this repo is built on top of [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), [pytorchvideo](https://github.com/facebookresearch/pytorchvideo/tree/9d0ca900f0427ed9b47b6182ad05f75c0e66274b), [skimage](https://github.com/scikit-image/scikit-image), [decord](https://github.com/dmlc/decord) and [kornia](https://github.com/kornia/kornia). I also learn many code designs from [MMaction2](https://github.com/open-mmlab/mmaction2). I thank the authors for releasing their code.
167 |
168 | ## Contribution
169 | I look forward to seeing one can provide some ideas about the repo, please feel free to report it in the issue, or even better, submit a pull request.
170 |
171 | And your star is my motivation, thank u~
172 |
--------------------------------------------------------------------------------
/data_trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytorch_lightning as pl
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data.dataloader import DataLoader
6 |
7 | from dataset import Kinetics
8 | import data_transform as T
9 |
10 | class Collator(object):
11 |
12 | def __init__(self, objective):
13 | self.objective = objective
14 |
15 | def collate(self, minibatch):
16 | image_list = []
17 | label_list = []
18 | mask_list = []
19 | marker_list = []
20 | for record in minibatch:
21 | image_list.append(record[0])
22 | label_list.append(record[1])
23 | if self.objective == 'mim':
24 | mask_list.append(record[2])
25 | marker_list.append(record[3])
26 | minibatch = []
27 | minibatch.append(torch.stack(image_list))
28 | if self.objective == 'mim':
29 | minibatch.append(torch.stack(label_list))
30 | minibatch.append(torch.stack(mask_list))
31 | minibatch.append(marker_list)
32 | else:
33 | label = np.stack(label_list)
34 | minibatch.append(torch.from_numpy(label))
35 |
36 | return minibatch
37 |
38 | class KineticsDataModule(pl.LightningDataModule):
39 | def __init__(self,
40 | configs,
41 | train_ann_path,
42 | val_ann_path=None,
43 | test_ann_path=None,
44 | ):
45 | super().__init__()
46 | self.train_ann_path = train_ann_path
47 | self.val_ann_path = val_ann_path
48 | self.test_ann_path = test_ann_path
49 | self.configs = configs
50 |
51 | def get_dataset(self, annotation_path, transform, temporal_sample):
52 | dataset = Kinetics(
53 | self.configs,
54 | annotation_path,
55 | transform=transform,
56 | temporal_sample=temporal_sample)
57 |
58 | return dataset
59 |
60 | def setup(self, stage):
61 | if self.configs.objective == 'mim':
62 | scale = (0.5, 1.0)
63 | color_jitter = None
64 | else:
65 | color_jitter = 0.4
66 | scale = None
67 |
68 | if self.configs.data_statics == 'imagenet':
69 | mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
70 | elif self.configs.data_statics == 'kinetics':
71 | mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)
72 | else:
73 | mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
74 |
75 | train_transform = T.create_video_transform(
76 | objective=self.configs.objective,
77 | input_size=self.configs.img_size,
78 | is_training=True,
79 | scale=scale,
80 | hflip=0.5,
81 | color_jitter=color_jitter,
82 | auto_augment=self.configs.auto_augment,
83 | interpolation='bicubic',
84 | mean=mean,
85 | std=std)
86 | train_temporal_sample = T.TemporalRandomCrop(
87 | self.configs.num_frames * self.configs.frame_interval)
88 |
89 | self.train_dataset = self.get_dataset(
90 | self.train_ann_path,
91 | train_transform,
92 | train_temporal_sample)
93 |
94 | if self.val_ann_path is not None:
95 | val_transform = T.create_video_transform(
96 | input_size=self.configs.img_size,
97 | is_training=False,
98 | interpolation='bicubic',
99 | mean=mean,
100 | std=std)
101 | val_temporal_sample = T.TemporalRandomCrop(
102 | self.configs.num_frames * self.configs.frame_interval)
103 | self.val_dataset = self.get_dataset(
104 | self.val_ann_path,
105 | val_transform,
106 | val_temporal_sample)
107 |
108 | if self.test_ann_path is not None:
109 | # need to update
110 | test_transform = T.Compose([
111 | T.Resize(scale_range=(-1, 256)),
112 | T.ThreeCrop(size=self.configs.img_size),
113 | T.ToTensor(),
114 | T.Normalize(mean, std),
115 | ])
116 | test_temporal_sample = T.TemporalRandomCrop(
117 | self.configs.num_frames * self.configs.frame_interval)
118 | self.test_dataset = self.get_dataset(
119 | self.test_ann_path,
120 | test_transform,
121 | test_temporal_sample)
122 |
123 | def train_dataloader(self):
124 | return DataLoader(
125 | self.train_dataset,
126 | batch_size=self.configs.batch_size,
127 | num_workers=self.configs.num_workers,
128 | collate_fn=Collator(self.configs.objective).collate,
129 | shuffle=True,
130 | drop_last=True,
131 | pin_memory=True
132 | )
133 |
134 | def val_dataloader(self):
135 | if self.val_ann_path is not None:
136 | return DataLoader(
137 | self.val_dataset,
138 | batch_size=self.configs.batch_size,
139 | num_workers=self.configs.num_workers,
140 | collate_fn=Collator(self.configs.objective).collate,
141 | shuffle=False,
142 | drop_last=False,
143 | )
144 |
145 | def test_dataloader(self):
146 | if self.test_ann_path is not None:
147 | return DataLoader(
148 | self.test_dataset,
149 | batch_size=self.configs.batch_size,
150 | num_workers=self.configs.num_workers,
151 | collate_fn=Collator(self.configs.objective).collate,
152 | shuffle=False,
153 | drop_last=False,
154 | )
--------------------------------------------------------------------------------
/data_transform.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | import random
3 | import math
4 |
5 | from einops import rearrange
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from PIL import Image
10 | from torchvision import transforms
11 | from torchvision.transforms.functional import InterpolationMode
12 |
13 | DEFAULT_CROP_PCT = 0.875
14 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
15 | _torch_interpolation_to_str = {
16 | InterpolationMode.NEAREST: 'nearest',
17 | InterpolationMode.BILINEAR: 'bilinear',
18 | InterpolationMode.BICUBIC: 'bicubic',
19 | InterpolationMode.BOX: 'box',
20 | InterpolationMode.HAMMING: 'hamming',
21 | InterpolationMode.LANCZOS: 'lanczos',
22 | }
23 | _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
24 |
25 | def str_to_interp_mode(mode_str):
26 | return _str_to_torch_interpolation[mode_str]
27 |
28 | # ------------------------------------------------------------
29 | # ---------------------- Common ----------------------------
30 | # ------------------------------------------------------------
31 | class Compose(object):
32 | """Composes several transforms together.
33 |
34 | Args:
35 | transforms (list of transform objects): list of data transforms to compose.
36 | """
37 |
38 | def __init__(self, transforms):
39 | self.transforms = transforms
40 |
41 | def __call__(self, img):
42 | for t in self.transforms:
43 | img = t(img)
44 | return img
45 |
46 | def randomize_parameters(self):
47 | for t in self.transforms:
48 | if hasattr(t, 'randomize_parameters'):
49 | t.randomize_parameters()
50 |
51 |
52 | class ToTensor(object):
53 | """Convert a tensor to torch.FloatTensor in the range [0.0, 1.0].
54 |
55 | Args:
56 | norm_value (int): the max value of the input image tensor, default to 255.
57 | """
58 |
59 | def __init__(self, norm_value=255):
60 | self.norm_value = norm_value
61 |
62 | def __call__(self, pic):
63 | if isinstance(pic, torch.Tensor):
64 | return pic.float().div(self.norm_value)
65 |
66 | def randomize_parameters(self):
67 | pass
68 |
69 |
70 | # ------------------------------------------------------------
71 | # ------------------- Transformation -----------------------
72 | # ------------------------------------------------------------
73 | class RandomCrop(object):
74 | """Random crop a fixed size region in a given image.
75 |
76 | Args:
77 | size (int, Tuple[int]): Desired output size (out_h, out_w) of the crop
78 | """
79 |
80 | def __init__(self, size):
81 | if isinstance(size, tuple):
82 | if size[0] != size[1]:
83 | raise ValueError(f'crop size {size[0], size[1]}, must be equal.')
84 | else:
85 | self.size = size[0]
86 | else:
87 | self.size = size
88 |
89 | def __call__(self, imgs):
90 | # Crop size
91 | size = self.size
92 |
93 | # Location
94 | img_height, img_width = imgs.size(2), imgs.size(3)
95 | y_offset = int(self.y_jitter * (img_height - size))
96 | x_offset = int(self.x_jitter * (img_width - size))
97 |
98 | imgs = imgs[..., y_offset : y_offset + size, x_offset : x_offset + size]
99 | return imgs
100 |
101 | def __repr__(self):
102 | repr_str = (f'{self.__class__.__name__}('
103 | f'size={self.size})')
104 | return repr_str
105 |
106 | def randomize_parameters(self):
107 | self.x_jitter = random.random()
108 | self.y_jitter = random.random()
109 |
110 |
111 | class Resize(object):
112 | """Resize images to a specific size.
113 |
114 | Args:
115 | scale_range (Tuple[int]): If the first value equals to -1, the second value
116 | serves as a short edge of the resized image: else if it is a tuple of 2
117 | integers, the short edge of resized image will be random choice from
118 | [scale_range[0], scale_range[1]].
119 | """
120 |
121 | def __init__(self, scale_range):
122 | if not isinstance(scale_range, tuple):
123 | raise ValueError(f'Scale_range {scale_range}, must be tuple.')
124 | self.scale_range = scale_range
125 |
126 | def __call__(self, imgs):
127 | imgs = self._resize(imgs)
128 | return imgs
129 |
130 | def __repr__(self):
131 | repr_str = (f'{self.__class__.__name__}('
132 | f'size={self.size})')
133 | return repr_str
134 |
135 | def randomize_parameters(self):
136 | if self.scale_range[0] == -1:
137 | self._resize = transforms.Resize(self.scale_range[1])
138 | else:
139 | short_edge = np.random.randint(self.scale_range[0],
140 | self.scale_range[1]+1)
141 | self._resize = transforms.Resize(short_edge)
142 |
143 |
144 | class RandomResizedCrop:
145 | """Random crop that specifics the area and height-weight ratio range.
146 |
147 | Args:
148 | area_range (Tuple[float]): The candidate area scales range of
149 | output cropped images. Default: (0.08, 1.0).
150 | aspect_ratio_range (Tuple[float]): The candidate aspect ratio range of
151 | output cropped images. Default: (3 / 4, 4 / 3).
152 | """
153 |
154 | def __init__(self,
155 | size,
156 | interpolation=3,
157 | scale=(0.08, 1.0),
158 | ratio=(3 / 4, 4 / 3)):
159 | self.size = size
160 | self.area_range = scale
161 | self.aspect_ratio_range = ratio
162 | self.interpolation = interpolation
163 |
164 | def __call__(self, imgs):
165 | """Performs the RandomResizeCrop augmentation.
166 |
167 | Args:
168 | results (dict): The resulting dict to be modified and passed
169 | to the next transform in pipeline.
170 | """
171 | # version one- frame diverse
172 | #imgs = self._crop_imgs(imgs)
173 |
174 | # version two- frame consistent
175 | img_width = imgs.shape[-1]
176 | img_height = imgs.shape[-2]
177 | # crop size
178 | min_length = min(img_width, img_height)
179 | crop_size = int(min_length * self.scale)
180 | width = crop_size
181 | height = crop_size*self.ratio
182 |
183 | # location
184 | left = self.tl_x * (img_width - width)
185 | top = self.tl_y * (img_height - height)
186 |
187 | imgs = transforms.functional.resized_crop(
188 | imgs, int(top), int(left), int(height), int(width), self.size, interpolation=self.interpolation)
189 |
190 | return imgs
191 |
192 | def __repr__(self):
193 | repr_str = (f'{self.__class__.__name__}('
194 | f'area_range={self.area_range}, '
195 | f'aspect_ratio_range={self.aspect_ratio_range}, '
196 | f'size={self.size})')
197 | return repr_str
198 |
199 | def randomize_parameters(self):
200 | self.scale = random.uniform(self.area_range[0], self.area_range[1])
201 | self.ratio = random.uniform(self.aspect_ratio_range[0], self.aspect_ratio_range[1])
202 | '''
203 | # version one- frame diverse
204 | self._crop_imgs = transforms.RandomResizedCrop(
205 | self.size, scale=(scale, scale), ratio=(ratio, ratio))
206 | '''
207 | # version two- frame consistent
208 | self.tl_x = random.random()
209 | self.tl_y = random.random()
210 |
211 |
212 | class Flip(object):
213 | """Flip the input images with a probability.
214 |
215 | Args:
216 | flip_ratio (float): Probability of implementing flip. Default: 0.5.
217 | """
218 |
219 | def __init__(self,
220 | flip_ratio=0.5):
221 | self.flip_ratio = flip_ratio
222 |
223 | def __call__(self, imgs):
224 | imgs = self._flip(imgs)
225 | return imgs
226 |
227 | def __repr__(self):
228 | repr_str = (
229 | f'{self.__class__.__name__}('
230 | f'flip_ratio={self.flip_ratio})')
231 | return repr_str
232 |
233 | def randomize_parameters(self):
234 | p = random.random()
235 | if p > self.flip_ratio:
236 | self._flip = transforms.RandomHorizontalFlip(p=1)
237 | else:
238 | self._flip = transforms.RandomHorizontalFlip(p=0)
239 |
240 |
241 | class RandomGrayscale(object):
242 | """Flip the input images with a probability.
243 |
244 | Args:
245 | flip_ratio (float): Probability of implementing flip. Default: 0.5.
246 | """
247 |
248 | def __init__(self,
249 | p=0.1):
250 | self.p = p
251 |
252 | def __call__(self, imgs):
253 | imgs = self._grayscale(imgs)
254 | return imgs
255 |
256 | def __repr__(self):
257 | repr_str = (
258 | f'{self.__class__.__name__}('
259 | f'p={self.p})')
260 | return repr_str
261 |
262 | def randomize_parameters(self):
263 | p = random.random()
264 | if p > self.p:
265 | self._grayscale = transforms.RandomGrayscale(p=0)
266 | else:
267 | self._grayscale = transforms.RandomGrayscale(p=1)
268 |
269 |
270 | class RandomApply(object):
271 | """Flip the input images with a probability.
272 |
273 | Args:
274 | flip_ratio (float): Probability of implementing flip. Default: 0.5.
275 | """
276 |
277 | def __init__(self,
278 | transform,
279 | p=0.5):
280 | self.p = p
281 | self.transform = transform
282 |
283 | def __call__(self, imgs):
284 | imgs = self._random_apply(imgs)
285 | return imgs
286 |
287 | def __repr__(self):
288 | repr_str = (
289 | f'{self.__class__.__name__}('
290 | f'p={self.p})')
291 | return repr_str
292 |
293 | def randomize_parameters(self):
294 | p = random.random()
295 | if p > self.p:
296 | self._random_apply = transforms.RandomApply(self.transform, p=0)
297 | else:
298 | self._random_apply = transforms.RandomApply(self.transform, p=1)
299 |
300 |
301 | class Normalize(object):
302 | """Normalize the images with the given mean and std value.
303 |
304 | Args:
305 | mean (Sequence[float]): Mean values of different channels.
306 | std (Sequence[float]): Std values of different channels.
307 | """
308 |
309 | def __init__(self, mean, std):
310 | if not isinstance(mean, Sequence):
311 | raise TypeError(
312 | f'Mean must be list, tuple or np.ndarray, but got {type(mean)}'
313 | )
314 |
315 | if not isinstance(std, Sequence):
316 | raise TypeError(
317 | f'Std must be list, tuple or np.ndarray, but got {type(std)}')
318 |
319 | self._normalize = transforms.Normalize(mean, std)
320 | self.mean = mean
321 | self.std = std
322 |
323 | #@profile
324 | def __call__(self, imgs):
325 | imgs = self._normalize(imgs)
326 | return imgs
327 |
328 | def __repr__(self):
329 | repr_str = (f'{self.__class__.__name__}('
330 | f'mean={self.mean}, '
331 | f'std={self.std})')
332 | return repr_str
333 |
334 | def randomize_parameters(self):
335 | pass
336 |
337 |
338 | class ColorJitter(object):
339 | """Randomly distort the brightness, contrast, saturation and hue of images.
340 |
341 | Note: The input images should be in RGB channel order.
342 |
343 | Args:
344 | brightness (float): the std values of brightness distortion.
345 | contrast (float): the std values of contrast distortion.
346 | saturation (float): the std values of saturation distortion.
347 | hue (float): the std values of hue distortion.
348 | """
349 |
350 | def __init__(self,
351 | brightness=0,
352 | contrast=0,
353 | saturation=0,
354 | hue=0):
355 | self.brightness = brightness
356 | self.contrast = contrast
357 | self.saturation = saturation
358 | self.hue = hue
359 |
360 | def __call__(self, imgs):
361 | print(imgs.shape)
362 | if imgs.ndim == 3:
363 | imgs = rearrange(imgs, '(t c) h w -> t c h w', c=3)
364 | imgs = self._color_jit(imgs)
365 | imgs = rearrange(imgs, 't c h w -> (t c) h w')
366 | return imgs
367 |
368 | def __repr__(self):
369 | repr_str = (f'{self.__class__.__name__}('
370 | f'brightness={self.brightness}, '
371 | f'contrast={self.contrast}, '
372 | f'saturation={self.saturation}, '
373 | f'hue={self.hue})')
374 | return repr_str
375 |
376 | def randomize_parameters(self):
377 | brightness = random.uniform(max(0,1-self.brightness), 1+self.brightness)
378 | contrast = random.uniform(max(0,1-self.contrast), 1+self.contrast)
379 | saturation = random.uniform(max(0,1-self.saturation), 1+self.saturation)
380 | hue = random.uniform(-self.hue, self.hue)
381 |
382 | self._color_jit = transforms.ColorJitter(
383 | brightness=(brightness,brightness),
384 | contrast=(contrast,contrast),
385 | saturation=(saturation,saturation),
386 | hue=(hue,hue))
387 |
388 |
389 | class CenterCrop(object):
390 | """Crop the center area from images.
391 |
392 | Args:
393 | crop_size (int | tuple[int]): (w, h) of crop size.
394 | """
395 |
396 | def __init__(self, size):
397 | self.size = size
398 | self._center_crop = transforms.CenterCrop(size=size)
399 |
400 | def __call__(self, imgs):
401 | imgs = self._center_crop(imgs)
402 | return imgs
403 |
404 | def __repr__(self):
405 | repr_str = (f'{self.__class__.__name__}(size={self.size})')
406 | return repr_str
407 |
408 | def randomize_parameters(self):
409 | pass
410 |
411 |
412 | class ThreeCrop(object):
413 | """Random crop the three pre-define regions of image.
414 |
415 | Args:
416 | size (int, Tuple[int]): Desired output size (out_h, out_w) of the crop
417 | """
418 |
419 | def __init__(self, size):
420 | if isinstance(size, tuple):
421 | if size[0] != size[1]:
422 | raise ValueError(f'crop size {size[0], size[1]}, must be equal.')
423 | else:
424 | self.size = size[0]
425 | else:
426 | self.size = size
427 |
428 | def __call__(self, imgs):
429 | # Crop size
430 | size = int(self.size)
431 | img_height, img_width = imgs.size(2), imgs.size(3)
432 | if size > img_height or size > img_width:
433 | msg = "Requested crop size {} is bigger than input size {}"
434 | raise ValueError(msg.format(size, (img_height, img_width)))
435 |
436 | # Location
437 | crops = []
438 | left_y_offset = (img_height - size) // 2
439 | left_x_offset = 0
440 | left = imgs[...,
441 | left_y_offset : left_y_offset + size,
442 | left_x_offset : left_x_offset + size]
443 | crops.append(left)
444 |
445 | right_y_offset = (img_height - size) // 2
446 | right_x_offset = img_width - size
447 | right = imgs[...,
448 | right_y_offset : right_y_offset + size,
449 | right_x_offset : right_x_offset + size]
450 | crops.append(right)
451 |
452 | center_y_offset = (img_height - size) // 2
453 | center_x_offset = (img_width - size) // 2
454 | center = imgs[...,
455 | center_y_offset : center_y_offset + size,
456 | center_x_offset : center_x_offset + size]
457 | crops.append(center)
458 |
459 | # (N_Crops T C H W)
460 | imgs = torch.stack(crops)
461 | return imgs
462 |
463 | def __repr__(self):
464 | repr_str = (f'{self.__class__.__name__}('
465 | f'size={self.size})')
466 | return repr_str
467 |
468 | def randomize_parameters(self):
469 | pass
470 |
471 |
472 | # ------------------------------------------------------------
473 | # --------------------- Sampling ---------------------------
474 | # ------------------------------------------------------------
475 | class TemporalRandomCrop(object):
476 | """Temporally crop the given frame indices at a random location.
477 |
478 | Args:
479 | size (int): Desired length of frames will be seen in the model.
480 | """
481 |
482 | def __init__(self, size):
483 | self.size = size
484 |
485 | def __call__(self, total_frames):
486 | rand_end = max(0, total_frames - self.size - 1)
487 | begin_index = random.randint(0, rand_end)
488 | end_index = min(begin_index + self.size, total_frames)
489 | return begin_index, end_index
490 |
491 |
492 | # ------------------------------------------------------------
493 | # --------------------- AdvancedAugment --------------------
494 | # ------------------------------------------------------------
495 | def transforms_train(img_size=224,
496 | scale=None,
497 | ratio=None,
498 | hflip=0.5,
499 | color_jitter=0.4,
500 | auto_augment=None,
501 | interpolation='random',
502 | mean=IMAGENET_DEFAULT_MEAN,
503 | std=IMAGENET_DEFAULT_STD,
504 | objective='supervised'):
505 | """
506 | If separate==True, the transforms are returned as a tuple of 3 separate transforms
507 | for use in a mixing dataset that passes
508 | * all data through the first (primary) transform, called the 'clean' data
509 | * a portion of the data through the secondary transform
510 | * normalizes and converts the branches above with the third, final transform
511 | """
512 | scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
513 | ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
514 | primary_tfl = [
515 | transforms.RandomResizedCrop(img_size, scale=scale, ratio=ratio, interpolation=str_to_interp_mode(interpolation))]
516 | if hflip > 0.:
517 | primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
518 |
519 | secondary_tfl = []
520 | if auto_augment:
521 | secondary_tfl += [transforms.autoaugment.RandAugment()]
522 | elif color_jitter is not None:
523 | # color jitter is enabled when not using AA
524 | if isinstance(color_jitter, (list, tuple)):
525 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
526 | # or 4 if also augmenting hue
527 | assert len(color_jitter) in (3, 4)
528 | else:
529 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
530 | color_jitter = (float(color_jitter),) * 3
531 | secondary_tfl += [transforms.ColorJitter(*color_jitter)]
532 |
533 | final_tfl = []
534 | final_tfl += [
535 | ToTensor(),
536 | transforms.Normalize(
537 | mean=torch.tensor(mean),
538 | std=torch.tensor(std))
539 | ]
540 | if objective == 'mim':
541 | return [Compose(primary_tfl + secondary_tfl), Compose(final_tfl)]
542 | else:
543 | return Compose(primary_tfl + secondary_tfl + final_tfl)
544 |
545 |
546 | def transforms_eval(img_size=224,
547 | crop_pct=None,
548 | interpolation='bilinear',
549 | mean=IMAGENET_DEFAULT_MEAN,
550 | std=IMAGENET_DEFAULT_STD):
551 | crop_pct = crop_pct or DEFAULT_CROP_PCT
552 |
553 | if isinstance(img_size, (tuple, list)):
554 | assert len(img_size) == 2
555 | if img_size[-1] == img_size[-2]:
556 | # fall-back to older behaviour so Resize scales to shortest edge if target is square
557 | scale_size = int(math.floor(img_size[0] / crop_pct))
558 | else:
559 | scale_size = tuple([int(x / crop_pct) for x in img_size])
560 | else:
561 | scale_size = int(math.floor(img_size / crop_pct))
562 |
563 | tfl = [
564 | transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
565 | transforms.CenterCrop(img_size),
566 | ]
567 | tfl += [
568 | ToTensor(),
569 | transforms.Normalize(
570 | mean=torch.tensor(mean),
571 | std=torch.tensor(std))
572 | ]
573 |
574 | return Compose(tfl)
575 |
576 |
577 | def create_video_transform(input_size=224,
578 | is_training=False,
579 | scale=None,
580 | ratio=None,
581 | hflip=0.5,
582 | color_jitter=0.4,
583 | auto_augment=None,
584 | interpolation='bilinear',
585 | mean=IMAGENET_DEFAULT_MEAN,
586 | std=IMAGENET_DEFAULT_STD,
587 | objective='supervised',
588 | crop_pct=None):
589 |
590 | if isinstance(input_size, (tuple, list)):
591 | img_size = input_size[-2:]
592 | else:
593 | img_size = input_size
594 |
595 | if is_training:
596 | transform = transforms_train(
597 | img_size,
598 | scale=scale,
599 | ratio=ratio,
600 | hflip=hflip,
601 | color_jitter=color_jitter,
602 | auto_augment=auto_augment,
603 | interpolation=interpolation,
604 | mean=mean,
605 | std=std,
606 | objective=objective)
607 | else:
608 | transform = transforms_eval(
609 | img_size,
610 | interpolation=interpolation,
611 | mean=mean,
612 | std=std,
613 | crop_pct=crop_pct)
614 |
615 | return transform
616 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import decord
5 | import numpy as np
6 | import torch
7 |
8 | from einops import rearrange
9 | from skimage.feature import hog
10 | from mask_generator import CubeMaskGenerator
11 |
12 | class_labels_map = None
13 | cls_sample_cnt = None
14 |
15 | def temporal_sampling(frames, start_idx, end_idx, num_samples):
16 | """
17 | Given the start and end frame index, sample num_samples frames between
18 | the start and end with equal interval.
19 | Args:
20 | frames (tensor): a tensor of video frames, dimension is
21 | `num video frames` x `channel` x `height` x `width`.
22 | start_idx (int): the index of the start frame.
23 | end_idx (int): the index of the end frame.
24 | num_samples (int): number of frames to sample.
25 | Returns:
26 | frames (tersor): a tensor of temporal sampled video frames, dimension is
27 | `num clip frames` x `channel` x `height` x `width`.
28 | """
29 | index = torch.linspace(start_idx, end_idx, num_samples)
30 | index = torch.clamp(index, 0, frames.shape[0] - 1).long()
31 | frames = torch.index_select(frames, 0, index)
32 | return frames
33 |
34 |
35 | def numpy2tensor(x):
36 | return torch.from_numpy(x)
37 |
38 |
39 | def extract_hog_features(image):
40 | hog_features_r = hog(image[:,:,0], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False)
41 | hog_features_g = hog(image[:,:,1], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False)
42 | hog_features_b = hog(image[:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False) #visualize=True
43 | hog_features = np.concatenate([hog_features_r,hog_features_g,hog_features_b], axis=-1)
44 | hog_features = rearrange(hog_features, '(ph dh) (pw dw) ch cw c -> ph pw (dh dw ch cw c)', ph=14, pw=14)
45 | return hog_features
46 |
47 |
48 | def load_annotation_data(data_file_path):
49 | with open(data_file_path, 'r') as data_file:
50 | return json.load(data_file)
51 |
52 |
53 | def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
54 | global class_labels_map, cls_sample_cnt
55 |
56 | if class_labels_map is not None:
57 | return class_labels_map, cls_sample_cnt
58 | else:
59 | cls_sample_cnt = {}
60 | class_labels_map = load_annotation_data(anno_pth)
61 | for cls in class_labels_map:
62 | cls_sample_cnt[cls] = 0
63 | return class_labels_map, cls_sample_cnt
64 |
65 |
66 | def load_annotations(ann_file, num_class, num_samples_per_cls):
67 | dataset = []
68 | class_to_idx, cls_sample_cnt = get_class_labels(num_class)
69 | with open(ann_file, 'r') as fin:
70 | for line in fin:
71 | line_split = line.strip().split('\t')
72 | sample = {}
73 | idx = 0
74 | # idx for frame_dir
75 | frame_dir = line_split[idx]
76 | sample['video'] = frame_dir
77 | idx += 1
78 |
79 | # idx for label[s]
80 | label = [x for x in line_split[idx:]]
81 | assert label, f'missing label in line: {line}'
82 | assert len(label) == 1
83 | class_name = label[0]
84 | class_index = int(class_to_idx[class_name])
85 |
86 | # choose a class subset of whole dataset
87 | if class_index < num_class:
88 | sample['label'] = class_index
89 | if cls_sample_cnt[class_name] < num_samples_per_cls:
90 | dataset.append(sample)
91 | cls_sample_cnt[class_name]+=1
92 |
93 | return dataset
94 |
95 |
96 | class DecordInit(object):
97 | """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
98 |
99 | def __init__(self, num_threads=1, **kwargs):
100 | self.num_threads = num_threads
101 | self.ctx = decord.cpu(0)
102 | self.kwargs = kwargs
103 |
104 | def __call__(self, filename):
105 | """Perform the Decord initialization.
106 | Args:
107 | results (dict): The resulting dict to be modified and passed
108 | to the next transform in pipeline.
109 | """
110 | reader = decord.VideoReader(filename,
111 | ctx=self.ctx,
112 | num_threads=self.num_threads)
113 | return reader
114 |
115 | def __repr__(self):
116 | repr_str = (f'{self.__class__.__name__}('
117 | f'sr={self.sr},'
118 | f'num_threads={self.num_threads})')
119 | return repr_str
120 |
121 |
122 | class Kinetics(torch.utils.data.Dataset):
123 | """Load the Kinetics video files
124 |
125 | Args:
126 | annotation_path (string): Annotation file path.
127 | num_class (int): The number of the class.
128 | num_samples_per_cls (int): the max samples used in each class.
129 | target_video_len (int): the number of video frames will be load.
130 | align_transform (callable): Align different videos in a specified size.
131 | temporal_sample (callable): Sample the target length of a video.
132 | """
133 |
134 | def __init__(self,
135 | configs,
136 | annotation_path,
137 | transform=None,
138 | temporal_sample=None):
139 | self.configs = configs
140 | self.data = load_annotations(annotation_path, self.configs.num_class, self.configs.num_samples_per_cls)
141 |
142 | self.transform = transform
143 | self.temporal_sample = temporal_sample
144 | self.target_video_len = self.configs.num_frames
145 | self.objective = self.configs.objective
146 | self.v_decoder = DecordInit()
147 |
148 | # mask
149 | if self.objective == 'mim':
150 | self.mask_generator = CubeMaskGenerator(input_size=(self.target_video_len//2,14,14),min_num_patches=16)
151 |
152 | def __getitem__(self, index):
153 | while True:
154 | try:
155 | path = self.data[index]['video']
156 | v_reader = self.v_decoder(path)
157 | total_frames = len(v_reader)
158 |
159 | # Sampling video frames
160 | start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
161 | assert end_frame_ind-start_frame_ind >= self.target_video_len
162 | frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
163 | video = v_reader.get_batch(frame_indice).asnumpy()
164 | del v_reader
165 | break
166 | except Exception as e:
167 | print(e)
168 | index = random.randint(0, len(self.data) - 1)
169 |
170 | # Video align transform: T C H W
171 | with torch.no_grad():
172 | video = torch.from_numpy(video).permute(0,3,1,2)
173 | if self.transform is not None:
174 | if self.objective == 'mim':
175 | pre_transform, post_transform = self.transform
176 | video = pre_transform(video) # align shape
177 | else:
178 | video = self.transform(video)
179 |
180 | # Label (depends)
181 | if self.objective == 'mim':
182 | # old version
183 | '''
184 | mask, cube_marker = self.mask_generator() # T' H' W'
185 | label = np.stack(list(map(extract_hog_features, video.permute(0,2,3,1).numpy())), axis=0) # T H W C -> T H' W' C'
186 | '''
187 | # new version
188 | mask, cube_marker = self.mask_generator() # T' H' W'
189 | hog_inputs = video.permute(0,2,3,1).numpy()
190 | hog_features = np.zeros((self.target_video_len,14,14,2*2*3*9))
191 | # speed up the extraction of hog features
192 | for marker in cube_marker: # [[start, span]]
193 | start_frame, span_frame = marker
194 | center_index = start_frame*2 + span_frame*2//2 # fix the temporal stride to 2
195 | hog_features[center_index] = extract_hog_features(hog_inputs[center_index])
196 | label = hog_features
197 | else:
198 | label = self.data[index]['label']
199 |
200 | if self.objective == 'mim':
201 | if self.transform is not None:
202 | video = post_transform(video) # to tensor & norm
203 | return video, numpy2tensor(label), numpy2tensor(mask), cube_marker
204 | else:
205 | return video, label
206 |
207 | def __len__(self):
208 | return len(self.data)
209 |
210 |
211 | if __name__ == '__main__':
212 | # Unit test for loading video and computing time cost
213 | import data_transform as T
214 | import time
215 | path = './YABnJL_bDzw.mp4'
216 | color_jitter = 0.4
217 | auto_augment = 'rand-m9-mstd0.5-inc1'
218 | scale = None
219 | mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)
220 | transform = T.create_video_transform(
221 | input_size=224,
222 | is_training=True,
223 | scale=scale,
224 | hflip=0.5,
225 | color_jitter=color_jitter,
226 | auto_augment=auto_augment,
227 | interpolation='bicubic',
228 | mean=mean,
229 | std=std)
230 |
231 | v_decoder = DecordInit()
232 | v_reader = v_decoder(path)
233 | total_frames = len(v_reader)
234 | target_video_len = 16
235 | # Sampling video frames
236 | temporal_sample = T.TemporalRandomCrop(target_video_len*16)
237 | start_frame_ind, end_frame_ind = temporal_sample(total_frames)
238 | frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, target_video_len, dtype=int)
239 | video = v_reader.get_batch(frame_indice).asnumpy()
240 | del v_reader
241 |
242 | # Video align transform: T C H W
243 | with torch.no_grad():
244 | video = torch.from_numpy(video).permute(0,3,1,2)
245 | if transform is not None:
246 | video = transform(video)
247 |
248 | show_processed_image(video.permute(0,2,3,1), save_dir='./', mean=mean, std=std)
249 | '''
250 | mask_generator = CubeMaskGenerator(input_size=(8,14,14),min_num_patches=16)
251 | counts = 1
252 | while True:
253 | if counts > 100:
254 | break
255 | start_time = time.perf_counter()
256 | v_decoder = DecordInit()
257 | v_reader = v_decoder(path)
258 | # Sampling video frames
259 | total_frames = len(v_reader)
260 | align_transform = T.Compose([
261 | T.RandomResizedCrop(size=(224, 224), area_range=(0.5, 1.0), interpolation=3), #InterpolationMode.BICUBIC
262 | T.Flip(),
263 | ])
264 | temporal_sample = T.TemporalRandomCrop(16*4)
265 | start_frame_ind, end_frame_ind = temporal_sample(total_frames)
266 | frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1,
267 | 16, dtype=int)
268 | video = v_reader.get_batch(frame_indice).asnumpy()
269 | del v_reader
270 |
271 | # Video align transform: T C H W
272 | with torch.no_grad():
273 | video = torch.from_numpy(video).permute(0,3,1,2)
274 | align_transform.randomize_parameters()
275 | video = align_transform(video)
276 | #label = np.stack(list(map(extract_hog_features, video.permute(0,2,3,1).numpy())), axis=0) # T H W C -> T H' W' C'
277 | _, hog_image = hog(video.permute(0,2,3,1).numpy()[0][:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False, visualize=True)
278 | mask, cube_marker = mask_generator() # T' H' W'
279 | counts += 1
280 | print(f'{(time.perf_counter()-start_time):.3f}')
281 | print('finish')
282 | '''
283 | #_, hog_image = hog(video.permute(0,2,3,1).numpy()[0][:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False, visualize=True)
284 | #from skimage import io
285 | #io.imsave('./test_img_hog.jpg',hog_image)
286 | #show_processed_image(video.permute(0,2,3,1), save_dir='./')
287 |
--------------------------------------------------------------------------------
/demo/9r8wpMS2iEk_000048_000058.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mx-mark/VideoTransformer-pytorch/194cae69722eb5efad031c59f4ff03bc60633fa8/demo/9r8wpMS2iEk_000048_000058.mp4
--------------------------------------------------------------------------------
/demo/YABnJL_bDzw.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mx-mark/VideoTransformer-pytorch/194cae69722eb5efad031c59f4ff03bc60633fa8/demo/YABnJL_bDzw.mp4
--------------------------------------------------------------------------------
/k400_classmap.json:
--------------------------------------------------------------------------------
1 | {"abseiling": "0", "air_drumming": "1", "answering_questions": "2", "applauding": "3", "applying_cream": "4", "archery": "5", "arm_wrestling": "6", "arranging_flowers": "7", "assembling_computer": "8", "auctioning": "9", "baby_waking_up": "10", "baking_cookies": "11", "balloon_blowing": "12", "bandaging": "13", "barbequing": "14", "bartending": "15", "beatboxing": "16", "bee_keeping": "17", "belly_dancing": "18", "bench_pressing": "19", "bending_back": "20", "bending_metal": "21", "biking_through_snow": "22", "blasting_sand": "23", "blowing_glass": "24", "blowing_leaves": "25", "blowing_nose": "26", "blowing_out_candles": "27", "bobsledding": "28", "bookbinding": "29", "bouncing_on_trampoline": "30", "bowling": "31", "braiding_hair": "32", "breading_or_breadcrumbing": "33", "breakdancing": "34", "brush_painting": "35", "brushing_hair": "36", "brushing_teeth": "37", "building_cabinet": "38", "building_shed": "39", "bungee_jumping": "40", "busking": "41", "canoeing_or_kayaking": "42", "capoeira": "43", "carrying_baby": "44", "cartwheeling": "45", "carving_pumpkin": "46", "catching_fish": "47", "catching_or_throwing_baseball": "48", "catching_or_throwing_frisbee": "49", "catching_or_throwing_softball": "50", "celebrating": "51", "changing_oil": "52", "changing_wheel": "53", "checking_tires": "54", "cheerleading": "55", "chopping_wood": "56", "clapping": "57", "clay_pottery_making": "58", "clean_and_jerk": "59", "cleaning_floor": "60", "cleaning_gutters": "61", "cleaning_pool": "62", "cleaning_shoes": "63", "cleaning_toilet": "64", "cleaning_windows": "65", "climbing_a_rope": "66", "climbing_ladder": "67", "climbing_tree": "68", "contact_juggling": "69", "cooking_chicken": "70", "cooking_egg": "71", "cooking_on_campfire": "72", "cooking_sausages": "73", "counting_money": "74", "country_line_dancing": "75", "cracking_neck": "76", "crawling_baby": "77", "crossing_river": "78", "crying": "79", "curling_hair": "80", "cutting_nails": "81", "cutting_pineapple": "82", "cutting_watermelon": "83", "dancing_ballet": "84", "dancing_charleston": "85", "dancing_gangnam_style": "86", "dancing_macarena": "87", "deadlifting": "88", "decorating_the_christmas_tree": "89", "digging": "90", "dining": "91", "disc_golfing": "92", "diving_cliff": "93", "dodgeball": "94", "doing_aerobics": "95", "doing_laundry": "96", "doing_nails": "97", "drawing": "98", "dribbling_basketball": "99", "drinking": "100", "drinking_beer": "101", "drinking_shots": "102", "driving_car": "103", "driving_tractor": "104", "drop_kicking": "105", "drumming_fingers": "106", "dunking_basketball": "107", "dying_hair": "108", "eating_burger": "109", "eating_cake": "110", "eating_carrots": "111", "eating_chips": "112", "eating_doughnuts": "113", "eating_hotdog": "114", "eating_ice_cream": "115", "eating_spaghetti": "116", "eating_watermelon": "117", "egg_hunting": "118", "exercising_arm": "119", "exercising_with_an_exercise_ball": "120", "extinguishing_fire": "121", "faceplanting": "122", "feeding_birds": "123", "feeding_fish": "124", "feeding_goats": "125", "filling_eyebrows": "126", "finger_snapping": "127", "fixing_hair": "128", "flipping_pancake": "129", "flying_kite": "130", "folding_clothes": "131", "folding_napkins": "132", "folding_paper": "133", "front_raises": "134", "frying_vegetables": "135", "garbage_collecting": "136", "gargling": "137", "getting_a_haircut": "138", "getting_a_tattoo": "139", "giving_or_receiving_award": "140", "golf_chipping": "141", "golf_driving": "142", "golf_putting": "143", "grinding_meat": "144", "grooming_dog": "145", "grooming_horse": "146", "gymnastics_tumbling": "147", "hammer_throw": "148", "headbanging": "149", "headbutting": "150", "high_jump": "151", "high_kick": "152", "hitting_baseball": "153", "hockey_stop": "154", "holding_snake": "155", "hopscotch": "156", "hoverboarding": "157", "hugging": "158", "hula_hooping": "159", "hurdling": "160", "hurling_(sport)": "161", "ice_climbing": "162", "ice_fishing": "163", "ice_skating": "164", "ironing": "165", "javelin_throw": "166", "jetskiing": "167", "jogging": "168", "juggling_balls": "169", "juggling_fire": "170", "juggling_soccer_ball": "171", "jumping_into_pool": "172", "jumpstyle_dancing": "173", "kicking_field_goal": "174", "kicking_soccer_ball": "175", "kissing": "176", "kitesurfing": "177", "knitting": "178", "krumping": "179", "laughing": "180", "laying_bricks": "181", "long_jump": "182", "lunge": "183", "making_a_cake": "184", "making_a_sandwich": "185", "making_bed": "186", "making_jewelry": "187", "making_pizza": "188", "making_snowman": "189", "making_sushi": "190", "making_tea": "191", "marching": "192", "massaging_back": "193", "massaging_feet": "194", "massaging_legs": "195", "massaging_person's_head": "196", "milking_cow": "197", "mopping_floor": "198", "motorcycling": "199", "moving_furniture": "200", "mowing_lawn": "201", "news_anchoring": "202", "opening_bottle": "203", "opening_present": "204", "paragliding": "205", "parasailing": "206", "parkour": "207", "passing_American_football_(in_game)": "208", "passing_American_football_(not_in_game)": "209", "peeling_apples": "210", "peeling_potatoes": "211", "petting_animal_(not_cat)": "212", "petting_cat": "213", "picking_fruit": "214", "planting_trees": "215", "plastering": "216", "playing_accordion": "217", "playing_badminton": "218", "playing_bagpipes": "219", "playing_basketball": "220", "playing_bass_guitar": "221", "playing_cards": "222", "playing_cello": "223", "playing_chess": "224", "playing_clarinet": "225", "playing_controller": "226", "playing_cricket": "227", "playing_cymbals": "228", "playing_didgeridoo": "229", "playing_drums": "230", "playing_flute": "231", "playing_guitar": "232", "playing_harmonica": "233", "playing_harp": "234", "playing_ice_hockey": "235", "playing_keyboard": "236", "playing_kickball": "237", "playing_monopoly": "238", "playing_organ": "239", "playing_paintball": "240", "playing_piano": "241", "playing_poker": "242", "playing_recorder": "243", "playing_saxophone": "244", "playing_squash_or_racquetball": "245", "playing_tennis": "246", "playing_trombone": "247", "playing_trumpet": "248", "playing_ukulele": "249", "playing_violin": "250", "playing_volleyball": "251", "playing_xylophone": "252", "pole_vault": "253", "presenting_weather_forecast": "254", "pull_ups": "255", "pumping_fist": "256", "pumping_gas": "257", "punching_bag": "258", "punching_person_(boxing)": "259", "push_up": "260", "pushing_car": "261", "pushing_cart": "262", "pushing_wheelchair": "263", "reading_book": "264", "reading_newspaper": "265", "recording_music": "266", "riding_a_bike": "267", "riding_camel": "268", "riding_elephant": "269", "riding_mechanical_bull": "270", "riding_mountain_bike": "271", "riding_mule": "272", "riding_or_walking_with_horse": "273", "riding_scooter": "274", "riding_unicycle": "275", "ripping_paper": "276", "robot_dancing": "277", "rock_climbing": "278", "rock_scissors_paper": "279", "roller_skating": "280", "running_on_treadmill": "281", "sailing": "282", "salsa_dancing": "283", "sanding_floor": "284", "scrambling_eggs": "285", "scuba_diving": "286", "setting_table": "287", "shaking_hands": "288", "shaking_head": "289", "sharpening_knives": "290", "sharpening_pencil": "291", "shaving_head": "292", "shaving_legs": "293", "shearing_sheep": "294", "shining_shoes": "295", "shooting_basketball": "296", "shooting_goal_(soccer)": "297", "shot_put": "298", "shoveling_snow": "299", "shredding_paper": "300", "shuffling_cards": "301", "side_kick": "302", "sign_language_interpreting": "303", "singing": "304", "situp": "305", "skateboarding": "306", "ski_jumping": "307", "skiing_(not_slalom_or_crosscountry)": "308", "skiing_crosscountry": "309", "skiing_slalom": "310", "skipping_rope": "311", "skydiving": "312", "slacklining": "313", "slapping": "314", "sled_dog_racing": "315", "smoking": "316", "smoking_hookah": "317", "snatch_weight_lifting": "318", "sneezing": "319", "sniffing": "320", "snorkeling": "321", "snowboarding": "322", "snowkiting": "323", "snowmobiling": "324", "somersaulting": "325", "spinning_poi": "326", "spray_painting": "327", "spraying": "328", "springboard_diving": "329", "squat": "330", "sticking_tongue_out": "331", "stomping_grapes": "332", "stretching_arm": "333", "stretching_leg": "334", "strumming_guitar": "335", "surfing_crowd": "336", "surfing_water": "337", "sweeping_floor": "338", "swimming_backstroke": "339", "swimming_breast_stroke": "340", "swimming_butterfly_stroke": "341", "swing_dancing": "342", "swinging_legs": "343", "swinging_on_something": "344", "sword_fighting": "345", "tai_chi": "346", "taking_a_shower": "347", "tango_dancing": "348", "tap_dancing": "349", "tapping_guitar": "350", "tapping_pen": "351", "tasting_beer": "352", "tasting_food": "353", "testifying": "354", "texting": "355", "throwing_axe": "356", "throwing_ball": "357", "throwing_discus": "358", "tickling": "359", "tobogganing": "360", "tossing_coin": "361", "tossing_salad": "362", "training_dog": "363", "trapezing": "364", "trimming_or_shaving_beard": "365", "trimming_trees": "366", "triple_jump": "367", "tying_bow_tie": "368", "tying_knot_(not_on_a_tie)": "369", "tying_tie": "370", "unboxing": "371", "unloading_truck": "372", "using_computer": "373", "using_remote_controller_(not_gaming)": "374", "using_segway": "375", "vault": "376", "waiting_in_line": "377", "walking_the_dog": "378", "washing_dishes": "379", "washing_feet": "380", "washing_hair": "381", "washing_hands": "382", "water_skiing": "383", "water_sliding": "384", "watering_plants": "385", "waxing_back": "386", "waxing_chest": "387", "waxing_eyebrows": "388", "waxing_legs": "389", "weaving_basket": "390", "welding": "391", "whistling": "392", "windsurfing": "393", "wrapping_present": "394", "wrestling": "395", "writing": "396", "yawning": "397", "yoga": "398", "zumba": "399"}
--------------------------------------------------------------------------------
/k600_classmap.json:
--------------------------------------------------------------------------------
1 | {
2 | "arguing": 0,
3 | "throwing ball (not baseball or American football)": 1,
4 | "falling off chair": 2,
5 | "shooting basketball": 3,
6 | "burping": 4,
7 | "ice swimming": 5,
8 | "assembling computer": 6,
9 | "playing chess": 7,
10 | "yawning": 8,
11 | "tackling": 9,
12 | "using a sledge hammer": 10,
13 | "sneezing": 11,
14 | "putting in contact lenses": 12,
15 | "massaging back": 13,
16 | "playing paintball": 14,
17 | "looking at phone": 15,
18 | "massaging neck": 16,
19 | "making sushi": 17,
20 | "poking bellybutton": 18,
21 | "scrambling eggs": 19,
22 | "hammer throw": 20,
23 | "ice skating": 21,
24 | "playing rubiks cube": 22,
25 | "catching or throwing frisbee": 23,
26 | "auctioning": 24,
27 | "person collecting garbage": 25,
28 | "contorting": 26,
29 | "playing piano": 27,
30 | "ski jumping": 28,
31 | "cutting pineapple": 29,
32 | "riding unicycle": 30,
33 | "playing lute": 31,
34 | "blowing bubble gum": 32,
35 | "backflip (human)": 33,
36 | "spelunking": 34,
37 | "getting a piercing": 35,
38 | "parkour": 36,
39 | "frying vegetables": 37,
40 | "planting trees": 38,
41 | "long jump": 39,
42 | "setting table": 40,
43 | "using puppets": 41,
44 | "whistling": 42,
45 | "cooking egg": 43,
46 | "bench pressing": 44,
47 | "scrubbing face": 45,
48 | "combing hair": 46,
49 | "snowmobiling": 47,
50 | "ice climbing": 48,
51 | "drawing": 49,
52 | "massaging feet": 50,
53 | "headbanging": 51,
54 | "opening wine bottle": 52,
55 | "marching": 53,
56 | "feeding goats": 54,
57 | "jumping jacks": 55,
58 | "playing hand clapping games": 56,
59 | "brushing teeth": 57,
60 | "dumpster diving": 58,
61 | "dancing gangnam style": 59,
62 | "springboard diving": 60,
63 | "high kick": 61,
64 | "making tea": 62,
65 | "adjusting glasses": 63,
66 | "putting on mascara": 64,
67 | "slapping": 65,
68 | "juggling fire": 66,
69 | "situp": 67,
70 | "snatch weight lifting": 68,
71 | "using segway": 69,
72 | "motorcycling": 70,
73 | "base jumping": 71,
74 | "playing violin": 72,
75 | "wood burning (art)": 73,
76 | "pushing car": 74,
77 | "kitesurfing": 75,
78 | "acting in play": 76,
79 | "cleaning shoes": 77,
80 | "pinching": 78,
81 | "rolling pastry": 79,
82 | "eating carrots": 80,
83 | "jumping into pool": 81,
84 | "tasting beer": 82,
85 | "blowing leaves": 83,
86 | "playing squash or racquetball": 84,
87 | "directing traffic": 85,
88 | "drooling": 86,
89 | "catching fish": 87,
90 | "bungee jumping": 88,
91 | "geocaching": 89,
92 | "blowing glass": 90,
93 | "slacklining": 91,
94 | "bending metal": 92,
95 | "changing gear in car": 93,
96 | "cooking scallops": 94,
97 | "rock scissors paper": 95,
98 | "training dog": 96,
99 | "chiseling stone": 97,
100 | "running on treadmill": 98,
101 | "zumba": 99,
102 | "using a power drill": 100,
103 | "ripping paper": 101,
104 | "riding a bike": 102,
105 | "testifying": 103,
106 | "jetskiing": 104,
107 | "salsa dancing": 105,
108 | "bouncing on bouncy castle": 106,
109 | "windsurfing": 107,
110 | "playing blackjack": 108,
111 | "playing cricket": 109,
112 | "applying cream": 110,
113 | "drop kicking": 111,
114 | "waving hand": 112,
115 | "eating hotdog": 113,
116 | "hockey stop": 114,
117 | "paragliding": 115,
118 | "sharpening knives": 116,
119 | "assembling bicycle": 117,
120 | "playing polo": 118,
121 | "visiting the zoo": 119,
122 | "flying kite": 120,
123 | "lunge": 121,
124 | "shaping bread dough": 122,
125 | "playing didgeridoo": 123,
126 | "shining shoes": 124,
127 | "installing carpet": 125,
128 | "trapezing": 126,
129 | "smoking": 127,
130 | "sausage making": 128,
131 | "water sliding": 129,
132 | "wrestling": 130,
133 | "changing oil": 131,
134 | "dyeing eyebrows": 132,
135 | "bouncing on trampoline": 133,
136 | "clapping": 134,
137 | "playing tennis": 135,
138 | "busking": 136,
139 | "building sandcastle": 137,
140 | "cutting apple": 138,
141 | "getting a haircut": 139,
142 | "ironing": 140,
143 | "tiptoeing": 141,
144 | "pouring beer": 142,
145 | "riding or walking with horse": 143,
146 | "luge": 144,
147 | "playing pan pipes": 145,
148 | "making a cake": 146,
149 | "pushing wheelbarrow": 147,
150 | "yoga": 148,
151 | "passing american football (not in game)": 149,
152 | "somersaulting": 150,
153 | "smoking pipe": 151,
154 | "card throwing": 152,
155 | "riding snow blower": 153,
156 | "playing harmonica": 154,
157 | "ice fishing": 155,
158 | "building lego": 156,
159 | "playing scrabble": 157,
160 | "dribbling basketball": 158,
161 | "dancing charleston": 159,
162 | "golf putting": 160,
163 | "mushroom foraging": 161,
164 | "lighting fire": 162,
165 | "standing on hands": 163,
166 | "diving cliff": 164,
167 | "petting animal (not cat)": 165,
168 | "spinning poi": 166,
169 | "changing wheel (not on bike)": 167,
170 | "stretching arm": 168,
171 | "packing": 169,
172 | "mowing lawn": 170,
173 | "chewing gum": 171,
174 | "using a microscope": 172,
175 | "playing dominoes": 173,
176 | "falling off bike": 174,
177 | "krumping": 175,
178 | "twiddling fingers": 176,
179 | "shearing sheep": 177,
180 | "folding paper": 178,
181 | "playing marbles": 179,
182 | "holding snake": 180,
183 | "sweeping floor": 181,
184 | "smashing": 182,
185 | "cutting nails": 183,
186 | "kicking soccer ball": 184,
187 | "playing flute": 185,
188 | "using atm": 186,
189 | "playing cymbals": 187,
190 | "opening door": 188,
191 | "photobombing": 189,
192 | "fly tying": 190,
193 | "roasting marshmallows": 191,
194 | "playing ping pong": 192,
195 | "using inhaler": 193,
196 | "swimming backstroke": 194,
197 | "crossing eyes": 195,
198 | "kicking field goal": 196,
199 | "air drumming": 197,
200 | "throwing tantrum": 198,
201 | "biking through snow": 199,
202 | "waking up": 200,
203 | "playing controller": 201,
204 | "beatboxing": 202,
205 | "baking cookies": 203,
206 | "snorkeling": 204,
207 | "carrying baby": 205,
208 | "grooming horse": 206,
209 | "threading needle": 207,
210 | "tickling": 208,
211 | "plastering": 209,
212 | "making balloon shapes": 210,
213 | "drinking shots": 211,
214 | "passing American football (in game)": 212,
215 | "clean and jerk": 213,
216 | "building shed": 214,
217 | "eating spaghetti": 215,
218 | "sipping cup": 216,
219 | "walking the dog": 217,
220 | "steer roping": 218,
221 | "stretching leg": 219,
222 | "belly dancing": 220,
223 | "playing keyboard": 221,
224 | "playing netball": 222,
225 | "cleaning gutters": 223,
226 | "playing basketball": 224,
227 | "playing accordion": 225,
228 | "dining": 226,
229 | "dyeing hair": 227,
230 | "making jewelry": 228,
231 | "massaging legs": 229,
232 | "card stacking": 230,
233 | "clam digging": 231,
234 | "weaving basket": 232,
235 | "unboxing": 233,
236 | "news anchoring": 234,
237 | "using a paint roller": 235,
238 | "using circular saw": 236,
239 | "cracking back": 237,
240 | "wading through water": 238,
241 | "cheerleading": 239,
242 | "gospel singing in church": 240,
243 | "playing drums": 241,
244 | "checking tires": 242,
245 | "land sailing": 243,
246 | "tapping pen": 244,
247 | "laughing": 245,
248 | "lifting hat": 246,
249 | "capsizing": 247,
250 | "barbequing": 248,
251 | "tango dancing": 249,
252 | "swimming butterfly stroke": 250,
253 | "bathing dog": 251,
254 | "scuba diving": 252,
255 | "using remote controller (not gaming)": 253,
256 | "javelin throw": 254,
257 | "cutting orange": 255,
258 | "sign language interpreting": 256,
259 | "spray painting": 257,
260 | "bulldozing": 258,
261 | "laying concrete": 259,
262 | "walking through snow": 260,
263 | "applauding": 261,
264 | "sharpening pencil": 262,
265 | "putting on eyeliner": 263,
266 | "gymnastics tumbling": 264,
267 | "braiding hair": 265,
268 | "flint knapping": 266,
269 | "shaking head": 267,
270 | "making cheese": 268,
271 | "trimming trees": 269,
272 | "hula hooping": 270,
273 | "deadlifting": 271,
274 | "washing hands": 272,
275 | "tying necktie": 273,
276 | "sword fighting": 274,
277 | "playing gong": 275,
278 | "using a wrench": 276,
279 | "putting on foundation": 277,
280 | "pillow fight": 278,
281 | "crawling baby": 279,
282 | "moon walking": 280,
283 | "doing jigsaw puzzle": 281,
284 | "making horseshoes": 282,
285 | "bookbinding": 283,
286 | "stomping grapes": 284,
287 | "brush painting": 285,
288 | "playing darts": 286,
289 | "presenting weather forecast": 287,
290 | "throwing discus": 288,
291 | "sucking lolly": 289,
292 | "playing recorder": 290,
293 | "embroidering": 291,
294 | "climbing tree": 292,
295 | "petting cat": 293,
296 | "reading newspaper": 294,
297 | "contact juggling": 295,
298 | "riding elephant": 296,
299 | "picking fruit": 297,
300 | "sleeping": 298,
301 | "cartwheeling": 299,
302 | "eating watermelon": 300,
303 | "shooting goal (soccer)": 301,
304 | "jogging": 302,
305 | "head stand": 303,
306 | "yarn spinning": 304,
307 | "blowing nose": 305,
308 | "calligraphy": 306,
309 | "rope pushdown": 307,
310 | "cleaning toilet": 308,
311 | "making the bed": 309,
312 | "separating eggs": 310,
313 | "answering questions": 311,
314 | "laying bricks": 312,
315 | "opening present": 313,
316 | "robot dancing": 314,
317 | "shaking hands": 315,
318 | "weaving fabric": 316,
319 | "longboarding": 317,
320 | "washing dishes": 318,
321 | "giving or receiving award": 319,
322 | "curling (sport)": 320,
323 | "extinguishing fire": 321,
324 | "vacuuming floor": 322,
325 | "arranging flowers": 323,
326 | "mopping floor": 324,
327 | "shaving legs": 325,
328 | "pushing cart": 326,
329 | "cooking sausages (not on barbeque)": 327,
330 | "planing wood": 328,
331 | "throwing snowballs": 329,
332 | "swimming front crawl": 330,
333 | "pumping gas": 331,
334 | "winking": 332,
335 | "exercising with an exercise ball": 333,
336 | "fixing hair": 334,
337 | "feeding birds": 335,
338 | "surfing water": 336,
339 | "making bubbles": 337,
340 | "skiing mono": 338,
341 | "opening refrigerator": 339,
342 | "folding clothes": 340,
343 | "marriage proposal": 341,
344 | "shining flashlight": 342,
345 | "sawing wood": 343,
346 | "fencing (sport)": 344,
347 | "drumming fingers": 345,
348 | "playing xylophone": 346,
349 | "hitting baseball": 347,
350 | "fixing bicycle": 348,
351 | "sanding floor": 349,
352 | "swing dancing": 350,
353 | "moving furniture": 351,
354 | "side kick": 352,
355 | "snowkiting": 353,
356 | "opening bottle (not wine)": 354,
357 | "playing pinball": 355,
358 | "playing saxophone": 356,
359 | "texting": 357,
360 | "chopping wood": 358,
361 | "playing cello": 359,
362 | "scrapbooking": 360,
363 | "ironing hair": 361,
364 | "arm wrestling": 362,
365 | "playing poker": 363,
366 | "eating doughnuts": 364,
367 | "waxing legs": 365,
368 | "capoeira": 366,
369 | "playing bagpipes": 367,
370 | "shucking oysters": 368,
371 | "cosplaying": 369,
372 | "bending back": 370,
373 | "breakdancing": 371,
374 | "playing trombone": 372,
375 | "building cabinet": 373,
376 | "hurdling": 374,
377 | "mountain climber (exercise)": 375,
378 | "playing volleyball": 376,
379 | "bee keeping": 377,
380 | "wading through mud": 378,
381 | "hugging baby": 379,
382 | "abseiling": 380,
383 | "carving pumpkin": 381,
384 | "breading or breadcrumbing": 382,
385 | "pole vault": 383,
386 | "juggling soccer ball": 384,
387 | "counting money": 385,
388 | "dancing macarena": 386,
389 | "dunking basketball": 387,
390 | "roller skating": 388,
391 | "knitting": 389,
392 | "laying stone": 390,
393 | "tapping guitar": 391,
394 | "shuffling feet": 392,
395 | "breaking boards": 393,
396 | "waxing back": 394,
397 | "coloring in": 395,
398 | "chopping meat": 396,
399 | "archery": 397,
400 | "playing monopoly": 398,
401 | "swinging baseball bat": 399,
402 | "repairing puncture": 400,
403 | "making paper aeroplanes": 401,
404 | "cracking neck": 402,
405 | "tying knot (not on a tie)": 403,
406 | "wrapping present": 404,
407 | "using bagging machine": 405,
408 | "jaywalking": 406,
409 | "raising eyebrows": 407,
410 | "dancing ballet": 408,
411 | "smelling feet": 409,
412 | "celebrating": 410,
413 | "bodysurfing": 411,
414 | "bartending": 412,
415 | "shaving head": 413,
416 | "squat": 414,
417 | "chopping vegetables": 415,
418 | "skydiving": 416,
419 | "jumpstyle dancing": 417,
420 | "archaeological excavation": 418,
421 | "waiting in line": 419,
422 | "recording music": 420,
423 | "passing soccer ball": 421,
424 | "washing feet": 422,
425 | "waxing eyebrows": 423,
426 | "making snowman": 424,
427 | "hoverboarding": 425,
428 | "rock climbing": 426,
429 | "catching or throwing softball": 427,
430 | "smoking hookah": 428,
431 | "skateboarding": 429,
432 | "climbing a rope": 430,
433 | "watching tv": 431,
434 | "delivering mail": 432,
435 | "putting on sari": 433,
436 | "swimming breast stroke": 434,
437 | "playing beer pong": 435,
438 | "eating chips": 436,
439 | "playing laser tag": 437,
440 | "blowing out candles": 438,
441 | "sled dog racing": 439,
442 | "shopping": 440,
443 | "folding napkins": 441,
444 | "roasting pig": 442,
445 | "writing": 443,
446 | "hand washing clothes": 444,
447 | "playing trumpet": 445,
448 | "riding mechanical bull": 446,
449 | "carving ice": 447,
450 | "punching bag": 448,
451 | "shot put": 449,
452 | "sewing": 450,
453 | "crossing river": 451,
454 | "hugging (not baby)": 452,
455 | "tying shoe laces": 453,
456 | "throwing water balloon": 454,
457 | "putting on lipstick": 455,
458 | "playing bass guitar": 456,
459 | "attending conference": 457,
460 | "playing kickball": 458,
461 | "peeling potatoes": 459,
462 | "trimming shrubs": 460,
463 | "massaging person's head": 461,
464 | "throwing knife": 462,
465 | "skipping stone": 463,
466 | "inflating balloons": 464,
467 | "playing harp": 465,
468 | "lock picking": 466,
469 | "skiing crosscountry": 467,
470 | "washing hair": 468,
471 | "cooking on campfire": 469,
472 | "grinding meat": 470,
473 | "preparing salad": 471,
474 | "huddling": 472,
475 | "tobogganing": 473,
476 | "tasting wine": 474,
477 | "bowling": 475,
478 | "tightrope walking": 476,
479 | "home roasting coffee": 477,
480 | "curling hair": 478,
481 | "chiseling wood": 479,
482 | "water skiing": 480,
483 | "calculating": 481,
484 | "tagging graffiti": 482,
485 | "dodgeball": 483,
486 | "tying bow tie": 484,
487 | "square dancing": 485,
488 | "laying tiles": 486,
489 | "tai chi": 487,
490 | "tie dying": 488,
491 | "throwing axe": 489,
492 | "playing ocarina": 490,
493 | "cutting watermelon": 491,
494 | "eating ice cream": 492,
495 | "cleaning pool": 493,
496 | "making pizza": 494,
497 | "playing field hockey": 495,
498 | "shuffling cards": 496,
499 | "singing": 497,
500 | "driving tractor": 498,
501 | "pull ups": 499,
502 | "golf chipping": 500,
503 | "faceplanting": 501,
504 | "shoveling snow": 502,
505 | "sailing": 503,
506 | "playing ukulele": 504,
507 | "needle felting": 505,
508 | "milking cow": 506,
509 | "catching or throwing baseball": 507,
510 | "trimming or shaving beard": 508,
511 | "playing with trains": 509,
512 | "high jump": 510,
513 | "riding mule": 511,
514 | "bobsledding": 512,
515 | "cumbia": 513,
516 | "jumping bicycle": 514,
517 | "riding scooter": 515,
518 | "reading book": 516,
519 | "climbing ladder": 517,
520 | "getting a tattoo": 518,
521 | "mosh pit dancing": 519,
522 | "historical reenactment": 520,
523 | "docking boat": 521,
524 | "bottling": 522,
525 | "swinging on something": 523,
526 | "parasailing": 524,
527 | "doing nails": 525,
528 | "talking on cell phone": 526,
529 | "blowdrying hair": 527,
530 | "clay pottery making": 528,
531 | "eating burger": 529,
532 | "juggling balls": 530,
533 | "doing laundry": 531,
534 | "finger snapping": 532,
535 | "tasting food": 533,
536 | "playing clarinet": 534,
537 | "playing organ": 535,
538 | "decorating the christmas tree": 536,
539 | "grooming dog": 537,
540 | "eating cake": 538,
541 | "sticking tongue out": 539,
542 | "disc golfing": 540,
543 | "kissing": 541,
544 | "cleaning windows": 542,
545 | "making a sandwich": 543,
546 | "staring": 544,
547 | "front raises": 545,
548 | "hurling (sport)": 546,
549 | "leatherworking": 547,
550 | "playing maracas": 548,
551 | "brushing hair": 549,
552 | "playing badminton": 550,
553 | "photocopying": 551,
554 | "canoeing or kayaking": 552,
555 | "polishing metal": 553,
556 | "playing guitar": 554,
557 | "battle rope training": 555,
558 | "egg hunting": 556,
559 | "pirouetting": 557,
560 | "hopscotch": 558,
561 | "karaoke": 559,
562 | "golf driving": 560,
563 | "tossing coin": 561,
564 | "bandaging": 562,
565 | "alligator wrestling": 563,
566 | "gold panning": 564,
567 | "flipping pancake": 565,
568 | "skiing slalom": 566,
569 | "casting fishing line": 567,
570 | "licking": 568,
571 | "bull fighting": 569,
572 | "cracking knuckles": 570,
573 | "headbutting": 571,
574 | "triple jump": 572,
575 | "fidgeting": 573,
576 | "punching person (boxing)": 574,
577 | "push up": 575,
578 | "lawn mower racing": 576,
579 | "riding camel": 577,
580 | "skipping rope": 578,
581 | "sword swallowing": 579,
582 | "feeding fish": 580,
583 | "peeling apples": 581,
584 | "doing aerobics": 582,
585 | "country line dancing": 583,
586 | "driving car": 584,
587 | "breathing fire": 585,
588 | "watering plants": 586,
589 | "popping balloons": 587,
590 | "snowboarding": 588,
591 | "blasting sand": 589,
592 | "pushing wheelchair": 590,
593 | "waxing chest": 591,
594 | "playing ice hockey": 592,
595 | "putting on shoes": 593,
596 | "pumping fist": 594,
597 | "surfing crowd": 595,
598 | "crying": 596,
599 | "unloading truck": 597,
600 | "welding": 598,
601 | "tap dancing": 599
602 | }
--------------------------------------------------------------------------------
/mask_generator.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import numpy as np
4 |
5 | class RandomMaskGenerator:
6 | def __init__(self, input_size=224, mask_ratio=0.6):
7 | if not isinstance(input_size, tuple):
8 | input_size = (input_size,) * 2
9 |
10 | self.height, self.width = input_size
11 |
12 | self.num_patches = self.height * self.width
13 | self.num_mask = int(mask_ratio * self.num_patches)
14 |
15 | def __call__(self):
16 | mask = np.hstack([
17 | np.zeros(self.num_patches - self.num_mask),
18 | np.ones(self.num_mask),
19 | ])
20 | np.random.shuffle(mask) #
21 | return mask # [1024]
22 |
23 | class CubeMaskGenerator:
24 | def __init__(
25 | self, input_size=(8,14,14), mask_ratio=0.4, min_num_patches=16, max_num_patches=None,
26 | min_aspect=0.3, max_aspect=None):
27 | self.temporal ,self.height, self.width = input_size
28 |
29 | self.num_patches = self.height * self.width
30 | self.num_masking_patches = int(self.num_patches * mask_ratio)
31 | self.num_masking_frames = int(self.temporal * mask_ratio)
32 |
33 | self.min_num_patches = min_num_patches # smaller than max_num_patches
34 | self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches
35 |
36 | max_aspect = max_aspect or 1 / min_aspect
37 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
38 |
39 | def __repr__(self):
40 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
41 | self.height, self.width, self.min_num_patches, self.max_num_patches,
42 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
43 | return repr_str
44 |
45 | def get_shape(self):
46 | return self.temporal, self.height, self.width
47 |
48 | def _mask(self, mask, max_mask_patches):
49 | delta = 0
50 | for attempt in range(10):
51 | target_area = random.uniform(self.min_num_patches, max_mask_patches)
52 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
53 | h = int(round(math.sqrt(target_area * aspect_ratio)))
54 | w = int(round(math.sqrt(target_area / aspect_ratio)))
55 | if w < self.width and h < self.height:
56 | top = random.randint(0, self.height - h)
57 | left = random.randint(0, self.width - w)
58 |
59 | num_masked = mask[top: top + h, left: left + w].sum()
60 | # Overlap
61 | if 0 < h * w - num_masked <= max_mask_patches:
62 | for i in range(top, top + h):
63 | for j in range(left, left + w):
64 | if mask[i, j] == 0:
65 | mask[i, j] = 1
66 | delta += 1
67 |
68 | if delta > 0:
69 | break
70 | return delta
71 |
72 | def __call__(self):
73 | time_marker = np.zeros(shape=self.temporal, dtype=np.int32)
74 | cube_mask = np.zeros(shape=self.get_shape(), dtype=np.int32)
75 | cube_marker = []
76 | temp_mask_count = 0
77 | while temp_mask_count < self.num_masking_frames:
78 | # generate 2D block-wise mask
79 | mask = np.zeros(shape=self.get_shape()[1:], dtype=np.int32)
80 | mask_count = 0
81 | while mask_count < self.num_masking_patches:
82 | max_mask_patches = self.num_masking_patches - mask_count
83 | max_mask_patches = min(max_mask_patches, self.max_num_patches)
84 |
85 | delta = self._mask(mask, max_mask_patches)
86 | if delta == 0:
87 | break
88 | else:
89 | mask_count += delta
90 | # assign to cube mask
91 | start_frame = random.randint(0, self.temporal)
92 | accumulate_frames = random.randint(1, self.num_masking_frames - temp_mask_count)
93 | mask_count = 0
94 | for i in range(start_frame, start_frame+accumulate_frames):
95 | if i > self.temporal-1:
96 | break
97 | if time_marker[i] == 0: # only update the unmask frame
98 | time_marker[i] = 1
99 | cube_mask[i] = mask
100 | mask_count+=1
101 | else: #avoid to overlap the orginal mask
102 | break
103 | temp_mask_count += mask_count
104 | if mask_count > 0: # mark the center frame index(mask_count > 0)
105 | cube_marker.append([start_frame, mask_count])
106 |
107 | return cube_mask, cube_marker
108 |
109 | if __name__ == '__main__':
110 | # Unit test for computing cube mask and extracting hog features
111 |
112 | '''
113 | mask_generator = CubeMaskGenerator(input_size=(8,14,14),min_num_patches=16)
114 | mask, cube_marker = mask_generator()
115 | print(mask)
116 | from einops import repeat
117 | mask = repeat(mask, 't h w -> t (h dh) (w dw)', dh=56//14, dw=56//14) # nearest-neighbor resize
118 | print(mask)
119 |
120 | #print(cube_marker)
121 | center_index = np.zeros(8).astype('bool')
122 | for marker in cube_marker:
123 | center_index[marker[0]+marker[1]//2] = 1
124 | mask[~center_index] = 0
125 | print(mask)
126 | #for i in cube_marker:
127 | #print(mask[i].sum()/(56*56))
128 | '''
129 |
130 | #'''
131 | from skimage.feature import hog
132 | from skimage import io
133 | from skimage import data
134 | from einops import rearrange
135 | import torch
136 |
137 | def extract_hog(image):
138 | hog_features_r = hog(image[:,:,0], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False)
139 | hog_features_g = hog(image[:,:,1], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False)
140 | hog_features_b, hog_image = hog(image[:,:,2], orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1), block_norm='L2', feature_vector=False, visualize=True)
141 | hog_features = np.concatenate([hog_features_r,hog_features_g,hog_features_b], axis=-1)
142 | hog_features = rearrange(hog_features, '(ph dh) (pw dw) ch cw c -> ph pw (dh dw ch cw c)', ph=14, pw=14)
143 | return hog_features
144 |
145 | images = np.zeros((2,224,224,3))
146 | image = data.astronaut()[:224,:224,:] # h w c
147 | image = torch.from_numpy(image).numpy()
148 | images[0] = image
149 | image = io.imread('./test_1.jpg')[:224,:224,:]
150 | images[1] = image
151 | hog_features = np.stack(list(map(extract_hog, images)), axis=0)
152 | print(hog_features.shape, np.min(hog_features), np.max(hog_features))
153 | #io.imsave('./test_img_hog.jpg',hog_image)
154 | #'''
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
17 | x = x.long().view(-1, 1)
18 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
19 |
20 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
21 | off_value = smoothing / num_classes
22 | on_value = 1. - smoothing + off_value
23 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
24 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
25 | return y1 * lam + y2 * (1. - lam)
26 |
27 | def rand_bbox(img_shape, lam, margin=0., count=None):
28 | """ Standard CutMix bounding-box
29 | Generates a random square bbox based on lambda value. This impl includes
30 | support for enforcing a border margin as percent of bbox dimensions.
31 |
32 | Args:
33 | img_shape (tuple): Image shape as tuple
34 | lam (float): Cutmix lambda value
35 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
36 | count (int): Number of bbox to generate
37 | """
38 | ratio = np.sqrt(1 - lam)
39 | img_h, img_w = img_shape[-2:]
40 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
41 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
42 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
43 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
44 | yl = np.clip(cy - cut_h // 2, 0, img_h)
45 | yh = np.clip(cy + cut_h // 2, 0, img_h)
46 | xl = np.clip(cx - cut_w // 2, 0, img_w)
47 | xh = np.clip(cx + cut_w // 2, 0, img_w)
48 | return yl, yh, xl, xh
49 |
50 | def cutmix_bbox_and_lam(img_shape, lam, correct_lam=True, count=None):
51 | """ Generate bbox and apply lambda correction.
52 | """
53 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
54 | if correct_lam:
55 | bbox_area = (yu - yl) * (xu - xl)
56 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
57 | return (yl, yu, xl, xu), lam
58 |
59 | class Mixup:
60 | """ Mixup/Cutmix that applies different params to each element or whole batch
61 |
62 | Args:
63 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
64 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
65 | prob (float): probability of applying mixup or cutmix per batch or element
66 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
67 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
68 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
69 | label_smoothing (float): apply label smoothing to the mixed target tensor
70 | num_classes (int): number of classes for target
71 | """
72 | def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0, prob=1.0, switch_prob=0.5,
73 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
74 | self.mixup_alpha = mixup_alpha
75 | self.cutmix_alpha = cutmix_alpha
76 | self.mix_prob = prob
77 | self.switch_prob = switch_prob
78 | self.label_smoothing = label_smoothing
79 | self.num_classes = num_classes
80 | self.mode = mode
81 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
82 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
83 |
84 | def _params_per_batch(self):
85 | lam = 1.
86 | use_cutmix = False
87 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
88 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
89 | use_cutmix = np.random.rand() < self.switch_prob
90 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
91 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
92 | elif self.mixup_alpha > 0.:
93 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
94 | elif self.cutmix_alpha > 0.:
95 | use_cutmix = True
96 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
97 | else:
98 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0."
99 | lam = float(lam_mix)
100 | return lam, use_cutmix
101 |
102 | def _mix_batch(self, x):
103 | lam, use_cutmix = self._params_per_batch()
104 | if lam == 1.:
105 | return 1.
106 | if use_cutmix:
107 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
108 | x.shape, lam, correct_lam=self.correct_lam)
109 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
110 | else:
111 | x_flipped = x.flip(0).mul_(1. - lam)
112 | x.mul_(lam).add_(x_flipped)
113 | return lam
114 |
115 | def __call__(self, x, target):
116 | assert len(x) % 2 == 0, 'Batch size should be even when using this' # [B,C,H,W] -> [B,T,C,H,W]
117 | need_reshape = False
118 | if x.ndim == 5:
119 | need_reshape = True
120 | b,t,c,h,w = x.shape
121 | x = x.view(b,t*c,h,w)
122 | lam = self._mix_batch(x)
123 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
124 | if need_reshape:
125 | x = x.view(b,t,c,h,w)
126 | return x, target
127 |
128 | if __name__ == '__main__':
129 | SEED = 0
130 | torch.random.manual_seed(SEED)
131 | np.random.seed(SEED)
132 | mixupfn = Mixup(num_classes=4)
133 | x = torch.rand(2,2,1,10,10)
134 | label = [0, 1]
135 | print(x, label)
136 | y = torch.from_numpy(np.array(label))
137 | x, y = mixupfn(x, y)
138 | print(x.shape, y.shape)
139 | print(x, y)
--------------------------------------------------------------------------------
/model_pretrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 | import warnings
5 | import argparse
6 |
7 | import kornia.augmentation as K
8 | import numpy as np
9 | import pytorch_lightning as pl
10 | from pytorch_lightning.plugins import DDPPlugin
11 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
12 | import torch
13 | import torch.utils.data as data
14 |
15 | from data_trainer import KineticsDataModule
16 | from model_trainer import VideoTransformer
17 | import data_transform as T
18 | from utils import print_on_rank_zero
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='lr receiver')
23 | # Common
24 | parser.add_argument(
25 | '-epoch', type=int, required=True,
26 | help='the max epochs of training')
27 | parser.add_argument(
28 | '-batch_size', type=int, required=True,
29 | help='the batch size of data inputs')
30 | parser.add_argument(
31 | '-num_workers', type=int, default=4,
32 | help='the num workers of loading data')
33 | parser.add_argument(
34 | '-resume', default=False, action='store_true')
35 | parser.add_argument(
36 | '-resume_from_checkpoint', type=str, default=None,
37 | help='the pretrain params from specific path')
38 | parser.add_argument(
39 | '-log_interval', type=int, default=30,
40 | help='the intervals of logging')
41 | parser.add_argument(
42 | '-save_ckpt_freq', type=int, default=20,
43 | help='the intervals of saving model')
44 | parser.add_argument(
45 | '-objective', type=str, default='mim',
46 | help='the learning objective from [mim, supervised]')
47 | parser.add_argument(
48 | '-eval_metrics', type=str, default='finetune',
49 | help='the eval metrics choosen from [linear_prob, finetune]')
50 |
51 | # Environment
52 | parser.add_argument(
53 | '-gpus', nargs='+', type=int, default=-1,
54 | help='the avaiable gpus in this experiment')
55 | parser.add_argument(
56 | '-root_dir', type=str, required=True,
57 | help='the path to root dir for work space')
58 |
59 | # Data
60 | parser.add_argument(
61 | '-num_class', type=int, required=True,
62 | help='the num class of dataset used')
63 | parser.add_argument(
64 | '-num_samples_per_cls', type=int, default=10000,
65 | help='the num samples of per class')
66 | parser.add_argument(
67 | '-img_size', type=int, default=224,
68 | help='the size of processed image')
69 | parser.add_argument(
70 | '-num_frames', type=int, required=True,
71 | help='the mumber of frame sampling')
72 | parser.add_argument(
73 | '-frame_interval', type=int, required=True,
74 | help='the intervals of frame sampling')
75 | parser.add_argument(
76 | '-data_statics', type=str, default='kinetics',
77 | help='choose data statics from [imagenet, kinetics]')
78 | parser.add_argument(
79 | '-train_data_path', type=str, required=True,
80 | help='the path to train set')
81 | parser.add_argument(
82 | '-val_data_path', type=str, default=None,
83 | help='the path to val set')
84 | parser.add_argument(
85 | '-test_data_path', type=str, default=None,
86 | help='the path to test set')
87 | parser.add_argument(
88 | '-multi_crop', type=bool, default=False,
89 | help="""Whether or not to use multi crop.""")
90 | parser.add_argument(
91 | '-mixup', type=bool, default=False,
92 | help="""Whether or not to use multi crop.""")
93 | parser.add_argument(
94 | '-auto_augment', type=str, default=None,
95 | help='the used Autoaugment policy')
96 |
97 | # Model
98 | parser.add_argument(
99 | '-arch', type=str, default='timesformer',
100 | help='the choosen model arch from [timesformer, vivit]')
101 | parser.add_argument(
102 | '-attention_type', type=str, default='divided_space_time',
103 | help='the choosen attention type using in model')
104 | parser.add_argument(
105 | '-pretrain_pth', type=str, default=None,
106 | help='the path to the pretrain weights')
107 | parser.add_argument(
108 | '-weights_from', type=str, default='imagenet',
109 | help='the pretrain params from [imagenet, kinetics]')
110 |
111 | # Training/Optimization parameters
112 | parser.add_argument(
113 | '-seed', type=int, default=0,
114 | help='the seed of exp')
115 | parser.add_argument(
116 | '-optim_type', type=str, default='adamw',
117 | help='the optimizer using in the training')
118 | parser.add_argument(
119 | '-lr_schedule', type=str, default='cosine',
120 | help='the lr schedule using in the training')
121 | parser.add_argument(
122 | '-lr', type=float, required=True,
123 | help='the initial learning rate')
124 | parser.add_argument(
125 | '-layer_decay', type=float, default=0.75,
126 | help='the value of layer_decay')
127 | parser.add_argument(
128 | '--min_lr', type=float, default=1e-6,
129 | help="""Target LR at the end of optimization. We use a cosine LR schedule with linear warmup.""")
130 | parser.add_argument(
131 | '-use_fp16', type=bool, default=True,
132 | help="""Whether or not to use half precision for training. Improves training time and memory requirements,
133 | but can provoke instability and slight decay of performance. We recommend disabling
134 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
135 | parser.add_argument(
136 | '-weight_decay', type=float, default=0.05,
137 | help="""Initial value of the weight decay. With ViT, a smaller value at the beginning of training works well.""")
138 | parser.add_argument(
139 | '-weight_decay_end', type=float, default=0.05,
140 | help="""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by
141 | the end of training improves performance for ViTs.""")
142 | parser.add_argument(
143 | '-clip_grad', type=float, default=0,
144 | help="""Maximal parameter gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
145 | help optimization for larger ViT architectures. 0 for disabling.""")
146 | parser.add_argument(
147 | "-warmup_epochs", default=5, type=int,
148 | help="Number of epochs for the linear learning-rate warm up.")
149 |
150 | args = parser.parse_args()
151 |
152 | return args
153 |
154 | def single_run():
155 | args = parse_args()
156 | warnings.filterwarnings('ignore')
157 |
158 | # linear learning rate scale
159 | if isinstance(args.gpus, int):
160 | num_gpus = torch.cuda.device_count()
161 | else:
162 | num_gpus = len(args.gpus)
163 | effective_batch_size = args.batch_size * num_gpus
164 | args.lr = args.lr * effective_batch_size / 256
165 |
166 | # Experiment Settings
167 | ROOT_DIR = args.root_dir
168 | exp_tag = (f'objective_{args.objective}_arch_{args.arch}_lr_{args.lr}_'
169 | f'optim_{args.optim_type}_lr_schedule_{args.lr_schedule}_'
170 | f'fp16_{args.use_fp16}_weight_decay_{args.weight_decay}_'
171 | f'weight_decay_end_{args.weight_decay_end}_warmup_epochs_{args.warmup_epochs}_'
172 | f'pretrain_{args.pretrain_pth}_weights_from_{args.weights_from}_seed_{args.seed}_'
173 | f'img_size_{args.img_size}_num_frames_{args.num_frames}_eval_metrics_{args.eval_metrics}_'
174 | f'frame_interval_{args.frame_interval}_mixup_{args.mixup}_'
175 | f'multi_crop_{args.multi_crop}_auto_augment_{args.auto_augment}_')
176 | ckpt_dir = os.path.join(ROOT_DIR, f'results/{exp_tag}/ckpt')
177 | log_dir = os.path.join(ROOT_DIR, f'results/{exp_tag}/log')
178 | os.makedirs(ckpt_dir, exist_ok=True)
179 | os.makedirs(log_dir, exist_ok=True)
180 |
181 | # Data
182 | do_eval = True if args.val_data_path is not None else False
183 | do_test = True if args.test_data_path is not None else False
184 |
185 | data_module = KineticsDataModule(configs=args,
186 | train_ann_path=args.train_data_path,
187 | val_ann_path=args.val_data_path,
188 | test_ann_path=args.test_data_path)
189 |
190 | # Resume from the last checkpoint
191 | if args.resume and not args.resume_from_checkpoint:
192 | args.resume_from_checkpoint = os.path.join(ckpt_dir, 'last_checkpoint.pth')
193 |
194 | # Trainer
195 | if args.arch == 'mvit' and args.objective == 'supervised':
196 | find_unused_parameters = True
197 | else:
198 | find_unused_parameters = False
199 |
200 | trainer = pl.Trainer(
201 | gpus=args.gpus,
202 | accelerator="ddp",
203 | precision=16,
204 | plugins=[DDPPlugin(find_unused_parameters=find_unused_parameters),],
205 | max_epochs=args.epoch,
206 | callbacks=[
207 | LearningRateMonitor(logging_interval='step'),
208 | ],
209 | resume_from_checkpoint=args.resume_from_checkpoint,
210 | check_val_every_n_epoch=1,
211 | log_every_n_steps=args.log_interval,
212 | progress_bar_refresh_rate=args.log_interval,
213 | flush_logs_every_n_steps=args.log_interval*5)
214 |
215 | # To be reproducable
216 | torch.random.manual_seed(args.seed)
217 | np.random.seed(args.seed)
218 | random.seed(args.seed)
219 | pl.seed_everything(args.seed, workers=True)
220 |
221 | # Model
222 | model = VideoTransformer(configs=args,
223 | trainer=trainer,
224 | ckpt_dir=ckpt_dir,
225 | do_eval=do_eval,
226 | do_test=do_test)
227 | print_on_rank_zero(args)
228 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
229 | print_on_rank_zero(f'{timestamp} - INFO - Start running,')
230 | trainer.fit(model, data_module)
231 |
232 | if __name__ == '__main__':
233 | single_run()
--------------------------------------------------------------------------------
/model_trainer.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import math
3 | import time
4 |
5 | import pytorch_lightning as pl
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torchvision
11 | from torchmetrics import Accuracy
12 | from timm.loss import SoftTargetCrossEntropy
13 |
14 | import utils
15 | from mixup import Mixup
16 | from optimizer import build_optimizer
17 | from transformer import ClassificationHead
18 | from video_transformer import TimeSformer, ViViT, MaskFeat
19 |
20 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, base_lr, objective, min_lr=5e-5, last_epoch=-1):
21 | """ Create a schedule with a learning rate that decreases following the
22 | values of the cosine function between 0 and `pi * cycles` after a warmup
23 | period during which it increases linearly between 0 and base_lr.
24 | """
25 | # step means epochs here
26 | def lr_lambda(current_step):
27 | current_step += 1
28 | if current_step <= num_warmup_steps:
29 | return float(current_step) / float(max(1, num_warmup_steps)) # * base_lr
30 | progress = min(float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)), 1)
31 | if objective == 'mim':
32 | return 0.5 * (1. + math.cos(math.pi * progress))
33 | else:
34 | factor = 0.5 * (1. + math.cos(math.pi * progress))
35 | return factor*(1 - min_lr/base_lr) + min_lr/base_lr
36 |
37 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
38 |
39 | class VideoTransformer(pl.LightningModule):
40 |
41 | def __init__(self,
42 | configs,
43 | trainer,
44 | ckpt_dir,
45 | do_eval,
46 | do_test,
47 | n_crops=3):
48 | super().__init__()
49 | self.configs = configs
50 | self.trainer = trainer
51 |
52 | # build models
53 | if self.configs.objective =='mim':
54 | self.model = MaskFeat(pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2]], feature_dim=2*2*2*3*9)
55 | else: # supervised
56 | # load pretrain weights from pretrained weight path and model.init_weights method
57 | if self.configs.arch == 'vivit':
58 | self.model = ViViT(
59 | pretrain_pth=self.configs.pretrain_pth,
60 | weights_from=self.configs.weights_from,
61 | img_size=self.configs.img_size,
62 | num_frames=self.configs.num_frames,
63 | attention_type=self.configs.attention_type)
64 | elif self.configs.arch == 'timesformer':
65 | self.model = TimeSformer(
66 | pretrain_pth=self.configs.pretrain_pth,
67 | weights_from=self.configs.weights_from,
68 | img_size=self.configs.img_size,
69 | num_frames=self.configs.num_frames,
70 | attention_type=self.configs.attention_type)
71 | else: # arch-mvit
72 | self.model = MaskFeat(
73 | pool_q_stride_size=[[1, 1, 2, 2], [3, 1, 2, 2]],
74 | feature_dim=2*2*2*3*9,
75 | pretrain_pth=self.configs.pretrain_pth,
76 | img_size=self.configs.img_size,
77 | num_frames=self.configs.num_frames)
78 | for name, param in self.model.decoder_pred.named_parameters():
79 | param.requires_grad = False
80 |
81 | self.cls_head = ClassificationHead(
82 | self.configs.num_class, self.model.embed_dims, eval_metrics=self.configs.eval_metrics)
83 |
84 | self.max_top1_acc = 0
85 | self.train_top1_acc = Accuracy()
86 | self.train_top5_acc = Accuracy(top_k=5)
87 | if self.configs.mixup:
88 | self.mixup_fn = Mixup(num_classes=self.configs.num_class)
89 | self.loss_fn = SoftTargetCrossEntropy()
90 | else:
91 | self.loss_fn = nn.CrossEntropyLoss()
92 |
93 | # common
94 | self.iteration = 0
95 | self.data_start = 0
96 | self.ckpt_dir = ckpt_dir
97 | self.do_eval = do_eval
98 | self.do_test = do_test
99 | if self.do_eval:
100 | self.val_top1_acc = Accuracy()
101 | self.val_top5_acc = Accuracy(top_k=5)
102 | if self.do_test:
103 | self.n_crops = n_crops
104 | self.test_top1_acc = Accuracy()
105 | self.test_top5_acc = Accuracy(top_k=5)
106 |
107 | @torch.jit.ignore
108 | def no_weight_decay_keywords(self):
109 | return {'pos_embed', 'cls_token', 'mask_token'}
110 |
111 | def configure_optimizers(self):
112 | # build optimzer
113 | is_pretrain = not (self.configs.objective == 'supervised')
114 | if self.configs.objective == 'supervised' and self.configs.eval_metrics == 'linear_prob':
115 | model = self.cls_head.module if hasattr(self.cls_head, 'module') else self.cls_head
116 | optimizer = build_optimizer(self.configs, model, is_pretrain=is_pretrain)
117 | else:
118 | optimizer = build_optimizer(self.configs, self, is_pretrain=is_pretrain)
119 |
120 | # lr schedule
121 | lr_scheduler = None
122 | lr_schedule = self.configs.lr_schedule
123 | if lr_schedule == 'multistep':
124 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
125 | milestones=[5, 11],
126 | gamma=0.1)
127 | elif lr_schedule == 'cosine':
128 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer,
129 | num_warmup_steps=self.configs.warmup_epochs,
130 | num_training_steps=self.trainer.max_epochs,
131 | base_lr=self.configs.lr,
132 | min_lr=self.configs.min_lr,
133 | objective=self.configs.objective)
134 | return [optimizer], [lr_scheduler]
135 |
136 | def parse_batch(self, batch, train):
137 | if self.configs.objective == 'mim':
138 | inputs, labels, mask, cube_marker, = *batch,
139 | return inputs, labels, mask, cube_marker
140 | else:
141 | inputs, labels, = *batch,
142 | if self.configs.mixup and train:
143 | inputs, labels = self.mixup_fn(inputs, labels)
144 | return inputs, labels
145 |
146 | # epoch schedule
147 | def _get_momentum(self, base_value, final_value):
148 | return final_value - (final_value - base_value) * (math.cos(math.pi * self.trainer.current_epoch / self.trainer.max_epochs) + 1) / 2
149 |
150 | def _weight_decay_update(self):
151 | for i, param_group in enumerate(self.optimizers().optimizer.param_groups):
152 | if i == 1: # only the first group is regularized
153 | param_group["weight_decay"] = self._get_momentum(base_value=self.configs.weight_decay, final_value=self.configs.weight_decay_end)
154 |
155 | def clip_gradients(self, clip_grad, norm_type=2):
156 | layer_norm = []
157 | if self.configs.objective == 'supervised' and self.configs.eval_metrics == 'linear_prob':
158 | model_wo_ddp = self.cls_head.module if hasattr(self.cls_head, 'module') else self.cls_head
159 | else:
160 | model_wo_ddp = self.module if hasattr(self, 'module') else self
161 | for name, p in model_wo_ddp.named_parameters():
162 | if p.grad is not None:
163 | param_norm = torch.norm(p.grad.detach(), norm_type)
164 | layer_norm.append(param_norm)
165 | if clip_grad:
166 | clip_coef = clip_grad / (param_norm + 1e-6)
167 | if clip_coef < 1:
168 | p.grad.data.mul_(clip_coef)
169 | total_grad_norm = torch.norm(torch.stack(layer_norm), norm_type)
170 | return total_grad_norm
171 |
172 | def log_step_state(self, data_time, top1_acc=0, top5_acc=0):
173 | self.log("time",float(f'{time.perf_counter()-self.data_start:.3f}'),prog_bar=True)
174 | self.log("data_time", data_time, prog_bar=True)
175 | if self.configs.objective == 'supervised':
176 | self.log("top1_acc",top1_acc,on_step=True,on_epoch=False,prog_bar=True)
177 | self.log("top5_acc",top5_acc,on_step=True,on_epoch=False,prog_bar=True)
178 |
179 | return None
180 |
181 | def get_progress_bar_dict(self):
182 | # don't show the version number
183 | items = super().get_progress_bar_dict()
184 | items.pop("v_num", None)
185 |
186 | return items
187 |
188 | # Trainer Pipeline
189 | def training_step(self, batch, batch_idx):
190 | data_time = float(f'{time.perf_counter() - self.data_start:.3f}')
191 | if self.configs.objective == 'mim':
192 | inputs, labels, mask, cube_marker = self.parse_batch(batch, train=True)
193 | preds, loss = self.model(inputs, labels, mask, cube_marker)
194 | self.log_step_state(data_time)
195 | return {'loss': loss, 'data_time': data_time}
196 | else:
197 | inputs, labels = self.parse_batch(batch, train=True)
198 | if self.configs.eval_metrics == 'linear_prob':
199 | with torch.no_grad():
200 | self.model.eval()
201 | preds = self.model(inputs)
202 | else:
203 | if self.configs.arch == 'mvit':
204 | preds = self.model.forward_features(inputs)[:, 0]
205 | else:
206 | preds = self.model(inputs)
207 | preds = self.cls_head(preds)
208 | loss = self.loss_fn(preds, labels)
209 | if self.configs.mixup:
210 | top1_acc = self.train_top1_acc(preds.softmax(dim=-1), labels.argmax(-1))
211 | top5_acc = self.train_top5_acc(preds.softmax(dim=-1), labels.argmax(-1))
212 | else:
213 | top1_acc = self.train_top1_acc(preds.softmax(dim=-1), labels)
214 | top5_acc = self.train_top5_acc(preds.softmax(dim=-1), labels)
215 | self.log_step_state(data_time, top1_acc, top5_acc)
216 | return {'loss': loss, 'data_time': data_time}
217 |
218 | def on_after_backward(self):
219 | param_norms = self.clip_gradients(self.configs.clip_grad)
220 | self._weight_decay_update()
221 | # log learning daynamic
222 | lr = self.optimizers().optimizer.param_groups[0]['lr']
223 | self.log("lr",lr,on_step=True,on_epoch=False,prog_bar=True)
224 | self.log("grad_norm",param_norms,on_step=True,on_epoch=False,prog_bar=True)
225 |
226 | def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
227 | optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
228 |
229 | optimizer.step(closure=optimizer_closure)
230 | self.data_start = time.perf_counter()
231 | self.iteration += 1
232 |
233 | def training_epoch_end(self, outputs):
234 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
235 | if self.configs.objective == 'supervised':
236 | mean_top1_acc = self.train_top1_acc.compute()
237 | mean_top5_acc = self.train_top5_acc.compute()
238 | self.print(f'{timestamp} - Evaluating mean ',
239 | f'top1_acc:{mean_top1_acc:.3f},',
240 | f'top5_acc:{mean_top5_acc:.3f} of current training epoch')
241 | self.train_top1_acc.reset()
242 | self.train_top5_acc.reset()
243 |
244 | # save last checkpoint
245 | save_path = osp.join(self.ckpt_dir, 'last_checkpoint.pth')
246 | self.trainer.save_checkpoint(save_path)
247 |
248 | if self.configs.objective != 'supervised' and (self.trainer.current_epoch+1) % self.configs.save_ckpt_freq == 0:
249 | save_path = osp.join(self.ckpt_dir,
250 | f'{timestamp}_'+
251 | f'ep_{self.trainer.current_epoch}.pth')
252 | self.trainer.save_checkpoint(save_path)
253 |
254 | def validation_step(self, batch, batch_indx):
255 | if self.do_eval:
256 | inputs, labels = self.parse_batch(batch, train=False)
257 | if self.configs.eval_metrics == 'linear_prob':
258 | with torch.no_grad():
259 | preds = self.model(inputs)
260 | else:
261 | if self.configs.arch == 'mvit':
262 | preds = self.model.forward_features(inputs)[:, 0]
263 | else:
264 | preds = self.model(inputs)
265 | preds = self.cls_head(preds)
266 |
267 | self.val_top1_acc(preds.softmax(dim=-1), labels)
268 | self.val_top5_acc(preds.softmax(dim=-1), labels)
269 | self.data_start = time.perf_counter()
270 |
271 | def validation_epoch_end(self, outputs):
272 | if self.do_eval:
273 | mean_top1_acc = self.val_top1_acc.compute()
274 | mean_top5_acc = self.val_top5_acc.compute()
275 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
276 | self.print(f'{timestamp} - Evaluating mean ',
277 | f'top1_acc:{mean_top1_acc:.3f}, ',
278 | f'top5_acc:{mean_top5_acc:.3f} of current validation epoch')
279 | self.val_top1_acc.reset()
280 | self.val_top5_acc.reset()
281 |
282 | # save best checkpoint
283 | if mean_top1_acc > self.max_top1_acc:
284 | save_path = osp.join(self.ckpt_dir,
285 | f'{timestamp}_'+
286 | f'ep_{self.trainer.current_epoch}_'+
287 | f'top1_acc_{mean_top1_acc:.3f}.pth')
288 | self.trainer.save_checkpoint(save_path)
289 | self.max_top1_acc = mean_top1_acc
290 |
291 | def test_step(self, batch, batch_idx):
292 | if self.do_test:
293 | inputs, labels = self.parse_batch(batch)
294 | preds = self.cls_head(self.model(inputs))
295 | preds = preds.view(-1, self.n_crops, self.configs.num_class).mean(1)
296 |
297 | self.test_top1_acc(preds.softmax(dim=-1), labels)
298 | self.test_top5_acc(preds.softmax(dim=-1), labels)
299 | self.data_start = time.perf_counter()
300 |
301 | def test_epoch_end(self, outputs):
302 | if self.do_test:
303 | mean_top1_acc = self.test_top1_acc.compute()
304 | mean_top5_acc = self.test_top5_acc.compute()
305 | timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
306 | self.print(f'{timestamp} - Evaluating mean ',
307 | f'top1_acc:{mean_top1_acc:.3f}, ',
308 | f'top5_acc:{mean_top5_acc:.3f} of current test epoch')
309 | self.test_top1_acc.reset()
310 | self.test_top5_acc.reset()
311 |
--------------------------------------------------------------------------------
/notebook/VideoTransformer_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# VideoTransformer\n",
8 | "\n",
9 | "TimeSformer(https://arxiv.org/abs/2102.05095), ViViT(https://arxiv.org/abs/2103.15691)\n",
10 | "\n",
11 | "Welcome to the demo notebook for VideoTransformer. We'll showcase the prediction result by the above pre-trained models."
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## Preliminaries\n",
19 | "\n",
20 | "This section contains initial setup. Run it first."
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {
27 | "colab": {
28 | "base_uri": "https://localhost:8080/"
29 | },
30 | "id": "Gv1qdarnNNVE",
31 | "outputId": "eaac44da-a976-4fda-b323-03dfa6790e21",
32 | "pycharm": {}
33 | },
34 | "outputs": [],
35 | "source": [
36 | "!pip3 install --user torch\n",
37 | "!pip3 install --user torchvision\n",
38 | "!pip3 install --user matplotlib\n",
39 | "!pip3 install --user decord\n",
40 | "!pip3 install --user einops\n",
41 | "!pip3 install --user scikit-image\n",
42 | "!pip3 install --user pytorch-lightning"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {
49 | "colab": {
50 | "base_uri": "https://localhost:8080/"
51 | },
52 | "id": "tBBG_T32pzPH",
53 | "outputId": "76921b65-9e35-4260-fa0a-bbb484ac285c",
54 | "pycharm": {}
55 | },
56 | "outputs": [],
57 | "source": [
58 | "import torch\n",
59 | "import torch.nn as nn\n",
60 | "import numpy as np\n",
61 | "\n",
62 | "from einops import rearrange, reduce, repeat\n",
63 | "from IPython.display import display\n",
64 | "\n",
65 | "!git clone https://github.com/mx-mark/VideoTransformer-pytorch.git\n",
66 | "%cd VideoTransformer-pytorch\n",
67 | "\n",
68 | "import data_transform as T\n",
69 | "from dataset import DecordInit, load_annotation_data\n",
70 | "from transformer import PatchEmbed, TransformerContainer, ClassificationHead"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "### Note\n",
78 | "Please firstly dowload the weights and move to the current path `./VideoTransformer-pytorch/`\n",
79 | "1. TimeSformer-B pre-trained on K400 https://drive.google.com/file/d/1jLkS24jkpmakPi3e5J8KH3FOPv370zvo/view?usp=sharing\n",
80 | "2. ViViT-B pre-trained on K400 from https://drive.google.com/file/d/1-JVhSN3QHKUOLkXLWXWn5drdvKn0gPll/view?usp=sharing"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {
86 | "id": "0yLko87R2_m4",
87 | "pycharm": {}
88 | },
89 | "source": [
90 | "## Video Transformer Model\n",
91 | "\n",
92 | "We here load the pretrained weights of the transformer model TimeSformer-B or ViViT-B."
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 19,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "class TimeSformer(nn.Module):\n",
102 | " \"\"\"TimeSformer. A PyTorch impl of `Is Space-Time Attention All You Need for\n",
103 | " Video Understanding? `_\n",
104 | "\n",
105 | " Args:\n",
106 | " num_frames (int): Number of frames in the video.\n",
107 | " img_size (int | tuple): Size of input image.\n",
108 | " patch_size (int): Size of one patch.\n",
109 | " pretrained (str | None): Name of pretrained model. Default: None.\n",
110 | " embed_dims (int): Dimensions of embedding. Defaults to 768.\n",
111 | " num_heads (int): Number of parallel attention heads in\n",
112 | " TransformerCoder. Defaults to 12.\n",
113 | " num_transformer_layers (int): Number of transformer layers. Defaults to\n",
114 | " 12.\n",
115 | " in_channels (int): Channel num of input features. Defaults to 3.\n",
116 | " dropout_p (float): Probability of dropout layer. Defaults to 0.\n",
117 | " conv_type (str): Type of the convolution in PatchEmbed layer. Defaults to Conv2d.\n",
118 | " attention_type (str): Type of attentions in TransformerCoder. Choices\n",
119 | " are 'divided_space_time', 'space_only' and 'joint_space_time'.\n",
120 | " Defaults to 'divided_space_time'.\n",
121 | " norm_layer (dict): Config for norm layers. Defaults to nn.LayerNorm.\n",
122 | " return_cls_token (bool): Whether to use cls_token to predict class label.\n",
123 | " \"\"\"\n",
124 | " supported_attention_types = [\n",
125 | " 'divided_space_time', 'space_only', 'joint_space_time'\n",
126 | " ]\n",
127 | "\n",
128 | " def __init__(self,\n",
129 | " num_frames,\n",
130 | " img_size,\n",
131 | " patch_size,\n",
132 | " embed_dims=768,\n",
133 | " num_heads=12,\n",
134 | " num_transformer_layers=12,\n",
135 | " in_channels=3,\n",
136 | " conv_type='Conv2d',\n",
137 | " dropout_p=0.,\n",
138 | " attention_type='divided_space_time',\n",
139 | " norm_layer=nn.LayerNorm,\n",
140 | " return_cls_token=True,\n",
141 | " **kwargs):\n",
142 | " super().__init__()\n",
143 | " assert attention_type in self.supported_attention_types, (\n",
144 | " f'Unsupported Attention Type {attention_type}!')\n",
145 | "\n",
146 | " self.num_frames = num_frames\n",
147 | " self.embed_dims = embed_dims\n",
148 | " self.num_transformer_layers = num_transformer_layers\n",
149 | " self.attention_type = attention_type\n",
150 | " self.conv_type = conv_type\n",
151 | " self.return_cls_token = return_cls_token\n",
152 | "\n",
153 | " #tokenize & position embedding\n",
154 | " self.patch_embed = PatchEmbed(\n",
155 | " img_size=img_size,\n",
156 | " patch_size=patch_size,\n",
157 | " in_channels=in_channels,\n",
158 | " embed_dims=embed_dims,\n",
159 | " conv_type=conv_type)\n",
160 | " num_patches = self.patch_embed.num_patches\n",
161 | "\n",
162 | " # Divided Space Time Attention\n",
163 | " operator_order = ['time_attn','space_attn','ffn']\n",
164 | " container = TransformerContainer(\n",
165 | " num_transformer_layers=num_transformer_layers,\n",
166 | " embed_dims=embed_dims,\n",
167 | " num_heads=num_heads,\n",
168 | " num_frames=num_frames,\n",
169 | " norm_layer=norm_layer,\n",
170 | " hidden_channels=embed_dims*4,\n",
171 | " operator_order=operator_order)\n",
172 | "\n",
173 | " self.transformer_layers = container\n",
174 | " self.norm = norm_layer(embed_dims, eps=1e-6)\n",
175 | "\n",
176 | " self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dims))\n",
177 | " num_patches = num_patches + 1\n",
178 | "\n",
179 | " # spatial pos_emb\n",
180 | " self.pos_embed = nn.Parameter(torch.zeros(1,num_patches,embed_dims))\n",
181 | " self.drop_after_pos = nn.Dropout(p=dropout_p)\n",
182 | "\n",
183 | " # temporal pos_emb\n",
184 | " self.time_embed = nn.Parameter(torch.zeros(1,num_frames,embed_dims))\n",
185 | " self.drop_after_time = nn.Dropout(p=dropout_p)\n",
186 | "\n",
187 | " def interpolate_pos_encoding(self, x, w, h):\n",
188 | " npatch = x.shape[1] - 1\n",
189 | " N = self.pos_embed.shape[1] - 1\n",
190 | " if npatch == N and w == h:\n",
191 | " return self.pos_embed\n",
192 | " class_pos_embed = self.pos_embed[:, 0]\n",
193 | " patch_pos_embed = self.pos_embed[:, 1:]\n",
194 | " dim = x.shape[-1]\n",
195 | " w0 = w // self.patch_embed.patch_size[0]\n",
196 | " h0 = h // self.patch_embed.patch_size[0]\n",
197 | " # we add a small number to avoid floating point error in the interpolation\n",
198 | " # see discussion at https://github.com/facebookresearch/dino/issues/8\n",
199 | " w0, h0 = w0 + 0.1, h0 + 0.1\n",
200 | " patch_pos_embed = nn.functional.interpolate(\n",
201 | " patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n",
202 | " scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n",
203 | " mode='bicubic',\n",
204 | " )\n",
205 | " assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n",
206 | " patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n",
207 | " return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n",
208 | "\n",
209 | " def forward(self, x):\n",
210 | " #Tokenize\n",
211 | " b, t, c, h, w = x.shape\n",
212 | " x = self.patch_embed(x)\n",
213 | "\n",
214 | " # Add Position Embedding\n",
215 | " cls_tokens = repeat(self.cls_token, 'b ... -> (repeat b) ...', repeat=x.shape[0])\n",
216 | " x = torch.cat((cls_tokens, x), dim=1)\n",
217 | " x = x + self.interpolate_pos_encoding(x, w, h) #self.pos_embed\n",
218 | " x = self.drop_after_pos(x)\n",
219 | "\n",
220 | " # Add Time Embedding\n",
221 | " cls_tokens = x[:b, 0, :].unsqueeze(1)\n",
222 | " x = rearrange(x[:, 1:, :], '(b t) p d -> (b p) t d', b=b)\n",
223 | " x = x + self.time_embed\n",
224 | " x = rearrange(x, '(b p) t d -> b (p t) d', b=b)\n",
225 | " x = torch.cat((cls_tokens, x), dim=1)\n",
226 | " x = self.drop_after_time(x)\n",
227 | "\n",
228 | " # Video transformer forward\n",
229 | " x = self.transformer_layers(x)\n",
230 | "\n",
231 | " x = self.norm(x)\n",
232 | " # Return Class Token\n",
233 | " if self.return_cls_token:\n",
234 | " return x[:, 0]\n",
235 | " else:\n",
236 | " return x[:, 1:].mean(1)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": 21,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "class ViViT(nn.Module):\n",
246 | " \"\"\"ViViT. A PyTorch impl of `ViViT: A Video Vision Transformer`\n",
247 | " \n",
248 | "\n",
249 | " Args:\n",
250 | " num_frames (int): Number of frames in the video.\n",
251 | " img_size (int | tuple): Size of input image.\n",
252 | " patch_size (int): Size of one patch.\n",
253 | " pretrained (str | None): Name of pretrained model. Default: None.\n",
254 | " embed_dims (int): Dimensions of embedding. Defaults to 768.\n",
255 | " num_heads (int): Number of parallel attention heads. Defaults to 12.\n",
256 | " num_transformer_layers (int): Number of transformer layers. Defaults to 12.\n",
257 | " in_channels (int): Channel num of input features. Defaults to 3.\n",
258 | " dropout_p (float): Probability of dropout layer. Defaults to 0..\n",
259 | " tube_size (int): Dimension of the kernel size in Conv3d. Defaults to 2.\n",
260 | " conv_type (str): Type of the convolution in PatchEmbed layer. Defaults to Conv3d.\n",
261 | " attention_type (str): Type of attentions in TransformerCoder. Choices\n",
262 | " are 'divided_space_time', 'fact_encoder' and 'joint_space_time'.\n",
263 | " Defaults to 'fact_encoder'.\n",
264 | " norm_layer (dict): Config for norm layers. Defaults to nn.LayerNorm.\n",
265 | " copy_strategy (str): Copy or Initial to zero towards the new additional layer.\n",
266 | " extend_strategy (str): How to initialize the weights of Conv3d from pre-trained Conv2d.\n",
267 | " use_learnable_pos_emb (bool): Whether to use learnable position embeddings.\n",
268 | " return_cls_token (bool): Whether to use cls_token to predict class label.\n",
269 | " \"\"\"\n",
270 | " supported_attention_types = [\n",
271 | " 'fact_encoder', 'joint_space_time', 'divided_space_time'\n",
272 | " ]\n",
273 | "\n",
274 | " def __init__(self,\n",
275 | " num_frames,\n",
276 | " img_size,\n",
277 | " patch_size,\n",
278 | " embed_dims=768,\n",
279 | " num_heads=12,\n",
280 | " num_transformer_layers=12,\n",
281 | " in_channels=3,\n",
282 | " dropout_p=0.,\n",
283 | " tube_size=2,\n",
284 | " conv_type='Conv3d',\n",
285 | " attention_type='fact_encoder',\n",
286 | " norm_layer=nn.LayerNorm,\n",
287 | " return_cls_token=True,\n",
288 | " **kwargs):\n",
289 | " super().__init__()\n",
290 | " assert attention_type in self.supported_attention_types, (\n",
291 | " f'Unsupported Attention Type {attention_type}!')\n",
292 | "\n",
293 | " num_frames = num_frames//tube_size\n",
294 | " self.num_frames = num_frames\n",
295 | " self.embed_dims = embed_dims\n",
296 | " self.num_transformer_layers = num_transformer_layers\n",
297 | " self.attention_type = attention_type\n",
298 | " self.conv_type = conv_type\n",
299 | " self.tube_size = tube_size\n",
300 | " self.num_time_transformer_layers = 4\n",
301 | " self.return_cls_token = return_cls_token\n",
302 | "\n",
303 | " #tokenize & position embedding\n",
304 | " self.patch_embed = PatchEmbed(\n",
305 | " img_size=img_size,\n",
306 | " patch_size=patch_size,\n",
307 | " in_channels=in_channels,\n",
308 | " embed_dims=embed_dims,\n",
309 | " tube_size=tube_size,\n",
310 | " conv_type=conv_type)\n",
311 | " num_patches = self.patch_embed.num_patches\n",
312 | "\n",
313 | " # Divided Space Time Transformer Encoder - Model 2\n",
314 | " transformer_layers = nn.ModuleList([])\n",
315 | "\n",
316 | " spatial_transformer = TransformerContainer(\n",
317 | " num_transformer_layers=num_transformer_layers,\n",
318 | " embed_dims=embed_dims,\n",
319 | " num_heads=num_heads,\n",
320 | " num_frames=num_frames,\n",
321 | " norm_layer=norm_layer,\n",
322 | " hidden_channels=embed_dims*4,\n",
323 | " operator_order=['self_attn','ffn'])\n",
324 | "\n",
325 | " temporal_transformer = TransformerContainer(\n",
326 | " num_transformer_layers=self.num_time_transformer_layers,\n",
327 | " embed_dims=embed_dims,\n",
328 | " num_heads=num_heads,\n",
329 | " num_frames=num_frames,\n",
330 | " norm_layer=norm_layer,\n",
331 | " hidden_channels=embed_dims*4,\n",
332 | " operator_order=['self_attn','ffn'])\n",
333 | "\n",
334 | " transformer_layers.append(spatial_transformer)\n",
335 | " transformer_layers.append(temporal_transformer)\n",
336 | "\n",
337 | " self.transformer_layers = transformer_layers\n",
338 | " self.norm = norm_layer(embed_dims, eps=1e-6)\n",
339 | "\n",
340 | " self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dims))\n",
341 | " # whether to add one cls_token in temporal pos_enb\n",
342 | " num_frames = num_frames + 1\n",
343 | " num_patches = num_patches + 1\n",
344 | "\n",
345 | " self.pos_embed = nn.Parameter(torch.zeros(1,num_patches,embed_dims))\n",
346 | " self.time_embed = nn.Parameter(torch.zeros(1,num_frames,embed_dims))\n",
347 | " self.drop_after_pos = nn.Dropout(p=dropout_p)\n",
348 | " self.drop_after_time = nn.Dropout(p=dropout_p)\n",
349 | " \n",
350 | " def interpolate_pos_encoding(self, x, w, h):\n",
351 | " npatch = x.shape[1] - 1\n",
352 | " N = self.pos_embed.shape[1] - 1\n",
353 | " if npatch == N and w == h:\n",
354 | " return self.pos_embed\n",
355 | " class_pos_embed = self.pos_embed[:, 0]\n",
356 | " patch_pos_embed = self.pos_embed[:, 1:]\n",
357 | " dim = x.shape[-1]\n",
358 | " w0 = w // self.patch_embed.patch_size[0]\n",
359 | " h0 = h // self.patch_embed.patch_size[0]\n",
360 | " # we add a small number to avoid floating point error in the interpolation\n",
361 | " # see discussion at https://github.com/facebookresearch/dino/issues/8\n",
362 | " w0, h0 = w0 + 0.1, h0 + 0.1\n",
363 | " patch_pos_embed = nn.functional.interpolate(\n",
364 | " patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),\n",
365 | " scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),\n",
366 | " mode='bicubic',\n",
367 | " )\n",
368 | " assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]\n",
369 | " patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n",
370 | " return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)\n",
371 | "\n",
372 | " def forward(self, x):\n",
373 | " #Tokenize\n",
374 | " b, t, c, h, w = x.shape\n",
375 | " x = self.patch_embed(x)\n",
376 | "\n",
377 | " # Add Position Embedding\n",
378 | " cls_tokens = repeat(self.cls_token, 'b ... -> (repeat b) ...', repeat=x.shape[0])\n",
379 | " x = torch.cat((cls_tokens, x), dim=1)\n",
380 | " x = x + self.interpolate_pos_encoding(x, w, h)\n",
381 | " x = self.drop_after_pos(x)\n",
382 | "\n",
383 | " # fact encoder - CRNN style\n",
384 | " spatial_transformer, temporal_transformer, = *self.transformer_layers,\n",
385 | " x = spatial_transformer(x)\n",
386 | "\n",
387 | " # Add Time Embedding\n",
388 | " cls_tokens = x[:b, 0, :].unsqueeze(1)\n",
389 | " x = rearrange(x[:, 1:, :], '(b t) p d -> b t p d', b=b)\n",
390 | " x = reduce(x, 'b t p d -> b t d', 'mean')\n",
391 | " x = torch.cat((cls_tokens, x), dim=1)\n",
392 | " x = x + self.time_embed\n",
393 | " x = self.drop_after_time(x)\n",
394 | "\n",
395 | " x = temporal_transformer(x)\n",
396 | "\n",
397 | " x = self.norm(x)\n",
398 | " # Return Class Token\n",
399 | " if self.return_cls_token:\n",
400 | " return x[:, 0]\n",
401 | " else:\n",
402 | " return x[:, 1:].mean(1)"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": 23,
408 | "metadata": {},
409 | "outputs": [],
410 | "source": [
411 | "def replace_state_dict(state_dict):\n",
412 | "\tfor old_key in list(state_dict.keys()):\n",
413 | "\t\tif old_key.startswith('model'):\n",
414 | "\t\t\tnew_key = old_key[6:]\n",
415 | "\t\t\tstate_dict[new_key] = state_dict.pop(old_key)\n",
416 | "\t\telse:\n",
417 | "\t\t\tnew_key = old_key[9:]\n",
418 | "\t\t\tstate_dict[new_key] = state_dict.pop(old_key)"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": 25,
424 | "metadata": {},
425 | "outputs": [],
426 | "source": [
427 | "def init_from_pretrain_(module, pretrained, init_module):\n",
428 | " if torch.cuda.is_available():\n",
429 | " state_dict = torch.load(pretrained)\n",
430 | " else:\n",
431 | " state_dict = torch.load(pretrained, map_location=torch.device('cpu'))\n",
432 | " if init_module == 'transformer':\n",
433 | " replace_state_dict(state_dict)\n",
434 | " elif init_module == 'cls_head':\n",
435 | " replace_state_dict(state_dict)\n",
436 | " else:\n",
437 | " raise TypeError(f'pretrained weights do not include the {init_module} module')\n",
438 | " msg = module.load_state_dict(state_dict, strict=False)\n",
439 | " return msg"
440 | ]
441 | },
442 | {
443 | "cell_type": "code",
444 | "execution_count": null,
445 | "metadata": {
446 | "colab": {
447 | "base_uri": "https://localhost:8080/",
448 | "height": 179
449 | },
450 | "id": "u5J7lGPJ2bLJ",
451 | "outputId": "41c0568b-082f-4609-8a5e-9ff4b80ffd92",
452 | "pycharm": {}
453 | },
454 | "outputs": [],
455 | "source": [
456 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") \n",
457 | "num_frames = 8\n",
458 | "frame_interval = 32\n",
459 | "num_class = 400\n",
460 | "arch = 'timesformer' # turn to vivit for initializing vivit model\n",
461 | "\n",
462 | "if arch == 'timesformer':\n",
463 | " pretrain_pth = './timesformer_k400.pth'\n",
464 | " model = TimeSformer(num_frames=num_frames,\n",
465 | " img_size=224,\n",
466 | " patch_size=16,\n",
467 | " embed_dims=768,\n",
468 | " in_channels=3,\n",
469 | " attention_type='divided_space_time',\n",
470 | " return_cls_token=True)\n",
471 | "elif arch == 'vivit':\n",
472 | " pretrain_pth = './vivit_model.pth'\n",
473 | " num_frames = num_frames * 2\n",
474 | " frame_interval = frame_interval // 2\n",
475 | " model = ViViT(num_frames=num_frames,\n",
476 | " img_size=224,\n",
477 | " patch_size=16,\n",
478 | " embed_dims=768,\n",
479 | " in_channels=3,\n",
480 | " attention_type='fact_encoder',\n",
481 | " return_cls_token=True)\n",
482 | "else:\n",
483 | " raise TypeError(f'not supported arch type {arch}, chosen in (timesformer, vivit)')\n",
484 | "\n",
485 | "cls_head = ClassificationHead(num_classes=num_class, in_channels=768)\n",
486 | "msg_trans = init_from_pretrain_(model, pretrain_pth, init_module='transformer')\n",
487 | "msg_cls = init_from_pretrain_(cls_head, pretrain_pth, init_module='cls_head')\n",
488 | "model.eval()\n",
489 | "cls_head.eval()\n",
490 | "model = model.to(device)\n",
491 | "cls_head = cls_head.to(device)\n",
492 | "print(f'load model finished, the missing key of transformer is:{msg_trans[0]}, cls is:{msg_cls[0]}')"
493 | ]
494 | },
495 | {
496 | "cell_type": "markdown",
497 | "metadata": {},
498 | "source": [
499 | "## Data preprocess\n",
500 | "\n",
501 | "Here we show the video demo and transform the video input for the model processing."
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": null,
507 | "metadata": {},
508 | "outputs": [],
509 | "source": [
510 | "from IPython.display import display, HTML\n",
511 | "\n",
512 | "video_path = './demo/YABnJL_bDzw.mp4'\n",
513 | "html_str = '''\n",
514 | "\n",
515 | "'''.format(video_path)\n",
516 | "display(HTML(html_str))"
517 | ]
518 | },
519 | {
520 | "cell_type": "code",
521 | "execution_count": 54,
522 | "metadata": {},
523 | "outputs": [],
524 | "source": [
525 | "# Prepare data preprocess\n",
526 | "mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)\n",
527 | "data_transform = T.Compose([\n",
528 | " T.Resize(scale_range=(-1, 256)),\n",
529 | " T.ThreeCrop(size=224),\n",
530 | " T.ToTensor(),\n",
531 | " T.Normalize(mean, std)\n",
532 | " ])\n",
533 | "temporal_sample = T.TemporalRandomCrop(num_frames*frame_interval)\n",
534 | "\n",
535 | "# Sampling video frames\n",
536 | "video_decoder = DecordInit()\n",
537 | "v_reader = video_decoder(video_path)\n",
538 | "total_frames = len(v_reader)\n",
539 | "start_frame_ind, end_frame_ind = temporal_sample(total_frames)\n",
540 | "if end_frame_ind-start_frame_ind < num_frames:\n",
541 | " raise ValueError(f'the total frames of the video {video_path} is less than {num_frames}')\n",
542 | "frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, num_frames, dtype=int)\n",
543 | "video = v_reader.get_batch(frame_indice).asnumpy()\n",
544 | "del v_reader\n",
545 | "\n",
546 | "video = torch.from_numpy(video).permute(0,3,1,2) # Video transform: T C H W\n",
547 | "data_transform.randomize_parameters()\n",
548 | "video = data_transform(video)\n",
549 | "video = video.to(device)"
550 | ]
551 | },
552 | {
553 | "cell_type": "markdown",
554 | "metadata": {},
555 | "source": [
556 | "## Video Classification\n",
557 | "\n",
558 | "Here we use the pre-trained video transformer to classify the input video."
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": null,
564 | "metadata": {},
565 | "outputs": [],
566 | "source": [
567 | "# Predict class label\n",
568 | "with torch.no_grad():\n",
569 | " logits = model(video)\n",
570 | " output = cls_head(logits)\n",
571 | " output = output.view(3, 400).mean(0)\n",
572 | " cls_pred = output.argmax().item()\n",
573 | "\n",
574 | "class_map = './k400_classmap.json'\n",
575 | "class_map = load_annotation_data(class_map)\n",
576 | "for key, value in class_map.items():\n",
577 | " if int(value) == int(cls_pred):\n",
578 | " print(f'the shape of ouptut: {output.shape}, and the prediction is: {key}')\n",
579 | " break"
580 | ]
581 | }
582 | ],
583 | "metadata": {
584 | "accelerator": "GPU",
585 | "colab": {
586 | "collapsed_sections": [],
587 | "name": "iBOT_demo",
588 | "provenance": [],
589 | "toc_visible": true
590 | },
591 | "kernelspec": {
592 | "display_name": "Python 3.8 (XPython)",
593 | "language": "python",
594 | "name": "xpython"
595 | },
596 | "language_info": {
597 | "file_extension": ".py",
598 | "mimetype": "text/x-python",
599 | "name": "python",
600 | "version": "3.8.6"
601 | },
602 | "widgets": {
603 | "application/vnd.jupyter.widget-state+json": {
604 | "10cb415f29954cb0a511e263fcb6c69f": {
605 | "model_module": "@jupyter-widgets/controls",
606 | "model_module_version": "1.5.0",
607 | "model_name": "HTMLModel",
608 | "state": {
609 | "_dom_classes": [],
610 | "_model_module": "@jupyter-widgets/controls",
611 | "_model_module_version": "1.5.0",
612 | "_model_name": "HTMLModel",
613 | "_view_count": null,
614 | "_view_module": "@jupyter-widgets/controls",
615 | "_view_module_version": "1.5.0",
616 | "_view_name": "HTMLView",
617 | "description": "",
618 | "description_tooltip": null,
619 | "layout": "IPY_MODEL_ecfd9523d50f4f7f87083460dc0e3dd7",
620 | "placeholder": "",
621 | "style": "IPY_MODEL_6d68baa05ccc427890286e0e74452155",
622 | "value": " 327M/327M [00:09<00:00, 36.6MB/s]"
623 | }
624 | },
625 | "3cb57e69cf054551b1978ddd6be3553b": {
626 | "model_module": "@jupyter-widgets/controls",
627 | "model_module_version": "1.5.0",
628 | "model_name": "HTMLModel",
629 | "state": {
630 | "_dom_classes": [],
631 | "_model_module": "@jupyter-widgets/controls",
632 | "_model_module_version": "1.5.0",
633 | "_model_name": "HTMLModel",
634 | "_view_count": null,
635 | "_view_module": "@jupyter-widgets/controls",
636 | "_view_module_version": "1.5.0",
637 | "_view_name": "HTMLView",
638 | "description": "",
639 | "description_tooltip": null,
640 | "layout": "IPY_MODEL_5bbd072a4a424ef88d5354793942a84b",
641 | "placeholder": "",
642 | "style": "IPY_MODEL_6668b553344743e2a1450251c5a237e2",
643 | "value": "100%"
644 | }
645 | },
646 | "557451188ee74f3698a8b358ee70d29d": {
647 | "model_module": "@jupyter-widgets/controls",
648 | "model_module_version": "1.5.0",
649 | "model_name": "FloatProgressModel",
650 | "state": {
651 | "_dom_classes": [],
652 | "_model_module": "@jupyter-widgets/controls",
653 | "_model_module_version": "1.5.0",
654 | "_model_name": "FloatProgressModel",
655 | "_view_count": null,
656 | "_view_module": "@jupyter-widgets/controls",
657 | "_view_module_version": "1.5.0",
658 | "_view_name": "ProgressView",
659 | "bar_style": "success",
660 | "description": "",
661 | "description_tooltip": null,
662 | "layout": "IPY_MODEL_af53b2f6e65b43fb8b3df5c656b40c92",
663 | "max": 343279349,
664 | "min": 0,
665 | "orientation": "horizontal",
666 | "style": "IPY_MODEL_9e962b8959ca4e90b496ab36d764f073",
667 | "value": 343279349
668 | }
669 | },
670 | "5bbd072a4a424ef88d5354793942a84b": {
671 | "model_module": "@jupyter-widgets/base",
672 | "model_module_version": "1.2.0",
673 | "model_name": "LayoutModel",
674 | "state": {
675 | "_model_module": "@jupyter-widgets/base",
676 | "_model_module_version": "1.2.0",
677 | "_model_name": "LayoutModel",
678 | "_view_count": null,
679 | "_view_module": "@jupyter-widgets/base",
680 | "_view_module_version": "1.2.0",
681 | "_view_name": "LayoutView",
682 | "align_content": null,
683 | "align_items": null,
684 | "align_self": null,
685 | "border": null,
686 | "bottom": null,
687 | "display": null,
688 | "flex": null,
689 | "flex_flow": null,
690 | "grid_area": null,
691 | "grid_auto_columns": null,
692 | "grid_auto_flow": null,
693 | "grid_auto_rows": null,
694 | "grid_column": null,
695 | "grid_gap": null,
696 | "grid_row": null,
697 | "grid_template_areas": null,
698 | "grid_template_columns": null,
699 | "grid_template_rows": null,
700 | "height": null,
701 | "justify_content": null,
702 | "justify_items": null,
703 | "left": null,
704 | "margin": null,
705 | "max_height": null,
706 | "max_width": null,
707 | "min_height": null,
708 | "min_width": null,
709 | "object_fit": null,
710 | "object_position": null,
711 | "order": null,
712 | "overflow": null,
713 | "overflow_x": null,
714 | "overflow_y": null,
715 | "padding": null,
716 | "right": null,
717 | "top": null,
718 | "visibility": null,
719 | "width": null
720 | }
721 | },
722 | "6668b553344743e2a1450251c5a237e2": {
723 | "model_module": "@jupyter-widgets/controls",
724 | "model_module_version": "1.5.0",
725 | "model_name": "DescriptionStyleModel",
726 | "state": {
727 | "_model_module": "@jupyter-widgets/controls",
728 | "_model_module_version": "1.5.0",
729 | "_model_name": "DescriptionStyleModel",
730 | "_view_count": null,
731 | "_view_module": "@jupyter-widgets/base",
732 | "_view_module_version": "1.2.0",
733 | "_view_name": "StyleView",
734 | "description_width": ""
735 | }
736 | },
737 | "6d68baa05ccc427890286e0e74452155": {
738 | "model_module": "@jupyter-widgets/controls",
739 | "model_module_version": "1.5.0",
740 | "model_name": "DescriptionStyleModel",
741 | "state": {
742 | "_model_module": "@jupyter-widgets/controls",
743 | "_model_module_version": "1.5.0",
744 | "_model_name": "DescriptionStyleModel",
745 | "_view_count": null,
746 | "_view_module": "@jupyter-widgets/base",
747 | "_view_module_version": "1.2.0",
748 | "_view_name": "StyleView",
749 | "description_width": ""
750 | }
751 | },
752 | "8902dbcd84b4439a987e908be8c8e19e": {
753 | "model_module": "@jupyter-widgets/controls",
754 | "model_module_version": "1.5.0",
755 | "model_name": "HBoxModel",
756 | "state": {
757 | "_dom_classes": [],
758 | "_model_module": "@jupyter-widgets/controls",
759 | "_model_module_version": "1.5.0",
760 | "_model_name": "HBoxModel",
761 | "_view_count": null,
762 | "_view_module": "@jupyter-widgets/controls",
763 | "_view_module_version": "1.5.0",
764 | "_view_name": "HBoxView",
765 | "box_style": "",
766 | "children": [
767 | "IPY_MODEL_3cb57e69cf054551b1978ddd6be3553b",
768 | "IPY_MODEL_557451188ee74f3698a8b358ee70d29d",
769 | "IPY_MODEL_10cb415f29954cb0a511e263fcb6c69f"
770 | ],
771 | "layout": "IPY_MODEL_95c73030d06d4d578252f47989c208fc"
772 | }
773 | },
774 | "95c73030d06d4d578252f47989c208fc": {
775 | "model_module": "@jupyter-widgets/base",
776 | "model_module_version": "1.2.0",
777 | "model_name": "LayoutModel",
778 | "state": {
779 | "_model_module": "@jupyter-widgets/base",
780 | "_model_module_version": "1.2.0",
781 | "_model_name": "LayoutModel",
782 | "_view_count": null,
783 | "_view_module": "@jupyter-widgets/base",
784 | "_view_module_version": "1.2.0",
785 | "_view_name": "LayoutView",
786 | "align_content": null,
787 | "align_items": null,
788 | "align_self": null,
789 | "border": null,
790 | "bottom": null,
791 | "display": null,
792 | "flex": null,
793 | "flex_flow": null,
794 | "grid_area": null,
795 | "grid_auto_columns": null,
796 | "grid_auto_flow": null,
797 | "grid_auto_rows": null,
798 | "grid_column": null,
799 | "grid_gap": null,
800 | "grid_row": null,
801 | "grid_template_areas": null,
802 | "grid_template_columns": null,
803 | "grid_template_rows": null,
804 | "height": null,
805 | "justify_content": null,
806 | "justify_items": null,
807 | "left": null,
808 | "margin": null,
809 | "max_height": null,
810 | "max_width": null,
811 | "min_height": null,
812 | "min_width": null,
813 | "object_fit": null,
814 | "object_position": null,
815 | "order": null,
816 | "overflow": null,
817 | "overflow_x": null,
818 | "overflow_y": null,
819 | "padding": null,
820 | "right": null,
821 | "top": null,
822 | "visibility": null,
823 | "width": null
824 | }
825 | },
826 | "9e962b8959ca4e90b496ab36d764f073": {
827 | "model_module": "@jupyter-widgets/controls",
828 | "model_module_version": "1.5.0",
829 | "model_name": "ProgressStyleModel",
830 | "state": {
831 | "_model_module": "@jupyter-widgets/controls",
832 | "_model_module_version": "1.5.0",
833 | "_model_name": "ProgressStyleModel",
834 | "_view_count": null,
835 | "_view_module": "@jupyter-widgets/base",
836 | "_view_module_version": "1.2.0",
837 | "_view_name": "StyleView",
838 | "bar_color": null,
839 | "description_width": ""
840 | }
841 | },
842 | "af53b2f6e65b43fb8b3df5c656b40c92": {
843 | "model_module": "@jupyter-widgets/base",
844 | "model_module_version": "1.2.0",
845 | "model_name": "LayoutModel",
846 | "state": {
847 | "_model_module": "@jupyter-widgets/base",
848 | "_model_module_version": "1.2.0",
849 | "_model_name": "LayoutModel",
850 | "_view_count": null,
851 | "_view_module": "@jupyter-widgets/base",
852 | "_view_module_version": "1.2.0",
853 | "_view_name": "LayoutView",
854 | "align_content": null,
855 | "align_items": null,
856 | "align_self": null,
857 | "border": null,
858 | "bottom": null,
859 | "display": null,
860 | "flex": null,
861 | "flex_flow": null,
862 | "grid_area": null,
863 | "grid_auto_columns": null,
864 | "grid_auto_flow": null,
865 | "grid_auto_rows": null,
866 | "grid_column": null,
867 | "grid_gap": null,
868 | "grid_row": null,
869 | "grid_template_areas": null,
870 | "grid_template_columns": null,
871 | "grid_template_rows": null,
872 | "height": null,
873 | "justify_content": null,
874 | "justify_items": null,
875 | "left": null,
876 | "margin": null,
877 | "max_height": null,
878 | "max_width": null,
879 | "min_height": null,
880 | "min_width": null,
881 | "object_fit": null,
882 | "object_position": null,
883 | "order": null,
884 | "overflow": null,
885 | "overflow_x": null,
886 | "overflow_y": null,
887 | "padding": null,
888 | "right": null,
889 | "top": null,
890 | "visibility": null,
891 | "width": null
892 | }
893 | },
894 | "ecfd9523d50f4f7f87083460dc0e3dd7": {
895 | "model_module": "@jupyter-widgets/base",
896 | "model_module_version": "1.2.0",
897 | "model_name": "LayoutModel",
898 | "state": {
899 | "_model_module": "@jupyter-widgets/base",
900 | "_model_module_version": "1.2.0",
901 | "_model_name": "LayoutModel",
902 | "_view_count": null,
903 | "_view_module": "@jupyter-widgets/base",
904 | "_view_module_version": "1.2.0",
905 | "_view_name": "LayoutView",
906 | "align_content": null,
907 | "align_items": null,
908 | "align_self": null,
909 | "border": null,
910 | "bottom": null,
911 | "display": null,
912 | "flex": null,
913 | "flex_flow": null,
914 | "grid_area": null,
915 | "grid_auto_columns": null,
916 | "grid_auto_flow": null,
917 | "grid_auto_rows": null,
918 | "grid_column": null,
919 | "grid_gap": null,
920 | "grid_row": null,
921 | "grid_template_areas": null,
922 | "grid_template_columns": null,
923 | "grid_template_rows": null,
924 | "height": null,
925 | "justify_content": null,
926 | "justify_items": null,
927 | "left": null,
928 | "margin": null,
929 | "max_height": null,
930 | "max_width": null,
931 | "min_height": null,
932 | "min_width": null,
933 | "object_fit": null,
934 | "object_position": null,
935 | "order": null,
936 | "overflow": null,
937 | "overflow_x": null,
938 | "overflow_y": null,
939 | "padding": null,
940 | "right": null,
941 | "top": null,
942 | "visibility": null,
943 | "width": null
944 | }
945 | }
946 | }
947 | }
948 | },
949 | "nbformat": 4,
950 | "nbformat_minor": 4
951 | }
952 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # Modified by MX
7 | # --------------------------------------------------------
8 | from functools import partial
9 | from torch import optim as optim
10 |
11 | from utils import print_on_rank_zero
12 |
13 |
14 | def build_optimizer(hparams, model, is_pretrain):
15 | if is_pretrain:
16 | return build_pretrain_optimizer(hparams, model)
17 | else:
18 | return build_finetune_optimizer(hparams, model)
19 |
20 |
21 | def build_pretrain_optimizer(hparams, model):
22 | skip = {}
23 | skip_keywords = {}
24 | if hasattr(model, 'no_weight_decay'):
25 | skip = model.no_weight_decay()
26 | if hasattr(model, 'no_weight_decay_keywords'):
27 | skip_keywords = model.no_weight_decay_keywords()
28 |
29 | parameters = get_pretrain_param_groups(model, skip, skip_keywords)
30 |
31 | opt_lower = hparams.optim_type.lower()
32 | optimizer = None
33 | if opt_lower == 'sgd':
34 | optimizer = optim.SGD(parameters, momentum=0.9, nesterov=True,
35 | lr=hparams.lr, weight_decay=hparams.weight_decay)
36 | elif opt_lower == 'adamw':
37 | optimizer = optim.AdamW(parameters, betas=(0.9, 0.999),
38 | lr=hparams.lr, weight_decay=hparams.weight_decay)
39 |
40 | return optimizer
41 |
42 |
43 | def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()):
44 | has_decay = []
45 | no_decay = []
46 | has_decay_name = []
47 | no_decay_name = []
48 |
49 | for name, param in model.named_parameters():
50 | if not param.requires_grad:
51 | continue
52 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
53 | check_keywords_in_name(name, skip_keywords):
54 | no_decay.append(param)
55 | no_decay_name.append(name)
56 | else:
57 | has_decay.append(param)
58 | has_decay_name.append(name)
59 |
60 | print_on_rank_zero(f'params_no_decay_name: {no_decay_name} \n params_decay_name: {has_decay_name}')
61 | return [{'params': no_decay, 'weight_decay': 0.},
62 | {'params': has_decay},]
63 |
64 |
65 | def build_finetune_optimizer(hparams, model):
66 | if hparams.arch == 'mvit':
67 | if hparams.layer_decay == 1:
68 | get_layer_func = None
69 | scales = None
70 | else:
71 | num_layers = 16
72 | get_layer_func = partial(get_mvit_layer, num_layers=num_layers + 2)
73 | scales = list(hparams.layer_decay ** i for i in reversed(range(num_layers + 2))) #layer_decay=1 disable
74 | else:
75 | return build_pretrain_optimizer(hparams, model)
76 |
77 | skip = {}
78 | skip_keywords = {}
79 | if hasattr(model, 'no_weight_decay'):
80 | skip = model.no_weight_decay()
81 | if hasattr(model, 'no_weight_decay_keywords'):
82 | skip_keywords = model.no_weight_decay_keywords()
83 |
84 | parameters = get_finetune_param_groups(
85 | model, hparams.lr, hparams.weight_decay,
86 | get_layer_func, scales, skip, skip_keywords)
87 |
88 | opt_lower = hparams.optim_type.lower()
89 | optimizer = None
90 | if opt_lower == 'sgd':
91 | optimizer = optim.SGD(parameters, momentum=0.9, nesterov=True,
92 | lr=hparams.lr, weight_decay=hparams.weight_decay)
93 | elif opt_lower == 'adamw':
94 | optimizer = optim.AdamW(parameters, betas=(0.9, 0.999),
95 | lr=hparams.lr, weight_decay=hparams.weight_decay)
96 |
97 | return optimizer
98 |
99 |
100 | def get_mvit_layer(name, num_layers):
101 | layer_name = name.replace('mvit.', '')
102 | layer_name = layer_name.replace('model.', '')
103 | if layer_name in ("mask_token"):
104 | return 0
105 | elif layer_name.startswith("patch_embed") or layer_name.startswith('cls_positional_encoding'):
106 | return 0
107 | elif layer_name.startswith("blocks"):
108 | layer_id = int(layer_name.split('.')[1])
109 | return layer_id + 1
110 | else:
111 | return num_layers - 1
112 |
113 |
114 | def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()):
115 | parameter_group_names = {}
116 | parameter_group_vars = {}
117 |
118 | for name, param in model.named_parameters():
119 | if not param.requires_grad:
120 | continue
121 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
122 | check_keywords_in_name(name, skip_keywords):
123 | group_name = "no_decay"
124 | this_weight_decay = 0.
125 | else:
126 | group_name = "decay"
127 | this_weight_decay = weight_decay
128 | if get_layer_func is not None:
129 | layer_id = get_layer_func(name)
130 | group_name = "layer_%d_%s" % (layer_id, group_name)
131 | #print(name, group_name)
132 | else:
133 | layer_id = None
134 |
135 | if group_name not in parameter_group_names:
136 | if scales is not None:
137 | scale = scales[layer_id]
138 | else:
139 | scale = 1.
140 |
141 | parameter_group_names[group_name] = {
142 | "group_name": group_name,
143 | "weight_decay": this_weight_decay,
144 | "params": [],
145 | "lr": lr * scale,
146 | "lr_scale": scale,
147 | }
148 | parameter_group_vars[group_name] = {
149 | "group_name": group_name,
150 | "weight_decay": this_weight_decay,
151 | "params": [],
152 | "lr": lr * scale,
153 | "lr_scale": scale
154 | }
155 |
156 | parameter_group_vars[group_name]["params"].append(param)
157 | parameter_group_names[group_name]["params"].append(name)
158 | return list(parameter_group_vars.values())
159 |
160 |
161 | def check_keywords_in_name(name, keywords=()):
162 | isin = False
163 | for keyword in keywords:
164 | if keyword in name:
165 | isin = True
166 | return isin
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchmetrics==0.5.1
3 | torchvision
4 | pytorch-lightning==1.3.8
5 | pytorchvideo
6 | scikit-image
7 | decord==0.6.0
8 | einops==0.3.2
9 | kornia==0.6.1
10 | matplotlib==3.4.3
11 | timm==0.4.12
--------------------------------------------------------------------------------
/transformer.py:
--------------------------------------------------------------------------------
1 | from einops import rearrange, repeat, reduce
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.utils import _pair
6 |
7 | from weight_init import trunc_normal_, constant_init_, kaiming_init_
8 |
9 |
10 | # sin-cos position encoding
11 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
12 | def get_sine_cosine_pos_emb(n_position, d_hid):
13 | ''' Sinusoid position encoding table '''
14 | # TODO: make it with torch instead of numpy
15 | def get_position_angle_vec(position):
16 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
17 |
18 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
19 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
20 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
21 |
22 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
23 |
24 |
25 | class DropPath(nn.Module):
26 |
27 | def __init__(self, dropout_p=None):
28 | super(DropPath, self).__init__()
29 | self.dropout_p = dropout_p
30 |
31 | def forward(self, x):
32 | return self.drop_path(x, self.dropout_p, self.training)
33 |
34 | def drop_path(self, x, dropout_p=0., training=False):
35 | if dropout_p == 0. or not training:
36 | return x
37 | keep_prob = 1 - dropout_p
38 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
39 | random_tensor = keep_prob + torch.rand(shape).type_as(x)
40 | random_tensor.floor_() # binarize
41 | output = x.div(keep_prob) * random_tensor
42 | return output
43 |
44 |
45 | class ClassificationHead(nn.Module):
46 | """Classification head for Video Transformer.
47 |
48 | Args:
49 | num_classes (int): Number of classes to be classified.
50 | in_channels (int): Number of channels in input feature.
51 | init_std (float): Std value for Initiation. Defaults to 0.02.
52 | kwargs (dict, optional): Any keyword argument to be used to initialize
53 | the head.
54 | """
55 |
56 | def __init__(self,
57 | num_classes,
58 | in_channels,
59 | init_std=0.02,
60 | eval_metrics='finetune',
61 | **kwargs):
62 | super().__init__()
63 | self.init_std = init_std
64 | self.eval_metrics = eval_metrics
65 | self.cls_head = nn.Linear(in_channels, num_classes)
66 |
67 | self.init_weights(self.cls_head)
68 |
69 | def init_weights(self, module):
70 | if hasattr(module, 'weight') and module.weight is not None:
71 | if self.eval_metrics == 'finetune':
72 | trunc_normal_(module.weight, std=self.init_std)
73 | else:
74 | module.weight.data.normal_(mean=0.0, std=0.01)
75 | if hasattr(module, 'bias') and module.bias is not None:
76 | constant_init_(module.bias, constant_value=0)
77 |
78 | def forward(self, x):
79 | cls_score = self.cls_head(x)
80 | return cls_score
81 |
82 |
83 | class PatchEmbed(nn.Module):
84 | """Images to Patch Embedding.
85 |
86 | Args:
87 | img_size (int | tuple): Size of input image.
88 | patch_size (int): Size of one patch.
89 | tube_size (int): Size of temporal field of one 3D patch.
90 | in_channels (int): Channel num of input features. Defaults to 3.
91 | embed_dims (int): Dimensions of embedding. Defaults to 768.
92 | conv_type (str): Type for convolution layer. Defaults to 'Conv2d'.
93 | """
94 |
95 | def __init__(self,
96 | img_size,
97 | patch_size,
98 | tube_size=2,
99 | in_channels=3,
100 | embed_dims=768,
101 | conv_type='Conv2d'):
102 | super().__init__()
103 | self.img_size = _pair(img_size)
104 | self.patch_size = _pair(patch_size)
105 |
106 | num_patches = \
107 | (self.img_size[1] // self.patch_size[1]) * \
108 | (self.img_size[0] // self.patch_size[0])
109 | assert (num_patches * self.patch_size[0] * self.patch_size[1] ==
110 | self.img_size[0] * self.img_size[1],
111 | 'The image size H*W must be divisible by patch size')
112 | self.num_patches = num_patches
113 |
114 | # Use conv layer to embed
115 | if conv_type == 'Conv2d':
116 | self.projection = nn.Conv2d(
117 | in_channels,
118 | embed_dims,
119 | kernel_size=patch_size,
120 | stride=patch_size)
121 | elif conv_type == 'Conv3d':
122 | self.projection = nn.Conv3d(
123 | in_channels,
124 | embed_dims,
125 | kernel_size=(tube_size,patch_size,patch_size),
126 | stride=(tube_size,patch_size,patch_size))
127 | else:
128 | raise TypeError(f'Unsupported conv layer type {conv_type}')
129 |
130 | self.init_weights(self.projection)
131 |
132 | def init_weights(self, module):
133 | if hasattr(module, 'weight') and module.weight is not None:
134 | kaiming_init_(module.weight, mode='fan_in', nonlinearity='relu')
135 | if hasattr(module, 'bias') and module.bias is not None:
136 | constant_init_(module.bias, constant_value=0)
137 |
138 | def forward(self, x):
139 | layer_type = type(self.projection)
140 | if layer_type == nn.Conv3d:
141 | x = rearrange(x, 'b t c h w -> b c t h w')
142 | x = self.projection(x)
143 | x = rearrange(x, 'b c t h w -> (b t) (h w) c')
144 | elif layer_type == nn.Conv2d:
145 | x = rearrange(x, 'b t c h w -> (b t) c h w')
146 | x = self.projection(x)
147 | x = rearrange(x, 'b c h w -> b (h w) c')
148 | else:
149 | raise TypeError(f'Unsupported conv layer type {layer_type}')
150 |
151 | return x
152 |
153 | class Attention(nn.Module):
154 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
155 | super().__init__()
156 | self.num_heads = num_heads
157 | head_dim = dim // num_heads
158 | self.scale = qk_scale or head_dim ** -0.5
159 |
160 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
161 | self.attn_drop = nn.Dropout(attn_drop)
162 | self.proj = nn.Linear(dim, dim)
163 | self.proj_drop = nn.Dropout(proj_drop)
164 |
165 | def forward(self, x):
166 | B, N, C = x.shape
167 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
168 | q, k, v = qkv[0], qkv[1], qkv[2]
169 |
170 | attn = (q @ k.transpose(-2, -1)) * self.scale
171 | attn = attn.softmax(dim=-1)
172 | attn = self.attn_drop(attn)
173 |
174 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
175 | x = self.proj(x)
176 | x = self.proj_drop(x)
177 | return x, attn
178 |
179 | class DividedTemporalAttentionWithPreNorm(nn.Module):
180 | """Temporal Attention in Divided Space Time Attention.
181 | A warp for torch.nn.MultiheadAttention.
182 |
183 | Args:
184 | embed_dims (int): Dimensions of embedding.
185 | num_heads (int): Number of parallel attention heads in
186 | TransformerCoder.
187 | num_frames (int): Number of frames in the video.
188 | use_cls_token (bool): Whether to perform MSA on cls_token.
189 | attn_drop (float): A Dropout layer on attn_output_weights. Defaults to
190 | 0..
191 | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
192 | Defaults to 0..
193 | layer_drop (dict): The layer_drop used when adding the shortcut.
194 | Defaults to `dict(type=DropPath, dropout_p=0.1)`.
195 | norm_layer (class): Class name for normalization layer. Defaults to
196 | nn.LayerNorm.
197 | """
198 |
199 | def __init__(self,
200 | embed_dims,
201 | num_heads,
202 | num_frames,
203 | use_cls_token,
204 | attn_drop=0.,
205 | proj_drop=0.,
206 | layer_drop=dict(type=DropPath, dropout_p=0.1),
207 | norm_layer=nn.LayerNorm,
208 | **kwargs):
209 | super().__init__()
210 | self.embed_dims = embed_dims
211 | self.num_heads = num_heads
212 | self.num_frames = num_frames
213 | self.use_cls_token = use_cls_token
214 |
215 | self.norm = norm_layer(embed_dims)
216 | #self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
217 | # **kwargs)
218 | self.attn = Attention(embed_dims, num_heads, qkv_bias=True, attn_drop=attn_drop) # batch first
219 |
220 | self.proj_drop = nn.Dropout(proj_drop)
221 | dropout_p = layer_drop.pop('dropout_p')
222 | layer_drop= layer_drop.pop('type')
223 | self.layer_drop = layer_drop(dropout_p) if layer_drop else nn.Identity()
224 | if not use_cls_token:
225 | self.temporal_fc = nn.Linear(self.embed_dims, self.embed_dims)
226 | self.init_weights(self.temporal_fc)
227 |
228 | def init_weights(self, module):
229 | if hasattr(module, 'weight') and module.weight is not None:
230 | constant_init_(module.weight, constant_value=0)
231 | if hasattr(module, 'bias') and module.bias is not None:
232 | constant_init_(module.bias, constant_value=0)
233 |
234 | def forward(self, query, key=None, value=None, residual=None, return_attention=False, **kwargs):
235 | assert residual is None, (
236 | 'Always adding the shortcut in the forward function')
237 |
238 | cls_token = query[:, 0, :].unsqueeze(1)
239 | if self.use_cls_token:
240 | residual = query
241 | query = query[:, 1:, :]
242 | else:
243 | query = query[:, 1:, :]
244 | residual = query
245 |
246 | b, n, d = query.size()
247 | p, t = n // self.num_frames, self.num_frames
248 |
249 | # Pre-Process
250 | query = rearrange(query, 'b (p t) d -> (b p) t d', p=p, t=t)
251 | if self.use_cls_token:
252 | cls_token = repeat(cls_token, 'b n d -> b (p n) d', p=p)
253 | cls_token = rearrange(cls_token, 'b p d -> (b p) 1 d')
254 | query = torch.cat((cls_token, query), 1)
255 |
256 | # Forward MSA
257 | query = self.norm(query)
258 | #query = rearrange(query, 'b n d -> n b d')
259 | #attn_out = self.attn(query, query, query)[0]
260 | #attn_out = rearrange(attn_out, 'n b d -> b n d')
261 | attn_out, attn_weights = self.attn(query)
262 | if return_attention:
263 | return attn_weights
264 |
265 | attn_out = self.layer_drop(self.proj_drop(attn_out.contiguous()))
266 | if not self.use_cls_token:
267 | attn_out = self.temporal_fc(attn_out)
268 |
269 | # Post-Process
270 | if self.use_cls_token:
271 | cls_token, attn_out = attn_out[:, 0, :], attn_out[:, 1:, :]
272 | cls_token = rearrange(cls_token, '(b p) d -> b p d', b=b)
273 | cls_token = reduce(cls_token, 'b p d -> b 1 d', 'mean')
274 |
275 | attn_out = rearrange(attn_out, '(b p) t d -> b (p t) d', p=p, t=t)
276 | attn_out = torch.cat((cls_token, attn_out), 1)
277 | new_query = residual + attn_out
278 | else:
279 | attn_out = rearrange(attn_out, '(b p) t d -> b (p t) d', p=p, t=t)
280 | new_query = residual + attn_out
281 | new_query = torch.cat((cls_token, new_query), 1)
282 | return new_query
283 |
284 |
285 | class DividedSpatialAttentionWithPreNorm(nn.Module):
286 | """Spatial Attention in Divided Space Time Attention.
287 | A warp for torch.nn.MultiheadAttention.
288 |
289 | Args:
290 | embed_dims (int): Dimensions of embedding.
291 | num_heads (int): Number of parallel attention heads in
292 | TransformerCoder.
293 | num_frames (int): Number of frames in the video.
294 | use_cls_token (bool): Whether to perform MSA on cls_token.
295 | attn_drop (float): A Dropout layer on attn_output_weights. Defaults to
296 | 0..
297 | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
298 | Defaults to 0..
299 | layer_drop (dict): The layer_drop used when adding the shortcut.
300 | Defaults to `dict(type=DropPath, dropout_p=0.1)`.
301 | norm_layer (class): Class name for normalization layer. Defaults to
302 | nn.LayerNorm.
303 | """
304 |
305 | def __init__(self,
306 | embed_dims,
307 | num_heads,
308 | num_frames,
309 | use_cls_token,
310 | attn_drop=0.,
311 | proj_drop=0.,
312 | layer_drop=dict(type=DropPath, dropout_p=0.1),
313 | norm_layer=nn.LayerNorm,
314 | **kwargs):
315 | super().__init__()
316 | self.embed_dims = embed_dims
317 | self.num_heads = num_heads
318 | self.num_frames = num_frames
319 | self.use_cls_token = use_cls_token
320 |
321 | self.norm = norm_layer(embed_dims)
322 | #self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
323 | # **kwargs)
324 | self.attn = Attention(embed_dims, num_heads, qkv_bias=True, attn_drop=attn_drop) # batch first
325 |
326 | self.proj_drop = nn.Dropout(proj_drop)
327 | dropout_p = layer_drop.pop('dropout_p')
328 | layer_drop= layer_drop.pop('type')
329 | self.layer_drop = layer_drop(dropout_p) if layer_drop else nn.Identity()
330 |
331 | self.init_weights()
332 |
333 | def init_weights(self):
334 | pass
335 |
336 | def forward(self, query, key=None, value=None, residual=None, return_attention=False, **kwargs):
337 | assert residual is None, (
338 | 'Always adding the shortcut in the forward function')
339 |
340 | cls_token = query[:, 0, :].unsqueeze(1)
341 | if self.use_cls_token:
342 | residual = query
343 | query = query[:, 1:, :]
344 | else:
345 | query = query[:, 1:, :]
346 | residual = query
347 |
348 | b, n, d = query.size()
349 | p, t = n // self.num_frames, self.num_frames
350 |
351 | # Pre-Process
352 | query = rearrange(query, 'b (p t) d -> (b t) p d', p=p, t=t)
353 | if self.use_cls_token:
354 | cls_token = repeat(cls_token, 'b n d -> b (t n) d', t=t)
355 | cls_token = rearrange(cls_token, 'b t d -> (b t) 1 d')
356 | query = torch.cat((cls_token, query), 1)
357 |
358 | # Forward MSA
359 | query = self.norm(query)
360 | #query = rearrange(query, 'b n d -> n b d')
361 | #attn_out = self.attn(query, query, query)[0]
362 | #attn_out = rearrange(attn_out, 'n b d -> b n d')
363 | attn_out, attn_weights = self.attn(query)
364 | if return_attention:
365 | return attn_weights
366 |
367 | attn_out = self.layer_drop(self.proj_drop(attn_out.contiguous()))
368 |
369 | # Post-Process
370 | if self.use_cls_token:
371 | cls_token, attn_out = attn_out[:, 0, :], attn_out[:, 1:, :]
372 | cls_token = rearrange(cls_token, '(b t) d -> b t d', b=b)
373 | cls_token = reduce(cls_token, 'b t d -> b 1 d', 'mean')
374 |
375 | attn_out = rearrange(attn_out, '(b t) p d -> b (p t) d', p=p, t=t)
376 | attn_out = torch.cat((cls_token, attn_out), 1)
377 | new_query = residual + attn_out
378 | else:
379 | attn_out = rearrange(attn_out, '(b t) p d -> b (p t) d', p=p, t=t)
380 | new_query = residual + attn_out
381 | new_query = torch.cat((cls_token, new_query), 1)
382 | return new_query
383 |
384 |
385 | class MultiheadAttentionWithPreNorm(nn.Module):
386 | """Implements MultiheadAttention with residual connection.
387 |
388 | Args:
389 | embed_dims (int): The embedding dimension.
390 | num_heads (int): Parallel attention heads.
391 | attn_drop (float): A Dropout layer on attn_output_weights.
392 | Default: 0.0.
393 | proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
394 | Default: 0.0.
395 | norm_layer (class): Class name for normalization layer. Defaults to
396 | nn.LayerNorm.
397 | layer_drop (obj:`ConfigDict`): The layer_drop used
398 | when adding the shortcut.
399 | batch_first (bool): When it is True, Key, Query and Value are shape of
400 | (batch, n, embed_dim), otherwise (n, batch, embed_dim).
401 | Default to False.
402 | """
403 |
404 | def __init__(self,
405 | embed_dims,
406 | num_heads,
407 | attn_drop=0.,
408 | proj_drop=0.,
409 | norm_layer=nn.LayerNorm,
410 | layer_drop=dict(type=DropPath, dropout_p=0.),
411 | batch_first=False,
412 | **kwargs):
413 | super().__init__()
414 | self.embed_dims = embed_dims
415 | self.num_heads = num_heads
416 | #self.batch_first = batch_first
417 |
418 | self.norm = norm_layer(embed_dims)
419 | #self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
420 | # **kwargs)
421 | self.attn = Attention(embed_dims, num_heads, qkv_bias=True, attn_drop=attn_drop) # batch first
422 |
423 | self.proj_drop = nn.Dropout(proj_drop)
424 | dropout_p = layer_drop.pop('dropout_p')
425 | layer_drop= layer_drop.pop('type')
426 | self.layer_drop = layer_drop(dropout_p) if layer_drop else nn.Identity()
427 |
428 | def forward(self,
429 | query,
430 | key=None,
431 | value=None,
432 | residual=None,
433 | attn_mask=None,
434 | key_padding_mask=None,
435 | return_attention=False,
436 | **kwargs):
437 | residual = query
438 |
439 | query = self.norm(query)
440 | #if self.batch_first:
441 | # query = query.transpose(0, 1)
442 | #attn_out = self.attn(
443 | # query=query,
444 | # key=query,
445 | # value=query,
446 | # attn_mask=attn_mask,
447 | # key_padding_mask=key_padding_mask)[0]
448 | #attn_out = self.attn(query, query, query)[0]
449 | #if self.batch_first:
450 | # attn_out = attn_out.transpose(0, 1)
451 | attn_out, attn_weights = self.attn(query)
452 | if return_attention:
453 | return attn_weights
454 |
455 | new_query = residual + self.layer_drop(self.proj_drop(attn_out))
456 | return new_query
457 |
458 |
459 | class FFNWithPreNorm(nn.Module):
460 | """Implements feed-forward networks (FFNs) with residual connection.
461 |
462 | Args:
463 | embed_dims (int): The feature dimension. Same as
464 | `MultiheadAttention`. Defaults: 256.
465 | hidden_channels (int): The hidden dimension of FFNs.
466 | Defaults: 1024.
467 | num_layers (int, optional): The number of fully-connected layers in
468 | FFNs. Default: 2.
469 | act_layer (dict, optional): The activation layer for FFNs.
470 | Default: nn.GELU
471 | norm_layer (class): Class name for normalization layer. Defaults to
472 | nn.LayerNorm.
473 | dropout_p (float, optional): Probability of an element to be
474 | zeroed in FFN. Default 0.0.
475 | layer_drop (obj:`ConfigDict`): The layer_drop used
476 | when adding the shortcut.
477 | """
478 |
479 | def __init__(self,
480 | embed_dims=256,
481 | hidden_channels=1024,
482 | num_layers=2,
483 | act_layer=nn.GELU,
484 | norm_layer=nn.LayerNorm,
485 | dropout_p=0.,
486 | layer_drop=None,
487 | **kwargs):
488 | super().__init__()
489 | assert num_layers >= 2, 'num_layers should be no less ' \
490 | f'than 2. got {num_layers}.'
491 | self.embed_dims = embed_dims
492 | self.hidden_channels = hidden_channels
493 | self.num_layers = num_layers
494 |
495 | self.norm = norm_layer(embed_dims)
496 | layers = []
497 | in_channels = embed_dims
498 | for _ in range(num_layers - 1):
499 | layers.append(
500 | nn.Sequential(
501 | nn.Linear(in_channels, hidden_channels),
502 | act_layer(),
503 | nn.Dropout(dropout_p)))
504 | in_channels = hidden_channels
505 | layers.append(nn.Linear(hidden_channels, embed_dims))
506 | layers.append(nn.Dropout(dropout_p))
507 | self.layers = nn.ModuleList(layers)
508 |
509 | if layer_drop:
510 | dropout_p = layer_drop.pop('dropout_p')
511 | layer_drop= layer_drop.pop('type')
512 | self.layer_drop = layer_drop(dropout_p)
513 | else:
514 | self.layer_drop = nn.Identity()
515 |
516 | def forward(self, x):
517 | residual = x
518 |
519 | x = self.norm(x)
520 | for layer in self.layers:
521 | x = layer(x)
522 |
523 | return residual + self.layer_drop(x)
524 |
525 |
526 | class TransformerContainer(nn.Module):
527 |
528 | def __init__(self,
529 | num_transformer_layers,
530 | embed_dims,
531 | num_heads,
532 | num_frames,
533 | hidden_channels,
534 | operator_order,
535 | drop_path_rate=0.1,
536 | norm_layer=nn.LayerNorm,
537 | act_layer=nn.GELU,
538 | num_layers=2):
539 | super().__init__()
540 | self.layers = nn.ModuleList([])
541 | self.num_transformer_layers = num_transformer_layers
542 |
543 | dpr = np.linspace(0, drop_path_rate, num_transformer_layers)
544 | for i in range(num_transformer_layers):
545 | self.layers.append(
546 | BasicTransformerBlock(
547 | embed_dims=embed_dims,
548 | num_heads=num_heads,
549 | num_frames=num_frames,
550 | hidden_channels=hidden_channels,
551 | operator_order=operator_order,
552 | norm_layer=norm_layer,
553 | act_layer=act_layer,
554 | num_layers=num_layers,
555 | dpr=dpr[i]))
556 |
557 | def forward(self, x, return_attention=False):
558 | layer_idx = 0
559 | for layer in self.layers:
560 | if layer_idx >= self.num_transformer_layers-1 and return_attention:
561 | x = layer(x, return_attention=True)
562 | else:
563 | x = layer(x)
564 | layer_idx += 1
565 | return x
566 |
567 |
568 | class BasicTransformerBlock(nn.Module):
569 |
570 | def __init__(self,
571 | embed_dims,
572 | num_heads,
573 | num_frames,
574 | hidden_channels,
575 | operator_order,
576 | norm_layer=nn.LayerNorm,
577 | act_layer=nn.GELU,
578 | num_layers=2,
579 | dpr=0,
580 | ):
581 |
582 | super().__init__()
583 | self.attentions = nn.ModuleList([])
584 | self.ffns = nn.ModuleList([])
585 |
586 | for i, operator in enumerate(operator_order):
587 | if operator == 'self_attn':
588 | self.attentions.append(
589 | MultiheadAttentionWithPreNorm(
590 | embed_dims=embed_dims,
591 | num_heads=num_heads,
592 | batch_first=True,
593 | norm_layer=nn.LayerNorm,
594 | layer_drop=dict(type=DropPath, dropout_p=dpr)))
595 | elif operator == 'time_attn':
596 | self.attentions.append(
597 | DividedTemporalAttentionWithPreNorm(
598 | embed_dims=embed_dims,
599 | num_heads=num_heads,
600 | num_frames=num_frames,
601 | norm_layer=norm_layer,
602 | use_cls_token=(i==len(operator_order)-2),
603 | layer_drop=dict(type=DropPath, dropout_p=dpr)))
604 | elif operator == 'space_attn':
605 | self.attentions.append(
606 | DividedSpatialAttentionWithPreNorm(
607 | embed_dims=embed_dims,
608 | num_heads=num_heads,
609 | num_frames=num_frames,
610 | norm_layer=norm_layer,
611 | use_cls_token=(i==len(operator_order)-2),
612 | layer_drop=dict(type=DropPath, dropout_p=dpr)))
613 | elif operator == 'ffn':
614 | self.ffns.append(
615 | FFNWithPreNorm(
616 | embed_dims=embed_dims,
617 | hidden_channels=hidden_channels,
618 | num_layers=num_layers,
619 | act_layer=act_layer,
620 | norm_layer=norm_layer,
621 | layer_drop=dict(type=DropPath, dropout_p=dpr)))
622 | else:
623 | raise TypeError(f'Unsupported operator type {operator}')
624 |
625 | def forward(self, x, return_attention=False):
626 | attention_idx = 0
627 | for layer in self.attentions:
628 | if attention_idx >= len(self.attentions)-1 and return_attention:
629 | x = layer(x, return_attention=True)
630 | return x
631 | else:
632 | x = layer(x)
633 | attention_idx += 1
634 | for layer in self.ffns:
635 | x = layer(x)
636 | return x
637 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import os.path as osp
4 |
5 | import numpy as np
6 | import torch
7 | import matplotlib.pyplot as plt
8 | import torch.distributed as dist
9 | from pytorch_lightning.utilities.distributed import rank_zero_only
10 |
11 | @rank_zero_only
12 | def print_on_rank_zero(content):
13 | if is_main_process():
14 | print(content)
15 |
16 | def is_dist_avail_and_initialized():
17 | if not dist.is_available():
18 | return False
19 | if not dist.is_initialized():
20 | return False
21 | return True
22 |
23 | def get_world_size():
24 | if not is_dist_avail_and_initialized():
25 | return 1
26 | return dist.get_world_size()
27 |
28 | def get_rank():
29 | if not is_dist_avail_and_initialized():
30 | return 0
31 | return dist.get_rank()
32 |
33 | def is_main_process():
34 | return get_rank() == 0
35 |
36 | def timeit_wrapper(func, *args, **kwargs):
37 | start = time.perf_counter()
38 | func_return_val = func(*args, **kwargs)
39 | end = time.perf_counter()
40 | return func_return_val, float(f'{end - start:.4f}')
41 |
42 | def show_trainable_params(named_parameters):
43 | for name, param in named_parameters:
44 | print(name, param.size())
45 |
46 | def build_param_groups(model):
47 | params_no_decay = []
48 | params_has_decay = []
49 | params_no_decay_name = []
50 | params_decay_name = []
51 | for name, param in model.named_parameters():
52 | if not param.requires_grad:
53 | continue
54 | if len(param) == 1 or name.endswith('.bias'):
55 | params_no_decay.append(param)
56 | params_no_decay_name.append(name)
57 | else:
58 | params_has_decay.append(param)
59 | params_decay_name.append(name)
60 |
61 | param_groups = [
62 | {'params': params_no_decay, 'weight_decay': 0},
63 | {'params': params_has_decay},
64 | ]
65 | print_on_rank_zero(f'params_no_decay_name: {params_no_decay_name} \n params_decay_name: {params_decay_name}')
66 | return param_groups
67 |
68 |
69 | def denormalize(data, mean, std):
70 | """Denormalize an image/video tensor with mean and standard deviation.
71 |
72 | Args:
73 | input: Image tensor of size : (H W C).
74 | mean: Mean for each channel.
75 | std: Standard deviations for each channel.
76 |
77 | Return:
78 | Denormalised tensor with same size as input : (H W C).
79 | """
80 | shape = data.shape
81 |
82 | if isinstance(mean, tuple):
83 | mean = np.array(mean, dtype=float)
84 | mean = torch.tensor(mean, device=data.device, dtype=data.dtype)
85 |
86 | if isinstance(std, tuple):
87 | std = np.array(std, dtype=float)
88 | std = torch.tensor(std, device=data.device, dtype=data.dtype)
89 |
90 | if mean.shape:
91 | mean = mean[None, :]
92 | if std.shape:
93 | std = std[None, :]
94 |
95 | out = (data.contiguous().view(-1, shape[-1]) * std) + mean
96 |
97 | return out.view(shape)
98 |
99 |
100 | def show_processed_image(imgs, save_dir, mean, std, index=0):
101 | """Plot the transformed images into figure and save to disk.
102 |
103 | Args:
104 | imgs: Image tensor of size : (T H W C).
105 | save_dir: The path to save the images.
106 | index: The index of current clips.
107 | """
108 | os.makedirs(save_dir, exist_ok=True)
109 | if not isinstance(imgs[0], list):
110 | imgs = [imgs]
111 |
112 | num_show_clips = 5
113 | num_rows = len(imgs)
114 | num_cols = num_show_clips
115 | fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
116 | for row_idx, row in enumerate(imgs):
117 | row = row[:num_show_clips]
118 | for col_idx, img in enumerate(row):
119 | ax = axs[row_idx, col_idx]
120 | img = denormalize(img, mean, std).cpu().numpy()
121 | img = (img * 255).astype(np.uint8)
122 | #img = img.cpu().numpy().astype(np.uint8)
123 | ax.imshow(np.asarray(img))
124 | ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
125 |
126 | plt.tight_layout()
127 | filename = osp.join(save_dir, f'clip_transformed_b{index}.png')
128 | plt.savefig(filename)
--------------------------------------------------------------------------------
/visualize_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Copy-paste from DINO library:
9 | https://github.com/facebookresearch/dino
10 | """
11 | import os
12 | import argparse
13 | import cv2
14 | import random
15 | import colorsys
16 | import matplotlib
17 | import matplotlib.pyplot as plt
18 | import torch
19 | import torch.nn as nn
20 | import torchvision
21 | import numpy as np
22 | from PIL import Image
23 | from skimage.measure import find_contours
24 | from matplotlib.patches import Polygon
25 | from torch.utils.data import DataLoader
26 |
27 | import utils
28 | from video_transformer import TimeSformer
29 | import data_transform as T
30 | from dataset import DecordInit
31 | import weight_init
32 |
33 | matplotlib.use('Agg')
34 |
35 | company_colors = [
36 | (0,160,215), # blue
37 | (220,55,60), # red
38 | (245,180,0), # yellow
39 | (10,120,190), # navy
40 | (40,150,100), # green
41 | (135,75,145), # purple
42 | ]
43 | company_colors = [(float(c[0]) / 255.0, float(c[1]) / 255.0, float(c[2]) / 255.0) for c in company_colors]
44 |
45 | def apply_mask2(image, mask, color, alpha=0.5):
46 | """Apply the given mask to the image.
47 | """
48 | t= 0.2
49 | mi = np.min(mask)
50 | ma = np.max(mask)
51 | mask = (mask - mi) / (ma - mi)
52 | for c in range(3):
53 | image[:, :, c] = image[:, :, c] * (1 - alpha * np.sqrt(mask) * (mask>t))+ alpha * np.sqrt(mask) * (mask>t) * color[c] * 255
54 | return image
55 |
56 | def random_colors(N, bright=True):
57 | """
58 | Generate random colors.
59 | """
60 | brightness = 1.0 if bright else 0.7
61 | hsv = [(i / N, 1, brightness) for i in range(N)]
62 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
63 | random.shuffle(colors)
64 | return colors
65 |
66 | def show_attn(img, attentions, w_featmap, h_featmap, frame_index, index=None):
67 |
68 | nh = attentions.shape[0] # number of head
69 |
70 | # we keep only the output patch attention
71 | attentions = attentions[:, 0, 1:].reshape(nh, -1)
72 |
73 | if args.threshold is not None:
74 | # we keep only a certain percentage of the mass
75 | val, idx = torch.sort(attentions)
76 | val /= torch.sum(val, dim=1, keepdim=True)
77 | cumval = torch.cumsum(val, dim=1)
78 | th_attn = cumval > (1 - args.threshold)
79 | idx2 = torch.argsort(idx)
80 | for head in range(nh):
81 | th_attn[head] = th_attn[head][idx2[head]]
82 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
83 | # interpolate
84 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].detach().cpu().numpy()
85 |
86 | attentions = attentions.reshape(nh, w_featmap, h_featmap)
87 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].detach().cpu().numpy()
88 |
89 | # save attentions heatmaps
90 | prefix = f'id{index}_' if index is not None else ''
91 | os.makedirs(args.output_dir, exist_ok=True)
92 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, f"img{frame_index}" + ".png"))
93 | img = Image.open(os.path.join(args.output_dir, f"img{frame_index}" + ".png"))
94 |
95 | attns = Image.new('RGB', (attentions.shape[2] * nh, attentions.shape[1]))
96 | for j in range(nh):
97 | #fname = os.path.join(args.output_dir, prefix + "attn-head" + str(j) + ".png")
98 | fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
99 | plt.imsave(fname=fname, arr=attentions[j], format='png')
100 | attns.paste(Image.open(fname), (j * attentions.shape[2], 0))
101 |
102 | return attentions, th_attn, img, attns
103 |
104 | def show_attn_color(image, attentions, th_attn, index=None, head=[0,1,2,3,4,5]):
105 | M = image.max()
106 | m = image.min()
107 | span = 64
108 | image = ((image - m) / (M-m)) * span + (256 - span)
109 | image = image.mean(axis=2)
110 | image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
111 |
112 | for j in head:
113 | m = attentions[j]
114 | m *= th_attn[j]
115 | attentions[j] = m
116 | mask = np.stack([attentions[j] for j in head])
117 |
118 | blur = False
119 | contour = False
120 | alpha = 1
121 | figsize = tuple([i / 100 for i in args.image_size])
122 | fig = plt.figure(figsize=figsize, frameon=False, dpi=100)
123 | ax = plt.Axes(fig, [0., 0., 1., 1.])
124 | ax.set_axis_off()
125 | fig.add_axes(ax)
126 | ax = plt.gca()
127 |
128 | if len(mask.shape) == 3:
129 | N = mask.shape[0]
130 | else:
131 | N = 1
132 | mask = mask[None, :, :]
133 |
134 | # AJ
135 | for i in range(N):
136 | mask[i] = mask[i] * ( mask[i] == np.amax(mask, axis=0))
137 | a = np.cumsum(mask, axis=0)
138 | for i in range(N):
139 | mask[i] = mask[i] * (mask[i] == a[i])
140 |
141 | colors = company_colors[:N]
142 |
143 | # Show area outside image boundaries.
144 | height, width = image.shape[:2]
145 | margin = 0
146 | ax.set_ylim(height + margin, -margin)
147 | ax.set_xlim(-margin, width + margin)
148 | ax.axis('off')
149 | masked_image = 0.1*image.astype(np.uint32).copy()
150 | for i in range(N):
151 | color = colors[i]
152 | _mask = mask[i]
153 | if blur:
154 | _mask = cv2.blur(_mask,(10,10))
155 | # Mask
156 | masked_image = apply_mask2(masked_image, _mask, color, alpha)
157 | # Mask Polygon
158 | # Pad to ensure proper polygons for masks that touch image edges.
159 | if contour:
160 | padded_mask = np.zeros(
161 | (_mask.shape[0] + 2, _mask.shape[1] + 2))#, dtype=np.uint8)
162 | padded_mask[1:-1, 1:-1] = _mask
163 | contours = find_contours(padded_mask, 0.5)
164 | for verts in contours:
165 | # Subtract the padding and flip (y, x) to (x, y)
166 | verts = np.fliplr(verts) - 1
167 | p = Polygon(verts, facecolor="none", edgecolor=color)
168 | ax.add_patch(p)
169 | ax.imshow(masked_image.astype(np.uint8), aspect='auto')
170 | ax.axis('image')
171 | #fname = os.path.join(output_dir, 'bnw-{:04d}'.format(imid))
172 | prefix = f'id{index}_' if index is not None else ''
173 | fname = os.path.join(args.output_dir, "attn_color.png")
174 | fig.savefig(fname)
175 | attn_color = Image.open(fname)
176 |
177 | return attn_color
178 |
179 | if __name__ == '__main__':
180 | parser = argparse.ArgumentParser('Visualize Self-Attention maps')
181 | parser.add_argument('--arch', default='timesformer', type=str, choices=['timesformer'], help='Architecture.')
182 | parser.add_argument('--pretrained_weights', default='', type=str, help="""Path to pretrained
183 | weights to evaluate. Set to `download` to automatically load the pretrained DINO from url.
184 | Otherwise the model is randomly initialized""")
185 | parser.add_argument('--output_dir', default='./attention_map', help='Path where to save visualizations.')
186 | parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
187 | obtained by thresholding the self-attention maps to keep xx% of the mass.""")
188 | parser.add_argument("--patch_size", type=int, default=16, help="""patch size.""")
189 | parser.add_argument("--image_size", default=(224, 224), type=int, nargs="+", help="Resize image.")
190 | args = parser.parse_args()
191 |
192 | #utils.fix_random_seeds(0)
193 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
194 |
195 | # build model
196 | num_frames = 8
197 | frame_interval = 32
198 | num_class = 400
199 | arch = args.arch # turn to vivit for initializing vivit model
200 | if arch == 'timesformer':
201 | pretrain_pth = args.pretrained_weights #'./timesformer_k400.pth'
202 | model = TimeSformer(num_frames=num_frames,
203 | img_size=args.image_size,
204 | patch_size=16,
205 | embed_dims=768,
206 | in_channels=3,
207 | attention_type='divided_space_time',
208 | return_cls_token=True)
209 | else:
210 | raise TypeError(f'not supported arch type {arch}, chosen in (timesformer, vivit)')
211 |
212 | msg_trans = weight_init.init_from_kinetics_pretrain_(model, pretrain_pth, init_module='transformer')
213 | model.eval()
214 | model = model.to(device)
215 | print(f'load model finished, the missing key of transformer is:{msg_trans[0]}, unexpect_key is:{msg_trans[1]}')
216 |
217 | # build data
218 | video_path = './demo/YABnJL_bDzw.mp4'
219 | mean, std = (0.45, 0.45, 0.45), (0.225, 0.225, 0.225)
220 | data_transform = T.Compose([
221 | T.Resize(scale_range=(-1, 256)),
222 | T.CenterCrop(args.image_size),
223 | T.ToTensor(),
224 | T.Normalize(mean, std)
225 | ])
226 | temporal_sample = T.TemporalRandomCrop(num_frames*frame_interval)
227 |
228 | video_decoder = DecordInit()
229 | v_reader = video_decoder(video_path)
230 | total_frames = len(v_reader)
231 | start_frame_ind, end_frame_ind = temporal_sample(total_frames)
232 | if end_frame_ind-start_frame_ind < num_frames:
233 | raise ValueError(f'the total frames of the video {video_path} is less than {num_frames}')
234 | frame_indice = np.linspace(0, end_frame_ind-start_frame_ind-1, num_frames, dtype=int)
235 | video = v_reader.get_batch(frame_indice).asnumpy()
236 | del v_reader
237 |
238 | video = torch.from_numpy(video).permute(0,3,1,2) # Video transform: T C H W
239 | data_transform.randomize_parameters()
240 | video = data_transform(video)
241 | video = video.to(device)
242 |
243 | # extract the attention maps
244 | w_featmap = video.shape[-2] // args.patch_size
245 | h_featmap = video.shape[-1] // args.patch_size
246 | attentions = model.get_last_selfattention(video.unsqueeze(0).to(device)) #
247 | print(attentions.shape) # [8 12 197 197]
248 | for i,(frame, attention) in enumerate(zip(video, attentions)):
249 | # make the video frame divisible by the patch size
250 | attentions, th_attn, pic_i, pic_attn = show_attn(frame, attention, w_featmap, h_featmap, frame_index=i)
251 | pic_attn_color = show_attn_color(frame.permute(1, 2, 0).cpu().numpy(), attentions, th_attn)
252 | final_pic = Image.new('RGB', (pic_i.size[1] * 2 + pic_attn.size[0], pic_i.size[1]))
253 | final_pic.paste(pic_i, (0, 0))
254 | final_pic.paste(pic_attn_color, (pic_i.size[1], 0))
255 | final_pic.paste(pic_attn, (pic_i.size[1] * 2, 0))
256 | final_pic.save(os.path.join(args.output_dir, f"attn_img{i}.png"))
--------------------------------------------------------------------------------
/weight_init.py:
--------------------------------------------------------------------------------
1 | import math
2 | import re
3 | import warnings
4 |
5 | from einops import repeat
6 | import torch
7 | import torch.nn as nn
8 |
9 | from utils import print_on_rank_zero
10 |
11 |
12 | def show_state_dict(state_dict):
13 | for name, value in state_dict.items():
14 | print(name)
15 |
16 |
17 | def replace_state_dict(state_dict):
18 | for old_key in list(state_dict.keys()):
19 | if old_key.startswith('model'):
20 | new_key = old_key[6:] # skip 'model.'
21 | if 'in_proj' in new_key:
22 | new_key = new_key.replace('in_proj_', 'qkv.') #in_proj_weight -> qkv.weight
23 | elif 'out_proj' in new_key:
24 | new_key = new_key.replace('out_proj', 'proj')
25 | state_dict[new_key] = state_dict.pop(old_key)
26 | else: # cls_head
27 | new_key = old_key[9:]
28 | state_dict[new_key] = state_dict.pop(old_key)
29 |
30 |
31 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
32 | def norm_cdf(x):
33 | # Computes standard normal cumulative distribution function
34 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
35 |
36 | if (mean < a - 2 * std) or (mean > b + 2 * std):
37 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
38 | "The distribution of values may be incorrect.",
39 | stacklevel=2)
40 |
41 | with torch.no_grad():
42 | # Values are generated by using a truncated uniform distribution and
43 | # then using the inverse CDF for the normal distribution.
44 | # Get upper and lower cdf values
45 | l = norm_cdf((a - mean) / std)
46 | u = norm_cdf((b - mean) / std)
47 |
48 | # Uniformly fill tensor with values from [l, u], then translate to
49 | # [2l-1, 2u-1].
50 | tensor.uniform_(2 * l - 1, 2 * u - 1)
51 |
52 | # Use inverse cdf transform for normal distribution to get truncated
53 | # standard normal
54 | tensor.erfinv_()
55 |
56 | # Transform to proper mean, std
57 | tensor.mul_(std * math.sqrt(2.))
58 | tensor.add_(mean)
59 |
60 | # Clamp to ensure it's in the proper range
61 | tensor.clamp_(min=a, max=b)
62 | return tensor
63 |
64 |
65 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
66 | # type: (Tensor, float, float, float, float) -> Tensor
67 | r"""Fills the input Tensor with values drawn from a truncated
68 | normal distribution. The values are effectively drawn from the
69 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
70 | with values outside :math:`[a, b]` redrawn until they are within
71 | the bounds. The method used for generating the random values works
72 | best when :math:`a \leq \text{mean} \leq b`.
73 | Args:
74 | tensor: an n-dimensional `torch.Tensor`
75 | mean: the mean of the normal distribution
76 | std: the standard deviation of the normal distribution
77 | a: the minimum cutoff value
78 | b: the maximum cutoff value
79 | Examples:
80 | >>> w = torch.empty(3, 5)
81 | >>> nn.init.trunc_normal_(w)
82 | """
83 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
84 |
85 |
86 | @torch.no_grad()
87 | def constant_init_(tensor, constant_value=0):
88 | nn.init.constant_(tensor, constant_value)
89 |
90 |
91 | @torch.no_grad()
92 | def kaiming_init_(tensor,
93 | a=0,
94 | mode='fan_out',
95 | nonlinearity='relu',
96 | distribution='normal'):
97 | assert distribution in ['uniform', 'normal']
98 | if distribution == 'uniform':
99 | nn.init.kaiming_uniform_(
100 | tensor, a=a, mode=mode, nonlinearity=nonlinearity)
101 | else:
102 | nn.init.kaiming_normal_(
103 | tensor, a=a, mode=mode, nonlinearity=nonlinearity)
104 |
105 |
106 | @torch.no_grad()
107 | def init_from_vit_pretrain_(module,
108 | pretrained,
109 | conv_type,
110 | attention_type,
111 | copy_strategy,
112 | extend_strategy='temporal_avg',
113 | tube_size=2,
114 | num_time_transformer_layers=4):
115 |
116 | if isinstance(pretrained, str):
117 | if torch.cuda.is_available():
118 | state_dict = torch.load(pretrained)
119 | else:
120 | state_dict = torch.load(pretrained, map_location=torch.device('cpu'))
121 |
122 | if 'state_dict' in state_dict:
123 | state_dict = state_dict['state_dict']
124 |
125 | old_state_dict_keys = list(state_dict.keys())
126 | for old_key in old_state_dict_keys:
127 | # extend the Conv2d params to Conv3d
128 | if conv_type == 'Conv3d':
129 | if 'patch_embed.projection.weight' in old_key:
130 | weight = state_dict[old_key]
131 | new_weight = repeat(weight, 'd c h w -> d c t h w', t=tube_size)
132 | if extend_strategy == 'temporal_avg':
133 | new_weight = new_weight / tube_size
134 | elif extend_strategy == 'center_frame':
135 | new_weight.zero_()
136 | new_weight[:,:,tube_size//2,:,:] = weight
137 | state_dict[old_key] = new_weight
138 | continue
139 |
140 | # modify the key names of norm layers
141 | if attention_type == 'fact_encoder':
142 | new_key = old_key.replace('transformer_layers.layers',
143 | 'transformer_layers.0.layers')
144 | else:
145 | new_key = old_key
146 |
147 | if 'in_proj' in new_key:
148 | new_key = new_key.replace('in_proj_', 'qkv.') #in_proj_weight -> qkv.weight
149 | elif 'out_proj' in new_key:
150 | new_key = new_key.replace('out_proj', 'proj')
151 |
152 | if 'norms' in new_key:
153 | new_key = new_key.replace('norms.0', 'attentions.0.norm')
154 | new_key = new_key.replace('norms.1', 'ffns.0.norm')
155 |
156 | state_dict[new_key] = state_dict.pop(old_key)
157 |
158 | old_state_dict_keys = list(state_dict.keys())
159 | for old_key in old_state_dict_keys:
160 | # copy the parameters of space attention to time attention
161 | if attention_type == 'divided_space_time':
162 | if 'attentions.0' in old_key:
163 | new_key = old_key.replace('attentions.0',
164 | 'attentions.1')
165 | if copy_strategy == 'repeat':
166 | state_dict[new_key] = state_dict[old_key].clone()
167 | elif copy_strategy == 'set_zero':
168 | state_dict[new_key] = state_dict[old_key].clone().zero_()
169 | # copy the part of parameters of space attention to time attention
170 | elif attention_type == 'fact_encoder':
171 | pattern = re.compile(r'(?<=layers.)\d+')
172 | matchObj = pattern.findall(old_key)
173 | if len(matchObj) > 1 and int(matchObj[1]) < num_time_transformer_layers:
174 | new_key = old_key.replace('transformer_layers.0.layers',
175 | 'transformer_layers.1.layers')
176 | if copy_strategy == 'repeat':
177 | state_dict[new_key] = state_dict[old_key].clone()
178 | elif copy_strategy == 'set_zero':
179 | state_dict[new_key] = state_dict[old_key].clone().zero_()
180 |
181 | missing_keys,unexpected_keys = module.load_state_dict(state_dict, strict=False)
182 | #print(f'missing_keys:{missing_keys}\n unexpected_keys:{unexpected_keys}')
183 | print_on_rank_zero(f'missing_keys:{missing_keys}\n '
184 | f'unexpected_keys:{unexpected_keys}')
185 |
186 |
187 | @torch.no_grad()
188 | def init_from_mae_pretrain_(module,
189 | pretrained,
190 | conv_type,
191 | attention_type,
192 | copy_strategy,
193 | extend_strategy='temporal_avg',
194 | tube_size=2,
195 | num_time_transformer_layers=4):
196 |
197 | if isinstance(pretrained, str):
198 | if torch.cuda.is_available():
199 | state_dict = torch.load(pretrained)
200 | else:
201 | state_dict = torch.load(pretrained, map_location=torch.device('cpu'))
202 |
203 | if 'model' in state_dict:
204 | state_dict = state_dict['model']
205 |
206 | # adjust to our module
207 | old_state_dict_keys = list(state_dict.keys())
208 | for old_key in old_state_dict_keys:
209 | if 'decoder' in old_key:
210 | state_dict.pop(old_key)
211 | continue
212 |
213 | # extend the Conv2d params to Conv3d
214 | if 'encoder.patch_embed.proj' in old_key:
215 | new_key = old_key.replace('encoder.patch_embed.proj',
216 | 'patch_embed.projection')
217 | if conv_type == 'Conv3d' and 'weight' in old_key:
218 | weight = state_dict[old_key]
219 | new_weight = repeat(weight, 'd c h w -> d c t h w', t=tube_size)
220 | if extend_strategy == 'temporal_avg':
221 | new_weight = new_weight / tube_size
222 | elif extend_strategy == 'center_frame':
223 | new_weight.zero_()
224 | new_weight[:,:,tube_size//2,:,:] = weight
225 | state_dict.pop(old_key)
226 | state_dict[new_key] = new_weight
227 | else:
228 | state_dict[new_key] = state_dict.pop(old_key)
229 | continue
230 |
231 | # modify the key names of norm layers
232 | if attention_type == 'fact_encoder':
233 | new_key = old_key.replace('encoder.blocks',
234 | 'transformer_layers.0.layers')
235 | else:
236 | new_key = old_key.replace('encoder.blocks',
237 | 'transformer_layers.layers')
238 |
239 | if 'norm' in new_key:
240 | new_key = new_key.replace('norm1', 'attentions.0.norm')
241 | new_key = new_key.replace('norm2', 'ffns.0.norm')
242 | elif 'attn' in new_key:
243 | #new_key = new_key.replace('attn.qkv.weight',
244 | # 'attentions.0.attn.in_proj_weight')
245 | #new_key = new_key.replace('attn.proj',
246 | # 'attentions.0.attn.out_proj')
247 | if 'q_bias' in new_key:
248 | pattern = re.compile(r'(?<=blocks.)\d+')
249 | matchObj = pattern.findall(old_key)
250 | block_id = int(matchObj[0])
251 | q_bias = state_dict[f'encoder.blocks.{block_id}.attn.q_bias']
252 | v_bias = state_dict[f'encoder.blocks.{block_id}.attn.v_bias']
253 | weight = torch.cat((q_bias,
254 | torch.zeros_like(q_bias, requires_grad=False),
255 | v_bias))
256 | new_key = new_key.replace('attn.q_bias',
257 | #'attentions.0.attn.in_proj_bias')
258 | 'attentions.0.attn.qkv.bias')
259 | state_dict.pop(f'encoder.blocks.{block_id}.attn.q_bias')
260 | state_dict.pop(f'encoder.blocks.{block_id}.attn.v_bias')
261 | state_dict[new_key] = weight
262 | continue
263 | elif 'v_bias' in new_key:
264 | continue
265 | elif 'mlp' in new_key:
266 | new_key = new_key.replace('mlp.fc1', 'ffns.0.layers.0.0')
267 | new_key = new_key.replace('mlp.fc2', 'ffns.0.layers.1')
268 |
269 | if 'encoder.norm' in old_key:
270 | new_key = old_key.replace('encoder.norm',
271 | 'norm')
272 |
273 | state_dict[new_key] = state_dict.pop(old_key)
274 |
275 | # copy to new layer
276 | old_state_dict_keys = list(state_dict.keys())
277 | for old_key in old_state_dict_keys:
278 | # copy the parameters of space attention to time attention
279 | if attention_type == 'divided_space_time':
280 | if 'attentions.0' in old_key:
281 | new_key = old_key.replace('attentions.0',
282 | 'attentions.1')
283 | if copy_strategy == 'repeat':
284 | state_dict[new_key] = state_dict[old_key].clone()
285 | elif copy_strategy == 'set_zero':
286 | state_dict[new_key] = state_dict[old_key].clone().zero_()
287 | # copy the part of parameters of space attention to time attention
288 | elif attention_type == 'fact_encoder':
289 | pattern = re.compile(r'(?<=layers.)\d+')
290 | matchObj = pattern.findall(old_key)
291 | if len(matchObj) > 1 and int(matchObj[1]) < num_time_transformer_layers:
292 | new_key = old_key.replace('transformer_layers.0.layers',
293 | 'transformer_layers.1.layers')
294 | if copy_strategy == 'repeat':
295 | state_dict[new_key] = state_dict[old_key].clone()
296 | elif copy_strategy == 'set_zero':
297 | state_dict[new_key] = state_dict[old_key].clone().zero_()
298 |
299 | missing_keys,unexpected_keys = module.load_state_dict(state_dict, strict=False)
300 | #print(f'missing_keys:{missing_keys}\n unexpected_keys:{unexpected_keys}')
301 | print_on_rank_zero(f'missing_keys:{missing_keys}\n '
302 | f'unexpected_keys:{unexpected_keys}')
303 |
304 |
305 | def init_from_kinetics_pretrain_(module, pretrain_pth):
306 | if torch.cuda.is_available():
307 | state_dict = torch.load(pretrain_pth)
308 | else:
309 | state_dict = torch.load(pretrain_pth, map_location=torch.device('cpu'))
310 | if 'state_dict' in state_dict:
311 | state_dict = state_dict['state_dict']
312 |
313 | replace_state_dict(state_dict)
314 | msg = module.load_state_dict(state_dict, strict=False)
315 | print_on_rank_zero(msg)
--------------------------------------------------------------------------------