├── .gitignore
├── README.md
├── checkpoint.py
├── data
└── k400_class_mappings.json
├── figs
├── arch.png
└── k400.png
├── main.py
├── model.py
├── scripts
├── eval_k400_vitb16_16f_dec4x768.sh
├── eval_k400_vitb16_32f_dec4x768.sh
├── eval_k400_vitb16_8f_dec4x768.sh
├── eval_k400_vitl14_16f_dec4x1024.sh
├── eval_k400_vitl14_32f_dec4x1024.sh
├── eval_k400_vitl14_8f_dec4x1024.sh
├── train_k400_vitb16_16f_dec4x768.sh
├── train_k400_vitb16_32f_dec4x768.sh
├── train_k400_vitb16_8f_dec4x768.sh
├── train_k400_vitl14_16f_dec4x1024.sh
├── train_k400_vitl14_32f_dec4x1024.sh
└── train_k400_vitl14_8f_dec4x1024.sh
├── video_dataset
├── __init__.py
├── dataloader.py
├── dataset.py
├── rand_augment.py
├── random_erasing.py
└── transform.py
├── vision_transformer.py
└── weight_loaders.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.py[cod]
3 | *$py.class
4 |
5 | runs/
6 |
7 | video_dataset/io_internal.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Frozen CLIP models are Efficient Video Learners
2 |
3 | This is the official implementation of the paper [Frozen CLIP models are Efficient Video Learners](https://arxiv.org/abs/2208.03550)
4 |
5 | ```
6 | @article{lin2022frozen,
7 | title={Frozen CLIP Models are Efficient Video Learners},
8 | author={Lin, Ziyi and Geng, Shijie and Zhang, Renrui and Gao, Peng and de Melo, Gerard and Wang, Xiaogang and Dai, Jifeng and Qiao, Yu and Li, Hongsheng},
9 | journal={arXiv preprint arXiv:2208.03550},
10 | year={2022}
11 | }
12 | ```
13 |
14 | ## Introduction
15 |
16 | The overall architecture of the EVL framework includes a trainable Transformer decoder, trainable local temporal modules and a pretrained, fixed image backbone
17 | (CLIP is used for instance).
18 |
19 |
20 |
21 | Using a fixed backbone significantly saves training time, and we managed to train a ViT-B/16 with 8 frames for 50 epochs in 60 GPU-hours (NVIDIA V100).
22 |
23 | Despite with a small training computation and memory consumption, EVL models achieves high performance on Kinetics-400. A comparison with state-of-the-art methods
24 | are as follows
25 |
26 |
27 |
28 | ## Installation
29 |
30 | We tested the released code with the following conda environment
31 |
32 | ```
33 | conda create -n pt1.9.0cu11.1_official -c pytorch -c conda-forge pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 cudatoolkit torchvision av
34 | ```
35 |
36 | ## Data Preparation
37 |
38 | We expect that `--train_list_path` and `--val_list_path` command line arguments to be a data list file of the following format
39 | ```
40 |
41 |
42 | ...
43 |
44 | ```
45 | where `` points to a video file, and `` is an integer between `0` and `num_classes - 1`.
46 | `--num_classes` should also be specified in the command line argument.
47 |
48 | Additionally, `` might be a relative path when `--data_root` is specified, and the actual path will be
49 | relative to the path passed as `--data_root`.
50 |
51 | The class mappings in the open-source weights are provided at [Kinetics-400 class mappings](data/k400_class_mappings.json)
52 |
53 | ## Backbone Preparation
54 |
55 | CLIP weights need to be downloaded from [CLIP official repo](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py#L30)
56 | and passed to the `--backbone_path` command line argument.
57 |
58 | ## Script Usage
59 |
60 | Training and evaliation scripts are provided in the scripts folder.
61 | Scripts should be ready to run once the environment is setup and
62 | `--backbone_path`, `--train_list_path` and `--val_list_path` are replaced with your own paths.
63 |
64 | For other command line arguments please see the help message for usage.
65 |
66 | ## Kinetics-400 Main Results
67 |
68 | This is a re-implementation for open-source use.
69 | We are still re-running some models, and their scripts, weights and logs will be released later.
70 | In the following table we report the re-run accuracy, which may be slightly different from the original paper (typically +/-0.1%)
71 |
72 | | Backbone | Decoder Layers | #frames x stride | top-1 | top-5 | Script | Model | Log |
73 | | - | - | - | - | - | - | - | - |
74 | | ViT-B/16 | 4 | 8 x 16 | 82.8 | 95.8 | [script](scripts/train_k400_vitb16_8f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1DoGjvDdkJoSa9i-wq1lh6QoEZIa4xTB3/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1-9vgsXMpnWBI9MxQV7SSQhkPfLomoYY3/view?usp=sharing) |
75 | | ViT-B/16 | 4 | 16 x 16 | 83.7 | 96.2 | [script](scripts/train_k400_vitb16_16f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1dax4qUIOEI_QzYXv31J-87cDkonQetVQ/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1l2ivY28jUpwSmafQZvwtUo7tvm42i0PL/view?usp=sharing) |
76 | | ViT-B/16 | 4 | 32 x 8 | 84.3 | 96.6 | [script](scripts/train_k400_vitb16_32f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1fzFM5pD39Kfp8xRAJuWaXR9RALLmnoeU/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1X1ZOdSCxXVeMpNhr_bviNKlRfJa5SMD7/view?usp=sharing) |
77 | | ViT-L/14 | 4 | 8 x 16 | 86.3 | 97.2 | [script](scripts/train_k400_vitl14_8f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1AkdF4CkOVW2uiycCVqCxS397oYxNISAI/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1OJFBmaE_tAwTzG-4i0CLQmhwGnN0psx1/view?usp=sharing) |
78 | | ViT-L/14 | 4 | 16 x 16 | 86.9 | 97.4 | [script](scripts/train_k400_vitl14_16f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1CTV9geLD3HLWzByAQUOf_m0F_g2lE3rg/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1a2iC4tQvjWFMI3UrEv2chuHwVrF6p9YF/view?usp=sharing) |
79 | | ViT-L/14 | 4 | 32 x 8 | 87.7 | 97.6 | [script](scripts/train_k400_vitl14_32f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1zNFNCKwP5owakELlnTCD20cpVQBqgJrB/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1dK7qoz3McYrmfS09FfreXC-LjUM7l0u4/view?usp=sharing) |
80 | | ViT-L/14 (336px) | 4 | 32 x 8 | 87.7 | 97.8 | | | |
81 |
82 | ## Data Loading Speed
83 |
84 | As the training process is fast, video frames are consumed at a very high rate.
85 | For easier installation, the current version uses PyTorch-builtin data loaders.
86 | They are not very efficient and can become a bottleneck when using ViT-B as backbones.
87 | We provide a `--dummy_dataset` option to bypass actual video decoding for training speed measurement.
88 | The model accuracy should not be affected.
89 | Our internal data loader is pure C++-based and does not bottleneck training by much on a machine with 2x Xeon Gold 6148 CPUs and 4x V100 GPUs.
90 |
91 |
92 | ## Acknowledgements
93 |
94 | The data loader code is modified from [PySlowFast](https://github.com/facebookresearch/SlowFast). Thanks for their awesome work!
95 |
--------------------------------------------------------------------------------
/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import os
5 |
6 | import torch
7 | import torch.distributed as dist
8 |
9 |
10 | def setup_arg_parser(parser: argparse.ArgumentParser):
11 | parser.add_argument('--checkpoint_dir', type=str,
12 | help='checkpoint output path')
13 | parser.add_argument('--auto_resume', action='store_true',
14 | help='auto resume from the last checkpoint from checkpoint_dir')
15 | parser.add_argument('--resume_path', type=str,
16 | help='resume from manually specified checkpoint file, overriding auto_resume')
17 | parser.add_argument('--pretrain', type=str,
18 | help='path to pretrained weights. will NOT override auto_resume of resume_path, '
19 | 'load optimizer state or enforce strict matching of checkpoint and model weights.')
20 |
21 |
22 | def _find_autoresume_path(args: argparse.Namespace):
23 | print('Trying to auto resume from path:', args.checkpoint_dir)
24 |
25 | if os.path.isdir(args.checkpoint_dir):
26 | checkpoint_files = [x for x in os.listdir(args.checkpoint_dir) if x.startswith('checkpoint-') and x.endswith('.pth')]
27 | checkpoint_iters = []
28 | for x in checkpoint_files:
29 | try:
30 | x = x[len('checkpoint-'): -len('.pth')]
31 | x = int(x)
32 | except ValueError:
33 | continue
34 | checkpoint_iters.append(x)
35 | else:
36 | checkpoint_iters = []
37 |
38 | if len(checkpoint_iters) == 0:
39 | print('Did not find a valid checkpoint file.')
40 | else:
41 | checkpoint_iters.sort()
42 | args.resume_path = os.path.join(args.checkpoint_dir, 'checkpoint-%d.pth' % checkpoint_iters[-1])
43 | print(f'Found {len(checkpoint_iters)} checkpoint file(s).')
44 |
45 |
46 | def resume_from_checkpoint(
47 | model: torch.nn.Module,
48 | optimizer: torch.optim.Optimizer,
49 | lr_sched: torch.optim.lr_scheduler._LRScheduler,
50 | loss_scaler: torch.cuda.amp.grad_scaler.GradScaler,
51 | args: argparse.Namespace,
52 | ) -> int:
53 | if args.pretrain is not None:
54 | print(f'Loading pretrain model: {args.pretrain}')
55 | ckpt = torch.load(args.pretrain, map_location='cpu')
56 | print(model.load_state_dict(ckpt['model'], strict=False))
57 |
58 | # returns resume_step on successful resume, or 0 otherwise.
59 | if args.auto_resume and args.resume_path is None:
60 | _find_autoresume_path(args)
61 |
62 | if args.resume_path is None:
63 | print('Not resuming from a checkpoint.')
64 | return 0
65 | else:
66 | print(f'Resuming from checkpoint file {args.resume_path}')
67 | ckpt = torch.load(args.resume_path, map_location='cpu')
68 | model.load_state_dict(ckpt['model'], strict=True)
69 | if 'optimizer' in ckpt:
70 | optimizer.load_state_dict(ckpt['optimizer'])
71 | lr_sched.load_state_dict(ckpt['lr_sched'])
72 | loss_scaler.load_state_dict(ckpt['loss_scaler'])
73 | return ckpt['next_step']
74 | else:
75 | print('Optimizer state is NOT found in checkpoint.')
76 | return 0
77 |
78 |
79 | def save_checkpoint(
80 | model: torch.nn.Module,
81 | optimizer: torch.optim.Optimizer,
82 | lr_sched: torch.optim.lr_scheduler._LRScheduler,
83 | loss_scaler: torch.cuda.amp.grad_scaler.GradScaler,
84 | next_step: int,
85 | args: argparse.Namespace,
86 | ):
87 | if args.checkpoint_dir is None:
88 | return
89 |
90 | if not os.path.isdir(args.checkpoint_dir):
91 | os.makedirs(args.checkpoint_dir)
92 |
93 | to_save = {
94 | 'model': model.state_dict(),
95 | 'optimizer': optimizer.state_dict(),
96 | 'lr_sched': lr_sched.state_dict(),
97 | 'loss_scaler': loss_scaler.state_dict(),
98 | 'next_step': next_step,
99 | }
100 | torch.save(to_save, os.path.join(args.checkpoint_dir, f'checkpoint-{next_step}.pth'))
101 |
--------------------------------------------------------------------------------
/data/k400_class_mappings.json:
--------------------------------------------------------------------------------
1 | [
2 | "abseiling",
3 | "air drumming",
4 | "answering questions",
5 | "applauding",
6 | "applying cream",
7 | "archery",
8 | "arm wrestling",
9 | "arranging flowers",
10 | "assembling computer",
11 | "auctioning",
12 | "baby waking up",
13 | "baking cookies",
14 | "balloon blowing",
15 | "bandaging",
16 | "barbequing",
17 | "bartending",
18 | "beatboxing",
19 | "bee keeping",
20 | "belly dancing",
21 | "bench pressing",
22 | "bending back",
23 | "bending metal",
24 | "biking through snow",
25 | "blasting sand",
26 | "blowing glass",
27 | "blowing leaves",
28 | "blowing nose",
29 | "blowing out candles",
30 | "bobsledding",
31 | "bookbinding",
32 | "bouncing on trampoline",
33 | "bowling",
34 | "braiding hair",
35 | "breading or breadcrumbing",
36 | "breakdancing",
37 | "brush painting",
38 | "brushing hair",
39 | "brushing teeth",
40 | "building cabinet",
41 | "building shed",
42 | "bungee jumping",
43 | "busking",
44 | "canoeing or kayaking",
45 | "capoeira",
46 | "carrying baby",
47 | "cartwheeling",
48 | "carving pumpkin",
49 | "catching fish",
50 | "catching or throwing baseball",
51 | "catching or throwing frisbee",
52 | "catching or throwing softball",
53 | "celebrating",
54 | "changing oil",
55 | "changing wheel",
56 | "checking tires",
57 | "cheerleading",
58 | "chopping wood",
59 | "clapping",
60 | "clay pottery making",
61 | "clean and jerk",
62 | "cleaning floor",
63 | "cleaning gutters",
64 | "cleaning pool",
65 | "cleaning shoes",
66 | "cleaning toilet",
67 | "cleaning windows",
68 | "climbing a rope",
69 | "climbing ladder",
70 | "climbing tree",
71 | "contact juggling",
72 | "cooking chicken",
73 | "cooking egg",
74 | "cooking on campfire",
75 | "cooking sausages",
76 | "counting money",
77 | "country line dancing",
78 | "cracking neck",
79 | "crawling baby",
80 | "crossing river",
81 | "crying",
82 | "curling hair",
83 | "cutting nails",
84 | "cutting pineapple",
85 | "cutting watermelon",
86 | "dancing ballet",
87 | "dancing charleston",
88 | "dancing gangnam style",
89 | "dancing macarena",
90 | "deadlifting",
91 | "decorating the christmas tree",
92 | "digging",
93 | "dining",
94 | "disc golfing",
95 | "diving cliff",
96 | "dodgeball",
97 | "doing aerobics",
98 | "doing laundry",
99 | "doing nails",
100 | "drawing",
101 | "dribbling basketball",
102 | "drinking",
103 | "drinking beer",
104 | "drinking shots",
105 | "driving car",
106 | "driving tractor",
107 | "drop kicking",
108 | "drumming fingers",
109 | "dunking basketball",
110 | "dying hair",
111 | "eating burger",
112 | "eating cake",
113 | "eating carrots",
114 | "eating chips",
115 | "eating doughnuts",
116 | "eating hotdog",
117 | "eating ice cream",
118 | "eating spaghetti",
119 | "eating watermelon",
120 | "egg hunting",
121 | "exercising arm",
122 | "exercising with an exercise ball",
123 | "extinguishing fire",
124 | "faceplanting",
125 | "feeding birds",
126 | "feeding fish",
127 | "feeding goats",
128 | "filling eyebrows",
129 | "finger snapping",
130 | "fixing hair",
131 | "flipping pancake",
132 | "flying kite",
133 | "folding clothes",
134 | "folding napkins",
135 | "folding paper",
136 | "front raises",
137 | "frying vegetables",
138 | "garbage collecting",
139 | "gargling",
140 | "getting a haircut",
141 | "getting a tattoo",
142 | "giving or receiving award",
143 | "golf chipping",
144 | "golf driving",
145 | "golf putting",
146 | "grinding meat",
147 | "grooming dog",
148 | "grooming horse",
149 | "gymnastics tumbling",
150 | "hammer throw",
151 | "headbanging",
152 | "headbutting",
153 | "high jump",
154 | "high kick",
155 | "hitting baseball",
156 | "hockey stop",
157 | "holding snake",
158 | "hopscotch",
159 | "hoverboarding",
160 | "hugging",
161 | "hula hooping",
162 | "hurdling",
163 | "hurling (sport)",
164 | "ice climbing",
165 | "ice fishing",
166 | "ice skating",
167 | "ironing",
168 | "javelin throw",
169 | "jetskiing",
170 | "jogging",
171 | "juggling balls",
172 | "juggling fire",
173 | "juggling soccer ball",
174 | "jumping into pool",
175 | "jumpstyle dancing",
176 | "kicking field goal",
177 | "kicking soccer ball",
178 | "kissing",
179 | "kitesurfing",
180 | "knitting",
181 | "krumping",
182 | "laughing",
183 | "laying bricks",
184 | "long jump",
185 | "lunge",
186 | "making a cake",
187 | "making a sandwich",
188 | "making bed",
189 | "making jewelry",
190 | "making pizza",
191 | "making snowman",
192 | "making sushi",
193 | "making tea",
194 | "marching",
195 | "massaging back",
196 | "massaging feet",
197 | "massaging legs",
198 | "massaging person's head",
199 | "milking cow",
200 | "mopping floor",
201 | "motorcycling",
202 | "moving furniture",
203 | "mowing lawn",
204 | "news anchoring",
205 | "opening bottle",
206 | "opening present",
207 | "paragliding",
208 | "parasailing",
209 | "parkour",
210 | "passing American football (in game)",
211 | "passing American football (not in game)",
212 | "peeling apples",
213 | "peeling potatoes",
214 | "petting animal (not cat)",
215 | "petting cat",
216 | "picking fruit",
217 | "planting trees",
218 | "plastering",
219 | "playing accordion",
220 | "playing badminton",
221 | "playing bagpipes",
222 | "playing basketball",
223 | "playing bass guitar",
224 | "playing cards",
225 | "playing cello",
226 | "playing chess",
227 | "playing clarinet",
228 | "playing controller",
229 | "playing cricket",
230 | "playing cymbals",
231 | "playing didgeridoo",
232 | "playing drums",
233 | "playing flute",
234 | "playing guitar",
235 | "playing harmonica",
236 | "playing harp",
237 | "playing ice hockey",
238 | "playing keyboard",
239 | "playing kickball",
240 | "playing monopoly",
241 | "playing organ",
242 | "playing paintball",
243 | "playing piano",
244 | "playing poker",
245 | "playing recorder",
246 | "playing saxophone",
247 | "playing squash or racquetball",
248 | "playing tennis",
249 | "playing trombone",
250 | "playing trumpet",
251 | "playing ukulele",
252 | "playing violin",
253 | "playing volleyball",
254 | "playing xylophone",
255 | "pole vault",
256 | "presenting weather forecast",
257 | "pull ups",
258 | "pumping fist",
259 | "pumping gas",
260 | "punching bag",
261 | "punching person (boxing)",
262 | "push up",
263 | "pushing car",
264 | "pushing cart",
265 | "pushing wheelchair",
266 | "reading book",
267 | "reading newspaper",
268 | "recording music",
269 | "riding a bike",
270 | "riding camel",
271 | "riding elephant",
272 | "riding mechanical bull",
273 | "riding mountain bike",
274 | "riding mule",
275 | "riding or walking with horse",
276 | "riding scooter",
277 | "riding unicycle",
278 | "ripping paper",
279 | "robot dancing",
280 | "rock climbing",
281 | "rock scissors paper",
282 | "roller skating",
283 | "running on treadmill",
284 | "sailing",
285 | "salsa dancing",
286 | "sanding floor",
287 | "scrambling eggs",
288 | "scuba diving",
289 | "setting table",
290 | "shaking hands",
291 | "shaking head",
292 | "sharpening knives",
293 | "sharpening pencil",
294 | "shaving head",
295 | "shaving legs",
296 | "shearing sheep",
297 | "shining shoes",
298 | "shooting basketball",
299 | "shooting goal (soccer)",
300 | "shot put",
301 | "shoveling snow",
302 | "shredding paper",
303 | "shuffling cards",
304 | "side kick",
305 | "sign language interpreting",
306 | "singing",
307 | "situp",
308 | "skateboarding",
309 | "ski jumping",
310 | "skiing (not slalom or crosscountry)",
311 | "skiing crosscountry",
312 | "skiing slalom",
313 | "skipping rope",
314 | "skydiving",
315 | "slacklining",
316 | "slapping",
317 | "sled dog racing",
318 | "smoking",
319 | "smoking hookah",
320 | "snatch weight lifting",
321 | "sneezing",
322 | "sniffing",
323 | "snorkeling",
324 | "snowboarding",
325 | "snowkiting",
326 | "snowmobiling",
327 | "somersaulting",
328 | "spinning poi",
329 | "spray painting",
330 | "spraying",
331 | "springboard diving",
332 | "squat",
333 | "sticking tongue out",
334 | "stomping grapes",
335 | "stretching arm",
336 | "stretching leg",
337 | "strumming guitar",
338 | "surfing crowd",
339 | "surfing water",
340 | "sweeping floor",
341 | "swimming backstroke",
342 | "swimming breast stroke",
343 | "swimming butterfly stroke",
344 | "swing dancing",
345 | "swinging legs",
346 | "swinging on something",
347 | "sword fighting",
348 | "tai chi",
349 | "taking a shower",
350 | "tango dancing",
351 | "tap dancing",
352 | "tapping guitar",
353 | "tapping pen",
354 | "tasting beer",
355 | "tasting food",
356 | "testifying",
357 | "texting",
358 | "throwing axe",
359 | "throwing ball",
360 | "throwing discus",
361 | "tickling",
362 | "tobogganing",
363 | "tossing coin",
364 | "tossing salad",
365 | "training dog",
366 | "trapezing",
367 | "trimming or shaving beard",
368 | "trimming trees",
369 | "triple jump",
370 | "tying bow tie",
371 | "tying knot (not on a tie)",
372 | "tying tie",
373 | "unboxing",
374 | "unloading truck",
375 | "using computer",
376 | "using remote controller (not gaming)",
377 | "using segway",
378 | "vault",
379 | "waiting in line",
380 | "walking the dog",
381 | "washing dishes",
382 | "washing feet",
383 | "washing hair",
384 | "washing hands",
385 | "water skiing",
386 | "water sliding",
387 | "watering plants",
388 | "waxing back",
389 | "waxing chest",
390 | "waxing eyebrows",
391 | "waxing legs",
392 | "weaving basket",
393 | "welding",
394 | "whistling",
395 | "windsurfing",
396 | "wrapping present",
397 | "wrestling",
398 | "writing",
399 | "yawning",
400 | "yoga",
401 | "zumba"
402 | ]
403 |
--------------------------------------------------------------------------------
/figs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/efficient-video-recognition/b282e8b49e4ecd24f7357361719cf4fa5ab9c40d/figs/arch.png
--------------------------------------------------------------------------------
/figs/k400.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/efficient-video-recognition/b282e8b49e4ecd24f7357361719cf4fa5ab9c40d/figs/k400.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | from datetime import datetime
5 | import builtins
6 |
7 | import torch
8 | import torch.distributed as dist
9 |
10 | import video_dataset
11 | import checkpoint
12 | from model import EVLTransformer
13 | from video_dataset import dataloader
14 | from weight_loaders import weight_loader_fn_dict
15 | from vision_transformer import vit_presets
16 |
17 | def setup_print(is_master: bool):
18 | """
19 | This function disables printing when not in master process
20 | """
21 | builtin_print = builtins.print
22 |
23 | def print(*args, **kwargs):
24 | force = kwargs.pop('force', False)
25 | if is_master or force:
26 | now = datetime.now().time()
27 | builtin_print('[{}] '.format(now), end='') # print with time stamp
28 | builtin_print(*args, **kwargs)
29 |
30 | builtins.print = print
31 |
32 |
33 | def main():
34 | parser = argparse.ArgumentParser()
35 |
36 | video_dataset.setup_arg_parser(parser)
37 | checkpoint.setup_arg_parser(parser)
38 |
39 | parser.add_argument('--num_steps', type=int,
40 | help='number of training steps')
41 | parser.add_argument('--eval_only', action='store_true',
42 | help='run evaluation only')
43 | parser.add_argument('--save_freq', type=int, default=5000,
44 | help='save a checkpoint every N steps')
45 | parser.add_argument('--eval_freq', type=int, default=5000,
46 | help='evaluate every N steps')
47 | parser.add_argument('--print_freq', type=int, default=10,
48 | help='print log message every N steps')
49 |
50 | parser.add_argument('--backbone', type=str, choices=vit_presets.keys(), default='ViT-B/16-lnpre',
51 | help='the backbone variant used to generate image feature maps')
52 | parser.add_argument('--backbone_path', type=str,
53 | help='path to pretrained backbone weights')
54 | parser.add_argument('--backbone_type', type=str, default='clip', choices=weight_loader_fn_dict.keys(),
55 | help='type of backbone weights (used to determine how to convert state_dict from different pretraining codebase)')
56 | parser.add_argument('--finetune_backbone', action='store_true',
57 | help='finetune backbone weights')
58 | parser.add_argument('--decoder_num_layers', type=int, default=4,
59 | help='number of decoder layers')
60 | parser.add_argument('--decoder_qkv_dim', type=int, default=768,
61 | help='q (k, v) projection output dimensions in decoder attention layers')
62 | parser.add_argument('--decoder_num_heads', type=int, default=12,
63 | help='number of heads in decoder attention layers')
64 | parser.add_argument('--decoder_mlp_factor', type=float, default=4.0,
65 | help='expansion factor of feature dimension in the middle of decoder MLPs')
66 | parser.add_argument('--num_classes', type=int, default=400,
67 | help='number of classes')
68 | parser.add_argument('--cls_dropout', type=float, default=0.5,
69 | help='dropout rate applied before the final classification linear projection')
70 | parser.add_argument('--decoder_mlp_dropout', type=float, default=0.5,
71 | help='dropout rate applied in MLP layers in the decoder')
72 | parser.add_argument('--no_temporal_conv', action='store_false', dest='temporal_conv',
73 | help='disable temporal convolution on frame features')
74 | parser.add_argument('--no_temporal_pos_embed', action='store_false', dest='temporal_pos_embed',
75 | help='disable temporal position embeddings added to frame features')
76 | parser.add_argument('--no_temporal_cross_attention', action='store_false', dest='temporal_cross_attention',
77 | help='disable temporal cross attention on frame query and key features')
78 | parser.set_defaults(temporal_conv=True, temporal_pos_embed=True, temporal_cross_attention=True)
79 |
80 | parser.add_argument('--lr', type=float, default=4e-4,
81 | help='learning rate')
82 | parser.add_argument('--weight_decay', type=float, default=0.05,
83 | help='optimizer weight decay')
84 | parser.add_argument('--disable_fp16', action='store_false', dest='fp16',
85 | help='disable fp16 during training or inference')
86 | parser.set_defaults(fp16=True)
87 |
88 | parser.add_argument('--batch_split', type=int, default=1,
89 | help='optionally split the batch into smaller shards and forward/backward one shard '
90 | 'at a time to avoid out-of-memory error.')
91 |
92 | args = parser.parse_args()
93 |
94 | dist.init_process_group('nccl')
95 | setup_print(dist.get_rank() == 0)
96 | cuda_device_id = dist.get_rank() % torch.cuda.device_count()
97 | torch.cuda.set_device(cuda_device_id)
98 |
99 | model = EVLTransformer(
100 | backbone_name=args.backbone,
101 | backbone_type=args.backbone_type,
102 | backbone_path=args.backbone_path,
103 | backbone_mode='finetune' if args.finetune_backbone else ('freeze_fp16' if args.fp16 else 'freeze_fp32'),
104 | decoder_num_layers=args.decoder_num_layers,
105 | decoder_qkv_dim=args.decoder_qkv_dim,
106 | decoder_num_heads=args.decoder_num_heads,
107 | decoder_mlp_factor=args.decoder_mlp_factor,
108 | num_classes=args.num_classes,
109 | enable_temporal_conv=args.temporal_conv,
110 | enable_temporal_pos_embed=args.temporal_pos_embed,
111 | enable_temporal_cross_attention=args.temporal_cross_attention,
112 | cls_dropout=args.cls_dropout,
113 | decoder_mlp_dropout=args.decoder_mlp_dropout,
114 | num_frames=args.num_frames,
115 | )
116 | print(model)
117 | model.cuda()
118 | model = torch.nn.parallel.DistributedDataParallel(
119 | model, device_ids=[cuda_device_id], output_device=cuda_device_id,
120 | )
121 |
122 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
123 | lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)
124 | loss_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=args.fp16)
125 | criterion = torch.nn.CrossEntropyLoss()
126 |
127 | resume_step = checkpoint.resume_from_checkpoint(model, optimizer, lr_sched, loss_scaler, args)
128 |
129 | val_loader = video_dataset.create_val_loader(args)
130 | if args.eval_only:
131 | print('Running in eval_only mode.')
132 | model.eval()
133 | evaluate(model, val_loader)
134 | return
135 | else:
136 | assert args.train_list_path is not None, 'Train list path must be specified if not in eval_only mode.'
137 | train_loader = video_dataset.create_train_loader(args, resume_step=resume_step)
138 |
139 | assert len(train_loader) == args.num_steps - resume_step
140 | batch_st, train_st = datetime.now(), datetime.now()
141 | for i, (data, labels) in enumerate(train_loader, resume_step):
142 | data, labels = data.cuda(), labels.cuda()
143 | data_ed = datetime.now()
144 |
145 | optimizer.zero_grad()
146 |
147 | assert data.size(0) % args.batch_split == 0
148 | split_size = data.size(0) // args.batch_split
149 | hit1, hit5, loss_value = 0, 0, 0
150 | for j in range(args.batch_split):
151 | data_slice = data[split_size * j: split_size * (j + 1)]
152 | labels_slice = labels[split_size * j: split_size * (j + 1)]
153 |
154 | with torch.cuda.amp.autocast(args.fp16):
155 | logits = model(data_slice)
156 | loss = criterion(logits, labels_slice)
157 |
158 | if labels.dtype == torch.long: # no mixup, can calculate accuracy
159 | hit1 += (logits.topk(1, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()
160 | hit5 += (logits.topk(5, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()
161 | loss_value += loss.item() / args.batch_split
162 |
163 | loss_scaler.scale(loss / args.batch_split).backward()
164 |
165 | loss_scaler.step(optimizer)
166 | loss_scaler.update()
167 | lr_sched.step()
168 |
169 | batch_ed = datetime.now()
170 |
171 | if i % args.print_freq == 0:
172 | sync_tensor = torch.Tensor([loss_value, hit1 / data.size(0), hit5 / data.size(0)]).cuda()
173 | dist.all_reduce(sync_tensor)
174 | sync_tensor = sync_tensor.cpu() / dist.get_world_size()
175 | loss_value, acc1, acc5 = sync_tensor.tolist()
176 |
177 | print(
178 | f'batch_time: {(batch_ed - batch_st).total_seconds():.3f} '
179 | f'data_time: {(data_ed - batch_st).total_seconds():.3f} '
180 | f'ETA: {(batch_ed - train_st) / (i - resume_step + 1) * (args.num_steps - i - 1)} | '
181 | f'lr: {optimizer.param_groups[0]["lr"]:.6f} '
182 | f'loss: {loss_value:.6f}' + (
183 | f' acc1: {acc1 * 100:.2f}% acc5: {acc5 * 100:.2f}%' if labels.dtype == torch.long else ''
184 | )
185 | )
186 |
187 | if (i + 1) % args.eval_freq == 0:
188 | print('Start model evaluation at step', i + 1)
189 | model.eval()
190 | evaluate(model, val_loader)
191 | model.train()
192 |
193 | if (i + 1) % args.save_freq == 0:
194 | checkpoint.save_checkpoint(model, optimizer, lr_sched, loss_scaler, i + 1, args)
195 |
196 | batch_st = datetime.now()
197 |
198 |
199 | def evaluate(model: torch.nn.Module, loader: torch.utils.data.DataLoader):
200 | tot, hit1, hit5 = 0, 0, 0
201 | eval_st = datetime.now()
202 | for data, labels in loader:
203 | data, labels = data.cuda(), labels.cuda()
204 | assert data.size(0) == 1
205 | if data.ndim == 6:
206 | data = data[0] # now the first dimension is number of views
207 |
208 | with torch.no_grad():
209 | logits = model(data)
210 | scores = logits.softmax(dim=-1).mean(dim=0)
211 |
212 | tot += 1
213 | hit1 += (scores.topk(1)[1] == labels).sum().item()
214 | hit5 += (scores.topk(5)[1] == labels).sum().item()
215 |
216 | if tot % 20 == 0:
217 | print(f'[Evaluation] num_samples: {tot} '
218 | f'ETA: {(datetime.now() - eval_st) / tot * (len(loader) - tot)} '
219 | f'cumulative_acc1: {hit1 / tot * 100.:.2f}% '
220 | f'cumulative_acc5: {hit5 / tot * 100.:.2f}%')
221 |
222 | sync_tensor = torch.LongTensor([tot, hit1, hit5]).cuda()
223 | dist.all_reduce(sync_tensor)
224 | tot, hit1, hit5 = sync_tensor.cpu().tolist()
225 |
226 | print(f'Accuracy on validation set: top1={hit1 / tot * 100:.2f}%, top5={hit5 / tot * 100:.2f}%')
227 |
228 |
229 | if __name__ == '__main__': main()
230 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from typing import Dict, Iterable, List, Tuple
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from vision_transformer import QuickGELU, Attention
11 | from weight_loaders import weight_loader_fn_dict
12 | from vision_transformer import (
13 | VisionTransformer2D, TransformerDecoderLayer,
14 | model_to_fp16, vit_presets,
15 | )
16 |
17 |
18 | class TemporalCrossAttention(nn.Module):
19 |
20 | def __init__(
21 | self,
22 | spatial_size: Tuple[int, int] = (14, 14),
23 | feature_dim: int = 768,
24 | ):
25 | super().__init__()
26 |
27 | self.spatial_size = spatial_size
28 |
29 | w_size = np.prod([x * 2 - 1 for x in spatial_size])
30 | self.w1 = nn.Parameter(torch.zeros([w_size, feature_dim]))
31 | self.w2 = nn.Parameter(torch.zeros([w_size, feature_dim]))
32 |
33 | idx_tensor = torch.zeros([np.prod(spatial_size) for _ in (0, 1)], dtype=torch.long)
34 | for q in range(np.prod(spatial_size)):
35 | qi, qj = q // spatial_size[1], q % spatial_size[1]
36 | for k in range(np.prod(spatial_size)):
37 | ki, kj = k // spatial_size[1], k % spatial_size[1]
38 | i_offs = qi - ki + spatial_size[0] - 1
39 | j_offs = qj - kj + spatial_size[1] - 1
40 | idx_tensor[q, k] = i_offs * (spatial_size[1] * 2 - 1) + j_offs
41 | self.idx_tensor = idx_tensor
42 |
43 |
44 | def forward_half(self, q: torch.Tensor, k: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
45 | q, k = q[:, :, 1:], k[:, :, 1:] # remove cls token
46 |
47 | assert q.size() == k.size()
48 | assert q.size(2) == np.prod(self.spatial_size)
49 |
50 | attn = torch.einsum('ntqhd,ntkhd->ntqkh', q / (q.size(-1) ** 0.5), k)
51 | attn = attn.softmax(dim=-2).mean(dim=-1) # L, L, N, T
52 |
53 | self.idx_tensor = self.idx_tensor.to(w.device)
54 | w_unroll = w[self.idx_tensor] # L, L, C
55 | ret = torch.einsum('ntqk,qkc->ntqc', attn, w_unroll)
56 |
57 | return ret
58 |
59 |
60 | def forward(self, q: torch.Tensor, k: torch.Tensor):
61 | N, T, L, H, D = q.size()
62 | assert L == np.prod(self.spatial_size) + 1
63 |
64 | ret = torch.zeros([N, T, L, self.w1.size(-1)], device='cuda')
65 | ret[:, 1:, 1:, :] += self.forward_half(q[:, 1:, :, :, :], k[:, :-1, :, :, :], self.w1)
66 | ret[:, :-1, 1:, :] += self.forward_half(q[:, :-1, :, :, :], k[:, 1:, :, :, :], self.w2)
67 |
68 | return ret
69 |
70 |
71 | class EVLDecoder(nn.Module):
72 |
73 | def __init__(
74 | self,
75 | num_frames: int = 8,
76 | spatial_size: Tuple[int, int] = (14, 14),
77 | num_layers: int = 4,
78 | in_feature_dim: int = 768,
79 | qkv_dim: int = 768,
80 | num_heads: int = 12,
81 | mlp_factor: float = 4.0,
82 | enable_temporal_conv: bool = True,
83 | enable_temporal_pos_embed: bool = True,
84 | enable_temporal_cross_attention: bool = True,
85 | mlp_dropout: float = 0.5,
86 | ):
87 | super().__init__()
88 |
89 | self.enable_temporal_conv = enable_temporal_conv
90 | self.enable_temporal_pos_embed = enable_temporal_pos_embed
91 | self.enable_temporal_cross_attention = enable_temporal_cross_attention
92 | self.num_layers = num_layers
93 |
94 | self.decoder_layers = nn.ModuleList(
95 | [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout) for _ in range(num_layers)]
96 | )
97 |
98 | if enable_temporal_conv:
99 | self.temporal_conv = nn.ModuleList(
100 | [nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)]
101 | )
102 | if enable_temporal_pos_embed:
103 | self.temporal_pos_embed = nn.ParameterList(
104 | [nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)]
105 | )
106 | if enable_temporal_cross_attention:
107 | self.cross_attention = nn.ModuleList(
108 | [TemporalCrossAttention(spatial_size, in_feature_dim) for _ in range(num_layers)]
109 | )
110 |
111 | self.cls_token = nn.Parameter(torch.zeros([in_feature_dim]))
112 |
113 |
114 | def _initialize_weights(self):
115 | nn.init.normal_(self.cls_token, std=0.02)
116 |
117 |
118 | def forward(self, in_features: List[Dict[str, torch.Tensor]]):
119 | N, T, L, C = in_features[0]['out'].size()
120 | assert len(in_features) == self.num_layers
121 | x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1)
122 |
123 | for i in range(self.num_layers):
124 | frame_features = in_features[i]['out']
125 |
126 | if self.enable_temporal_conv:
127 | feat = in_features[i]['out']
128 | feat = feat.permute(0, 2, 3, 1).contiguous().flatten(0, 1) # N * L, C, T
129 | feat = self.temporal_conv[i](feat)
130 | feat = feat.view(N, L, C, T).permute(0, 3, 1, 2).contiguous() # N, T, L, C
131 | frame_features += feat
132 |
133 | if self.enable_temporal_pos_embed:
134 | frame_features += self.temporal_pos_embed[i].view(1, T, 1, C)
135 |
136 | if self.enable_temporal_cross_attention:
137 | frame_features += self.cross_attention[i](in_features[i]['q'], in_features[i]['k'])
138 |
139 | frame_features = frame_features.flatten(1, 2) # N, T * L, C
140 |
141 | x = self.decoder_layers[i](x, frame_features)
142 |
143 | return x
144 |
145 |
146 | class EVLTransformer(nn.Module):
147 |
148 | def __init__(
149 | self,
150 | num_frames: int = 8,
151 | backbone_name: str = 'ViT-B/16',
152 | backbone_type: str = 'clip',
153 | backbone_path: str = '',
154 | backbone_mode: str = 'frozen_fp16',
155 | decoder_num_layers: int = 4,
156 | decoder_qkv_dim: int = 768,
157 | decoder_num_heads: int = 12,
158 | decoder_mlp_factor: float = 4.0,
159 | num_classes: int = 400,
160 | enable_temporal_conv: bool = True,
161 | enable_temporal_pos_embed: bool = True,
162 | enable_temporal_cross_attention: bool = True,
163 | cls_dropout: float = 0.5,
164 | decoder_mlp_dropout: float = 0.5,
165 | ):
166 | super().__init__()
167 |
168 | self.decoder_num_layers = decoder_num_layers
169 |
170 | backbone_config = self._create_backbone(backbone_name, backbone_type, backbone_path, backbone_mode)
171 | backbone_feature_dim = backbone_config['feature_dim']
172 | backbone_spatial_size = tuple(x // y for x, y in zip(backbone_config['input_size'], backbone_config['patch_size']))
173 |
174 | self.decoder = EVLDecoder(
175 | num_frames=num_frames,
176 | spatial_size=backbone_spatial_size,
177 | num_layers=decoder_num_layers,
178 | in_feature_dim=backbone_feature_dim,
179 | qkv_dim=decoder_qkv_dim,
180 | num_heads=decoder_num_heads,
181 | mlp_factor=decoder_mlp_factor,
182 | enable_temporal_conv=enable_temporal_conv,
183 | enable_temporal_pos_embed=enable_temporal_pos_embed,
184 | enable_temporal_cross_attention=enable_temporal_cross_attention,
185 | mlp_dropout=decoder_mlp_dropout,
186 | )
187 | self.proj = nn.Sequential(
188 | nn.LayerNorm(backbone_feature_dim),
189 | nn.Dropout(cls_dropout),
190 | nn.Linear(backbone_feature_dim, num_classes),
191 | )
192 |
193 |
194 | def _create_backbone(
195 | self,
196 | backbone_name: str,
197 | backbone_type: str,
198 | backbone_path: str,
199 | backbone_mode: str,
200 | ) -> dict:
201 | weight_loader_fn = weight_loader_fn_dict[backbone_type]
202 | state_dict = weight_loader_fn(backbone_path)
203 |
204 | backbone = VisionTransformer2D(return_all_features=True, **vit_presets[backbone_name])
205 | backbone.load_state_dict(state_dict, strict=True) # weight_loader_fn is expected to strip unused parameters
206 |
207 | assert backbone_mode in ['finetune', 'freeze_fp16', 'freeze_fp32']
208 |
209 | if backbone_mode == 'finetune':
210 | self.backbone = backbone
211 | else:
212 | backbone.eval().requires_grad_(False)
213 | if backbone_mode == 'freeze_fp16':
214 | model_to_fp16(backbone)
215 | self.backbone = [backbone] # avoid backbone parameter registration
216 |
217 | return vit_presets[backbone_name]
218 |
219 |
220 | def _get_backbone(self, x):
221 | if isinstance(self.backbone, list):
222 | # freeze backbone
223 | self.backbone[0] = self.backbone[0].to(x.device)
224 | return self.backbone[0]
225 | else:
226 | # finetune bakbone
227 | return self.backbone
228 |
229 |
230 | def forward(self, x: torch.Tensor):
231 | backbone = self._get_backbone(x)
232 |
233 | B, C, T, H, W = x.size()
234 | x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
235 | features = backbone(x)[-self.decoder_num_layers:]
236 | features = [
237 | dict((k, v.float().view(B, T, *v.size()[1:])) for k, v in x.items())
238 | for x in features
239 | ]
240 |
241 | x = self.decoder(features)
242 | x = self.proj(x[:, 0, :])
243 |
244 | return x
--------------------------------------------------------------------------------
/scripts/eval_k400_vitb16_16f_dec4x768.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -u -m torch.distributed.run --nproc_per_node 4 \
4 | main.py \
5 | --num_steps 50000 \
6 | --backbone "ViT-B/16-lnpre" \
7 | --backbone_type clip \
8 | --backbone_path /path/to/clip_models/ViT-B-16.pt \
9 | --decoder_num_layers 4 \
10 | --decoder_qkv_dim 768 \
11 | --decoder_num_heads 12 \
12 | --num_classes 400 \
13 | --val_list_path /path/to/k400/val.txt \
14 | --batch_size 256 \
15 | --batch_split 1 \
16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
17 | --mean 0.48145466 0.4578275 0.40821073 \
18 | --std 0.26862954 0.26130258 0.27577711 \
19 | --num_workers 12 \
20 | --num_frames 16 \
21 | --sampling_rate 16 \
22 | --num_spatial_views 3 \
23 | --num_temporal_views 1 \
24 | --resume_path /path/to/checkpoint_release/k400_vitb16_16f_dec4x768.pth \
25 | --eval_only
26 |
--------------------------------------------------------------------------------
/scripts/eval_k400_vitb16_32f_dec4x768.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -u -m torch.distributed.run --nproc_per_node 4 \
4 | main.py \
5 | --num_steps 50000 \
6 | --backbone "ViT-B/16-lnpre" \
7 | --backbone_type clip \
8 | --backbone_path /path/to/clip_models/ViT-B-16.pt \
9 | --decoder_num_layers 4 \
10 | --decoder_qkv_dim 768 \
11 | --decoder_num_heads 12 \
12 | --num_classes 400 \
13 | --val_list_path /path/to/k400/val.txt \
14 | --batch_size 256 \
15 | --batch_split 1 \
16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
17 | --mean 0.48145466 0.4578275 0.40821073 \
18 | --std 0.26862954 0.26130258 0.27577711 \
19 | --num_workers 12 \
20 | --num_frames 32 \
21 | --sampling_rate 8 \
22 | --num_spatial_views 3 \
23 | --num_temporal_views 1 \
24 | --resume_path /path/to/checkpoint_release/k400_vitb16_32f_dec4x768.pth \
25 | --eval_only
26 |
--------------------------------------------------------------------------------
/scripts/eval_k400_vitb16_8f_dec4x768.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -u -m torch.distributed.run --nproc_per_node 4 \
4 | main.py \
5 | --num_steps 50000 \
6 | --backbone "ViT-B/16-lnpre" \
7 | --backbone_type clip \
8 | --backbone_path /path/to/clip_models/ViT-B-16.pt \
9 | --decoder_num_layers 4 \
10 | --decoder_qkv_dim 768 \
11 | --decoder_num_heads 12 \
12 | --num_classes 400 \
13 | --val_list_path /path/to/k400/val.txt \
14 | --batch_size 256 \
15 | --batch_split 1 \
16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
17 | --mean 0.48145466 0.4578275 0.40821073 \
18 | --std 0.26862954 0.26130258 0.27577711 \
19 | --num_workers 12 \
20 | --num_frames 8 \
21 | --sampling_rate 16 \
22 | --num_spatial_views 1 \
23 | --num_temporal_views 3 \
24 | --resume_path /path/to/checkpoint_release/k400_vitb16_8f_dec4x768.pth \
25 | --eval_only
26 |
--------------------------------------------------------------------------------
/scripts/eval_k400_vitl14_16f_dec4x1024.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -u -m torch.distributed.run --nproc_per_node 4 \
4 | main.py \
5 | --num_steps 50000 \
6 | --backbone "ViT-L/14-lnpre" \
7 | --backbone_type clip \
8 | --backbone_path /path/to/clip_models/ViT-L-14.pt \
9 | --decoder_num_layers 4 \
10 | --decoder_qkv_dim 1024 \
11 | --decoder_num_heads 16 \
12 | --num_classes 400 \
13 | --val_list_path /path/to/k400/val.txt \
14 | --batch_size 256 \
15 | --batch_split 1 \
16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
17 | --mean 0.48145466 0.4578275 0.40821073 \
18 | --std 0.26862954 0.26130258 0.27577711 \
19 | --num_workers 12 \
20 | --num_frames 16 \
21 | --sampling_rate 16 \
22 | --num_spatial_views 3 \
23 | --num_temporal_views 1 \
24 | --resume_path /path/to/checkpoint_release/k400_vitl14_16f_dec4x1024.pth \
25 | --eval_only
26 |
--------------------------------------------------------------------------------
/scripts/eval_k400_vitl14_32f_dec4x1024.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -u -m torch.distributed.run --nproc_per_node 4 \
4 | main.py \
5 | --num_steps 50000 \
6 | --backbone "ViT-L/14-lnpre" \
7 | --backbone_type clip \
8 | --backbone_path /path/to/clip_models/ViT-L-14.pt \
9 | --decoder_num_layers 4 \
10 | --decoder_qkv_dim 1024 \
11 | --decoder_num_heads 16 \
12 | --num_classes 400 \
13 | --val_list_path /path/to/k400/val.txt \
14 | --batch_size 256 \
15 | --batch_split 1 \
16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
17 | --mean 0.48145466 0.4578275 0.40821073 \
18 | --std 0.26862954 0.26130258 0.27577711 \
19 | --num_workers 12 \
20 | --num_frames 32 \
21 | --sampling_rate 8 \
22 | --num_spatial_views 3 \
23 | --num_temporal_views 1 \
24 | --resume_path /path/to/checkpoint_release/k400_vitl14_32f_dec4x1024.pth \
25 | --eval_only
26 |
--------------------------------------------------------------------------------
/scripts/eval_k400_vitl14_8f_dec4x1024.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python -u -m torch.distributed.run --nproc_per_node 4 \
4 | main.py \
5 | --num_steps 50000 \
6 | --backbone "ViT-L/14-lnpre" \
7 | --backbone_type clip \
8 | --backbone_path /path/to/clip_models/ViT-L-14.pt \
9 | --decoder_num_layers 4 \
10 | --decoder_qkv_dim 1024 \
11 | --decoder_num_heads 16 \
12 | --num_classes 400 \
13 | --val_list_path /path/to/k400/val.txt \
14 | --batch_size 256 \
15 | --batch_split 1 \
16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
17 | --mean 0.48145466 0.4578275 0.40821073 \
18 | --std 0.26862954 0.26130258 0.27577711 \
19 | --num_workers 12 \
20 | --num_frames 8 \
21 | --sampling_rate 16 \
22 | --num_spatial_views 1 \
23 | --num_temporal_views 3 \
24 | --resume_path /path/to/checkpoint_release/k400_vitl14_8f_dec4x1024.pth \
25 | --eval_only
26 |
--------------------------------------------------------------------------------
/scripts/train_k400_vitb16_16f_dec4x768.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | exp_dir=runs/k400_vitb16_16f_dec4x768
4 |
5 | mkdir -p "${exp_dir}"
6 | python -u -m torch.distributed.run --nproc_per_node 8 \
7 | main.py \
8 | --num_steps 50000 \
9 | --backbone "ViT-B/16-lnpre" \
10 | --backbone_type clip \
11 | --backbone_path /path/to/clip_models/ViT-B-16.pt \
12 | --decoder_num_layers 4 \
13 | --decoder_qkv_dim 768 \
14 | --decoder_num_heads 12 \
15 | --num_classes 400 \
16 | --checkpoint_dir "${exp_dir}" \
17 | --auto_resume \
18 | --train_list_path /path/to/k400/train.txt \
19 | --val_list_path /path/to/k400/val.txt \
20 | --batch_size 256 \
21 | --batch_split 1 \
22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
23 | --mean 0.48145466 0.4578275 0.40821073 \
24 | --std 0.26862954 0.26130258 0.27577711 \
25 | --num_workers 12 \
26 | --num_frames 16 \
27 | --sampling_rate 16 \
28 | --num_spatial_views 3 \
29 | --num_temporal_views 1 \
30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"
31 |
--------------------------------------------------------------------------------
/scripts/train_k400_vitb16_32f_dec4x768.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | exp_dir=runs/k400_vitb16_32f_dec4x768
4 |
5 | mkdir -p "${exp_dir}"
6 | python -u -m torch.distributed.run --nproc_per_node 8 \
7 | main.py \
8 | --num_steps 50000 \
9 | --backbone "ViT-B/16-lnpre" \
10 | --backbone_type clip \
11 | --backbone_path /path/to/clip_models/ViT-B-16.pt \
12 | --decoder_num_layers 4 \
13 | --decoder_qkv_dim 768 \
14 | --decoder_num_heads 12 \
15 | --num_classes 400 \
16 | --checkpoint_dir "${exp_dir}" \
17 | --auto_resume \
18 | --train_list_path /path/to/k400/train.txt \
19 | --val_list_path /path/to/k400/val.txt \
20 | --batch_size 256 \
21 | --batch_split 1 \
22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
23 | --mean 0.48145466 0.4578275 0.40821073 \
24 | --std 0.26862954 0.26130258 0.27577711 \
25 | --num_workers 12 \
26 | --num_frames 32 \
27 | --sampling_rate 8 \
28 | --num_spatial_views 3 \
29 | --num_temporal_views 1 \
30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"
31 |
--------------------------------------------------------------------------------
/scripts/train_k400_vitb16_8f_dec4x768.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | exp_dir=runs/k400_vitb16_8f_dec4x768
4 |
5 | mkdir -p "${exp_dir}"
6 | python -u -m torch.distributed.run --nproc_per_node 8 \
7 | main.py \
8 | --num_steps 50000 \
9 | --backbone "ViT-B/16-lnpre" \
10 | --backbone_type clip \
11 | --backbone_path /path/to/clip_models/ViT-B-16.pt \
12 | --decoder_num_layers 4 \
13 | --decoder_qkv_dim 768 \
14 | --decoder_num_heads 12 \
15 | --num_classes 400 \
16 | --checkpoint_dir "${exp_dir}" \
17 | --auto_resume \
18 | --train_list_path /path/to/k400/train.txt \
19 | --val_list_path /path/to/k400/val.txt \
20 | --batch_size 256 \
21 | --batch_split 1 \
22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
23 | --mean 0.48145466 0.4578275 0.40821073 \
24 | --std 0.26862954 0.26130258 0.27577711 \
25 | --num_workers 12 \
26 | --num_frames 8 \
27 | --sampling_rate 16 \
28 | --num_spatial_views 1 \
29 | --num_temporal_views 3 \
30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"
31 |
--------------------------------------------------------------------------------
/scripts/train_k400_vitl14_16f_dec4x1024.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | exp_dir=runs/k400_vitl14_16f_dec4x1024
4 |
5 | mkdir -p "${exp_dir}"
6 | python -u -m torch.distributed.run --nproc_per_node 8 \
7 | main.py \
8 | --num_steps 50000 \
9 | --backbone "ViT-L/14-lnpre" \
10 | --backbone_type clip \
11 | --backbone_path /path/to/clip_models/ViT-L-14.pt \
12 | --decoder_num_layers 4 \
13 | --decoder_qkv_dim 1024 \
14 | --decoder_num_heads 16 \
15 | --num_classes 400 \
16 | --checkpoint_dir "${exp_dir}" \
17 | --auto_resume \
18 | --train_list_path /path/to/k400/train.txt \
19 | --val_list_path /path/to/k400/val.txt \
20 | --batch_size 256 \
21 | --batch_split 2 \
22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
23 | --mean 0.48145466 0.4578275 0.40821073 \
24 | --std 0.26862954 0.26130258 0.27577711 \
25 | --num_workers 12 \
26 | --num_frames 16 \
27 | --sampling_rate 16 \
28 | --num_spatial_views 3 \
29 | --num_temporal_views 1 \
30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"
31 |
--------------------------------------------------------------------------------
/scripts/train_k400_vitl14_32f_dec4x1024.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | exp_dir=runs/k400_vitl14_32f_dec4x1024
4 |
5 | mkdir -p "${exp_dir}"
6 | python -u -m torch.distributed.run --nproc_per_node 8 \
7 | main.py \
8 | --num_steps 50000 \
9 | --backbone "ViT-L/14-lnpre" \
10 | --backbone_type clip \
11 | --backbone_path /path/to/clip_models/ViT-L-14.pt \
12 | --decoder_num_layers 4 \
13 | --decoder_qkv_dim 1024 \
14 | --decoder_num_heads 16 \
15 | --num_classes 400 \
16 | --checkpoint_dir "${exp_dir}" \
17 | --auto_resume \
18 | --train_list_path /path/to/k400/train.txt \
19 | --val_list_path /path/to/k400/val.txt \
20 | --batch_size 256 \
21 | --batch_split 4 \
22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
23 | --mean 0.48145466 0.4578275 0.40821073 \
24 | --std 0.26862954 0.26130258 0.27577711 \
25 | --num_workers 12 \
26 | --num_frames 32 \
27 | --sampling_rate 8 \
28 | --num_spatial_views 3 \
29 | --num_temporal_views 1 \
30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"
31 |
--------------------------------------------------------------------------------
/scripts/train_k400_vitl14_8f_dec4x1024.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | exp_dir=runs/k400_vitl14_8f_dec4x1024
4 |
5 | mkdir -p "${exp_dir}"
6 | python -u -m torch.distributed.run --nproc_per_node 8 \
7 | main.py \
8 | --num_steps 50000 \
9 | --backbone "ViT-L/14-lnpre" \
10 | --backbone_type clip \
11 | --backbone_path /path/to/clip_models/ViT-L-14.pt \
12 | --decoder_num_layers 4 \
13 | --decoder_qkv_dim 1024 \
14 | --decoder_num_heads 16 \
15 | --num_classes 400 \
16 | --checkpoint_dir "${exp_dir}" \
17 | --auto_resume \
18 | --train_list_path /path/to/k400/train.txt \
19 | --val_list_path /path/to/k400/val.txt \
20 | --batch_size 256 \
21 | --batch_split 1 \
22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \
23 | --mean 0.48145466 0.4578275 0.40821073 \
24 | --std 0.26862954 0.26130258 0.27577711 \
25 | --num_workers 12 \
26 | --num_frames 8 \
27 | --sampling_rate 16 \
28 | --num_spatial_views 1 \
29 | --num_temporal_views 3 \
30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"
31 |
--------------------------------------------------------------------------------
/video_dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from .dataloader import setup_arg_parser, create_train_loader, create_val_loader
--------------------------------------------------------------------------------
/video_dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | from typing import Dict
5 |
6 | import torch
7 | import torch.distributed as dist
8 |
9 | from .dataset import VideoDataset, DummyDataset
10 |
11 | def setup_arg_parser(parser: argparse.ArgumentParser):
12 | parser.add_argument('--train_list_path', type=str,
13 | help='path to training data list')
14 | parser.add_argument('--val_list_path', type=str,
15 | help='path to validation data list')
16 | parser.add_argument('--train_data_root', type=str,
17 | help='training samples root directory')
18 | parser.add_argument('--val_data_root', type=str,
19 | help='validation samples root directory')
20 | parser.add_argument('--data_root', type=str, default='',
21 | help='training and validation samples root directory, might be overrided by --train_data_root or --val_data_root')
22 |
23 | parser.add_argument('--batch_size', type=int,
24 | help='training batch size on a all GPUs')
25 |
26 | parser.add_argument('--num_spatial_views', type=int, default=1,
27 | help='number of spatial crops used for testing (total views = num_spatial_views * num_temporal_views)')
28 | parser.add_argument('--num_temporal_views', type=int, default=3,
29 | help='number of temporal crops used for testing (total views = num_spatial_views * num_temporal_views)')
30 | parser.add_argument('--num_frames', type=int, default=8,
31 | help='number of frames used for each view')
32 | parser.add_argument('--sampling_rate', type=int, default=16,
33 | help='temporal stride for frame sampling, only valid when tsn_sampling is not enabled')
34 | parser.add_argument('--tsn_sampling', action='store_true',
35 | help='enable TSN-style sampling (i.e. sample frames with dynamic stride to cover the whole video)')
36 | parser.add_argument('--spatial_size', type=int, default=224,
37 | help='frame height and width in pixels')
38 |
39 | parser.add_argument('--mean', type=float, nargs='+',
40 | help='pixel mean used to normalize the image.')
41 | parser.add_argument('--std', type=float, nargs='+',
42 | help='pixel std used to normalize the image')
43 |
44 | parser.add_argument('--num_workers', type=int, default=10,
45 | help='number of DataLoader worker threads')
46 |
47 | parser.add_argument('--dummy_dataset', action='store_true',
48 | help='use fake datasets that generate all 0 (use for speed test only)')
49 |
50 | parser.add_argument('--auto_augment', type=str,
51 | help='auto augment configuration')
52 | parser.add_argument('--interpolation', type=str, default='bicubic',
53 | help='interpolation mode')
54 | parser.add_argument('--no_mirror', action='store_false', dest='mirror',
55 | help='disable mirror for training (frequently used for the something-something dataset)')
56 | parser.set_defaults(mirror=True)
57 |
58 |
59 | def _parse_mean_and_std(args: argparse.Namespace) -> Dict[str, torch.Tensor]:
60 | def parse_mean_or_std(arg, default_value):
61 | if arg is None:
62 | return torch.Tensor([default_value] * 3)
63 | elif len(arg) == 1:
64 | return torch.Tensor(arg * 3)
65 | elif len(arg) == 3:
66 | return torch.Tensor(arg)
67 | else:
68 | raise NotImplementedError()
69 | return {
70 | 'mean': parse_mean_or_std(args.mean, 0.45),
71 | 'std': parse_mean_or_std(args.std, 0.225),
72 | }
73 |
74 |
75 | def create_train_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:
76 | if args.dummy_dataset:
77 | return DummyDataset(
78 | list_path=args.train_list_path,
79 | num_frames=args.num_frames,
80 | num_views=1,
81 | spatial_size=args.spatial_size,
82 | )
83 |
84 | return VideoDataset(
85 | list_path=args.train_list_path,
86 | data_root=args.train_data_root or args.data_root,
87 | num_spatial_views=1, num_temporal_views=1, random_sample=True,
88 | auto_augment=args.auto_augment,
89 | interpolation=args.interpolation,
90 | mirror=args.mirror,
91 | num_frames=args.num_frames,
92 | sampling_rate=-1 if args.tsn_sampling else args.sampling_rate,
93 | spatial_size=args.spatial_size,
94 | **_parse_mean_and_std(args),
95 | )
96 |
97 |
98 | def create_train_loader(args: argparse.Namespace, resume_step: int = 0) -> torch.utils.data.DataLoader:
99 | dataset = create_train_dataset(args)
100 | rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())
101 |
102 | assert args.batch_size % world_size == 0
103 | batch_size_per_gpu = args.batch_size // world_size
104 |
105 | # manually create a step-based sampler
106 | sampler = []
107 | while len(sampler) * len(dataset) < args.num_steps * args.batch_size:
108 | g = torch.Generator()
109 | g.manual_seed(len(sampler))
110 | indices = torch.randperm(len(dataset), generator=g)
111 | sampler.append(indices)
112 | sampler = torch.cat(sampler, dim=0)[:args.num_steps * args.batch_size].view(args.num_steps, args.batch_size)
113 | sampler = sampler[resume_step:, batch_size_per_gpu * rank: batch_size_per_gpu * (rank + 1)].flatten().tolist()
114 |
115 | loader = torch.utils.data.DataLoader(
116 | dataset, sampler=sampler, batch_size=batch_size_per_gpu,
117 | num_workers=args.num_workers, pin_memory=False, drop_last=True,
118 | )
119 |
120 | return loader
121 |
122 |
123 | def create_val_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:
124 | if args.dummy_dataset:
125 | return DummyDataset(
126 | list_path=args.val_list_path,
127 | num_frames=args.num_frames,
128 | num_views=args.num_spatial_views * args.num_temporal_views,
129 | spatial_size=args.spatial_size,
130 | )
131 |
132 | return VideoDataset(
133 | list_path=args.val_list_path,
134 | data_root=args.val_data_root or args.data_root,
135 | num_spatial_views=args.num_spatial_views,
136 | num_temporal_views=args.num_temporal_views,
137 | random_sample=False,
138 | num_frames=args.num_frames,
139 | sampling_rate=-1 if args.tsn_sampling else args.sampling_rate,
140 | spatial_size=args.spatial_size,
141 | **_parse_mean_and_std(args),
142 | )
143 |
144 |
145 | def create_val_loader(args: argparse.Namespace) -> torch.utils.data.Dataset:
146 | dataset = create_val_dataset(args)
147 | rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())
148 |
149 | # sampler for distribued eval
150 | sampler = list(range(rank, len(dataset), world_size))
151 |
152 | loader = torch.utils.data.DataLoader(
153 | dataset, sampler=sampler, batch_size=1,
154 | num_workers=args.num_workers, pin_memory=False,
155 | )
156 |
157 | return loader
158 |
--------------------------------------------------------------------------------
/video_dataset/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os, sys
4 | from typing import Optional
5 | import av
6 | import io
7 | import numpy as np
8 |
9 | import torch
10 | from torchvision import transforms
11 |
12 | from .transform import create_random_augment, random_resized_crop
13 |
14 | class VideoDataset(torch.utils.data.Dataset):
15 |
16 | def __init__(
17 | self, list_path: str, data_root: str,
18 | num_spatial_views: int, num_temporal_views: int, random_sample: bool,
19 | num_frames: int, sampling_rate: int, spatial_size: int,
20 | mean: torch.Tensor, std: torch.Tensor,
21 | auto_augment: Optional[str] = None, interpolation: str = 'bicubic',
22 | mirror: bool = False,
23 | ):
24 | self.data_root = data_root
25 | self.interpolation = interpolation
26 | self.spatial_size = spatial_size
27 |
28 | self.mean, self.std = mean, std
29 | self.num_frames, self.sampling_rate = num_frames, sampling_rate
30 |
31 | if random_sample:
32 | assert num_spatial_views == 1 and num_temporal_views == 1
33 | self.random_sample = True
34 | self.mirror = mirror
35 | self.auto_augment = auto_augment
36 | else:
37 | assert auto_augment is None and not mirror
38 | self.random_sample = False
39 | self.num_temporal_views = num_temporal_views
40 | self.num_spatial_views = num_spatial_views
41 |
42 | with open(list_path) as f:
43 | self.data_list = f.read().splitlines()
44 |
45 |
46 | def __len__(self):
47 | return len(self.data_list)
48 |
49 |
50 | def __getitem__(self, idx):
51 | line = self.data_list[idx]
52 | path, label = line.split(' ')
53 | path = os.path.join(self.data_root, path)
54 | label = int(label)
55 |
56 | container = av.open(path)
57 | frames = {}
58 | for frame in container.decode(video=0):
59 | frames[frame.pts] = frame
60 | container.close()
61 | frames = [frames[k] for k in sorted(frames.keys())]
62 |
63 | if self.random_sample:
64 | frame_idx = self._random_sample_frame_idx(len(frames))
65 | frames = [frames[x].to_rgb().to_ndarray() for x in frame_idx]
66 | frames = torch.as_tensor(np.stack(frames)).float() / 255.
67 |
68 | if self.auto_augment is not None:
69 | aug_transform = create_random_augment(
70 | input_size=(frames.size(1), frames.size(2)),
71 | auto_augment=self.auto_augment,
72 | interpolation=self.interpolation,
73 | )
74 | frames = frames.permute(0, 3, 1, 2) # T, C, H, W
75 | frames = [transforms.ToPILImage()(frames[i]) for i in range(frames.size(0))]
76 | frames = aug_transform(frames)
77 | frames = torch.stack([transforms.ToTensor()(img) for img in frames])
78 | frames = frames.permute(0, 2, 3, 1)
79 |
80 | frames = (frames - self.mean) / self.std
81 | frames = frames.permute(3, 0, 1, 2) # C, T, H, W
82 | frames = random_resized_crop(
83 | frames, self.spatial_size, self.spatial_size,
84 | )
85 |
86 | else:
87 | frames = [x.to_rgb().to_ndarray() for x in frames]
88 | frames = torch.as_tensor(np.stack(frames))
89 | frames = frames.float() / 255.
90 |
91 | frames = (frames - self.mean) / self.std
92 | frames = frames.permute(3, 0, 1, 2) # C, T, H, W
93 |
94 | if frames.size(-2) < frames.size(-1):
95 | new_width = frames.size(-1) * self.spatial_size // frames.size(-2)
96 | new_height = self.spatial_size
97 | else:
98 | new_height = frames.size(-2) * self.spatial_size // frames.size(-1)
99 | new_width = self.spatial_size
100 | frames = torch.nn.functional.interpolate(
101 | frames, size=(new_height, new_width),
102 | mode='bilinear', align_corners=False,
103 | )
104 |
105 | frames = self._generate_spatial_crops(frames)
106 | frames = sum([self._generate_temporal_crops(x) for x in frames], [])
107 | if len(frames) > 1:
108 | frames = torch.stack(frames)
109 |
110 | return frames, label
111 |
112 |
113 | def _generate_temporal_crops(self, frames):
114 | seg_len = (self.num_frames - 1) * self.sampling_rate + 1
115 | if frames.size(1) < seg_len:
116 | frames = torch.cat([frames, frames[:, -1:].repeat(1, seg_len - frames.size(1), 1, 1)], dim=1)
117 | slide_len = frames.size(1) - seg_len
118 |
119 | crops = []
120 | for i in range(self.num_temporal_views):
121 | if self.num_temporal_views == 1:
122 | st = slide_len // 2
123 | else:
124 | st = round(slide_len / (self.num_temporal_views - 1) * i)
125 |
126 | crops.append(frames[:, st: st + self.num_frames * self.sampling_rate: self.sampling_rate])
127 |
128 | return crops
129 |
130 |
131 | def _generate_spatial_crops(self, frames):
132 | if self.num_spatial_views == 1:
133 | assert min(frames.size(-2), frames.size(-1)) >= self.spatial_size
134 | h_st = (frames.size(-2) - self.spatial_size) // 2
135 | w_st = (frames.size(-1) - self.spatial_size) // 2
136 | h_ed, w_ed = h_st + self.spatial_size, w_st + self.spatial_size
137 | return [frames[:, :, h_st: h_ed, w_st: w_ed]]
138 |
139 | elif self.num_spatial_views == 3:
140 | assert min(frames.size(-2), frames.size(-1)) == self.spatial_size
141 | crops = []
142 | margin = max(frames.size(-2), frames.size(-1)) - self.spatial_size
143 | for st in (0, margin // 2, margin):
144 | ed = st + self.spatial_size
145 | if frames.size(-2) > frames.size(-1):
146 | crops.append(frames[:, :, st: ed, :])
147 | else:
148 | crops.append(frames[:, :, :, st: ed])
149 | return crops
150 |
151 | else:
152 | raise NotImplementedError()
153 |
154 |
155 | def _random_sample_frame_idx(self, len):
156 | frame_indices = []
157 |
158 | if self.sampling_rate < 0: # tsn sample
159 | seg_size = (len - 1) / self.num_frames
160 | for i in range(self.num_frames):
161 | start, end = round(seg_size * i), round(seg_size * (i + 1))
162 | frame_indices.append(np.random.randint(start, end + 1))
163 | elif self.sampling_rate * (self.num_frames - 1) + 1 >= len:
164 | for i in range(self.num_frames):
165 | frame_indices.append(i * self.sampling_rate if i * self.sampling_rate < len else frame_indices[-1])
166 | else:
167 | start = np.random.randint(len - self.sampling_rate * (self.num_frames - 1))
168 | frame_indices = list(range(start, start + self.sampling_rate * self.num_frames, self.sampling_rate))
169 |
170 | return frame_indices
171 |
172 |
173 | class DummyDataset(torch.utils.data.Dataset):
174 |
175 | def __init__(self, list_path: str, num_frames: int, num_views: int, spatial_size: int):
176 | with open(list_path) as f:
177 | self.len = len(f.read().splitlines())
178 | self.num_frames = num_frames
179 | self.num_views = num_views
180 | self.spatial_size = spatial_size
181 |
182 | def __len__(self):
183 | return self.len
184 |
185 | def __getitem__(self, _):
186 | shape = [3, self.num_frames, self.spatial_size, self.spatial_size]
187 | if self.num_views != 1:
188 | shape = [self.num_views] + shape
189 | return torch.zeros(shape), 0
190 |
--------------------------------------------------------------------------------
/video_dataset/rand_augment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/rand_augment.py
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
5 |
6 | """
7 | This implementation is based on
8 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
9 | pulished under an Apache License 2.0.
10 |
11 | COMMENT FROM ORIGINAL:
12 | AutoAugment, RandAugment, and AugMix for PyTorch
13 | This code implements the searched ImageNet policies with various tweaks and
14 | improvements and does not include any of the search code. AA and RA
15 | Implementation adapted from:
16 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
17 | AugMix adapted from:
18 | https://github.com/google-research/augmix
19 | Papers:
20 | AutoAugment: Learning Augmentation Policies from Data
21 | https://arxiv.org/abs/1805.09501
22 | Learning Data Augmentation Strategies for Object Detection
23 | https://arxiv.org/abs/1906.11172
24 | RandAugment: Practical automated data augmentation...
25 | https://arxiv.org/abs/1909.13719
26 | AugMix: A Simple Data Processing Method to Improve Robustness and
27 | Uncertainty https://arxiv.org/abs/1912.02781
28 |
29 | Hacked together by / Copyright 2020 Ross Wightman
30 | """
31 |
32 | import math
33 | import numpy as np
34 | import random
35 | import re
36 | import PIL
37 | from PIL import Image, ImageEnhance, ImageOps
38 |
39 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
40 |
41 | _FILL = (128, 128, 128)
42 |
43 | # This signifies the max integer that the controller RNN could predict for the
44 | # augmentation scheme.
45 | _MAX_LEVEL = 10.0
46 |
47 | _HPARAMS_DEFAULT = {
48 | "translate_const": 250,
49 | "img_mean": _FILL,
50 | }
51 |
52 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
53 |
54 |
55 | def _interpolation(kwargs):
56 | interpolation = kwargs.pop("resample", Image.BILINEAR)
57 | if isinstance(interpolation, (list, tuple)):
58 | return random.choice(interpolation)
59 | else:
60 | return interpolation
61 |
62 |
63 | def _check_args_tf(kwargs):
64 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
65 | kwargs.pop("fillcolor")
66 | kwargs["resample"] = _interpolation(kwargs)
67 |
68 |
69 | def shear_x(img, factor, **kwargs):
70 | _check_args_tf(kwargs)
71 | return img.transform(
72 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
73 | )
74 |
75 |
76 | def shear_y(img, factor, **kwargs):
77 | _check_args_tf(kwargs)
78 | return img.transform(
79 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
80 | )
81 |
82 |
83 | def translate_x_rel(img, pct, **kwargs):
84 | pixels = pct * img.size[0]
85 | _check_args_tf(kwargs)
86 | return img.transform(
87 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
88 | )
89 |
90 |
91 | def translate_y_rel(img, pct, **kwargs):
92 | pixels = pct * img.size[1]
93 | _check_args_tf(kwargs)
94 | return img.transform(
95 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
96 | )
97 |
98 |
99 | def translate_x_abs(img, pixels, **kwargs):
100 | _check_args_tf(kwargs)
101 | return img.transform(
102 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
103 | )
104 |
105 |
106 | def translate_y_abs(img, pixels, **kwargs):
107 | _check_args_tf(kwargs)
108 | return img.transform(
109 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
110 | )
111 |
112 |
113 | def rotate(img, degrees, **kwargs):
114 | _check_args_tf(kwargs)
115 | if _PIL_VER >= (5, 2):
116 | return img.rotate(degrees, **kwargs)
117 | elif _PIL_VER >= (5, 0):
118 | w, h = img.size
119 | post_trans = (0, 0)
120 | rotn_center = (w / 2.0, h / 2.0)
121 | angle = -math.radians(degrees)
122 | matrix = [
123 | round(math.cos(angle), 15),
124 | round(math.sin(angle), 15),
125 | 0.0,
126 | round(-math.sin(angle), 15),
127 | round(math.cos(angle), 15),
128 | 0.0,
129 | ]
130 |
131 | def transform(x, y, matrix):
132 | (a, b, c, d, e, f) = matrix
133 | return a * x + b * y + c, d * x + e * y + f
134 |
135 | matrix[2], matrix[5] = transform(
136 | -rotn_center[0] - post_trans[0],
137 | -rotn_center[1] - post_trans[1],
138 | matrix,
139 | )
140 | matrix[2] += rotn_center[0]
141 | matrix[5] += rotn_center[1]
142 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
143 | else:
144 | return img.rotate(degrees, resample=kwargs["resample"])
145 |
146 |
147 | def auto_contrast(img, **__):
148 | return ImageOps.autocontrast(img)
149 |
150 |
151 | def invert(img, **__):
152 | return ImageOps.invert(img)
153 |
154 |
155 | def equalize(img, **__):
156 | return ImageOps.equalize(img)
157 |
158 |
159 | def solarize(img, thresh, **__):
160 | return ImageOps.solarize(img, thresh)
161 |
162 |
163 | def solarize_add(img, add, thresh=128, **__):
164 | lut = []
165 | for i in range(256):
166 | if i < thresh:
167 | lut.append(min(255, i + add))
168 | else:
169 | lut.append(i)
170 | if img.mode in ("L", "RGB"):
171 | if img.mode == "RGB" and len(lut) == 256:
172 | lut = lut + lut + lut
173 | return img.point(lut)
174 | else:
175 | return img
176 |
177 |
178 | def posterize(img, bits_to_keep, **__):
179 | if bits_to_keep >= 8:
180 | return img
181 | return ImageOps.posterize(img, bits_to_keep)
182 |
183 |
184 | def contrast(img, factor, **__):
185 | return ImageEnhance.Contrast(img).enhance(factor)
186 |
187 |
188 | def color(img, factor, **__):
189 | return ImageEnhance.Color(img).enhance(factor)
190 |
191 |
192 | def brightness(img, factor, **__):
193 | return ImageEnhance.Brightness(img).enhance(factor)
194 |
195 |
196 | def sharpness(img, factor, **__):
197 | return ImageEnhance.Sharpness(img).enhance(factor)
198 |
199 |
200 | def _randomly_negate(v):
201 | """With 50% prob, negate the value"""
202 | return -v if random.random() > 0.5 else v
203 |
204 |
205 | def _rotate_level_to_arg(level, _hparams):
206 | # range [-30, 30]
207 | level = (level / _MAX_LEVEL) * 30.0
208 | level = _randomly_negate(level)
209 | return (level,)
210 |
211 |
212 | def _enhance_level_to_arg(level, _hparams):
213 | # range [0.1, 1.9]
214 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
215 |
216 |
217 | def _enhance_increasing_level_to_arg(level, _hparams):
218 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
219 | # range [0.1, 1.9]
220 | level = (level / _MAX_LEVEL) * 0.9
221 | level = 1.0 + _randomly_negate(level)
222 | return (level,)
223 |
224 |
225 | def _shear_level_to_arg(level, _hparams):
226 | # range [-0.3, 0.3]
227 | level = (level / _MAX_LEVEL) * 0.3
228 | level = _randomly_negate(level)
229 | return (level,)
230 |
231 |
232 | def _translate_abs_level_to_arg(level, hparams):
233 | translate_const = hparams["translate_const"]
234 | level = (level / _MAX_LEVEL) * float(translate_const)
235 | level = _randomly_negate(level)
236 | return (level,)
237 |
238 |
239 | def _translate_rel_level_to_arg(level, hparams):
240 | # default range [-0.45, 0.45]
241 | translate_pct = hparams.get("translate_pct", 0.45)
242 | level = (level / _MAX_LEVEL) * translate_pct
243 | level = _randomly_negate(level)
244 | return (level,)
245 |
246 |
247 | def _posterize_level_to_arg(level, _hparams):
248 | # As per Tensorflow TPU EfficientNet impl
249 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
250 | # intensity/severity of augmentation decreases with level
251 | return (int((level / _MAX_LEVEL) * 4),)
252 |
253 |
254 | def _posterize_increasing_level_to_arg(level, hparams):
255 | # As per Tensorflow models research and UDA impl
256 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
257 | # intensity/severity of augmentation increases with level
258 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
259 |
260 |
261 | def _posterize_original_level_to_arg(level, _hparams):
262 | # As per original AutoAugment paper description
263 | # range [4, 8], 'keep 4 up to 8 MSB of image'
264 | # intensity/severity of augmentation decreases with level
265 | return (int((level / _MAX_LEVEL) * 4) + 4,)
266 |
267 |
268 | def _solarize_level_to_arg(level, _hparams):
269 | # range [0, 256]
270 | # intensity/severity of augmentation decreases with level
271 | return (int((level / _MAX_LEVEL) * 256),)
272 |
273 |
274 | def _solarize_increasing_level_to_arg(level, _hparams):
275 | # range [0, 256]
276 | # intensity/severity of augmentation increases with level
277 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
278 |
279 |
280 | def _solarize_add_level_to_arg(level, _hparams):
281 | # range [0, 110]
282 | return (int((level / _MAX_LEVEL) * 110),)
283 |
284 |
285 | LEVEL_TO_ARG = {
286 | "AutoContrast": None,
287 | "Equalize": None,
288 | "Invert": None,
289 | "Rotate": _rotate_level_to_arg,
290 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
291 | "Posterize": _posterize_level_to_arg,
292 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
293 | "PosterizeOriginal": _posterize_original_level_to_arg,
294 | "Solarize": _solarize_level_to_arg,
295 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
296 | "SolarizeAdd": _solarize_add_level_to_arg,
297 | "Color": _enhance_level_to_arg,
298 | "ColorIncreasing": _enhance_increasing_level_to_arg,
299 | "Contrast": _enhance_level_to_arg,
300 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
301 | "Brightness": _enhance_level_to_arg,
302 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
303 | "Sharpness": _enhance_level_to_arg,
304 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
305 | "ShearX": _shear_level_to_arg,
306 | "ShearY": _shear_level_to_arg,
307 | "TranslateX": _translate_abs_level_to_arg,
308 | "TranslateY": _translate_abs_level_to_arg,
309 | "TranslateXRel": _translate_rel_level_to_arg,
310 | "TranslateYRel": _translate_rel_level_to_arg,
311 | }
312 |
313 |
314 | NAME_TO_OP = {
315 | "AutoContrast": auto_contrast,
316 | "Equalize": equalize,
317 | "Invert": invert,
318 | "Rotate": rotate,
319 | "Posterize": posterize,
320 | "PosterizeIncreasing": posterize,
321 | "PosterizeOriginal": posterize,
322 | "Solarize": solarize,
323 | "SolarizeIncreasing": solarize,
324 | "SolarizeAdd": solarize_add,
325 | "Color": color,
326 | "ColorIncreasing": color,
327 | "Contrast": contrast,
328 | "ContrastIncreasing": contrast,
329 | "Brightness": brightness,
330 | "BrightnessIncreasing": brightness,
331 | "Sharpness": sharpness,
332 | "SharpnessIncreasing": sharpness,
333 | "ShearX": shear_x,
334 | "ShearY": shear_y,
335 | "TranslateX": translate_x_abs,
336 | "TranslateY": translate_y_abs,
337 | "TranslateXRel": translate_x_rel,
338 | "TranslateYRel": translate_y_rel,
339 | }
340 |
341 |
342 | class AugmentOp:
343 | """
344 | Apply for video.
345 | """
346 |
347 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
348 | hparams = hparams or _HPARAMS_DEFAULT
349 | self.aug_fn = NAME_TO_OP[name]
350 | self.level_fn = LEVEL_TO_ARG[name]
351 | self.prob = prob
352 | self.magnitude = magnitude
353 | self.hparams = hparams.copy()
354 | self.kwargs = {
355 | "fillcolor": hparams["img_mean"]
356 | if "img_mean" in hparams
357 | else _FILL,
358 | "resample": hparams["interpolation"]
359 | if "interpolation" in hparams
360 | else _RANDOM_INTERPOLATION,
361 | }
362 |
363 | # If magnitude_std is > 0, we introduce some randomness
364 | # in the usually fixed policy and sample magnitude from a normal distribution
365 | # with mean `magnitude` and std-dev of `magnitude_std`.
366 | # NOTE This is my own hack, being tested, not in papers or reference impls.
367 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
368 |
369 | def __call__(self, img_list):
370 | if self.prob < 1.0 and random.random() > self.prob:
371 | return img_list
372 | magnitude = self.magnitude
373 | if self.magnitude_std and self.magnitude_std > 0:
374 | magnitude = random.gauss(magnitude, self.magnitude_std)
375 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
376 | level_args = (
377 | self.level_fn(magnitude, self.hparams)
378 | if self.level_fn is not None
379 | else ()
380 | )
381 |
382 | if isinstance(img_list, list):
383 | return [
384 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
385 | ]
386 | else:
387 | return self.aug_fn(img_list, *level_args, **self.kwargs)
388 |
389 |
390 | _RAND_TRANSFORMS = [
391 | "AutoContrast",
392 | "Equalize",
393 | "Invert",
394 | "Rotate",
395 | "Posterize",
396 | "Solarize",
397 | "SolarizeAdd",
398 | "Color",
399 | "Contrast",
400 | "Brightness",
401 | "Sharpness",
402 | "ShearX",
403 | "ShearY",
404 | "TranslateXRel",
405 | "TranslateYRel",
406 | ]
407 |
408 |
409 | _RAND_INCREASING_TRANSFORMS = [
410 | "AutoContrast",
411 | "Equalize",
412 | "Invert",
413 | "Rotate",
414 | "PosterizeIncreasing",
415 | "SolarizeIncreasing",
416 | "SolarizeAdd",
417 | "ColorIncreasing",
418 | "ContrastIncreasing",
419 | "BrightnessIncreasing",
420 | "SharpnessIncreasing",
421 | "ShearX",
422 | "ShearY",
423 | "TranslateXRel",
424 | "TranslateYRel",
425 | ]
426 |
427 |
428 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
429 | # They may not result in increased performance, but could likely be tuned to so.
430 | _RAND_CHOICE_WEIGHTS_0 = {
431 | "Rotate": 0.3,
432 | "ShearX": 0.2,
433 | "ShearY": 0.2,
434 | "TranslateXRel": 0.1,
435 | "TranslateYRel": 0.1,
436 | "Color": 0.025,
437 | "Sharpness": 0.025,
438 | "AutoContrast": 0.025,
439 | "Solarize": 0.005,
440 | "SolarizeAdd": 0.005,
441 | "Contrast": 0.005,
442 | "Brightness": 0.005,
443 | "Equalize": 0.005,
444 | "Posterize": 0,
445 | "Invert": 0,
446 | }
447 |
448 |
449 | def _select_rand_weights(weight_idx=0, transforms=None):
450 | transforms = transforms or _RAND_TRANSFORMS
451 | assert weight_idx == 0 # only one set of weights currently
452 | rand_weights = _RAND_CHOICE_WEIGHTS_0
453 | probs = [rand_weights[k] for k in transforms]
454 | probs /= np.sum(probs)
455 | return probs
456 |
457 |
458 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
459 | hparams = hparams or _HPARAMS_DEFAULT
460 | transforms = transforms or _RAND_TRANSFORMS
461 | return [
462 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
463 | for name in transforms
464 | ]
465 |
466 |
467 | class RandAugment:
468 | def __init__(self, ops, num_layers=2, choice_weights=None):
469 | self.ops = ops
470 | self.num_layers = num_layers
471 | self.choice_weights = choice_weights
472 |
473 | def __call__(self, img):
474 | # no replacement when using weighted choice
475 | ops = np.random.choice(
476 | self.ops,
477 | self.num_layers,
478 | replace=self.choice_weights is None,
479 | p=self.choice_weights,
480 | )
481 | for op in ops:
482 | img = op(img)
483 | return img
484 |
485 |
486 | def rand_augment_transform(config_str, hparams):
487 | """
488 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
489 |
490 | Create a RandAugment transform
491 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
492 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
493 | sections, not order sepecific determine
494 | 'm' - integer magnitude of rand augment
495 | 'n' - integer num layers (number of transform ops selected per image)
496 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
497 | 'mstd' - float std deviation of magnitude noise applied
498 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
499 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
500 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
501 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
502 | :return: A PyTorch compatible Transform
503 | """
504 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
505 | num_layers = 2 # default to 2 ops per image
506 | weight_idx = None # default to no probability weights for op choice
507 | transforms = _RAND_TRANSFORMS
508 | config = config_str.split("-")
509 | assert config[0] == "rand"
510 | config = config[1:]
511 | for c in config:
512 | cs = re.split(r"(\d.*)", c)
513 | if len(cs) < 2:
514 | continue
515 | key, val = cs[:2]
516 | if key == "mstd":
517 | # noise param injected via hparams for now
518 | hparams.setdefault("magnitude_std", float(val))
519 | elif key == "inc":
520 | if bool(val):
521 | transforms = _RAND_INCREASING_TRANSFORMS
522 | elif key == "m":
523 | magnitude = int(val)
524 | elif key == "n":
525 | num_layers = int(val)
526 | elif key == "w":
527 | weight_idx = int(val)
528 | else:
529 | assert NotImplementedError
530 | ra_ops = rand_augment_ops(
531 | magnitude=magnitude, hparams=hparams, transforms=transforms
532 | )
533 | choice_weights = (
534 | None if weight_idx is None else _select_rand_weights(weight_idx)
535 | )
536 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
537 |
--------------------------------------------------------------------------------
/video_dataset/random_erasing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/random_erasing.py
3 |
4 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
5 |
6 | """
7 | This implementation is based on
8 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
9 | pulished under an Apache License 2.0.
10 |
11 | COMMENT FROM ORIGINAL:
12 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
13 | Copyright Zhun Zhong & Liang Zheng
14 | Hacked together by / Copyright 2020 Ross Wightman
15 | """
16 | import math
17 | import random
18 | import torch
19 |
20 |
21 | def _get_pixels(
22 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
23 | ):
24 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
25 | # paths, flip the order so normal is run on CPU if this becomes a problem
26 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
27 | if per_pixel:
28 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
29 | elif rand_color:
30 | return torch.empty(
31 | (patch_size[0], 1, 1), dtype=dtype, device=device
32 | ).normal_()
33 | else:
34 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
35 |
36 |
37 | class RandomErasing:
38 | """Randomly selects a rectangle region in an image and erases its pixels.
39 | 'Random Erasing Data Augmentation' by Zhong et al.
40 | See https://arxiv.org/pdf/1708.04896.pdf
41 | This variant of RandomErasing is intended to be applied to either a batch
42 | or single image tensor after it has been normalized by dataset mean and std.
43 | Args:
44 | probability: Probability that the Random Erasing operation will be performed.
45 | min_area: Minimum percentage of erased area wrt input image area.
46 | max_area: Maximum percentage of erased area wrt input image area.
47 | min_aspect: Minimum aspect ratio of erased area.
48 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
49 | 'const' - erase block is constant color of 0 for all channels
50 | 'rand' - erase block is same per-channel random (normal) color
51 | 'pixel' - erase block is per-pixel random (normal) color
52 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
53 | per-image count is randomly chosen between 1 and this value.
54 | """
55 |
56 | def __init__(
57 | self,
58 | probability=0.5,
59 | min_area=0.02,
60 | max_area=1 / 3,
61 | min_aspect=0.3,
62 | max_aspect=None,
63 | mode="const",
64 | min_count=1,
65 | max_count=None,
66 | num_splits=0,
67 | device="cuda",
68 | cube=True,
69 | ):
70 | self.probability = probability
71 | self.min_area = min_area
72 | self.max_area = max_area
73 | max_aspect = max_aspect or 1 / min_aspect
74 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
75 | self.min_count = min_count
76 | self.max_count = max_count or min_count
77 | self.num_splits = num_splits
78 | mode = mode.lower()
79 | self.rand_color = False
80 | self.per_pixel = False
81 | self.cube = cube
82 | if mode == "rand":
83 | self.rand_color = True # per block random normal
84 | elif mode == "pixel":
85 | self.per_pixel = True # per pixel random normal
86 | else:
87 | assert not mode or mode == "const"
88 | self.device = device
89 |
90 | def _erase(self, img, chan, img_h, img_w, dtype):
91 | if random.random() > self.probability:
92 | return
93 | area = img_h * img_w
94 | count = (
95 | self.min_count
96 | if self.min_count == self.max_count
97 | else random.randint(self.min_count, self.max_count)
98 | )
99 | for _ in range(count):
100 | for _ in range(10):
101 | target_area = (
102 | random.uniform(self.min_area, self.max_area) * area / count
103 | )
104 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
105 | h = int(round(math.sqrt(target_area * aspect_ratio)))
106 | w = int(round(math.sqrt(target_area / aspect_ratio)))
107 | if w < img_w and h < img_h:
108 | top = random.randint(0, img_h - h)
109 | left = random.randint(0, img_w - w)
110 | img[:, top : top + h, left : left + w] = _get_pixels(
111 | self.per_pixel,
112 | self.rand_color,
113 | (chan, h, w),
114 | dtype=dtype,
115 | device=self.device,
116 | )
117 | break
118 |
119 | def _erase_cube(
120 | self,
121 | img,
122 | batch_start,
123 | batch_size,
124 | chan,
125 | img_h,
126 | img_w,
127 | dtype,
128 | ):
129 | if random.random() > self.probability:
130 | return
131 | area = img_h * img_w
132 | count = (
133 | self.min_count
134 | if self.min_count == self.max_count
135 | else random.randint(self.min_count, self.max_count)
136 | )
137 | for _ in range(count):
138 | for _ in range(100):
139 | target_area = (
140 | random.uniform(self.min_area, self.max_area) * area / count
141 | )
142 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
143 | h = int(round(math.sqrt(target_area * aspect_ratio)))
144 | w = int(round(math.sqrt(target_area / aspect_ratio)))
145 | if w < img_w and h < img_h:
146 | top = random.randint(0, img_h - h)
147 | left = random.randint(0, img_w - w)
148 | for i in range(batch_start, batch_size):
149 | img_instance = img[i]
150 | img_instance[
151 | :, top : top + h, left : left + w
152 | ] = _get_pixels(
153 | self.per_pixel,
154 | self.rand_color,
155 | (chan, h, w),
156 | dtype=dtype,
157 | device=self.device,
158 | )
159 | break
160 |
161 | def __call__(self, input):
162 | if len(input.size()) == 3:
163 | self._erase(input, *input.size(), input.dtype)
164 | else:
165 | batch_size, chan, img_h, img_w = input.size()
166 | # skip first slice of batch if num_splits is set (for clean portion of samples)
167 | batch_start = (
168 | batch_size // self.num_splits if self.num_splits > 1 else 0
169 | )
170 | if self.cube:
171 | self._erase_cube(
172 | input,
173 | batch_start,
174 | batch_size,
175 | chan,
176 | img_h,
177 | img_w,
178 | input.dtype,
179 | )
180 | else:
181 | for i in range(batch_start, batch_size):
182 | self._erase(input[i], chan, img_h, img_w, input.dtype)
183 | return input
184 |
--------------------------------------------------------------------------------
/video_dataset/transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Originate from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/transform.py
3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
4 |
5 | import logging
6 | import math
7 | import numpy as np
8 |
9 | # import cv2
10 | import random
11 | import torch
12 | import torchvision as tv
13 | import torchvision.transforms.functional as F
14 | from PIL import Image, ImageFilter
15 | from torchvision import transforms
16 |
17 | from .rand_augment import rand_augment_transform
18 | from .random_erasing import RandomErasing
19 |
20 | _pil_interpolation_to_str = {
21 | Image.NEAREST: "PIL.Image.NEAREST",
22 | Image.BILINEAR: "PIL.Image.BILINEAR",
23 | Image.BICUBIC: "PIL.Image.BICUBIC",
24 | Image.LANCZOS: "PIL.Image.LANCZOS",
25 | Image.HAMMING: "PIL.Image.HAMMING",
26 | Image.BOX: "PIL.Image.BOX",
27 | }
28 |
29 |
30 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
31 |
32 |
33 | def _pil_interp(method):
34 | if method == "bicubic":
35 | return Image.BICUBIC
36 | elif method == "lanczos":
37 | return Image.LANCZOS
38 | elif method == "hamming":
39 | return Image.HAMMING
40 | else:
41 | return Image.BILINEAR
42 |
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 |
47 | def random_short_side_scale_jitter(
48 | images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
49 | ):
50 | """
51 | Perform a spatial short scale jittering on the given images and
52 | corresponding boxes.
53 | Args:
54 | images (tensor): images to perform scale jitter. Dimension is
55 | `num frames` x `channel` x `height` x `width`.
56 | min_size (int): the minimal size to scale the frames.
57 | max_size (int): the maximal size to scale the frames.
58 | boxes (ndarray): optional. Corresponding boxes to images.
59 | Dimension is `num boxes` x 4.
60 | inverse_uniform_sampling (bool): if True, sample uniformly in
61 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
62 | scale. If False, take a uniform sample from [min_scale, max_scale].
63 | Returns:
64 | (tensor): the scaled images with dimension of
65 | `num frames` x `channel` x `new height` x `new width`.
66 | (ndarray or None): the scaled boxes with dimension of
67 | `num boxes` x 4.
68 | """
69 | if inverse_uniform_sampling:
70 | size = int(
71 | round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
72 | )
73 | else:
74 | size = int(round(np.random.uniform(min_size, max_size)))
75 |
76 | height = images.shape[2]
77 | width = images.shape[3]
78 | if (width <= height and width == size) or (
79 | height <= width and height == size
80 | ):
81 | return images, boxes
82 | new_width = size
83 | new_height = size
84 | if width < height:
85 | new_height = int(math.floor((float(height) / width) * size))
86 | if boxes is not None:
87 | boxes = boxes * float(new_height) / height
88 | else:
89 | new_width = int(math.floor((float(width) / height) * size))
90 | if boxes is not None:
91 | boxes = boxes * float(new_width) / width
92 |
93 | return (
94 | torch.nn.functional.interpolate(
95 | images,
96 | size=(new_height, new_width),
97 | mode="bilinear",
98 | align_corners=False,
99 | ),
100 | boxes,
101 | )
102 |
103 |
104 | def crop_boxes(boxes, x_offset, y_offset):
105 | """
106 | Peform crop on the bounding boxes given the offsets.
107 | Args:
108 | boxes (ndarray or None): bounding boxes to peform crop. The dimension
109 | is `num boxes` x 4.
110 | x_offset (int): cropping offset in the x axis.
111 | y_offset (int): cropping offset in the y axis.
112 | Returns:
113 | cropped_boxes (ndarray or None): the cropped boxes with dimension of
114 | `num boxes` x 4.
115 | """
116 | cropped_boxes = boxes.copy()
117 | cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
118 | cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
119 |
120 | return cropped_boxes
121 |
122 |
123 | def random_crop(images, size, boxes=None):
124 | """
125 | Perform random spatial crop on the given images and corresponding boxes.
126 | Args:
127 | images (tensor): images to perform random crop. The dimension is
128 | `num frames` x `channel` x `height` x `width`.
129 | size (int): the size of height and width to crop on the image.
130 | boxes (ndarray or None): optional. Corresponding boxes to images.
131 | Dimension is `num boxes` x 4.
132 | Returns:
133 | cropped (tensor): cropped images with dimension of
134 | `num frames` x `channel` x `size` x `size`.
135 | cropped_boxes (ndarray or None): the cropped boxes with dimension of
136 | `num boxes` x 4.
137 | """
138 | if images.shape[2] == size and images.shape[3] == size:
139 | return images, boxes
140 | height = images.shape[2]
141 | width = images.shape[3]
142 | y_offset = 0
143 | if height > size:
144 | y_offset = int(np.random.randint(0, height - size))
145 | x_offset = 0
146 | if width > size:
147 | x_offset = int(np.random.randint(0, width - size))
148 | cropped = images[
149 | :, :, y_offset : y_offset + size, x_offset : x_offset + size
150 | ]
151 |
152 | cropped_boxes = (
153 | crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
154 | )
155 |
156 | return cropped, cropped_boxes
157 |
158 |
159 | def horizontal_flip(prob, images, boxes=None):
160 | """
161 | Perform horizontal flip on the given images and corresponding boxes.
162 | Args:
163 | prob (float): probility to flip the images.
164 | images (tensor): images to perform horizontal flip, the dimension is
165 | `num frames` x `channel` x `height` x `width`.
166 | boxes (ndarray or None): optional. Corresponding boxes to images.
167 | Dimension is `num boxes` x 4.
168 | Returns:
169 | images (tensor): images with dimension of
170 | `num frames` x `channel` x `height` x `width`.
171 | flipped_boxes (ndarray or None): the flipped boxes with dimension of
172 | `num boxes` x 4.
173 | """
174 | if boxes is None:
175 | flipped_boxes = None
176 | else:
177 | flipped_boxes = boxes.copy()
178 |
179 | if np.random.uniform() < prob:
180 | images = images.flip((-1))
181 |
182 | if len(images.shape) == 3:
183 | width = images.shape[2]
184 | elif len(images.shape) == 4:
185 | width = images.shape[3]
186 | else:
187 | raise NotImplementedError("Dimension does not supported")
188 | if boxes is not None:
189 | flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
190 |
191 | return images, flipped_boxes
192 |
193 |
194 | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
195 | """
196 | Perform uniform spatial sampling on the images and corresponding boxes.
197 | Args:
198 | images (tensor): images to perform uniform crop. The dimension is
199 | `num frames` x `channel` x `height` x `width`.
200 | size (int): size of height and weight to crop the images.
201 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
202 | is larger than height. Or 0, 1, or 2 for top, center, and bottom
203 | crop if height is larger than width.
204 | boxes (ndarray or None): optional. Corresponding boxes to images.
205 | Dimension is `num boxes` x 4.
206 | scale_size (int): optinal. If not None, resize the images to scale_size before
207 | performing any crop.
208 | Returns:
209 | cropped (tensor): images with dimension of
210 | `num frames` x `channel` x `size` x `size`.
211 | cropped_boxes (ndarray or None): the cropped boxes with dimension of
212 | `num boxes` x 4.
213 | """
214 | assert spatial_idx in [0, 1, 2]
215 | ndim = len(images.shape)
216 | if ndim == 3:
217 | images = images.unsqueeze(0)
218 | height = images.shape[2]
219 | width = images.shape[3]
220 |
221 | if scale_size is not None:
222 | if width <= height:
223 | width, height = scale_size, int(height / width * scale_size)
224 | else:
225 | width, height = int(width / height * scale_size), scale_size
226 | images = torch.nn.functional.interpolate(
227 | images,
228 | size=(height, width),
229 | mode="bilinear",
230 | align_corners=False,
231 | )
232 |
233 | y_offset = int(math.ceil((height - size) / 2))
234 | x_offset = int(math.ceil((width - size) / 2))
235 |
236 | if height > width:
237 | if spatial_idx == 0:
238 | y_offset = 0
239 | elif spatial_idx == 2:
240 | y_offset = height - size
241 | else:
242 | if spatial_idx == 0:
243 | x_offset = 0
244 | elif spatial_idx == 2:
245 | x_offset = width - size
246 | cropped = images[
247 | :, :, y_offset : y_offset + size, x_offset : x_offset + size
248 | ]
249 | cropped_boxes = (
250 | crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
251 | )
252 | if ndim == 3:
253 | cropped = cropped.squeeze(0)
254 | return cropped, cropped_boxes
255 |
256 |
257 | def clip_boxes_to_image(boxes, height, width):
258 | """
259 | Clip an array of boxes to an image with the given height and width.
260 | Args:
261 | boxes (ndarray): bounding boxes to perform clipping.
262 | Dimension is `num boxes` x 4.
263 | height (int): given image height.
264 | width (int): given image width.
265 | Returns:
266 | clipped_boxes (ndarray): the clipped boxes with dimension of
267 | `num boxes` x 4.
268 | """
269 | clipped_boxes = boxes.copy()
270 | clipped_boxes[:, [0, 2]] = np.minimum(
271 | width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
272 | )
273 | clipped_boxes[:, [1, 3]] = np.minimum(
274 | height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
275 | )
276 | return clipped_boxes
277 |
278 |
279 | def blend(images1, images2, alpha):
280 | """
281 | Blend two images with a given weight alpha.
282 | Args:
283 | images1 (tensor): the first images to be blended, the dimension is
284 | `num frames` x `channel` x `height` x `width`.
285 | images2 (tensor): the second images to be blended, the dimension is
286 | `num frames` x `channel` x `height` x `width`.
287 | alpha (float): the blending weight.
288 | Returns:
289 | (tensor): blended images, the dimension is
290 | `num frames` x `channel` x `height` x `width`.
291 | """
292 | return images1 * alpha + images2 * (1 - alpha)
293 |
294 |
295 | def grayscale(images):
296 | """
297 | Get the grayscale for the input images. The channels of images should be
298 | in order BGR.
299 | Args:
300 | images (tensor): the input images for getting grayscale. Dimension is
301 | `num frames` x `channel` x `height` x `width`.
302 | Returns:
303 | img_gray (tensor): blended images, the dimension is
304 | `num frames` x `channel` x `height` x `width`.
305 | """
306 | # R -> 0.299, G -> 0.587, B -> 0.114.
307 | img_gray = torch.tensor(images)
308 | gray_channel = (
309 | 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
310 | )
311 | img_gray[:, 0] = gray_channel
312 | img_gray[:, 1] = gray_channel
313 | img_gray[:, 2] = gray_channel
314 | return img_gray
315 |
316 |
317 | def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
318 | """
319 | Perfrom a color jittering on the input images. The channels of images
320 | should be in order BGR.
321 | Args:
322 | images (tensor): images to perform color jitter. Dimension is
323 | `num frames` x `channel` x `height` x `width`.
324 | img_brightness (float): jitter ratio for brightness.
325 | img_contrast (float): jitter ratio for contrast.
326 | img_saturation (float): jitter ratio for saturation.
327 | Returns:
328 | images (tensor): the jittered images, the dimension is
329 | `num frames` x `channel` x `height` x `width`.
330 | """
331 |
332 | jitter = []
333 | if img_brightness != 0:
334 | jitter.append("brightness")
335 | if img_contrast != 0:
336 | jitter.append("contrast")
337 | if img_saturation != 0:
338 | jitter.append("saturation")
339 |
340 | if len(jitter) > 0:
341 | order = np.random.permutation(np.arange(len(jitter)))
342 | for idx in range(0, len(jitter)):
343 | if jitter[order[idx]] == "brightness":
344 | images = brightness_jitter(img_brightness, images)
345 | elif jitter[order[idx]] == "contrast":
346 | images = contrast_jitter(img_contrast, images)
347 | elif jitter[order[idx]] == "saturation":
348 | images = saturation_jitter(img_saturation, images)
349 | return images
350 |
351 |
352 | def brightness_jitter(var, images):
353 | """
354 | Perfrom brightness jittering on the input images. The channels of images
355 | should be in order BGR.
356 | Args:
357 | var (float): jitter ratio for brightness.
358 | images (tensor): images to perform color jitter. Dimension is
359 | `num frames` x `channel` x `height` x `width`.
360 | Returns:
361 | images (tensor): the jittered images, the dimension is
362 | `num frames` x `channel` x `height` x `width`.
363 | """
364 | alpha = 1.0 + np.random.uniform(-var, var)
365 |
366 | img_bright = torch.zeros(images.shape)
367 | images = blend(images, img_bright, alpha)
368 | return images
369 |
370 |
371 | def contrast_jitter(var, images):
372 | """
373 | Perfrom contrast jittering on the input images. The channels of images
374 | should be in order BGR.
375 | Args:
376 | var (float): jitter ratio for contrast.
377 | images (tensor): images to perform color jitter. Dimension is
378 | `num frames` x `channel` x `height` x `width`.
379 | Returns:
380 | images (tensor): the jittered images, the dimension is
381 | `num frames` x `channel` x `height` x `width`.
382 | """
383 | alpha = 1.0 + np.random.uniform(-var, var)
384 |
385 | img_gray = grayscale(images)
386 | img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
387 | images = blend(images, img_gray, alpha)
388 | return images
389 |
390 |
391 | def saturation_jitter(var, images):
392 | """
393 | Perfrom saturation jittering on the input images. The channels of images
394 | should be in order BGR.
395 | Args:
396 | var (float): jitter ratio for saturation.
397 | images (tensor): images to perform color jitter. Dimension is
398 | `num frames` x `channel` x `height` x `width`.
399 | Returns:
400 | images (tensor): the jittered images, the dimension is
401 | `num frames` x `channel` x `height` x `width`.
402 | """
403 | alpha = 1.0 + np.random.uniform(-var, var)
404 | img_gray = grayscale(images)
405 | images = blend(images, img_gray, alpha)
406 |
407 | return images
408 |
409 |
410 | def lighting_jitter(images, alphastd, eigval, eigvec):
411 | """
412 | Perform AlexNet-style PCA jitter on the given images.
413 | Args:
414 | images (tensor): images to perform lighting jitter. Dimension is
415 | `num frames` x `channel` x `height` x `width`.
416 | alphastd (float): jitter ratio for PCA jitter.
417 | eigval (list): eigenvalues for PCA jitter.
418 | eigvec (list[list]): eigenvectors for PCA jitter.
419 | Returns:
420 | out_images (tensor): the jittered images, the dimension is
421 | `num frames` x `channel` x `height` x `width`.
422 | """
423 | if alphastd == 0:
424 | return images
425 | # generate alpha1, alpha2, alpha3.
426 | alpha = np.random.normal(0, alphastd, size=(1, 3))
427 | eig_vec = np.array(eigvec)
428 | eig_val = np.reshape(eigval, (1, 3))
429 | rgb = np.sum(
430 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
431 | axis=1,
432 | )
433 | out_images = torch.zeros_like(images)
434 | if len(images.shape) == 3:
435 | # C H W
436 | channel_dim = 0
437 | elif len(images.shape) == 4:
438 | # T C H W
439 | channel_dim = 1
440 | else:
441 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
442 |
443 | for idx in range(images.shape[channel_dim]):
444 | # C H W
445 | if len(images.shape) == 3:
446 | out_images[idx] = images[idx] + rgb[2 - idx]
447 | # T C H W
448 | elif len(images.shape) == 4:
449 | out_images[:, idx] = images[:, idx] + rgb[2 - idx]
450 | else:
451 | raise NotImplementedError(
452 | f"Unsupported dimension {len(images.shape)}"
453 | )
454 |
455 | return out_images
456 |
457 |
458 | def color_normalization(images, mean, stddev):
459 | """
460 | Perform color nomration on the given images.
461 | Args:
462 | images (tensor): images to perform color normalization. Dimension is
463 | `num frames` x `channel` x `height` x `width`.
464 | mean (list): mean values for normalization.
465 | stddev (list): standard deviations for normalization.
466 |
467 | Returns:
468 | out_images (tensor): the noramlized images, the dimension is
469 | `num frames` x `channel` x `height` x `width`.
470 | """
471 | if len(images.shape) == 3:
472 | assert (
473 | len(mean) == images.shape[0]
474 | ), "channel mean not computed properly"
475 | assert (
476 | len(stddev) == images.shape[0]
477 | ), "channel stddev not computed properly"
478 | elif len(images.shape) == 4:
479 | assert (
480 | len(mean) == images.shape[1]
481 | ), "channel mean not computed properly"
482 | assert (
483 | len(stddev) == images.shape[1]
484 | ), "channel stddev not computed properly"
485 | else:
486 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
487 |
488 | out_images = torch.zeros_like(images)
489 | for idx in range(len(mean)):
490 | # C H W
491 | if len(images.shape) == 3:
492 | out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
493 | elif len(images.shape) == 4:
494 | out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
495 | else:
496 | raise NotImplementedError(
497 | f"Unsupported dimension {len(images.shape)}"
498 | )
499 | return out_images
500 |
501 |
502 | def _get_param_spatial_crop(
503 | scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False
504 | ):
505 | """
506 | Given scale, ratio, height and width, return sampled coordinates of the videos.
507 | """
508 | for _ in range(num_repeat):
509 | area = height * width
510 | target_area = random.uniform(*scale) * area
511 | if log_scale:
512 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
513 | aspect_ratio = math.exp(random.uniform(*log_ratio))
514 | else:
515 | aspect_ratio = random.uniform(*ratio)
516 |
517 | w = int(round(math.sqrt(target_area * aspect_ratio)))
518 | h = int(round(math.sqrt(target_area / aspect_ratio)))
519 |
520 | if np.random.uniform() < 0.5 and switch_hw:
521 | w, h = h, w
522 |
523 | if 0 < w <= width and 0 < h <= height:
524 | i = random.randint(0, height - h)
525 | j = random.randint(0, width - w)
526 | return i, j, h, w
527 |
528 | # Fallback to central crop
529 | in_ratio = float(width) / float(height)
530 | if in_ratio < min(ratio):
531 | w = width
532 | h = int(round(w / min(ratio)))
533 | elif in_ratio > max(ratio):
534 | h = height
535 | w = int(round(h * max(ratio)))
536 | else: # whole image
537 | w = width
538 | h = height
539 | i = (height - h) // 2
540 | j = (width - w) // 2
541 | return i, j, h, w
542 |
543 |
544 | def random_resized_crop(
545 | images,
546 | target_height,
547 | target_width,
548 | scale=(0.08, 1.0),
549 | ratio=(3.0 / 4.0, 4.0 / 3.0),
550 | ):
551 | """
552 | Crop the given images to random size and aspect ratio. A crop of random
553 | size (default: of 0.08 to 1.0) of the original size and a random aspect
554 | ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
555 | crop is finally resized to given size. This is popularly used to train the
556 | Inception networks.
557 |
558 | Args:
559 | images: Images to perform resizing and cropping.
560 | target_height: Desired height after cropping.
561 | target_width: Desired width after cropping.
562 | scale: Scale range of Inception-style area based random resizing.
563 | ratio: Aspect ratio range of Inception-style area based random resizing.
564 | """
565 |
566 | height = images.shape[2]
567 | width = images.shape[3]
568 |
569 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
570 | cropped = images[:, :, i : i + h, j : j + w]
571 | return torch.nn.functional.interpolate(
572 | cropped,
573 | size=(target_height, target_width),
574 | mode="bilinear",
575 | align_corners=False,
576 | )
577 |
578 |
579 | def random_resized_crop_with_shift(
580 | images,
581 | target_height,
582 | target_width,
583 | scale=(0.8, 1.0),
584 | ratio=(3.0 / 4.0, 4.0 / 3.0),
585 | ):
586 | """
587 | This is similar to random_resized_crop. However, it samples two different
588 | boxes (for cropping) for the first and last frame. It then linearly
589 | interpolates the two boxes for other frames.
590 |
591 | Args:
592 | images: Images to perform resizing and cropping.
593 | target_height: Desired height after cropping.
594 | target_width: Desired width after cropping.
595 | scale: Scale range of Inception-style area based random resizing.
596 | ratio: Aspect ratio range of Inception-style area based random resizing.
597 | """
598 | t = images.shape[1]
599 | height = images.shape[2]
600 | width = images.shape[3]
601 |
602 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
603 | i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
604 | i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
605 | j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
606 | h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
607 | w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
608 | out = torch.zeros((3, t, target_height, target_width))
609 | for ind in range(t):
610 | out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
611 | images[
612 | :,
613 | ind : ind + 1,
614 | i_s[ind] : i_s[ind] + h_s[ind],
615 | j_s[ind] : j_s[ind] + w_s[ind],
616 | ],
617 | size=(target_height, target_width),
618 | mode="bilinear",
619 | align_corners=False,
620 | )
621 | return out
622 |
623 |
624 | def create_random_augment(
625 | input_size,
626 | auto_augment=None,
627 | interpolation="bilinear",
628 | ):
629 | """
630 | Get video randaug transform.
631 |
632 | Args:
633 | input_size: The size of the input video in tuple.
634 | auto_augment: Parameters for randaug. An example:
635 | "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
636 | of operations to apply).
637 | interpolation: Interpolation method.
638 | """
639 | if isinstance(input_size, tuple):
640 | img_size = input_size[-2:]
641 | else:
642 | img_size = input_size
643 |
644 | if auto_augment:
645 | assert isinstance(auto_augment, str)
646 | if isinstance(img_size, tuple):
647 | img_size_min = min(img_size)
648 | else:
649 | img_size_min = img_size
650 | aa_params = {"translate_const": int(img_size_min * 0.45)}
651 | if interpolation and interpolation != "random":
652 | aa_params["interpolation"] = _pil_interp(interpolation)
653 | if auto_augment.startswith("rand"):
654 | return transforms.Compose(
655 | [rand_augment_transform(auto_augment, aa_params)]
656 | )
657 | raise NotImplementedError
658 |
659 |
660 | def random_sized_crop_img(
661 | im,
662 | size,
663 | jitter_scale=(0.08, 1.0),
664 | jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
665 | max_iter=10,
666 | ):
667 | """
668 | Performs Inception-style cropping (used for training).
669 | """
670 | assert (
671 | len(im.shape) == 3
672 | ), "Currently only support image for random_sized_crop"
673 | h, w = im.shape[1:3]
674 | i, j, h, w = _get_param_spatial_crop(
675 | scale=jitter_scale,
676 | ratio=jitter_aspect,
677 | height=h,
678 | width=w,
679 | num_repeat=max_iter,
680 | log_scale=False,
681 | switch_hw=True,
682 | )
683 | cropped = im[:, i : i + h, j : j + w]
684 | return torch.nn.functional.interpolate(
685 | cropped.unsqueeze(0),
686 | size=(size, size),
687 | mode="bilinear",
688 | align_corners=False,
689 | ).squeeze(0)
690 |
691 |
692 | # The following code are modified based on timm lib, we will replace the following
693 | # contents with dependency from PyTorchVideo.
694 | # https://github.com/facebookresearch/pytorchvideo
695 | class RandomResizedCropAndInterpolation:
696 | """Crop the given PIL Image to random size and aspect ratio with random interpolation.
697 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
698 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
699 | is finally resized to given size.
700 | This is popularly used to train the Inception networks.
701 | Args:
702 | size: expected output size of each edge
703 | scale: range of size of the origin size cropped
704 | ratio: range of aspect ratio of the origin aspect ratio cropped
705 | interpolation: Default: PIL.Image.BILINEAR
706 | """
707 |
708 | def __init__(
709 | self,
710 | size,
711 | scale=(0.08, 1.0),
712 | ratio=(3.0 / 4.0, 4.0 / 3.0),
713 | interpolation="bilinear",
714 | ):
715 | if isinstance(size, tuple):
716 | self.size = size
717 | else:
718 | self.size = (size, size)
719 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
720 | print("range should be of kind (min, max)")
721 |
722 | if interpolation == "random":
723 | self.interpolation = _RANDOM_INTERPOLATION
724 | else:
725 | self.interpolation = _pil_interp(interpolation)
726 | self.scale = scale
727 | self.ratio = ratio
728 |
729 | @staticmethod
730 | def get_params(img, scale, ratio):
731 | """Get parameters for ``crop`` for a random sized crop.
732 | Args:
733 | img (PIL Image): Image to be cropped.
734 | scale (tuple): range of size of the origin size cropped
735 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
736 | Returns:
737 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
738 | sized crop.
739 | """
740 | area = img.size[0] * img.size[1]
741 |
742 | for _ in range(10):
743 | target_area = random.uniform(*scale) * area
744 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
745 | aspect_ratio = math.exp(random.uniform(*log_ratio))
746 |
747 | w = int(round(math.sqrt(target_area * aspect_ratio)))
748 | h = int(round(math.sqrt(target_area / aspect_ratio)))
749 |
750 | if w <= img.size[0] and h <= img.size[1]:
751 | i = random.randint(0, img.size[1] - h)
752 | j = random.randint(0, img.size[0] - w)
753 | return i, j, h, w
754 |
755 | # Fallback to central crop
756 | in_ratio = img.size[0] / img.size[1]
757 | if in_ratio < min(ratio):
758 | w = img.size[0]
759 | h = int(round(w / min(ratio)))
760 | elif in_ratio > max(ratio):
761 | h = img.size[1]
762 | w = int(round(h * max(ratio)))
763 | else: # whole image
764 | w = img.size[0]
765 | h = img.size[1]
766 | i = (img.size[1] - h) // 2
767 | j = (img.size[0] - w) // 2
768 | return i, j, h, w
769 |
770 | def __call__(self, img):
771 | """
772 | Args:
773 | img (PIL Image): Image to be cropped and resized.
774 | Returns:
775 | PIL Image: Randomly cropped and resized image.
776 | """
777 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
778 | if isinstance(self.interpolation, (tuple, list)):
779 | interpolation = random.choice(self.interpolation)
780 | else:
781 | interpolation = self.interpolation
782 | return F.resized_crop(img, i, j, h, w, self.size, interpolation)
783 |
784 | def __repr__(self):
785 | if isinstance(self.interpolation, (tuple, list)):
786 | interpolate_str = " ".join(
787 | [_pil_interpolation_to_str[x] for x in self.interpolation]
788 | )
789 | else:
790 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
791 | format_string = self.__class__.__name__ + "(size={0}".format(self.size)
792 | format_string += ", scale={0}".format(
793 | tuple(round(s, 4) for s in self.scale)
794 | )
795 | format_string += ", ratio={0}".format(
796 | tuple(round(r, 4) for r in self.ratio)
797 | )
798 | format_string += ", interpolation={0})".format(interpolate_str)
799 | return format_string
800 |
--------------------------------------------------------------------------------
/vision_transformer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from collections import OrderedDict
4 | import numpy as np
5 | from typing import Tuple
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | '''
12 | QuickGELU and LayerNorm w/ fp16 from official CLIP repo
13 | (https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py)
14 | '''
15 | class QuickGELU(nn.Module):
16 | def forward(self, x: torch.Tensor):
17 | return x * torch.sigmoid(1.702 * x)
18 |
19 | class LayerNorm(nn.LayerNorm):
20 | """Subclass torch's LayerNorm to handle fp16."""
21 |
22 | def forward(self, x: torch.Tensor):
23 | orig_type = x.dtype
24 | ret = super().forward(x.type(torch.float32))
25 | return ret.type(orig_type)
26 |
27 |
28 | class Attention(nn.Module):
29 | '''
30 | A generalized attention module with more flexibility.
31 | '''
32 |
33 | def __init__(
34 | self, q_in_dim: int, k_in_dim: int, v_in_dim: int,
35 | qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int,
36 | return_all_features: bool = False,
37 | ):
38 | super().__init__()
39 |
40 | self.q_proj = nn.Linear(q_in_dim, qk_proj_dim)
41 | self.k_proj = nn.Linear(k_in_dim, qk_proj_dim)
42 | self.v_proj = nn.Linear(v_in_dim, v_proj_dim)
43 | self.out_proj = nn.Linear(v_proj_dim, out_dim)
44 |
45 | self.num_heads = num_heads
46 | self.return_all_features = return_all_features
47 | assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0
48 |
49 | self._initialize_weights()
50 |
51 |
52 | def _initialize_weights(self):
53 | for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
54 | nn.init.xavier_uniform_(m.weight)
55 | nn.init.constant_(m.bias, 0.)
56 |
57 |
58 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
59 | assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
60 | N = q.size(0); assert k.size(0) == N and v.size(0) == N
61 | Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv
62 |
63 | q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
64 |
65 | H = self.num_heads
66 | Cqk, Cv = q.size(-1) // H, v.size(-1) // H
67 |
68 | q = q.view(N, Lq, H, Cqk)
69 | k = k.view(N, Lkv, H, Cqk)
70 | v = v.view(N, Lkv, H, Cv)
71 |
72 | aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k)
73 | aff = aff.softmax(dim=-2)
74 | mix = torch.einsum('nqlh,nlhc->nqhc', aff, v)
75 |
76 | out = self.out_proj(mix.flatten(-2))
77 |
78 | if self.return_all_features:
79 | return dict(q=q, k=k, v=v, aff=aff, out=out)
80 | else:
81 | return out
82 |
83 |
84 | class PatchEmbed2D(nn.Module):
85 |
86 | def __init__(
87 | self,
88 | patch_size: Tuple[int, int] = (16, 16),
89 | in_channels: int = 3,
90 | embed_dim: int = 768,
91 | ):
92 | super().__init__()
93 |
94 | self.patch_size = patch_size
95 | self.in_channels = in_channels
96 |
97 | self.proj = nn.Linear(np.prod(patch_size) * in_channels, embed_dim)
98 |
99 |
100 | def _initialize_weights(self, x):
101 | nn.init.kaiming_normal_(self.proj.weight, 0.)
102 | nn.init.constant_(self.proj.bias, 0.)
103 |
104 |
105 | def forward(self, x: torch.Tensor):
106 | B, C, H, W = x.size()
107 | pH, pW = self.patch_size
108 |
109 | assert C == self.in_channels and H % pH == 0 and W % pW == 0
110 |
111 | x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3).flatten(1, 2)
112 | x = self.proj(x)
113 |
114 | return x
115 |
116 | class TransformerEncoderLayer(nn.Module):
117 |
118 | def __init__(
119 | self,
120 | in_feature_dim: int = 768,
121 | qkv_dim: int = 768,
122 | num_heads: int = 12,
123 | mlp_factor: float = 4.0,
124 | mlp_dropout: float = 0.0,
125 | act: nn.Module = QuickGELU,
126 | return_all_features: bool = False,
127 | ):
128 | super().__init__()
129 |
130 | self.return_all_features = return_all_features
131 |
132 | self.attn = Attention(
133 | q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
134 | qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim,
135 | return_all_features=return_all_features,
136 | )
137 |
138 | mlp_dim = round(mlp_factor * in_feature_dim)
139 | self.mlp = nn.Sequential(OrderedDict([
140 | ('fc1', nn.Linear(in_feature_dim, mlp_dim)),
141 | ('act', act()),
142 | ('dropout', nn.Dropout(mlp_dropout)),
143 | ('fc2', nn.Linear(mlp_dim, in_feature_dim)),
144 | ]))
145 |
146 | self.norm1 = LayerNorm(in_feature_dim)
147 | self.norm2 = LayerNorm(in_feature_dim)
148 |
149 | self._initialize_weights()
150 |
151 |
152 | def _initialize_weights(self):
153 | for m in (self.mlp[0], self.mlp[-1]):
154 | nn.init.xavier_uniform_(m.weight)
155 | nn.init.normal_(m.bias, std=1e-6)
156 |
157 |
158 | def forward(self, x: torch.Tensor):
159 | if self.return_all_features:
160 | ret_dict = {}
161 |
162 | x_norm = self.norm1(x)
163 | attn_out = self.attn(x_norm, x_norm, x_norm)
164 | ret_dict['q'] = attn_out['q']
165 | ret_dict['k'] = attn_out['k']
166 | ret_dict['v'] = attn_out['v']
167 | ret_dict['attn_out'] = attn_out['out']
168 | x = x + attn_out['out']
169 |
170 | x = x + self.mlp(self.norm2(x))
171 | ret_dict['out'] = x
172 |
173 | return ret_dict
174 |
175 | else:
176 | x_norm = self.norm1(x)
177 | x = x + self.attn(x_norm, x_norm, x_norm)
178 | x = x + self.mlp(self.norm2(x))
179 |
180 | return x
181 |
182 |
183 | class TransformerDecoderLayer(nn.Module):
184 |
185 | def __init__(
186 | self,
187 | in_feature_dim: int = 768,
188 | qkv_dim: int = 768,
189 | num_heads: int = 12,
190 | mlp_factor: float = 4.0,
191 | mlp_dropout: float = 0.0,
192 | act: nn.Module = QuickGELU,
193 | ):
194 | super().__init__()
195 |
196 | self.attn = Attention(
197 | q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
198 | qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim,
199 | )
200 |
201 | mlp_dim = round(mlp_factor * in_feature_dim)
202 | self.mlp = nn.Sequential(OrderedDict([
203 | ('fc1', nn.Linear(in_feature_dim, mlp_dim)),
204 | ('act', act()),
205 | ('dropout', nn.Dropout(mlp_dropout)),
206 | ('fc2', nn.Linear(mlp_dim, in_feature_dim)),
207 | ]))
208 |
209 | self.norm1 = LayerNorm(in_feature_dim)
210 | self.norm2 = LayerNorm(in_feature_dim)
211 | self.norm3 = LayerNorm(in_feature_dim)
212 |
213 | self._initialize_weights()
214 |
215 |
216 | def _initialize_weights(self):
217 | for m in (self.mlp[0], self.mlp[-1]):
218 | nn.init.xavier_uniform_(m.weight)
219 | nn.init.normal_(m.bias, std=1e-6)
220 |
221 |
222 | def forward(self, x: torch.Tensor, y: torch.Tensor):
223 | y_norm = self.norm3(y)
224 | x = x + self.attn(self.norm1(x), y_norm, y_norm)
225 | x = x + self.mlp(self.norm2(x))
226 |
227 | return x
228 |
229 |
230 | class VisionTransformer2D(nn.Module):
231 |
232 | def __init__(
233 | self,
234 | feature_dim: int = 768,
235 | input_size: Tuple[int, int] = (224, 224),
236 | patch_size: Tuple[int, int] = (16, 16),
237 | num_heads: int = 12,
238 | num_layers: int = 12,
239 | mlp_factor: float = 4.0,
240 | act: nn.Module = QuickGELU,
241 | return_all_features: bool = False,
242 | ln_pre: bool = False,
243 | ):
244 | super().__init__()
245 |
246 | self.return_all_features = return_all_features
247 |
248 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, embed_dim=feature_dim)
249 | self.num_patches = np.prod([x // y for x, y in zip(input_size, patch_size)]) + 1
250 |
251 | self.cls_token = nn.Parameter(torch.zeros([feature_dim]))
252 | self.pos_embed = nn.Parameter(torch.zeros([self.num_patches, feature_dim]))
253 |
254 | self.blocks = nn.ModuleList([
255 | TransformerEncoderLayer(
256 | in_feature_dim=feature_dim, qkv_dim=feature_dim, num_heads=num_heads, mlp_factor=mlp_factor, act=act,
257 | return_all_features=return_all_features,
258 | ) for _ in range(num_layers)
259 | ])
260 |
261 | if ln_pre:
262 | self.ln_pre = LayerNorm(feature_dim)
263 | else:
264 | self.ln_pre = nn.Identity()
265 |
266 | self._initialize_weights()
267 |
268 |
269 | def _initialize_weights(self):
270 | nn.init.normal_(self.cls_token, std=0.02)
271 | nn.init.normal_(self.pos_embed, std=0.02)
272 |
273 | def forward(self, x: torch.Tensor):
274 | dtype = self.patch_embed.proj.weight.dtype
275 | x = x.to(dtype)
276 |
277 | x = self.patch_embed(x)
278 | x = torch.cat([self.cls_token.view(1, 1, -1).repeat(x.size(0), 1, 1), x], dim=1)
279 | x = x + self.pos_embed
280 |
281 | x = self.ln_pre(x)
282 |
283 | if self.return_all_features:
284 | all_features = []
285 | for blk in self.blocks:
286 | x = blk(x)
287 | all_features.append(x)
288 | x = x['out']
289 | return all_features
290 |
291 | else:
292 | for blk in self.blocks:
293 | x = blk(x)
294 | return x
295 |
296 |
297 | def model_to_fp16(model: VisionTransformer2D):
298 | def _module_to_fp16(m: nn.Module):
299 | if isinstance(m, (nn.Linear,)):
300 | m.half()
301 | model.apply(_module_to_fp16)
302 |
303 | model.pos_embed.data = model.pos_embed.data.half()
304 | model.cls_token.data = model.cls_token.data.half()
305 |
306 |
307 | vit_presets = {
308 | 'ViT-B/16-lnpre': dict(
309 | feature_dim=768,
310 | input_size=(224, 224),
311 | patch_size=(16, 16),
312 | num_heads=12,
313 | num_layers=12,
314 | mlp_factor=4.0,
315 | ln_pre=True,
316 | ),
317 | 'ViT-L/14-lnpre': dict(
318 | feature_dim=1024,
319 | input_size=(224, 224),
320 | patch_size=(14, 14),
321 | num_heads=16,
322 | num_layers=24,
323 | mlp_factor=4.0,
324 | ln_pre=True,
325 | ),
326 | }
--------------------------------------------------------------------------------
/weight_loaders.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os, sys
4 | from typing import Dict
5 |
6 | import torch
7 |
8 | __all__ = ['weight_loader_fn_dict']
9 |
10 | def load_weights_clip(load_path: str) -> Dict[str, torch.Tensor]:
11 | clip_model = torch.jit.load(load_path, map_location='cpu')
12 | clip_model = clip_model.visual
13 | src_state_dict = clip_model.state_dict()
14 | src_state_dict = dict((k, v.float()) for k, v in src_state_dict.items())
15 |
16 | dst_state_dict = {}
17 |
18 | dst_state_dict['cls_token'] = src_state_dict['class_embedding']
19 | dst_state_dict['pos_embed'] = src_state_dict['positional_embedding']
20 | dst_state_dict['patch_embed.proj.weight'] = src_state_dict['conv1.weight'].flatten(1)
21 | dst_state_dict['patch_embed.proj.bias'] = torch.zeros([src_state_dict['conv1.weight'].size(0)])
22 |
23 | dst_state_dict['ln_pre.weight'] = src_state_dict['ln_pre.weight']
24 | dst_state_dict['ln_pre.bias'] = src_state_dict['ln_pre.bias']
25 |
26 | block_idx = 0
27 | while True:
28 | src_prefix = 'transformer.resblocks.%d.' % block_idx
29 | dst_prefix = 'blocks.%d.' % block_idx
30 |
31 | src_block_state_dict = dict((k[len(src_prefix):], v) for k, v in src_state_dict.items() if k.startswith(src_prefix))
32 | if len(src_block_state_dict) == 0:
33 | break
34 |
35 | dst_block_state_dict = {}
36 | feat_dim = src_block_state_dict['ln_1.weight'].size(0)
37 |
38 | for i, dst_name in enumerate(('q', 'k', 'v')):
39 | dst_block_state_dict['attn.%s_proj.weight' % dst_name] = src_block_state_dict['attn.in_proj_weight'][feat_dim * i: feat_dim * (i + 1)]
40 | dst_block_state_dict['attn.%s_proj.bias' % dst_name] = src_block_state_dict['attn.in_proj_bias'][feat_dim * i: feat_dim * (i + 1)]
41 |
42 | dst_block_state_dict['attn.out_proj.weight'] = src_block_state_dict['attn.out_proj.weight']
43 | dst_block_state_dict['attn.out_proj.bias'] = src_block_state_dict['attn.out_proj.bias']
44 |
45 | dst_block_state_dict['mlp.fc1.weight'] = src_block_state_dict['mlp.c_fc.weight']
46 | dst_block_state_dict['mlp.fc1.bias'] = src_block_state_dict['mlp.c_fc.bias']
47 | dst_block_state_dict['mlp.fc2.weight'] = src_block_state_dict['mlp.c_proj.weight']
48 | dst_block_state_dict['mlp.fc2.bias'] = src_block_state_dict['mlp.c_proj.bias']
49 |
50 | dst_block_state_dict['norm1.weight'] = src_block_state_dict['ln_1.weight']
51 | dst_block_state_dict['norm1.bias'] = src_block_state_dict['ln_1.bias']
52 | dst_block_state_dict['norm2.weight'] = src_block_state_dict['ln_2.weight']
53 | dst_block_state_dict['norm2.bias'] = src_block_state_dict['ln_2.bias']
54 |
55 | dst_state_dict.update(dict((dst_prefix + k, v) for k, v in dst_block_state_dict.items()))
56 | block_idx += 1
57 |
58 | return dst_state_dict
59 |
60 |
61 | weight_loader_fn_dict = {
62 | 'clip': load_weights_clip,
63 | }
64 |
--------------------------------------------------------------------------------