├── .gitignore ├── README.md ├── checkpoint.py ├── data └── k400_class_mappings.json ├── figs ├── arch.png └── k400.png ├── main.py ├── model.py ├── scripts ├── eval_k400_vitb16_16f_dec4x768.sh ├── eval_k400_vitb16_32f_dec4x768.sh ├── eval_k400_vitb16_8f_dec4x768.sh ├── eval_k400_vitl14_16f_dec4x1024.sh ├── eval_k400_vitl14_32f_dec4x1024.sh ├── eval_k400_vitl14_8f_dec4x1024.sh ├── train_k400_vitb16_16f_dec4x768.sh ├── train_k400_vitb16_32f_dec4x768.sh ├── train_k400_vitb16_8f_dec4x768.sh ├── train_k400_vitl14_16f_dec4x1024.sh ├── train_k400_vitl14_32f_dec4x1024.sh └── train_k400_vitl14_8f_dec4x1024.sh ├── video_dataset ├── __init__.py ├── dataloader.py ├── dataset.py ├── rand_augment.py ├── random_erasing.py └── transform.py ├── vision_transformer.py └── weight_loaders.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.py[cod] 3 | *$py.class 4 | 5 | runs/ 6 | 7 | video_dataset/io_internal.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Frozen CLIP models are Efficient Video Learners 2 | 3 | This is the official implementation of the paper [Frozen CLIP models are Efficient Video Learners](https://arxiv.org/abs/2208.03550) 4 | 5 | ``` 6 | @article{lin2022frozen, 7 | title={Frozen CLIP Models are Efficient Video Learners}, 8 | author={Lin, Ziyi and Geng, Shijie and Zhang, Renrui and Gao, Peng and de Melo, Gerard and Wang, Xiaogang and Dai, Jifeng and Qiao, Yu and Li, Hongsheng}, 9 | journal={arXiv preprint arXiv:2208.03550}, 10 | year={2022} 11 | } 12 | ``` 13 | 14 | ## Introduction 15 | 16 | The overall architecture of the EVL framework includes a trainable Transformer decoder, trainable local temporal modules and a pretrained, fixed image backbone 17 | (CLIP is used for instance). 18 | 19 | 20 | 21 | Using a fixed backbone significantly saves training time, and we managed to train a ViT-B/16 with 8 frames for 50 epochs in 60 GPU-hours (NVIDIA V100). 22 | 23 | Despite with a small training computation and memory consumption, EVL models achieves high performance on Kinetics-400. A comparison with state-of-the-art methods 24 | are as follows 25 | 26 | 27 | 28 | ## Installation 29 | 30 | We tested the released code with the following conda environment 31 | 32 | ``` 33 | conda create -n pt1.9.0cu11.1_official -c pytorch -c conda-forge pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 cudatoolkit torchvision av 34 | ``` 35 | 36 | ## Data Preparation 37 | 38 | We expect that `--train_list_path` and `--val_list_path` command line arguments to be a data list file of the following format 39 | ``` 40 | 41 | 42 | ... 43 | 44 | ``` 45 | where `` points to a video file, and `` is an integer between `0` and `num_classes - 1`. 46 | `--num_classes` should also be specified in the command line argument. 47 | 48 | Additionally, `` might be a relative path when `--data_root` is specified, and the actual path will be 49 | relative to the path passed as `--data_root`. 50 | 51 | The class mappings in the open-source weights are provided at [Kinetics-400 class mappings](data/k400_class_mappings.json) 52 | 53 | ## Backbone Preparation 54 | 55 | CLIP weights need to be downloaded from [CLIP official repo](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py#L30) 56 | and passed to the `--backbone_path` command line argument. 57 | 58 | ## Script Usage 59 | 60 | Training and evaliation scripts are provided in the scripts folder. 61 | Scripts should be ready to run once the environment is setup and 62 | `--backbone_path`, `--train_list_path` and `--val_list_path` are replaced with your own paths. 63 | 64 | For other command line arguments please see the help message for usage. 65 | 66 | ## Kinetics-400 Main Results 67 | 68 | This is a re-implementation for open-source use. 69 | We are still re-running some models, and their scripts, weights and logs will be released later. 70 | In the following table we report the re-run accuracy, which may be slightly different from the original paper (typically +/-0.1%) 71 | 72 | | Backbone | Decoder Layers | #frames x stride | top-1 | top-5 | Script | Model | Log | 73 | | - | - | - | - | - | - | - | - | 74 | | ViT-B/16 | 4 | 8 x 16 | 82.8 | 95.8 | [script](scripts/train_k400_vitb16_8f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1DoGjvDdkJoSa9i-wq1lh6QoEZIa4xTB3/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1-9vgsXMpnWBI9MxQV7SSQhkPfLomoYY3/view?usp=sharing) | 75 | | ViT-B/16 | 4 | 16 x 16 | 83.7 | 96.2 | [script](scripts/train_k400_vitb16_16f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1dax4qUIOEI_QzYXv31J-87cDkonQetVQ/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1l2ivY28jUpwSmafQZvwtUo7tvm42i0PL/view?usp=sharing) | 76 | | ViT-B/16 | 4 | 32 x 8 | 84.3 | 96.6 | [script](scripts/train_k400_vitb16_32f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1fzFM5pD39Kfp8xRAJuWaXR9RALLmnoeU/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1X1ZOdSCxXVeMpNhr_bviNKlRfJa5SMD7/view?usp=sharing) | 77 | | ViT-L/14 | 4 | 8 x 16 | 86.3 | 97.2 | [script](scripts/train_k400_vitl14_8f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1AkdF4CkOVW2uiycCVqCxS397oYxNISAI/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1OJFBmaE_tAwTzG-4i0CLQmhwGnN0psx1/view?usp=sharing) | 78 | | ViT-L/14 | 4 | 16 x 16 | 86.9 | 97.4 | [script](scripts/train_k400_vitl14_16f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1CTV9geLD3HLWzByAQUOf_m0F_g2lE3rg/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1a2iC4tQvjWFMI3UrEv2chuHwVrF6p9YF/view?usp=sharing) | 79 | | ViT-L/14 | 4 | 32 x 8 | 87.7 | 97.6 | [script](scripts/train_k400_vitl14_32f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1zNFNCKwP5owakELlnTCD20cpVQBqgJrB/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1dK7qoz3McYrmfS09FfreXC-LjUM7l0u4/view?usp=sharing) | 80 | | ViT-L/14 (336px) | 4 | 32 x 8 | 87.7 | 97.8 | | | | 81 | 82 | ## Data Loading Speed 83 | 84 | As the training process is fast, video frames are consumed at a very high rate. 85 | For easier installation, the current version uses PyTorch-builtin data loaders. 86 | They are not very efficient and can become a bottleneck when using ViT-B as backbones. 87 | We provide a `--dummy_dataset` option to bypass actual video decoding for training speed measurement. 88 | The model accuracy should not be affected. 89 | Our internal data loader is pure C++-based and does not bottleneck training by much on a machine with 2x Xeon Gold 6148 CPUs and 4x V100 GPUs. 90 | 91 | 92 | ## Acknowledgements 93 | 94 | The data loader code is modified from [PySlowFast](https://github.com/facebookresearch/SlowFast). Thanks for their awesome work! 95 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | def setup_arg_parser(parser: argparse.ArgumentParser): 11 | parser.add_argument('--checkpoint_dir', type=str, 12 | help='checkpoint output path') 13 | parser.add_argument('--auto_resume', action='store_true', 14 | help='auto resume from the last checkpoint from checkpoint_dir') 15 | parser.add_argument('--resume_path', type=str, 16 | help='resume from manually specified checkpoint file, overriding auto_resume') 17 | parser.add_argument('--pretrain', type=str, 18 | help='path to pretrained weights. will NOT override auto_resume of resume_path, ' 19 | 'load optimizer state or enforce strict matching of checkpoint and model weights.') 20 | 21 | 22 | def _find_autoresume_path(args: argparse.Namespace): 23 | print('Trying to auto resume from path:', args.checkpoint_dir) 24 | 25 | if os.path.isdir(args.checkpoint_dir): 26 | checkpoint_files = [x for x in os.listdir(args.checkpoint_dir) if x.startswith('checkpoint-') and x.endswith('.pth')] 27 | checkpoint_iters = [] 28 | for x in checkpoint_files: 29 | try: 30 | x = x[len('checkpoint-'): -len('.pth')] 31 | x = int(x) 32 | except ValueError: 33 | continue 34 | checkpoint_iters.append(x) 35 | else: 36 | checkpoint_iters = [] 37 | 38 | if len(checkpoint_iters) == 0: 39 | print('Did not find a valid checkpoint file.') 40 | else: 41 | checkpoint_iters.sort() 42 | args.resume_path = os.path.join(args.checkpoint_dir, 'checkpoint-%d.pth' % checkpoint_iters[-1]) 43 | print(f'Found {len(checkpoint_iters)} checkpoint file(s).') 44 | 45 | 46 | def resume_from_checkpoint( 47 | model: torch.nn.Module, 48 | optimizer: torch.optim.Optimizer, 49 | lr_sched: torch.optim.lr_scheduler._LRScheduler, 50 | loss_scaler: torch.cuda.amp.grad_scaler.GradScaler, 51 | args: argparse.Namespace, 52 | ) -> int: 53 | if args.pretrain is not None: 54 | print(f'Loading pretrain model: {args.pretrain}') 55 | ckpt = torch.load(args.pretrain, map_location='cpu') 56 | print(model.load_state_dict(ckpt['model'], strict=False)) 57 | 58 | # returns resume_step on successful resume, or 0 otherwise. 59 | if args.auto_resume and args.resume_path is None: 60 | _find_autoresume_path(args) 61 | 62 | if args.resume_path is None: 63 | print('Not resuming from a checkpoint.') 64 | return 0 65 | else: 66 | print(f'Resuming from checkpoint file {args.resume_path}') 67 | ckpt = torch.load(args.resume_path, map_location='cpu') 68 | model.load_state_dict(ckpt['model'], strict=True) 69 | if 'optimizer' in ckpt: 70 | optimizer.load_state_dict(ckpt['optimizer']) 71 | lr_sched.load_state_dict(ckpt['lr_sched']) 72 | loss_scaler.load_state_dict(ckpt['loss_scaler']) 73 | return ckpt['next_step'] 74 | else: 75 | print('Optimizer state is NOT found in checkpoint.') 76 | return 0 77 | 78 | 79 | def save_checkpoint( 80 | model: torch.nn.Module, 81 | optimizer: torch.optim.Optimizer, 82 | lr_sched: torch.optim.lr_scheduler._LRScheduler, 83 | loss_scaler: torch.cuda.amp.grad_scaler.GradScaler, 84 | next_step: int, 85 | args: argparse.Namespace, 86 | ): 87 | if args.checkpoint_dir is None: 88 | return 89 | 90 | if not os.path.isdir(args.checkpoint_dir): 91 | os.makedirs(args.checkpoint_dir) 92 | 93 | to_save = { 94 | 'model': model.state_dict(), 95 | 'optimizer': optimizer.state_dict(), 96 | 'lr_sched': lr_sched.state_dict(), 97 | 'loss_scaler': loss_scaler.state_dict(), 98 | 'next_step': next_step, 99 | } 100 | torch.save(to_save, os.path.join(args.checkpoint_dir, f'checkpoint-{next_step}.pth')) 101 | -------------------------------------------------------------------------------- /data/k400_class_mappings.json: -------------------------------------------------------------------------------- 1 | [ 2 | "abseiling", 3 | "air drumming", 4 | "answering questions", 5 | "applauding", 6 | "applying cream", 7 | "archery", 8 | "arm wrestling", 9 | "arranging flowers", 10 | "assembling computer", 11 | "auctioning", 12 | "baby waking up", 13 | "baking cookies", 14 | "balloon blowing", 15 | "bandaging", 16 | "barbequing", 17 | "bartending", 18 | "beatboxing", 19 | "bee keeping", 20 | "belly dancing", 21 | "bench pressing", 22 | "bending back", 23 | "bending metal", 24 | "biking through snow", 25 | "blasting sand", 26 | "blowing glass", 27 | "blowing leaves", 28 | "blowing nose", 29 | "blowing out candles", 30 | "bobsledding", 31 | "bookbinding", 32 | "bouncing on trampoline", 33 | "bowling", 34 | "braiding hair", 35 | "breading or breadcrumbing", 36 | "breakdancing", 37 | "brush painting", 38 | "brushing hair", 39 | "brushing teeth", 40 | "building cabinet", 41 | "building shed", 42 | "bungee jumping", 43 | "busking", 44 | "canoeing or kayaking", 45 | "capoeira", 46 | "carrying baby", 47 | "cartwheeling", 48 | "carving pumpkin", 49 | "catching fish", 50 | "catching or throwing baseball", 51 | "catching or throwing frisbee", 52 | "catching or throwing softball", 53 | "celebrating", 54 | "changing oil", 55 | "changing wheel", 56 | "checking tires", 57 | "cheerleading", 58 | "chopping wood", 59 | "clapping", 60 | "clay pottery making", 61 | "clean and jerk", 62 | "cleaning floor", 63 | "cleaning gutters", 64 | "cleaning pool", 65 | "cleaning shoes", 66 | "cleaning toilet", 67 | "cleaning windows", 68 | "climbing a rope", 69 | "climbing ladder", 70 | "climbing tree", 71 | "contact juggling", 72 | "cooking chicken", 73 | "cooking egg", 74 | "cooking on campfire", 75 | "cooking sausages", 76 | "counting money", 77 | "country line dancing", 78 | "cracking neck", 79 | "crawling baby", 80 | "crossing river", 81 | "crying", 82 | "curling hair", 83 | "cutting nails", 84 | "cutting pineapple", 85 | "cutting watermelon", 86 | "dancing ballet", 87 | "dancing charleston", 88 | "dancing gangnam style", 89 | "dancing macarena", 90 | "deadlifting", 91 | "decorating the christmas tree", 92 | "digging", 93 | "dining", 94 | "disc golfing", 95 | "diving cliff", 96 | "dodgeball", 97 | "doing aerobics", 98 | "doing laundry", 99 | "doing nails", 100 | "drawing", 101 | "dribbling basketball", 102 | "drinking", 103 | "drinking beer", 104 | "drinking shots", 105 | "driving car", 106 | "driving tractor", 107 | "drop kicking", 108 | "drumming fingers", 109 | "dunking basketball", 110 | "dying hair", 111 | "eating burger", 112 | "eating cake", 113 | "eating carrots", 114 | "eating chips", 115 | "eating doughnuts", 116 | "eating hotdog", 117 | "eating ice cream", 118 | "eating spaghetti", 119 | "eating watermelon", 120 | "egg hunting", 121 | "exercising arm", 122 | "exercising with an exercise ball", 123 | "extinguishing fire", 124 | "faceplanting", 125 | "feeding birds", 126 | "feeding fish", 127 | "feeding goats", 128 | "filling eyebrows", 129 | "finger snapping", 130 | "fixing hair", 131 | "flipping pancake", 132 | "flying kite", 133 | "folding clothes", 134 | "folding napkins", 135 | "folding paper", 136 | "front raises", 137 | "frying vegetables", 138 | "garbage collecting", 139 | "gargling", 140 | "getting a haircut", 141 | "getting a tattoo", 142 | "giving or receiving award", 143 | "golf chipping", 144 | "golf driving", 145 | "golf putting", 146 | "grinding meat", 147 | "grooming dog", 148 | "grooming horse", 149 | "gymnastics tumbling", 150 | "hammer throw", 151 | "headbanging", 152 | "headbutting", 153 | "high jump", 154 | "high kick", 155 | "hitting baseball", 156 | "hockey stop", 157 | "holding snake", 158 | "hopscotch", 159 | "hoverboarding", 160 | "hugging", 161 | "hula hooping", 162 | "hurdling", 163 | "hurling (sport)", 164 | "ice climbing", 165 | "ice fishing", 166 | "ice skating", 167 | "ironing", 168 | "javelin throw", 169 | "jetskiing", 170 | "jogging", 171 | "juggling balls", 172 | "juggling fire", 173 | "juggling soccer ball", 174 | "jumping into pool", 175 | "jumpstyle dancing", 176 | "kicking field goal", 177 | "kicking soccer ball", 178 | "kissing", 179 | "kitesurfing", 180 | "knitting", 181 | "krumping", 182 | "laughing", 183 | "laying bricks", 184 | "long jump", 185 | "lunge", 186 | "making a cake", 187 | "making a sandwich", 188 | "making bed", 189 | "making jewelry", 190 | "making pizza", 191 | "making snowman", 192 | "making sushi", 193 | "making tea", 194 | "marching", 195 | "massaging back", 196 | "massaging feet", 197 | "massaging legs", 198 | "massaging person's head", 199 | "milking cow", 200 | "mopping floor", 201 | "motorcycling", 202 | "moving furniture", 203 | "mowing lawn", 204 | "news anchoring", 205 | "opening bottle", 206 | "opening present", 207 | "paragliding", 208 | "parasailing", 209 | "parkour", 210 | "passing American football (in game)", 211 | "passing American football (not in game)", 212 | "peeling apples", 213 | "peeling potatoes", 214 | "petting animal (not cat)", 215 | "petting cat", 216 | "picking fruit", 217 | "planting trees", 218 | "plastering", 219 | "playing accordion", 220 | "playing badminton", 221 | "playing bagpipes", 222 | "playing basketball", 223 | "playing bass guitar", 224 | "playing cards", 225 | "playing cello", 226 | "playing chess", 227 | "playing clarinet", 228 | "playing controller", 229 | "playing cricket", 230 | "playing cymbals", 231 | "playing didgeridoo", 232 | "playing drums", 233 | "playing flute", 234 | "playing guitar", 235 | "playing harmonica", 236 | "playing harp", 237 | "playing ice hockey", 238 | "playing keyboard", 239 | "playing kickball", 240 | "playing monopoly", 241 | "playing organ", 242 | "playing paintball", 243 | "playing piano", 244 | "playing poker", 245 | "playing recorder", 246 | "playing saxophone", 247 | "playing squash or racquetball", 248 | "playing tennis", 249 | "playing trombone", 250 | "playing trumpet", 251 | "playing ukulele", 252 | "playing violin", 253 | "playing volleyball", 254 | "playing xylophone", 255 | "pole vault", 256 | "presenting weather forecast", 257 | "pull ups", 258 | "pumping fist", 259 | "pumping gas", 260 | "punching bag", 261 | "punching person (boxing)", 262 | "push up", 263 | "pushing car", 264 | "pushing cart", 265 | "pushing wheelchair", 266 | "reading book", 267 | "reading newspaper", 268 | "recording music", 269 | "riding a bike", 270 | "riding camel", 271 | "riding elephant", 272 | "riding mechanical bull", 273 | "riding mountain bike", 274 | "riding mule", 275 | "riding or walking with horse", 276 | "riding scooter", 277 | "riding unicycle", 278 | "ripping paper", 279 | "robot dancing", 280 | "rock climbing", 281 | "rock scissors paper", 282 | "roller skating", 283 | "running on treadmill", 284 | "sailing", 285 | "salsa dancing", 286 | "sanding floor", 287 | "scrambling eggs", 288 | "scuba diving", 289 | "setting table", 290 | "shaking hands", 291 | "shaking head", 292 | "sharpening knives", 293 | "sharpening pencil", 294 | "shaving head", 295 | "shaving legs", 296 | "shearing sheep", 297 | "shining shoes", 298 | "shooting basketball", 299 | "shooting goal (soccer)", 300 | "shot put", 301 | "shoveling snow", 302 | "shredding paper", 303 | "shuffling cards", 304 | "side kick", 305 | "sign language interpreting", 306 | "singing", 307 | "situp", 308 | "skateboarding", 309 | "ski jumping", 310 | "skiing (not slalom or crosscountry)", 311 | "skiing crosscountry", 312 | "skiing slalom", 313 | "skipping rope", 314 | "skydiving", 315 | "slacklining", 316 | "slapping", 317 | "sled dog racing", 318 | "smoking", 319 | "smoking hookah", 320 | "snatch weight lifting", 321 | "sneezing", 322 | "sniffing", 323 | "snorkeling", 324 | "snowboarding", 325 | "snowkiting", 326 | "snowmobiling", 327 | "somersaulting", 328 | "spinning poi", 329 | "spray painting", 330 | "spraying", 331 | "springboard diving", 332 | "squat", 333 | "sticking tongue out", 334 | "stomping grapes", 335 | "stretching arm", 336 | "stretching leg", 337 | "strumming guitar", 338 | "surfing crowd", 339 | "surfing water", 340 | "sweeping floor", 341 | "swimming backstroke", 342 | "swimming breast stroke", 343 | "swimming butterfly stroke", 344 | "swing dancing", 345 | "swinging legs", 346 | "swinging on something", 347 | "sword fighting", 348 | "tai chi", 349 | "taking a shower", 350 | "tango dancing", 351 | "tap dancing", 352 | "tapping guitar", 353 | "tapping pen", 354 | "tasting beer", 355 | "tasting food", 356 | "testifying", 357 | "texting", 358 | "throwing axe", 359 | "throwing ball", 360 | "throwing discus", 361 | "tickling", 362 | "tobogganing", 363 | "tossing coin", 364 | "tossing salad", 365 | "training dog", 366 | "trapezing", 367 | "trimming or shaving beard", 368 | "trimming trees", 369 | "triple jump", 370 | "tying bow tie", 371 | "tying knot (not on a tie)", 372 | "tying tie", 373 | "unboxing", 374 | "unloading truck", 375 | "using computer", 376 | "using remote controller (not gaming)", 377 | "using segway", 378 | "vault", 379 | "waiting in line", 380 | "walking the dog", 381 | "washing dishes", 382 | "washing feet", 383 | "washing hair", 384 | "washing hands", 385 | "water skiing", 386 | "water sliding", 387 | "watering plants", 388 | "waxing back", 389 | "waxing chest", 390 | "waxing eyebrows", 391 | "waxing legs", 392 | "weaving basket", 393 | "welding", 394 | "whistling", 395 | "windsurfing", 396 | "wrapping present", 397 | "wrestling", 398 | "writing", 399 | "yawning", 400 | "yoga", 401 | "zumba" 402 | ] 403 | -------------------------------------------------------------------------------- /figs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/efficient-video-recognition/b282e8b49e4ecd24f7357361719cf4fa5ab9c40d/figs/arch.png -------------------------------------------------------------------------------- /figs/k400.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/efficient-video-recognition/b282e8b49e4ecd24f7357361719cf4fa5ab9c40d/figs/k400.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from datetime import datetime 5 | import builtins 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | import video_dataset 11 | import checkpoint 12 | from model import EVLTransformer 13 | from video_dataset import dataloader 14 | from weight_loaders import weight_loader_fn_dict 15 | from vision_transformer import vit_presets 16 | 17 | def setup_print(is_master: bool): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | builtin_print = builtins.print 22 | 23 | def print(*args, **kwargs): 24 | force = kwargs.pop('force', False) 25 | if is_master or force: 26 | now = datetime.now().time() 27 | builtin_print('[{}] '.format(now), end='') # print with time stamp 28 | builtin_print(*args, **kwargs) 29 | 30 | builtins.print = print 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser() 35 | 36 | video_dataset.setup_arg_parser(parser) 37 | checkpoint.setup_arg_parser(parser) 38 | 39 | parser.add_argument('--num_steps', type=int, 40 | help='number of training steps') 41 | parser.add_argument('--eval_only', action='store_true', 42 | help='run evaluation only') 43 | parser.add_argument('--save_freq', type=int, default=5000, 44 | help='save a checkpoint every N steps') 45 | parser.add_argument('--eval_freq', type=int, default=5000, 46 | help='evaluate every N steps') 47 | parser.add_argument('--print_freq', type=int, default=10, 48 | help='print log message every N steps') 49 | 50 | parser.add_argument('--backbone', type=str, choices=vit_presets.keys(), default='ViT-B/16-lnpre', 51 | help='the backbone variant used to generate image feature maps') 52 | parser.add_argument('--backbone_path', type=str, 53 | help='path to pretrained backbone weights') 54 | parser.add_argument('--backbone_type', type=str, default='clip', choices=weight_loader_fn_dict.keys(), 55 | help='type of backbone weights (used to determine how to convert state_dict from different pretraining codebase)') 56 | parser.add_argument('--finetune_backbone', action='store_true', 57 | help='finetune backbone weights') 58 | parser.add_argument('--decoder_num_layers', type=int, default=4, 59 | help='number of decoder layers') 60 | parser.add_argument('--decoder_qkv_dim', type=int, default=768, 61 | help='q (k, v) projection output dimensions in decoder attention layers') 62 | parser.add_argument('--decoder_num_heads', type=int, default=12, 63 | help='number of heads in decoder attention layers') 64 | parser.add_argument('--decoder_mlp_factor', type=float, default=4.0, 65 | help='expansion factor of feature dimension in the middle of decoder MLPs') 66 | parser.add_argument('--num_classes', type=int, default=400, 67 | help='number of classes') 68 | parser.add_argument('--cls_dropout', type=float, default=0.5, 69 | help='dropout rate applied before the final classification linear projection') 70 | parser.add_argument('--decoder_mlp_dropout', type=float, default=0.5, 71 | help='dropout rate applied in MLP layers in the decoder') 72 | parser.add_argument('--no_temporal_conv', action='store_false', dest='temporal_conv', 73 | help='disable temporal convolution on frame features') 74 | parser.add_argument('--no_temporal_pos_embed', action='store_false', dest='temporal_pos_embed', 75 | help='disable temporal position embeddings added to frame features') 76 | parser.add_argument('--no_temporal_cross_attention', action='store_false', dest='temporal_cross_attention', 77 | help='disable temporal cross attention on frame query and key features') 78 | parser.set_defaults(temporal_conv=True, temporal_pos_embed=True, temporal_cross_attention=True) 79 | 80 | parser.add_argument('--lr', type=float, default=4e-4, 81 | help='learning rate') 82 | parser.add_argument('--weight_decay', type=float, default=0.05, 83 | help='optimizer weight decay') 84 | parser.add_argument('--disable_fp16', action='store_false', dest='fp16', 85 | help='disable fp16 during training or inference') 86 | parser.set_defaults(fp16=True) 87 | 88 | parser.add_argument('--batch_split', type=int, default=1, 89 | help='optionally split the batch into smaller shards and forward/backward one shard ' 90 | 'at a time to avoid out-of-memory error.') 91 | 92 | args = parser.parse_args() 93 | 94 | dist.init_process_group('nccl') 95 | setup_print(dist.get_rank() == 0) 96 | cuda_device_id = dist.get_rank() % torch.cuda.device_count() 97 | torch.cuda.set_device(cuda_device_id) 98 | 99 | model = EVLTransformer( 100 | backbone_name=args.backbone, 101 | backbone_type=args.backbone_type, 102 | backbone_path=args.backbone_path, 103 | backbone_mode='finetune' if args.finetune_backbone else ('freeze_fp16' if args.fp16 else 'freeze_fp32'), 104 | decoder_num_layers=args.decoder_num_layers, 105 | decoder_qkv_dim=args.decoder_qkv_dim, 106 | decoder_num_heads=args.decoder_num_heads, 107 | decoder_mlp_factor=args.decoder_mlp_factor, 108 | num_classes=args.num_classes, 109 | enable_temporal_conv=args.temporal_conv, 110 | enable_temporal_pos_embed=args.temporal_pos_embed, 111 | enable_temporal_cross_attention=args.temporal_cross_attention, 112 | cls_dropout=args.cls_dropout, 113 | decoder_mlp_dropout=args.decoder_mlp_dropout, 114 | num_frames=args.num_frames, 115 | ) 116 | print(model) 117 | model.cuda() 118 | model = torch.nn.parallel.DistributedDataParallel( 119 | model, device_ids=[cuda_device_id], output_device=cuda_device_id, 120 | ) 121 | 122 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 123 | lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps) 124 | loss_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=args.fp16) 125 | criterion = torch.nn.CrossEntropyLoss() 126 | 127 | resume_step = checkpoint.resume_from_checkpoint(model, optimizer, lr_sched, loss_scaler, args) 128 | 129 | val_loader = video_dataset.create_val_loader(args) 130 | if args.eval_only: 131 | print('Running in eval_only mode.') 132 | model.eval() 133 | evaluate(model, val_loader) 134 | return 135 | else: 136 | assert args.train_list_path is not None, 'Train list path must be specified if not in eval_only mode.' 137 | train_loader = video_dataset.create_train_loader(args, resume_step=resume_step) 138 | 139 | assert len(train_loader) == args.num_steps - resume_step 140 | batch_st, train_st = datetime.now(), datetime.now() 141 | for i, (data, labels) in enumerate(train_loader, resume_step): 142 | data, labels = data.cuda(), labels.cuda() 143 | data_ed = datetime.now() 144 | 145 | optimizer.zero_grad() 146 | 147 | assert data.size(0) % args.batch_split == 0 148 | split_size = data.size(0) // args.batch_split 149 | hit1, hit5, loss_value = 0, 0, 0 150 | for j in range(args.batch_split): 151 | data_slice = data[split_size * j: split_size * (j + 1)] 152 | labels_slice = labels[split_size * j: split_size * (j + 1)] 153 | 154 | with torch.cuda.amp.autocast(args.fp16): 155 | logits = model(data_slice) 156 | loss = criterion(logits, labels_slice) 157 | 158 | if labels.dtype == torch.long: # no mixup, can calculate accuracy 159 | hit1 += (logits.topk(1, dim=1)[1] == labels_slice.view(-1, 1)).sum().item() 160 | hit5 += (logits.topk(5, dim=1)[1] == labels_slice.view(-1, 1)).sum().item() 161 | loss_value += loss.item() / args.batch_split 162 | 163 | loss_scaler.scale(loss / args.batch_split).backward() 164 | 165 | loss_scaler.step(optimizer) 166 | loss_scaler.update() 167 | lr_sched.step() 168 | 169 | batch_ed = datetime.now() 170 | 171 | if i % args.print_freq == 0: 172 | sync_tensor = torch.Tensor([loss_value, hit1 / data.size(0), hit5 / data.size(0)]).cuda() 173 | dist.all_reduce(sync_tensor) 174 | sync_tensor = sync_tensor.cpu() / dist.get_world_size() 175 | loss_value, acc1, acc5 = sync_tensor.tolist() 176 | 177 | print( 178 | f'batch_time: {(batch_ed - batch_st).total_seconds():.3f} ' 179 | f'data_time: {(data_ed - batch_st).total_seconds():.3f} ' 180 | f'ETA: {(batch_ed - train_st) / (i - resume_step + 1) * (args.num_steps - i - 1)} | ' 181 | f'lr: {optimizer.param_groups[0]["lr"]:.6f} ' 182 | f'loss: {loss_value:.6f}' + ( 183 | f' acc1: {acc1 * 100:.2f}% acc5: {acc5 * 100:.2f}%' if labels.dtype == torch.long else '' 184 | ) 185 | ) 186 | 187 | if (i + 1) % args.eval_freq == 0: 188 | print('Start model evaluation at step', i + 1) 189 | model.eval() 190 | evaluate(model, val_loader) 191 | model.train() 192 | 193 | if (i + 1) % args.save_freq == 0: 194 | checkpoint.save_checkpoint(model, optimizer, lr_sched, loss_scaler, i + 1, args) 195 | 196 | batch_st = datetime.now() 197 | 198 | 199 | def evaluate(model: torch.nn.Module, loader: torch.utils.data.DataLoader): 200 | tot, hit1, hit5 = 0, 0, 0 201 | eval_st = datetime.now() 202 | for data, labels in loader: 203 | data, labels = data.cuda(), labels.cuda() 204 | assert data.size(0) == 1 205 | if data.ndim == 6: 206 | data = data[0] # now the first dimension is number of views 207 | 208 | with torch.no_grad(): 209 | logits = model(data) 210 | scores = logits.softmax(dim=-1).mean(dim=0) 211 | 212 | tot += 1 213 | hit1 += (scores.topk(1)[1] == labels).sum().item() 214 | hit5 += (scores.topk(5)[1] == labels).sum().item() 215 | 216 | if tot % 20 == 0: 217 | print(f'[Evaluation] num_samples: {tot} ' 218 | f'ETA: {(datetime.now() - eval_st) / tot * (len(loader) - tot)} ' 219 | f'cumulative_acc1: {hit1 / tot * 100.:.2f}% ' 220 | f'cumulative_acc5: {hit5 / tot * 100.:.2f}%') 221 | 222 | sync_tensor = torch.LongTensor([tot, hit1, hit5]).cuda() 223 | dist.all_reduce(sync_tensor) 224 | tot, hit1, hit5 = sync_tensor.cpu().tolist() 225 | 226 | print(f'Accuracy on validation set: top1={hit1 / tot * 100:.2f}%, top5={hit5 / tot * 100:.2f}%') 227 | 228 | 229 | if __name__ == '__main__': main() 230 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from typing import Dict, Iterable, List, Tuple 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from vision_transformer import QuickGELU, Attention 11 | from weight_loaders import weight_loader_fn_dict 12 | from vision_transformer import ( 13 | VisionTransformer2D, TransformerDecoderLayer, 14 | model_to_fp16, vit_presets, 15 | ) 16 | 17 | 18 | class TemporalCrossAttention(nn.Module): 19 | 20 | def __init__( 21 | self, 22 | spatial_size: Tuple[int, int] = (14, 14), 23 | feature_dim: int = 768, 24 | ): 25 | super().__init__() 26 | 27 | self.spatial_size = spatial_size 28 | 29 | w_size = np.prod([x * 2 - 1 for x in spatial_size]) 30 | self.w1 = nn.Parameter(torch.zeros([w_size, feature_dim])) 31 | self.w2 = nn.Parameter(torch.zeros([w_size, feature_dim])) 32 | 33 | idx_tensor = torch.zeros([np.prod(spatial_size) for _ in (0, 1)], dtype=torch.long) 34 | for q in range(np.prod(spatial_size)): 35 | qi, qj = q // spatial_size[1], q % spatial_size[1] 36 | for k in range(np.prod(spatial_size)): 37 | ki, kj = k // spatial_size[1], k % spatial_size[1] 38 | i_offs = qi - ki + spatial_size[0] - 1 39 | j_offs = qj - kj + spatial_size[1] - 1 40 | idx_tensor[q, k] = i_offs * (spatial_size[1] * 2 - 1) + j_offs 41 | self.idx_tensor = idx_tensor 42 | 43 | 44 | def forward_half(self, q: torch.Tensor, k: torch.Tensor, w: torch.Tensor) -> torch.Tensor: 45 | q, k = q[:, :, 1:], k[:, :, 1:] # remove cls token 46 | 47 | assert q.size() == k.size() 48 | assert q.size(2) == np.prod(self.spatial_size) 49 | 50 | attn = torch.einsum('ntqhd,ntkhd->ntqkh', q / (q.size(-1) ** 0.5), k) 51 | attn = attn.softmax(dim=-2).mean(dim=-1) # L, L, N, T 52 | 53 | self.idx_tensor = self.idx_tensor.to(w.device) 54 | w_unroll = w[self.idx_tensor] # L, L, C 55 | ret = torch.einsum('ntqk,qkc->ntqc', attn, w_unroll) 56 | 57 | return ret 58 | 59 | 60 | def forward(self, q: torch.Tensor, k: torch.Tensor): 61 | N, T, L, H, D = q.size() 62 | assert L == np.prod(self.spatial_size) + 1 63 | 64 | ret = torch.zeros([N, T, L, self.w1.size(-1)], device='cuda') 65 | ret[:, 1:, 1:, :] += self.forward_half(q[:, 1:, :, :, :], k[:, :-1, :, :, :], self.w1) 66 | ret[:, :-1, 1:, :] += self.forward_half(q[:, :-1, :, :, :], k[:, 1:, :, :, :], self.w2) 67 | 68 | return ret 69 | 70 | 71 | class EVLDecoder(nn.Module): 72 | 73 | def __init__( 74 | self, 75 | num_frames: int = 8, 76 | spatial_size: Tuple[int, int] = (14, 14), 77 | num_layers: int = 4, 78 | in_feature_dim: int = 768, 79 | qkv_dim: int = 768, 80 | num_heads: int = 12, 81 | mlp_factor: float = 4.0, 82 | enable_temporal_conv: bool = True, 83 | enable_temporal_pos_embed: bool = True, 84 | enable_temporal_cross_attention: bool = True, 85 | mlp_dropout: float = 0.5, 86 | ): 87 | super().__init__() 88 | 89 | self.enable_temporal_conv = enable_temporal_conv 90 | self.enable_temporal_pos_embed = enable_temporal_pos_embed 91 | self.enable_temporal_cross_attention = enable_temporal_cross_attention 92 | self.num_layers = num_layers 93 | 94 | self.decoder_layers = nn.ModuleList( 95 | [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout) for _ in range(num_layers)] 96 | ) 97 | 98 | if enable_temporal_conv: 99 | self.temporal_conv = nn.ModuleList( 100 | [nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)] 101 | ) 102 | if enable_temporal_pos_embed: 103 | self.temporal_pos_embed = nn.ParameterList( 104 | [nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)] 105 | ) 106 | if enable_temporal_cross_attention: 107 | self.cross_attention = nn.ModuleList( 108 | [TemporalCrossAttention(spatial_size, in_feature_dim) for _ in range(num_layers)] 109 | ) 110 | 111 | self.cls_token = nn.Parameter(torch.zeros([in_feature_dim])) 112 | 113 | 114 | def _initialize_weights(self): 115 | nn.init.normal_(self.cls_token, std=0.02) 116 | 117 | 118 | def forward(self, in_features: List[Dict[str, torch.Tensor]]): 119 | N, T, L, C = in_features[0]['out'].size() 120 | assert len(in_features) == self.num_layers 121 | x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1) 122 | 123 | for i in range(self.num_layers): 124 | frame_features = in_features[i]['out'] 125 | 126 | if self.enable_temporal_conv: 127 | feat = in_features[i]['out'] 128 | feat = feat.permute(0, 2, 3, 1).contiguous().flatten(0, 1) # N * L, C, T 129 | feat = self.temporal_conv[i](feat) 130 | feat = feat.view(N, L, C, T).permute(0, 3, 1, 2).contiguous() # N, T, L, C 131 | frame_features += feat 132 | 133 | if self.enable_temporal_pos_embed: 134 | frame_features += self.temporal_pos_embed[i].view(1, T, 1, C) 135 | 136 | if self.enable_temporal_cross_attention: 137 | frame_features += self.cross_attention[i](in_features[i]['q'], in_features[i]['k']) 138 | 139 | frame_features = frame_features.flatten(1, 2) # N, T * L, C 140 | 141 | x = self.decoder_layers[i](x, frame_features) 142 | 143 | return x 144 | 145 | 146 | class EVLTransformer(nn.Module): 147 | 148 | def __init__( 149 | self, 150 | num_frames: int = 8, 151 | backbone_name: str = 'ViT-B/16', 152 | backbone_type: str = 'clip', 153 | backbone_path: str = '', 154 | backbone_mode: str = 'frozen_fp16', 155 | decoder_num_layers: int = 4, 156 | decoder_qkv_dim: int = 768, 157 | decoder_num_heads: int = 12, 158 | decoder_mlp_factor: float = 4.0, 159 | num_classes: int = 400, 160 | enable_temporal_conv: bool = True, 161 | enable_temporal_pos_embed: bool = True, 162 | enable_temporal_cross_attention: bool = True, 163 | cls_dropout: float = 0.5, 164 | decoder_mlp_dropout: float = 0.5, 165 | ): 166 | super().__init__() 167 | 168 | self.decoder_num_layers = decoder_num_layers 169 | 170 | backbone_config = self._create_backbone(backbone_name, backbone_type, backbone_path, backbone_mode) 171 | backbone_feature_dim = backbone_config['feature_dim'] 172 | backbone_spatial_size = tuple(x // y for x, y in zip(backbone_config['input_size'], backbone_config['patch_size'])) 173 | 174 | self.decoder = EVLDecoder( 175 | num_frames=num_frames, 176 | spatial_size=backbone_spatial_size, 177 | num_layers=decoder_num_layers, 178 | in_feature_dim=backbone_feature_dim, 179 | qkv_dim=decoder_qkv_dim, 180 | num_heads=decoder_num_heads, 181 | mlp_factor=decoder_mlp_factor, 182 | enable_temporal_conv=enable_temporal_conv, 183 | enable_temporal_pos_embed=enable_temporal_pos_embed, 184 | enable_temporal_cross_attention=enable_temporal_cross_attention, 185 | mlp_dropout=decoder_mlp_dropout, 186 | ) 187 | self.proj = nn.Sequential( 188 | nn.LayerNorm(backbone_feature_dim), 189 | nn.Dropout(cls_dropout), 190 | nn.Linear(backbone_feature_dim, num_classes), 191 | ) 192 | 193 | 194 | def _create_backbone( 195 | self, 196 | backbone_name: str, 197 | backbone_type: str, 198 | backbone_path: str, 199 | backbone_mode: str, 200 | ) -> dict: 201 | weight_loader_fn = weight_loader_fn_dict[backbone_type] 202 | state_dict = weight_loader_fn(backbone_path) 203 | 204 | backbone = VisionTransformer2D(return_all_features=True, **vit_presets[backbone_name]) 205 | backbone.load_state_dict(state_dict, strict=True) # weight_loader_fn is expected to strip unused parameters 206 | 207 | assert backbone_mode in ['finetune', 'freeze_fp16', 'freeze_fp32'] 208 | 209 | if backbone_mode == 'finetune': 210 | self.backbone = backbone 211 | else: 212 | backbone.eval().requires_grad_(False) 213 | if backbone_mode == 'freeze_fp16': 214 | model_to_fp16(backbone) 215 | self.backbone = [backbone] # avoid backbone parameter registration 216 | 217 | return vit_presets[backbone_name] 218 | 219 | 220 | def _get_backbone(self, x): 221 | if isinstance(self.backbone, list): 222 | # freeze backbone 223 | self.backbone[0] = self.backbone[0].to(x.device) 224 | return self.backbone[0] 225 | else: 226 | # finetune bakbone 227 | return self.backbone 228 | 229 | 230 | def forward(self, x: torch.Tensor): 231 | backbone = self._get_backbone(x) 232 | 233 | B, C, T, H, W = x.size() 234 | x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) 235 | features = backbone(x)[-self.decoder_num_layers:] 236 | features = [ 237 | dict((k, v.float().view(B, T, *v.size()[1:])) for k, v in x.items()) 238 | for x in features 239 | ] 240 | 241 | x = self.decoder(features) 242 | x = self.proj(x[:, 0, :]) 243 | 244 | return x -------------------------------------------------------------------------------- /scripts/eval_k400_vitb16_16f_dec4x768.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -u -m torch.distributed.run --nproc_per_node 4 \ 4 | main.py \ 5 | --num_steps 50000 \ 6 | --backbone "ViT-B/16-lnpre" \ 7 | --backbone_type clip \ 8 | --backbone_path /path/to/clip_models/ViT-B-16.pt \ 9 | --decoder_num_layers 4 \ 10 | --decoder_qkv_dim 768 \ 11 | --decoder_num_heads 12 \ 12 | --num_classes 400 \ 13 | --val_list_path /path/to/k400/val.txt \ 14 | --batch_size 256 \ 15 | --batch_split 1 \ 16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 17 | --mean 0.48145466 0.4578275 0.40821073 \ 18 | --std 0.26862954 0.26130258 0.27577711 \ 19 | --num_workers 12 \ 20 | --num_frames 16 \ 21 | --sampling_rate 16 \ 22 | --num_spatial_views 3 \ 23 | --num_temporal_views 1 \ 24 | --resume_path /path/to/checkpoint_release/k400_vitb16_16f_dec4x768.pth \ 25 | --eval_only 26 | -------------------------------------------------------------------------------- /scripts/eval_k400_vitb16_32f_dec4x768.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -u -m torch.distributed.run --nproc_per_node 4 \ 4 | main.py \ 5 | --num_steps 50000 \ 6 | --backbone "ViT-B/16-lnpre" \ 7 | --backbone_type clip \ 8 | --backbone_path /path/to/clip_models/ViT-B-16.pt \ 9 | --decoder_num_layers 4 \ 10 | --decoder_qkv_dim 768 \ 11 | --decoder_num_heads 12 \ 12 | --num_classes 400 \ 13 | --val_list_path /path/to/k400/val.txt \ 14 | --batch_size 256 \ 15 | --batch_split 1 \ 16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 17 | --mean 0.48145466 0.4578275 0.40821073 \ 18 | --std 0.26862954 0.26130258 0.27577711 \ 19 | --num_workers 12 \ 20 | --num_frames 32 \ 21 | --sampling_rate 8 \ 22 | --num_spatial_views 3 \ 23 | --num_temporal_views 1 \ 24 | --resume_path /path/to/checkpoint_release/k400_vitb16_32f_dec4x768.pth \ 25 | --eval_only 26 | -------------------------------------------------------------------------------- /scripts/eval_k400_vitb16_8f_dec4x768.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -u -m torch.distributed.run --nproc_per_node 4 \ 4 | main.py \ 5 | --num_steps 50000 \ 6 | --backbone "ViT-B/16-lnpre" \ 7 | --backbone_type clip \ 8 | --backbone_path /path/to/clip_models/ViT-B-16.pt \ 9 | --decoder_num_layers 4 \ 10 | --decoder_qkv_dim 768 \ 11 | --decoder_num_heads 12 \ 12 | --num_classes 400 \ 13 | --val_list_path /path/to/k400/val.txt \ 14 | --batch_size 256 \ 15 | --batch_split 1 \ 16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 17 | --mean 0.48145466 0.4578275 0.40821073 \ 18 | --std 0.26862954 0.26130258 0.27577711 \ 19 | --num_workers 12 \ 20 | --num_frames 8 \ 21 | --sampling_rate 16 \ 22 | --num_spatial_views 1 \ 23 | --num_temporal_views 3 \ 24 | --resume_path /path/to/checkpoint_release/k400_vitb16_8f_dec4x768.pth \ 25 | --eval_only 26 | -------------------------------------------------------------------------------- /scripts/eval_k400_vitl14_16f_dec4x1024.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -u -m torch.distributed.run --nproc_per_node 4 \ 4 | main.py \ 5 | --num_steps 50000 \ 6 | --backbone "ViT-L/14-lnpre" \ 7 | --backbone_type clip \ 8 | --backbone_path /path/to/clip_models/ViT-L-14.pt \ 9 | --decoder_num_layers 4 \ 10 | --decoder_qkv_dim 1024 \ 11 | --decoder_num_heads 16 \ 12 | --num_classes 400 \ 13 | --val_list_path /path/to/k400/val.txt \ 14 | --batch_size 256 \ 15 | --batch_split 1 \ 16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 17 | --mean 0.48145466 0.4578275 0.40821073 \ 18 | --std 0.26862954 0.26130258 0.27577711 \ 19 | --num_workers 12 \ 20 | --num_frames 16 \ 21 | --sampling_rate 16 \ 22 | --num_spatial_views 3 \ 23 | --num_temporal_views 1 \ 24 | --resume_path /path/to/checkpoint_release/k400_vitl14_16f_dec4x1024.pth \ 25 | --eval_only 26 | -------------------------------------------------------------------------------- /scripts/eval_k400_vitl14_32f_dec4x1024.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -u -m torch.distributed.run --nproc_per_node 4 \ 4 | main.py \ 5 | --num_steps 50000 \ 6 | --backbone "ViT-L/14-lnpre" \ 7 | --backbone_type clip \ 8 | --backbone_path /path/to/clip_models/ViT-L-14.pt \ 9 | --decoder_num_layers 4 \ 10 | --decoder_qkv_dim 1024 \ 11 | --decoder_num_heads 16 \ 12 | --num_classes 400 \ 13 | --val_list_path /path/to/k400/val.txt \ 14 | --batch_size 256 \ 15 | --batch_split 1 \ 16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 17 | --mean 0.48145466 0.4578275 0.40821073 \ 18 | --std 0.26862954 0.26130258 0.27577711 \ 19 | --num_workers 12 \ 20 | --num_frames 32 \ 21 | --sampling_rate 8 \ 22 | --num_spatial_views 3 \ 23 | --num_temporal_views 1 \ 24 | --resume_path /path/to/checkpoint_release/k400_vitl14_32f_dec4x1024.pth \ 25 | --eval_only 26 | -------------------------------------------------------------------------------- /scripts/eval_k400_vitl14_8f_dec4x1024.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python -u -m torch.distributed.run --nproc_per_node 4 \ 4 | main.py \ 5 | --num_steps 50000 \ 6 | --backbone "ViT-L/14-lnpre" \ 7 | --backbone_type clip \ 8 | --backbone_path /path/to/clip_models/ViT-L-14.pt \ 9 | --decoder_num_layers 4 \ 10 | --decoder_qkv_dim 1024 \ 11 | --decoder_num_heads 16 \ 12 | --num_classes 400 \ 13 | --val_list_path /path/to/k400/val.txt \ 14 | --batch_size 256 \ 15 | --batch_split 1 \ 16 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 17 | --mean 0.48145466 0.4578275 0.40821073 \ 18 | --std 0.26862954 0.26130258 0.27577711 \ 19 | --num_workers 12 \ 20 | --num_frames 8 \ 21 | --sampling_rate 16 \ 22 | --num_spatial_views 1 \ 23 | --num_temporal_views 3 \ 24 | --resume_path /path/to/checkpoint_release/k400_vitl14_8f_dec4x1024.pth \ 25 | --eval_only 26 | -------------------------------------------------------------------------------- /scripts/train_k400_vitb16_16f_dec4x768.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | exp_dir=runs/k400_vitb16_16f_dec4x768 4 | 5 | mkdir -p "${exp_dir}" 6 | python -u -m torch.distributed.run --nproc_per_node 8 \ 7 | main.py \ 8 | --num_steps 50000 \ 9 | --backbone "ViT-B/16-lnpre" \ 10 | --backbone_type clip \ 11 | --backbone_path /path/to/clip_models/ViT-B-16.pt \ 12 | --decoder_num_layers 4 \ 13 | --decoder_qkv_dim 768 \ 14 | --decoder_num_heads 12 \ 15 | --num_classes 400 \ 16 | --checkpoint_dir "${exp_dir}" \ 17 | --auto_resume \ 18 | --train_list_path /path/to/k400/train.txt \ 19 | --val_list_path /path/to/k400/val.txt \ 20 | --batch_size 256 \ 21 | --batch_split 1 \ 22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 23 | --mean 0.48145466 0.4578275 0.40821073 \ 24 | --std 0.26862954 0.26130258 0.27577711 \ 25 | --num_workers 12 \ 26 | --num_frames 16 \ 27 | --sampling_rate 16 \ 28 | --num_spatial_views 3 \ 29 | --num_temporal_views 1 \ 30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" 31 | -------------------------------------------------------------------------------- /scripts/train_k400_vitb16_32f_dec4x768.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | exp_dir=runs/k400_vitb16_32f_dec4x768 4 | 5 | mkdir -p "${exp_dir}" 6 | python -u -m torch.distributed.run --nproc_per_node 8 \ 7 | main.py \ 8 | --num_steps 50000 \ 9 | --backbone "ViT-B/16-lnpre" \ 10 | --backbone_type clip \ 11 | --backbone_path /path/to/clip_models/ViT-B-16.pt \ 12 | --decoder_num_layers 4 \ 13 | --decoder_qkv_dim 768 \ 14 | --decoder_num_heads 12 \ 15 | --num_classes 400 \ 16 | --checkpoint_dir "${exp_dir}" \ 17 | --auto_resume \ 18 | --train_list_path /path/to/k400/train.txt \ 19 | --val_list_path /path/to/k400/val.txt \ 20 | --batch_size 256 \ 21 | --batch_split 1 \ 22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 23 | --mean 0.48145466 0.4578275 0.40821073 \ 24 | --std 0.26862954 0.26130258 0.27577711 \ 25 | --num_workers 12 \ 26 | --num_frames 32 \ 27 | --sampling_rate 8 \ 28 | --num_spatial_views 3 \ 29 | --num_temporal_views 1 \ 30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" 31 | -------------------------------------------------------------------------------- /scripts/train_k400_vitb16_8f_dec4x768.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | exp_dir=runs/k400_vitb16_8f_dec4x768 4 | 5 | mkdir -p "${exp_dir}" 6 | python -u -m torch.distributed.run --nproc_per_node 8 \ 7 | main.py \ 8 | --num_steps 50000 \ 9 | --backbone "ViT-B/16-lnpre" \ 10 | --backbone_type clip \ 11 | --backbone_path /path/to/clip_models/ViT-B-16.pt \ 12 | --decoder_num_layers 4 \ 13 | --decoder_qkv_dim 768 \ 14 | --decoder_num_heads 12 \ 15 | --num_classes 400 \ 16 | --checkpoint_dir "${exp_dir}" \ 17 | --auto_resume \ 18 | --train_list_path /path/to/k400/train.txt \ 19 | --val_list_path /path/to/k400/val.txt \ 20 | --batch_size 256 \ 21 | --batch_split 1 \ 22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 23 | --mean 0.48145466 0.4578275 0.40821073 \ 24 | --std 0.26862954 0.26130258 0.27577711 \ 25 | --num_workers 12 \ 26 | --num_frames 8 \ 27 | --sampling_rate 16 \ 28 | --num_spatial_views 1 \ 29 | --num_temporal_views 3 \ 30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" 31 | -------------------------------------------------------------------------------- /scripts/train_k400_vitl14_16f_dec4x1024.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | exp_dir=runs/k400_vitl14_16f_dec4x1024 4 | 5 | mkdir -p "${exp_dir}" 6 | python -u -m torch.distributed.run --nproc_per_node 8 \ 7 | main.py \ 8 | --num_steps 50000 \ 9 | --backbone "ViT-L/14-lnpre" \ 10 | --backbone_type clip \ 11 | --backbone_path /path/to/clip_models/ViT-L-14.pt \ 12 | --decoder_num_layers 4 \ 13 | --decoder_qkv_dim 1024 \ 14 | --decoder_num_heads 16 \ 15 | --num_classes 400 \ 16 | --checkpoint_dir "${exp_dir}" \ 17 | --auto_resume \ 18 | --train_list_path /path/to/k400/train.txt \ 19 | --val_list_path /path/to/k400/val.txt \ 20 | --batch_size 256 \ 21 | --batch_split 2 \ 22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 23 | --mean 0.48145466 0.4578275 0.40821073 \ 24 | --std 0.26862954 0.26130258 0.27577711 \ 25 | --num_workers 12 \ 26 | --num_frames 16 \ 27 | --sampling_rate 16 \ 28 | --num_spatial_views 3 \ 29 | --num_temporal_views 1 \ 30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" 31 | -------------------------------------------------------------------------------- /scripts/train_k400_vitl14_32f_dec4x1024.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | exp_dir=runs/k400_vitl14_32f_dec4x1024 4 | 5 | mkdir -p "${exp_dir}" 6 | python -u -m torch.distributed.run --nproc_per_node 8 \ 7 | main.py \ 8 | --num_steps 50000 \ 9 | --backbone "ViT-L/14-lnpre" \ 10 | --backbone_type clip \ 11 | --backbone_path /path/to/clip_models/ViT-L-14.pt \ 12 | --decoder_num_layers 4 \ 13 | --decoder_qkv_dim 1024 \ 14 | --decoder_num_heads 16 \ 15 | --num_classes 400 \ 16 | --checkpoint_dir "${exp_dir}" \ 17 | --auto_resume \ 18 | --train_list_path /path/to/k400/train.txt \ 19 | --val_list_path /path/to/k400/val.txt \ 20 | --batch_size 256 \ 21 | --batch_split 4 \ 22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 23 | --mean 0.48145466 0.4578275 0.40821073 \ 24 | --std 0.26862954 0.26130258 0.27577711 \ 25 | --num_workers 12 \ 26 | --num_frames 32 \ 27 | --sampling_rate 8 \ 28 | --num_spatial_views 3 \ 29 | --num_temporal_views 1 \ 30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" 31 | -------------------------------------------------------------------------------- /scripts/train_k400_vitl14_8f_dec4x1024.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | exp_dir=runs/k400_vitl14_8f_dec4x1024 4 | 5 | mkdir -p "${exp_dir}" 6 | python -u -m torch.distributed.run --nproc_per_node 8 \ 7 | main.py \ 8 | --num_steps 50000 \ 9 | --backbone "ViT-L/14-lnpre" \ 10 | --backbone_type clip \ 11 | --backbone_path /path/to/clip_models/ViT-L-14.pt \ 12 | --decoder_num_layers 4 \ 13 | --decoder_qkv_dim 1024 \ 14 | --decoder_num_heads 16 \ 15 | --num_classes 400 \ 16 | --checkpoint_dir "${exp_dir}" \ 17 | --auto_resume \ 18 | --train_list_path /path/to/k400/train.txt \ 19 | --val_list_path /path/to/k400/val.txt \ 20 | --batch_size 256 \ 21 | --batch_split 1 \ 22 | --auto_augment rand-m7-n4-mstd0.5-inc1 \ 23 | --mean 0.48145466 0.4578275 0.40821073 \ 24 | --std 0.26862954 0.26130258 0.27577711 \ 25 | --num_workers 12 \ 26 | --num_frames 8 \ 27 | --sampling_rate 16 \ 28 | --num_spatial_views 1 \ 29 | --num_temporal_views 3 \ 30 | 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" 31 | -------------------------------------------------------------------------------- /video_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from .dataloader import setup_arg_parser, create_train_loader, create_val_loader -------------------------------------------------------------------------------- /video_dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from typing import Dict 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from .dataset import VideoDataset, DummyDataset 10 | 11 | def setup_arg_parser(parser: argparse.ArgumentParser): 12 | parser.add_argument('--train_list_path', type=str, 13 | help='path to training data list') 14 | parser.add_argument('--val_list_path', type=str, 15 | help='path to validation data list') 16 | parser.add_argument('--train_data_root', type=str, 17 | help='training samples root directory') 18 | parser.add_argument('--val_data_root', type=str, 19 | help='validation samples root directory') 20 | parser.add_argument('--data_root', type=str, default='', 21 | help='training and validation samples root directory, might be overrided by --train_data_root or --val_data_root') 22 | 23 | parser.add_argument('--batch_size', type=int, 24 | help='training batch size on a all GPUs') 25 | 26 | parser.add_argument('--num_spatial_views', type=int, default=1, 27 | help='number of spatial crops used for testing (total views = num_spatial_views * num_temporal_views)') 28 | parser.add_argument('--num_temporal_views', type=int, default=3, 29 | help='number of temporal crops used for testing (total views = num_spatial_views * num_temporal_views)') 30 | parser.add_argument('--num_frames', type=int, default=8, 31 | help='number of frames used for each view') 32 | parser.add_argument('--sampling_rate', type=int, default=16, 33 | help='temporal stride for frame sampling, only valid when tsn_sampling is not enabled') 34 | parser.add_argument('--tsn_sampling', action='store_true', 35 | help='enable TSN-style sampling (i.e. sample frames with dynamic stride to cover the whole video)') 36 | parser.add_argument('--spatial_size', type=int, default=224, 37 | help='frame height and width in pixels') 38 | 39 | parser.add_argument('--mean', type=float, nargs='+', 40 | help='pixel mean used to normalize the image.') 41 | parser.add_argument('--std', type=float, nargs='+', 42 | help='pixel std used to normalize the image') 43 | 44 | parser.add_argument('--num_workers', type=int, default=10, 45 | help='number of DataLoader worker threads') 46 | 47 | parser.add_argument('--dummy_dataset', action='store_true', 48 | help='use fake datasets that generate all 0 (use for speed test only)') 49 | 50 | parser.add_argument('--auto_augment', type=str, 51 | help='auto augment configuration') 52 | parser.add_argument('--interpolation', type=str, default='bicubic', 53 | help='interpolation mode') 54 | parser.add_argument('--no_mirror', action='store_false', dest='mirror', 55 | help='disable mirror for training (frequently used for the something-something dataset)') 56 | parser.set_defaults(mirror=True) 57 | 58 | 59 | def _parse_mean_and_std(args: argparse.Namespace) -> Dict[str, torch.Tensor]: 60 | def parse_mean_or_std(arg, default_value): 61 | if arg is None: 62 | return torch.Tensor([default_value] * 3) 63 | elif len(arg) == 1: 64 | return torch.Tensor(arg * 3) 65 | elif len(arg) == 3: 66 | return torch.Tensor(arg) 67 | else: 68 | raise NotImplementedError() 69 | return { 70 | 'mean': parse_mean_or_std(args.mean, 0.45), 71 | 'std': parse_mean_or_std(args.std, 0.225), 72 | } 73 | 74 | 75 | def create_train_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset: 76 | if args.dummy_dataset: 77 | return DummyDataset( 78 | list_path=args.train_list_path, 79 | num_frames=args.num_frames, 80 | num_views=1, 81 | spatial_size=args.spatial_size, 82 | ) 83 | 84 | return VideoDataset( 85 | list_path=args.train_list_path, 86 | data_root=args.train_data_root or args.data_root, 87 | num_spatial_views=1, num_temporal_views=1, random_sample=True, 88 | auto_augment=args.auto_augment, 89 | interpolation=args.interpolation, 90 | mirror=args.mirror, 91 | num_frames=args.num_frames, 92 | sampling_rate=-1 if args.tsn_sampling else args.sampling_rate, 93 | spatial_size=args.spatial_size, 94 | **_parse_mean_and_std(args), 95 | ) 96 | 97 | 98 | def create_train_loader(args: argparse.Namespace, resume_step: int = 0) -> torch.utils.data.DataLoader: 99 | dataset = create_train_dataset(args) 100 | rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size()) 101 | 102 | assert args.batch_size % world_size == 0 103 | batch_size_per_gpu = args.batch_size // world_size 104 | 105 | # manually create a step-based sampler 106 | sampler = [] 107 | while len(sampler) * len(dataset) < args.num_steps * args.batch_size: 108 | g = torch.Generator() 109 | g.manual_seed(len(sampler)) 110 | indices = torch.randperm(len(dataset), generator=g) 111 | sampler.append(indices) 112 | sampler = torch.cat(sampler, dim=0)[:args.num_steps * args.batch_size].view(args.num_steps, args.batch_size) 113 | sampler = sampler[resume_step:, batch_size_per_gpu * rank: batch_size_per_gpu * (rank + 1)].flatten().tolist() 114 | 115 | loader = torch.utils.data.DataLoader( 116 | dataset, sampler=sampler, batch_size=batch_size_per_gpu, 117 | num_workers=args.num_workers, pin_memory=False, drop_last=True, 118 | ) 119 | 120 | return loader 121 | 122 | 123 | def create_val_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset: 124 | if args.dummy_dataset: 125 | return DummyDataset( 126 | list_path=args.val_list_path, 127 | num_frames=args.num_frames, 128 | num_views=args.num_spatial_views * args.num_temporal_views, 129 | spatial_size=args.spatial_size, 130 | ) 131 | 132 | return VideoDataset( 133 | list_path=args.val_list_path, 134 | data_root=args.val_data_root or args.data_root, 135 | num_spatial_views=args.num_spatial_views, 136 | num_temporal_views=args.num_temporal_views, 137 | random_sample=False, 138 | num_frames=args.num_frames, 139 | sampling_rate=-1 if args.tsn_sampling else args.sampling_rate, 140 | spatial_size=args.spatial_size, 141 | **_parse_mean_and_std(args), 142 | ) 143 | 144 | 145 | def create_val_loader(args: argparse.Namespace) -> torch.utils.data.Dataset: 146 | dataset = create_val_dataset(args) 147 | rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size()) 148 | 149 | # sampler for distribued eval 150 | sampler = list(range(rank, len(dataset), world_size)) 151 | 152 | loader = torch.utils.data.DataLoader( 153 | dataset, sampler=sampler, batch_size=1, 154 | num_workers=args.num_workers, pin_memory=False, 155 | ) 156 | 157 | return loader 158 | -------------------------------------------------------------------------------- /video_dataset/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os, sys 4 | from typing import Optional 5 | import av 6 | import io 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision import transforms 11 | 12 | from .transform import create_random_augment, random_resized_crop 13 | 14 | class VideoDataset(torch.utils.data.Dataset): 15 | 16 | def __init__( 17 | self, list_path: str, data_root: str, 18 | num_spatial_views: int, num_temporal_views: int, random_sample: bool, 19 | num_frames: int, sampling_rate: int, spatial_size: int, 20 | mean: torch.Tensor, std: torch.Tensor, 21 | auto_augment: Optional[str] = None, interpolation: str = 'bicubic', 22 | mirror: bool = False, 23 | ): 24 | self.data_root = data_root 25 | self.interpolation = interpolation 26 | self.spatial_size = spatial_size 27 | 28 | self.mean, self.std = mean, std 29 | self.num_frames, self.sampling_rate = num_frames, sampling_rate 30 | 31 | if random_sample: 32 | assert num_spatial_views == 1 and num_temporal_views == 1 33 | self.random_sample = True 34 | self.mirror = mirror 35 | self.auto_augment = auto_augment 36 | else: 37 | assert auto_augment is None and not mirror 38 | self.random_sample = False 39 | self.num_temporal_views = num_temporal_views 40 | self.num_spatial_views = num_spatial_views 41 | 42 | with open(list_path) as f: 43 | self.data_list = f.read().splitlines() 44 | 45 | 46 | def __len__(self): 47 | return len(self.data_list) 48 | 49 | 50 | def __getitem__(self, idx): 51 | line = self.data_list[idx] 52 | path, label = line.split(' ') 53 | path = os.path.join(self.data_root, path) 54 | label = int(label) 55 | 56 | container = av.open(path) 57 | frames = {} 58 | for frame in container.decode(video=0): 59 | frames[frame.pts] = frame 60 | container.close() 61 | frames = [frames[k] for k in sorted(frames.keys())] 62 | 63 | if self.random_sample: 64 | frame_idx = self._random_sample_frame_idx(len(frames)) 65 | frames = [frames[x].to_rgb().to_ndarray() for x in frame_idx] 66 | frames = torch.as_tensor(np.stack(frames)).float() / 255. 67 | 68 | if self.auto_augment is not None: 69 | aug_transform = create_random_augment( 70 | input_size=(frames.size(1), frames.size(2)), 71 | auto_augment=self.auto_augment, 72 | interpolation=self.interpolation, 73 | ) 74 | frames = frames.permute(0, 3, 1, 2) # T, C, H, W 75 | frames = [transforms.ToPILImage()(frames[i]) for i in range(frames.size(0))] 76 | frames = aug_transform(frames) 77 | frames = torch.stack([transforms.ToTensor()(img) for img in frames]) 78 | frames = frames.permute(0, 2, 3, 1) 79 | 80 | frames = (frames - self.mean) / self.std 81 | frames = frames.permute(3, 0, 1, 2) # C, T, H, W 82 | frames = random_resized_crop( 83 | frames, self.spatial_size, self.spatial_size, 84 | ) 85 | 86 | else: 87 | frames = [x.to_rgb().to_ndarray() for x in frames] 88 | frames = torch.as_tensor(np.stack(frames)) 89 | frames = frames.float() / 255. 90 | 91 | frames = (frames - self.mean) / self.std 92 | frames = frames.permute(3, 0, 1, 2) # C, T, H, W 93 | 94 | if frames.size(-2) < frames.size(-1): 95 | new_width = frames.size(-1) * self.spatial_size // frames.size(-2) 96 | new_height = self.spatial_size 97 | else: 98 | new_height = frames.size(-2) * self.spatial_size // frames.size(-1) 99 | new_width = self.spatial_size 100 | frames = torch.nn.functional.interpolate( 101 | frames, size=(new_height, new_width), 102 | mode='bilinear', align_corners=False, 103 | ) 104 | 105 | frames = self._generate_spatial_crops(frames) 106 | frames = sum([self._generate_temporal_crops(x) for x in frames], []) 107 | if len(frames) > 1: 108 | frames = torch.stack(frames) 109 | 110 | return frames, label 111 | 112 | 113 | def _generate_temporal_crops(self, frames): 114 | seg_len = (self.num_frames - 1) * self.sampling_rate + 1 115 | if frames.size(1) < seg_len: 116 | frames = torch.cat([frames, frames[:, -1:].repeat(1, seg_len - frames.size(1), 1, 1)], dim=1) 117 | slide_len = frames.size(1) - seg_len 118 | 119 | crops = [] 120 | for i in range(self.num_temporal_views): 121 | if self.num_temporal_views == 1: 122 | st = slide_len // 2 123 | else: 124 | st = round(slide_len / (self.num_temporal_views - 1) * i) 125 | 126 | crops.append(frames[:, st: st + self.num_frames * self.sampling_rate: self.sampling_rate]) 127 | 128 | return crops 129 | 130 | 131 | def _generate_spatial_crops(self, frames): 132 | if self.num_spatial_views == 1: 133 | assert min(frames.size(-2), frames.size(-1)) >= self.spatial_size 134 | h_st = (frames.size(-2) - self.spatial_size) // 2 135 | w_st = (frames.size(-1) - self.spatial_size) // 2 136 | h_ed, w_ed = h_st + self.spatial_size, w_st + self.spatial_size 137 | return [frames[:, :, h_st: h_ed, w_st: w_ed]] 138 | 139 | elif self.num_spatial_views == 3: 140 | assert min(frames.size(-2), frames.size(-1)) == self.spatial_size 141 | crops = [] 142 | margin = max(frames.size(-2), frames.size(-1)) - self.spatial_size 143 | for st in (0, margin // 2, margin): 144 | ed = st + self.spatial_size 145 | if frames.size(-2) > frames.size(-1): 146 | crops.append(frames[:, :, st: ed, :]) 147 | else: 148 | crops.append(frames[:, :, :, st: ed]) 149 | return crops 150 | 151 | else: 152 | raise NotImplementedError() 153 | 154 | 155 | def _random_sample_frame_idx(self, len): 156 | frame_indices = [] 157 | 158 | if self.sampling_rate < 0: # tsn sample 159 | seg_size = (len - 1) / self.num_frames 160 | for i in range(self.num_frames): 161 | start, end = round(seg_size * i), round(seg_size * (i + 1)) 162 | frame_indices.append(np.random.randint(start, end + 1)) 163 | elif self.sampling_rate * (self.num_frames - 1) + 1 >= len: 164 | for i in range(self.num_frames): 165 | frame_indices.append(i * self.sampling_rate if i * self.sampling_rate < len else frame_indices[-1]) 166 | else: 167 | start = np.random.randint(len - self.sampling_rate * (self.num_frames - 1)) 168 | frame_indices = list(range(start, start + self.sampling_rate * self.num_frames, self.sampling_rate)) 169 | 170 | return frame_indices 171 | 172 | 173 | class DummyDataset(torch.utils.data.Dataset): 174 | 175 | def __init__(self, list_path: str, num_frames: int, num_views: int, spatial_size: int): 176 | with open(list_path) as f: 177 | self.len = len(f.read().splitlines()) 178 | self.num_frames = num_frames 179 | self.num_views = num_views 180 | self.spatial_size = spatial_size 181 | 182 | def __len__(self): 183 | return self.len 184 | 185 | def __getitem__(self, _): 186 | shape = [3, self.num_frames, self.spatial_size, self.spatial_size] 187 | if self.num_views != 1: 188 | shape = [self.num_views] + shape 189 | return torch.zeros(shape), 0 190 | -------------------------------------------------------------------------------- /video_dataset/rand_augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/rand_augment.py 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 5 | 6 | """ 7 | This implementation is based on 8 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 9 | pulished under an Apache License 2.0. 10 | 11 | COMMENT FROM ORIGINAL: 12 | AutoAugment, RandAugment, and AugMix for PyTorch 13 | This code implements the searched ImageNet policies with various tweaks and 14 | improvements and does not include any of the search code. AA and RA 15 | Implementation adapted from: 16 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 17 | AugMix adapted from: 18 | https://github.com/google-research/augmix 19 | Papers: 20 | AutoAugment: Learning Augmentation Policies from Data 21 | https://arxiv.org/abs/1805.09501 22 | Learning Data Augmentation Strategies for Object Detection 23 | https://arxiv.org/abs/1906.11172 24 | RandAugment: Practical automated data augmentation... 25 | https://arxiv.org/abs/1909.13719 26 | AugMix: A Simple Data Processing Method to Improve Robustness and 27 | Uncertainty https://arxiv.org/abs/1912.02781 28 | 29 | Hacked together by / Copyright 2020 Ross Wightman 30 | """ 31 | 32 | import math 33 | import numpy as np 34 | import random 35 | import re 36 | import PIL 37 | from PIL import Image, ImageEnhance, ImageOps 38 | 39 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 40 | 41 | _FILL = (128, 128, 128) 42 | 43 | # This signifies the max integer that the controller RNN could predict for the 44 | # augmentation scheme. 45 | _MAX_LEVEL = 10.0 46 | 47 | _HPARAMS_DEFAULT = { 48 | "translate_const": 250, 49 | "img_mean": _FILL, 50 | } 51 | 52 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 53 | 54 | 55 | def _interpolation(kwargs): 56 | interpolation = kwargs.pop("resample", Image.BILINEAR) 57 | if isinstance(interpolation, (list, tuple)): 58 | return random.choice(interpolation) 59 | else: 60 | return interpolation 61 | 62 | 63 | def _check_args_tf(kwargs): 64 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 65 | kwargs.pop("fillcolor") 66 | kwargs["resample"] = _interpolation(kwargs) 67 | 68 | 69 | def shear_x(img, factor, **kwargs): 70 | _check_args_tf(kwargs) 71 | return img.transform( 72 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 73 | ) 74 | 75 | 76 | def shear_y(img, factor, **kwargs): 77 | _check_args_tf(kwargs) 78 | return img.transform( 79 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 80 | ) 81 | 82 | 83 | def translate_x_rel(img, pct, **kwargs): 84 | pixels = pct * img.size[0] 85 | _check_args_tf(kwargs) 86 | return img.transform( 87 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 88 | ) 89 | 90 | 91 | def translate_y_rel(img, pct, **kwargs): 92 | pixels = pct * img.size[1] 93 | _check_args_tf(kwargs) 94 | return img.transform( 95 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 96 | ) 97 | 98 | 99 | def translate_x_abs(img, pixels, **kwargs): 100 | _check_args_tf(kwargs) 101 | return img.transform( 102 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 103 | ) 104 | 105 | 106 | def translate_y_abs(img, pixels, **kwargs): 107 | _check_args_tf(kwargs) 108 | return img.transform( 109 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 110 | ) 111 | 112 | 113 | def rotate(img, degrees, **kwargs): 114 | _check_args_tf(kwargs) 115 | if _PIL_VER >= (5, 2): 116 | return img.rotate(degrees, **kwargs) 117 | elif _PIL_VER >= (5, 0): 118 | w, h = img.size 119 | post_trans = (0, 0) 120 | rotn_center = (w / 2.0, h / 2.0) 121 | angle = -math.radians(degrees) 122 | matrix = [ 123 | round(math.cos(angle), 15), 124 | round(math.sin(angle), 15), 125 | 0.0, 126 | round(-math.sin(angle), 15), 127 | round(math.cos(angle), 15), 128 | 0.0, 129 | ] 130 | 131 | def transform(x, y, matrix): 132 | (a, b, c, d, e, f) = matrix 133 | return a * x + b * y + c, d * x + e * y + f 134 | 135 | matrix[2], matrix[5] = transform( 136 | -rotn_center[0] - post_trans[0], 137 | -rotn_center[1] - post_trans[1], 138 | matrix, 139 | ) 140 | matrix[2] += rotn_center[0] 141 | matrix[5] += rotn_center[1] 142 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 143 | else: 144 | return img.rotate(degrees, resample=kwargs["resample"]) 145 | 146 | 147 | def auto_contrast(img, **__): 148 | return ImageOps.autocontrast(img) 149 | 150 | 151 | def invert(img, **__): 152 | return ImageOps.invert(img) 153 | 154 | 155 | def equalize(img, **__): 156 | return ImageOps.equalize(img) 157 | 158 | 159 | def solarize(img, thresh, **__): 160 | return ImageOps.solarize(img, thresh) 161 | 162 | 163 | def solarize_add(img, add, thresh=128, **__): 164 | lut = [] 165 | for i in range(256): 166 | if i < thresh: 167 | lut.append(min(255, i + add)) 168 | else: 169 | lut.append(i) 170 | if img.mode in ("L", "RGB"): 171 | if img.mode == "RGB" and len(lut) == 256: 172 | lut = lut + lut + lut 173 | return img.point(lut) 174 | else: 175 | return img 176 | 177 | 178 | def posterize(img, bits_to_keep, **__): 179 | if bits_to_keep >= 8: 180 | return img 181 | return ImageOps.posterize(img, bits_to_keep) 182 | 183 | 184 | def contrast(img, factor, **__): 185 | return ImageEnhance.Contrast(img).enhance(factor) 186 | 187 | 188 | def color(img, factor, **__): 189 | return ImageEnhance.Color(img).enhance(factor) 190 | 191 | 192 | def brightness(img, factor, **__): 193 | return ImageEnhance.Brightness(img).enhance(factor) 194 | 195 | 196 | def sharpness(img, factor, **__): 197 | return ImageEnhance.Sharpness(img).enhance(factor) 198 | 199 | 200 | def _randomly_negate(v): 201 | """With 50% prob, negate the value""" 202 | return -v if random.random() > 0.5 else v 203 | 204 | 205 | def _rotate_level_to_arg(level, _hparams): 206 | # range [-30, 30] 207 | level = (level / _MAX_LEVEL) * 30.0 208 | level = _randomly_negate(level) 209 | return (level,) 210 | 211 | 212 | def _enhance_level_to_arg(level, _hparams): 213 | # range [0.1, 1.9] 214 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 215 | 216 | 217 | def _enhance_increasing_level_to_arg(level, _hparams): 218 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 219 | # range [0.1, 1.9] 220 | level = (level / _MAX_LEVEL) * 0.9 221 | level = 1.0 + _randomly_negate(level) 222 | return (level,) 223 | 224 | 225 | def _shear_level_to_arg(level, _hparams): 226 | # range [-0.3, 0.3] 227 | level = (level / _MAX_LEVEL) * 0.3 228 | level = _randomly_negate(level) 229 | return (level,) 230 | 231 | 232 | def _translate_abs_level_to_arg(level, hparams): 233 | translate_const = hparams["translate_const"] 234 | level = (level / _MAX_LEVEL) * float(translate_const) 235 | level = _randomly_negate(level) 236 | return (level,) 237 | 238 | 239 | def _translate_rel_level_to_arg(level, hparams): 240 | # default range [-0.45, 0.45] 241 | translate_pct = hparams.get("translate_pct", 0.45) 242 | level = (level / _MAX_LEVEL) * translate_pct 243 | level = _randomly_negate(level) 244 | return (level,) 245 | 246 | 247 | def _posterize_level_to_arg(level, _hparams): 248 | # As per Tensorflow TPU EfficientNet impl 249 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 250 | # intensity/severity of augmentation decreases with level 251 | return (int((level / _MAX_LEVEL) * 4),) 252 | 253 | 254 | def _posterize_increasing_level_to_arg(level, hparams): 255 | # As per Tensorflow models research and UDA impl 256 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 257 | # intensity/severity of augmentation increases with level 258 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 259 | 260 | 261 | def _posterize_original_level_to_arg(level, _hparams): 262 | # As per original AutoAugment paper description 263 | # range [4, 8], 'keep 4 up to 8 MSB of image' 264 | # intensity/severity of augmentation decreases with level 265 | return (int((level / _MAX_LEVEL) * 4) + 4,) 266 | 267 | 268 | def _solarize_level_to_arg(level, _hparams): 269 | # range [0, 256] 270 | # intensity/severity of augmentation decreases with level 271 | return (int((level / _MAX_LEVEL) * 256),) 272 | 273 | 274 | def _solarize_increasing_level_to_arg(level, _hparams): 275 | # range [0, 256] 276 | # intensity/severity of augmentation increases with level 277 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 278 | 279 | 280 | def _solarize_add_level_to_arg(level, _hparams): 281 | # range [0, 110] 282 | return (int((level / _MAX_LEVEL) * 110),) 283 | 284 | 285 | LEVEL_TO_ARG = { 286 | "AutoContrast": None, 287 | "Equalize": None, 288 | "Invert": None, 289 | "Rotate": _rotate_level_to_arg, 290 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 291 | "Posterize": _posterize_level_to_arg, 292 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 293 | "PosterizeOriginal": _posterize_original_level_to_arg, 294 | "Solarize": _solarize_level_to_arg, 295 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 296 | "SolarizeAdd": _solarize_add_level_to_arg, 297 | "Color": _enhance_level_to_arg, 298 | "ColorIncreasing": _enhance_increasing_level_to_arg, 299 | "Contrast": _enhance_level_to_arg, 300 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 301 | "Brightness": _enhance_level_to_arg, 302 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 303 | "Sharpness": _enhance_level_to_arg, 304 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 305 | "ShearX": _shear_level_to_arg, 306 | "ShearY": _shear_level_to_arg, 307 | "TranslateX": _translate_abs_level_to_arg, 308 | "TranslateY": _translate_abs_level_to_arg, 309 | "TranslateXRel": _translate_rel_level_to_arg, 310 | "TranslateYRel": _translate_rel_level_to_arg, 311 | } 312 | 313 | 314 | NAME_TO_OP = { 315 | "AutoContrast": auto_contrast, 316 | "Equalize": equalize, 317 | "Invert": invert, 318 | "Rotate": rotate, 319 | "Posterize": posterize, 320 | "PosterizeIncreasing": posterize, 321 | "PosterizeOriginal": posterize, 322 | "Solarize": solarize, 323 | "SolarizeIncreasing": solarize, 324 | "SolarizeAdd": solarize_add, 325 | "Color": color, 326 | "ColorIncreasing": color, 327 | "Contrast": contrast, 328 | "ContrastIncreasing": contrast, 329 | "Brightness": brightness, 330 | "BrightnessIncreasing": brightness, 331 | "Sharpness": sharpness, 332 | "SharpnessIncreasing": sharpness, 333 | "ShearX": shear_x, 334 | "ShearY": shear_y, 335 | "TranslateX": translate_x_abs, 336 | "TranslateY": translate_y_abs, 337 | "TranslateXRel": translate_x_rel, 338 | "TranslateYRel": translate_y_rel, 339 | } 340 | 341 | 342 | class AugmentOp: 343 | """ 344 | Apply for video. 345 | """ 346 | 347 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 348 | hparams = hparams or _HPARAMS_DEFAULT 349 | self.aug_fn = NAME_TO_OP[name] 350 | self.level_fn = LEVEL_TO_ARG[name] 351 | self.prob = prob 352 | self.magnitude = magnitude 353 | self.hparams = hparams.copy() 354 | self.kwargs = { 355 | "fillcolor": hparams["img_mean"] 356 | if "img_mean" in hparams 357 | else _FILL, 358 | "resample": hparams["interpolation"] 359 | if "interpolation" in hparams 360 | else _RANDOM_INTERPOLATION, 361 | } 362 | 363 | # If magnitude_std is > 0, we introduce some randomness 364 | # in the usually fixed policy and sample magnitude from a normal distribution 365 | # with mean `magnitude` and std-dev of `magnitude_std`. 366 | # NOTE This is my own hack, being tested, not in papers or reference impls. 367 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 368 | 369 | def __call__(self, img_list): 370 | if self.prob < 1.0 and random.random() > self.prob: 371 | return img_list 372 | magnitude = self.magnitude 373 | if self.magnitude_std and self.magnitude_std > 0: 374 | magnitude = random.gauss(magnitude, self.magnitude_std) 375 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 376 | level_args = ( 377 | self.level_fn(magnitude, self.hparams) 378 | if self.level_fn is not None 379 | else () 380 | ) 381 | 382 | if isinstance(img_list, list): 383 | return [ 384 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 385 | ] 386 | else: 387 | return self.aug_fn(img_list, *level_args, **self.kwargs) 388 | 389 | 390 | _RAND_TRANSFORMS = [ 391 | "AutoContrast", 392 | "Equalize", 393 | "Invert", 394 | "Rotate", 395 | "Posterize", 396 | "Solarize", 397 | "SolarizeAdd", 398 | "Color", 399 | "Contrast", 400 | "Brightness", 401 | "Sharpness", 402 | "ShearX", 403 | "ShearY", 404 | "TranslateXRel", 405 | "TranslateYRel", 406 | ] 407 | 408 | 409 | _RAND_INCREASING_TRANSFORMS = [ 410 | "AutoContrast", 411 | "Equalize", 412 | "Invert", 413 | "Rotate", 414 | "PosterizeIncreasing", 415 | "SolarizeIncreasing", 416 | "SolarizeAdd", 417 | "ColorIncreasing", 418 | "ContrastIncreasing", 419 | "BrightnessIncreasing", 420 | "SharpnessIncreasing", 421 | "ShearX", 422 | "ShearY", 423 | "TranslateXRel", 424 | "TranslateYRel", 425 | ] 426 | 427 | 428 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 429 | # They may not result in increased performance, but could likely be tuned to so. 430 | _RAND_CHOICE_WEIGHTS_0 = { 431 | "Rotate": 0.3, 432 | "ShearX": 0.2, 433 | "ShearY": 0.2, 434 | "TranslateXRel": 0.1, 435 | "TranslateYRel": 0.1, 436 | "Color": 0.025, 437 | "Sharpness": 0.025, 438 | "AutoContrast": 0.025, 439 | "Solarize": 0.005, 440 | "SolarizeAdd": 0.005, 441 | "Contrast": 0.005, 442 | "Brightness": 0.005, 443 | "Equalize": 0.005, 444 | "Posterize": 0, 445 | "Invert": 0, 446 | } 447 | 448 | 449 | def _select_rand_weights(weight_idx=0, transforms=None): 450 | transforms = transforms or _RAND_TRANSFORMS 451 | assert weight_idx == 0 # only one set of weights currently 452 | rand_weights = _RAND_CHOICE_WEIGHTS_0 453 | probs = [rand_weights[k] for k in transforms] 454 | probs /= np.sum(probs) 455 | return probs 456 | 457 | 458 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 459 | hparams = hparams or _HPARAMS_DEFAULT 460 | transforms = transforms or _RAND_TRANSFORMS 461 | return [ 462 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 463 | for name in transforms 464 | ] 465 | 466 | 467 | class RandAugment: 468 | def __init__(self, ops, num_layers=2, choice_weights=None): 469 | self.ops = ops 470 | self.num_layers = num_layers 471 | self.choice_weights = choice_weights 472 | 473 | def __call__(self, img): 474 | # no replacement when using weighted choice 475 | ops = np.random.choice( 476 | self.ops, 477 | self.num_layers, 478 | replace=self.choice_weights is None, 479 | p=self.choice_weights, 480 | ) 481 | for op in ops: 482 | img = op(img) 483 | return img 484 | 485 | 486 | def rand_augment_transform(config_str, hparams): 487 | """ 488 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 489 | 490 | Create a RandAugment transform 491 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 492 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 493 | sections, not order sepecific determine 494 | 'm' - integer magnitude of rand augment 495 | 'n' - integer num layers (number of transform ops selected per image) 496 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 497 | 'mstd' - float std deviation of magnitude noise applied 498 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 499 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 500 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 501 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 502 | :return: A PyTorch compatible Transform 503 | """ 504 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 505 | num_layers = 2 # default to 2 ops per image 506 | weight_idx = None # default to no probability weights for op choice 507 | transforms = _RAND_TRANSFORMS 508 | config = config_str.split("-") 509 | assert config[0] == "rand" 510 | config = config[1:] 511 | for c in config: 512 | cs = re.split(r"(\d.*)", c) 513 | if len(cs) < 2: 514 | continue 515 | key, val = cs[:2] 516 | if key == "mstd": 517 | # noise param injected via hparams for now 518 | hparams.setdefault("magnitude_std", float(val)) 519 | elif key == "inc": 520 | if bool(val): 521 | transforms = _RAND_INCREASING_TRANSFORMS 522 | elif key == "m": 523 | magnitude = int(val) 524 | elif key == "n": 525 | num_layers = int(val) 526 | elif key == "w": 527 | weight_idx = int(val) 528 | else: 529 | assert NotImplementedError 530 | ra_ops = rand_augment_ops( 531 | magnitude=magnitude, hparams=hparams, transforms=transforms 532 | ) 533 | choice_weights = ( 534 | None if weight_idx is None else _select_rand_weights(weight_idx) 535 | ) 536 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 537 | -------------------------------------------------------------------------------- /video_dataset/random_erasing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/random_erasing.py 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 5 | 6 | """ 7 | This implementation is based on 8 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 9 | pulished under an Apache License 2.0. 10 | 11 | COMMENT FROM ORIGINAL: 12 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 13 | Copyright Zhun Zhong & Liang Zheng 14 | Hacked together by / Copyright 2020 Ross Wightman 15 | """ 16 | import math 17 | import random 18 | import torch 19 | 20 | 21 | def _get_pixels( 22 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 23 | ): 24 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 25 | # paths, flip the order so normal is run on CPU if this becomes a problem 26 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 27 | if per_pixel: 28 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 29 | elif rand_color: 30 | return torch.empty( 31 | (patch_size[0], 1, 1), dtype=dtype, device=device 32 | ).normal_() 33 | else: 34 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 35 | 36 | 37 | class RandomErasing: 38 | """Randomly selects a rectangle region in an image and erases its pixels. 39 | 'Random Erasing Data Augmentation' by Zhong et al. 40 | See https://arxiv.org/pdf/1708.04896.pdf 41 | This variant of RandomErasing is intended to be applied to either a batch 42 | or single image tensor after it has been normalized by dataset mean and std. 43 | Args: 44 | probability: Probability that the Random Erasing operation will be performed. 45 | min_area: Minimum percentage of erased area wrt input image area. 46 | max_area: Maximum percentage of erased area wrt input image area. 47 | min_aspect: Minimum aspect ratio of erased area. 48 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 49 | 'const' - erase block is constant color of 0 for all channels 50 | 'rand' - erase block is same per-channel random (normal) color 51 | 'pixel' - erase block is per-pixel random (normal) color 52 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 53 | per-image count is randomly chosen between 1 and this value. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | probability=0.5, 59 | min_area=0.02, 60 | max_area=1 / 3, 61 | min_aspect=0.3, 62 | max_aspect=None, 63 | mode="const", 64 | min_count=1, 65 | max_count=None, 66 | num_splits=0, 67 | device="cuda", 68 | cube=True, 69 | ): 70 | self.probability = probability 71 | self.min_area = min_area 72 | self.max_area = max_area 73 | max_aspect = max_aspect or 1 / min_aspect 74 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 75 | self.min_count = min_count 76 | self.max_count = max_count or min_count 77 | self.num_splits = num_splits 78 | mode = mode.lower() 79 | self.rand_color = False 80 | self.per_pixel = False 81 | self.cube = cube 82 | if mode == "rand": 83 | self.rand_color = True # per block random normal 84 | elif mode == "pixel": 85 | self.per_pixel = True # per pixel random normal 86 | else: 87 | assert not mode or mode == "const" 88 | self.device = device 89 | 90 | def _erase(self, img, chan, img_h, img_w, dtype): 91 | if random.random() > self.probability: 92 | return 93 | area = img_h * img_w 94 | count = ( 95 | self.min_count 96 | if self.min_count == self.max_count 97 | else random.randint(self.min_count, self.max_count) 98 | ) 99 | for _ in range(count): 100 | for _ in range(10): 101 | target_area = ( 102 | random.uniform(self.min_area, self.max_area) * area / count 103 | ) 104 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 105 | h = int(round(math.sqrt(target_area * aspect_ratio))) 106 | w = int(round(math.sqrt(target_area / aspect_ratio))) 107 | if w < img_w and h < img_h: 108 | top = random.randint(0, img_h - h) 109 | left = random.randint(0, img_w - w) 110 | img[:, top : top + h, left : left + w] = _get_pixels( 111 | self.per_pixel, 112 | self.rand_color, 113 | (chan, h, w), 114 | dtype=dtype, 115 | device=self.device, 116 | ) 117 | break 118 | 119 | def _erase_cube( 120 | self, 121 | img, 122 | batch_start, 123 | batch_size, 124 | chan, 125 | img_h, 126 | img_w, 127 | dtype, 128 | ): 129 | if random.random() > self.probability: 130 | return 131 | area = img_h * img_w 132 | count = ( 133 | self.min_count 134 | if self.min_count == self.max_count 135 | else random.randint(self.min_count, self.max_count) 136 | ) 137 | for _ in range(count): 138 | for _ in range(100): 139 | target_area = ( 140 | random.uniform(self.min_area, self.max_area) * area / count 141 | ) 142 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 143 | h = int(round(math.sqrt(target_area * aspect_ratio))) 144 | w = int(round(math.sqrt(target_area / aspect_ratio))) 145 | if w < img_w and h < img_h: 146 | top = random.randint(0, img_h - h) 147 | left = random.randint(0, img_w - w) 148 | for i in range(batch_start, batch_size): 149 | img_instance = img[i] 150 | img_instance[ 151 | :, top : top + h, left : left + w 152 | ] = _get_pixels( 153 | self.per_pixel, 154 | self.rand_color, 155 | (chan, h, w), 156 | dtype=dtype, 157 | device=self.device, 158 | ) 159 | break 160 | 161 | def __call__(self, input): 162 | if len(input.size()) == 3: 163 | self._erase(input, *input.size(), input.dtype) 164 | else: 165 | batch_size, chan, img_h, img_w = input.size() 166 | # skip first slice of batch if num_splits is set (for clean portion of samples) 167 | batch_start = ( 168 | batch_size // self.num_splits if self.num_splits > 1 else 0 169 | ) 170 | if self.cube: 171 | self._erase_cube( 172 | input, 173 | batch_start, 174 | batch_size, 175 | chan, 176 | img_h, 177 | img_w, 178 | input.dtype, 179 | ) 180 | else: 181 | for i in range(batch_start, batch_size): 182 | self._erase(input[i], chan, img_h, img_w, input.dtype) 183 | return input 184 | -------------------------------------------------------------------------------- /video_dataset/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Originate from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/transform.py 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | 5 | import logging 6 | import math 7 | import numpy as np 8 | 9 | # import cv2 10 | import random 11 | import torch 12 | import torchvision as tv 13 | import torchvision.transforms.functional as F 14 | from PIL import Image, ImageFilter 15 | from torchvision import transforms 16 | 17 | from .rand_augment import rand_augment_transform 18 | from .random_erasing import RandomErasing 19 | 20 | _pil_interpolation_to_str = { 21 | Image.NEAREST: "PIL.Image.NEAREST", 22 | Image.BILINEAR: "PIL.Image.BILINEAR", 23 | Image.BICUBIC: "PIL.Image.BICUBIC", 24 | Image.LANCZOS: "PIL.Image.LANCZOS", 25 | Image.HAMMING: "PIL.Image.HAMMING", 26 | Image.BOX: "PIL.Image.BOX", 27 | } 28 | 29 | 30 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 31 | 32 | 33 | def _pil_interp(method): 34 | if method == "bicubic": 35 | return Image.BICUBIC 36 | elif method == "lanczos": 37 | return Image.LANCZOS 38 | elif method == "hamming": 39 | return Image.HAMMING 40 | else: 41 | return Image.BILINEAR 42 | 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | def random_short_side_scale_jitter( 48 | images, min_size, max_size, boxes=None, inverse_uniform_sampling=False 49 | ): 50 | """ 51 | Perform a spatial short scale jittering on the given images and 52 | corresponding boxes. 53 | Args: 54 | images (tensor): images to perform scale jitter. Dimension is 55 | `num frames` x `channel` x `height` x `width`. 56 | min_size (int): the minimal size to scale the frames. 57 | max_size (int): the maximal size to scale the frames. 58 | boxes (ndarray): optional. Corresponding boxes to images. 59 | Dimension is `num boxes` x 4. 60 | inverse_uniform_sampling (bool): if True, sample uniformly in 61 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 62 | scale. If False, take a uniform sample from [min_scale, max_scale]. 63 | Returns: 64 | (tensor): the scaled images with dimension of 65 | `num frames` x `channel` x `new height` x `new width`. 66 | (ndarray or None): the scaled boxes with dimension of 67 | `num boxes` x 4. 68 | """ 69 | if inverse_uniform_sampling: 70 | size = int( 71 | round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) 72 | ) 73 | else: 74 | size = int(round(np.random.uniform(min_size, max_size))) 75 | 76 | height = images.shape[2] 77 | width = images.shape[3] 78 | if (width <= height and width == size) or ( 79 | height <= width and height == size 80 | ): 81 | return images, boxes 82 | new_width = size 83 | new_height = size 84 | if width < height: 85 | new_height = int(math.floor((float(height) / width) * size)) 86 | if boxes is not None: 87 | boxes = boxes * float(new_height) / height 88 | else: 89 | new_width = int(math.floor((float(width) / height) * size)) 90 | if boxes is not None: 91 | boxes = boxes * float(new_width) / width 92 | 93 | return ( 94 | torch.nn.functional.interpolate( 95 | images, 96 | size=(new_height, new_width), 97 | mode="bilinear", 98 | align_corners=False, 99 | ), 100 | boxes, 101 | ) 102 | 103 | 104 | def crop_boxes(boxes, x_offset, y_offset): 105 | """ 106 | Peform crop on the bounding boxes given the offsets. 107 | Args: 108 | boxes (ndarray or None): bounding boxes to peform crop. The dimension 109 | is `num boxes` x 4. 110 | x_offset (int): cropping offset in the x axis. 111 | y_offset (int): cropping offset in the y axis. 112 | Returns: 113 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 114 | `num boxes` x 4. 115 | """ 116 | cropped_boxes = boxes.copy() 117 | cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset 118 | cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset 119 | 120 | return cropped_boxes 121 | 122 | 123 | def random_crop(images, size, boxes=None): 124 | """ 125 | Perform random spatial crop on the given images and corresponding boxes. 126 | Args: 127 | images (tensor): images to perform random crop. The dimension is 128 | `num frames` x `channel` x `height` x `width`. 129 | size (int): the size of height and width to crop on the image. 130 | boxes (ndarray or None): optional. Corresponding boxes to images. 131 | Dimension is `num boxes` x 4. 132 | Returns: 133 | cropped (tensor): cropped images with dimension of 134 | `num frames` x `channel` x `size` x `size`. 135 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 136 | `num boxes` x 4. 137 | """ 138 | if images.shape[2] == size and images.shape[3] == size: 139 | return images, boxes 140 | height = images.shape[2] 141 | width = images.shape[3] 142 | y_offset = 0 143 | if height > size: 144 | y_offset = int(np.random.randint(0, height - size)) 145 | x_offset = 0 146 | if width > size: 147 | x_offset = int(np.random.randint(0, width - size)) 148 | cropped = images[ 149 | :, :, y_offset : y_offset + size, x_offset : x_offset + size 150 | ] 151 | 152 | cropped_boxes = ( 153 | crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None 154 | ) 155 | 156 | return cropped, cropped_boxes 157 | 158 | 159 | def horizontal_flip(prob, images, boxes=None): 160 | """ 161 | Perform horizontal flip on the given images and corresponding boxes. 162 | Args: 163 | prob (float): probility to flip the images. 164 | images (tensor): images to perform horizontal flip, the dimension is 165 | `num frames` x `channel` x `height` x `width`. 166 | boxes (ndarray or None): optional. Corresponding boxes to images. 167 | Dimension is `num boxes` x 4. 168 | Returns: 169 | images (tensor): images with dimension of 170 | `num frames` x `channel` x `height` x `width`. 171 | flipped_boxes (ndarray or None): the flipped boxes with dimension of 172 | `num boxes` x 4. 173 | """ 174 | if boxes is None: 175 | flipped_boxes = None 176 | else: 177 | flipped_boxes = boxes.copy() 178 | 179 | if np.random.uniform() < prob: 180 | images = images.flip((-1)) 181 | 182 | if len(images.shape) == 3: 183 | width = images.shape[2] 184 | elif len(images.shape) == 4: 185 | width = images.shape[3] 186 | else: 187 | raise NotImplementedError("Dimension does not supported") 188 | if boxes is not None: 189 | flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 190 | 191 | return images, flipped_boxes 192 | 193 | 194 | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): 195 | """ 196 | Perform uniform spatial sampling on the images and corresponding boxes. 197 | Args: 198 | images (tensor): images to perform uniform crop. The dimension is 199 | `num frames` x `channel` x `height` x `width`. 200 | size (int): size of height and weight to crop the images. 201 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width 202 | is larger than height. Or 0, 1, or 2 for top, center, and bottom 203 | crop if height is larger than width. 204 | boxes (ndarray or None): optional. Corresponding boxes to images. 205 | Dimension is `num boxes` x 4. 206 | scale_size (int): optinal. If not None, resize the images to scale_size before 207 | performing any crop. 208 | Returns: 209 | cropped (tensor): images with dimension of 210 | `num frames` x `channel` x `size` x `size`. 211 | cropped_boxes (ndarray or None): the cropped boxes with dimension of 212 | `num boxes` x 4. 213 | """ 214 | assert spatial_idx in [0, 1, 2] 215 | ndim = len(images.shape) 216 | if ndim == 3: 217 | images = images.unsqueeze(0) 218 | height = images.shape[2] 219 | width = images.shape[3] 220 | 221 | if scale_size is not None: 222 | if width <= height: 223 | width, height = scale_size, int(height / width * scale_size) 224 | else: 225 | width, height = int(width / height * scale_size), scale_size 226 | images = torch.nn.functional.interpolate( 227 | images, 228 | size=(height, width), 229 | mode="bilinear", 230 | align_corners=False, 231 | ) 232 | 233 | y_offset = int(math.ceil((height - size) / 2)) 234 | x_offset = int(math.ceil((width - size) / 2)) 235 | 236 | if height > width: 237 | if spatial_idx == 0: 238 | y_offset = 0 239 | elif spatial_idx == 2: 240 | y_offset = height - size 241 | else: 242 | if spatial_idx == 0: 243 | x_offset = 0 244 | elif spatial_idx == 2: 245 | x_offset = width - size 246 | cropped = images[ 247 | :, :, y_offset : y_offset + size, x_offset : x_offset + size 248 | ] 249 | cropped_boxes = ( 250 | crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None 251 | ) 252 | if ndim == 3: 253 | cropped = cropped.squeeze(0) 254 | return cropped, cropped_boxes 255 | 256 | 257 | def clip_boxes_to_image(boxes, height, width): 258 | """ 259 | Clip an array of boxes to an image with the given height and width. 260 | Args: 261 | boxes (ndarray): bounding boxes to perform clipping. 262 | Dimension is `num boxes` x 4. 263 | height (int): given image height. 264 | width (int): given image width. 265 | Returns: 266 | clipped_boxes (ndarray): the clipped boxes with dimension of 267 | `num boxes` x 4. 268 | """ 269 | clipped_boxes = boxes.copy() 270 | clipped_boxes[:, [0, 2]] = np.minimum( 271 | width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) 272 | ) 273 | clipped_boxes[:, [1, 3]] = np.minimum( 274 | height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) 275 | ) 276 | return clipped_boxes 277 | 278 | 279 | def blend(images1, images2, alpha): 280 | """ 281 | Blend two images with a given weight alpha. 282 | Args: 283 | images1 (tensor): the first images to be blended, the dimension is 284 | `num frames` x `channel` x `height` x `width`. 285 | images2 (tensor): the second images to be blended, the dimension is 286 | `num frames` x `channel` x `height` x `width`. 287 | alpha (float): the blending weight. 288 | Returns: 289 | (tensor): blended images, the dimension is 290 | `num frames` x `channel` x `height` x `width`. 291 | """ 292 | return images1 * alpha + images2 * (1 - alpha) 293 | 294 | 295 | def grayscale(images): 296 | """ 297 | Get the grayscale for the input images. The channels of images should be 298 | in order BGR. 299 | Args: 300 | images (tensor): the input images for getting grayscale. Dimension is 301 | `num frames` x `channel` x `height` x `width`. 302 | Returns: 303 | img_gray (tensor): blended images, the dimension is 304 | `num frames` x `channel` x `height` x `width`. 305 | """ 306 | # R -> 0.299, G -> 0.587, B -> 0.114. 307 | img_gray = torch.tensor(images) 308 | gray_channel = ( 309 | 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] 310 | ) 311 | img_gray[:, 0] = gray_channel 312 | img_gray[:, 1] = gray_channel 313 | img_gray[:, 2] = gray_channel 314 | return img_gray 315 | 316 | 317 | def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): 318 | """ 319 | Perfrom a color jittering on the input images. The channels of images 320 | should be in order BGR. 321 | Args: 322 | images (tensor): images to perform color jitter. Dimension is 323 | `num frames` x `channel` x `height` x `width`. 324 | img_brightness (float): jitter ratio for brightness. 325 | img_contrast (float): jitter ratio for contrast. 326 | img_saturation (float): jitter ratio for saturation. 327 | Returns: 328 | images (tensor): the jittered images, the dimension is 329 | `num frames` x `channel` x `height` x `width`. 330 | """ 331 | 332 | jitter = [] 333 | if img_brightness != 0: 334 | jitter.append("brightness") 335 | if img_contrast != 0: 336 | jitter.append("contrast") 337 | if img_saturation != 0: 338 | jitter.append("saturation") 339 | 340 | if len(jitter) > 0: 341 | order = np.random.permutation(np.arange(len(jitter))) 342 | for idx in range(0, len(jitter)): 343 | if jitter[order[idx]] == "brightness": 344 | images = brightness_jitter(img_brightness, images) 345 | elif jitter[order[idx]] == "contrast": 346 | images = contrast_jitter(img_contrast, images) 347 | elif jitter[order[idx]] == "saturation": 348 | images = saturation_jitter(img_saturation, images) 349 | return images 350 | 351 | 352 | def brightness_jitter(var, images): 353 | """ 354 | Perfrom brightness jittering on the input images. The channels of images 355 | should be in order BGR. 356 | Args: 357 | var (float): jitter ratio for brightness. 358 | images (tensor): images to perform color jitter. Dimension is 359 | `num frames` x `channel` x `height` x `width`. 360 | Returns: 361 | images (tensor): the jittered images, the dimension is 362 | `num frames` x `channel` x `height` x `width`. 363 | """ 364 | alpha = 1.0 + np.random.uniform(-var, var) 365 | 366 | img_bright = torch.zeros(images.shape) 367 | images = blend(images, img_bright, alpha) 368 | return images 369 | 370 | 371 | def contrast_jitter(var, images): 372 | """ 373 | Perfrom contrast jittering on the input images. The channels of images 374 | should be in order BGR. 375 | Args: 376 | var (float): jitter ratio for contrast. 377 | images (tensor): images to perform color jitter. Dimension is 378 | `num frames` x `channel` x `height` x `width`. 379 | Returns: 380 | images (tensor): the jittered images, the dimension is 381 | `num frames` x `channel` x `height` x `width`. 382 | """ 383 | alpha = 1.0 + np.random.uniform(-var, var) 384 | 385 | img_gray = grayscale(images) 386 | img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) 387 | images = blend(images, img_gray, alpha) 388 | return images 389 | 390 | 391 | def saturation_jitter(var, images): 392 | """ 393 | Perfrom saturation jittering on the input images. The channels of images 394 | should be in order BGR. 395 | Args: 396 | var (float): jitter ratio for saturation. 397 | images (tensor): images to perform color jitter. Dimension is 398 | `num frames` x `channel` x `height` x `width`. 399 | Returns: 400 | images (tensor): the jittered images, the dimension is 401 | `num frames` x `channel` x `height` x `width`. 402 | """ 403 | alpha = 1.0 + np.random.uniform(-var, var) 404 | img_gray = grayscale(images) 405 | images = blend(images, img_gray, alpha) 406 | 407 | return images 408 | 409 | 410 | def lighting_jitter(images, alphastd, eigval, eigvec): 411 | """ 412 | Perform AlexNet-style PCA jitter on the given images. 413 | Args: 414 | images (tensor): images to perform lighting jitter. Dimension is 415 | `num frames` x `channel` x `height` x `width`. 416 | alphastd (float): jitter ratio for PCA jitter. 417 | eigval (list): eigenvalues for PCA jitter. 418 | eigvec (list[list]): eigenvectors for PCA jitter. 419 | Returns: 420 | out_images (tensor): the jittered images, the dimension is 421 | `num frames` x `channel` x `height` x `width`. 422 | """ 423 | if alphastd == 0: 424 | return images 425 | # generate alpha1, alpha2, alpha3. 426 | alpha = np.random.normal(0, alphastd, size=(1, 3)) 427 | eig_vec = np.array(eigvec) 428 | eig_val = np.reshape(eigval, (1, 3)) 429 | rgb = np.sum( 430 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), 431 | axis=1, 432 | ) 433 | out_images = torch.zeros_like(images) 434 | if len(images.shape) == 3: 435 | # C H W 436 | channel_dim = 0 437 | elif len(images.shape) == 4: 438 | # T C H W 439 | channel_dim = 1 440 | else: 441 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 442 | 443 | for idx in range(images.shape[channel_dim]): 444 | # C H W 445 | if len(images.shape) == 3: 446 | out_images[idx] = images[idx] + rgb[2 - idx] 447 | # T C H W 448 | elif len(images.shape) == 4: 449 | out_images[:, idx] = images[:, idx] + rgb[2 - idx] 450 | else: 451 | raise NotImplementedError( 452 | f"Unsupported dimension {len(images.shape)}" 453 | ) 454 | 455 | return out_images 456 | 457 | 458 | def color_normalization(images, mean, stddev): 459 | """ 460 | Perform color nomration on the given images. 461 | Args: 462 | images (tensor): images to perform color normalization. Dimension is 463 | `num frames` x `channel` x `height` x `width`. 464 | mean (list): mean values for normalization. 465 | stddev (list): standard deviations for normalization. 466 | 467 | Returns: 468 | out_images (tensor): the noramlized images, the dimension is 469 | `num frames` x `channel` x `height` x `width`. 470 | """ 471 | if len(images.shape) == 3: 472 | assert ( 473 | len(mean) == images.shape[0] 474 | ), "channel mean not computed properly" 475 | assert ( 476 | len(stddev) == images.shape[0] 477 | ), "channel stddev not computed properly" 478 | elif len(images.shape) == 4: 479 | assert ( 480 | len(mean) == images.shape[1] 481 | ), "channel mean not computed properly" 482 | assert ( 483 | len(stddev) == images.shape[1] 484 | ), "channel stddev not computed properly" 485 | else: 486 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 487 | 488 | out_images = torch.zeros_like(images) 489 | for idx in range(len(mean)): 490 | # C H W 491 | if len(images.shape) == 3: 492 | out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] 493 | elif len(images.shape) == 4: 494 | out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] 495 | else: 496 | raise NotImplementedError( 497 | f"Unsupported dimension {len(images.shape)}" 498 | ) 499 | return out_images 500 | 501 | 502 | def _get_param_spatial_crop( 503 | scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False 504 | ): 505 | """ 506 | Given scale, ratio, height and width, return sampled coordinates of the videos. 507 | """ 508 | for _ in range(num_repeat): 509 | area = height * width 510 | target_area = random.uniform(*scale) * area 511 | if log_scale: 512 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 513 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 514 | else: 515 | aspect_ratio = random.uniform(*ratio) 516 | 517 | w = int(round(math.sqrt(target_area * aspect_ratio))) 518 | h = int(round(math.sqrt(target_area / aspect_ratio))) 519 | 520 | if np.random.uniform() < 0.5 and switch_hw: 521 | w, h = h, w 522 | 523 | if 0 < w <= width and 0 < h <= height: 524 | i = random.randint(0, height - h) 525 | j = random.randint(0, width - w) 526 | return i, j, h, w 527 | 528 | # Fallback to central crop 529 | in_ratio = float(width) / float(height) 530 | if in_ratio < min(ratio): 531 | w = width 532 | h = int(round(w / min(ratio))) 533 | elif in_ratio > max(ratio): 534 | h = height 535 | w = int(round(h * max(ratio))) 536 | else: # whole image 537 | w = width 538 | h = height 539 | i = (height - h) // 2 540 | j = (width - w) // 2 541 | return i, j, h, w 542 | 543 | 544 | def random_resized_crop( 545 | images, 546 | target_height, 547 | target_width, 548 | scale=(0.08, 1.0), 549 | ratio=(3.0 / 4.0, 4.0 / 3.0), 550 | ): 551 | """ 552 | Crop the given images to random size and aspect ratio. A crop of random 553 | size (default: of 0.08 to 1.0) of the original size and a random aspect 554 | ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This 555 | crop is finally resized to given size. This is popularly used to train the 556 | Inception networks. 557 | 558 | Args: 559 | images: Images to perform resizing and cropping. 560 | target_height: Desired height after cropping. 561 | target_width: Desired width after cropping. 562 | scale: Scale range of Inception-style area based random resizing. 563 | ratio: Aspect ratio range of Inception-style area based random resizing. 564 | """ 565 | 566 | height = images.shape[2] 567 | width = images.shape[3] 568 | 569 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) 570 | cropped = images[:, :, i : i + h, j : j + w] 571 | return torch.nn.functional.interpolate( 572 | cropped, 573 | size=(target_height, target_width), 574 | mode="bilinear", 575 | align_corners=False, 576 | ) 577 | 578 | 579 | def random_resized_crop_with_shift( 580 | images, 581 | target_height, 582 | target_width, 583 | scale=(0.8, 1.0), 584 | ratio=(3.0 / 4.0, 4.0 / 3.0), 585 | ): 586 | """ 587 | This is similar to random_resized_crop. However, it samples two different 588 | boxes (for cropping) for the first and last frame. It then linearly 589 | interpolates the two boxes for other frames. 590 | 591 | Args: 592 | images: Images to perform resizing and cropping. 593 | target_height: Desired height after cropping. 594 | target_width: Desired width after cropping. 595 | scale: Scale range of Inception-style area based random resizing. 596 | ratio: Aspect ratio range of Inception-style area based random resizing. 597 | """ 598 | t = images.shape[1] 599 | height = images.shape[2] 600 | width = images.shape[3] 601 | 602 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) 603 | i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) 604 | i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] 605 | j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] 606 | h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] 607 | w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] 608 | out = torch.zeros((3, t, target_height, target_width)) 609 | for ind in range(t): 610 | out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( 611 | images[ 612 | :, 613 | ind : ind + 1, 614 | i_s[ind] : i_s[ind] + h_s[ind], 615 | j_s[ind] : j_s[ind] + w_s[ind], 616 | ], 617 | size=(target_height, target_width), 618 | mode="bilinear", 619 | align_corners=False, 620 | ) 621 | return out 622 | 623 | 624 | def create_random_augment( 625 | input_size, 626 | auto_augment=None, 627 | interpolation="bilinear", 628 | ): 629 | """ 630 | Get video randaug transform. 631 | 632 | Args: 633 | input_size: The size of the input video in tuple. 634 | auto_augment: Parameters for randaug. An example: 635 | "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number 636 | of operations to apply). 637 | interpolation: Interpolation method. 638 | """ 639 | if isinstance(input_size, tuple): 640 | img_size = input_size[-2:] 641 | else: 642 | img_size = input_size 643 | 644 | if auto_augment: 645 | assert isinstance(auto_augment, str) 646 | if isinstance(img_size, tuple): 647 | img_size_min = min(img_size) 648 | else: 649 | img_size_min = img_size 650 | aa_params = {"translate_const": int(img_size_min * 0.45)} 651 | if interpolation and interpolation != "random": 652 | aa_params["interpolation"] = _pil_interp(interpolation) 653 | if auto_augment.startswith("rand"): 654 | return transforms.Compose( 655 | [rand_augment_transform(auto_augment, aa_params)] 656 | ) 657 | raise NotImplementedError 658 | 659 | 660 | def random_sized_crop_img( 661 | im, 662 | size, 663 | jitter_scale=(0.08, 1.0), 664 | jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), 665 | max_iter=10, 666 | ): 667 | """ 668 | Performs Inception-style cropping (used for training). 669 | """ 670 | assert ( 671 | len(im.shape) == 3 672 | ), "Currently only support image for random_sized_crop" 673 | h, w = im.shape[1:3] 674 | i, j, h, w = _get_param_spatial_crop( 675 | scale=jitter_scale, 676 | ratio=jitter_aspect, 677 | height=h, 678 | width=w, 679 | num_repeat=max_iter, 680 | log_scale=False, 681 | switch_hw=True, 682 | ) 683 | cropped = im[:, i : i + h, j : j + w] 684 | return torch.nn.functional.interpolate( 685 | cropped.unsqueeze(0), 686 | size=(size, size), 687 | mode="bilinear", 688 | align_corners=False, 689 | ).squeeze(0) 690 | 691 | 692 | # The following code are modified based on timm lib, we will replace the following 693 | # contents with dependency from PyTorchVideo. 694 | # https://github.com/facebookresearch/pytorchvideo 695 | class RandomResizedCropAndInterpolation: 696 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 697 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 698 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 699 | is finally resized to given size. 700 | This is popularly used to train the Inception networks. 701 | Args: 702 | size: expected output size of each edge 703 | scale: range of size of the origin size cropped 704 | ratio: range of aspect ratio of the origin aspect ratio cropped 705 | interpolation: Default: PIL.Image.BILINEAR 706 | """ 707 | 708 | def __init__( 709 | self, 710 | size, 711 | scale=(0.08, 1.0), 712 | ratio=(3.0 / 4.0, 4.0 / 3.0), 713 | interpolation="bilinear", 714 | ): 715 | if isinstance(size, tuple): 716 | self.size = size 717 | else: 718 | self.size = (size, size) 719 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 720 | print("range should be of kind (min, max)") 721 | 722 | if interpolation == "random": 723 | self.interpolation = _RANDOM_INTERPOLATION 724 | else: 725 | self.interpolation = _pil_interp(interpolation) 726 | self.scale = scale 727 | self.ratio = ratio 728 | 729 | @staticmethod 730 | def get_params(img, scale, ratio): 731 | """Get parameters for ``crop`` for a random sized crop. 732 | Args: 733 | img (PIL Image): Image to be cropped. 734 | scale (tuple): range of size of the origin size cropped 735 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 736 | Returns: 737 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 738 | sized crop. 739 | """ 740 | area = img.size[0] * img.size[1] 741 | 742 | for _ in range(10): 743 | target_area = random.uniform(*scale) * area 744 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 745 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 746 | 747 | w = int(round(math.sqrt(target_area * aspect_ratio))) 748 | h = int(round(math.sqrt(target_area / aspect_ratio))) 749 | 750 | if w <= img.size[0] and h <= img.size[1]: 751 | i = random.randint(0, img.size[1] - h) 752 | j = random.randint(0, img.size[0] - w) 753 | return i, j, h, w 754 | 755 | # Fallback to central crop 756 | in_ratio = img.size[0] / img.size[1] 757 | if in_ratio < min(ratio): 758 | w = img.size[0] 759 | h = int(round(w / min(ratio))) 760 | elif in_ratio > max(ratio): 761 | h = img.size[1] 762 | w = int(round(h * max(ratio))) 763 | else: # whole image 764 | w = img.size[0] 765 | h = img.size[1] 766 | i = (img.size[1] - h) // 2 767 | j = (img.size[0] - w) // 2 768 | return i, j, h, w 769 | 770 | def __call__(self, img): 771 | """ 772 | Args: 773 | img (PIL Image): Image to be cropped and resized. 774 | Returns: 775 | PIL Image: Randomly cropped and resized image. 776 | """ 777 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 778 | if isinstance(self.interpolation, (tuple, list)): 779 | interpolation = random.choice(self.interpolation) 780 | else: 781 | interpolation = self.interpolation 782 | return F.resized_crop(img, i, j, h, w, self.size, interpolation) 783 | 784 | def __repr__(self): 785 | if isinstance(self.interpolation, (tuple, list)): 786 | interpolate_str = " ".join( 787 | [_pil_interpolation_to_str[x] for x in self.interpolation] 788 | ) 789 | else: 790 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 791 | format_string = self.__class__.__name__ + "(size={0}".format(self.size) 792 | format_string += ", scale={0}".format( 793 | tuple(round(s, 4) for s in self.scale) 794 | ) 795 | format_string += ", ratio={0}".format( 796 | tuple(round(r, 4) for r in self.ratio) 797 | ) 798 | format_string += ", interpolation={0})".format(interpolate_str) 799 | return format_string 800 | -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from collections import OrderedDict 4 | import numpy as np 5 | from typing import Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | ''' 12 | QuickGELU and LayerNorm w/ fp16 from official CLIP repo 13 | (https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py) 14 | ''' 15 | class QuickGELU(nn.Module): 16 | def forward(self, x: torch.Tensor): 17 | return x * torch.sigmoid(1.702 * x) 18 | 19 | class LayerNorm(nn.LayerNorm): 20 | """Subclass torch's LayerNorm to handle fp16.""" 21 | 22 | def forward(self, x: torch.Tensor): 23 | orig_type = x.dtype 24 | ret = super().forward(x.type(torch.float32)) 25 | return ret.type(orig_type) 26 | 27 | 28 | class Attention(nn.Module): 29 | ''' 30 | A generalized attention module with more flexibility. 31 | ''' 32 | 33 | def __init__( 34 | self, q_in_dim: int, k_in_dim: int, v_in_dim: int, 35 | qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int, 36 | return_all_features: bool = False, 37 | ): 38 | super().__init__() 39 | 40 | self.q_proj = nn.Linear(q_in_dim, qk_proj_dim) 41 | self.k_proj = nn.Linear(k_in_dim, qk_proj_dim) 42 | self.v_proj = nn.Linear(v_in_dim, v_proj_dim) 43 | self.out_proj = nn.Linear(v_proj_dim, out_dim) 44 | 45 | self.num_heads = num_heads 46 | self.return_all_features = return_all_features 47 | assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0 48 | 49 | self._initialize_weights() 50 | 51 | 52 | def _initialize_weights(self): 53 | for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj): 54 | nn.init.xavier_uniform_(m.weight) 55 | nn.init.constant_(m.bias, 0.) 56 | 57 | 58 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): 59 | assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 60 | N = q.size(0); assert k.size(0) == N and v.size(0) == N 61 | Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv 62 | 63 | q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v) 64 | 65 | H = self.num_heads 66 | Cqk, Cv = q.size(-1) // H, v.size(-1) // H 67 | 68 | q = q.view(N, Lq, H, Cqk) 69 | k = k.view(N, Lkv, H, Cqk) 70 | v = v.view(N, Lkv, H, Cv) 71 | 72 | aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k) 73 | aff = aff.softmax(dim=-2) 74 | mix = torch.einsum('nqlh,nlhc->nqhc', aff, v) 75 | 76 | out = self.out_proj(mix.flatten(-2)) 77 | 78 | if self.return_all_features: 79 | return dict(q=q, k=k, v=v, aff=aff, out=out) 80 | else: 81 | return out 82 | 83 | 84 | class PatchEmbed2D(nn.Module): 85 | 86 | def __init__( 87 | self, 88 | patch_size: Tuple[int, int] = (16, 16), 89 | in_channels: int = 3, 90 | embed_dim: int = 768, 91 | ): 92 | super().__init__() 93 | 94 | self.patch_size = patch_size 95 | self.in_channels = in_channels 96 | 97 | self.proj = nn.Linear(np.prod(patch_size) * in_channels, embed_dim) 98 | 99 | 100 | def _initialize_weights(self, x): 101 | nn.init.kaiming_normal_(self.proj.weight, 0.) 102 | nn.init.constant_(self.proj.bias, 0.) 103 | 104 | 105 | def forward(self, x: torch.Tensor): 106 | B, C, H, W = x.size() 107 | pH, pW = self.patch_size 108 | 109 | assert C == self.in_channels and H % pH == 0 and W % pW == 0 110 | 111 | x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3).flatten(1, 2) 112 | x = self.proj(x) 113 | 114 | return x 115 | 116 | class TransformerEncoderLayer(nn.Module): 117 | 118 | def __init__( 119 | self, 120 | in_feature_dim: int = 768, 121 | qkv_dim: int = 768, 122 | num_heads: int = 12, 123 | mlp_factor: float = 4.0, 124 | mlp_dropout: float = 0.0, 125 | act: nn.Module = QuickGELU, 126 | return_all_features: bool = False, 127 | ): 128 | super().__init__() 129 | 130 | self.return_all_features = return_all_features 131 | 132 | self.attn = Attention( 133 | q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim, 134 | qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, 135 | return_all_features=return_all_features, 136 | ) 137 | 138 | mlp_dim = round(mlp_factor * in_feature_dim) 139 | self.mlp = nn.Sequential(OrderedDict([ 140 | ('fc1', nn.Linear(in_feature_dim, mlp_dim)), 141 | ('act', act()), 142 | ('dropout', nn.Dropout(mlp_dropout)), 143 | ('fc2', nn.Linear(mlp_dim, in_feature_dim)), 144 | ])) 145 | 146 | self.norm1 = LayerNorm(in_feature_dim) 147 | self.norm2 = LayerNorm(in_feature_dim) 148 | 149 | self._initialize_weights() 150 | 151 | 152 | def _initialize_weights(self): 153 | for m in (self.mlp[0], self.mlp[-1]): 154 | nn.init.xavier_uniform_(m.weight) 155 | nn.init.normal_(m.bias, std=1e-6) 156 | 157 | 158 | def forward(self, x: torch.Tensor): 159 | if self.return_all_features: 160 | ret_dict = {} 161 | 162 | x_norm = self.norm1(x) 163 | attn_out = self.attn(x_norm, x_norm, x_norm) 164 | ret_dict['q'] = attn_out['q'] 165 | ret_dict['k'] = attn_out['k'] 166 | ret_dict['v'] = attn_out['v'] 167 | ret_dict['attn_out'] = attn_out['out'] 168 | x = x + attn_out['out'] 169 | 170 | x = x + self.mlp(self.norm2(x)) 171 | ret_dict['out'] = x 172 | 173 | return ret_dict 174 | 175 | else: 176 | x_norm = self.norm1(x) 177 | x = x + self.attn(x_norm, x_norm, x_norm) 178 | x = x + self.mlp(self.norm2(x)) 179 | 180 | return x 181 | 182 | 183 | class TransformerDecoderLayer(nn.Module): 184 | 185 | def __init__( 186 | self, 187 | in_feature_dim: int = 768, 188 | qkv_dim: int = 768, 189 | num_heads: int = 12, 190 | mlp_factor: float = 4.0, 191 | mlp_dropout: float = 0.0, 192 | act: nn.Module = QuickGELU, 193 | ): 194 | super().__init__() 195 | 196 | self.attn = Attention( 197 | q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim, 198 | qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, 199 | ) 200 | 201 | mlp_dim = round(mlp_factor * in_feature_dim) 202 | self.mlp = nn.Sequential(OrderedDict([ 203 | ('fc1', nn.Linear(in_feature_dim, mlp_dim)), 204 | ('act', act()), 205 | ('dropout', nn.Dropout(mlp_dropout)), 206 | ('fc2', nn.Linear(mlp_dim, in_feature_dim)), 207 | ])) 208 | 209 | self.norm1 = LayerNorm(in_feature_dim) 210 | self.norm2 = LayerNorm(in_feature_dim) 211 | self.norm3 = LayerNorm(in_feature_dim) 212 | 213 | self._initialize_weights() 214 | 215 | 216 | def _initialize_weights(self): 217 | for m in (self.mlp[0], self.mlp[-1]): 218 | nn.init.xavier_uniform_(m.weight) 219 | nn.init.normal_(m.bias, std=1e-6) 220 | 221 | 222 | def forward(self, x: torch.Tensor, y: torch.Tensor): 223 | y_norm = self.norm3(y) 224 | x = x + self.attn(self.norm1(x), y_norm, y_norm) 225 | x = x + self.mlp(self.norm2(x)) 226 | 227 | return x 228 | 229 | 230 | class VisionTransformer2D(nn.Module): 231 | 232 | def __init__( 233 | self, 234 | feature_dim: int = 768, 235 | input_size: Tuple[int, int] = (224, 224), 236 | patch_size: Tuple[int, int] = (16, 16), 237 | num_heads: int = 12, 238 | num_layers: int = 12, 239 | mlp_factor: float = 4.0, 240 | act: nn.Module = QuickGELU, 241 | return_all_features: bool = False, 242 | ln_pre: bool = False, 243 | ): 244 | super().__init__() 245 | 246 | self.return_all_features = return_all_features 247 | 248 | self.patch_embed = PatchEmbed2D(patch_size=patch_size, embed_dim=feature_dim) 249 | self.num_patches = np.prod([x // y for x, y in zip(input_size, patch_size)]) + 1 250 | 251 | self.cls_token = nn.Parameter(torch.zeros([feature_dim])) 252 | self.pos_embed = nn.Parameter(torch.zeros([self.num_patches, feature_dim])) 253 | 254 | self.blocks = nn.ModuleList([ 255 | TransformerEncoderLayer( 256 | in_feature_dim=feature_dim, qkv_dim=feature_dim, num_heads=num_heads, mlp_factor=mlp_factor, act=act, 257 | return_all_features=return_all_features, 258 | ) for _ in range(num_layers) 259 | ]) 260 | 261 | if ln_pre: 262 | self.ln_pre = LayerNorm(feature_dim) 263 | else: 264 | self.ln_pre = nn.Identity() 265 | 266 | self._initialize_weights() 267 | 268 | 269 | def _initialize_weights(self): 270 | nn.init.normal_(self.cls_token, std=0.02) 271 | nn.init.normal_(self.pos_embed, std=0.02) 272 | 273 | def forward(self, x: torch.Tensor): 274 | dtype = self.patch_embed.proj.weight.dtype 275 | x = x.to(dtype) 276 | 277 | x = self.patch_embed(x) 278 | x = torch.cat([self.cls_token.view(1, 1, -1).repeat(x.size(0), 1, 1), x], dim=1) 279 | x = x + self.pos_embed 280 | 281 | x = self.ln_pre(x) 282 | 283 | if self.return_all_features: 284 | all_features = [] 285 | for blk in self.blocks: 286 | x = blk(x) 287 | all_features.append(x) 288 | x = x['out'] 289 | return all_features 290 | 291 | else: 292 | for blk in self.blocks: 293 | x = blk(x) 294 | return x 295 | 296 | 297 | def model_to_fp16(model: VisionTransformer2D): 298 | def _module_to_fp16(m: nn.Module): 299 | if isinstance(m, (nn.Linear,)): 300 | m.half() 301 | model.apply(_module_to_fp16) 302 | 303 | model.pos_embed.data = model.pos_embed.data.half() 304 | model.cls_token.data = model.cls_token.data.half() 305 | 306 | 307 | vit_presets = { 308 | 'ViT-B/16-lnpre': dict( 309 | feature_dim=768, 310 | input_size=(224, 224), 311 | patch_size=(16, 16), 312 | num_heads=12, 313 | num_layers=12, 314 | mlp_factor=4.0, 315 | ln_pre=True, 316 | ), 317 | 'ViT-L/14-lnpre': dict( 318 | feature_dim=1024, 319 | input_size=(224, 224), 320 | patch_size=(14, 14), 321 | num_heads=16, 322 | num_layers=24, 323 | mlp_factor=4.0, 324 | ln_pre=True, 325 | ), 326 | } -------------------------------------------------------------------------------- /weight_loaders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os, sys 4 | from typing import Dict 5 | 6 | import torch 7 | 8 | __all__ = ['weight_loader_fn_dict'] 9 | 10 | def load_weights_clip(load_path: str) -> Dict[str, torch.Tensor]: 11 | clip_model = torch.jit.load(load_path, map_location='cpu') 12 | clip_model = clip_model.visual 13 | src_state_dict = clip_model.state_dict() 14 | src_state_dict = dict((k, v.float()) for k, v in src_state_dict.items()) 15 | 16 | dst_state_dict = {} 17 | 18 | dst_state_dict['cls_token'] = src_state_dict['class_embedding'] 19 | dst_state_dict['pos_embed'] = src_state_dict['positional_embedding'] 20 | dst_state_dict['patch_embed.proj.weight'] = src_state_dict['conv1.weight'].flatten(1) 21 | dst_state_dict['patch_embed.proj.bias'] = torch.zeros([src_state_dict['conv1.weight'].size(0)]) 22 | 23 | dst_state_dict['ln_pre.weight'] = src_state_dict['ln_pre.weight'] 24 | dst_state_dict['ln_pre.bias'] = src_state_dict['ln_pre.bias'] 25 | 26 | block_idx = 0 27 | while True: 28 | src_prefix = 'transformer.resblocks.%d.' % block_idx 29 | dst_prefix = 'blocks.%d.' % block_idx 30 | 31 | src_block_state_dict = dict((k[len(src_prefix):], v) for k, v in src_state_dict.items() if k.startswith(src_prefix)) 32 | if len(src_block_state_dict) == 0: 33 | break 34 | 35 | dst_block_state_dict = {} 36 | feat_dim = src_block_state_dict['ln_1.weight'].size(0) 37 | 38 | for i, dst_name in enumerate(('q', 'k', 'v')): 39 | dst_block_state_dict['attn.%s_proj.weight' % dst_name] = src_block_state_dict['attn.in_proj_weight'][feat_dim * i: feat_dim * (i + 1)] 40 | dst_block_state_dict['attn.%s_proj.bias' % dst_name] = src_block_state_dict['attn.in_proj_bias'][feat_dim * i: feat_dim * (i + 1)] 41 | 42 | dst_block_state_dict['attn.out_proj.weight'] = src_block_state_dict['attn.out_proj.weight'] 43 | dst_block_state_dict['attn.out_proj.bias'] = src_block_state_dict['attn.out_proj.bias'] 44 | 45 | dst_block_state_dict['mlp.fc1.weight'] = src_block_state_dict['mlp.c_fc.weight'] 46 | dst_block_state_dict['mlp.fc1.bias'] = src_block_state_dict['mlp.c_fc.bias'] 47 | dst_block_state_dict['mlp.fc2.weight'] = src_block_state_dict['mlp.c_proj.weight'] 48 | dst_block_state_dict['mlp.fc2.bias'] = src_block_state_dict['mlp.c_proj.bias'] 49 | 50 | dst_block_state_dict['norm1.weight'] = src_block_state_dict['ln_1.weight'] 51 | dst_block_state_dict['norm1.bias'] = src_block_state_dict['ln_1.bias'] 52 | dst_block_state_dict['norm2.weight'] = src_block_state_dict['ln_2.weight'] 53 | dst_block_state_dict['norm2.bias'] = src_block_state_dict['ln_2.bias'] 54 | 55 | dst_state_dict.update(dict((dst_prefix + k, v) for k, v in dst_block_state_dict.items())) 56 | block_idx += 1 57 | 58 | return dst_state_dict 59 | 60 | 61 | weight_loader_fn_dict = { 62 | 'clip': load_weights_clip, 63 | } 64 | --------------------------------------------------------------------------------