├── gitignore ├── figs ├── apt.png ├── apt-intro.jpg └── apt-method.jpg ├── scripts ├── .DS_Store ├── ssv2 │ ├── videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ │ ├── finetune_viu3.sh │ │ └── finetune_viu2.sh │ └── videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ │ ├── finetune_viu3.sh │ │ └── finetune_viu2.sh ├── ucf101 │ ├── videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ │ ├── linearprobe_viu3.sh │ │ ├── finetune_viu3.sh │ │ └── finetune_viu2.sh │ └── videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ │ ├── finetune_viu3.sh │ │ └── finetune_viu2.sh ├── hmdb51 │ ├── videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ │ ├── finetune_viu3.sh │ │ └── finetune_viu2.sh │ └── videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ │ ├── finetune_viu3.sh │ │ └── finetune_viu2.sh └── kinetics │ ├── videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ ├── finetune_viu3.sh │ └── finetune_viu2.sh │ └── videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400 │ ├── finetune_viu3.sh │ └── finetune_viu2.sh ├── vis.sh ├── INSTALL.md ├── masking_generator.py ├── environment.yml ├── final_args.json ├── DATASET.md ├── functional.py ├── volume_transforms.py ├── PRETRAIN.md ├── engine_for_pretraining.py ├── FINETUNE.md ├── README.md ├── optim_factory.py ├── random_erasing.py ├── MODEL_ZOO.md ├── transforms.py ├── datasets.py ├── run_videomae_vis.py ├── engine_for_finetuning.py ├── run_mae_pretraining.py ├── modeling_finetune.py ├── modeling_pretrain.py ├── mixup.py ├── ssv2.py └── rand_augment.py /gitignore: -------------------------------------------------------------------------------- 1 | experiments 2 | datasets 3 | *.pyc -------------------------------------------------------------------------------- /figs/apt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgcban/apt/HEAD/figs/apt.png -------------------------------------------------------------------------------- /figs/apt-intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgcban/apt/HEAD/figs/apt-intro.jpg -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgcban/apt/HEAD/scripts/.DS_Store -------------------------------------------------------------------------------- /figs/apt-method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wgcban/apt/HEAD/figs/apt-method.jpg -------------------------------------------------------------------------------- /vis.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save video 2 | OUTPUT_DIR='TODO/VideoMAE/demo/vis_k400_1_0.9' 3 | # path to video for visualization 4 | VIDEO_PATH='TODO/TODO.mp4' 5 | # path to pretrain model 6 | MODEL_PATH='TODO/videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1599.pth' 7 | 8 | python3 run_videomae_vis.py \ 9 | --mask_ratio 0.9 \ 10 | --mask_type tube \ 11 | --decoder_depth 4 \ 12 | --model pretrain_videomae_base_patch16_224 \ 13 | ${VIDEO_PATH} ${OUTPUT_DIR} ${MODEL_PATH} -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | The codebase is mainly built with following libraries: 4 | 5 | - Python 3.6 or higher 6 | 7 | - [PyTorch](https://pytorch.org/) and [torchvision](https://github.com/pytorch/vision).
8 | We can successfully reproduce the main results under two settings below:
9 | Tesla **A100** (40G): CUDA 11.1 + PyTorch 1.8.0 + torchvision 0.9.0
10 | Tesla **V100** (32G): CUDA 10.1 + PyTorch 1.6.0 + torchvision 0.7.0 11 | 12 | - [timm==0.4.8/0.4.12](https://github.com/rwightman/pytorch-image-models) 13 | 14 | - [deepspeed==0.5.8](https://github.com/microsoft/DeepSpeed) 15 | 16 | `DS_BUILD_OPS=1 pip install deepspeed` 17 | 18 | - [TensorboardX](https://github.com/lanpa/tensorboardX) 19 | 20 | - [decord](https://github.com/dmlc/decord) 21 | 22 | - [einops](https://github.com/arogozhnikov/einops) 23 | -------------------------------------------------------------------------------- /scripts/ssv2/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on SSv2 2 | OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_small_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5e-2_pl2_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2/' 4 | MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_small_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set SSV2 \ 15 | --nb_classes 174 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --batch_size_val 16 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 2 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --is_aa \ 38 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ucf101/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/linearprobe_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on UCF101 2 | OUTPUT_DIR='experiments/LinearProbe/UCF101/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2' 3 | DATA_PATH='datasets/ucf101/lists' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type linear \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 5 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set UCF101 \ 15 | --nb_classes 101 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.01 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ssv2/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on SSv2 2 | OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5e-2_pl2_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2/' 4 | MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 2 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set SSV2 \ 15 | --nb_classes 174 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 2 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ssv2/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on SSv2 2 | OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5se-2_pl2_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2/' 4 | MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 2 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set SSV2 \ 15 | --nb_classes 174 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 2 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/hmdb51/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on HMDB51 2 | OUTPUT_DIR='experiments/APT/HMDB51/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl200_ps0_pe11_drop10' 3 | DATA_PATH='datasets/hmdb51/lists/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5 python -m torch.distributed.launch --nproc_per_node=5 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 100 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set HMDB51 \ 15 | --nb_classes 51 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 10 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ucf101/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on UCF101 2 | OUTPUT_DIR='experiments/APT/UCF101/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl1_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ucf101/lists' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 1 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set UCF101 \ 15 | --nb_classes 101 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/hmdb51/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on HMDB51 2 | OUTPUT_DIR='experiments/APT/HMDB51/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/hmdb51/lists/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set HMDB51 \ 15 | --nb_classes 51 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 10 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/hmdb51/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on HMDB51 2 | OUTPUT_DIR='experiments/APT/HMDB51/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr1e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/hmdb51/lists/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set HMDB51 \ 15 | --nb_classes 51 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.01 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 10 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ssv2/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on SSv2 2 | OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_small_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5e-2_pl2_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2/' 4 | MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_small_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set SSV2 \ 15 | --nb_classes 174 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --batch_size_val 16 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 2 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ucf101/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on UCF101 2 | OUTPUT_DIR='experiments/APT/UCF101/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ucf101/lists' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set UCF101 \ 15 | --nb_classes 101 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 50 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ucf101/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on UCF101 2 | OUTPUT_DIR='experiments/APT/UCF101/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ucf101/lists/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set UCF101 \ 15 | --nb_classes 101 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/hmdb51/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on HMDB51 2 | OUTPUT_DIR='experiments/APT/HMDB51/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/hmdb51/lists/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set HMDB51 \ 15 | --nb_classes 51 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 10 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/ucf101/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on UCF101 2 | OUTPUT_DIR='experiments/APT/UCF101/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ucf101/lists/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set UCF101 \ 15 | --nb_classes 101 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/kinetics/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on Kinetics-400 2 | OUTPUT_DIR='experiments/APT/K400/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set Kinetics-400 \ 15 | --nb_classes 400 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/kinetics/videomae_vpt_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on Kinetics-400 2 | OUTPUT_DIR='experiments/APT/K400/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_base_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set Kinetics-400 \ 15 | --nb_classes 400 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/kinetics/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu3.sh: -------------------------------------------------------------------------------- 1 | # APT on Kinetics-400 2 | OUTPUT_DIR='experiments/APT/K400/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6 python -m torch.distributed.launch --nproc_per_node=6 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set Kinetics-400 \ 15 | --nb_classes 400 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /scripts/kinetics/videomae_vpt_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune_viu2.sh: -------------------------------------------------------------------------------- 1 | # APT on Kinetics-400 2 | OUTPUT_DIR='experiments/APT/K400/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/adam_mome9e-1_wd1e-5_lr5e-2_pl400_ps0_pe11_drop10' 3 | DATA_PATH='datasets/ss2/list_ssv2/' 4 | MODEL_PATH='experiments/pretrain/k400_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint.pth' 5 | 6 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 7 | run_class_apt.py \ 8 | --model vit_small_patch16_224 \ 9 | --transfer_type prompt \ 10 | --prompt_start 0 \ 11 | --prompt_end 11 \ 12 | --prompt_num_tokens 400 \ 13 | --prompt_dropout 0.1 \ 14 | --data_set Kinetics-400 \ 15 | --nb_classes 400 \ 16 | --data_path ${DATA_PATH} \ 17 | --finetune ${MODEL_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 8 \ 21 | --batch_size_val 8 \ 22 | --num_sample 2 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 10 \ 26 | --num_frames 16 \ 27 | --opt adamw \ 28 | --lr 0.05 \ 29 | --weight_decay 0.00001 \ 30 | --epochs 100 \ 31 | --warmup_epochs 10 \ 32 | --test_num_segment 5 \ 33 | --test_num_crop 3 \ 34 | --dist_eval \ 35 | --pin_mem \ 36 | --enable_deepspeed \ 37 | --prompt_reparam \ 38 | --is_aa \ 39 | --aa rand-m4-n2-mstd0.2-inc1 -------------------------------------------------------------------------------- /masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Tube masking generator 4 | class TubeMaskingGenerator: 5 | def __init__(self, input_size, mask_ratio): 6 | self.frames, self.height, self.width = input_size 7 | self.num_patches_per_frame = self.height * self.width 8 | self.total_patches = self.frames * self.num_patches_per_frame 9 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 10 | self.total_masks = self.frames * self.num_masks_per_frame 11 | 12 | def __repr__(self): 13 | repr_str = "Maks: total patches {}, mask patches {}".format( 14 | self.total_patches, self.total_masks 15 | ) 16 | return repr_str 17 | 18 | def __call__(self): 19 | mask_per_frame = np.hstack([ 20 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 21 | np.ones(self.num_masks_per_frame), 22 | ]) 23 | np.random.shuffle(mask_per_frame) 24 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten() 25 | return mask 26 | 27 | # Random masking generator 28 | class RandomMaskingGenerator: 29 | def __init__(self, input_size, mask_ratio): 30 | self.frames, self.height, self.width = input_size 31 | self.num_patches_per_frame = self.height * self.width 32 | self.total_patches = self.frames * self.num_patches_per_frame 33 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 34 | self.total_masks = self.frames * self.num_masks_per_frame 35 | 36 | def __repr__(self): 37 | repr_str = "Maks: total patches {}, mask patches {}".format( 38 | self.total_patches, self.total_masks 39 | ) 40 | return repr_str 41 | 42 | def __call__(self): 43 | mask = np.hstack([ 44 | np.zeros(self.total_patches - self.total_masks), 45 | np.ones(self.total_masks), 46 | ]) 47 | np.random.shuffle(mask) 48 | return mask -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: apt 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - ca-certificates=2022.12.7=ha878542_0 9 | - certifi=2022.12.7=pyhd8ed1ab_0 10 | - ld_impl_linux-64=2.38=h1181459_1 11 | - libblas=3.9.0=15_linux64_openblas 12 | - libcblas=3.9.0=15_linux64_openblas 13 | - libffi=3.4.2=h6a678d5_6 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgfortran-ng=12.2.0=h69a702a_19 16 | - libgfortran5=12.2.0=h337968e_19 17 | - libgomp=11.2.0=h1234567_1 18 | - liblapack=3.9.0=15_linux64_openblas 19 | - libopenblas=0.3.20=pthreads_h78a6416_0 20 | - libprotobuf=3.15.8=h780b84a_1 21 | - libstdcxx-ng=11.2.0=h1234567_1 22 | - ncurses=6.4=h6a678d5_0 23 | - openssl=1.1.1s=h7f8727e_0 24 | - pip=22.3.1=py39h06a4308_0 25 | - python=3.9.16=h7a1cb2a_0 26 | - python_abi=3.9=2_cp39 27 | - readline=8.2=h5eee18b_0 28 | - setuptools=65.6.3=py39h06a4308_0 29 | - six=1.16.0=pyh6c4a22f_0 30 | - sqlite=3.40.1=h5082296_0 31 | - tensorboardx=2.5.1=pyhd8ed1ab_0 32 | - tk=8.6.12=h1ccaba5_0 33 | - tzdata=2022g=h04d1e81_0 34 | - wheel=0.37.1=pyhd3eb1b0_0 35 | - xz=5.2.10=h5eee18b_1 36 | - zlib=1.2.13=h5eee18b_0 37 | - pip: 38 | - charset-normalizer==3.0.1 39 | - click==8.1.3 40 | - cupy-cuda11x==11.5.0 41 | - decord==0.6.0 42 | - deepspeed==0.8.0 43 | - einops==0.6.0 44 | - fastrlock==0.8.1 45 | - hjson==3.1.0 46 | - idna==3.4 47 | - libaio==0.9.1 48 | - ninja==1.11.1 49 | - numpy==1.24.1 50 | - opencv-python==4.7.0.68 51 | - packaging==23.0 52 | - pandas==1.5.3 53 | - patool==1.12 54 | - pillow==9.4.0 55 | - protobuf==3.20.1 56 | - psutil==5.9.4 57 | - py-cpuinfo==9.0.0 58 | - pydantic==1.10.4 59 | - python-dateutil==2.8.2 60 | - pytz==2022.7.1 61 | - requests==2.28.2 62 | - scipy==1.10.0 63 | - timm==0.4.12 64 | - torch==1.13.1 65 | - torchpq==0.3.0.5 66 | - torchvision==0.14.1 67 | - tqdm==4.64.1 68 | - triton==1.0.0 69 | - typing-extensions==4.4.0 70 | - unp==0.3 71 | - unrar==0.4 72 | - urllib3==1.26.14 -------------------------------------------------------------------------------- /final_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 16, 3 | "batch_size_val": 14, 4 | "epochs": 100, 5 | "update_freq": 1, 6 | "save_ckpt_freq": 10, 7 | "model": "vit_base_patch16_224", 8 | "tubelet_size": 2, 9 | "input_size": 224, 10 | "fc_drop_rate": 0.0, 11 | "drop": 0.0, 12 | "attn_drop_rate": 0.0, 13 | "drop_path": 0.0, 14 | "disable_eval_during_finetuning": false, 15 | "model_ema": false, 16 | "model_ema_decay": 0.9999, 17 | "model_ema_force_cpu": false, 18 | "opt": "sgd", 19 | "opt_eps": 1e-08, 20 | "opt_betas": null, 21 | "clip_grad": null, 22 | "momentum": 0.9, 23 | "weight_decay": 0.001, 24 | "weight_decay_end": null, 25 | "lr": 0.1, 26 | "layer_decay": 1.0, 27 | "warmup_lr": 1e-06, 28 | "min_lr": 1e-06, 29 | "warmup_epochs": 10, 30 | "warmup_steps": -1, 31 | "is_aa": false, 32 | "color_jitter": 0.4, 33 | "num_sample": 2, 34 | "aa": "rand-m7-n4-mstd0.5-inc1", 35 | "smoothing": 0.1, 36 | "train_interpolation": "bicubic", 37 | "crop_pct": null, 38 | "short_side_size": 224, 39 | "test_num_segment": 5, 40 | "test_num_crop": 3, 41 | "reprob": 0.0, 42 | "remode": "pixel", 43 | "recount": 1, 44 | "resplit": false, 45 | "mixup": 0.0, 46 | "cutmix": 0.0, 47 | "cutmix_minmax": null, 48 | "mixup_prob": 0.0, 49 | "mixup_switch_prob": 0.0, 50 | "mixup_mode": "batch", 51 | "finetune": "", 52 | "model_key": "model|module", 53 | "model_prefix": "", 54 | "init_scale": 0.001, 55 | "use_checkpoint": false, 56 | "use_mean_pooling": true, 57 | "prompt_num_tokens": 5, 58 | "prompt_start": null, 59 | "prompt_end": null, 60 | "prompt_dropout": 0.0, 61 | "prompt_init": "random", 62 | "transfer_type": "prompt", 63 | "prompt_reparam": false, 64 | "data_path": "/path/to/list_kinetics-400", 65 | "eval_data_path": null, 66 | "nb_classes": 400, 67 | "imagenet_default_mean_and_std": true, 68 | "num_segments": 1, 69 | "num_frames": 16, 70 | "sampling_rate": 4, 71 | "data_set": "Kinetics-400", 72 | "output_dir": "", 73 | "log_dir": null, 74 | "device": "cuda", 75 | "seed": 0, 76 | "resume": "", 77 | "auto_resume": true, 78 | "save_ckpt": true, 79 | "start_epoch": 0, 80 | "eval": false, 81 | "dist_eval": false, 82 | "num_workers": 10, 83 | "pin_mem": true, 84 | "world_size": 3, 85 | "local_rank": 0, 86 | "dist_on_itp": false, 87 | "dist_url": "env://", 88 | "enable_deepspeed": false, 89 | "rank": 0, 90 | "gpu": 0, 91 | "distributed": true, 92 | "dist_backend": "nccl" 93 | } -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | We have successfully pre-trained and fine-tuned our VideoMAE on [Kinetics400](https://deepmind.com/research/open-source/kinetics), [Something-Something-V2](https://developer.qualcomm.com/software/ai-datasets/something-something), [UCF101](https://www.crcv.ucf.edu/data/UCF101.php) and [HMDB51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) with this codebase. 4 | 5 | - The pre-processing of **Something-Something-V2** can be summarized into 3 steps: 6 | 7 | 1. Download the dataset from [official website](https://developer.qualcomm.com/software/ai-datasets/something-something). 8 | 9 | 2. Preprocess the dataset by changing the video extension from `webm` to `.mp4` with the **original** height of **240px**. We provide the pre-processing [script](scripts/data/data_clean.py) here. 10 | 11 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv` ( here `test.csv` is the same as `val.csv`). The format of `*.csv` file is like: 12 | 13 | ``` 14 | dataset_root/video_1.mp4 label_1 15 | dataset_root/video_2.mp4 label_2 16 | dataset_root/video_3.mp4 label_3 17 | ... 18 | dataset_root/video_N.mp4 label_N 19 | ``` 20 | 21 | - The pre-processing of **Kinetics400** can be summarized into 3 steps: 22 | 23 | 1. Download the dataset from [official website](https://deepmind.com/research/open-source/kinetics). 24 | 25 | 2. Preprocess the dataset by resizing the short edge of video to **320px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2) for [TSN](https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/tsn#kinetics-400-data-benchmark-8-gpus-resnet50-imagenet-pretrain-3-segments) and [SlowOnly](https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/slowonly#kinetics-400-data-benchmark). 26 | 27 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv` ( here `test.csv` is the same as `val.csv`). The format of `*.csv` file is like: 28 | 29 | ``` 30 | dataset_root/video_1.mp4 label_1 31 | dataset_root/video_2.mp4 label_2 32 | dataset_root/video_3.mp4 label_3 33 | ... 34 | dataset_root/video_N.mp4 label_N 35 | ``` 36 | 37 | ### Note: 38 | 39 | We use [decord](https://github.com/dmlc/decord) to decode the videos **on the fly** during both pre-training and fine-tuning phases. -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /PRETRAIN.md: -------------------------------------------------------------------------------- 1 | # Pre-training VideoMAE 2 | 3 | ## Original Implementation 4 | 5 | The implementation of our VideoMAE supports **multi-node distributed training**. We provide the **off-the-shelf** scripts in the [scripts folder](scripts). 6 | 7 | - For example, to pre-train VideoMAE ViT-Base on **Something-Something V2** with 64 GPUs (8 nodes x 8 GPUs), you can run 8 | 9 | ```bash 10 | OUTPUT_DIR='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800' 11 | DATA_PATH='YOUR_PATH/list_ssv2/train.csv' 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 12320 --nnodes=8 \ 15 | --node_rank=0 --master_addr=$ip_node_0 \ 16 | run_mae_pretraining.py \ 17 | --data_path ${DATA_PATH} \ 18 | --mask_type tube \ 19 | --mask_ratio 0.9 \ 20 | --model pretrain_videomae_base_patch16_224 \ 21 | --decoder_depth 4 \ 22 | --batch_size 32 \ 23 | --num_frames 16 \ 24 | --sampling_rate 2 \ 25 | --opt adamw \ 26 | --opt_betas 0.9 0.95 \ 27 | --warmup_epochs 40 \ 28 | --save_ckpt_freq 20 \ 29 | --epochs 801 \ 30 | --log_dir ${OUTPUT_DIR} \ 31 | --output_dir ${OUTPUT_DIR} 32 | ``` 33 | 34 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0. 35 | 36 | - For example, to pre-train VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run 37 | 38 | ```bash 39 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800' 40 | DATA_PATH='YOUR_PATH/list_kinetics-400/train.csv' 41 | 42 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 \ 43 | --master_port 12320 --nnodes=8 \ 44 | --node_rank=0 --master_addr=$your_node_0_ip \ 45 | run_mae_pretraining.py \ 46 | --data_path ${DATA_PATH} \ 47 | --mask_type tube \ 48 | --mask_ratio 0.9 \ 49 | --model pretrain_videomae_base_patch16_224 \ 50 | --decoder_depth 4 \ 51 | --batch_size 32 \ 52 | --num_frames 16 \ 53 | --sampling_rate 4 \ 54 | --opt adamw \ 55 | --opt_betas 0.9 0.95 \ 56 | --warmup_epochs 40 \ 57 | --save_ckpt_freq 20 \ 58 | --epochs 801 \ 59 | --log_dir ${OUTPUT_DIR} \ 60 | --output_dir ${OUTPUT_DIR} 61 | ``` 62 | 63 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0. 64 | 65 | ### Note: 66 | 67 | - Here the batch size is 32 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 2048. 68 | - `lr` here is the base learning rate and is set to `1.5e-4` as default. The ` actual lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `` actual lr`` = `lr` * total batch size / 256. 69 | - We have observed accidental interrupt in the last epoch when conduct the experiment on V100 GPUs (torch 1.6.0). This interrupt is caused by the scheduler of learning rate. We naively set `--epochs 801` to walk away from issue :) 70 | 71 | ## Slurm 72 | 73 | To help the community to reproduce our results on slurm cluster, we also provide the the **off-the-shelf** script. 74 | 75 | For example, to pre-train VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run 76 | 77 | ```bash 78 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 79 | export OMP_NUM_THREADS=1 80 | 81 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800' 82 | DATA_PATH='YOUR_PATH/list_kinetics-400/train.csv' 83 | 84 | JOB_NAME=$1 85 | PARTITION=${PARTITION:-"video"} 86 | # 8 for 1 node, 16 for 2 node, etc. 87 | GPUS=${GPUS:-64} 88 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 89 | CPUS_PER_TASK=${CPUS_PER_TASK:-8} 90 | SRUN_ARGS=${SRUN_ARGS:-""} 91 | PY_ARGS=${@:2} 92 | 93 | # batch_size can be adjusted according to the graphics card 94 | srun -p $PARTITION \ 95 | --job-name=${JOB_NAME} \ 96 | --gres=gpu:${GPUS_PER_NODE} \ 97 | --ntasks=${GPUS} \ 98 | --ntasks-per-node=${GPUS_PER_NODE} \ 99 | --cpus-per-task=${CPUS_PER_TASK} \ 100 | --kill-on-bad-exit=1 \ 101 | ${SRUN_ARGS} \ 102 | python -u run_mae_pretraining.py \ 103 | --data_path ${DATA_PATH} \ 104 | --mask_type tube \ 105 | --mask_ratio 0.9 \ 106 | --model pretrain_videomae_base_patch16_224 \ 107 | --decoder_depth 4 \ 108 | --batch_size 32 \ 109 | --num_frames 16 \ 110 | --sampling_rate 4 \ 111 | --opt adamw \ 112 | --opt_betas 0.9 0.95 \ 113 | --warmup_epochs 40 \ 114 | --save_ckpt_freq 20 \ 115 | --epochs 801 \ 116 | --log_dir ${OUTPUT_DIR} \ 117 | --output_dir ${OUTPUT_DIR} \ 118 | ${PY_ARGS} 119 | ``` 120 | 121 | -------------------------------------------------------------------------------- /engine_for_pretraining.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | from einops import rearrange 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | 10 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, 11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16, 12 | normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None, 13 | lr_schedule_values=None, wd_schedule_values=None): 14 | model.train() 15 | metric_logger = utils.MetricLogger(delimiter=" ") 16 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 17 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 18 | header = 'Epoch: [{}]'.format(epoch) 19 | print_freq = 10 20 | 21 | loss_func = nn.MSELoss() 22 | 23 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 24 | # assign learning rate & weight decay for each step 25 | it = start_steps + step # global training iteration 26 | if lr_schedule_values is not None or wd_schedule_values is not None: 27 | for i, param_group in enumerate(optimizer.param_groups): 28 | if lr_schedule_values is not None: 29 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 30 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 31 | param_group["weight_decay"] = wd_schedule_values[it] 32 | 33 | videos, bool_masked_pos = batch 34 | videos = videos.to(device, non_blocking=True) 35 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool) 36 | 37 | with torch.no_grad(): 38 | # calculate the predict label 39 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] 40 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] 41 | unnorm_videos = videos * std + mean # in [0, 1] 42 | 43 | if normlize_target: 44 | videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size) 45 | videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True) 46 | ) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) 47 | # we find that the mean is about 0.48 and standard deviation is about 0.08. 48 | videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)') 49 | else: 50 | videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size) 51 | 52 | B, _, C = videos_patch.shape 53 | labels = videos_patch[bool_masked_pos].reshape(B, -1, C) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(videos, bool_masked_pos) 57 | loss = loss_func(input=outputs, target=labels) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | optimizer.zero_grad() 66 | # this attribute is added by timm on one optimizer (adahessian) 67 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 68 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), create_graph=is_second_order) 70 | loss_scale_value = loss_scaler.state_dict()["scale"] 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | metric_logger.update(loss_scale=loss_scale_value) 76 | min_lr = 10. 77 | max_lr = 0. 78 | for group in optimizer.param_groups: 79 | min_lr = min(min_lr, group["lr"]) 80 | max_lr = max(max_lr, group["lr"]) 81 | 82 | metric_logger.update(lr=max_lr) 83 | metric_logger.update(min_lr=min_lr) 84 | weight_decay_value = None 85 | for group in optimizer.param_groups: 86 | if group["weight_decay"] > 0: 87 | weight_decay_value = group["weight_decay"] 88 | metric_logger.update(weight_decay=weight_decay_value) 89 | metric_logger.update(grad_norm=grad_norm) 90 | 91 | if log_writer is not None: 92 | log_writer.update(loss=loss_value, head="loss") 93 | log_writer.update(loss_scale=loss_scale_value, head="opt") 94 | log_writer.update(lr=max_lr, head="opt") 95 | log_writer.update(min_lr=min_lr, head="opt") 96 | log_writer.update(weight_decay=weight_decay_value, head="opt") 97 | log_writer.update(grad_norm=grad_norm, head="opt") 98 | log_writer.set_step() 99 | 100 | if lr_scheduler is not None: 101 | lr_scheduler.step_update(start_steps + step) 102 | # gather the stats from all processes 103 | metric_logger.synchronize_between_processes() 104 | print("Averaged stats:", metric_logger) 105 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 106 | -------------------------------------------------------------------------------- /FINETUNE.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning VideoMAE 2 | 3 | ## Original Implementation 4 | 5 | The implementation of our VideoMAE supports **multi-node distributed training**. We provide the **off-the-shelf** scripts in the [scripts folder](scripts). 6 | 7 | - For example, to fine-tune VideoMAE ViT-Base on **Something-Something V2** with 64 GPUs (8 nodes x 8 GPUs), you can run 8 | 9 | ```bash 10 | OUTPUT_DIR='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800/eval_lr_5e-4_epoch_50' 11 | DATA_PATH='YOUR_PATH/list_ssv2' 12 | MODEL_PATH='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800/checkpoint-799.pth' 13 | 14 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 15 | --master_port 12320 --nnodes=8 \ 16 | --node_rank=0 --master_addr=$ip_node_0 \ 17 | run_class_finetuning.py \ 18 | --model vit_base_patch16_224 \ 19 | --data_set SSV2 \ 20 | --nb_classes 174 \ 21 | --data_path ${DATA_PATH} \ 22 | --finetune ${MODEL_PATH} \ 23 | --log_dir ${OUTPUT_DIR} \ 24 | --output_dir ${OUTPUT_DIR} \ 25 | --batch_size 8 \ 26 | --num_sample 1 \ 27 | --input_size 224 \ 28 | --short_side_size 224 \ 29 | --save_ckpt_freq 10 \ 30 | --num_frames 16 \ 31 | --opt adamw \ 32 | --lr 5e-4 \ 33 | --opt_betas 0.9 0.999 \ 34 | --weight_decay 0.05 \ 35 | --epochs 50 \ 36 | --dist_eval \ 37 | --test_num_segment 2 \ 38 | --test_num_crop 3 \ 39 | --enable_deepspeed 40 | ``` 41 | 42 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0. 43 | 44 | - For example, to fine-tune VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run 45 | 46 | ```bash 47 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/eval_lr_1e-3_epoch_100' 48 | DATA_PATH='YOUR_PATH/list_kinetics-400' 49 | MODEL_PATH='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/checkpoint-799.pth' 50 | 51 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 52 | --master_port 12320 --nnodes=8 \ 53 | --node_rank=0 --master_addr=$ip_node_0 \ 54 | run_class_finetuning.py \ 55 | --model vit_base_patch16_224 \ 56 | --data_set Kinetics-400 \ 57 | --nb_classes 400 \ 58 | --data_path ${DATA_PATH} \ 59 | --finetune ${MODEL_PATH} \ 60 | --log_dir ${OUTPUT_DIR} \ 61 | --output_dir ${OUTPUT_DIR} \ 62 | --batch_size 8 \ 63 | --num_sample 1 \ 64 | --input_size 224 \ 65 | --short_side_size 224 \ 66 | --save_ckpt_freq 10 \ 67 | --num_frames 16 \ 68 | --sampling_rate 4 \ 69 | --opt adamw \ 70 | --lr 1e-3 \ 71 | --opt_betas 0.9 0.999 \ 72 | --weight_decay 0.05 \ 73 | --epochs 100 \ 74 | --dist_eval \ 75 | --test_num_segment 5 \ 76 | --test_num_crop 3 \ 77 | --enable_deepspeed 78 | ``` 79 | 80 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0. 81 | 82 | ### Note: 83 | 84 | - We perform the **I3D dense sampling** on **Kinetics400** and **uniform sampling** on **Something-Something V2**, respectively. 85 | - We didn't use `cls token` in our implementation, and directly average the feature of last layer for video classification. 86 | - Here total batch size = (`batch_size` per gpu) x `nodes` x (gpus per node). 87 | - `lr` here is the base learning rate. The ` actual lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `` actual lr`` = `lr` * total batch size / 256. 88 | 89 | ## Slurm 90 | 91 | To help the community to reproduce our results on slurm cluster, we also provide the the **off-the-shelf** script. 92 | 93 | For example, to fine-tune VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run: 94 | 95 | ```bash 96 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 97 | export OMP_NUM_THREADS=1 98 | 99 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/eval_lr_1e-3_epoch_100' 100 | DATA_PATH='YOUR_PATH/list_kinetics-400' 101 | MODEL_PATH='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/checkpoint-799.pth' 102 | 103 | JOB_NAME=$1 104 | PARTITION=${PARTITION:-"video"} 105 | # 8 for 1 node, 16 for 2 node, etc. 106 | GPUS=${GPUS:-64} 107 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 108 | CPUS_PER_TASK=${CPUS_PER_TASK:-8} 109 | SRUN_ARGS=${SRUN_ARGS:-""} 110 | PY_ARGS=${@:2} 111 | 112 | # batch_size can be adjusted according to the graphics card 113 | srun -p $PARTITION \ 114 | --job-name=${JOB_NAME} \ 115 | --gres=gpu:${GPUS_PER_NODE} \ 116 | --ntasks=${GPUS} \ 117 | --ntasks-per-node=${GPUS_PER_NODE} \ 118 | --cpus-per-task=${CPUS_PER_TASK} \ 119 | --kill-on-bad-exit=1 \ 120 | ${SRUN_ARGS} \ 121 | python -u run_class_finetuning.py \ 122 | --model vit_base_patch16_224 \ 123 | --data_set Kinetics-400 \ 124 | --nb_classes 400 \ 125 | --data_path ${DATA_PATH} \ 126 | --finetune ${MODEL_PATH} \ 127 | --log_dir ${OUTPUT_DIR} \ 128 | --output_dir ${OUTPUT_DIR} \ 129 | --batch_size 8 \ 130 | --num_sample 1 \ 131 | --input_size 224 \ 132 | --short_side_size 224 \ 133 | --save_ckpt_freq 10 \ 134 | --num_frames 16 \ 135 | --sampling_rate 4 \ 136 | --opt adamw \ 137 | --lr 1e-3 \ 138 | --opt_betas 0.9 0.999 \ 139 | --weight_decay 0.05 \ 140 | --epochs 100 \ 141 | --dist_eval \ 142 | --test_num_segment 5 \ 143 | --test_num_crop 3 \ 144 | --enable_deepspeed \ 145 | ${PY_ARGS} 146 | ``` 147 | 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APT: Attention Prompt Tuning 2 | > A Parameter-Efficient Adaptation of Pre-Trained Models for Action Recognition ...
3 | 4 | > [Wele Gedara Chaminda Bandara](https://github.com/wgcban), [Vishal M Patel](https://engineering.jhu.edu/vpatel36/team/vishalpatel/)
Johns Hopkins University 5 | 6 | > Accepted at [FG'24](https://fg2024.ieee-biometrics.org) 7 | 8 | > [Paper (on ArXiv)](https://arxiv.org/abs/2403.06978)
9 | 10 | ## Overview of Proposed Method 11 | 12 |

13 | 14 |

15 |

16 | Comparison of our Attention Prompt Tuning (APT) for videos action classification with other existing tuning methods: linear probing, adapter tuning, visual prompt tuning (VPT), and full fine-tuning. 17 |

18 | 19 | 20 |

21 | 22 |

23 |

24 | Attention Prompt Tuning (APT) injects learnable prompts directly into the MHA unlike VPT. 25 |

26 | 27 | ## Getting Started 28 | 29 | ### Step 1: Conda Environment 30 | 31 | Setup the virtual conda environment using the `environment.yml`: 32 | ``` 33 | conda env create -f environment.yml 34 | ``` 35 | 36 | Then activate the conda environment: 37 | ``` 38 | conda activate apt 39 | ``` 40 | 41 | ### Step 2: Download the VideoMAE Pre-trained Models: 42 | 43 | We use [VideoMAE](https://github.com/MCG-NJU/VideoMAE) pretrianed on [Kinetics-400](https://github.com/cvdfoundation/kinetics-dataset) dataset for our experiments. 44 | 45 | The pre-trained models for ViT-Small and ViT-Base backbones can be downloaded from below links: 46 | 47 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | 48 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | 49 | | VideoMAE | ***no*** | ViT-S | 1600 | 16x5x3 | [checkpoint](https://drive.google.com/file/d/1nU-H1u3eJ-VuyCveU7v-WIOcAVxs5Hww/view?usp=sharing) | 50 | | VideoMAE | ***no*** | ViT-B | 1600 | 16x5x3 | [checkpoint](https://drive.google.com/file/d/1tEhLyskjb755TJ65ptsrafUG2llSwQE1/view?usp=sharing) | 51 | 52 | If you need other pre-trained models please refer [MODEL_ZOO.md](https://github.com/wgcban/apt/blob/main/MODEL_ZOO.md). 53 | 54 | ### Step 3: Download the datasets 55 | 56 | We conduct experiments on three action recognition datasets: 1) UCF101 2) HMDB51 3) Something-Something-V2. 57 | 58 | Please refer [DATASETS.md](https://github.com/wgcban/apt/blob/main/DATASET.md) for access to those links and pre-processing steps. 59 | 60 | ### Step 4: Attention Prompt Tuning 61 | 62 | We provide example scripts to run the attention prompt tuning on UCF101, HMDB51, and SSv2 datasets in `scripts/` folder. 63 | 64 | Inside `scripts/` you can find two folders which corresponds to APT finetuning with ViT-Small and ViT-Base architectures. 65 | 66 | To fine-tune with APT you just need to execute `finetune.sh` file -- which will launch the job with distributed training by 67 | 68 | 69 | For example, to fine-tune ViT-Base on SSv2 with APT, you may run: 70 | ``` 71 | sh scripts/ssv2/vit_base/finetune.sh 72 | ``` 73 | 74 | The `finetune.sh` looks like this: 75 | 76 | ```bash 77 | # APT on SSv2 78 | OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5se-2_pl2_ps0_pe11_drop10' 79 | DATA_PATH='datasets/ss2/list_ssv2/' 80 | MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth' 81 | 82 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \ 83 | run_class_apt.py \ 84 | --model vit_base_patch16_224 \ 85 | --transfer_type prompt \ 86 | --prompt_start 0 \ 87 | --prompt_end 11 \ 88 | --prompt_num_tokens 2 \ 89 | --prompt_dropout 0.1 \ 90 | --data_set SSV2 \ 91 | --nb_classes 174 \ 92 | --data_path ${DATA_PATH} \ 93 | --finetune ${MODEL_PATH} \ 94 | --log_dir ${OUTPUT_DIR} \ 95 | --output_dir ${OUTPUT_DIR} \ 96 | --batch_size 8 \ 97 | --batch_size_val 8 \ 98 | --num_sample 2 \ 99 | --input_size 224 \ 100 | --short_side_size 224 \ 101 | --save_ckpt_freq 10 \ 102 | --num_frames 16 \ 103 | --opt adamw \ 104 | --lr 0.05 \ 105 | --weight_decay 0.00001 \ 106 | --epochs 100 \ 107 | --warmup_epochs 10 \ 108 | --test_num_segment 2 \ 109 | --test_num_crop 3 \ 110 | --dist_eval \ 111 | --pin_mem \ 112 | --enable_deepspeed \ 113 | --prompt_reparam \ 114 | --is_aa \ 115 | --aa rand-m4-n2-mstd0.2-inc1 116 | 117 | ``` 118 | 119 | Here, 120 | 121 | - `OUTPUT_DIR`: place where you wants to save the results (i.e., logs and checkpoints) 122 | - `DATA_PATH`: path to where the dataset is stored 123 | - `MODEL_PATH`: path to the downloaded videomae pre-trained model 124 | - specifiy thich gpus (gpu ids) you wants to use for finetuning in `CUDA_VISIBLE_DEVICES=`... 125 | - `nproc_per_node` is the number of gpus using for fine-tuning 126 | - `model` is the vit-base (vit_base_patch16_224) or vit-small (vit_small_patch16_224) 127 | - `transfer_type` specifies which finetuning method to use. 'random' means random initialization, 'end2end' means full end-to-end fine tuning, 'prompt' means APT (ours), 'linear' means linear probing 128 | - `prompt_start` refers to starting trasnformer block where you add attention prompts. 0 means you start adding learninable prompts from 1st transformer block in vit 129 | - `prompt_end` refers to ending trasformer block where you stop adding attention prompts. vit-base / vit-small has 12 transformer blocks. hence 11 here means you add prompts until last trasnformer block 130 | - `data_set` specifies the dataset 131 | - * all the other parameters are hyperparamters related to apt fine-tuning. 132 | 133 | 134 | ## ✏️ Citation 135 | 136 | If you think this project is helpful, please feel free to leave a star and cite our paper: 137 | 138 | ```bibtex 139 | @misc{bandara2024attention, 140 | title={Attention Prompt Tuning: Parameter-efficient Adaptation of Pre-trained Models for Spatiotemporal Modeling}, 141 | author={Wele Gedara Chaminda Bandara and Vishal M. Patel}, 142 | year={2024}, 143 | eprint={2403.06978}, 144 | archivePrefix={arXiv}, 145 | primaryClass={cs.CV} 146 | } 147 | ``` 148 | 149 | 150 | ## ✏️ Disclaimer 151 | 152 | This repocitory is built on top of VideoMAE: https://github.com/MCG-NJU/VideoMAE codebase and we approcite the authors of VideoMAE for making their codebase publically available. 153 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_vit(var_name, num_max_layer): 25 | if var_name in ("cls_token", "mask_token", "pos_embed"): 26 | return 0 27 | elif var_name.startswith("patch_embed"): 28 | return 0 29 | elif var_name.startswith("rel_pos_bias"): 30 | return num_max_layer - 1 31 | elif var_name.startswith("blocks"): 32 | layer_id = int(var_name.split('.')[1]) 33 | return layer_id + 1 34 | else: 35 | return num_max_layer - 1 36 | 37 | 38 | class LayerDecayValueAssigner(object): 39 | def __init__(self, values): 40 | self.values = values 41 | 42 | def get_scale(self, layer_id): 43 | return self.values[layer_id] 44 | 45 | def get_layer_id(self, var_name): 46 | return get_num_layer_for_vit(var_name, len(self.values)) 47 | 48 | 49 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 50 | parameter_group_names = {} 51 | parameter_group_vars = {} 52 | 53 | for name, param in model.named_parameters(): 54 | if not param.requires_grad: 55 | continue # frozen weights 56 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 57 | group_name = "no_decay" 58 | this_weight_decay = 0. 59 | else: 60 | group_name = "decay" 61 | this_weight_decay = weight_decay 62 | if get_num_layer is not None: 63 | layer_id = get_num_layer(name) 64 | group_name = "layer_%d_%s" % (layer_id, group_name) 65 | else: 66 | layer_id = None 67 | 68 | if group_name not in parameter_group_names: 69 | if get_layer_scale is not None: 70 | scale = get_layer_scale(layer_id) 71 | else: 72 | scale = 1. 73 | 74 | parameter_group_names[group_name] = { 75 | "weight_decay": this_weight_decay, 76 | "params": [], 77 | "lr_scale": scale 78 | } 79 | parameter_group_vars[group_name] = { 80 | "weight_decay": this_weight_decay, 81 | "params": [], 82 | "lr_scale": scale 83 | } 84 | 85 | parameter_group_vars[group_name]["params"].append(param) 86 | parameter_group_names[group_name]["params"].append(name) 87 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 88 | return list(parameter_group_vars.values()) 89 | 90 | 91 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 92 | opt_lower = args.opt.lower() 93 | weight_decay = args.weight_decay 94 | if weight_decay and filter_bias_and_bn: 95 | skip = {} 96 | if skip_list is not None: 97 | skip = skip_list 98 | elif hasattr(model, 'no_weight_decay'): 99 | skip = model.no_weight_decay() 100 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 101 | weight_decay = 0. 102 | else: 103 | parameters = model.parameters() 104 | 105 | if 'fused' in opt_lower: 106 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 107 | 108 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 109 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 110 | opt_args['eps'] = args.opt_eps 111 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 112 | opt_args['betas'] = args.opt_betas 113 | 114 | print("optimizer settings:", opt_args) 115 | 116 | opt_split = opt_lower.split('_') 117 | opt_lower = opt_split[-1] 118 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 119 | opt_args.pop('eps', None) 120 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 121 | elif opt_lower == 'momentum': 122 | opt_args.pop('eps', None) 123 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 124 | elif opt_lower == 'adam': 125 | optimizer = optim.Adam(parameters, **opt_args) 126 | elif opt_lower == 'adamw': 127 | optimizer = optim.AdamW(parameters, **opt_args) 128 | elif opt_lower == 'nadam': 129 | optimizer = Nadam(parameters, **opt_args) 130 | elif opt_lower == 'radam': 131 | optimizer = RAdam(parameters, **opt_args) 132 | elif opt_lower == 'adamp': 133 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 134 | elif opt_lower == 'sgdp': 135 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 136 | elif opt_lower == 'adadelta': 137 | optimizer = optim.Adadelta(parameters, **opt_args) 138 | elif opt_lower == 'adafactor': 139 | if not args.lr: 140 | opt_args['lr'] = None 141 | optimizer = Adafactor(parameters, **opt_args) 142 | elif opt_lower == 'adahessian': 143 | optimizer = Adahessian(parameters, **opt_args) 144 | elif opt_lower == 'rmsprop': 145 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 146 | elif opt_lower == 'rmsproptf': 147 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 148 | elif opt_lower == 'novograd': 149 | optimizer = NovoGrad(parameters, **opt_args) 150 | elif opt_lower == 'nvnovograd': 151 | optimizer = NvNovoGrad(parameters, **opt_args) 152 | elif opt_lower == 'fusedsgd': 153 | opt_args.pop('eps', None) 154 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 155 | elif opt_lower == 'fusedmomentum': 156 | opt_args.pop('eps', None) 157 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 158 | elif opt_lower == 'fusedadam': 159 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 160 | elif opt_lower == 'fusedadamw': 161 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 162 | elif opt_lower == 'fusedlamb': 163 | optimizer = FusedLAMB(parameters, **opt_args) 164 | elif opt_lower == 'fusednovograd': 165 | opt_args.setdefault('betas', (0.95, 0.98)) 166 | optimizer = FusedNovoGrad(parameters, **opt_args) 167 | else: 168 | assert False and "Invalid optimizer" 169 | raise ValueError 170 | 171 | if len(opt_split) > 1: 172 | if opt_split[0] == 'lookahead': 173 | optimizer = Lookahead(optimizer) 174 | 175 | return optimizer 176 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Pre-trained VideoMAE Models 2 | 3 | For all experiments on APT, we use VideoMAE pre-trained ViT models on Kinetics-400. 4 | 5 | The following table provide different checkpoints. 6 | 7 | Note that we use pre-trained checkpoint. Not the fine-tuned one. 8 | 9 | ### Kinetics-400 10 | 11 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | Fine-tune | Top-1 | Top-5 | 12 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: | 13 | | VideoMAE | ***no*** | ViT-S | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1fbmQtp3UUw9fro3MVkKCW62Ib_HlZvNz/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1nU-H1u3eJ-VuyCveU7v-WIOcAVxs5Hww/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1RuEvCT2OMKPax2gGB1gBsH6ItiXIPH-R/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1ygjLRm1kvs9mwGsP3lLxUExhRo6TWnrx/view?usp=sharing) | 79.0 | 93.8 | 14 | | VideoMAE | ***no*** | ViT-B | 800 | 16x5x3 | [script](scripts/kinetics/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/pretrain.sh)/[log](https://drive.google.com/file/d/1kP3_-465jCL7PRNFq1JcAghPo2BONRWY/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1JfrhN144Hdg7we213H1WxwR3lGYOlmIn/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/finetune.sh)/[log](https://drive.google.com/file/d/1JOJzhlCujgpsjjth0J49k5EwBNxy76xt/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/18EEgdXY9347yK3Yb28O-GxFMbk41F6Ne/view?usp=sharing)
(w/o repeated aug) | 80.0 | 94.4 | 15 | | VideoMAE | ***no*** | ViT-B | 800 | 16x5x3 | same as above | TODO | 81.0 | 94.8 | 16 | | VideoMAE | ***no*** | ViT-B | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1ftVHzzCupEGV4bCHC5JWIUsEwOEeAQcg/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1tEhLyskjb755TJ65ptsrafUG2llSwQE1/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_large_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1fYXtL2y2ZTMxDtTRqoUOe6leVmdVI5HH/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1MzwteHH-1yuMnFb8vRBQDvngV1Zl-d3z/view?usp=sharing) | 81.5 | 95.1 | 17 | | VideoMAE | ***no*** | ViT-L | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_large_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1X7WBzn_yG4lDWuvBMBBgrtgqDLZVHrc2/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1qLOXWb_MGEvaI7tvuAe94CV7S2HXRwT3/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_large_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1Doqx6zDQEMnMyPvDdz2knG385o0sZn3f/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1jX1CiqxSkCfc94y8FRW1YGHy-GNvHCuD/view?usp=sharing) | 85.2 | 96.8 | 18 | | VideoMAE | ***no*** | ViT-H | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_huge_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1ZGOGk5_L7cqJ2UkrNQ7c_jcw1OUBqptl/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1AJQR1Rsi2N1pDn9tLyJ8DQrUREiBA1bO/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_huge_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1NOUjO5wPrHZo4EUfklKvfGM3ScJVmGAK/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/104ouJZxSVPSAm0LwJXd6IzjdA_RGLqZi/view?usp=sharing) | 86.6 | 97.1 | 19 | 20 | ### Something-Something V2 21 | 22 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | Fine-tune | Top-1 | Top-5 | 23 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: | 24 | | VideoMAE | ***no*** | ViT-S | 2400 | 16x2x3 | [script](scripts/ssv2/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/pretrain.sh)/[log](https://drive.google.com/file/d/129wqpAtwTCD-T1SQIX7q5nB9CEGchhw0/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1p_I1aaONOeUvRmRQw1UT3-L2H8XJClHu/view?usp=sharing) | [script](scripts/ssv2/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune.sh)/[log](https://drive.google.com/file/d/17X9PcDSBB1Zb1blNqQP3vvnqOuMzJrGp/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1ajlMrT06jiiM-5YjNI2X_UFyzsuYbbtZ/view?usp=sharing) | 66.8 | 90.3 | 25 | | VideoMAE | ***no*** | ViT-B | 800 | 16x2x3 | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/pretrain.sh)/[log](https://drive.google.com/file/d/1eGS18rKvbgEJ3nbsXxokkMSwNGxxoX48/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/181hLvyrrPW2IOGA46fkxdJk0tNLIgdB2/view?usp=sharing) | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/finetune.sh)/[log](https://drive.google.com/file/d/1jYAHPcs7zt_QMPM2D_geEWoWrf3yHox8/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1xZCiaPF4w7lYmLt5o1D5tIZyDdLtJAvH/view?usp=sharing)
(w/o repeated aug) | 69.6 | 92.0 | 26 | | VideoMAE | ***no*** | ViT-B | 2400 | 16x2x3 | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/pretrain.sh)/[log](https://drive.google.com/file/d/148nURgfcIFBQd3IQH5YhJ9dTwNCc2jkU/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1I18dY_7rSalGL8fPWV82c0-foRUDzJJk/view?usp=sharing) | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune.sh)/[log](https://drive.google.com/file/d/15TPBiUl_K2Q_9l6J41G_vf-2lovVLEHM/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1dt_59tBIyzdZd5Ecr22lTtzs_64MOZkT/view?usp=sharing) | 70.8 | 92.4 | 27 | 28 | ### UCF101 29 | 30 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | Fine-tune | Top-1 | Top-5 | 31 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: | 32 | | VideoMAE | ***no*** | ViT-B | 3200 | 16x5x3 | [script](scripts/ucf101/videomae_vit_base_patch16_224_tubemasking_ratio_0.75_epoch_3200/pretrain.sh)/[log](https://drive.google.com/file/d/1kZODk_dQgB-aW6oIwPYZxqZAG6YKNtXC/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1BHev4meNgKM0o_8DMRbuzAsKSP3IpQ3o/view?usp=sharing) | [script](scripts/ucf101/videomae_vit_base_patch16_224_tubemasking_ratio_0.75_epoch_3200/finetune.sh)/[log](https://drive.google.com/file/d/17Mq7rlM1TRgV4KKX7UIlmKw653RmwSqe/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1MSyon6fPpKz7oqD6WDGPFK4k_Rbyb6fw/view?usp=sharing) | 91.3 | 98.5 | 33 | 34 | ### Note: 35 | 36 | - We report the results of VideoMAE finetuned with `I3D dense sampling` on **Kinetics400** and `TSN uniform sampling` on **Something-Something V2**, respectively. 37 | - \#Frame = #input_frame x #clip x #crop. 38 | - \#input_frame means how many frames are input for model during the test phase. 39 | - \#crop means spatial crops (e.g., 3 for left/right/center crop). 40 | - \#clip means temporal clips (e.g., 5 means repeted temporal sampling five clips with different start indices). 41 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import random 5 | import numpy as np 6 | import torchvision 7 | from PIL import Image, ImageOps 8 | import numbers 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_tuple): 19 | img_group, label = img_tuple 20 | 21 | w, h = img_group[0].size 22 | th, tw = self.size 23 | 24 | out_images = list() 25 | 26 | x1 = random.randint(0, w - tw) 27 | y1 = random.randint(0, h - th) 28 | 29 | for img in img_group: 30 | assert(img.size[0] == w and img.size[1] == h) 31 | if w == tw and h == th: 32 | out_images.append(img) 33 | else: 34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 35 | 36 | return (out_images, label) 37 | 38 | 39 | class GroupCenterCrop(object): 40 | def __init__(self, size): 41 | self.worker = torchvision.transforms.CenterCrop(size) 42 | 43 | def __call__(self, img_tuple): 44 | img_group, label = img_tuple 45 | return ([self.worker(img) for img in img_group], label) 46 | 47 | 48 | class GroupNormalize(object): 49 | def __init__(self, mean, std): 50 | self.mean = mean 51 | self.std = std 52 | 53 | def __call__(self, tensor_tuple): 54 | tensor, label = tensor_tuple 55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 56 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 57 | 58 | # TODO: make efficient 59 | for t, m, s in zip(tensor, rep_mean, rep_std): 60 | t.sub_(m).div_(s) 61 | 62 | return (tensor,label) 63 | 64 | 65 | class GroupGrayScale(object): 66 | def __init__(self, size): 67 | self.worker = torchvision.transforms.Grayscale(size) 68 | 69 | def __call__(self, img_tuple): 70 | img_group, label = img_tuple 71 | return ([self.worker(img) for img in img_group], label) 72 | 73 | 74 | class GroupScale(object): 75 | """ Rescales the input PIL.Image to the given 'size'. 76 | 'size' will be the size of the smaller edge. 77 | For example, if height > width, then image will be 78 | rescaled to (size * height / width, size) 79 | size: size of the smaller edge 80 | interpolation: Default: PIL.Image.BILINEAR 81 | """ 82 | 83 | def __init__(self, size, interpolation=Image.BILINEAR): 84 | self.worker = torchvision.transforms.Resize(size, interpolation) 85 | 86 | def __call__(self, img_tuple): 87 | img_group, label = img_tuple 88 | return ([self.worker(img) for img in img_group], label) 89 | 90 | 91 | class GroupMultiScaleCrop(object): 92 | 93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 94 | self.scales = scales if scales is not None else [1, .875, .75, .66] 95 | self.max_distort = max_distort 96 | self.fix_crop = fix_crop 97 | self.more_fix_crop = more_fix_crop 98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 99 | self.interpolation = Image.BILINEAR 100 | 101 | def __call__(self, img_tuple): 102 | img_group, label = img_tuple 103 | 104 | im_size = img_group[0].size 105 | 106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 109 | return (ret_img_group, label) 110 | 111 | def _sample_crop_size(self, im_size): 112 | image_w, image_h = im_size[0], im_size[1] 113 | 114 | # find a crop size 115 | base_size = min(image_w, image_h) 116 | crop_sizes = [int(base_size * x) for x in self.scales] 117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 119 | 120 | pairs = [] 121 | for i, h in enumerate(crop_h): 122 | for j, w in enumerate(crop_w): 123 | if abs(i - j) <= self.max_distort: 124 | pairs.append((w, h)) 125 | 126 | crop_pair = random.choice(pairs) 127 | if not self.fix_crop: 128 | w_offset = random.randint(0, image_w - crop_pair[0]) 129 | h_offset = random.randint(0, image_h - crop_pair[1]) 130 | else: 131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 132 | 133 | return crop_pair[0], crop_pair[1], w_offset, h_offset 134 | 135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 137 | return random.choice(offsets) 138 | 139 | @staticmethod 140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 141 | w_step = (image_w - crop_w) // 4 142 | h_step = (image_h - crop_h) // 4 143 | 144 | ret = list() 145 | ret.append((0, 0)) # upper left 146 | ret.append((4 * w_step, 0)) # upper right 147 | ret.append((0, 4 * h_step)) # lower left 148 | ret.append((4 * w_step, 4 * h_step)) # lower right 149 | ret.append((2 * w_step, 2 * h_step)) # center 150 | 151 | if more_fix_crop: 152 | ret.append((0, 2 * h_step)) # center left 153 | ret.append((4 * w_step, 2 * h_step)) # center right 154 | ret.append((2 * w_step, 4 * h_step)) # lower center 155 | ret.append((2 * w_step, 0 * h_step)) # upper center 156 | 157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 161 | return ret 162 | 163 | 164 | class Stack(object): 165 | 166 | def __init__(self, roll=False): 167 | self.roll = roll 168 | 169 | def __call__(self, img_tuple): 170 | img_group, label = img_tuple 171 | 172 | if img_group[0].mode == 'L': 173 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 174 | elif img_group[0].mode == 'RGB': 175 | if self.roll: 176 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 177 | else: 178 | return (np.concatenate(img_group, axis=2), label) 179 | 180 | 181 | class ToTorchFormatTensor(object): 182 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 183 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 184 | def __init__(self, div=True): 185 | self.div = div 186 | 187 | def __call__(self, pic_tuple): 188 | pic, label = pic_tuple 189 | 190 | if isinstance(pic, np.ndarray): 191 | # handle numpy array 192 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 193 | else: 194 | # handle PIL Image 195 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 196 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 197 | # put it from HWC to CHW format 198 | # yikes, this transpose takes 80% of the loading time/CPU 199 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 200 | return (img.float().div(255.) if self.div else img.float(), label) 201 | 202 | 203 | class IdentityTransform(object): 204 | 205 | def __call__(self, data): 206 | return data 207 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import transforms 3 | from transforms import * 4 | from masking_generator import TubeMaskingGenerator 5 | from kinetics import VideoClsDataset, VideoMAE 6 | from ssv2 import SSVideoClsDataset 7 | 8 | 9 | class DataAugmentationForVideoMAE(object): 10 | def __init__(self, args): 11 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 12 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 13 | normalize = GroupNormalize(self.input_mean, self.input_std) 14 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) 15 | self.transform = transforms.Compose([ 16 | self.train_augmentation, 17 | Stack(roll=False), 18 | ToTorchFormatTensor(div=True), 19 | normalize, 20 | ]) 21 | if args.mask_type == 'tube': 22 | self.masked_position_generator = TubeMaskingGenerator( 23 | args.window_size, args.mask_ratio 24 | ) 25 | 26 | def __call__(self, images): 27 | process_data, _ = self.transform(images) 28 | return process_data, self.masked_position_generator() 29 | 30 | def __repr__(self): 31 | repr = "(DataAugmentationForVideoMAE,\n" 32 | repr += " transform = %s,\n" % str(self.transform) 33 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 34 | repr += ")" 35 | return repr 36 | 37 | 38 | def build_pretraining_dataset(args): 39 | transform = DataAugmentationForVideoMAE(args) 40 | dataset = VideoMAE( 41 | root=None, 42 | setting=args.data_path, 43 | video_ext='mp4', 44 | is_color=True, 45 | modality='rgb', 46 | new_length=args.num_frames, 47 | new_step=args.sampling_rate, 48 | transform=transform, 49 | temporal_jitter=False, 50 | video_loader=True, 51 | use_decord=True, 52 | lazy_init=False) 53 | print("Data Aug = %s" % str(transform)) 54 | return dataset 55 | 56 | 57 | def build_dataset(is_train, test_mode, args): 58 | if args.data_set == 'Kinetics-400': 59 | mode = None 60 | anno_path = None 61 | if is_train is True: 62 | mode = 'train' 63 | anno_path = os.path.join(args.data_path, 'train.csv') 64 | elif test_mode is True: 65 | mode = 'test' 66 | anno_path = os.path.join(args.data_path, 'test.csv') 67 | else: 68 | mode = 'validation' 69 | anno_path = os.path.join(args.data_path, 'val.csv') 70 | 71 | dataset = VideoClsDataset( 72 | anno_path=anno_path, 73 | data_path='/', 74 | mode=mode, 75 | clip_len=args.num_frames, 76 | frame_sample_rate=args.sampling_rate, 77 | num_segment=1, 78 | test_num_segment=args.test_num_segment, 79 | test_num_crop=args.test_num_crop, 80 | num_crop=1 if not test_mode else 3, 81 | keep_aspect_ratio=True, 82 | crop_size=args.input_size, 83 | short_side_size=args.short_side_size, 84 | new_height=256, 85 | new_width=320, 86 | args=args) 87 | nb_classes = 400 88 | 89 | elif args.data_set == 'SSV2': 90 | mode = None 91 | anno_path = None 92 | if is_train is True: 93 | mode = 'train' 94 | anno_path = os.path.join(args.data_path, 'train.csv') 95 | elif test_mode is True: 96 | mode = 'test' 97 | anno_path = os.path.join(args.data_path, 'test.csv') 98 | else: 99 | mode = 'validation' 100 | anno_path = os.path.join(args.data_path, 'val.csv') 101 | 102 | dataset = SSVideoClsDataset( 103 | anno_path=anno_path, 104 | data_path='/', 105 | mode=mode, 106 | clip_len=1, 107 | num_segment=args.num_frames, 108 | test_num_segment=args.test_num_segment, 109 | test_num_crop=args.test_num_crop, 110 | num_crop=1 if not test_mode else 3, 111 | keep_aspect_ratio=True, 112 | crop_size=args.input_size, 113 | short_side_size=args.short_side_size, 114 | new_height=256, 115 | new_width=320, 116 | args=args) 117 | nb_classes = 174 118 | 119 | elif args.data_set == 'UCF101': 120 | mode = None 121 | anno_path = None 122 | if is_train is True: 123 | mode = 'train' 124 | anno_path = os.path.join(args.data_path, 'train.csv') 125 | elif test_mode is True: 126 | mode = 'test' 127 | anno_path = os.path.join(args.data_path, 'test.csv') 128 | else: 129 | mode = 'validation' 130 | anno_path = os.path.join(args.data_path, 'val.csv') 131 | 132 | dataset = VideoClsDataset( 133 | anno_path=anno_path, 134 | data_path='/', 135 | mode=mode, 136 | clip_len=args.num_frames, 137 | frame_sample_rate=args.sampling_rate, 138 | num_segment=1, 139 | test_num_segment=args.test_num_segment, 140 | test_num_crop=args.test_num_crop, 141 | num_crop=1 if not test_mode else 3, 142 | keep_aspect_ratio=True, 143 | crop_size=args.input_size, 144 | short_side_size=args.short_side_size, 145 | new_height=256, 146 | new_width=320, 147 | args=args) 148 | nb_classes = 101 149 | 150 | elif args.data_set == 'HMDB51': 151 | mode = None 152 | anno_path = None 153 | if is_train is True: 154 | mode = 'train' 155 | anno_path = os.path.join(args.data_path, 'train.csv') 156 | elif test_mode is True: 157 | mode = 'test' 158 | anno_path = os.path.join(args.data_path, 'test.csv') 159 | else: 160 | mode = 'validation' 161 | anno_path = os.path.join(args.data_path, 'val.csv') 162 | 163 | dataset = VideoClsDataset( 164 | anno_path=anno_path, 165 | data_path='/', 166 | mode=mode, 167 | clip_len=args.num_frames, 168 | frame_sample_rate=args.sampling_rate, 169 | num_segment=1, 170 | test_num_segment=args.test_num_segment, 171 | test_num_crop=args.test_num_crop, 172 | num_crop=1 if not test_mode else 3, 173 | keep_aspect_ratio=True, 174 | crop_size=args.input_size, 175 | short_side_size=args.short_side_size, 176 | new_height=256, 177 | new_width=320, 178 | args=args) 179 | nb_classes = 51 180 | 181 | elif args.data_set == 'ROCOG': 182 | mode = None 183 | anno_path = None 184 | if is_train is True: 185 | mode = 'train' 186 | anno_path = os.path.join(args.data_path, 'train.csv') 187 | elif test_mode is True: 188 | mode = 'test' 189 | anno_path = os.path.join(args.data_path, 'test.csv') 190 | else: 191 | mode = 'validation' 192 | anno_path = os.path.join(args.data_path, 'val.csv') 193 | 194 | dataset = VideoClsDataset( 195 | anno_path=anno_path, 196 | data_path='/', 197 | mode=mode, 198 | clip_len=args.num_frames, 199 | frame_sample_rate=args.sampling_rate, 200 | num_segment=1, 201 | test_num_segment=args.test_num_segment, 202 | test_num_crop=args.test_num_crop, 203 | num_crop=1 if not test_mode else 3, 204 | keep_aspect_ratio=True, 205 | crop_size=args.input_size, 206 | short_side_size=args.short_side_size, 207 | new_height=256, 208 | new_width=320, 209 | args=args) 210 | nb_classes = 7 211 | else: 212 | raise NotImplementedError() 213 | assert nb_classes == args.nb_classes 214 | print("Number of the class = %d" % args.nb_classes) 215 | 216 | return dataset, nb_classes 217 | -------------------------------------------------------------------------------- /run_videomae_vis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | from PIL import Image 7 | from pathlib import Path 8 | from timm.models import create_model 9 | import utils 10 | import modeling_pretrain 11 | from datasets import DataAugmentationForVideoMAE 12 | from torchvision.transforms import ToPILImage 13 | from einops import rearrange 14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 15 | from decord import VideoReader, cpu 16 | from torchvision import transforms 17 | from transforms import * 18 | from masking_generator import TubeMaskingGenerator 19 | 20 | class DataAugmentationForVideoMAE(object): 21 | def __init__(self, args): 22 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 23 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 24 | normalize = GroupNormalize(self.input_mean, self.input_std) 25 | self.train_augmentation = GroupCenterCrop(args.input_size) 26 | self.transform = transforms.Compose([ 27 | self.train_augmentation, 28 | Stack(roll=False), 29 | ToTorchFormatTensor(div=True), 30 | normalize, 31 | ]) 32 | if args.mask_type == 'tube': 33 | self.masked_position_generator = TubeMaskingGenerator( 34 | args.window_size, args.mask_ratio 35 | ) 36 | 37 | def __call__(self, images): 38 | process_data , _ = self.transform(images) 39 | return process_data, self.masked_position_generator() 40 | 41 | def __repr__(self): 42 | repr = "(DataAugmentationForVideoMAE,\n" 43 | repr += " transform = %s,\n" % str(self.transform) 44 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 45 | repr += ")" 46 | return repr 47 | 48 | def get_args(): 49 | parser = argparse.ArgumentParser('VideoMAE visualization reconstruction script', add_help=False) 50 | parser.add_argument('img_path', type=str, help='input video path') 51 | parser.add_argument('save_path', type=str, help='save video path') 52 | parser.add_argument('model_path', type=str, help='checkpoint path of model') 53 | parser.add_argument('--mask_type', default='random', choices=['random', 'tube'], 54 | type=str, help='masked strategy of video tokens/patches') 55 | parser.add_argument('--num_frames', type=int, default= 16) 56 | parser.add_argument('--sampling_rate', type=int, default= 4) 57 | parser.add_argument('--decoder_depth', default=4, type=int, 58 | help='depth of decoder') 59 | parser.add_argument('--input_size', default=224, type=int, 60 | help='videos input size for backbone') 61 | parser.add_argument('--device', default='cuda:0', 62 | help='device to use for training / testing') 63 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true') 64 | parser.add_argument('--mask_ratio', default=0.75, type=float, 65 | help='ratio of the visual tokens/patches need be masked') 66 | # Model parameters 67 | parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL', 68 | help='Name of model to vis') 69 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', 70 | help='Drop path rate (default: 0.1)') 71 | 72 | return parser.parse_args() 73 | 74 | 75 | def get_model(args): 76 | print(f"Creating model: {args.model}") 77 | model = create_model( 78 | args.model, 79 | pretrained=False, 80 | drop_path_rate=args.drop_path, 81 | drop_block_rate=None, 82 | decoder_depth=args.decoder_depth 83 | ) 84 | 85 | return model 86 | 87 | 88 | def main(args): 89 | print(args) 90 | 91 | device = torch.device(args.device) 92 | cudnn.benchmark = True 93 | 94 | model = get_model(args) 95 | patch_size = model.encoder.patch_embed.patch_size 96 | print("Patch size = %s" % str(patch_size)) 97 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) 98 | args.patch_size = patch_size 99 | 100 | model.to(device) 101 | checkpoint = torch.load(args.model_path, map_location='cpu') 102 | model.load_state_dict(checkpoint['model']) 103 | model.eval() 104 | 105 | if args.save_path: 106 | Path(args.save_path).mkdir(parents=True, exist_ok=True) 107 | 108 | with open(args.img_path, 'rb') as f: 109 | vr = VideoReader(f, ctx=cpu(0)) 110 | duration = len(vr) 111 | new_length = 1 112 | new_step = 1 113 | skip_length = new_length * new_step 114 | # frame_id_list = [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61] 115 | 116 | 117 | tmp = np.arange(0,32, 2) + 60 118 | frame_id_list = tmp.tolist() 119 | # average_duration = (duration - skip_length + 1) // args.num_frames 120 | # if average_duration > 0: 121 | # frame_id_list = np.multiply(list(range(args.num_frames)), 122 | # average_duration) 123 | # frame_id_list = frame_id_list + np.random.randint(average_duration, 124 | # size=args.num_frames) 125 | 126 | video_data = vr.get_batch(frame_id_list).asnumpy() 127 | print(video_data.shape) 128 | img = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 129 | 130 | transforms = DataAugmentationForVideoMAE(args) 131 | img, bool_masked_pos = transforms((img, None)) # T*C,H,W 132 | # print(img.shape) 133 | img = img.view((args.num_frames , 3) + img.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W 134 | # img = img.view(( -1 , args.num_frames) + img.size()[-2:]) 135 | bool_masked_pos = torch.from_numpy(bool_masked_pos) 136 | 137 | with torch.no_grad(): 138 | # img = img[None, :] 139 | # bool_masked_pos = bool_masked_pos[None, :] 140 | img = img.unsqueeze(0) 141 | print(img.shape) 142 | bool_masked_pos = bool_masked_pos.unsqueeze(0) 143 | 144 | img = img.to(device, non_blocking=True) 145 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool) 146 | outputs = model(img, bool_masked_pos) 147 | 148 | #save original video 149 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] 150 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] 151 | ori_img = img * std + mean # in [0, 1] 152 | imgs = [ToPILImage()(ori_img[0,:,vid,:,:].cpu()) for vid, _ in enumerate(frame_id_list) ] 153 | for id, im in enumerate(imgs): 154 | im.save(f"{args.save_path}/ori_img{id}.jpg") 155 | 156 | img_squeeze = rearrange(ori_img, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size[0], p2=patch_size[0]) 157 | img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) 158 | img_patch = rearrange(img_norm, 'b n p c -> b n (p c)') 159 | img_patch[bool_masked_pos] = outputs 160 | 161 | #make mask 162 | mask = torch.ones_like(img_patch) 163 | mask[bool_masked_pos] = 0 164 | mask = rearrange(mask, 'b n (p c) -> b n p c', c=3) 165 | mask = rearrange(mask, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14) 166 | 167 | #save reconstruction video 168 | rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3) 169 | # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch. 170 | rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean(dim=-2, keepdim=True) 171 | rec_img = rearrange(rec_img, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14) 172 | imgs = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0,0.996)) for vid, _ in enumerate(frame_id_list) ] 173 | 174 | for id, im in enumerate(imgs): 175 | im.save(f"{args.save_path}/rec_img{id}.jpg") 176 | 177 | #save masked video 178 | img_mask = rec_img * mask 179 | imgs = [ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(frame_id_list)] 180 | for id, im in enumerate(imgs): 181 | im.save(f"{args.save_path}/mask_img{id}.jpg") 182 | 183 | if __name__ == '__main__': 184 | opts = get_args() 185 | main(opts) 186 | -------------------------------------------------------------------------------- /engine_for_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import sys 5 | from typing import Iterable, Optional 6 | import torch 7 | from mixup import Mixup 8 | from timm.utils import accuracy, ModelEma 9 | import utils 10 | from scipy.special import softmax 11 | 12 | def train_class_batch(model, samples, target, criterion): 13 | outputs = model(samples) 14 | loss = criterion(outputs, target) 15 | return loss, outputs 16 | 17 | 18 | def get_loss_scale_for_deepspeed(model): 19 | optimizer = model.optimizer 20 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 21 | 22 | 23 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 24 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 27 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 28 | num_training_steps_per_epoch=None, update_freq=None): 29 | model.train(True) 30 | metric_logger = utils.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 32 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 10 35 | 36 | if loss_scaler is None: 37 | model.zero_grad() 38 | model.micro_steps = 0 39 | else: 40 | optimizer.zero_grad() 41 | 42 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 43 | step = data_iter_step // update_freq 44 | if step >= num_training_steps_per_epoch: 45 | continue 46 | it = start_steps + step # global training iteration 47 | # Update LR & WD for the first acc 48 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 49 | for i, param_group in enumerate(optimizer.param_groups): 50 | if lr_schedule_values is not None: 51 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 52 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 53 | param_group["weight_decay"] = wd_schedule_values[it] 54 | 55 | samples = samples.to(device, non_blocking=True) 56 | targets = targets.to(device, non_blocking=True) 57 | 58 | if mixup_fn is not None: 59 | samples, targets = mixup_fn(samples, targets) 60 | 61 | if loss_scaler is None: 62 | samples = samples.half() 63 | loss, output = train_class_batch( 64 | model, samples, targets, criterion) 65 | else: 66 | with torch.cuda.amp.autocast(): 67 | loss, output = train_class_batch( 68 | model, samples, targets, criterion) 69 | 70 | loss_value = loss.item() 71 | 72 | if not math.isfinite(loss_value): 73 | print("Loss is {}, stopping training".format(loss_value)) 74 | sys.exit(1) 75 | 76 | if loss_scaler is None: 77 | loss /= update_freq 78 | model.backward(loss) 79 | model.step() 80 | 81 | if (data_iter_step + 1) % update_freq == 0: 82 | # model.zero_grad() 83 | # Deepspeed will call step() & model.zero_grad() automatic 84 | if model_ema is not None: 85 | model_ema.update(model) 86 | grad_norm = None 87 | loss_scale_value = get_loss_scale_for_deepspeed(model) 88 | else: 89 | # this attribute is added by timm on one optimizer (adahessian) 90 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 91 | loss /= update_freq 92 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 93 | parameters=model.parameters(), create_graph=is_second_order, 94 | update_grad=(data_iter_step + 1) % update_freq == 0) 95 | if (data_iter_step + 1) % update_freq == 0: 96 | optimizer.zero_grad() 97 | if model_ema is not None: 98 | model_ema.update(model) 99 | loss_scale_value = loss_scaler.state_dict()["scale"] 100 | 101 | torch.cuda.synchronize() 102 | 103 | if mixup_fn is None: 104 | class_acc = (output.max(-1)[-1] == targets).float().mean() 105 | else: 106 | class_acc = None 107 | metric_logger.update(loss=loss_value) 108 | metric_logger.update(class_acc=class_acc) 109 | metric_logger.update(loss_scale=loss_scale_value) 110 | min_lr = 10. 111 | max_lr = 0. 112 | for group in optimizer.param_groups: 113 | min_lr = min(min_lr, group["lr"]) 114 | max_lr = max(max_lr, group["lr"]) 115 | 116 | metric_logger.update(lr=max_lr) 117 | metric_logger.update(min_lr=min_lr) 118 | weight_decay_value = None 119 | for group in optimizer.param_groups: 120 | if group["weight_decay"] > 0: 121 | weight_decay_value = group["weight_decay"] 122 | metric_logger.update(weight_decay=weight_decay_value) 123 | metric_logger.update(grad_norm=grad_norm) 124 | 125 | if log_writer is not None: 126 | log_writer.update(loss=loss_value, head="loss") 127 | log_writer.update(class_acc=class_acc, head="loss") 128 | log_writer.update(loss_scale=loss_scale_value, head="opt") 129 | log_writer.update(lr=max_lr, head="opt") 130 | log_writer.update(min_lr=min_lr, head="opt") 131 | log_writer.update(weight_decay=weight_decay_value, head="opt") 132 | log_writer.update(grad_norm=grad_norm, head="opt") 133 | 134 | log_writer.set_step() 135 | 136 | # gather the stats from all processes 137 | metric_logger.synchronize_between_processes() 138 | print("Averaged stats:", metric_logger) 139 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 140 | 141 | 142 | @torch.no_grad() 143 | def validation_one_epoch(data_loader, model, device): 144 | criterion = torch.nn.CrossEntropyLoss() 145 | 146 | metric_logger = utils.MetricLogger(delimiter=" ") 147 | header = 'Val:' 148 | 149 | # switch to evaluation mode 150 | model.eval() 151 | 152 | for batch in metric_logger.log_every(data_loader, 10, header): 153 | videos = batch[0] 154 | target = batch[1] 155 | videos = videos.to(device, non_blocking=True) 156 | target = target.to(device, non_blocking=True) 157 | 158 | # compute output 159 | with torch.cuda.amp.autocast(): 160 | output = model(videos) 161 | loss = criterion(output, target) 162 | 163 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 164 | 165 | batch_size = videos.shape[0] 166 | metric_logger.update(loss=loss.item()) 167 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 168 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 169 | # gather the stats from all processes 170 | metric_logger.synchronize_between_processes() 171 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 172 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 173 | 174 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 175 | 176 | 177 | 178 | @torch.no_grad() 179 | def final_test(data_loader, model, device, file): 180 | criterion = torch.nn.CrossEntropyLoss() 181 | 182 | metric_logger = utils.MetricLogger(delimiter=" ") 183 | header = 'Test:' 184 | 185 | # switch to evaluation mode 186 | model.eval() 187 | final_result = [] 188 | 189 | for batch in metric_logger.log_every(data_loader, 10, header): 190 | videos = batch[0] 191 | target = batch[1] 192 | ids = batch[2] 193 | chunk_nb = batch[3] 194 | split_nb = batch[4] 195 | videos = videos.to(device, non_blocking=True) 196 | target = target.to(device, non_blocking=True) 197 | 198 | # compute output 199 | with torch.cuda.amp.autocast(): 200 | output = model(videos) 201 | loss = criterion(output, target) 202 | 203 | for i in range(output.size(0)): 204 | string = "{} {} {} {} {}\n".format(ids[i], \ 205 | str(output.data[i].cpu().numpy().tolist()), \ 206 | str(int(target[i].cpu().numpy())), \ 207 | str(int(chunk_nb[i].cpu().numpy())), \ 208 | str(int(split_nb[i].cpu().numpy()))) 209 | final_result.append(string) 210 | 211 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 212 | 213 | batch_size = videos.shape[0] 214 | metric_logger.update(loss=loss.item()) 215 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 216 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 217 | 218 | if not os.path.exists(file): 219 | os.mknod(file) 220 | with open(file, 'w') as f: 221 | f.write("{}, {}\n".format(acc1, acc5)) 222 | for line in final_result: 223 | f.write(line) 224 | # gather the stats from all processes 225 | metric_logger.synchronize_between_processes() 226 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 227 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 228 | 229 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 230 | 231 | 232 | def merge(eval_path, num_tasks): 233 | dict_feats = {} 234 | dict_label = {} 235 | dict_pos = {} 236 | print("Reading individual output files") 237 | 238 | for x in range(num_tasks): 239 | file = os.path.join(eval_path, str(x) + '.txt') 240 | lines = open(file, 'r').readlines()[1:] 241 | for line in lines: 242 | line = line.strip() 243 | name = line.split('[')[0] 244 | label = line.split(']')[1].split(' ')[1] 245 | chunk_nb = line.split(']')[1].split(' ')[2] 246 | split_nb = line.split(']')[1].split(' ')[3] 247 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',') 248 | data = softmax(data) 249 | if not name in dict_feats: 250 | dict_feats[name] = [] 251 | dict_label[name] = 0 252 | dict_pos[name] = [] 253 | if chunk_nb + split_nb in dict_pos[name]: 254 | continue 255 | dict_feats[name].append(data) 256 | dict_pos[name].append(chunk_nb + split_nb) 257 | dict_label[name] = label 258 | print("Computing final results") 259 | 260 | input_lst = [] 261 | print(len(dict_feats)) 262 | for i, item in enumerate(dict_feats): 263 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 264 | from multiprocessing import Pool 265 | p = Pool(64) 266 | ans = p.map(compute_video, input_lst) 267 | top1 = [x[1] for x in ans] 268 | top5 = [x[2] for x in ans] 269 | pred = [x[0] for x in ans] 270 | label = [x[3] for x in ans] 271 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 272 | return final_top1*100 ,final_top5*100 273 | 274 | def compute_video(lst): 275 | i, video_id, data, label = lst 276 | feat = [x for x in data] 277 | feat = np.mean(feat, axis=0) 278 | pred = np.argmax(feat) 279 | top1 = (int(pred) == int(label)) * 1.0 280 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 281 | return [pred, top1, top5, int(label)] 282 | -------------------------------------------------------------------------------- /run_mae_pretraining.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import os 9 | from pathlib import Path 10 | from timm.models import create_model 11 | from optim_factory import create_optimizer 12 | from datasets import build_pretraining_dataset 13 | from engine_for_pretraining import train_one_epoch 14 | from utils import NativeScalerWithGradNormCount as NativeScaler 15 | import utils 16 | import modeling_pretrain 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser('VideoMAE pre-training script', add_help=False) 21 | parser.add_argument('--batch_size', default=64, type=int) 22 | parser.add_argument('--epochs', default=800, type=int) 23 | parser.add_argument('--save_ckpt_freq', default=50, type=int) 24 | 25 | # Model parameters 26 | parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL', 27 | help='Name of model to train') 28 | 29 | parser.add_argument('--decoder_depth', default=4, type=int, 30 | help='depth of decoder') 31 | 32 | parser.add_argument('--mask_type', default='tube', choices=['random', 'tube'], 33 | type=str, help='masked strategy of video tokens/patches') 34 | 35 | parser.add_argument('--mask_ratio', default=0.75, type=float, 36 | help='ratio of the visual tokens/patches need be masked') 37 | 38 | parser.add_argument('--input_size', default=224, type=int, 39 | help='videos input size for backbone') 40 | 41 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', 42 | help='Drop path rate (default: 0.1)') 43 | 44 | parser.add_argument('--normlize_target', default=True, type=bool, 45 | help='normalized the target patch pixels') 46 | 47 | # Optimizer parameters 48 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 49 | help='Optimizer (default: "adamw"') 50 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 51 | help='Optimizer Epsilon (default: 1e-8)') 52 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 53 | help='Optimizer Betas (default: None, use opt default)') 54 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 55 | help='Clip gradient norm (default: None, no clipping)') 56 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 57 | help='SGD momentum (default: 0.9)') 58 | parser.add_argument('--weight_decay', type=float, default=0.05, 59 | help='weight decay (default: 0.05)') 60 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 61 | weight decay. We use a cosine schedule for WD. 62 | (Set the same value with args.weight_decay to keep weight decay no change)""") 63 | 64 | parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR', 65 | help='learning rate (default: 1.5e-4)') 66 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 67 | help='warmup learning rate (default: 1e-6)') 68 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', 69 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 70 | 71 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 72 | help='epochs to warmup LR, if scheduler supports') 73 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 74 | help='epochs to warmup LR, if scheduler supports') 75 | parser.add_argument('--use_checkpoint', action='store_true') 76 | parser.set_defaults(use_checkpoint=False) 77 | 78 | # Augmentation parameters 79 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT', 80 | help='Color jitter factor (default: 0.4)') 81 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 82 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 83 | 84 | # Dataset parameters 85 | parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str, 86 | help='dataset path') 87 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true') 88 | parser.add_argument('--num_frames', type=int, default= 16) 89 | parser.add_argument('--sampling_rate', type=int, default= 4) 90 | parser.add_argument('--output_dir', default='', 91 | help='path where to save, empty for no saving') 92 | parser.add_argument('--log_dir', default=None, 93 | help='path where to tensorboard log') 94 | parser.add_argument('--device', default='cuda', 95 | help='device to use for training / testing') 96 | parser.add_argument('--seed', default=0, type=int) 97 | parser.add_argument('--resume', default='', help='resume from checkpoint') 98 | parser.add_argument('--auto_resume', action='store_true') 99 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 100 | parser.set_defaults(auto_resume=True) 101 | 102 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 103 | help='start epoch') 104 | parser.add_argument('--num_workers', default=10, type=int) 105 | parser.add_argument('--pin_mem', action='store_true', 106 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 107 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', 108 | help='') 109 | parser.set_defaults(pin_mem=True) 110 | 111 | # distributed training parameters 112 | parser.add_argument('--world_size', default=1, type=int, 113 | help='number of distributed processes') 114 | parser.add_argument('--local_rank', default=-1, type=int) 115 | parser.add_argument('--dist_on_itp', action='store_true') 116 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 117 | 118 | return parser.parse_args() 119 | 120 | 121 | def get_model(args): 122 | print(f"Creating model: {args.model}") 123 | model = create_model( 124 | args.model, 125 | pretrained=False, 126 | drop_path_rate=args.drop_path, 127 | drop_block_rate=None, 128 | decoder_depth=args.decoder_depth, 129 | use_checkpoint=args.use_checkpoint 130 | ) 131 | return model 132 | 133 | 134 | def main(args): 135 | utils.init_distributed_mode(args) 136 | 137 | print(args) 138 | 139 | device = torch.device(args.device) 140 | 141 | # fix the seed for reproducibility 142 | seed = args.seed + utils.get_rank() 143 | torch.manual_seed(seed) 144 | np.random.seed(seed) 145 | 146 | cudnn.benchmark = True 147 | 148 | model = get_model(args) 149 | patch_size = model.encoder.patch_embed.patch_size 150 | print("Patch size = %s" % str(patch_size)) 151 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) 152 | args.patch_size = patch_size 153 | 154 | # get dataset 155 | dataset_train = build_pretraining_dataset(args) 156 | 157 | 158 | num_tasks = utils.get_world_size() 159 | global_rank = utils.get_rank() 160 | sampler_rank = global_rank 161 | num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks 162 | 163 | sampler_train = torch.utils.data.DistributedSampler( 164 | dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True 165 | ) 166 | print("Sampler_train = %s" % str(sampler_train)) 167 | 168 | 169 | if global_rank == 0 and args.log_dir is not None: 170 | os.makedirs(args.log_dir, exist_ok=True) 171 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 172 | else: 173 | log_writer = None 174 | 175 | data_loader_train = torch.utils.data.DataLoader( 176 | dataset_train, sampler=sampler_train, 177 | batch_size=args.batch_size, 178 | num_workers=args.num_workers, 179 | pin_memory=args.pin_mem, 180 | drop_last=True, 181 | worker_init_fn=utils.seed_worker 182 | ) 183 | 184 | model.to(device) 185 | model_without_ddp = model 186 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 187 | 188 | print("Model = %s" % str(model_without_ddp)) 189 | print('number of params: {} M'.format(n_parameters / 1e6)) 190 | 191 | total_batch_size = args.batch_size * utils.get_world_size() 192 | 193 | args.lr = args.lr * total_batch_size / 256 194 | args.min_lr = args.min_lr * total_batch_size / 256 195 | args.warmup_lr = args.warmup_lr * total_batch_size / 256 196 | print("LR = %.8f" % args.lr) 197 | print("Batch size = %d" % total_batch_size) 198 | print("Number of training steps = %d" % num_training_steps_per_epoch) 199 | print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch)) 200 | 201 | if args.distributed: 202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 203 | model_without_ddp = model.module 204 | 205 | optimizer = create_optimizer( 206 | args, model_without_ddp) 207 | loss_scaler = NativeScaler() 208 | 209 | print("Use step level LR & WD scheduler!") 210 | lr_schedule_values = utils.cosine_scheduler( 211 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 212 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 213 | ) 214 | if args.weight_decay_end is None: 215 | args.weight_decay_end = args.weight_decay 216 | wd_schedule_values = utils.cosine_scheduler( 217 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 218 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 219 | 220 | utils.auto_load_model( 221 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 222 | torch.cuda.empty_cache() 223 | print(f"Start training for {args.epochs} epochs") 224 | start_time = time.time() 225 | for epoch in range(args.start_epoch, args.epochs): 226 | if args.distributed: 227 | data_loader_train.sampler.set_epoch(epoch) 228 | if log_writer is not None: 229 | log_writer.set_step(epoch * num_training_steps_per_epoch) 230 | train_stats = train_one_epoch( 231 | model, data_loader_train, 232 | optimizer, device, epoch, loss_scaler, 233 | args.clip_grad, log_writer=log_writer, 234 | start_steps=epoch * num_training_steps_per_epoch, 235 | lr_schedule_values=lr_schedule_values, 236 | wd_schedule_values=wd_schedule_values, 237 | patch_size=patch_size[0], 238 | normlize_target=args.normlize_target, 239 | ) 240 | if args.output_dir: 241 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 242 | utils.save_model( 243 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 244 | loss_scaler=loss_scaler, epoch=epoch) 245 | 246 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 247 | 'epoch': epoch, 'n_parameters': n_parameters} 248 | 249 | if args.output_dir and utils.is_main_process(): 250 | if log_writer is not None: 251 | log_writer.flush() 252 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 253 | f.write(json.dumps(log_stats) + "\n") 254 | 255 | total_time = time.time() - start_time 256 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 257 | print('Training time {}'.format(total_time_str)) 258 | 259 | 260 | if __name__ == '__main__': 261 | opts = get_args() 262 | if opts.output_dir: 263 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True) 264 | main(opts) 265 | -------------------------------------------------------------------------------- /modeling_finetune.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | import torch.utils.checkpoint as checkpoint 9 | 10 | 11 | def _cfg(url='', **kwargs): 12 | return { 13 | 'url': url, 14 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None, 15 | 'crop_pct': .9, 'interpolation': 'bicubic', 16 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 17 | **kwargs 18 | } 19 | 20 | 21 | class DropPath(nn.Module): 22 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 23 | """ 24 | def __init__(self, drop_prob=None): 25 | super(DropPath, self).__init__() 26 | self.drop_prob = drop_prob 27 | 28 | def forward(self, x): 29 | return drop_path(x, self.drop_prob, self.training) 30 | 31 | def extra_repr(self) -> str: 32 | return 'p={}'.format(self.drop_prob) 33 | 34 | 35 | class Mlp(nn.Module): 36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 37 | super().__init__() 38 | out_features = out_features or in_features 39 | hidden_features = hidden_features or in_features 40 | self.fc1 = nn.Linear(in_features, hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | # x = self.drop(x) 49 | # commit this for the orignal BERT implement 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__( 57 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 58 | proj_drop=0., attn_head_dim=None): 59 | super().__init__() 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | if attn_head_dim is not None: 63 | head_dim = attn_head_dim 64 | all_head_dim = head_dim * self.num_heads 65 | self.scale = qk_scale or head_dim ** -0.5 66 | 67 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 68 | if qkv_bias: 69 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 70 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 71 | else: 72 | self.q_bias = None 73 | self.v_bias = None 74 | 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(all_head_dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape 81 | qkv_bias = None 82 | if self.q_bias is not None: 83 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 84 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 85 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 86 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 87 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 88 | 89 | q = q * self.scale 90 | attn = (q @ k.transpose(-2, -1)) 91 | 92 | 93 | attn = attn.softmax(dim=-1) 94 | attn = self.attn_drop(attn) 95 | 96 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 97 | x = self.proj(x) 98 | x = self.proj_drop(x) 99 | return x 100 | 101 | 102 | class Block(nn.Module): 103 | 104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 105 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 106 | attn_head_dim=None): 107 | super().__init__() 108 | self.norm1 = norm_layer(dim) 109 | self.attn = Attention( 110 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 111 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 114 | self.norm2 = norm_layer(dim) 115 | mlp_hidden_dim = int(dim * mlp_ratio) 116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 117 | 118 | if init_values > 0: 119 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 120 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 121 | else: 122 | self.gamma_1, self.gamma_2 = None, None 123 | 124 | def forward(self, x): 125 | if self.gamma_1 is None: 126 | x = x + self.drop_path(self.attn(self.norm1(x))) 127 | x = x + self.drop_path(self.mlp(self.norm2(x))) 128 | else: 129 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 130 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 131 | return x 132 | 133 | 134 | class PatchEmbed(nn.Module): 135 | """ Image to Patch Embedding 136 | """ 137 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): 138 | super().__init__() 139 | img_size = to_2tuple(img_size) 140 | patch_size = to_2tuple(patch_size) 141 | self.tubelet_size = int(tubelet_size) 142 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size) 143 | self.img_size = img_size 144 | self.patch_size = patch_size 145 | self.num_patches = num_patches 146 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim, 147 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]), 148 | stride=(self.tubelet_size, patch_size[0], patch_size[1])) 149 | 150 | def forward(self, x, **kwargs): 151 | B, C, T, H, W = x.shape 152 | # FIXME look at relaxing size constraints 153 | assert H == self.img_size[0] and W == self.img_size[1], \ 154 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 155 | x = self.proj(x).flatten(2).transpose(1, 2) 156 | return x 157 | 158 | # sin-cos position encoding 159 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 160 | def get_sinusoid_encoding_table(n_position, d_hid): 161 | ''' Sinusoid position encoding table ''' 162 | # TODO: make it with torch instead of numpy 163 | def get_position_angle_vec(position): 164 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 165 | 166 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 167 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 168 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 169 | 170 | return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0) 171 | 172 | 173 | class VisionTransformer(nn.Module): 174 | """ Vision Transformer with support for patch or hybrid CNN input stage 175 | """ 176 | def __init__(self, 177 | img_size=224, 178 | patch_size=16, 179 | in_chans=3, 180 | num_classes=1000, 181 | embed_dim=768, 182 | depth=12, 183 | num_heads=12, 184 | mlp_ratio=4., 185 | qkv_bias=False, 186 | qk_scale=None, 187 | fc_drop_rate=0., 188 | drop_rate=0., 189 | attn_drop_rate=0., 190 | drop_path_rate=0., 191 | norm_layer=nn.LayerNorm, 192 | init_values=0., 193 | use_learnable_pos_emb=False, 194 | init_scale=0., 195 | all_frames=16, 196 | tubelet_size=2, 197 | use_checkpoint=False, 198 | use_mean_pooling=True): 199 | super().__init__() 200 | self.num_classes = num_classes 201 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 202 | self.tubelet_size = tubelet_size 203 | self.patch_embed = PatchEmbed( 204 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size) 205 | num_patches = self.patch_embed.num_patches 206 | self.use_checkpoint = use_checkpoint 207 | 208 | if use_learnable_pos_emb: 209 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 210 | else: 211 | # sine-cosine positional embeddings is on the way 212 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 213 | 214 | self.pos_drop = nn.Dropout(p=drop_rate) 215 | 216 | 217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 218 | self.blocks = nn.ModuleList([ 219 | Block( 220 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 222 | init_values=init_values) 223 | for i in range(depth)]) 224 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 225 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 226 | self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity() 227 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 228 | 229 | if use_learnable_pos_emb: 230 | trunc_normal_(self.pos_embed, std=.02) 231 | 232 | trunc_normal_(self.head.weight, std=.02) 233 | self.apply(self._init_weights) 234 | 235 | self.head.weight.data.mul_(init_scale) 236 | self.head.bias.data.mul_(init_scale) 237 | 238 | def _init_weights(self, m): 239 | if isinstance(m, nn.Linear): 240 | trunc_normal_(m.weight, std=.02) 241 | if isinstance(m, nn.Linear) and m.bias is not None: 242 | nn.init.constant_(m.bias, 0) 243 | elif isinstance(m, nn.LayerNorm): 244 | nn.init.constant_(m.bias, 0) 245 | nn.init.constant_(m.weight, 1.0) 246 | 247 | def get_num_layers(self): 248 | return len(self.blocks) 249 | 250 | @torch.jit.ignore 251 | def no_weight_decay(self): 252 | return {'pos_embed', 'cls_token'} 253 | 254 | def get_classifier(self): 255 | return self.head 256 | 257 | def reset_classifier(self, num_classes, global_pool=''): 258 | self.num_classes = num_classes 259 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 260 | 261 | def forward_features(self, x): 262 | x = self.patch_embed(x) 263 | B, _, _ = x.size() 264 | 265 | if self.pos_embed is not None: 266 | x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() 267 | x = self.pos_drop(x) 268 | 269 | if self.use_checkpoint: 270 | for blk in self.blocks: 271 | x = checkpoint.checkpoint(blk, x) 272 | else: 273 | for blk in self.blocks: 274 | x = blk(x) 275 | 276 | x = self.norm(x) 277 | if self.fc_norm is not None: 278 | return self.fc_norm(x.mean(1)) 279 | else: 280 | return x[:, 0] 281 | 282 | def forward(self, x): 283 | x = self.forward_features(x) 284 | x = self.head(self.fc_dropout(x)) 285 | return x 286 | 287 | 288 | @register_model 289 | def vit_small_patch16_224(pretrained=False, **kwargs): 290 | model = VisionTransformer( 291 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 292 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 293 | model.default_cfg = _cfg() 294 | return model 295 | 296 | 297 | @register_model 298 | def vit_base_patch16_224(pretrained=False, **kwargs): 299 | model = VisionTransformer( 300 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 301 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 302 | model.default_cfg = _cfg() 303 | return model 304 | 305 | 306 | @register_model 307 | def vit_base_patch16_384(pretrained=False, **kwargs): 308 | model = VisionTransformer( 309 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 310 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 311 | model.default_cfg = _cfg() 312 | return model 313 | 314 | 315 | @register_model 316 | def vit_large_patch16_224(pretrained=False, **kwargs): 317 | model = VisionTransformer( 318 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 319 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 320 | model.default_cfg = _cfg() 321 | return model 322 | 323 | 324 | @register_model 325 | def vit_large_patch16_384(pretrained=False, **kwargs): 326 | model = VisionTransformer( 327 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 328 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 329 | model.default_cfg = _cfg() 330 | return model 331 | 332 | 333 | @register_model 334 | def vit_large_patch16_512(pretrained=False, **kwargs): 335 | model = VisionTransformer( 336 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 337 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 338 | model.default_cfg = _cfg() 339 | return model 340 | 341 | 342 | @register_model 343 | def vit_huge_patch16_224(pretrained=False, **kwargs): 344 | model = VisionTransformer( 345 | patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 346 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 347 | model.default_cfg = _cfg() 348 | return model 349 | -------------------------------------------------------------------------------- /modeling_pretrain.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as checkpoint 6 | from functools import partial 7 | 8 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 11 | 12 | 13 | 14 | def trunc_normal_(tensor, mean=0., std=1.): 15 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 16 | 17 | 18 | __all__ = [ 19 | 'pretrain_videomae_small_patch16_224', 20 | 'pretrain_videomae_base_patch16_224', 21 | 'pretrain_videomae_large_patch16_224', 22 | 'pretrain_videomae_huge_patch16_224', 23 | ] 24 | 25 | 26 | class PretrainVisionTransformerEncoder(nn.Module): 27 | """ Vision Transformer with support for patch or hybrid CNN input stage 28 | """ 29 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 30 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 31 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False, 32 | use_learnable_pos_emb=False): 33 | super().__init__() 34 | self.num_classes = num_classes 35 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 36 | self.patch_embed = PatchEmbed( 37 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size) 38 | num_patches = self.patch_embed.num_patches 39 | self.use_checkpoint = use_checkpoint 40 | 41 | 42 | # TODO: Add the cls token 43 | if use_learnable_pos_emb: 44 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 45 | else: 46 | # sine-cosine positional embeddings 47 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 48 | 49 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 50 | self.blocks = nn.ModuleList([ 51 | Block( 52 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 53 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 54 | init_values=init_values) 55 | for i in range(depth)]) 56 | self.norm = norm_layer(embed_dim) 57 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 58 | 59 | if use_learnable_pos_emb: 60 | trunc_normal_(self.pos_embed, std=.02) 61 | 62 | self.apply(self._init_weights) 63 | 64 | 65 | def _init_weights(self, m): 66 | if isinstance(m, nn.Linear): 67 | nn.init.xavier_uniform_(m.weight) 68 | if isinstance(m, nn.Linear) and m.bias is not None: 69 | nn.init.constant_(m.bias, 0) 70 | elif isinstance(m, nn.LayerNorm): 71 | nn.init.constant_(m.bias, 0) 72 | nn.init.constant_(m.weight, 1.0) 73 | 74 | def get_num_layers(self): 75 | return len(self.blocks) 76 | 77 | @torch.jit.ignore 78 | def no_weight_decay(self): 79 | return {'pos_embed', 'cls_token'} 80 | 81 | def get_classifier(self): 82 | return self.head 83 | 84 | def reset_classifier(self, num_classes, global_pool=''): 85 | self.num_classes = num_classes 86 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 87 | 88 | def forward_features(self, x, mask): 89 | _, _, T, _, _ = x.shape 90 | x = self.patch_embed(x) 91 | 92 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() 93 | 94 | B, _, C = x.shape 95 | x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible 96 | 97 | if self.use_checkpoint: 98 | for blk in self.blocks: 99 | x_vis = checkpoint.checkpoint(blk, x_vis) 100 | else: 101 | for blk in self.blocks: 102 | x_vis = blk(x_vis) 103 | 104 | x_vis = self.norm(x_vis) 105 | return x_vis 106 | 107 | def forward(self, x, mask): 108 | x = self.forward_features(x, mask) 109 | x = self.head(x) 110 | return x 111 | 112 | class PretrainVisionTransformerDecoder(nn.Module): 113 | """ Vision Transformer with support for patch or hybrid CNN input stage 114 | """ 115 | def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 116 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 117 | norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False 118 | ): 119 | super().__init__() 120 | self.num_classes = num_classes 121 | assert num_classes == 3 * tubelet_size * patch_size ** 2 122 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 123 | self.patch_size = patch_size 124 | self.use_checkpoint = use_checkpoint 125 | 126 | 127 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 128 | self.blocks = nn.ModuleList([ 129 | Block( 130 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 131 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 132 | init_values=init_values) 133 | for i in range(depth)]) 134 | self.norm = norm_layer(embed_dim) 135 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 136 | 137 | self.apply(self._init_weights) 138 | 139 | 140 | def _init_weights(self, m): 141 | if isinstance(m, nn.Linear): 142 | nn.init.xavier_uniform_(m.weight) 143 | if isinstance(m, nn.Linear) and m.bias is not None: 144 | nn.init.constant_(m.bias, 0) 145 | elif isinstance(m, nn.LayerNorm): 146 | nn.init.constant_(m.bias, 0) 147 | nn.init.constant_(m.weight, 1.0) 148 | 149 | def get_num_layers(self): 150 | return len(self.blocks) 151 | 152 | @torch.jit.ignore 153 | def no_weight_decay(self): 154 | return {'pos_embed', 'cls_token'} 155 | 156 | def get_classifier(self): 157 | return self.head 158 | 159 | def reset_classifier(self, num_classes, global_pool=''): 160 | self.num_classes = num_classes 161 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 162 | 163 | def forward(self, x, return_token_num): 164 | if self.use_checkpoint: 165 | for blk in self.blocks: 166 | x = checkpoint.checkpoint(blk, x) 167 | else: 168 | for blk in self.blocks: 169 | x = blk(x) 170 | 171 | if return_token_num > 0: 172 | x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels 173 | else: 174 | x = self.head(self.norm(x)) 175 | 176 | return x 177 | 178 | class PretrainVisionTransformer(nn.Module): 179 | """ Vision Transformer with support for patch or hybrid CNN input stage 180 | """ 181 | def __init__(self, 182 | img_size=224, 183 | patch_size=16, 184 | encoder_in_chans=3, 185 | encoder_num_classes=0, 186 | encoder_embed_dim=768, 187 | encoder_depth=12, 188 | encoder_num_heads=12, 189 | decoder_num_classes=1536, # decoder_num_classes=768, 190 | decoder_embed_dim=512, 191 | decoder_depth=8, 192 | decoder_num_heads=8, 193 | mlp_ratio=4., 194 | qkv_bias=False, 195 | qk_scale=None, 196 | drop_rate=0., 197 | attn_drop_rate=0., 198 | drop_path_rate=0., 199 | norm_layer=nn.LayerNorm, 200 | init_values=0., 201 | use_learnable_pos_emb=False, 202 | use_checkpoint=False, 203 | tubelet_size=2, 204 | num_classes=0, # avoid the error from create_fn in timm 205 | in_chans=0, # avoid the error from create_fn in timm 206 | ): 207 | super().__init__() 208 | self.encoder = PretrainVisionTransformerEncoder( 209 | img_size=img_size, 210 | patch_size=patch_size, 211 | in_chans=encoder_in_chans, 212 | num_classes=encoder_num_classes, 213 | embed_dim=encoder_embed_dim, 214 | depth=encoder_depth, 215 | num_heads=encoder_num_heads, 216 | mlp_ratio=mlp_ratio, 217 | qkv_bias=qkv_bias, 218 | qk_scale=qk_scale, 219 | drop_rate=drop_rate, 220 | attn_drop_rate=attn_drop_rate, 221 | drop_path_rate=drop_path_rate, 222 | norm_layer=norm_layer, 223 | init_values=init_values, 224 | tubelet_size=tubelet_size, 225 | use_checkpoint=use_checkpoint, 226 | use_learnable_pos_emb=use_learnable_pos_emb) 227 | 228 | self.decoder = PretrainVisionTransformerDecoder( 229 | patch_size=patch_size, 230 | num_patches=self.encoder.patch_embed.num_patches, 231 | num_classes=decoder_num_classes, 232 | embed_dim=decoder_embed_dim, 233 | depth=decoder_depth, 234 | num_heads=decoder_num_heads, 235 | mlp_ratio=mlp_ratio, 236 | qkv_bias=qkv_bias, 237 | qk_scale=qk_scale, 238 | drop_rate=drop_rate, 239 | attn_drop_rate=attn_drop_rate, 240 | drop_path_rate=drop_path_rate, 241 | norm_layer=norm_layer, 242 | init_values=init_values, 243 | tubelet_size=tubelet_size, 244 | use_checkpoint=use_checkpoint) 245 | 246 | self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False) 247 | 248 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 249 | 250 | self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim) 251 | 252 | trunc_normal_(self.mask_token, std=.02) 253 | 254 | 255 | def _init_weights(self, m): 256 | if isinstance(m, nn.Linear): 257 | nn.init.xavier_uniform_(m.weight) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | 264 | def get_num_layers(self): 265 | return len(self.blocks) 266 | 267 | @torch.jit.ignore 268 | def no_weight_decay(self): 269 | return {'pos_embed', 'cls_token', 'mask_token'} 270 | 271 | def forward(self, x, mask): 272 | _, _, T, _, _ = x.shape 273 | x_vis = self.encoder(x, mask) # [B, N_vis, C_e] 274 | x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d] 275 | B, N, C = x_vis.shape 276 | # we don't unshuffle the correct visible token order, 277 | # but shuffle the pos embedding accorddingly. 278 | expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() 279 | pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) 280 | pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) 281 | x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] 282 | x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] 283 | 284 | return x 285 | 286 | @register_model 287 | def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs): 288 | model = PretrainVisionTransformer( 289 | img_size=224, 290 | patch_size=16, 291 | encoder_embed_dim=384, 292 | encoder_depth=12, 293 | encoder_num_heads=6, 294 | encoder_num_classes=0, 295 | decoder_num_classes=1536, 296 | decoder_embed_dim=192, 297 | decoder_num_heads=3, 298 | mlp_ratio=4, 299 | qkv_bias=True, 300 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 301 | **kwargs) 302 | model.default_cfg = _cfg() 303 | if pretrained: 304 | checkpoint = torch.load( 305 | kwargs["init_ckpt"], map_location="cpu" 306 | ) 307 | model.load_state_dict(checkpoint["model"]) 308 | return model 309 | 310 | @register_model 311 | def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs): 312 | model = PretrainVisionTransformer( 313 | img_size=224, 314 | patch_size=16, 315 | encoder_embed_dim=768, 316 | encoder_depth=12, 317 | encoder_num_heads=12, 318 | encoder_num_classes=0, 319 | decoder_num_classes=1536, 320 | decoder_embed_dim=384, 321 | decoder_num_heads=6, 322 | mlp_ratio=4, 323 | qkv_bias=True, 324 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 325 | **kwargs) 326 | model.default_cfg = _cfg() 327 | if pretrained: 328 | checkpoint = torch.load( 329 | kwargs["init_ckpt"], map_location="cpu" 330 | ) 331 | model.load_state_dict(checkpoint["model"]) 332 | return model 333 | 334 | @register_model 335 | def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs): 336 | model = PretrainVisionTransformer( 337 | img_size=224, 338 | patch_size=16, 339 | encoder_embed_dim=1024, 340 | encoder_depth=24, 341 | encoder_num_heads=16, 342 | encoder_num_classes=0, 343 | decoder_num_classes=1536, 344 | decoder_embed_dim=512, 345 | decoder_num_heads=8, 346 | mlp_ratio=4, 347 | qkv_bias=True, 348 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 349 | **kwargs) 350 | model.default_cfg = _cfg() 351 | if pretrained: 352 | checkpoint = torch.load( 353 | kwargs["init_ckpt"], map_location="cpu" 354 | ) 355 | model.load_state_dict(checkpoint["model"]) 356 | return model 357 | 358 | @register_model 359 | def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs): 360 | model = PretrainVisionTransformer( 361 | img_size=224, 362 | patch_size=16, 363 | encoder_embed_dim=1280, 364 | encoder_depth=32, 365 | encoder_num_heads=16, 366 | encoder_num_classes=0, 367 | decoder_num_classes=1536, 368 | decoder_embed_dim=640, 369 | decoder_num_heads=8, 370 | mlp_ratio=4, 371 | qkv_bias=True, 372 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 373 | **kwargs) 374 | model.default_cfg = _cfg() 375 | if pretrained: 376 | checkpoint = torch.load( 377 | kwargs["init_ckpt"], map_location="cpu" 378 | ) 379 | model.load_state_dict(checkpoint["model"]) 380 | return model 381 | -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 27 | return y1 * lam + y2 * (1. - lam) 28 | 29 | 30 | def rand_bbox(img_shape, lam, margin=0., count=None): 31 | """ Standard CutMix bounding-box 32 | Generates a random square bbox based on lambda value. This impl includes 33 | support for enforcing a border margin as percent of bbox dimensions. 34 | 35 | Args: 36 | img_shape (tuple): Image shape as tuple 37 | lam (float): Cutmix lambda value 38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 39 | count (int): Number of bbox to generate 40 | """ 41 | ratio = np.sqrt(1 - lam) 42 | img_h, img_w = img_shape[-2:] 43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 47 | yl = np.clip(cy - cut_h // 2, 0, img_h) 48 | yh = np.clip(cy + cut_h // 2, 0, img_h) 49 | xl = np.clip(cx - cut_w // 2, 0, img_w) 50 | xh = np.clip(cx + cut_w // 2, 0, img_w) 51 | return yl, yh, xl, xh 52 | 53 | 54 | def rand_bbox_minmax(img_shape, minmax, count=None): 55 | """ Min-Max CutMix bounding-box 56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 57 | based on min/max percent values applied to each dimension of the input image. 58 | 59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 60 | 61 | Args: 62 | img_shape (tuple): Image shape as tuple 63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 64 | count (int): Number of bbox to generate 65 | """ 66 | assert len(minmax) == 2 67 | img_h, img_w = img_shape[-2:] 68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 70 | yl = np.random.randint(0, img_h - cut_h, size=count) 71 | xl = np.random.randint(0, img_w - cut_w, size=count) 72 | yu = yl + cut_h 73 | xu = xl + cut_w 74 | return yl, yu, xl, xu 75 | 76 | 77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 78 | """ Generate bbox and apply lambda correction. 79 | """ 80 | if ratio_minmax is not None: 81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 82 | else: 83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 84 | if correct_lam or ratio_minmax is not None: 85 | bbox_area = (yu - yl) * (xu - xl) 86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 87 | return (yl, yu, xl, xu), lam 88 | 89 | 90 | class Mixup: 91 | """ Mixup/Cutmix that applies different params to each element or whole batch 92 | 93 | Args: 94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 97 | prob (float): probability of applying mixup or cutmix per batch or element 98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 101 | label_smoothing (float): apply label smoothing to the mixed target tensor 102 | num_classes (int): number of classes for target 103 | """ 104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 106 | self.mixup_alpha = mixup_alpha 107 | self.cutmix_alpha = cutmix_alpha 108 | self.cutmix_minmax = cutmix_minmax 109 | if self.cutmix_minmax is not None: 110 | assert len(self.cutmix_minmax) == 2 111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 112 | self.cutmix_alpha = 1.0 113 | self.mix_prob = prob 114 | self.switch_prob = switch_prob 115 | self.label_smoothing = label_smoothing 116 | self.num_classes = num_classes 117 | self.mode = mode 118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 120 | 121 | def _params_per_elem(self, batch_size): 122 | lam = np.ones(batch_size, dtype=np.float32) 123 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 124 | if self.mixup_enabled: 125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 127 | lam_mix = np.where( 128 | use_cutmix, 129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 131 | elif self.mixup_alpha > 0.: 132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 133 | elif self.cutmix_alpha > 0.: 134 | use_cutmix = np.ones(batch_size, dtype=np.bool) 135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 136 | else: 137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 139 | return lam, use_cutmix 140 | 141 | def _params_per_batch(self): 142 | lam = 1. 143 | use_cutmix = False 144 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 146 | use_cutmix = np.random.rand() < self.switch_prob 147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 148 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 149 | elif self.mixup_alpha > 0.: 150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 151 | elif self.cutmix_alpha > 0.: 152 | use_cutmix = True 153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 154 | else: 155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 156 | lam = float(lam_mix) 157 | return lam, use_cutmix 158 | 159 | def _mix_elem(self, x): 160 | batch_size = len(x) 161 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 162 | x_orig = x.clone() # need to keep an unmodified original for mixing source 163 | for i in range(batch_size): 164 | j = batch_size - i - 1 165 | lam = lam_batch[i] 166 | if lam != 1.: 167 | if use_cutmix[i]: 168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh] 171 | lam_batch[i] = lam 172 | else: 173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 175 | 176 | def _mix_pair(self, x): 177 | batch_size = len(x) 178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 179 | x_orig = x.clone() # need to keep an unmodified original for mixing source 180 | for i in range(batch_size // 2): 181 | j = batch_size - i - 1 182 | lam = lam_batch[i] 183 | if lam != 1.: 184 | if use_cutmix[i]: 185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 189 | lam_batch[i] = lam 190 | else: 191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 195 | 196 | def _mix_batch(self, x): 197 | lam, use_cutmix = self._params_per_batch() 198 | if lam == 1.: 199 | return 1. 200 | if use_cutmix: 201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 204 | else: 205 | x_flipped = x.flip(0).mul_(1. - lam) 206 | x.mul_(lam).add_(x_flipped) 207 | return lam 208 | 209 | def __call__(self, x, target): 210 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 211 | if self.mode == 'elem': 212 | lam = self._mix_elem(x) 213 | elif self.mode == 'pair': 214 | lam = self._mix_pair(x) 215 | else: 216 | lam = self._mix_batch(x) 217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 218 | return x, target 219 | 220 | 221 | class FastCollateMixup(Mixup): 222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 223 | 224 | A Mixup impl that's performed while collating the batches. 225 | """ 226 | 227 | def _mix_elem_collate(self, output, batch, half=False): 228 | batch_size = len(batch) 229 | num_elem = batch_size // 2 if half else batch_size 230 | assert len(output) == num_elem 231 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 232 | for i in range(num_elem): 233 | j = batch_size - i - 1 234 | lam = lam_batch[i] 235 | mixed = batch[i][0] 236 | if lam != 1.: 237 | if use_cutmix[i]: 238 | if not half: 239 | mixed = mixed.copy() 240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 243 | lam_batch[i] = lam 244 | else: 245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 246 | np.rint(mixed, out=mixed) 247 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 248 | if half: 249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 250 | return torch.tensor(lam_batch).unsqueeze(1) 251 | 252 | def _mix_pair_collate(self, output, batch): 253 | batch_size = len(batch) 254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 255 | for i in range(batch_size // 2): 256 | j = batch_size - i - 1 257 | lam = lam_batch[i] 258 | mixed_i = batch[i][0] 259 | mixed_j = batch[j][0] 260 | assert 0 <= lam <= 1.0 261 | if lam < 1.: 262 | if use_cutmix[i]: 263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 267 | mixed_j[:, yl:yh, xl:xh] = patch_i 268 | lam_batch[i] = lam 269 | else: 270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 272 | mixed_i = mixed_temp 273 | np.rint(mixed_j, out=mixed_j) 274 | np.rint(mixed_i, out=mixed_i) 275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 278 | return torch.tensor(lam_batch).unsqueeze(1) 279 | 280 | def _mix_batch_collate(self, output, batch): 281 | batch_size = len(batch) 282 | lam, use_cutmix = self._params_per_batch() 283 | if use_cutmix: 284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 286 | for i in range(batch_size): 287 | j = batch_size - i - 1 288 | mixed = batch[i][0] 289 | if lam != 1.: 290 | if use_cutmix: 291 | mixed = mixed.copy() # don't want to modify the original while iterating 292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh] 293 | else: 294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 295 | np.rint(mixed, out=mixed) 296 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 297 | return lam 298 | 299 | def __call__(self, batch, _=None): 300 | batch_size = len(batch) 301 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 302 | half = 'half' in self.mode 303 | if half: 304 | batch_size //= 2 305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 306 | if self.mode == 'elem' or self.mode == 'half': 307 | lam = self._mix_elem_collate(output, batch, half=half) 308 | elif self.mode == 'pair': 309 | lam = self._mix_pair_collate(output, batch) 310 | else: 311 | lam = self._mix_batch_collate(output, batch) 312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 314 | target = target[:batch_size] 315 | return output, target 316 | 317 | -------------------------------------------------------------------------------- /ssv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms 5 | from random_erasing import RandomErasing 6 | import warnings 7 | from decord import VideoReader, cpu 8 | from torch.utils.data import Dataset 9 | import video_transforms as video_transforms 10 | import volume_transforms as volume_transforms 11 | 12 | 13 | class SSVideoClsDataset(Dataset): 14 | """Load your own video classification dataset.""" 15 | 16 | def __init__(self, anno_path, data_path, mode='train', clip_len=8, 17 | crop_size=224, short_side_size=256, new_height=256, 18 | new_width=340, keep_aspect_ratio=True, num_segment=1, 19 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None): 20 | self.anno_path = anno_path 21 | self.data_path = data_path 22 | self.mode = mode 23 | self.clip_len = clip_len 24 | self.crop_size = crop_size 25 | self.short_side_size = short_side_size 26 | self.new_height = new_height 27 | self.new_width = new_width 28 | self.keep_aspect_ratio = keep_aspect_ratio 29 | self.num_segment = num_segment 30 | self.test_num_segment = test_num_segment 31 | self.num_crop = num_crop 32 | self.test_num_crop = test_num_crop 33 | self.args = args 34 | self.aug = False 35 | self.rand_erase = False 36 | if self.mode in ['train']: 37 | self.aug = True 38 | if self.args.reprob > 0: 39 | self.rand_erase = True 40 | if VideoReader is None: 41 | raise ImportError("Unable to import `decord` which is required to read videos.") 42 | 43 | import pandas as pd 44 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ') 45 | self.dataset_samples = list(cleaned.values[:, 0]) 46 | self.label_array = list(cleaned.values[:, 1]) 47 | 48 | if (mode == 'train'): 49 | pass 50 | 51 | elif (mode == 'validation'): 52 | self.data_transform = video_transforms.Compose([ 53 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'), 54 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 55 | volume_transforms.ClipToTensor(), 56 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 57 | std=[0.229, 0.224, 0.225]) 58 | ]) 59 | elif mode == 'test': 60 | self.data_resize = video_transforms.Compose([ 61 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear') 62 | ]) 63 | self.data_transform = video_transforms.Compose([ 64 | volume_transforms.ClipToTensor(), 65 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 66 | std=[0.229, 0.224, 0.225]) 67 | ]) 68 | self.test_seg = [] 69 | self.test_dataset = [] 70 | self.test_label_array = [] 71 | for ck in range(self.test_num_segment): 72 | for cp in range(self.test_num_crop): 73 | for idx in range(len(self.label_array)): 74 | sample_label = self.label_array[idx] 75 | self.test_label_array.append(sample_label) 76 | self.test_dataset.append(self.dataset_samples[idx]) 77 | self.test_seg.append((ck, cp)) 78 | 79 | def __getitem__(self, index): 80 | if self.mode == 'train': 81 | args = self.args 82 | scale_t = 1 83 | 84 | sample = self.dataset_samples[index] 85 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C 86 | if len(buffer) == 0: 87 | while len(buffer) == 0: 88 | warnings.warn("video {} not correctly loaded during training".format(sample)) 89 | index = np.random.randint(self.__len__()) 90 | sample = self.dataset_samples[index] 91 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) 92 | 93 | if args.num_sample > 1: 94 | frame_list = [] 95 | label_list = [] 96 | index_list = [] 97 | for _ in range(args.num_sample): 98 | new_frames = self._aug_frame(buffer, args) 99 | label = self.label_array[index] 100 | frame_list.append(new_frames) 101 | label_list.append(label) 102 | index_list.append(index) 103 | return frame_list, label_list, index_list, {} 104 | else: 105 | buffer = self._aug_frame(buffer, args) 106 | 107 | return buffer, self.label_array[index], index, {} 108 | 109 | elif self.mode == 'validation': 110 | sample = self.dataset_samples[index] 111 | buffer = self.loadvideo_decord(sample) 112 | if len(buffer) == 0: 113 | while len(buffer) == 0: 114 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 115 | index = np.random.randint(self.__len__()) 116 | sample = self.dataset_samples[index] 117 | buffer = self.loadvideo_decord(sample) 118 | buffer = self.data_transform(buffer) 119 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 120 | 121 | elif self.mode == 'test': 122 | sample = self.test_dataset[index] 123 | chunk_nb, split_nb = self.test_seg[index] 124 | buffer = self.loadvideo_decord(sample) 125 | 126 | while len(buffer) == 0: 127 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 128 | str(self.test_dataset[index]), chunk_nb, split_nb)) 129 | index = np.random.randint(self.__len__()) 130 | sample = self.test_dataset[index] 131 | chunk_nb, split_nb = self.test_seg[index] 132 | buffer = self.loadvideo_decord(sample) 133 | 134 | buffer = self.data_resize(buffer) 135 | if isinstance(buffer, list): 136 | buffer = np.stack(buffer, 0) 137 | 138 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 139 | / (self.test_num_crop - 1) 140 | temporal_start = chunk_nb # 0/1 141 | spatial_start = int(split_nb * spatial_step) 142 | if buffer.shape[1] >= buffer.shape[2]: 143 | buffer = buffer[temporal_start::2, \ 144 | spatial_start:spatial_start + self.short_side_size, :, :] 145 | else: 146 | buffer = buffer[temporal_start::2, \ 147 | :, spatial_start:spatial_start + self.short_side_size, :] 148 | 149 | buffer = self.data_transform(buffer) 150 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 151 | chunk_nb, split_nb 152 | else: 153 | raise NameError('mode {} unkown'.format(self.mode)) 154 | 155 | def _aug_frame( 156 | self, 157 | buffer, 158 | args, 159 | ): 160 | 161 | buffer = [ 162 | transforms.ToPILImage()(frame) for frame in buffer 163 | ] 164 | 165 | if args.is_aa: 166 | aug_transform = video_transforms.create_random_augment( 167 | input_size=(self.crop_size, self.crop_size), 168 | auto_augment=args.aa, 169 | interpolation=args.train_interpolation, 170 | ) 171 | buffer = aug_transform(buffer) 172 | 173 | buffer = [transforms.ToTensor()(img) for img in buffer] 174 | buffer = torch.stack(buffer) # T C H W 175 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 176 | 177 | # T H W C 178 | buffer = tensor_normalize( 179 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 180 | ) 181 | # T H W C -> C T H W. 182 | buffer = buffer.permute(3, 0, 1, 2) 183 | # Perform data augmentation. 184 | scl, asp = ( 185 | [0.08, 1.0], 186 | [0.75, 1.3333], 187 | ) 188 | 189 | buffer = spatial_sampling( 190 | buffer, 191 | spatial_idx=-1, 192 | min_scale=256, 193 | max_scale=320, 194 | crop_size=self.crop_size, 195 | random_horizontal_flip=False if args.data_set == 'SSV2' else True, 196 | inverse_uniform_sampling=False, 197 | aspect_ratio=asp, 198 | scale=scl, 199 | motion_shift=False 200 | ) 201 | 202 | # No random erase for linear probing or prompting. 203 | # if self.rand_erase: 204 | # erase_transform = RandomErasing( 205 | # args.reprob, 206 | # mode=args.remode, 207 | # max_count=args.recount, 208 | # num_splits=args.recount, 209 | # device="cpu", 210 | # ) 211 | # buffer = buffer.permute(1, 0, 2, 3) 212 | # buffer = erase_transform(buffer) 213 | # buffer = buffer.permute(1, 0, 2, 3) 214 | 215 | return buffer 216 | 217 | 218 | def loadvideo_decord(self, sample, sample_rate_scale=1): 219 | """Load video content using Decord""" 220 | fname = sample 221 | 222 | if not (os.path.exists(fname)): 223 | return [] 224 | 225 | # avoid hanging issue 226 | if os.path.getsize(fname) < 1 * 1024: 227 | print('SKIP: ', fname, " - ", os.path.getsize(fname)) 228 | return [] 229 | try: 230 | if self.keep_aspect_ratio: 231 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 232 | else: 233 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 234 | num_threads=1, ctx=cpu(0)) 235 | except: 236 | print("video cannot be loaded by decord: ", fname) 237 | return [] 238 | 239 | if self.mode == 'test': 240 | all_index = [] 241 | tick = len(vr) / float(self.num_segment) 242 | all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] + 243 | [int(tick * x) for x in range(self.num_segment)])) 244 | while len(all_index) < (self.num_segment * self.test_num_segment): 245 | all_index.append(all_index[-1]) 246 | all_index = list(np.sort(np.array(all_index))) 247 | vr.seek(0) 248 | buffer = vr.get_batch(all_index).asnumpy() 249 | return buffer 250 | 251 | # handle temporal segments 252 | average_duration = len(vr) // self.num_segment 253 | all_index = [] 254 | if average_duration > 0: 255 | all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration, 256 | size=self.num_segment)) 257 | elif len(vr) > self.num_segment: 258 | all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment))) 259 | else: 260 | all_index += list(np.zeros((self.num_segment,))) 261 | all_index = list(np.array(all_index)) 262 | vr.seek(0) 263 | buffer = vr.get_batch(all_index).asnumpy() 264 | return buffer 265 | 266 | def __len__(self): 267 | if self.mode != 'test': 268 | return len(self.dataset_samples) 269 | else: 270 | return len(self.test_dataset) 271 | 272 | 273 | def spatial_sampling( 274 | frames, 275 | spatial_idx=-1, 276 | min_scale=256, 277 | max_scale=320, 278 | crop_size=224, 279 | random_horizontal_flip=True, 280 | inverse_uniform_sampling=False, 281 | aspect_ratio=None, 282 | scale=None, 283 | motion_shift=False, 284 | ): 285 | """ 286 | Perform spatial sampling on the given video frames. If spatial_idx is 287 | -1, perform random scale, random crop, and random flip on the given 288 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 289 | with the given spatial_idx. 290 | Args: 291 | frames (tensor): frames of images sampled from the video. The 292 | dimension is `num frames` x `height` x `width` x `channel`. 293 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 294 | or 2, perform left, center, right crop if width is larger than 295 | height, and perform top, center, buttom crop if height is larger 296 | than width. 297 | min_scale (int): the minimal size of scaling. 298 | max_scale (int): the maximal size of scaling. 299 | crop_size (int): the size of height and width used to crop the 300 | frames. 301 | inverse_uniform_sampling (bool): if True, sample uniformly in 302 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 303 | scale. If False, take a uniform sample from [min_scale, 304 | max_scale]. 305 | aspect_ratio (list): Aspect ratio range for resizing. 306 | scale (list): Scale range for resizing. 307 | motion_shift (bool): Whether to apply motion shift for resizing. 308 | Returns: 309 | frames (tensor): spatially sampled frames. 310 | """ 311 | assert spatial_idx in [-1, 0, 1, 2] 312 | if spatial_idx == -1: 313 | if aspect_ratio is None and scale is None: 314 | frames, _ = video_transforms.random_short_side_scale_jitter( 315 | images=frames, 316 | min_size=min_scale, 317 | max_size=max_scale, 318 | inverse_uniform_sampling=inverse_uniform_sampling, 319 | ) 320 | frames, _ = video_transforms.random_crop(frames, crop_size) 321 | else: 322 | transform_func = ( 323 | video_transforms.random_resized_crop_with_shift 324 | if motion_shift 325 | else video_transforms.random_resized_crop 326 | ) 327 | frames = transform_func( 328 | images=frames, 329 | target_height=crop_size, 330 | target_width=crop_size, 331 | scale=scale, 332 | ratio=aspect_ratio, 333 | ) 334 | if random_horizontal_flip: 335 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 336 | else: 337 | # The testing is deterministic and no jitter should be performed. 338 | # min_scale, max_scale, and crop_size are expect to be the same. 339 | assert len({min_scale, max_scale, crop_size}) == 1 340 | frames, _ = video_transforms.random_short_side_scale_jitter( 341 | frames, min_scale, max_scale 342 | ) 343 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 344 | return frames 345 | 346 | 347 | def tensor_normalize(tensor, mean, std): 348 | """ 349 | Normalize a given tensor by subtracting the mean and dividing the std. 350 | Args: 351 | tensor (tensor): tensor to normalize. 352 | mean (tensor or list): mean value to subtract. 353 | std (tensor or list): std to divide. 354 | """ 355 | if tensor.dtype == torch.uint8: 356 | tensor = tensor.float() 357 | tensor = tensor / 255.0 358 | if type(mean) == list: 359 | mean = torch.tensor(mean) 360 | if type(std) == list: 361 | std = torch.tensor(std) 362 | tensor = tensor - mean 363 | tensor = tensor / std 364 | return tensor 365 | -------------------------------------------------------------------------------- /rand_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 4 | pulished under an Apache License 2.0. 5 | 6 | COMMENT FROM ORIGINAL: 7 | AutoAugment, RandAugment, and AugMix for PyTorch 8 | This code implements the searched ImageNet policies with various tweaks and 9 | improvements and does not include any of the search code. AA and RA 10 | Implementation adapted from: 11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 12 | AugMix adapted from: 13 | https://github.com/google-research/augmix 14 | Papers: 15 | AutoAugment: Learning Augmentation Policies from Data 16 | https://arxiv.org/abs/1805.09501 17 | Learning Data Augmentation Strategies for Object Detection 18 | https://arxiv.org/abs/1906.11172 19 | RandAugment: Practical automated data augmentation... 20 | https://arxiv.org/abs/1909.13719 21 | AugMix: A Simple Data Processing Method to Improve Robustness and 22 | Uncertainty https://arxiv.org/abs/1912.02781 23 | 24 | Hacked together by / Copyright 2020 Ross Wightman 25 | """ 26 | 27 | import math 28 | import numpy as np 29 | import random 30 | import re 31 | import PIL 32 | from PIL import Image, ImageEnhance, ImageOps 33 | 34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 35 | 36 | _FILL = (128, 128, 128) 37 | 38 | # This signifies the max integer that the controller RNN could predict for the 39 | # augmentation scheme. 40 | _MAX_LEVEL = 10.0 41 | 42 | _HPARAMS_DEFAULT = { 43 | "translate_const": 250, 44 | "img_mean": _FILL, 45 | } 46 | 47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 48 | 49 | 50 | def _interpolation(kwargs): 51 | interpolation = kwargs.pop("resample", Image.BILINEAR) 52 | if isinstance(interpolation, (list, tuple)): 53 | return random.choice(interpolation) 54 | else: 55 | return interpolation 56 | 57 | 58 | def _check_args_tf(kwargs): 59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 60 | kwargs.pop("fillcolor") 61 | kwargs["resample"] = _interpolation(kwargs) 62 | 63 | 64 | def shear_x(img, factor, **kwargs): 65 | _check_args_tf(kwargs) 66 | return img.transform( 67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 68 | ) 69 | 70 | 71 | def shear_y(img, factor, **kwargs): 72 | _check_args_tf(kwargs) 73 | return img.transform( 74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 75 | ) 76 | 77 | 78 | def translate_x_rel(img, pct, **kwargs): 79 | pixels = pct * img.size[0] 80 | _check_args_tf(kwargs) 81 | return img.transform( 82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 83 | ) 84 | 85 | 86 | def translate_y_rel(img, pct, **kwargs): 87 | pixels = pct * img.size[1] 88 | _check_args_tf(kwargs) 89 | return img.transform( 90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 91 | ) 92 | 93 | 94 | def translate_x_abs(img, pixels, **kwargs): 95 | _check_args_tf(kwargs) 96 | return img.transform( 97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 98 | ) 99 | 100 | 101 | def translate_y_abs(img, pixels, **kwargs): 102 | _check_args_tf(kwargs) 103 | return img.transform( 104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 105 | ) 106 | 107 | 108 | def rotate(img, degrees, **kwargs): 109 | _check_args_tf(kwargs) 110 | if _PIL_VER >= (5, 2): 111 | return img.rotate(degrees, **kwargs) 112 | elif _PIL_VER >= (5, 0): 113 | w, h = img.size 114 | post_trans = (0, 0) 115 | rotn_center = (w / 2.0, h / 2.0) 116 | angle = -math.radians(degrees) 117 | matrix = [ 118 | round(math.cos(angle), 15), 119 | round(math.sin(angle), 15), 120 | 0.0, 121 | round(-math.sin(angle), 15), 122 | round(math.cos(angle), 15), 123 | 0.0, 124 | ] 125 | 126 | def transform(x, y, matrix): 127 | (a, b, c, d, e, f) = matrix 128 | return a * x + b * y + c, d * x + e * y + f 129 | 130 | matrix[2], matrix[5] = transform( 131 | -rotn_center[0] - post_trans[0], 132 | -rotn_center[1] - post_trans[1], 133 | matrix, 134 | ) 135 | matrix[2] += rotn_center[0] 136 | matrix[5] += rotn_center[1] 137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 138 | else: 139 | return img.rotate(degrees, resample=kwargs["resample"]) 140 | 141 | 142 | def auto_contrast(img, **__): 143 | return ImageOps.autocontrast(img) 144 | 145 | 146 | def invert(img, **__): 147 | return ImageOps.invert(img) 148 | 149 | 150 | def equalize(img, **__): 151 | return ImageOps.equalize(img) 152 | 153 | 154 | def solarize(img, thresh, **__): 155 | return ImageOps.solarize(img, thresh) 156 | 157 | 158 | def solarize_add(img, add, thresh=128, **__): 159 | lut = [] 160 | for i in range(256): 161 | if i < thresh: 162 | lut.append(min(255, i + add)) 163 | else: 164 | lut.append(i) 165 | if img.mode in ("L", "RGB"): 166 | if img.mode == "RGB" and len(lut) == 256: 167 | lut = lut + lut + lut 168 | return img.point(lut) 169 | else: 170 | return img 171 | 172 | 173 | def posterize(img, bits_to_keep, **__): 174 | if bits_to_keep >= 8: 175 | return img 176 | return ImageOps.posterize(img, bits_to_keep) 177 | 178 | 179 | def contrast(img, factor, **__): 180 | return ImageEnhance.Contrast(img).enhance(factor) 181 | 182 | 183 | def color(img, factor, **__): 184 | return ImageEnhance.Color(img).enhance(factor) 185 | 186 | 187 | def brightness(img, factor, **__): 188 | return ImageEnhance.Brightness(img).enhance(factor) 189 | 190 | 191 | def sharpness(img, factor, **__): 192 | return ImageEnhance.Sharpness(img).enhance(factor) 193 | 194 | 195 | def _randomly_negate(v): 196 | """With 50% prob, negate the value""" 197 | return -v if random.random() > 0.5 else v 198 | 199 | 200 | def _rotate_level_to_arg(level, _hparams): 201 | # range [-30, 30] 202 | level = (level / _MAX_LEVEL) * 30.0 203 | level = _randomly_negate(level) 204 | return (level,) 205 | 206 | 207 | def _enhance_level_to_arg(level, _hparams): 208 | # range [0.1, 1.9] 209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 210 | 211 | 212 | def _enhance_increasing_level_to_arg(level, _hparams): 213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 214 | # range [0.1, 1.9] 215 | level = (level / _MAX_LEVEL) * 0.9 216 | level = 1.0 + _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _shear_level_to_arg(level, _hparams): 221 | # range [-0.3, 0.3] 222 | level = (level / _MAX_LEVEL) * 0.3 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_abs_level_to_arg(level, hparams): 228 | translate_const = hparams["translate_const"] 229 | level = (level / _MAX_LEVEL) * float(translate_const) 230 | level = _randomly_negate(level) 231 | return (level,) 232 | 233 | 234 | def _translate_rel_level_to_arg(level, hparams): 235 | # default range [-0.45, 0.45] 236 | translate_pct = hparams.get("translate_pct", 0.45) 237 | level = (level / _MAX_LEVEL) * translate_pct 238 | level = _randomly_negate(level) 239 | return (level,) 240 | 241 | 242 | def _posterize_level_to_arg(level, _hparams): 243 | # As per Tensorflow TPU EfficientNet impl 244 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 245 | # intensity/severity of augmentation decreases with level 246 | return (int((level / _MAX_LEVEL) * 4),) 247 | 248 | 249 | def _posterize_increasing_level_to_arg(level, hparams): 250 | # As per Tensorflow models research and UDA impl 251 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 252 | # intensity/severity of augmentation increases with level 253 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 254 | 255 | 256 | def _posterize_original_level_to_arg(level, _hparams): 257 | # As per original AutoAugment paper description 258 | # range [4, 8], 'keep 4 up to 8 MSB of image' 259 | # intensity/severity of augmentation decreases with level 260 | return (int((level / _MAX_LEVEL) * 4) + 4,) 261 | 262 | 263 | def _solarize_level_to_arg(level, _hparams): 264 | # range [0, 256] 265 | # intensity/severity of augmentation decreases with level 266 | return (int((level / _MAX_LEVEL) * 256),) 267 | 268 | 269 | def _solarize_increasing_level_to_arg(level, _hparams): 270 | # range [0, 256] 271 | # intensity/severity of augmentation increases with level 272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 273 | 274 | 275 | def _solarize_add_level_to_arg(level, _hparams): 276 | # range [0, 110] 277 | return (int((level / _MAX_LEVEL) * 110),) 278 | 279 | 280 | LEVEL_TO_ARG = { 281 | "AutoContrast": None, 282 | "Equalize": None, 283 | "Invert": None, 284 | "Rotate": _rotate_level_to_arg, 285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 286 | "Posterize": _posterize_level_to_arg, 287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 288 | "PosterizeOriginal": _posterize_original_level_to_arg, 289 | "Solarize": _solarize_level_to_arg, 290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 291 | "SolarizeAdd": _solarize_add_level_to_arg, 292 | "Color": _enhance_level_to_arg, 293 | "ColorIncreasing": _enhance_increasing_level_to_arg, 294 | "Contrast": _enhance_level_to_arg, 295 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 296 | "Brightness": _enhance_level_to_arg, 297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 298 | "Sharpness": _enhance_level_to_arg, 299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 300 | "ShearX": _shear_level_to_arg, 301 | "ShearY": _shear_level_to_arg, 302 | "TranslateX": _translate_abs_level_to_arg, 303 | "TranslateY": _translate_abs_level_to_arg, 304 | "TranslateXRel": _translate_rel_level_to_arg, 305 | "TranslateYRel": _translate_rel_level_to_arg, 306 | } 307 | 308 | 309 | NAME_TO_OP = { 310 | "AutoContrast": auto_contrast, 311 | "Equalize": equalize, 312 | "Invert": invert, 313 | "Rotate": rotate, 314 | "Posterize": posterize, 315 | "PosterizeIncreasing": posterize, 316 | "PosterizeOriginal": posterize, 317 | "Solarize": solarize, 318 | "SolarizeIncreasing": solarize, 319 | "SolarizeAdd": solarize_add, 320 | "Color": color, 321 | "ColorIncreasing": color, 322 | "Contrast": contrast, 323 | "ContrastIncreasing": contrast, 324 | "Brightness": brightness, 325 | "BrightnessIncreasing": brightness, 326 | "Sharpness": sharpness, 327 | "SharpnessIncreasing": sharpness, 328 | "ShearX": shear_x, 329 | "ShearY": shear_y, 330 | "TranslateX": translate_x_abs, 331 | "TranslateY": translate_y_abs, 332 | "TranslateXRel": translate_x_rel, 333 | "TranslateYRel": translate_y_rel, 334 | } 335 | 336 | 337 | class AugmentOp: 338 | """ 339 | Apply for video. 340 | """ 341 | 342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 343 | hparams = hparams or _HPARAMS_DEFAULT 344 | self.aug_fn = NAME_TO_OP[name] 345 | self.level_fn = LEVEL_TO_ARG[name] 346 | self.prob = prob 347 | self.magnitude = magnitude 348 | self.hparams = hparams.copy() 349 | self.kwargs = { 350 | "fillcolor": hparams["img_mean"] 351 | if "img_mean" in hparams 352 | else _FILL, 353 | "resample": hparams["interpolation"] 354 | if "interpolation" in hparams 355 | else _RANDOM_INTERPOLATION, 356 | } 357 | 358 | # If magnitude_std is > 0, we introduce some randomness 359 | # in the usually fixed policy and sample magnitude from a normal distribution 360 | # with mean `magnitude` and std-dev of `magnitude_std`. 361 | # NOTE This is my own hack, being tested, not in papers or reference impls. 362 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 363 | 364 | def __call__(self, img_list): 365 | if self.prob < 1.0 and random.random() > self.prob: 366 | return img_list 367 | magnitude = self.magnitude 368 | if self.magnitude_std and self.magnitude_std > 0: 369 | magnitude = random.gauss(magnitude, self.magnitude_std) 370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 371 | level_args = ( 372 | self.level_fn(magnitude, self.hparams) 373 | if self.level_fn is not None 374 | else () 375 | ) 376 | 377 | if isinstance(img_list, list): 378 | return [ 379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 380 | ] 381 | else: 382 | return self.aug_fn(img_list, *level_args, **self.kwargs) 383 | 384 | 385 | _RAND_TRANSFORMS = [ 386 | "AutoContrast", 387 | "Equalize", 388 | "Invert", 389 | "Rotate", 390 | "Posterize", 391 | "Solarize", 392 | "SolarizeAdd", 393 | "Color", 394 | "Contrast", 395 | "Brightness", 396 | "Sharpness", 397 | "ShearX", 398 | "ShearY", 399 | "TranslateXRel", 400 | "TranslateYRel", 401 | ] 402 | 403 | 404 | _RAND_INCREASING_TRANSFORMS = [ 405 | "AutoContrast", 406 | "Equalize", 407 | "Invert", 408 | "Rotate", 409 | "PosterizeIncreasing", 410 | "SolarizeIncreasing", 411 | "SolarizeAdd", 412 | "ColorIncreasing", 413 | "ContrastIncreasing", 414 | "BrightnessIncreasing", 415 | "SharpnessIncreasing", 416 | "ShearX", 417 | "ShearY", 418 | "TranslateXRel", 419 | "TranslateYRel", 420 | ] 421 | 422 | 423 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 424 | # They may not result in increased performance, but could likely be tuned to so. 425 | _RAND_CHOICE_WEIGHTS_0 = { 426 | "Rotate": 0.3, 427 | "ShearX": 0.2, 428 | "ShearY": 0.2, 429 | "TranslateXRel": 0.1, 430 | "TranslateYRel": 0.1, 431 | "Color": 0.025, 432 | "Sharpness": 0.025, 433 | "AutoContrast": 0.025, 434 | "Solarize": 0.005, 435 | "SolarizeAdd": 0.005, 436 | "Contrast": 0.005, 437 | "Brightness": 0.005, 438 | "Equalize": 0.005, 439 | "Posterize": 0, 440 | "Invert": 0, 441 | } 442 | 443 | 444 | def _select_rand_weights(weight_idx=0, transforms=None): 445 | transforms = transforms or _RAND_TRANSFORMS 446 | assert weight_idx == 0 # only one set of weights currently 447 | rand_weights = _RAND_CHOICE_WEIGHTS_0 448 | probs = [rand_weights[k] for k in transforms] 449 | probs /= np.sum(probs) 450 | return probs 451 | 452 | 453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 454 | hparams = hparams or _HPARAMS_DEFAULT 455 | transforms = transforms or _RAND_TRANSFORMS 456 | return [ 457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 458 | for name in transforms 459 | ] 460 | 461 | 462 | class RandAugment: 463 | def __init__(self, ops, num_layers=2, choice_weights=None): 464 | self.ops = ops 465 | self.num_layers = num_layers 466 | self.choice_weights = choice_weights 467 | 468 | def __call__(self, img): 469 | # no replacement when using weighted choice 470 | ops = np.random.choice( 471 | self.ops, 472 | self.num_layers, 473 | replace=self.choice_weights is None, 474 | p=self.choice_weights, 475 | ) 476 | for op in ops: 477 | img = op(img) 478 | return img 479 | 480 | 481 | def rand_augment_transform(config_str, hparams): 482 | """ 483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 484 | 485 | Create a RandAugment transform 486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 488 | sections, not order sepecific determine 489 | 'm' - integer magnitude of rand augment 490 | 'n' - integer num layers (number of transform ops selected per image) 491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 492 | 'mstd' - float std deviation of magnitude noise applied 493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 497 | :return: A PyTorch compatible Transform 498 | """ 499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 500 | num_layers = 2 # default to 2 ops per image 501 | weight_idx = None # default to no probability weights for op choice 502 | transforms = _RAND_TRANSFORMS 503 | config = config_str.split("-") 504 | assert config[0] == "rand" 505 | config = config[1:] 506 | for c in config: 507 | cs = re.split(r"(\d.*)", c) 508 | if len(cs) < 2: 509 | continue 510 | key, val = cs[:2] 511 | if key == "mstd": 512 | # noise param injected via hparams for now 513 | hparams.setdefault("magnitude_std", float(val)) 514 | elif key == "inc": 515 | if bool(val): 516 | transforms = _RAND_INCREASING_TRANSFORMS 517 | elif key == "m": 518 | magnitude = int(val) 519 | elif key == "n": 520 | num_layers = int(val) 521 | elif key == "w": 522 | weight_idx = int(val) 523 | else: 524 | assert NotImplementedError 525 | ra_ops = rand_augment_ops( 526 | magnitude=magnitude, hparams=hparams, transforms=transforms 527 | ) 528 | choice_weights = ( 529 | None if weight_idx is None else _select_rand_weights(weight_idx) 530 | ) 531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 532 | --------------------------------------------------------------------------------