├── assets ├── EVEREST_plot.PNG ├── EVEREST_concept.PNG ├── VideoMS_concept.PNG └── icml2024_main_figure.pdf ├── requirements.txt ├── scripts ├── hmdb51 │ ├── pretrain.sh │ └── finetune.sh └── ucf101 │ ├── pretrain.sh │ └── finetune.sh ├── masking_generator.py ├── .gitignore ├── functional.py ├── README.md ├── volume_transforms.py ├── engine_for_pretraining.py ├── datasets.py ├── optim_factory.py ├── random_erasing.py ├── transforms.py ├── engine_for_finetuning.py ├── run_ms_pretraining.py ├── mixup.py ├── modeling_finetune.py ├── modeling_pretrain.py ├── rand_augment.py ├── utils.py ├── ucf.py └── run_class_finetuning.py /assets/EVEREST_plot.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/EVEREST_plot.PNG -------------------------------------------------------------------------------- /assets/EVEREST_concept.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/EVEREST_concept.PNG -------------------------------------------------------------------------------- /assets/VideoMS_concept.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/VideoMS_concept.PNG -------------------------------------------------------------------------------- /assets/icml2024_main_figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/icml2024_main_figure.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | deepspeed==0.7.3 3 | tensorboardX 4 | decord 5 | einops 6 | opencv-python 7 | scipy 8 | pandas -------------------------------------------------------------------------------- /scripts/hmdb51/pretrain.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR='output/hmdb51/pt' 2 | DATA_PATH='path_to_data/HMDB51/train.csv' 3 | 4 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 5 | --master_port 12320 run_ms_pretraining.py \ 6 | --data_path ${DATA_PATH} \ 7 | --mask_type motion-centric \ 8 | --motion_centric_masking_ratio 0.7 \ 9 | --mask_ratio 0.9 \ 10 | --model pretrain_videoms_base_patch16_224 \ 11 | --decoder_depth 4 \ 12 | --lr 1e-3 \ 13 | --batch_size 24 \ 14 | --num_frames 16 \ 15 | --sampling_rate 2 \ 16 | --opt adamw \ 17 | --opt_betas 0.9 0.95 \ 18 | --warmup_epochs 40 \ 19 | --epochs 4800 \ 20 | --save_ckpt_freq 100 \ 21 | --log_dir ${OUTPUT_DIR} \ 22 | --output_dir ${OUTPUT_DIR} 23 | -------------------------------------------------------------------------------- /scripts/ucf101/pretrain.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR='output/ucf101/pt' 2 | DATA_PATH='path_to_data/UCF101/train.csv' 3 | 4 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 5 | --master_port 12320 run_ms_pretraining.py \ 6 | --data_path ${DATA_PATH} \ 7 | --mask_type motion-centric \ 8 | --motion_centric_masking_ratio 0.7 \ 9 | --mask_ratio 0.9 \ 10 | --model pretrain_videoms_base_patch16_224 \ 11 | --decoder_depth 4 \ 12 | --lr 1e-3 \ 13 | --batch_size 24 \ 14 | --num_frames 16 \ 15 | --sampling_rate 4 \ 16 | --opt adamw \ 17 | --opt_betas 0.9 0.95 \ 18 | --warmup_epochs 40 \ 19 | --epochs 3200 \ 20 | --save_ckpt_freq 100 \ 21 | --log_dir ${OUTPUT_DIR} \ 22 | --output_dir ${OUTPUT_DIR} 23 | -------------------------------------------------------------------------------- /scripts/hmdb51/finetune.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR='output/hmdb51/ft' 2 | DATA_PATH='path_to_data/HMDB51/' # path to HMDB51 annotation file (train.csv/val.csv/test.csv) 3 | MODEL_PATH='output/hmdb51/pt/checkpoint-4799.pth' 4 | 5 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 6 | --master_port 12320 run_class_finetuning.py \ 7 | --model vit_base_patch16_224 \ 8 | --data_path ${DATA_PATH} \ 9 | --finetune ${MODEL_PATH} \ 10 | --log_dir ${OUTPUT_DIR} \ 11 | --output_dir ${OUTPUT_DIR} \ 12 | --data_set HMDB51 \ 13 | --nb_classes 51 \ 14 | --batch_size 16 \ 15 | --input_size 224 \ 16 | --short_side_size 224 \ 17 | --save_ckpt_freq 20 \ 18 | --num_frames 16 \ 19 | --sampling_rate 2 \ 20 | --num_sample 1 \ 21 | --opt adamw \ 22 | --lr 1e-3 \ 23 | --opt_betas 0.9 0.999 \ 24 | --weight_decay 0.05 \ 25 | --epochs 50 \ 26 | --test_num_segment 10 \ 27 | --test_num_crop 3 \ 28 | --use_checkpoint \ 29 | --dist_eval \ 30 | --enable_deepspeed \ 31 | --mcm \ 32 | --mcm_ratio 0.4 \ 33 | --mixup 0.0 \ 34 | --cutmix 0.0 35 | -------------------------------------------------------------------------------- /scripts/ucf101/finetune.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR='output/ucf101/ft' 2 | DATA_PATH='path_to_data/UCF101/' # path to UCF101 annotation file (train.csv/val.csv/test.csv) 3 | MODEL_PATH='output/ucf101/pt/checkpoint-3199.pth' 4 | 5 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 6 | --master_port 12320 run_class_finetuning.py \ 7 | --model vit_base_patch16_224 \ 8 | --data_path ${DATA_PATH} \ 9 | --finetune ${MODEL_PATH} \ 10 | --log_dir ${OUTPUT_DIR} \ 11 | --output_dir ${OUTPUT_DIR} \ 12 | --data_set UCF101 \ 13 | --nb_classes 101 \ 14 | --batch_size 16 \ 15 | --input_size 224 \ 16 | --short_side_size 224 \ 17 | --save_ckpt_freq 20 \ 18 | --num_frames 16 \ 19 | --sampling_rate 4 \ 20 | --num_sample 1 \ 21 | --opt adamw \ 22 | --lr 1e-3 \ 23 | --opt_betas 0.9 0.999 \ 24 | --weight_decay 0.05 \ 25 | --epochs 100 \ 26 | --test_num_segment 5 \ 27 | --test_num_crop 3 \ 28 | --use_checkpoint \ 29 | --dist_eval \ 30 | --enable_deepspeed \ 31 | --mcm \ 32 | --mcm_ratio 0.4 \ 33 | --mixup 0.0 \ 34 | --cutmix 0.0 35 | -------------------------------------------------------------------------------- /masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class TubeMaskingGenerator: 4 | def __init__(self, input_size, mask_ratio): 5 | self.frames, self.height, self.width = input_size 6 | self.num_patches_per_frame = self.height * self.width 7 | self.total_patches = self.frames * self.num_patches_per_frame 8 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 9 | self.total_masks = self.frames * self.num_masks_per_frame 10 | 11 | def __repr__(self): 12 | repr_str = "Maks: total patches {}, mask patches {}".format( 13 | self.total_patches, self.total_masks 14 | ) 15 | return repr_str 16 | 17 | def __call__(self): 18 | mask_per_frame = np.hstack([ 19 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 20 | np.ones(self.num_masks_per_frame), 21 | ]) 22 | np.random.shuffle(mask_per_frame) 23 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten() 24 | return mask 25 | 26 | class RandomMaskingGenerator: 27 | def __init__(self, input_size, mask_ratio): 28 | self.frames, self.height, self.width = input_size 29 | self.num_patches_per_frame = self.height * self.width 30 | self.total_patches = self.frames * self.num_patches_per_frame 31 | self.total_masks = int(mask_ratio * self.total_patches) 32 | 33 | def __repr__(self): 34 | repr_str = "Maks: total patches {}, mask patches {}".format( 35 | self.total_patches, self.total_masks 36 | ) 37 | return repr_str 38 | 39 | def __call__(self): 40 | mask_per_input = np.hstack([ 41 | np.zeros(self.total_patches - self.total_masks), 42 | np.ones(self.total_masks), 43 | ]) 44 | np.random.shuffle(mask_per_input) 45 | return mask_per_input -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | */__pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EVEREST: Efficient Masked Video Autoencoder by Removing Redundant Spatiotemporal Tokens [ICML2024] 2 | 3 | This repository is an official Pytorch implementation of [EVEREST: Efficient Masked Video Autoencoder by Removing Redundant Spatiotemporal Tokens](https://arxiv.org/abs/2211.10636). 4 | 5 | **The new version of EVEREST will be updated soon!!!** 🚨 6 | 7 |

8 | 9 |

10 | 11 | ## Abstract 12 | 13 | Masked Video Autoencoder (MVA) approaches have demonstrated their potential by significantly outperforming previous video representation learning methods. However, they waste an excessive amount of computations and memory in predicting uninformative tokens/frames due to random masking strategies. (e.g., over 16 nodes with 128 NVIDIA A100 GPUs). To resolve this issue, we exploit the unequal information density among the patches in videos and propose EVEREST, a surprisingly efficient MVA approach for video representation learning that finds tokens containing rich motion features and discards uninformative ones during both pre-training and fine-tuning. We further present an information-intensive frame selection strategy that allows the model to focus on informative and causal frames with minimal redundancy. Our method significantly reduces the computation and memory requirements of MVA, enabling the pre-training and fine-tuning on a single machine with 8 GPUs while achieving comparable performance to computation- and memory-heavy baselines on multiple benchmarks and the uncurated Ego4D dataset. We hope that our work contributes to reducing the barrier to further research on video understanding. 14 | 15 | ## Results 16 | 17 |

18 | 19 |

20 | 21 | ## Prerequisites 22 | EVEREST is built in **Python 3.7.12**, **torch 1.8.0** and **torchvision 0.9.0**. Please use the following command to install the requirements: 23 | ``` 24 | $ pip install -r requirements.txt 25 | ``` 26 | 27 | ## Run 28 | 1. __UCF101__ experiment 29 | ``` 30 | $ bash scripts/ucf101/pretrain.sh 31 | $ bash scripts/ucf101/finetune.sh 32 | ``` 33 | 34 | 2. __HMDB51__ experiment 35 | 36 | ``` 37 | $ bash scripts/hmdb51/pretrain.sh 38 | $ bash scripts/hmdb51/finetune.sh 39 | ``` 40 | 41 | 3. __K400, SSv2, OSCC__ experiment will be released soon. 42 | 43 | ## Dataset 44 | 1. Download [UCF101](https://www.crcv.ucf.edu/data/UCF101.php) and [HMDB51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) datasets from the official websites. 45 | 2. Make annotation files in `*.csv` format like this: 46 | ``` 47 | path_to_video/video_0.avi 0 48 | path_to_video/video_1.avi 0 49 | ... 50 | path_to_video/video_N.avi 101 51 | ``` 52 | 53 | ## Training Logs and Checkpoints 54 | ### UCF101 55 | 56 | | Backbone | \#Frame | Pre-train (3,200 epochs) | Fine-tune (100 epochs) | Top-1 | Top-5 | 57 | | :------: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: | 58 | | ViT-B | 16x5x3 | [log](https://drive.google.com/file/d/1dupg3ultdh1qsijUAYSZm8-hW2SAspLT/view?usp=share_link) / [checkpoint](https://drive.google.com/file/d/1liGNGprKdfiOCArK-WMqIcfeOJ-AZKzr/view?usp=share_link) | [log](https://drive.google.com/file/d/1EMlHBPqTC1_QURiCiaOdwPeoXdL67Gql/view?usp=share_link) / [checkpoint](https://drive.google.com/file/d/1iGFUxYpzjb7zaajB0O0j1MzS6PzyzrQF/view?usp=share_link) | 93.7 | 98.9 | 59 | 60 | ## Contact 61 | Sunil Hwang: sunilhoho@kaist.ac.kr 62 | Jaehong Yoon: jaehong.yoon@kaist.ac.kr 63 | 64 | ## Acknowledgment 65 | The code is built upon [VidoeMAE](https://github.com/MCG-NJU/VideoMAE). 66 | 67 | ## Citations 68 | ``` 69 | @inproceedings{hwang2024everest, 70 | title={EVEREST: Efficient Masked Video Autoencoder by Removing Redundant Spatiotemporal Tokens}, 71 | author={Hwang, Sunil and Yoon, Jaehong and Lee, Youngwan and Hwang, Sung Ju}, 72 | booktitle={International Conference on Machine Learning}, 73 | year={2024}, 74 | } 75 | ``` -------------------------------------------------------------------------------- /volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /engine_for_pretraining.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | from einops import rearrange 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | 10 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, 11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16, 12 | normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None, 13 | lr_schedule_values=None, wd_schedule_values=None, mask_type='motion-centric'): 14 | model.train() 15 | metric_logger = utils.MetricLogger(delimiter=" ") 16 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 17 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 18 | header = 'Epoch: [{}]'.format(epoch) 19 | print_freq = 10 20 | 21 | loss_func = nn.MSELoss() 22 | 23 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 24 | # assign learning rate & weight decay for each step 25 | it = start_steps + step # global training iteration 26 | if lr_schedule_values is not None or wd_schedule_values is not None: 27 | for i, param_group in enumerate(optimizer.param_groups): 28 | if lr_schedule_values is not None: 29 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 30 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 31 | param_group["weight_decay"] = wd_schedule_values[it] 32 | 33 | videos, bool_masked_pos = batch 34 | videos = videos.to(device, non_blocking=True) 35 | if mask_type != 'motion-centric': 36 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool) 37 | 38 | with torch.no_grad(): 39 | # calculate the predict label 40 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] 41 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] 42 | unnorm_videos = videos * std + mean # in [0, 1] 43 | 44 | if normlize_target: 45 | videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size) 46 | videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True) 47 | ) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) 48 | # we find that the mean is about 0.48 and standard deviation is about 0.08. 49 | videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)') 50 | else: 51 | videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size) 52 | 53 | B, _, C = videos_patch.shape 54 | if mask_type != 'motion-centric': 55 | labels = videos_patch[bool_masked_pos].reshape(B, -1, C) 56 | 57 | with torch.cuda.amp.autocast(): 58 | outputs, (_, mc_target_mask) = model(videos, bool_masked_pos) 59 | if mask_type == 'motion-centric': 60 | labels = videos_patch[~mc_target_mask].reshape(B, -1, C) 61 | loss = loss_func(input=outputs, target=labels) 62 | 63 | loss_value = loss.item() 64 | 65 | if not math.isfinite(loss_value): 66 | print("Loss is {}, stopping training".format(loss_value)) 67 | sys.exit(1) 68 | 69 | optimizer.zero_grad() 70 | # this attribute is added by timm on one optimizer (adahessian) 71 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 72 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 73 | parameters=model.parameters(), create_graph=is_second_order) 74 | loss_scale_value = loss_scaler.state_dict()["scale"] 75 | 76 | torch.cuda.synchronize() 77 | 78 | metric_logger.update(loss=loss_value) 79 | metric_logger.update(loss_scale=loss_scale_value) 80 | min_lr = 10. 81 | max_lr = 0. 82 | for group in optimizer.param_groups: 83 | min_lr = min(min_lr, group["lr"]) 84 | max_lr = max(max_lr, group["lr"]) 85 | 86 | metric_logger.update(lr=max_lr) 87 | metric_logger.update(min_lr=min_lr) 88 | weight_decay_value = None 89 | for group in optimizer.param_groups: 90 | if group["weight_decay"] > 0: 91 | weight_decay_value = group["weight_decay"] 92 | metric_logger.update(weight_decay=weight_decay_value) 93 | metric_logger.update(grad_norm=grad_norm) 94 | 95 | if log_writer is not None: 96 | log_writer.update(loss=loss_value, head="loss") 97 | log_writer.update(loss_scale=loss_scale_value, head="opt") 98 | log_writer.update(lr=max_lr, head="opt") 99 | log_writer.update(min_lr=min_lr, head="opt") 100 | log_writer.update(weight_decay=weight_decay_value, head="opt") 101 | log_writer.update(grad_norm=grad_norm, head="opt") 102 | log_writer.set_step() 103 | 104 | if lr_scheduler is not None: 105 | lr_scheduler.step_update(start_steps + step) 106 | # gather the stats from all processes 107 | metric_logger.synchronize_between_processes() 108 | print("Averaged stats:", metric_logger) 109 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 110 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import transforms 3 | from transforms import * 4 | from masking_generator import TubeMaskingGenerator, RandomMaskingGenerator 5 | from ucf import VideoClsDataset, VideoMS 6 | 7 | 8 | class DataAugmentationForVideoMS(object): 9 | def __init__(self, args): 10 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 11 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 12 | normalize = GroupNormalize(self.input_mean, self.input_std) 13 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) 14 | self.transform = transforms.Compose([ 15 | self.train_augmentation, 16 | Stack(roll=False), 17 | ToTorchFormatTensor(div=True), 18 | normalize, 19 | ]) 20 | self.mcm = False 21 | if args.mask_type == 'tube': 22 | self.masked_position_generator = TubeMaskingGenerator( 23 | args.window_size, args.mask_ratio 24 | ) 25 | elif args.mask_type == 'random': 26 | self.masked_position_generator = RandomMaskingGenerator( 27 | args.window_size, args.mask_ratio 28 | ) 29 | elif args.mask_type == 'motion-centric': 30 | self.mcm = True 31 | self.masked_position_generator = 'Motion-centric Masking' 32 | 33 | def __call__(self, images): 34 | process_data, _ = self.transform(images) 35 | if self.mcm: 36 | return process_data, 0 37 | return process_data, self.masked_position_generator() 38 | 39 | def __repr__(self): 40 | repr = "(DataAugmentationForVideoMS,\n" 41 | repr += " transform = %s,\n" % str(self.transform) 42 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 43 | repr += ")" 44 | return repr 45 | 46 | 47 | def build_pretraining_dataset(args): 48 | transform = DataAugmentationForVideoMS(args) 49 | dataset = VideoMS( 50 | root=None, 51 | setting=args.data_path, 52 | video_ext='mp4', 53 | is_color=True, 54 | modality='rgb', 55 | new_length=args.num_frames, 56 | new_step=args.sampling_rate, 57 | transform=transform, 58 | temporal_jitter=False, 59 | video_loader=True, 60 | use_decord=True, 61 | lazy_init=False) 62 | print("Data Aug = %s" % str(transform)) 63 | return dataset 64 | 65 | 66 | def build_dataset(is_train, test_mode, args): 67 | if args.data_set == 'UCF101': 68 | mode = None 69 | anno_path = None 70 | if is_train is True: 71 | mode = 'train' 72 | anno_path = os.path.join(args.data_path, 'train.csv') 73 | elif test_mode is True: 74 | mode = 'test' 75 | anno_path = os.path.join(args.data_path, 'test.csv') 76 | else: 77 | mode = 'validation' 78 | anno_path = os.path.join(args.data_path, 'val.csv') 79 | 80 | dataset = VideoClsDataset( 81 | anno_path=anno_path, 82 | data_path='/', 83 | mode=mode, 84 | clip_len=args.num_frames, 85 | frame_sample_rate=args.sampling_rate, 86 | num_segment=1, 87 | test_num_segment=args.test_num_segment, 88 | test_num_crop=args.test_num_crop, 89 | num_crop=1 if not test_mode else 3, 90 | keep_aspect_ratio=True, 91 | crop_size=args.input_size, 92 | short_side_size=args.short_side_size, 93 | new_height=256, 94 | new_width=320, 95 | args=args) 96 | nb_classes = 101 97 | 98 | elif args.data_set == 'HMDB51': 99 | mode = None 100 | anno_path = None 101 | if is_train is True: 102 | mode = 'train' 103 | anno_path = os.path.join(args.data_path, 'train.csv') 104 | elif test_mode is True: 105 | mode = 'test' 106 | anno_path = os.path.join(args.data_path, 'test.csv') 107 | else: 108 | mode = 'validation' 109 | anno_path = os.path.join(args.data_path, 'val.csv') 110 | 111 | dataset = VideoClsDataset( 112 | anno_path=anno_path, 113 | data_path='/', 114 | mode=mode, 115 | clip_len=args.num_frames, 116 | frame_sample_rate=args.sampling_rate, 117 | num_segment=1, 118 | test_num_segment=args.test_num_segment, 119 | test_num_crop=args.test_num_crop, 120 | num_crop=1 if not test_mode else 3, 121 | keep_aspect_ratio=True, 122 | crop_size=args.input_size, 123 | short_side_size=args.short_side_size, 124 | new_height=256, 125 | new_width=320, 126 | args=args) 127 | nb_classes = 51 128 | 129 | elif args.data_set == 'OSCC': 130 | mode = None 131 | anno_path = None 132 | if is_train is True: 133 | mode = 'train' 134 | anno_path = os.path.join(args.data_path, 'train.csv') 135 | elif test_mode is True: 136 | mode = 'test' 137 | anno_path = os.path.join(args.data_path, 'test.csv') 138 | else: 139 | mode = 'validation' 140 | anno_path = os.path.join(args.data_path, 'val.csv') 141 | 142 | dataset = VideoClsDataset( 143 | anno_path=anno_path, 144 | data_path='/', 145 | mode=mode, 146 | clip_len=args.num_frames, 147 | frame_sample_rate=args.sampling_rate, 148 | num_segment=1, 149 | test_num_segment=args.test_num_segment, 150 | test_num_crop=args.test_num_crop, 151 | num_crop=1 if not test_mode else 3, 152 | keep_aspect_ratio=True, 153 | crop_size=args.input_size, 154 | short_side_size=args.short_side_size, 155 | new_height=256, 156 | new_width=320, 157 | args=args) 158 | nb_classes = 2 159 | 160 | else: 161 | raise NotImplementedError() 162 | assert nb_classes == args.nb_classes 163 | print("Number of the class = %d" % args.nb_classes) 164 | 165 | return dataset, nb_classes 166 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_vit(var_name, num_max_layer): 25 | if var_name in ("cls_token", "mask_token", "pos_embed"): 26 | return 0 27 | elif var_name.startswith("patch_embed"): 28 | return 0 29 | elif var_name.startswith("rel_pos_bias"): 30 | return num_max_layer - 1 31 | elif var_name.startswith("blocks"): 32 | layer_id = int(var_name.split('.')[1]) 33 | return layer_id + 1 34 | else: 35 | return num_max_layer - 1 36 | 37 | 38 | class LayerDecayValueAssigner(object): 39 | def __init__(self, values): 40 | self.values = values 41 | 42 | def get_scale(self, layer_id): 43 | return self.values[layer_id] 44 | 45 | def get_layer_id(self, var_name): 46 | return get_num_layer_for_vit(var_name, len(self.values)) 47 | 48 | 49 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 50 | parameter_group_names = {} 51 | parameter_group_vars = {} 52 | 53 | for name, param in model.named_parameters(): 54 | if not param.requires_grad: 55 | continue # frozen weights 56 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 57 | group_name = "no_decay" 58 | this_weight_decay = 0. 59 | else: 60 | group_name = "decay" 61 | this_weight_decay = weight_decay 62 | if get_num_layer is not None: 63 | layer_id = get_num_layer(name) 64 | group_name = "layer_%d_%s" % (layer_id, group_name) 65 | else: 66 | layer_id = None 67 | 68 | if group_name not in parameter_group_names: 69 | if get_layer_scale is not None: 70 | scale = get_layer_scale(layer_id) 71 | else: 72 | scale = 1. 73 | 74 | parameter_group_names[group_name] = { 75 | "weight_decay": this_weight_decay, 76 | "params": [], 77 | "lr_scale": scale 78 | } 79 | parameter_group_vars[group_name] = { 80 | "weight_decay": this_weight_decay, 81 | "params": [], 82 | "lr_scale": scale 83 | } 84 | 85 | parameter_group_vars[group_name]["params"].append(param) 86 | parameter_group_names[group_name]["params"].append(name) 87 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 88 | return list(parameter_group_vars.values()) 89 | 90 | 91 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 92 | opt_lower = args.opt.lower() 93 | weight_decay = args.weight_decay 94 | if weight_decay and filter_bias_and_bn: 95 | skip = {} 96 | if skip_list is not None: 97 | skip = skip_list 98 | elif hasattr(model, 'no_weight_decay'): 99 | skip = model.no_weight_decay() 100 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 101 | weight_decay = 0. 102 | else: 103 | parameters = model.parameters() 104 | 105 | if 'fused' in opt_lower: 106 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 107 | 108 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 109 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 110 | opt_args['eps'] = args.opt_eps 111 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 112 | opt_args['betas'] = args.opt_betas 113 | 114 | print("optimizer settings:", opt_args) 115 | 116 | opt_split = opt_lower.split('_') 117 | opt_lower = opt_split[-1] 118 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 119 | opt_args.pop('eps', None) 120 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 121 | elif opt_lower == 'momentum': 122 | opt_args.pop('eps', None) 123 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 124 | elif opt_lower == 'adam': 125 | optimizer = optim.Adam(parameters, **opt_args) 126 | elif opt_lower == 'adamw': 127 | optimizer = optim.AdamW(parameters, **opt_args) 128 | elif opt_lower == 'nadam': 129 | optimizer = Nadam(parameters, **opt_args) 130 | elif opt_lower == 'radam': 131 | optimizer = RAdam(parameters, **opt_args) 132 | elif opt_lower == 'adamp': 133 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 134 | elif opt_lower == 'sgdp': 135 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 136 | elif opt_lower == 'adadelta': 137 | optimizer = optim.Adadelta(parameters, **opt_args) 138 | elif opt_lower == 'adafactor': 139 | if not args.lr: 140 | opt_args['lr'] = None 141 | optimizer = Adafactor(parameters, **opt_args) 142 | elif opt_lower == 'adahessian': 143 | optimizer = Adahessian(parameters, **opt_args) 144 | elif opt_lower == 'rmsprop': 145 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 146 | elif opt_lower == 'rmsproptf': 147 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 148 | elif opt_lower == 'novograd': 149 | optimizer = NovoGrad(parameters, **opt_args) 150 | elif opt_lower == 'nvnovograd': 151 | optimizer = NvNovoGrad(parameters, **opt_args) 152 | elif opt_lower == 'fusedsgd': 153 | opt_args.pop('eps', None) 154 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 155 | elif opt_lower == 'fusedmomentum': 156 | opt_args.pop('eps', None) 157 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 158 | elif opt_lower == 'fusedadam': 159 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 160 | elif opt_lower == 'fusedadamw': 161 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 162 | elif opt_lower == 'fusedlamb': 163 | optimizer = FusedLAMB(parameters, **opt_args) 164 | elif opt_lower == 'fusednovograd': 165 | opt_args.setdefault('betas', (0.95, 0.98)) 166 | optimizer = FusedNovoGrad(parameters, **opt_args) 167 | else: 168 | assert False and "Invalid optimizer" 169 | raise ValueError 170 | 171 | if len(opt_split) > 1: 172 | if opt_split[0] == 'lookahead': 173 | optimizer = Lookahead(optimizer) 174 | 175 | return optimizer 176 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import random 5 | import numpy as np 6 | import torchvision 7 | from PIL import Image, ImageOps 8 | import numbers 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_tuple): 19 | img_group, label = img_tuple 20 | 21 | w, h = img_group[0].size 22 | th, tw = self.size 23 | 24 | out_images = list() 25 | 26 | x1 = random.randint(0, w - tw) 27 | y1 = random.randint(0, h - th) 28 | 29 | for img in img_group: 30 | assert(img.size[0] == w and img.size[1] == h) 31 | if w == tw and h == th: 32 | out_images.append(img) 33 | else: 34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 35 | 36 | return (out_images, label) 37 | 38 | 39 | class GroupCenterCrop(object): 40 | def __init__(self, size): 41 | self.worker = torchvision.transforms.CenterCrop(size) 42 | 43 | def __call__(self, img_tuple): 44 | img_group, label = img_tuple 45 | return ([self.worker(img) for img in img_group], label) 46 | 47 | 48 | class GroupNormalize(object): 49 | def __init__(self, mean, std): 50 | self.mean = mean 51 | self.std = std 52 | 53 | def __call__(self, tensor_tuple): 54 | tensor, label = tensor_tuple 55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 56 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 57 | 58 | # TODO: make efficient 59 | for t, m, s in zip(tensor, rep_mean, rep_std): 60 | t.sub_(m).div_(s) 61 | 62 | return (tensor,label) 63 | 64 | 65 | class GroupGrayScale(object): 66 | def __init__(self, size): 67 | self.worker = torchvision.transforms.Grayscale(size) 68 | 69 | def __call__(self, img_tuple): 70 | img_group, label = img_tuple 71 | return ([self.worker(img) for img in img_group], label) 72 | 73 | 74 | class GroupScale(object): 75 | """ Rescales the input PIL.Image to the given 'size'. 76 | 'size' will be the size of the smaller edge. 77 | For example, if height > width, then image will be 78 | rescaled to (size * height / width, size) 79 | size: size of the smaller edge 80 | interpolation: Default: PIL.Image.BILINEAR 81 | """ 82 | 83 | def __init__(self, size, interpolation=Image.BILINEAR): 84 | self.worker = torchvision.transforms.Resize(size, interpolation) 85 | 86 | def __call__(self, img_tuple): 87 | img_group, label = img_tuple 88 | return ([self.worker(img) for img in img_group], label) 89 | 90 | 91 | class GroupMultiScaleCrop(object): 92 | 93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 94 | self.scales = scales if scales is not None else [1, 875, .75, .66] 95 | self.max_distort = max_distort 96 | self.fix_crop = fix_crop 97 | self.more_fix_crop = more_fix_crop 98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 99 | self.interpolation = Image.BILINEAR 100 | 101 | def __call__(self, img_tuple): 102 | img_group, label = img_tuple 103 | 104 | im_size = img_group[0].size 105 | 106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 109 | return (ret_img_group, label) 110 | 111 | def _sample_crop_size(self, im_size): 112 | image_w, image_h = im_size[0], im_size[1] 113 | 114 | # find a crop size 115 | base_size = min(image_w, image_h) 116 | crop_sizes = [int(base_size * x) for x in self.scales] 117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 119 | 120 | pairs = [] 121 | for i, h in enumerate(crop_h): 122 | for j, w in enumerate(crop_w): 123 | if abs(i - j) <= self.max_distort: 124 | pairs.append((w, h)) 125 | 126 | crop_pair = random.choice(pairs) 127 | if not self.fix_crop: 128 | w_offset = random.randint(0, image_w - crop_pair[0]) 129 | h_offset = random.randint(0, image_h - crop_pair[1]) 130 | else: 131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 132 | 133 | return crop_pair[0], crop_pair[1], w_offset, h_offset 134 | 135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 137 | return random.choice(offsets) 138 | 139 | @staticmethod 140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 141 | w_step = (image_w - crop_w) // 4 142 | h_step = (image_h - crop_h) // 4 143 | 144 | ret = list() 145 | ret.append((0, 0)) # upper left 146 | ret.append((4 * w_step, 0)) # upper right 147 | ret.append((0, 4 * h_step)) # lower left 148 | ret.append((4 * w_step, 4 * h_step)) # lower right 149 | ret.append((2 * w_step, 2 * h_step)) # center 150 | 151 | if more_fix_crop: 152 | ret.append((0, 2 * h_step)) # center left 153 | ret.append((4 * w_step, 2 * h_step)) # center right 154 | ret.append((2 * w_step, 4 * h_step)) # lower center 155 | ret.append((2 * w_step, 0 * h_step)) # upper center 156 | 157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 161 | return ret 162 | 163 | 164 | class Stack(object): 165 | 166 | def __init__(self, roll=False): 167 | self.roll = roll 168 | 169 | def __call__(self, img_tuple): 170 | img_group, label = img_tuple 171 | 172 | if img_group[0].mode == 'L': 173 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 174 | elif img_group[0].mode == 'RGB': 175 | if self.roll: 176 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 177 | else: 178 | return (np.concatenate(img_group, axis=2), label) 179 | 180 | 181 | class ToTorchFormatTensor(object): 182 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 183 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 184 | def __init__(self, div=True): 185 | self.div = div 186 | 187 | def __call__(self, pic_tuple): 188 | pic, label = pic_tuple 189 | 190 | if isinstance(pic, np.ndarray): 191 | # handle numpy array 192 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 193 | else: 194 | # handle PIL Image 195 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 196 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 197 | # put it from HWC to CHW format 198 | # yikes, this transpose takes 80% of the loading time/CPU 199 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 200 | return (img.float().div(255.) if self.div else img.float(), label) 201 | 202 | 203 | class IdentityTransform(object): 204 | 205 | def __call__(self, data): 206 | return data 207 | -------------------------------------------------------------------------------- /engine_for_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import sys 5 | from typing import Iterable, Optional 6 | import torch 7 | from mixup import Mixup 8 | from timm.utils import accuracy, ModelEma 9 | import utils 10 | from scipy.special import softmax 11 | 12 | 13 | def train_class_batch(model, samples, target, criterion): 14 | outputs = model(samples) 15 | loss = criterion(outputs, target) 16 | return loss, outputs 17 | 18 | 19 | def get_loss_scale_for_deepspeed(model): 20 | optimizer = model.optimizer 21 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 22 | 23 | 24 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 28 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 29 | num_training_steps_per_epoch=None, update_freq=None): 30 | model.train(True) 31 | metric_logger = utils.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | header = 'Epoch: [{}]'.format(epoch) 35 | print_freq = 10 36 | 37 | if loss_scaler is None: 38 | model.zero_grad() 39 | model.micro_steps = 0 40 | else: 41 | optimizer.zero_grad() 42 | 43 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | step = data_iter_step // update_freq 45 | if step >= num_training_steps_per_epoch: 46 | continue 47 | it = start_steps + step # global training iteration 48 | # Update LR & WD for the first acc 49 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 50 | for i, param_group in enumerate(optimizer.param_groups): 51 | if lr_schedule_values is not None: 52 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 53 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 54 | param_group["weight_decay"] = wd_schedule_values[it] 55 | 56 | samples = samples.to(device, non_blocking=True) 57 | targets = targets.to(device, non_blocking=True) 58 | 59 | if mixup_fn is not None: 60 | samples, targets = mixup_fn(samples, targets) 61 | 62 | if loss_scaler is None: 63 | samples = samples.half() 64 | loss, output = train_class_batch( 65 | model, samples, targets, criterion) 66 | else: 67 | with torch.cuda.amp.autocast(): 68 | loss, output = train_class_batch( 69 | model, samples, targets, criterion) 70 | 71 | loss_value = loss.item() 72 | 73 | if not math.isfinite(loss_value): 74 | print("Loss is {}, stopping training".format(loss_value)) 75 | sys.exit(1) 76 | 77 | if loss_scaler is None: 78 | loss /= update_freq 79 | model.backward(loss) 80 | model.step() 81 | 82 | if (data_iter_step + 1) % update_freq == 0: 83 | # model.zero_grad() 84 | # Deepspeed will call step() & model.zero_grad() automatic 85 | if model_ema is not None: 86 | model_ema.update(model) 87 | grad_norm = None 88 | loss_scale_value = get_loss_scale_for_deepspeed(model) 89 | else: 90 | # this attribute is added by timm on one optimizer (adahessian) 91 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 92 | loss /= update_freq 93 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 94 | parameters=model.parameters(), create_graph=is_second_order, 95 | update_grad=(data_iter_step + 1) % update_freq == 0) 96 | if (data_iter_step + 1) % update_freq == 0: 97 | optimizer.zero_grad() 98 | if model_ema is not None: 99 | model_ema.update(model) 100 | loss_scale_value = loss_scaler.state_dict()["scale"] 101 | 102 | torch.cuda.synchronize() 103 | 104 | if mixup_fn is None: 105 | class_acc = (output.max(-1)[-1] == targets).float().mean() 106 | else: 107 | class_acc = None 108 | metric_logger.update(loss=loss_value) 109 | metric_logger.update(class_acc=class_acc) 110 | metric_logger.update(loss_scale=loss_scale_value) 111 | min_lr = 10. 112 | max_lr = 0. 113 | for group in optimizer.param_groups: 114 | min_lr = min(min_lr, group["lr"]) 115 | max_lr = max(max_lr, group["lr"]) 116 | 117 | metric_logger.update(lr=max_lr) 118 | metric_logger.update(min_lr=min_lr) 119 | weight_decay_value = None 120 | for group in optimizer.param_groups: 121 | if group["weight_decay"] > 0: 122 | weight_decay_value = group["weight_decay"] 123 | metric_logger.update(weight_decay=weight_decay_value) 124 | metric_logger.update(grad_norm=grad_norm) 125 | 126 | if log_writer is not None: 127 | log_writer.update(loss=loss_value, head="loss") 128 | log_writer.update(class_acc=class_acc, head="loss") 129 | log_writer.update(loss_scale=loss_scale_value, head="opt") 130 | log_writer.update(lr=max_lr, head="opt") 131 | log_writer.update(min_lr=min_lr, head="opt") 132 | log_writer.update(weight_decay=weight_decay_value, head="opt") 133 | log_writer.update(grad_norm=grad_norm, head="opt") 134 | 135 | log_writer.set_step() 136 | 137 | # gather the stats from all processes 138 | metric_logger.synchronize_between_processes() 139 | print("Averaged stats:", metric_logger) 140 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 141 | 142 | 143 | @torch.no_grad() 144 | def validation_one_epoch(data_loader, model, device): 145 | criterion = torch.nn.CrossEntropyLoss() 146 | 147 | metric_logger = utils.MetricLogger(delimiter=" ") 148 | header = 'Val:' 149 | 150 | # switch to evaluation mode 151 | model.eval() 152 | 153 | for batch in metric_logger.log_every(data_loader, 10, header): 154 | videos = batch[0] 155 | target = batch[1] 156 | videos = videos.to(device, non_blocking=True) 157 | target = target.to(device, non_blocking=True) 158 | 159 | # compute output 160 | with torch.cuda.amp.autocast(): 161 | output = model(videos) 162 | loss = criterion(output, target) 163 | 164 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 165 | 166 | batch_size = videos.shape[0] 167 | metric_logger.update(loss=loss.item()) 168 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 169 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 170 | # gather the stats from all processes 171 | metric_logger.synchronize_between_processes() 172 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 173 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 174 | 175 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 176 | 177 | 178 | 179 | @torch.no_grad() 180 | def final_test(data_loader, model, device, file): 181 | criterion = torch.nn.CrossEntropyLoss() 182 | 183 | metric_logger = utils.MetricLogger(delimiter=" ") 184 | header = 'Test:' 185 | 186 | # switch to evaluation mode 187 | model.eval() 188 | final_result = [] 189 | 190 | for batch in metric_logger.log_every(data_loader, 10, header): 191 | videos = batch[0] 192 | target = batch[1] 193 | ids = batch[2] 194 | chunk_nb = batch[3] 195 | split_nb = batch[4] 196 | videos = videos.to(device, non_blocking=True) 197 | target = target.to(device, non_blocking=True) 198 | 199 | # compute output 200 | with torch.cuda.amp.autocast(): 201 | output = model(videos) 202 | loss = criterion(output, target) 203 | 204 | for i in range(output.size(0)): 205 | string = "{} {} {} {} {}\n".format(ids[i], \ 206 | str(output.data[i].cpu().numpy().tolist()), \ 207 | str(int(target[i].cpu().numpy())), \ 208 | str(int(chunk_nb[i].cpu().numpy())), \ 209 | str(int(split_nb[i].cpu().numpy()))) 210 | final_result.append(string) 211 | 212 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 213 | 214 | batch_size = videos.shape[0] 215 | metric_logger.update(loss=loss.item()) 216 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 217 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 218 | 219 | if not os.path.exists(file): 220 | os.mknod(file) 221 | with open(file, 'w') as f: 222 | f.write("{}, {}\n".format(acc1, acc5)) 223 | for line in final_result: 224 | f.write(line) 225 | # gather the stats from all processes 226 | metric_logger.synchronize_between_processes() 227 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 228 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 229 | 230 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 231 | 232 | 233 | def merge(eval_path, num_tasks): 234 | dict_feats = {} 235 | dict_label = {} 236 | dict_pos = {} 237 | print("Reading individual output files") 238 | 239 | for x in range(num_tasks): 240 | file = os.path.join(eval_path, str(x) + '.txt') 241 | lines = open(file, 'r').readlines()[1:] 242 | for line in lines: 243 | line = line.strip() 244 | name = line.split('[')[0] 245 | label = line.split(']')[1].split(' ')[1] 246 | chunk_nb = line.split(']')[1].split(' ')[2] 247 | split_nb = line.split(']')[1].split(' ')[3] 248 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',') 249 | data = softmax(data) 250 | if not name in dict_feats: 251 | dict_feats[name] = [] 252 | dict_label[name] = 0 253 | dict_pos[name] = [] 254 | if chunk_nb + split_nb in dict_pos[name]: 255 | continue 256 | dict_feats[name].append(data) 257 | dict_pos[name].append(chunk_nb + split_nb) 258 | dict_label[name] = label 259 | print("Computing final results") 260 | 261 | input_lst = [] 262 | print(len(dict_feats)) 263 | for i, item in enumerate(dict_feats): 264 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 265 | from multiprocessing import Pool 266 | p = Pool(64) 267 | ans = p.map(compute_video, input_lst) 268 | top1 = [x[1] for x in ans] 269 | top5 = [x[2] for x in ans] 270 | pred = [x[0] for x in ans] 271 | label = [x[3] for x in ans] 272 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 273 | return final_top1*100 ,final_top5*100 274 | 275 | def compute_video(lst): 276 | i, video_id, data, label = lst 277 | feat = [x for x in data] 278 | feat = np.mean(feat, axis=0) 279 | pred = np.argmax(feat) 280 | top1 = (int(pred) == int(label)) * 1.0 281 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 282 | return [pred, top1, top5, int(label)] 283 | -------------------------------------------------------------------------------- /run_ms_pretraining.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import os 9 | from pathlib import Path 10 | from timm.models import create_model 11 | from optim_factory import create_optimizer 12 | from datasets import build_pretraining_dataset 13 | from engine_for_pretraining import train_one_epoch 14 | from utils import NativeScalerWithGradNormCount as NativeScaler 15 | import utils 16 | import modeling_pretrain 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser('VideoMS pre-training script', add_help=False) 21 | parser.add_argument('--batch_size', default=64, type=int) 22 | parser.add_argument('--epochs', default=800, type=int) 23 | parser.add_argument('--save_ckpt_freq', default=50, type=int) 24 | 25 | # Model parameters 26 | parser.add_argument('--model', default='pretrain_videoms_base_patch16_224', type=str, metavar='MODEL', 27 | help='Name of model to train') 28 | 29 | parser.add_argument('--decoder_depth', default=4, type=int, 30 | help='depth of decoder') 31 | 32 | parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'motion-centric'], 33 | type=str, help='masked strategy of video tokens/patches') 34 | 35 | parser.add_argument('--mask_ratio', default=0.75, type=float, 36 | help='ratio of the visual tokens/patches need be masked') 37 | 38 | parser.add_argument('--motion_centric_masking_ratio', default=0.7, type=float, 39 | help='ratio of the motion-centric masking') 40 | 41 | parser.add_argument('--input_size', default=224, type=int, 42 | help='videos input size for backbone') 43 | 44 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', 45 | help='Drop path rate (default: 0.1)') 46 | 47 | parser.add_argument('--normlize_target', default=True, type=bool, 48 | help='normalized the target patch pixels') 49 | 50 | # Optimizer parameters 51 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 52 | help='Optimizer (default: "adamw"') 53 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 54 | help='Optimizer Epsilon (default: 1e-8)') 55 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 56 | help='Optimizer Betas (default: None, use opt default)') 57 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 58 | help='Clip gradient norm (default: None, no clipping)') 59 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 60 | help='SGD momentum (default: 0.9)') 61 | parser.add_argument('--weight_decay', type=float, default=0.05, 62 | help='weight decay (default: 0.05)') 63 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 64 | weight decay. We use a cosine schedule for WD. 65 | (Set the same value with args.weight_decay to keep weight decay no change)""") 66 | 67 | parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR', 68 | help='learning rate (default: 1.5e-4)') 69 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 70 | help='warmup learning rate (default: 1e-6)') 71 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR', 72 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 73 | 74 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 75 | help='epochs to warmup LR, if scheduler supports') 76 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 77 | help='epochs to warmup LR, if scheduler supports') 78 | parser.add_argument('--use_checkpoint', action='store_true') 79 | parser.set_defaults(use_checkpoint=False) 80 | 81 | # Augmentation parameters 82 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT', 83 | help='Color jitter factor (default: 0.4)') 84 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 85 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 86 | 87 | # Dataset parameters 88 | parser.add_argument('--data_path', default='/path/to/list_ucf-101', type=str, 89 | help='dataset path') 90 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true') 91 | parser.add_argument('--num_frames', type=int, default= 16) 92 | parser.add_argument('--sampling_rate', type=int, default= 4) 93 | parser.add_argument('--output_dir', default='', 94 | help='path where to save, empty for no saving') 95 | parser.add_argument('--log_dir', default=None, 96 | help='path where to tensorboard log') 97 | parser.add_argument('--device', default='cuda', 98 | help='device to use for training / testing') 99 | parser.add_argument('--seed', default=0, type=int) 100 | parser.add_argument('--resume', default='', help='resume from checkpoint') 101 | parser.add_argument('--auto_resume', action='store_true') 102 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 103 | parser.set_defaults(auto_resume=True) 104 | 105 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 106 | help='start epoch') 107 | parser.add_argument('--num_workers', default=10, type=int) 108 | parser.add_argument('--pin_mem', action='store_true', 109 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 110 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem', 111 | help='') 112 | parser.set_defaults(pin_mem=True) 113 | 114 | # distributed training parameters 115 | parser.add_argument('--world_size', default=1, type=int, 116 | help='number of distributed processes') 117 | parser.add_argument('--local_rank', default=-1, type=int) 118 | parser.add_argument('--dist_on_itp', action='store_true') 119 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 120 | 121 | return parser.parse_args() 122 | 123 | 124 | def get_model(args): 125 | print(f"Creating model: {args.model}") 126 | model = create_model( 127 | args.model, 128 | pretrained=False, 129 | drop_path_rate=args.drop_path, 130 | drop_block_rate=None, 131 | decoder_depth=args.decoder_depth, 132 | use_checkpoint=args.use_checkpoint, 133 | motion_centric_masking=args.mask_type=='motion-centric', 134 | motion_centric_masking_ratio=args.motion_centric_masking_ratio, 135 | masking_ratio=args.mask_ratio 136 | ) 137 | return model 138 | 139 | 140 | def main(args): 141 | utils.init_distributed_mode(args) 142 | 143 | print(args) 144 | 145 | device = torch.device(args.device) 146 | 147 | # fix the seed for reproducibility 148 | seed = args.seed + utils.get_rank() 149 | torch.manual_seed(seed) 150 | np.random.seed(seed) 151 | 152 | cudnn.benchmark = True 153 | 154 | model = get_model(args) 155 | patch_size = model.encoder.patch_embed.patch_size 156 | print("Patch size = %s" % str(patch_size)) 157 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) 158 | args.patch_size = patch_size 159 | 160 | # get dataset 161 | dataset_train = build_pretraining_dataset(args) 162 | 163 | 164 | num_tasks = utils.get_world_size() 165 | global_rank = utils.get_rank() 166 | sampler_rank = global_rank 167 | num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks 168 | 169 | sampler_train = torch.utils.data.DistributedSampler( 170 | dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True 171 | ) 172 | print("Sampler_train = %s" % str(sampler_train)) 173 | 174 | 175 | if global_rank == 0 and args.log_dir is not None: 176 | os.makedirs(args.log_dir, exist_ok=True) 177 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 178 | else: 179 | log_writer = None 180 | 181 | data_loader_train = torch.utils.data.DataLoader( 182 | dataset_train, sampler=sampler_train, 183 | batch_size=args.batch_size, 184 | num_workers=args.num_workers, 185 | pin_memory=args.pin_mem, 186 | drop_last=True, 187 | worker_init_fn=utils.seed_worker 188 | ) 189 | 190 | model.to(device) 191 | model_without_ddp = model 192 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 193 | 194 | print("Model = %s" % str(model_without_ddp)) 195 | print('number of params: {} M'.format(n_parameters / 1e6)) 196 | 197 | total_batch_size = args.batch_size * utils.get_world_size() 198 | 199 | args.lr = args.lr * total_batch_size / 256 200 | args.min_lr = args.min_lr * total_batch_size / 256 201 | args.warmup_lr = args.warmup_lr * total_batch_size / 256 202 | print("LR = %.8f" % args.lr) 203 | print("Batch size = %d" % total_batch_size) 204 | print("Number of training steps = %d" % num_training_steps_per_epoch) 205 | print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch)) 206 | 207 | if args.distributed: 208 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 209 | model_without_ddp = model.module 210 | 211 | optimizer = create_optimizer( 212 | args, model_without_ddp) 213 | loss_scaler = NativeScaler() 214 | 215 | print("Use step level LR & WD scheduler!") 216 | lr_schedule_values = utils.cosine_scheduler( 217 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 218 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 219 | ) 220 | if args.weight_decay_end is None: 221 | args.weight_decay_end = args.weight_decay 222 | wd_schedule_values = utils.cosine_scheduler( 223 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 224 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 225 | 226 | utils.auto_load_model( 227 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 228 | torch.cuda.empty_cache() 229 | print(f"Start training for {args.epochs} epochs") 230 | start_time = time.time() 231 | for epoch in range(args.start_epoch, args.epochs): 232 | if args.distributed: 233 | data_loader_train.sampler.set_epoch(epoch) 234 | if log_writer is not None: 235 | log_writer.set_step(epoch * num_training_steps_per_epoch) 236 | train_stats = train_one_epoch( 237 | model, data_loader_train, 238 | optimizer, device, epoch, loss_scaler, 239 | args.clip_grad, log_writer=log_writer, 240 | start_steps=epoch * num_training_steps_per_epoch, 241 | lr_schedule_values=lr_schedule_values, 242 | wd_schedule_values=wd_schedule_values, 243 | patch_size=patch_size[0], 244 | normlize_target=args.normlize_target, 245 | mask_type=args.mask_type 246 | ) 247 | if args.output_dir: 248 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 249 | utils.save_model( 250 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 251 | loss_scaler=loss_scaler, epoch=epoch) 252 | 253 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 254 | 'epoch': epoch, 'n_parameters': n_parameters} 255 | 256 | if args.output_dir and utils.is_main_process(): 257 | if log_writer is not None: 258 | log_writer.flush() 259 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 260 | f.write(json.dumps(log_stats) + "\n") 261 | 262 | total_time = time.time() - start_time 263 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 264 | print('Training time {}'.format(total_time_str)) 265 | 266 | 267 | if __name__ == '__main__': 268 | opts = get_args() 269 | if opts.output_dir: 270 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True) 271 | main(opts) 272 | -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 27 | return y1 * lam + y2 * (1. - lam) 28 | 29 | 30 | def rand_bbox(img_shape, lam, margin=0., count=None): 31 | """ Standard CutMix bounding-box 32 | Generates a random square bbox based on lambda value. This impl includes 33 | support for enforcing a border margin as percent of bbox dimensions. 34 | 35 | Args: 36 | img_shape (tuple): Image shape as tuple 37 | lam (float): Cutmix lambda value 38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 39 | count (int): Number of bbox to generate 40 | """ 41 | ratio = np.sqrt(1 - lam) 42 | img_h, img_w = img_shape[-2:] 43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 47 | yl = np.clip(cy - cut_h // 2, 0, img_h) 48 | yh = np.clip(cy + cut_h // 2, 0, img_h) 49 | xl = np.clip(cx - cut_w // 2, 0, img_w) 50 | xh = np.clip(cx + cut_w // 2, 0, img_w) 51 | return yl, yh, xl, xh 52 | 53 | 54 | def rand_bbox_minmax(img_shape, minmax, count=None): 55 | """ Min-Max CutMix bounding-box 56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 57 | based on min/max percent values applied to each dimension of the input image. 58 | 59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 60 | 61 | Args: 62 | img_shape (tuple): Image shape as tuple 63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 64 | count (int): Number of bbox to generate 65 | """ 66 | assert len(minmax) == 2 67 | img_h, img_w = img_shape[-2:] 68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 70 | yl = np.random.randint(0, img_h - cut_h, size=count) 71 | xl = np.random.randint(0, img_w - cut_w, size=count) 72 | yu = yl + cut_h 73 | xu = xl + cut_w 74 | return yl, yu, xl, xu 75 | 76 | 77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 78 | """ Generate bbox and apply lambda correction. 79 | """ 80 | if ratio_minmax is not None: 81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 82 | else: 83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 84 | if correct_lam or ratio_minmax is not None: 85 | bbox_area = (yu - yl) * (xu - xl) 86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 87 | return (yl, yu, xl, xu), lam 88 | 89 | 90 | class Mixup: 91 | """ Mixup/Cutmix that applies different params to each element or whole batch 92 | 93 | Args: 94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 97 | prob (float): probability of applying mixup or cutmix per batch or element 98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 101 | label_smoothing (float): apply label smoothing to the mixed target tensor 102 | num_classes (int): number of classes for target 103 | """ 104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 106 | self.mixup_alpha = mixup_alpha 107 | self.cutmix_alpha = cutmix_alpha 108 | self.cutmix_minmax = cutmix_minmax 109 | if self.cutmix_minmax is not None: 110 | assert len(self.cutmix_minmax) == 2 111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 112 | self.cutmix_alpha = 1.0 113 | self.mix_prob = prob 114 | self.switch_prob = switch_prob 115 | self.label_smoothing = label_smoothing 116 | self.num_classes = num_classes 117 | self.mode = mode 118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 120 | 121 | def _params_per_elem(self, batch_size): 122 | lam = np.ones(batch_size, dtype=np.float32) 123 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 124 | if self.mixup_enabled: 125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 127 | lam_mix = np.where( 128 | use_cutmix, 129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 131 | elif self.mixup_alpha > 0.: 132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 133 | elif self.cutmix_alpha > 0.: 134 | use_cutmix = np.ones(batch_size, dtype=np.bool) 135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 136 | else: 137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 139 | return lam, use_cutmix 140 | 141 | def _params_per_batch(self): 142 | lam = 1. 143 | use_cutmix = False 144 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 146 | use_cutmix = np.random.rand() < self.switch_prob 147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 148 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 149 | elif self.mixup_alpha > 0.: 150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 151 | elif self.cutmix_alpha > 0.: 152 | use_cutmix = True 153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 154 | else: 155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 156 | lam = float(lam_mix) 157 | return lam, use_cutmix 158 | 159 | def _mix_elem(self, x): 160 | batch_size = len(x) 161 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 162 | x_orig = x.clone() # need to keep an unmodified original for mixing source 163 | for i in range(batch_size): 164 | j = batch_size - i - 1 165 | lam = lam_batch[i] 166 | if lam != 1.: 167 | if use_cutmix[i]: 168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh] 171 | lam_batch[i] = lam 172 | else: 173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 175 | 176 | def _mix_pair(self, x): 177 | batch_size = len(x) 178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 179 | x_orig = x.clone() # need to keep an unmodified original for mixing source 180 | for i in range(batch_size // 2): 181 | j = batch_size - i - 1 182 | lam = lam_batch[i] 183 | if lam != 1.: 184 | if use_cutmix[i]: 185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 189 | lam_batch[i] = lam 190 | else: 191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 195 | 196 | def _mix_batch(self, x): 197 | lam, use_cutmix = self._params_per_batch() 198 | if lam == 1.: 199 | return 1. 200 | if use_cutmix: 201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 204 | else: 205 | x_flipped = x.flip(0).mul_(1. - lam) 206 | x.mul_(lam).add_(x_flipped) 207 | return lam 208 | 209 | def __call__(self, x, target): 210 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 211 | if self.mode == 'elem': 212 | lam = self._mix_elem(x) 213 | elif self.mode == 'pair': 214 | lam = self._mix_pair(x) 215 | else: 216 | lam = self._mix_batch(x) 217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 218 | return x, target 219 | 220 | 221 | class FastCollateMixup(Mixup): 222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 223 | 224 | A Mixup impl that's performed while collating the batches. 225 | """ 226 | 227 | def _mix_elem_collate(self, output, batch, half=False): 228 | batch_size = len(batch) 229 | num_elem = batch_size // 2 if half else batch_size 230 | assert len(output) == num_elem 231 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 232 | for i in range(num_elem): 233 | j = batch_size - i - 1 234 | lam = lam_batch[i] 235 | mixed = batch[i][0] 236 | if lam != 1.: 237 | if use_cutmix[i]: 238 | if not half: 239 | mixed = mixed.copy() 240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 243 | lam_batch[i] = lam 244 | else: 245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 246 | np.rint(mixed, out=mixed) 247 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 248 | if half: 249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 250 | return torch.tensor(lam_batch).unsqueeze(1) 251 | 252 | def _mix_pair_collate(self, output, batch): 253 | batch_size = len(batch) 254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 255 | for i in range(batch_size // 2): 256 | j = batch_size - i - 1 257 | lam = lam_batch[i] 258 | mixed_i = batch[i][0] 259 | mixed_j = batch[j][0] 260 | assert 0 <= lam <= 1.0 261 | if lam < 1.: 262 | if use_cutmix[i]: 263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 267 | mixed_j[:, yl:yh, xl:xh] = patch_i 268 | lam_batch[i] = lam 269 | else: 270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 272 | mixed_i = mixed_temp 273 | np.rint(mixed_j, out=mixed_j) 274 | np.rint(mixed_i, out=mixed_i) 275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 278 | return torch.tensor(lam_batch).unsqueeze(1) 279 | 280 | def _mix_batch_collate(self, output, batch): 281 | batch_size = len(batch) 282 | lam, use_cutmix = self._params_per_batch() 283 | if use_cutmix: 284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 286 | for i in range(batch_size): 287 | j = batch_size - i - 1 288 | mixed = batch[i][0] 289 | if lam != 1.: 290 | if use_cutmix: 291 | mixed = mixed.copy() # don't want to modify the original while iterating 292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh] 293 | else: 294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 295 | np.rint(mixed, out=mixed) 296 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 297 | return lam 298 | 299 | def __call__(self, batch, _=None): 300 | batch_size = len(batch) 301 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 302 | half = 'half' in self.mode 303 | if half: 304 | batch_size //= 2 305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 306 | if self.mode == 'elem' or self.mode == 'half': 307 | lam = self._mix_elem_collate(output, batch, half=half) 308 | elif self.mode == 'pair': 309 | lam = self._mix_pair_collate(output, batch) 310 | else: 311 | lam = self._mix_batch_collate(output, batch) 312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 314 | target = target[:batch_size] 315 | return output, target 316 | 317 | -------------------------------------------------------------------------------- /modeling_finetune.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | import torch.utils.checkpoint as checkpoint 9 | 10 | 11 | def _cfg(url='', **kwargs): 12 | return { 13 | 'url': url, 14 | 'num_classes': 101, 'input_size': (3, 224, 224), 'pool_size': None, 15 | 'crop_pct': .9, 'interpolation': 'bicubic', 16 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 17 | **kwargs 18 | } 19 | 20 | 21 | class DropPath(nn.Module): 22 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 23 | """ 24 | def __init__(self, drop_prob=None): 25 | super(DropPath, self).__init__() 26 | self.drop_prob = drop_prob 27 | 28 | def forward(self, x): 29 | return drop_path(x, self.drop_prob, self.training) 30 | 31 | def extra_repr(self) -> str: 32 | return 'p={}'.format(self.drop_prob) 33 | 34 | 35 | class Mlp(nn.Module): 36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 37 | super().__init__() 38 | out_features = out_features or in_features 39 | hidden_features = hidden_features or in_features 40 | self.fc1 = nn.Linear(in_features, hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | # x = self.drop(x) 49 | # commit this for the orignal BERT implement 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__( 57 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 58 | proj_drop=0., attn_head_dim=None): 59 | super().__init__() 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | if attn_head_dim is not None: 63 | head_dim = attn_head_dim 64 | all_head_dim = head_dim * self.num_heads 65 | self.scale = qk_scale or head_dim ** -0.5 66 | 67 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 68 | if qkv_bias: 69 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 70 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 71 | else: 72 | self.q_bias = None 73 | self.v_bias = None 74 | 75 | self.attn_drop = nn.Dropout(attn_drop) 76 | self.proj = nn.Linear(all_head_dim, dim) 77 | self.proj_drop = nn.Dropout(proj_drop) 78 | 79 | def forward(self, x): 80 | B, N, C = x.shape 81 | qkv_bias = None 82 | if self.q_bias is not None: 83 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 84 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 85 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 86 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 87 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 88 | 89 | q = q * self.scale 90 | attn = (q @ k.transpose(-2, -1)) 91 | 92 | 93 | attn = attn.softmax(dim=-1) 94 | attn = self.attn_drop(attn) 95 | 96 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 97 | x = self.proj(x) 98 | x = self.proj_drop(x) 99 | return x 100 | 101 | 102 | class Block(nn.Module): 103 | 104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 105 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 106 | attn_head_dim=None): 107 | super().__init__() 108 | self.norm1 = norm_layer(dim) 109 | self.attn = Attention( 110 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 111 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 114 | self.norm2 = norm_layer(dim) 115 | mlp_hidden_dim = int(dim * mlp_ratio) 116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 117 | 118 | if init_values > 0: 119 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 120 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 121 | else: 122 | self.gamma_1, self.gamma_2 = None, None 123 | 124 | def forward(self, x): 125 | if self.gamma_1 is None: 126 | x = x + self.drop_path(self.attn(self.norm1(x))) 127 | x = x + self.drop_path(self.mlp(self.norm2(x))) 128 | else: 129 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 130 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 131 | return x 132 | 133 | 134 | class PatchEmbed(nn.Module): 135 | """ Image to Patch Embedding 136 | """ 137 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): 138 | super().__init__() 139 | img_size = to_2tuple(img_size) 140 | patch_size = to_2tuple(patch_size) 141 | self.tubelet_size = int(tubelet_size) 142 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size) 143 | self.img_size = img_size 144 | self.patch_size = patch_size 145 | self.num_patches = num_patches 146 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim, 147 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]), 148 | stride=(self.tubelet_size, patch_size[0], patch_size[1])) 149 | 150 | def forward(self, x, **kwargs): 151 | B, C, T, H, W = x.shape 152 | # FIXME look at relaxing size constraints 153 | assert H == self.img_size[0] and W == self.img_size[1], \ 154 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 155 | x = self.proj(x).flatten(2).transpose(1, 2) 156 | return x 157 | 158 | # sin-cos position encoding 159 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 160 | def get_sinusoid_encoding_table(n_position, d_hid): 161 | ''' Sinusoid position encoding table ''' 162 | # TODO: make it with torch instead of numpy 163 | def get_position_angle_vec(position): 164 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 165 | 166 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 167 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 168 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 169 | 170 | return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0) 171 | 172 | 173 | class VisionTransformer(nn.Module): 174 | """ Vision Transformer with support for patch or hybrid CNN input stage 175 | """ 176 | def __init__(self, 177 | img_size=224, 178 | patch_size=16, 179 | in_chans=3, 180 | num_classes=1000, 181 | embed_dim=768, 182 | depth=12, 183 | num_heads=12, 184 | mlp_ratio=4., 185 | qkv_bias=False, 186 | qk_scale=None, 187 | fc_drop_rate=0., 188 | drop_rate=0., 189 | attn_drop_rate=0., 190 | drop_path_rate=0., 191 | norm_layer=nn.LayerNorm, 192 | init_values=0., 193 | use_learnable_pos_emb=False, 194 | init_scale=0., 195 | all_frames=16, 196 | tubelet_size=2, 197 | use_checkpoint=False, 198 | use_mean_pooling=True, 199 | mcm=False, 200 | mcm_ratio=0.4): 201 | super().__init__() 202 | self.num_classes = num_classes 203 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 204 | self.tubelet_size = tubelet_size 205 | self.patch_embed = PatchEmbed( 206 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size) 207 | num_patches = self.patch_embed.num_patches 208 | self.use_checkpoint = use_checkpoint 209 | 210 | if use_learnable_pos_emb: 211 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 212 | else: 213 | # sine-cosine positional embeddings is on the way 214 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 215 | 216 | self.pos_drop = nn.Dropout(p=drop_rate) 217 | 218 | 219 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 220 | self.blocks = nn.ModuleList([ 221 | Block( 222 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 223 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 224 | init_values=init_values) 225 | for i in range(depth)]) 226 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 227 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 228 | self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity() 229 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 230 | 231 | if use_learnable_pos_emb: 232 | trunc_normal_(self.pos_embed, std=.02) 233 | 234 | trunc_normal_(self.head.weight, std=.02) 235 | self.apply(self._init_weights) 236 | 237 | self.head.weight.data.mul_(init_scale) 238 | self.head.bias.data.mul_(init_scale) 239 | 240 | self.mcm = mcm 241 | self.mcm_ratio = mcm_ratio 242 | 243 | def _init_weights(self, m): 244 | if isinstance(m, nn.Linear): 245 | trunc_normal_(m.weight, std=.02) 246 | if isinstance(m, nn.Linear) and m.bias is not None: 247 | nn.init.constant_(m.bias, 0) 248 | elif isinstance(m, nn.LayerNorm): 249 | nn.init.constant_(m.bias, 0) 250 | nn.init.constant_(m.weight, 1.0) 251 | 252 | def get_num_layers(self): 253 | return len(self.blocks) 254 | 255 | @torch.jit.ignore 256 | def no_weight_decay(self): 257 | return {'pos_embed', 'cls_token'} 258 | 259 | def get_classifier(self): 260 | return self.head 261 | 262 | def reset_classifier(self, num_classes, global_pool=''): 263 | self.num_classes = num_classes 264 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 265 | 266 | def forward_features(self, x): 267 | N, _, T, _, _ = x.shape 268 | x = self.patch_embed(x) 269 | if self.mcm: 270 | mask = self.mcm_step(x, N, x.shape[1] // T * 2, T//2, x.shape[2]) 271 | B, _, C = x.size() 272 | 273 | if self.pos_embed is not None: 274 | x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() 275 | if self.mcm: 276 | x = x[~mask].reshape(B, -1, C) 277 | x = self.pos_drop(x) 278 | 279 | if self.use_checkpoint: 280 | for blk in self.blocks: 281 | x = checkpoint.checkpoint(blk, x) 282 | else: 283 | for blk in self.blocks: 284 | x = blk(x) 285 | 286 | x = self.norm(x) 287 | if self.fc_norm is not None: 288 | return self.fc_norm(x.mean(1)) 289 | else: 290 | return x[:, 0] 291 | 292 | def mcm_step(self, x, N, L, T, D): 293 | patch_embed_vectors = x.detach().clone().reshape(shape=(N, T, L, D)) 294 | 295 | distance = torch.norm(patch_embed_vectors[:,:7,:,:] - patch_embed_vectors[:,1:,:,:], p=2, dim=3) 296 | importance = torch.cat((distance[:,0,:], distance.flatten(1)), dim=1) 297 | 298 | ids_sorted = torch.argsort(importance, dim=1, descending=True) 299 | num_compressed_tokens = int((1 - self.mcm_ratio) * (T * L)) 300 | 301 | ids_restore = torch.argsort(ids_sorted, dim=1) 302 | # keep the first subset 303 | ids_keep = ids_sorted[:, :num_compressed_tokens] 304 | 305 | input_mask = torch.ones([N, T * L], device=x.device) 306 | input_mask[:, :num_compressed_tokens] = 0 307 | 308 | # unshuffle to get the binary mask 309 | input_mask = torch.gather(input_mask, dim=1, index=ids_restore) 310 | 311 | return input_mask.to(torch.bool) 312 | 313 | def forward(self, x): 314 | x = self.forward_features(x) 315 | x = self.head(self.fc_dropout(x)) 316 | return x 317 | 318 | 319 | @register_model 320 | def vit_small_patch16_224(pretrained=False, **kwargs): 321 | model = VisionTransformer( 322 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 323 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 324 | model.default_cfg = _cfg() 325 | return model 326 | 327 | 328 | @register_model 329 | def vit_base_patch16_224(pretrained=False, **kwargs): 330 | model = VisionTransformer( 331 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 332 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 333 | model.default_cfg = _cfg() 334 | return model 335 | 336 | 337 | @register_model 338 | def vit_base_patch16_384(pretrained=False, **kwargs): 339 | model = VisionTransformer( 340 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 341 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 342 | model.default_cfg = _cfg() 343 | return model 344 | 345 | 346 | @register_model 347 | def vit_large_patch16_224(pretrained=False, **kwargs): 348 | model = VisionTransformer( 349 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 350 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 351 | model.default_cfg = _cfg() 352 | return model 353 | 354 | 355 | @register_model 356 | def vit_large_patch16_384(pretrained=False, **kwargs): 357 | model = VisionTransformer( 358 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 359 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 360 | model.default_cfg = _cfg() 361 | return model 362 | 363 | 364 | @register_model 365 | def vit_large_patch16_512(pretrained=False, **kwargs): 366 | model = VisionTransformer( 367 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 368 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 369 | model.default_cfg = _cfg() 370 | return model 371 | 372 | 373 | @register_model 374 | def vit_huge_patch16_224(pretrained=False, **kwargs): 375 | model = VisionTransformer( 376 | patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 377 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 378 | model.default_cfg = _cfg() 379 | return model 380 | -------------------------------------------------------------------------------- /modeling_pretrain.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as checkpoint 6 | from functools import partial 7 | 8 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 11 | 12 | 13 | 14 | def trunc_normal_(tensor, mean=0., std=1.): 15 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 16 | 17 | 18 | __all__ = [ 19 | 'pretrain_videoms_small_patch16_224', 20 | 'pretrain_videoms_base_patch16_224', 21 | 'pretrain_videoms_large_patch16_224', 22 | 'pretrain_videoms_huge_patch16_224', 23 | ] 24 | 25 | 26 | class PretrainVisionTransformerEncoder(nn.Module): 27 | """ Vision Transformer with support for patch or hybrid CNN input stage 28 | """ 29 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 30 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 31 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False, 32 | use_learnable_pos_emb=False, mcm=False, mcm_ratio=0.7, masking_ratio=0.9): 33 | super().__init__() 34 | self.num_classes = num_classes 35 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 36 | self.patch_embed = PatchEmbed( 37 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size) 38 | num_patches = self.patch_embed.num_patches 39 | self.use_checkpoint = use_checkpoint 40 | 41 | 42 | # TODO: Add the cls token 43 | if use_learnable_pos_emb: 44 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 45 | else: 46 | # sine-cosine positional embeddings 47 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 48 | 49 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 50 | self.blocks = nn.ModuleList([ 51 | Block( 52 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 53 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 54 | init_values=init_values) 55 | for i in range(depth)]) 56 | self.norm = norm_layer(embed_dim) 57 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 58 | 59 | if use_learnable_pos_emb: 60 | trunc_normal_(self.pos_embed, std=.02) 61 | 62 | self.apply(self._init_weights) 63 | 64 | self.mcm = mcm 65 | self.mcm_ratio = mcm_ratio 66 | self.masking_ratio = masking_ratio 67 | 68 | 69 | def _init_weights(self, m): 70 | if isinstance(m, nn.Linear): 71 | nn.init.xavier_uniform_(m.weight) 72 | if isinstance(m, nn.Linear) and m.bias is not None: 73 | nn.init.constant_(m.bias, 0) 74 | elif isinstance(m, nn.LayerNorm): 75 | nn.init.constant_(m.bias, 0) 76 | nn.init.constant_(m.weight, 1.0) 77 | 78 | def get_num_layers(self): 79 | return len(self.blocks) 80 | 81 | @torch.jit.ignore 82 | def no_weight_decay(self): 83 | return {'pos_embed', 'cls_token'} 84 | 85 | def get_classifier(self): 86 | return self.head 87 | 88 | def reset_classifier(self, num_classes, global_pool=''): 89 | self.num_classes = num_classes 90 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 91 | 92 | def forward_features(self, x, mask): 93 | N, _, T, _, _ = x.shape 94 | x = self.patch_embed(x) 95 | masks = (None, None) 96 | if self.mcm: 97 | mask, target_mask = self.mcm_step(x, N, x.shape[1] // T * 2, T // 2, x.shape[2]) 98 | masks = (mask, target_mask) 99 | 100 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() 101 | 102 | B, _, C = x.shape 103 | x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible 104 | 105 | if self.use_checkpoint: 106 | for blk in self.blocks: 107 | x_vis = checkpoint.checkpoint(blk, x_vis) 108 | else: 109 | for blk in self.blocks: 110 | x_vis = blk(x_vis) 111 | 112 | x_vis = self.norm(x_vis) 113 | return x_vis, masks 114 | 115 | def mcm_step(self, x, N, L, T, D): 116 | patch_embed_vectors = x.detach().clone().reshape(shape=(N, T, L, D)) 117 | 118 | distance = torch.norm(patch_embed_vectors[:,:7,:,:] - patch_embed_vectors[:,1:,:,:], p=2, dim=3) 119 | importance = torch.cat((distance[:,0,:], distance.flatten(1)), dim=1) 120 | 121 | ids_sorted = torch.argsort(importance, dim=1, descending=True) 122 | num_compressed_tokens = int((1 - self.mcm_ratio) * (T * L)) 123 | num_input_tokens = int((1 - self.masking_ratio) * (T * L)) 124 | noise = torch.rand(N, num_compressed_tokens, device=x.device) # noise in [0, 1] 125 | noise_id_shuffled = torch.argsort(noise, dim=1) 126 | 127 | ids_sorted[:,:num_compressed_tokens] = torch.gather(ids_sorted[:,:num_compressed_tokens], dim=1, index=noise_id_shuffled) 128 | 129 | ids_restore = torch.argsort(ids_sorted, dim=1) 130 | 131 | ids_keep = ids_sorted[:, :num_input_tokens] 132 | ids_recon = ids_sorted[:, num_input_tokens:num_compressed_tokens] 133 | 134 | input_mask = torch.ones([N, T * L], device=x.device) 135 | input_mask[:, :num_input_tokens] = 0 136 | 137 | target_mask = torch.ones([N, T * L], device=x.device) 138 | target_mask[:, num_input_tokens:num_compressed_tokens] = 0 139 | 140 | 141 | input_mask = torch.gather(input_mask, dim=1, index=ids_restore) 142 | target_mask = torch.gather(target_mask, dim=1, index=ids_restore) 143 | 144 | return input_mask.to(torch.bool), target_mask.to(torch.bool) 145 | 146 | def forward(self, x, mask): 147 | x, masks = self.forward_features(x, mask) 148 | x = self.head(x) 149 | return x, masks 150 | 151 | class PretrainVisionTransformerDecoder(nn.Module): 152 | """ Vision Transformer with support for patch or hybrid CNN input stage 153 | """ 154 | def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 155 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 156 | norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False 157 | ): 158 | super().__init__() 159 | self.num_classes = num_classes 160 | assert num_classes == 3 * tubelet_size * patch_size ** 2 161 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 162 | self.patch_size = patch_size 163 | self.use_checkpoint = use_checkpoint 164 | 165 | 166 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 167 | self.blocks = nn.ModuleList([ 168 | Block( 169 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 170 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 171 | init_values=init_values) 172 | for i in range(depth)]) 173 | self.norm = norm_layer(embed_dim) 174 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 175 | 176 | self.apply(self._init_weights) 177 | 178 | 179 | def _init_weights(self, m): 180 | if isinstance(m, nn.Linear): 181 | nn.init.xavier_uniform_(m.weight) 182 | if isinstance(m, nn.Linear) and m.bias is not None: 183 | nn.init.constant_(m.bias, 0) 184 | elif isinstance(m, nn.LayerNorm): 185 | nn.init.constant_(m.bias, 0) 186 | nn.init.constant_(m.weight, 1.0) 187 | 188 | def get_num_layers(self): 189 | return len(self.blocks) 190 | 191 | @torch.jit.ignore 192 | def no_weight_decay(self): 193 | return {'pos_embed', 'cls_token'} 194 | 195 | def get_classifier(self): 196 | return self.head 197 | 198 | def reset_classifier(self, num_classes, global_pool=''): 199 | self.num_classes = num_classes 200 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 201 | 202 | def forward(self, x, return_token_num): 203 | if self.use_checkpoint: 204 | for blk in self.blocks: 205 | x = checkpoint.checkpoint(blk, x) 206 | else: 207 | for blk in self.blocks: 208 | x = blk(x) 209 | 210 | if return_token_num > 0: 211 | x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels 212 | else: 213 | x = self.head(self.norm(x)) 214 | 215 | return x 216 | 217 | class PretrainVisionTransformer(nn.Module): 218 | """ Vision Transformer with support for patch or hybrid CNN input stage 219 | """ 220 | def __init__(self, 221 | img_size=224, 222 | patch_size=16, 223 | encoder_in_chans=3, 224 | encoder_num_classes=0, 225 | encoder_embed_dim=768, 226 | encoder_depth=12, 227 | encoder_num_heads=12, 228 | decoder_num_classes=1536, # decoder_num_classes=768, 229 | decoder_embed_dim=512, 230 | decoder_depth=8, 231 | decoder_num_heads=8, 232 | mlp_ratio=4., 233 | qkv_bias=False, 234 | qk_scale=None, 235 | drop_rate=0., 236 | attn_drop_rate=0., 237 | drop_path_rate=0., 238 | norm_layer=nn.LayerNorm, 239 | init_values=0., 240 | use_learnable_pos_emb=False, 241 | use_checkpoint=False, 242 | tubelet_size=2, 243 | num_classes=0, # avoid the error from create_fn in timm 244 | in_chans=0, # avoid the error from create_fn in timm 245 | motion_centric_masking=False, 246 | motion_centric_masking_ratio=0.7, 247 | masking_ratio=0.9 248 | ): 249 | super().__init__() 250 | self.encoder = PretrainVisionTransformerEncoder( 251 | img_size=img_size, 252 | patch_size=patch_size, 253 | in_chans=encoder_in_chans, 254 | num_classes=encoder_num_classes, 255 | embed_dim=encoder_embed_dim, 256 | depth=encoder_depth, 257 | num_heads=encoder_num_heads, 258 | mlp_ratio=mlp_ratio, 259 | qkv_bias=qkv_bias, 260 | qk_scale=qk_scale, 261 | drop_rate=drop_rate, 262 | attn_drop_rate=attn_drop_rate, 263 | drop_path_rate=drop_path_rate, 264 | norm_layer=norm_layer, 265 | init_values=init_values, 266 | tubelet_size=tubelet_size, 267 | use_checkpoint=use_checkpoint, 268 | use_learnable_pos_emb=use_learnable_pos_emb, 269 | mcm=motion_centric_masking, 270 | mcm_ratio=motion_centric_masking_ratio, 271 | masking_ratio=masking_ratio) 272 | 273 | self.decoder = PretrainVisionTransformerDecoder( 274 | patch_size=patch_size, 275 | num_patches=self.encoder.patch_embed.num_patches, 276 | num_classes=decoder_num_classes, 277 | embed_dim=decoder_embed_dim, 278 | depth=decoder_depth, 279 | num_heads=decoder_num_heads, 280 | mlp_ratio=mlp_ratio, 281 | qkv_bias=qkv_bias, 282 | qk_scale=qk_scale, 283 | drop_rate=drop_rate, 284 | attn_drop_rate=attn_drop_rate, 285 | drop_path_rate=drop_path_rate, 286 | norm_layer=norm_layer, 287 | init_values=init_values, 288 | tubelet_size=tubelet_size, 289 | use_checkpoint=use_checkpoint) 290 | 291 | self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False) 292 | 293 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 294 | 295 | self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim) 296 | 297 | trunc_normal_(self.mask_token, std=.02) 298 | 299 | 300 | def _init_weights(self, m): 301 | if isinstance(m, nn.Linear): 302 | nn.init.xavier_uniform_(m.weight) 303 | if isinstance(m, nn.Linear) and m.bias is not None: 304 | nn.init.constant_(m.bias, 0) 305 | elif isinstance(m, nn.LayerNorm): 306 | nn.init.constant_(m.bias, 0) 307 | nn.init.constant_(m.weight, 1.0) 308 | 309 | def get_num_layers(self): 310 | return len(self.blocks) 311 | 312 | @torch.jit.ignore 313 | def no_weight_decay(self): 314 | return {'pos_embed', 'cls_token', 'mask_token'} 315 | 316 | def forward(self, x, mask): 317 | _, _, T, _, _ = x.shape 318 | x_vis, masks = self.encoder(x, mask) 319 | x_vis = self.encoder_to_decoder(x_vis) 320 | B, N, C = x_vis.shape 321 | expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() 322 | if masks[0] is None: 323 | pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) 324 | pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) 325 | x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) 326 | x = self.decoder(x_full, pos_emd_mask.shape[1]) 327 | else: 328 | pos_emd_vis = expand_pos_embed[~masks[0]].reshape(B, -1, C) 329 | pos_emd_mask = expand_pos_embed[~masks[1]].reshape(B, -1, C) 330 | x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] 331 | x = self.decoder(x_full, pos_emd_mask.shape[1]) 332 | return x, masks 333 | 334 | @register_model 335 | def pretrain_videoms_small_patch16_224(pretrained=False, **kwargs): 336 | model = PretrainVisionTransformer( 337 | img_size=224, 338 | patch_size=16, 339 | encoder_embed_dim=384, 340 | encoder_depth=12, 341 | encoder_num_heads=6, 342 | encoder_num_classes=0, 343 | decoder_num_classes=1536, 344 | decoder_embed_dim=192, 345 | decoder_num_heads=3, 346 | mlp_ratio=4, 347 | qkv_bias=True, 348 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 349 | **kwargs) 350 | model.default_cfg = _cfg() 351 | if pretrained: 352 | checkpoint = torch.load( 353 | kwargs["init_ckpt"], map_location="cpu" 354 | ) 355 | model.load_state_dict(checkpoint["model"]) 356 | return model 357 | 358 | @register_model 359 | def pretrain_videoms_base_patch16_224(pretrained=False, **kwargs): 360 | model = PretrainVisionTransformer( 361 | img_size=224, 362 | patch_size=16, 363 | encoder_embed_dim=768, 364 | encoder_depth=12, 365 | encoder_num_heads=12, 366 | encoder_num_classes=0, 367 | decoder_num_classes=1536, 368 | decoder_embed_dim=384, 369 | decoder_num_heads=6, 370 | mlp_ratio=4, 371 | qkv_bias=True, 372 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 373 | **kwargs) 374 | model.default_cfg = _cfg() 375 | if pretrained: 376 | checkpoint = torch.load( 377 | kwargs["init_ckpt"], map_location="cpu" 378 | ) 379 | model.load_state_dict(checkpoint["model"]) 380 | return model 381 | 382 | @register_model 383 | def pretrain_videoms_large_patch16_224(pretrained=False, **kwargs): 384 | model = PretrainVisionTransformer( 385 | img_size=224, 386 | patch_size=16, 387 | encoder_embed_dim=1024, 388 | encoder_depth=24, 389 | encoder_num_heads=16, 390 | encoder_num_classes=0, 391 | decoder_num_classes=1536, 392 | decoder_embed_dim=512, 393 | decoder_num_heads=8, 394 | mlp_ratio=4, 395 | qkv_bias=True, 396 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 397 | **kwargs) 398 | model.default_cfg = _cfg() 399 | if pretrained: 400 | checkpoint = torch.load( 401 | kwargs["init_ckpt"], map_location="cpu" 402 | ) 403 | model.load_state_dict(checkpoint["model"]) 404 | return model 405 | 406 | @register_model 407 | def pretrain_videoms_huge_patch16_224(pretrained=False, **kwargs): 408 | model = PretrainVisionTransformer( 409 | img_size=224, 410 | patch_size=16, 411 | encoder_embed_dim=1280, 412 | encoder_depth=32, 413 | encoder_num_heads=16, 414 | encoder_num_classes=0, 415 | decoder_num_classes=1536, 416 | decoder_embed_dim=640, 417 | decoder_num_heads=8, 418 | mlp_ratio=4, 419 | qkv_bias=True, 420 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 421 | **kwargs) 422 | model.default_cfg = _cfg() 423 | if pretrained: 424 | checkpoint = torch.load( 425 | kwargs["init_ckpt"], map_location="cpu" 426 | ) 427 | model.load_state_dict(checkpoint["model"]) 428 | return model 429 | -------------------------------------------------------------------------------- /rand_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 4 | pulished under an Apache License 2.0. 5 | 6 | COMMENT FROM ORIGINAL: 7 | AutoAugment, RandAugment, and AugMix for PyTorch 8 | This code implements the searched ImageNet policies with various tweaks and 9 | improvements and does not include any of the search code. AA and RA 10 | Implementation adapted from: 11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 12 | AugMix adapted from: 13 | https://github.com/google-research/augmix 14 | Papers: 15 | AutoAugment: Learning Augmentation Policies from Data 16 | https://arxiv.org/abs/1805.09501 17 | Learning Data Augmentation Strategies for Object Detection 18 | https://arxiv.org/abs/1906.11172 19 | RandAugment: Practical automated data augmentation... 20 | https://arxiv.org/abs/1909.13719 21 | AugMix: A Simple Data Processing Method to Improve Robustness and 22 | Uncertainty https://arxiv.org/abs/1912.02781 23 | 24 | Hacked together by / Copyright 2020 Ross Wightman 25 | """ 26 | 27 | import math 28 | import numpy as np 29 | import random 30 | import re 31 | import PIL 32 | from PIL import Image, ImageEnhance, ImageOps 33 | 34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 35 | 36 | _FILL = (128, 128, 128) 37 | 38 | # This signifies the max integer that the controller RNN could predict for the 39 | # augmentation scheme. 40 | _MAX_LEVEL = 10.0 41 | 42 | _HPARAMS_DEFAULT = { 43 | "translate_const": 250, 44 | "img_mean": _FILL, 45 | } 46 | 47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 48 | 49 | 50 | def _interpolation(kwargs): 51 | interpolation = kwargs.pop("resample", Image.BILINEAR) 52 | if isinstance(interpolation, (list, tuple)): 53 | return random.choice(interpolation) 54 | else: 55 | return interpolation 56 | 57 | 58 | def _check_args_tf(kwargs): 59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 60 | kwargs.pop("fillcolor") 61 | kwargs["resample"] = _interpolation(kwargs) 62 | 63 | 64 | def shear_x(img, factor, **kwargs): 65 | _check_args_tf(kwargs) 66 | return img.transform( 67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 68 | ) 69 | 70 | 71 | def shear_y(img, factor, **kwargs): 72 | _check_args_tf(kwargs) 73 | return img.transform( 74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 75 | ) 76 | 77 | 78 | def translate_x_rel(img, pct, **kwargs): 79 | pixels = pct * img.size[0] 80 | _check_args_tf(kwargs) 81 | return img.transform( 82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 83 | ) 84 | 85 | 86 | def translate_y_rel(img, pct, **kwargs): 87 | pixels = pct * img.size[1] 88 | _check_args_tf(kwargs) 89 | return img.transform( 90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 91 | ) 92 | 93 | 94 | def translate_x_abs(img, pixels, **kwargs): 95 | _check_args_tf(kwargs) 96 | return img.transform( 97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 98 | ) 99 | 100 | 101 | def translate_y_abs(img, pixels, **kwargs): 102 | _check_args_tf(kwargs) 103 | return img.transform( 104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 105 | ) 106 | 107 | 108 | def rotate(img, degrees, **kwargs): 109 | _check_args_tf(kwargs) 110 | if _PIL_VER >= (5, 2): 111 | return img.rotate(degrees, **kwargs) 112 | elif _PIL_VER >= (5, 0): 113 | w, h = img.size 114 | post_trans = (0, 0) 115 | rotn_center = (w / 2.0, h / 2.0) 116 | angle = -math.radians(degrees) 117 | matrix = [ 118 | round(math.cos(angle), 15), 119 | round(math.sin(angle), 15), 120 | 0.0, 121 | round(-math.sin(angle), 15), 122 | round(math.cos(angle), 15), 123 | 0.0, 124 | ] 125 | 126 | def transform(x, y, matrix): 127 | (a, b, c, d, e, f) = matrix 128 | return a * x + b * y + c, d * x + e * y + f 129 | 130 | matrix[2], matrix[5] = transform( 131 | -rotn_center[0] - post_trans[0], 132 | -rotn_center[1] - post_trans[1], 133 | matrix, 134 | ) 135 | matrix[2] += rotn_center[0] 136 | matrix[5] += rotn_center[1] 137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 138 | else: 139 | return img.rotate(degrees, resample=kwargs["resample"]) 140 | 141 | 142 | def auto_contrast(img, **__): 143 | return ImageOps.autocontrast(img) 144 | 145 | 146 | def invert(img, **__): 147 | return ImageOps.invert(img) 148 | 149 | 150 | def equalize(img, **__): 151 | return ImageOps.equalize(img) 152 | 153 | 154 | def solarize(img, thresh, **__): 155 | return ImageOps.solarize(img, thresh) 156 | 157 | 158 | def solarize_add(img, add, thresh=128, **__): 159 | lut = [] 160 | for i in range(256): 161 | if i < thresh: 162 | lut.append(min(255, i + add)) 163 | else: 164 | lut.append(i) 165 | if img.mode in ("L", "RGB"): 166 | if img.mode == "RGB" and len(lut) == 256: 167 | lut = lut + lut + lut 168 | return img.point(lut) 169 | else: 170 | return img 171 | 172 | 173 | def posterize(img, bits_to_keep, **__): 174 | if bits_to_keep >= 8: 175 | return img 176 | return ImageOps.posterize(img, bits_to_keep) 177 | 178 | 179 | def contrast(img, factor, **__): 180 | return ImageEnhance.Contrast(img).enhance(factor) 181 | 182 | 183 | def color(img, factor, **__): 184 | return ImageEnhance.Color(img).enhance(factor) 185 | 186 | 187 | def brightness(img, factor, **__): 188 | return ImageEnhance.Brightness(img).enhance(factor) 189 | 190 | 191 | def sharpness(img, factor, **__): 192 | return ImageEnhance.Sharpness(img).enhance(factor) 193 | 194 | 195 | def _randomly_negate(v): 196 | """With 50% prob, negate the value""" 197 | return -v if random.random() > 0.5 else v 198 | 199 | 200 | def _rotate_level_to_arg(level, _hparams): 201 | # range [-30, 30] 202 | level = (level / _MAX_LEVEL) * 30.0 203 | level = _randomly_negate(level) 204 | return (level,) 205 | 206 | 207 | def _enhance_level_to_arg(level, _hparams): 208 | # range [0.1, 1.9] 209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 210 | 211 | 212 | def _enhance_increasing_level_to_arg(level, _hparams): 213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 214 | # range [0.1, 1.9] 215 | level = (level / _MAX_LEVEL) * 0.9 216 | level = 1.0 + _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _shear_level_to_arg(level, _hparams): 221 | # range [-0.3, 0.3] 222 | level = (level / _MAX_LEVEL) * 0.3 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_abs_level_to_arg(level, hparams): 228 | translate_const = hparams["translate_const"] 229 | level = (level / _MAX_LEVEL) * float(translate_const) 230 | level = _randomly_negate(level) 231 | return (level,) 232 | 233 | 234 | def _translate_rel_level_to_arg(level, hparams): 235 | # default range [-0.45, 0.45] 236 | translate_pct = hparams.get("translate_pct", 0.45) 237 | level = (level / _MAX_LEVEL) * translate_pct 238 | level = _randomly_negate(level) 239 | return (level,) 240 | 241 | 242 | def _posterize_level_to_arg(level, _hparams): 243 | # As per Tensorflow TPU EfficientNet impl 244 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 245 | # intensity/severity of augmentation decreases with level 246 | return (int((level / _MAX_LEVEL) * 4),) 247 | 248 | 249 | def _posterize_increasing_level_to_arg(level, hparams): 250 | # As per Tensorflow models research and UDA impl 251 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 252 | # intensity/severity of augmentation increases with level 253 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 254 | 255 | 256 | def _posterize_original_level_to_arg(level, _hparams): 257 | # As per original AutoAugment paper description 258 | # range [4, 8], 'keep 4 up to 8 MSB of image' 259 | # intensity/severity of augmentation decreases with level 260 | return (int((level / _MAX_LEVEL) * 4) + 4,) 261 | 262 | 263 | def _solarize_level_to_arg(level, _hparams): 264 | # range [0, 256] 265 | # intensity/severity of augmentation decreases with level 266 | return (int((level / _MAX_LEVEL) * 256),) 267 | 268 | 269 | def _solarize_increasing_level_to_arg(level, _hparams): 270 | # range [0, 256] 271 | # intensity/severity of augmentation increases with level 272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 273 | 274 | 275 | def _solarize_add_level_to_arg(level, _hparams): 276 | # range [0, 110] 277 | return (int((level / _MAX_LEVEL) * 110),) 278 | 279 | 280 | LEVEL_TO_ARG = { 281 | "AutoContrast": None, 282 | "Equalize": None, 283 | "Invert": None, 284 | "Rotate": _rotate_level_to_arg, 285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 286 | "Posterize": _posterize_level_to_arg, 287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 288 | "PosterizeOriginal": _posterize_original_level_to_arg, 289 | "Solarize": _solarize_level_to_arg, 290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 291 | "SolarizeAdd": _solarize_add_level_to_arg, 292 | "Color": _enhance_level_to_arg, 293 | "ColorIncreasing": _enhance_increasing_level_to_arg, 294 | "Contrast": _enhance_level_to_arg, 295 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 296 | "Brightness": _enhance_level_to_arg, 297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 298 | "Sharpness": _enhance_level_to_arg, 299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 300 | "ShearX": _shear_level_to_arg, 301 | "ShearY": _shear_level_to_arg, 302 | "TranslateX": _translate_abs_level_to_arg, 303 | "TranslateY": _translate_abs_level_to_arg, 304 | "TranslateXRel": _translate_rel_level_to_arg, 305 | "TranslateYRel": _translate_rel_level_to_arg, 306 | } 307 | 308 | 309 | NAME_TO_OP = { 310 | "AutoContrast": auto_contrast, 311 | "Equalize": equalize, 312 | "Invert": invert, 313 | "Rotate": rotate, 314 | "Posterize": posterize, 315 | "PosterizeIncreasing": posterize, 316 | "PosterizeOriginal": posterize, 317 | "Solarize": solarize, 318 | "SolarizeIncreasing": solarize, 319 | "SolarizeAdd": solarize_add, 320 | "Color": color, 321 | "ColorIncreasing": color, 322 | "Contrast": contrast, 323 | "ContrastIncreasing": contrast, 324 | "Brightness": brightness, 325 | "BrightnessIncreasing": brightness, 326 | "Sharpness": sharpness, 327 | "SharpnessIncreasing": sharpness, 328 | "ShearX": shear_x, 329 | "ShearY": shear_y, 330 | "TranslateX": translate_x_abs, 331 | "TranslateY": translate_y_abs, 332 | "TranslateXRel": translate_x_rel, 333 | "TranslateYRel": translate_y_rel, 334 | } 335 | 336 | 337 | class AugmentOp: 338 | """ 339 | Apply for video. 340 | """ 341 | 342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 343 | hparams = hparams or _HPARAMS_DEFAULT 344 | self.aug_fn = NAME_TO_OP[name] 345 | self.level_fn = LEVEL_TO_ARG[name] 346 | self.prob = prob 347 | self.magnitude = magnitude 348 | self.hparams = hparams.copy() 349 | self.kwargs = { 350 | "fillcolor": hparams["img_mean"] 351 | if "img_mean" in hparams 352 | else _FILL, 353 | "resample": hparams["interpolation"] 354 | if "interpolation" in hparams 355 | else _RANDOM_INTERPOLATION, 356 | } 357 | 358 | # If magnitude_std is > 0, we introduce some randomness 359 | # in the usually fixed policy and sample magnitude from a normal distribution 360 | # with mean `magnitude` and std-dev of `magnitude_std`. 361 | # NOTE This is my own hack, being tested, not in papers or reference impls. 362 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 363 | 364 | def __call__(self, img_list): 365 | if self.prob < 1.0 and random.random() > self.prob: 366 | return img_list 367 | magnitude = self.magnitude 368 | if self.magnitude_std and self.magnitude_std > 0: 369 | magnitude = random.gauss(magnitude, self.magnitude_std) 370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 371 | level_args = ( 372 | self.level_fn(magnitude, self.hparams) 373 | if self.level_fn is not None 374 | else () 375 | ) 376 | 377 | if isinstance(img_list, list): 378 | return [ 379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 380 | ] 381 | else: 382 | return self.aug_fn(img_list, *level_args, **self.kwargs) 383 | 384 | 385 | _RAND_TRANSFORMS = [ 386 | "AutoContrast", 387 | "Equalize", 388 | "Invert", 389 | "Rotate", 390 | "Posterize", 391 | "Solarize", 392 | "SolarizeAdd", 393 | "Color", 394 | "Contrast", 395 | "Brightness", 396 | "Sharpness", 397 | "ShearX", 398 | "ShearY", 399 | "TranslateXRel", 400 | "TranslateYRel", 401 | ] 402 | 403 | 404 | _RAND_INCREASING_TRANSFORMS = [ 405 | "AutoContrast", 406 | "Equalize", 407 | "Invert", 408 | "Rotate", 409 | "PosterizeIncreasing", 410 | "SolarizeIncreasing", 411 | "SolarizeAdd", 412 | "ColorIncreasing", 413 | "ContrastIncreasing", 414 | "BrightnessIncreasing", 415 | "SharpnessIncreasing", 416 | "ShearX", 417 | "ShearY", 418 | "TranslateXRel", 419 | "TranslateYRel", 420 | ] 421 | 422 | 423 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 424 | # They may not result in increased performance, but could likely be tuned to so. 425 | _RAND_CHOICE_WEIGHTS_0 = { 426 | "Rotate": 0.3, 427 | "ShearX": 0.2, 428 | "ShearY": 0.2, 429 | "TranslateXRel": 0.1, 430 | "TranslateYRel": 0.1, 431 | "Color": 0.025, 432 | "Sharpness": 0.025, 433 | "AutoContrast": 0.025, 434 | "Solarize": 0.005, 435 | "SolarizeAdd": 0.005, 436 | "Contrast": 0.005, 437 | "Brightness": 0.005, 438 | "Equalize": 0.005, 439 | "Posterize": 0, 440 | "Invert": 0, 441 | } 442 | 443 | 444 | def _select_rand_weights(weight_idx=0, transforms=None): 445 | transforms = transforms or _RAND_TRANSFORMS 446 | assert weight_idx == 0 # only one set of weights currently 447 | rand_weights = _RAND_CHOICE_WEIGHTS_0 448 | probs = [rand_weights[k] for k in transforms] 449 | probs /= np.sum(probs) 450 | return probs 451 | 452 | 453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 454 | hparams = hparams or _HPARAMS_DEFAULT 455 | transforms = transforms or _RAND_TRANSFORMS 456 | return [ 457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 458 | for name in transforms 459 | ] 460 | 461 | 462 | class RandAugment: 463 | def __init__(self, ops, num_layers=2, choice_weights=None): 464 | self.ops = ops 465 | self.num_layers = num_layers 466 | self.choice_weights = choice_weights 467 | 468 | def __call__(self, img): 469 | # no replacement when using weighted choice 470 | ops = np.random.choice( 471 | self.ops, 472 | self.num_layers, 473 | replace=self.choice_weights is None, 474 | p=self.choice_weights, 475 | ) 476 | for op in ops: 477 | img = op(img) 478 | return img 479 | 480 | 481 | def rand_augment_transform(config_str, hparams): 482 | """ 483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 484 | 485 | Create a RandAugment transform 486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 488 | sections, not order sepecific determine 489 | 'm' - integer magnitude of rand augment 490 | 'n' - integer num layers (number of transform ops selected per image) 491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 492 | 'mstd' - float std deviation of magnitude noise applied 493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 497 | :return: A PyTorch compatible Transform 498 | """ 499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 500 | num_layers = 2 # default to 2 ops per image 501 | weight_idx = None # default to no probability weights for op choice 502 | transforms = _RAND_TRANSFORMS 503 | config = config_str.split("-") 504 | assert config[0] == "rand" 505 | config = config[1:] 506 | for c in config: 507 | cs = re.split(r"(\d.*)", c) 508 | if len(cs) < 2: 509 | continue 510 | key, val = cs[:2] 511 | if key == "mstd": 512 | # noise param injected via hparams for now 513 | hparams.setdefault("magnitude_std", float(val)) 514 | elif key == "inc": 515 | if bool(val): 516 | transforms = _RAND_INCREASING_TRANSFORMS 517 | elif key == "m": 518 | magnitude = int(val) 519 | elif key == "n": 520 | num_layers = int(val) 521 | elif key == "w": 522 | weight_idx = int(val) 523 | else: 524 | assert NotImplementedError 525 | ra_ops = rand_augment_ops( 526 | magnitude=magnitude, hparams=hparams, transforms=transforms 527 | ) 528 | choice_weights = ( 529 | None if weight_idx is None else _select_rand_weights(weight_idx) 530 | ) 531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 532 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import math 4 | import time 5 | import json 6 | from collections import defaultdict, deque 7 | import datetime 8 | import numpy as np 9 | from timm.utils import get_state_dict 10 | from torch.utils.data._utils.collate import default_collate 11 | from pathlib import Path 12 | import subprocess 13 | import torch 14 | import torch.distributed as dist 15 | from torch._six import inf 16 | import random 17 | 18 | from tensorboardX import SummaryWriter 19 | 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window_size=20, fmt=None): 27 | if fmt is None: 28 | fmt = "{median:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window_size) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | if not is_dist_avail_and_initialized(): 44 | return 45 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 46 | dist.barrier() 47 | dist.all_reduce(t) 48 | t = t.tolist() 49 | self.count = int(t[0]) 50 | self.total = t[1] 51 | 52 | @property 53 | def median(self): 54 | d = torch.tensor(list(self.deque)) 55 | return d.median().item() 56 | 57 | @property 58 | def avg(self): 59 | d = torch.tensor(list(self.deque), dtype=torch.float32) 60 | return d.mean().item() 61 | 62 | @property 63 | def global_avg(self): 64 | return self.total / self.count 65 | 66 | @property 67 | def max(self): 68 | return max(self.deque) 69 | 70 | @property 71 | def value(self): 72 | return self.deque[-1] 73 | 74 | def __str__(self): 75 | return self.fmt.format( 76 | median=self.median, 77 | avg=self.avg, 78 | global_avg=self.global_avg, 79 | max=self.max, 80 | value=self.value) 81 | 82 | 83 | class MetricLogger(object): 84 | def __init__(self, delimiter="\t"): 85 | self.meters = defaultdict(SmoothedValue) 86 | self.delimiter = delimiter 87 | 88 | def update(self, **kwargs): 89 | for k, v in kwargs.items(): 90 | if v is None: 91 | continue 92 | if isinstance(v, torch.Tensor): 93 | v = v.item() 94 | assert isinstance(v, (float, int)) 95 | self.meters[k].update(v) 96 | 97 | def __getattr__(self, attr): 98 | if attr in self.meters: 99 | return self.meters[attr] 100 | if attr in self.__dict__: 101 | return self.__dict__[attr] 102 | raise AttributeError("'{}' object has no attribute '{}'".format( 103 | type(self).__name__, attr)) 104 | 105 | def __str__(self): 106 | loss_str = [] 107 | for name, meter in self.meters.items(): 108 | loss_str.append( 109 | "{}: {}".format(name, str(meter)) 110 | ) 111 | return self.delimiter.join(loss_str) 112 | 113 | def synchronize_between_processes(self): 114 | for meter in self.meters.values(): 115 | meter.synchronize_between_processes() 116 | 117 | def add_meter(self, name, meter): 118 | self.meters[name] = meter 119 | 120 | def log_every(self, iterable, print_freq, header=None): 121 | i = 0 122 | if not header: 123 | header = '' 124 | start_time = time.time() 125 | end = time.time() 126 | iter_time = SmoothedValue(fmt='{avg:.4f}') 127 | data_time = SmoothedValue(fmt='{avg:.4f}') 128 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 129 | log_msg = [ 130 | header, 131 | '[{0' + space_fmt + '}/{1}]', 132 | 'eta: {eta}', 133 | '{meters}', 134 | 'time: {time}', 135 | 'data: {data}' 136 | ] 137 | if torch.cuda.is_available(): 138 | log_msg.append('max mem: {memory:.0f}') 139 | log_msg = self.delimiter.join(log_msg) 140 | MB = 1024.0 * 1024.0 141 | for obj in iterable: 142 | data_time.update(time.time() - end) 143 | yield obj 144 | iter_time.update(time.time() - end) 145 | if i % print_freq == 0 or i == len(iterable) - 1: 146 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 147 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 148 | if torch.cuda.is_available(): 149 | print(log_msg.format( 150 | i, len(iterable), eta=eta_string, 151 | meters=str(self), 152 | time=str(iter_time), data=str(data_time), 153 | memory=torch.cuda.max_memory_allocated() / MB)) 154 | else: 155 | print(log_msg.format( 156 | i, len(iterable), eta=eta_string, 157 | meters=str(self), 158 | time=str(iter_time), data=str(data_time))) 159 | i += 1 160 | end = time.time() 161 | total_time = time.time() - start_time 162 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 163 | print('{} Total time: {} ({:.4f} s / it)'.format( 164 | header, total_time_str, total_time / len(iterable))) 165 | 166 | 167 | class TensorboardLogger(object): 168 | def __init__(self, log_dir): 169 | self.writer = SummaryWriter(logdir=log_dir) 170 | self.step = 0 171 | 172 | def set_step(self, step=None): 173 | if step is not None: 174 | self.step = step 175 | else: 176 | self.step += 1 177 | 178 | def update(self, head='scalar', step=None, **kwargs): 179 | for k, v in kwargs.items(): 180 | if v is None: 181 | continue 182 | if isinstance(v, torch.Tensor): 183 | v = v.item() 184 | assert isinstance(v, (float, int)) 185 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 186 | 187 | def flush(self): 188 | self.writer.flush() 189 | 190 | def seed_worker(worker_id): 191 | worker_seed = torch.initial_seed() % 2**32 192 | np.random.seed(worker_seed) 193 | random.seed(worker_seed) 194 | 195 | def _load_checkpoint_for_ema(model_ema, checkpoint): 196 | """ 197 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 198 | """ 199 | mem_file = io.BytesIO() 200 | torch.save(checkpoint, mem_file) 201 | mem_file.seek(0) 202 | model_ema._load_checkpoint(mem_file) 203 | 204 | 205 | def setup_for_distributed(is_master): 206 | """ 207 | This function disables printing when not in master process 208 | """ 209 | import builtins as __builtin__ 210 | builtin_print = __builtin__.print 211 | 212 | def print(*args, **kwargs): 213 | force = kwargs.pop('force', False) 214 | if is_master or force: 215 | builtin_print(*args, **kwargs) 216 | 217 | __builtin__.print = print 218 | 219 | 220 | def is_dist_avail_and_initialized(): 221 | if not dist.is_available(): 222 | return False 223 | if not dist.is_initialized(): 224 | return False 225 | return True 226 | 227 | 228 | def get_world_size(): 229 | if not is_dist_avail_and_initialized(): 230 | return 1 231 | return dist.get_world_size() 232 | 233 | 234 | def get_rank(): 235 | if not is_dist_avail_and_initialized(): 236 | return 0 237 | return dist.get_rank() 238 | 239 | 240 | def is_main_process(): 241 | return get_rank() == 0 242 | 243 | 244 | def save_on_master(*args, **kwargs): 245 | if is_main_process(): 246 | torch.save(*args, **kwargs) 247 | 248 | 249 | def init_distributed_mode(args): 250 | if args.dist_on_itp: 251 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 252 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 253 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 254 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 255 | os.environ['LOCAL_RANK'] = str(args.gpu) 256 | os.environ['RANK'] = str(args.rank) 257 | os.environ['WORLD_SIZE'] = str(args.world_size) 258 | elif 'SLURM_PROCID' in os.environ: 259 | args.rank = int(os.environ['SLURM_PROCID']) 260 | args.gpu = int(os.environ['SLURM_LOCALID']) 261 | args.world_size = int(os.environ['SLURM_NTASKS']) 262 | os.environ['RANK'] = str(args.rank) 263 | os.environ['LOCAL_RANK'] = str(args.gpu) 264 | os.environ['WORLD_SIZE'] = str(args.world_size) 265 | 266 | node_list = os.environ['SLURM_NODELIST'] 267 | addr = subprocess.getoutput( 268 | f'scontrol show hostname {node_list} | head -n1') 269 | if 'MASTER_ADDR' not in os.environ: 270 | os.environ['MASTER_ADDR'] = addr 271 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 272 | args.rank = int(os.environ["RANK"]) 273 | args.world_size = int(os.environ['WORLD_SIZE']) 274 | args.gpu = int(os.environ['LOCAL_RANK']) 275 | else: 276 | print('Not using distributed mode') 277 | args.distributed = False 278 | return 279 | 280 | args.distributed = True 281 | 282 | torch.cuda.set_device(args.gpu) 283 | args.dist_backend = 'nccl' 284 | print('| distributed init (rank {}): {}, gpu {}'.format( 285 | args.rank, args.dist_url, args.gpu), flush=True) 286 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 287 | world_size=args.world_size, rank=args.rank) 288 | torch.distributed.barrier() 289 | # assert torch.distributed.is_initialized() 290 | setup_for_distributed(args.rank == 0) 291 | 292 | 293 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 294 | missing_keys = [] 295 | unexpected_keys = [] 296 | error_msgs = [] 297 | metadata = getattr(state_dict, '_metadata', None) 298 | state_dict = state_dict.copy() 299 | if metadata is not None: 300 | state_dict._metadata = metadata 301 | 302 | def load(module, prefix=''): 303 | local_metadata = {} if metadata is None else metadata.get( 304 | prefix[:-1], {}) 305 | module._load_from_state_dict( 306 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 307 | for name, child in module._modules.items(): 308 | if child is not None: 309 | load(child, prefix + name + '.') 310 | 311 | load(model, prefix=prefix) 312 | 313 | warn_missing_keys = [] 314 | ignore_missing_keys = [] 315 | for key in missing_keys: 316 | keep_flag = True 317 | for ignore_key in ignore_missing.split('|'): 318 | if ignore_key in key: 319 | keep_flag = False 320 | break 321 | if keep_flag: 322 | warn_missing_keys.append(key) 323 | else: 324 | ignore_missing_keys.append(key) 325 | 326 | missing_keys = warn_missing_keys 327 | 328 | if len(missing_keys) > 0: 329 | print("Weights of {} not initialized from pretrained model: {}".format( 330 | model.__class__.__name__, missing_keys)) 331 | if len(unexpected_keys) > 0: 332 | print("Weights from pretrained model not used in {}: {}".format( 333 | model.__class__.__name__, unexpected_keys)) 334 | if len(ignore_missing_keys) > 0: 335 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 336 | model.__class__.__name__, ignore_missing_keys)) 337 | if len(error_msgs) > 0: 338 | print('\n'.join(error_msgs)) 339 | 340 | 341 | class NativeScalerWithGradNormCount: 342 | state_dict_key = "amp_scaler" 343 | 344 | def __init__(self): 345 | self._scaler = torch.cuda.amp.GradScaler() 346 | 347 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 348 | self._scaler.scale(loss).backward(create_graph=create_graph) 349 | if update_grad: 350 | if clip_grad is not None: 351 | assert parameters is not None 352 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 353 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 354 | else: 355 | self._scaler.unscale_(optimizer) 356 | norm = get_grad_norm_(parameters) 357 | self._scaler.step(optimizer) 358 | self._scaler.update() 359 | else: 360 | norm = None 361 | return norm 362 | 363 | def state_dict(self): 364 | return self._scaler.state_dict() 365 | 366 | def load_state_dict(self, state_dict): 367 | self._scaler.load_state_dict(state_dict) 368 | 369 | 370 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 371 | if isinstance(parameters, torch.Tensor): 372 | parameters = [parameters] 373 | parameters = [p for p in parameters if p.grad is not None] 374 | norm_type = float(norm_type) 375 | if len(parameters) == 0: 376 | return torch.tensor(0.) 377 | device = parameters[0].grad.device 378 | if norm_type == inf: 379 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 380 | else: 381 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 382 | return total_norm 383 | 384 | 385 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 386 | start_warmup_value=0, warmup_steps=-1): 387 | warmup_schedule = np.array([]) 388 | warmup_iters = warmup_epochs * niter_per_ep 389 | if warmup_steps > 0: 390 | warmup_iters = warmup_steps 391 | print("Set warmup steps = %d" % warmup_iters) 392 | if warmup_epochs > 0: 393 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 394 | 395 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 396 | schedule = np.array( 397 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 398 | 399 | schedule = np.concatenate((warmup_schedule, schedule)) 400 | 401 | assert len(schedule) == epochs * niter_per_ep 402 | return schedule 403 | 404 | 405 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 406 | output_dir = Path(args.output_dir) 407 | epoch_name = str(epoch) 408 | if loss_scaler is not None: 409 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 410 | for checkpoint_path in checkpoint_paths: 411 | to_save = { 412 | 'model': model_without_ddp.state_dict(), 413 | 'optimizer': optimizer.state_dict(), 414 | 'epoch': epoch, 415 | 'scaler': loss_scaler.state_dict(), 416 | 'args': args, 417 | } 418 | 419 | if model_ema is not None: 420 | to_save['model_ema'] = get_state_dict(model_ema) 421 | 422 | save_on_master(to_save, checkpoint_path) 423 | else: 424 | client_state = {'epoch': epoch} 425 | if model_ema is not None: 426 | client_state['model_ema'] = get_state_dict(model_ema) 427 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 428 | 429 | 430 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 431 | output_dir = Path(args.output_dir) 432 | if loss_scaler is not None: 433 | # torch.amp 434 | if args.auto_resume and len(args.resume) == 0: 435 | import glob 436 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 437 | latest_ckpt = -1 438 | for ckpt in all_checkpoints: 439 | t = ckpt.split('-')[-1].split('.')[0] 440 | if t.isdigit(): 441 | latest_ckpt = max(int(t), latest_ckpt) 442 | if latest_ckpt >= 0: 443 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 444 | print("Auto resume checkpoint: %s" % args.resume) 445 | 446 | if args.resume: 447 | if args.resume.startswith('https'): 448 | checkpoint = torch.hub.load_state_dict_from_url( 449 | args.resume, map_location='cpu', check_hash=True) 450 | else: 451 | checkpoint = torch.load(args.resume, map_location='cpu') 452 | model_without_ddp.load_state_dict(checkpoint['model']) 453 | print("Resume checkpoint %s" % args.resume) 454 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 455 | optimizer.load_state_dict(checkpoint['optimizer']) 456 | args.start_epoch = checkpoint['epoch'] + 1 457 | if hasattr(args, 'model_ema') and args.model_ema: 458 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 459 | if 'scaler' in checkpoint: 460 | loss_scaler.load_state_dict(checkpoint['scaler']) 461 | print("With optim & sched!") 462 | else: 463 | # deepspeed, only support '--auto_resume'. 464 | if args.auto_resume: 465 | import glob 466 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*')) 467 | latest_ckpt = -1 468 | for ckpt in all_checkpoints: 469 | t = ckpt.split('-')[-1].split('.')[0] 470 | if t.isdigit(): 471 | latest_ckpt = max(int(t), latest_ckpt) 472 | if latest_ckpt >= 0: 473 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt) 474 | print("Auto resume checkpoint: %d" % latest_ckpt) 475 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt) 476 | args.start_epoch = client_states['epoch'] + 1 477 | if model_ema is not None: 478 | if args.model_ema: 479 | _load_checkpoint_for_ema(model_ema, client_states['model_ema']) 480 | 481 | 482 | def create_ds_config(args): 483 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json") 484 | with open(args.deepspeed_config, mode="w") as writer: 485 | ds_config = { 486 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(), 487 | "train_micro_batch_size_per_gpu": args.batch_size, 488 | "steps_per_print": 1000, 489 | "optimizer": { 490 | "type": "Adam", 491 | "adam_w_mode": True, 492 | "params": { 493 | "lr": args.lr, 494 | "weight_decay": args.weight_decay, 495 | "bias_correction": True, 496 | "betas": [ 497 | 0.9, 498 | 0.999 499 | ], 500 | "eps": 1e-8 501 | } 502 | }, 503 | "fp16": { 504 | "enabled": True, 505 | "loss_scale": 0, 506 | "initial_scale_power": 7, 507 | "loss_scale_window": 128 508 | } 509 | } 510 | 511 | writer.write(json.dumps(ds_config, indent=2)) 512 | 513 | def multiple_samples_collate(batch, fold=False): 514 | """ 515 | Collate function for repeated augmentation. Each instance in the batch has 516 | more than one sample. 517 | Args: 518 | batch (tuple or list): data batch to collate. 519 | Returns: 520 | (tuple): collated data batch. 521 | """ 522 | inputs, labels, video_idx, extra_data = zip(*batch) 523 | inputs = [item for sublist in inputs for item in sublist] 524 | labels = [item for sublist in labels for item in sublist] 525 | video_idx = [item for sublist in video_idx for item in sublist] 526 | inputs, labels, video_idx, extra_data = ( 527 | default_collate(inputs), 528 | default_collate(labels), 529 | default_collate(video_idx), 530 | default_collate(extra_data), 531 | ) 532 | if fold: 533 | return [inputs], labels, video_idx, extra_data 534 | else: 535 | return inputs, labels, video_idx, extra_data 536 | -------------------------------------------------------------------------------- /ucf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from numpy.lib.function_base import disp 4 | import torch 5 | import decord 6 | from PIL import Image 7 | from torchvision import transforms 8 | from random_erasing import RandomErasing 9 | import warnings 10 | from decord import VideoReader, cpu 11 | from torch.utils.data import Dataset 12 | import video_transforms as video_transforms 13 | import volume_transforms as volume_transforms 14 | 15 | class VideoClsDataset(Dataset): 16 | """Load your own video classification dataset.""" 17 | 18 | def __init__(self, anno_path, data_path, mode='train', clip_len=8, 19 | frame_sample_rate=2, crop_size=224, short_side_size=256, 20 | new_height=256, new_width=340, keep_aspect_ratio=True, 21 | num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,args=None): 22 | self.anno_path = anno_path 23 | self.data_path = data_path 24 | self.mode = mode 25 | self.clip_len = clip_len 26 | self.frame_sample_rate = frame_sample_rate 27 | self.crop_size = crop_size 28 | self.short_side_size = short_side_size 29 | self.new_height = new_height 30 | self.new_width = new_width 31 | self.keep_aspect_ratio = keep_aspect_ratio 32 | self.num_segment = num_segment 33 | self.test_num_segment = test_num_segment 34 | self.num_crop = num_crop 35 | self.test_num_crop = test_num_crop 36 | self.args = args 37 | self.aug = False 38 | self.rand_erase = False 39 | if self.mode in ['train']: 40 | self.aug = True 41 | if self.args.reprob > 0: 42 | self.rand_erase = True 43 | if VideoReader is None: 44 | raise ImportError("Unable to import `decord` which is required to read videos.") 45 | 46 | import pandas as pd 47 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ') 48 | self.dataset_samples = list(cleaned.values[:, 0]) 49 | self.label_array = list(cleaned.values[:, 1]) 50 | 51 | if (mode == 'train'): 52 | pass 53 | 54 | elif (mode == 'validation'): 55 | self.data_transform = video_transforms.Compose([ 56 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'), 57 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 58 | volume_transforms.ClipToTensor(), 59 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 60 | std=[0.229, 0.224, 0.225]) 61 | ]) 62 | elif mode == 'test': 63 | self.data_resize = video_transforms.Compose([ 64 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear') 65 | ]) 66 | self.data_transform = video_transforms.Compose([ 67 | volume_transforms.ClipToTensor(), 68 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 69 | std=[0.229, 0.224, 0.225]) 70 | ]) 71 | self.test_seg = [] 72 | self.test_dataset = [] 73 | self.test_label_array = [] 74 | for ck in range(self.test_num_segment): 75 | for cp in range(self.test_num_crop): 76 | for idx in range(len(self.label_array)): 77 | sample_label = self.label_array[idx] 78 | self.test_label_array.append(sample_label) 79 | self.test_dataset.append(self.dataset_samples[idx]) 80 | self.test_seg.append((ck, cp)) 81 | 82 | def __getitem__(self, index): 83 | if self.mode == 'train': 84 | args = self.args 85 | scale_t = 1 86 | 87 | sample = self.dataset_samples[index] 88 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C 89 | if len(buffer) == 0: 90 | while len(buffer) == 0: 91 | warnings.warn("video {} not correctly loaded during training".format(sample)) 92 | index = np.random.randint(self.__len__()) 93 | sample = self.dataset_samples[index] 94 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) 95 | 96 | if args.num_sample > 1: 97 | frame_list = [] 98 | label_list = [] 99 | index_list = [] 100 | for _ in range(args.num_sample): 101 | new_frames = self._aug_frame(buffer, args) 102 | label = self.label_array[index] 103 | frame_list.append(new_frames) 104 | label_list.append(label) 105 | index_list.append(index) 106 | return frame_list, label_list, index_list, {} 107 | else: 108 | buffer = self._aug_frame(buffer, args) 109 | 110 | return buffer, self.label_array[index], index, {} 111 | 112 | elif self.mode == 'validation': 113 | sample = self.dataset_samples[index] 114 | buffer = self.loadvideo_decord(sample) 115 | if len(buffer) == 0: 116 | while len(buffer) == 0: 117 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 118 | index = np.random.randint(self.__len__()) 119 | sample = self.dataset_samples[index] 120 | buffer = self.loadvideo_decord(sample) 121 | buffer = self.data_transform(buffer) 122 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 123 | 124 | elif self.mode == 'test': 125 | sample = self.test_dataset[index] 126 | chunk_nb, split_nb = self.test_seg[index] 127 | buffer = self.loadvideo_decord(sample) 128 | 129 | while len(buffer) == 0: 130 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 131 | str(self.test_dataset[index]), chunk_nb, split_nb)) 132 | index = np.random.randint(self.__len__()) 133 | sample = self.test_dataset[index] 134 | chunk_nb, split_nb = self.test_seg[index] 135 | buffer = self.loadvideo_decord(sample) 136 | 137 | buffer = self.data_resize(buffer) 138 | if isinstance(buffer, list): 139 | buffer = np.stack(buffer, 0) 140 | 141 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 142 | / (self.test_num_crop - 1) 143 | temporal_step = max(1.0 * (buffer.shape[0] - self.clip_len) \ 144 | / (self.test_num_segment - 1), 0) 145 | temporal_start = int(chunk_nb * temporal_step) 146 | spatial_start = int(split_nb * spatial_step) 147 | if buffer.shape[1] >= buffer.shape[2]: 148 | buffer = buffer[temporal_start:temporal_start + self.clip_len, \ 149 | spatial_start:spatial_start + self.short_side_size, :, :] 150 | else: 151 | buffer = buffer[temporal_start:temporal_start + self.clip_len, \ 152 | :, spatial_start:spatial_start + self.short_side_size, :] 153 | 154 | buffer = self.data_transform(buffer) 155 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 156 | chunk_nb, split_nb 157 | else: 158 | raise NameError('mode {} unkown'.format(self.mode)) 159 | 160 | def _aug_frame( 161 | self, 162 | buffer, 163 | args, 164 | ): 165 | 166 | aug_transform = video_transforms.create_random_augment( 167 | input_size=(self.crop_size, self.crop_size), 168 | auto_augment=args.aa, 169 | interpolation=args.train_interpolation, 170 | ) 171 | 172 | buffer = [ 173 | transforms.ToPILImage()(frame) for frame in buffer 174 | ] 175 | 176 | buffer = aug_transform(buffer) 177 | 178 | buffer = [transforms.ToTensor()(img) for img in buffer] 179 | buffer = torch.stack(buffer) # T C H W 180 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 181 | 182 | # T H W C 183 | buffer = tensor_normalize( 184 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 185 | ) 186 | # T H W C -> C T H W. 187 | buffer = buffer.permute(3, 0, 1, 2) 188 | # Perform data augmentation. 189 | scl, asp = ( 190 | [0.08, 1.0], 191 | [0.75, 1.3333], 192 | ) 193 | 194 | buffer = spatial_sampling( 195 | buffer, 196 | spatial_idx=-1, 197 | min_scale=256, 198 | max_scale=320, 199 | crop_size=self.crop_size, 200 | random_horizontal_flip=False if args.data_set == 'SSV2' else True , 201 | inverse_uniform_sampling=False, 202 | aspect_ratio=asp, 203 | scale=scl, 204 | motion_shift=False 205 | ) 206 | 207 | if self.rand_erase: 208 | erase_transform = RandomErasing( 209 | args.reprob, 210 | mode=args.remode, 211 | max_count=args.recount, 212 | num_splits=args.recount, 213 | device="cpu", 214 | ) 215 | buffer = buffer.permute(1, 0, 2, 3) 216 | buffer = erase_transform(buffer) 217 | buffer = buffer.permute(1, 0, 2, 3) 218 | 219 | return buffer 220 | 221 | 222 | def loadvideo_decord(self, sample, sample_rate_scale=1): 223 | """Load video content using Decord""" 224 | fname = sample 225 | 226 | if not (os.path.exists(fname)): 227 | return [] 228 | 229 | # avoid hanging issue 230 | if os.path.getsize(fname) < 1 * 1024: 231 | print('SKIP: ', fname, " - ", os.path.getsize(fname)) 232 | return [] 233 | try: 234 | if self.keep_aspect_ratio: 235 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 236 | else: 237 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 238 | num_threads=1, ctx=cpu(0)) 239 | except: 240 | print("video cannot be loaded by decord: ", fname) 241 | return [] 242 | 243 | if self.mode == 'test': 244 | all_index = [x for x in range(0, len(vr), self.frame_sample_rate)] 245 | while len(all_index) < self.clip_len: 246 | all_index.append(all_index[-1]) 247 | vr.seek(0) 248 | buffer = vr.get_batch(all_index).asnumpy() 249 | return buffer 250 | 251 | # handle temporal segments 252 | converted_len = int(self.clip_len * self.frame_sample_rate) 253 | seg_len = len(vr) // self.num_segment 254 | 255 | all_index = [] 256 | for i in range(self.num_segment): 257 | if seg_len <= converted_len: 258 | index = np.linspace(0, seg_len, num=seg_len // self.frame_sample_rate) 259 | index = np.concatenate((index, np.ones(self.clip_len - seg_len // self.frame_sample_rate) * seg_len)) 260 | index = np.clip(index, 0, seg_len - 1).astype(np.int64) 261 | else: 262 | end_idx = np.random.randint(converted_len, seg_len) 263 | str_idx = end_idx - converted_len 264 | index = np.linspace(str_idx, end_idx, num=self.clip_len) 265 | index = np.clip(index, str_idx, end_idx - 1).astype(np.int64) 266 | index = index + i*seg_len 267 | all_index.extend(list(index)) 268 | 269 | all_index = all_index[::int(sample_rate_scale)] 270 | vr.seek(0) 271 | buffer = vr.get_batch(all_index).asnumpy() 272 | return buffer 273 | 274 | def __len__(self): 275 | if self.mode != 'test': 276 | return len(self.dataset_samples) 277 | else: 278 | return len(self.test_dataset) 279 | 280 | 281 | def spatial_sampling( 282 | frames, 283 | spatial_idx=-1, 284 | min_scale=256, 285 | max_scale=320, 286 | crop_size=224, 287 | random_horizontal_flip=True, 288 | inverse_uniform_sampling=False, 289 | aspect_ratio=None, 290 | scale=None, 291 | motion_shift=False, 292 | ): 293 | """ 294 | Perform spatial sampling on the given video frames. If spatial_idx is 295 | -1, perform random scale, random crop, and random flip on the given 296 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 297 | with the given spatial_idx. 298 | Args: 299 | frames (tensor): frames of images sampled from the video. The 300 | dimension is `num frames` x `height` x `width` x `channel`. 301 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 302 | or 2, perform left, center, right crop if width is larger than 303 | height, and perform top, center, buttom crop if height is larger 304 | than width. 305 | min_scale (int): the minimal size of scaling. 306 | max_scale (int): the maximal size of scaling. 307 | crop_size (int): the size of height and width used to crop the 308 | frames. 309 | inverse_uniform_sampling (bool): if True, sample uniformly in 310 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 311 | scale. If False, take a uniform sample from [min_scale, 312 | max_scale]. 313 | aspect_ratio (list): Aspect ratio range for resizing. 314 | scale (list): Scale range for resizing. 315 | motion_shift (bool): Whether to apply motion shift for resizing. 316 | Returns: 317 | frames (tensor): spatially sampled frames. 318 | """ 319 | assert spatial_idx in [-1, 0, 1, 2] 320 | if spatial_idx == -1: 321 | if aspect_ratio is None and scale is None: 322 | frames, _ = video_transforms.random_short_side_scale_jitter( 323 | images=frames, 324 | min_size=min_scale, 325 | max_size=max_scale, 326 | inverse_uniform_sampling=inverse_uniform_sampling, 327 | ) 328 | frames, _ = video_transforms.random_crop(frames, crop_size) 329 | else: 330 | transform_func = ( 331 | video_transforms.random_resized_crop_with_shift 332 | if motion_shift 333 | else video_transforms.random_resized_crop 334 | ) 335 | frames = transform_func( 336 | images=frames, 337 | target_height=crop_size, 338 | target_width=crop_size, 339 | scale=scale, 340 | ratio=aspect_ratio, 341 | ) 342 | if random_horizontal_flip: 343 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 344 | else: 345 | # The testing is deterministic and no jitter should be performed. 346 | # min_scale, max_scale, and crop_size are expect to be the same. 347 | assert len({min_scale, max_scale, crop_size}) == 1 348 | frames, _ = video_transforms.random_short_side_scale_jitter( 349 | frames, min_scale, max_scale 350 | ) 351 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 352 | return frames 353 | 354 | 355 | def tensor_normalize(tensor, mean, std): 356 | """ 357 | Normalize a given tensor by subtracting the mean and dividing the std. 358 | Args: 359 | tensor (tensor): tensor to normalize. 360 | mean (tensor or list): mean value to subtract. 361 | std (tensor or list): std to divide. 362 | """ 363 | if tensor.dtype == torch.uint8: 364 | tensor = tensor.float() 365 | tensor = tensor / 255.0 366 | if type(mean) == list: 367 | mean = torch.tensor(mean) 368 | if type(std) == list: 369 | std = torch.tensor(std) 370 | tensor = tensor - mean 371 | tensor = tensor / std 372 | return tensor 373 | 374 | 375 | class VideoMS(torch.utils.data.Dataset): 376 | """Load your own video classification dataset. 377 | Parameters 378 | ---------- 379 | root : str, required. 380 | Path to the root folder storing the dataset. 381 | setting : str, required. 382 | A text file describing the dataset, each line per video sample. 383 | There are three items in each line: (1) video path; (2) video length and (3) video label. 384 | train : bool, default True. 385 | Whether to load the training or validation set. 386 | test_mode : bool, default False. 387 | Whether to perform evaluation on the test set. 388 | Usually there is three-crop or ten-crop evaluation strategy involved. 389 | name_pattern : str, default None. 390 | The naming pattern of the decoded video frames. 391 | For example, img_00012.jpg. 392 | video_ext : str, default 'mp4'. 393 | If video_loader is set to True, please specify the video format accordinly. 394 | is_color : bool, default True. 395 | Whether the loaded image is color or grayscale. 396 | modality : str, default 'rgb'. 397 | Input modalities, we support only rgb video frames for now. 398 | Will add support for rgb difference image and optical flow image later. 399 | num_segments : int, default 1. 400 | Number of segments to evenly divide the video into clips. 401 | A useful technique to obtain global video-level information. 402 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. 403 | num_crop : int, default 1. 404 | Number of crops for each image. default is 1. 405 | Common choices are three crops and ten crops during evaluation. 406 | new_length : int, default 1. 407 | The length of input video clip. Default is a single image, but it can be multiple video frames. 408 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames. 409 | new_step : int, default 1. 410 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. 411 | new_step=2 means we will extract a video clip of every other frame. 412 | temporal_jitter : bool, default False. 413 | Whether to temporally jitter if new_step > 1. 414 | video_loader : bool, default False. 415 | Whether to use video loader to load data. 416 | use_decord : bool, default True. 417 | Whether to use Decord video loader to load data. Otherwise use mmcv video loader. 418 | transform : function, default None. 419 | A function that takes data and label and transforms them. 420 | data_aug : str, default 'v1'. 421 | Different types of data augmentation auto. Supports v1, v2, v3 and v4. 422 | lazy_init : bool, default False. 423 | If set to True, build a dataset instance without loading any dataset. 424 | """ 425 | def __init__(self, 426 | root, 427 | setting, 428 | train=True, 429 | test_mode=False, 430 | name_pattern='img_%05d.jpg', 431 | video_ext='mp4', 432 | is_color=True, 433 | modality='rgb', 434 | num_segments=1, 435 | num_crop=1, 436 | new_length=1, 437 | new_step=1, 438 | transform=None, 439 | temporal_jitter=False, 440 | video_loader=False, 441 | use_decord=False, 442 | lazy_init=False): 443 | 444 | super(VideoMS, self).__init__() 445 | self.root = root 446 | self.setting = setting 447 | self.train = train 448 | self.test_mode = test_mode 449 | self.is_color = is_color 450 | self.modality = modality 451 | self.num_segments = num_segments 452 | self.num_crop = num_crop 453 | self.new_length = new_length 454 | self.new_step = new_step 455 | self.skip_length = self.new_length * self.new_step 456 | self.temporal_jitter = temporal_jitter 457 | self.name_pattern = name_pattern 458 | self.video_loader = video_loader 459 | self.video_ext = video_ext 460 | self.use_decord = use_decord 461 | self.transform = transform 462 | self.lazy_init = lazy_init 463 | 464 | 465 | if not self.lazy_init: 466 | self.clips = self._make_dataset(root, setting) 467 | if len(self.clips) == 0: 468 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" 469 | "Check your data directory (opt.data-dir).")) 470 | 471 | def __getitem__(self, index): 472 | 473 | directory, target = self.clips[index] 474 | if self.video_loader: 475 | if '.' in directory.split('/')[-1]: 476 | # data in the "setting" file already have extension, e.g., demo.mp4 477 | video_name = directory 478 | else: 479 | # data in the "setting" file do not have extension, e.g., demo 480 | # So we need to provide extension (i.e., .mp4) to complete the file name. 481 | video_name = '{}.{}'.format(directory, self.video_ext) 482 | 483 | decord_vr = decord.VideoReader(video_name, num_threads=1) 484 | duration = len(decord_vr) 485 | 486 | segment_indices, skip_offsets = self._sample_train_indices(duration) 487 | 488 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets) 489 | 490 | process_data, mask = self.transform((images, None)) # T*C,H,W 491 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W 492 | 493 | return (process_data, mask) 494 | 495 | def __len__(self): 496 | return len(self.clips) 497 | 498 | def _make_dataset(self, directory, setting): 499 | if not os.path.exists(setting): 500 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) 501 | clips = [] 502 | with open(setting) as split_f: 503 | data = split_f.readlines() 504 | for line in data: 505 | line_info = line.split(' ') 506 | # line format: video_path, video_duration, video_label 507 | if len(line_info) < 2: 508 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) 509 | clip_path = os.path.join(line_info[0]) 510 | target = int(line_info[1]) 511 | item = (clip_path, target) 512 | clips.append(item) 513 | return clips 514 | 515 | def _sample_train_indices(self, num_frames): 516 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments 517 | if average_duration > 0: 518 | offsets = np.multiply(list(range(self.num_segments)), 519 | average_duration) 520 | offsets = offsets + np.random.randint(average_duration, 521 | size=self.num_segments) 522 | elif num_frames > max(self.num_segments, self.skip_length): 523 | offsets = np.sort(np.random.randint( 524 | num_frames - self.skip_length + 1, 525 | size=self.num_segments)) 526 | else: 527 | offsets = np.zeros((self.num_segments,)) 528 | 529 | if self.temporal_jitter: 530 | skip_offsets = np.random.randint( 531 | self.new_step, size=self.skip_length // self.new_step) 532 | else: 533 | skip_offsets = np.zeros( 534 | self.skip_length // self.new_step, dtype=int) 535 | return offsets + 1, skip_offsets 536 | 537 | 538 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets): 539 | sampled_list = [] 540 | frame_id_list = [] 541 | for seg_ind in indices: 542 | offset = int(seg_ind) 543 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)): 544 | if offset + skip_offsets[i] <= duration: 545 | frame_id = offset + skip_offsets[i] - 1 546 | else: 547 | frame_id = offset - 1 548 | frame_id_list.append(frame_id) 549 | if offset + self.new_step < duration: 550 | offset += self.new_step 551 | try: 552 | video_data = video_reader.get_batch(frame_id_list).asnumpy() 553 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] 554 | except: 555 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) 556 | return sampled_list 557 | -------------------------------------------------------------------------------- /run_class_finetuning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import os 9 | from functools import partial 10 | from pathlib import Path 11 | from collections import OrderedDict 12 | 13 | from mixup import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.utils import ModelEma 17 | from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner 18 | 19 | from datasets import build_dataset 20 | from engine_for_finetuning import train_one_epoch, validation_one_epoch, final_test, merge 21 | from utils import NativeScalerWithGradNormCount as NativeScaler 22 | from utils import multiple_samples_collate 23 | import utils 24 | import modeling_finetune 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser('VideoMS fine-tuning and evaluation script for video classification', add_help=False) 29 | parser.add_argument('--batch_size', default=64, type=int) 30 | parser.add_argument('--epochs', default=30, type=int) 31 | parser.add_argument('--update_freq', default=1, type=int) 32 | parser.add_argument('--save_ckpt_freq', default=100, type=int) 33 | 34 | # Model parameters 35 | parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL', 36 | help='Name of model to train') 37 | parser.add_argument('--tubelet_size', type=int, default= 2) 38 | parser.add_argument('--input_size', default=224, type=int, 39 | help='videos input size') 40 | 41 | parser.add_argument('--fc_drop_rate', type=float, default=0.0, metavar='PCT', 42 | help='Dropout rate (default: 0.)') 43 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 44 | help='Dropout rate (default: 0.)') 45 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT', 46 | help='Attention dropout rate (default: 0.)') 47 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 48 | help='Drop path rate (default: 0.1)') 49 | 50 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False) 51 | parser.add_argument('--model_ema', action='store_true', default=False) 52 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 53 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='') 54 | 55 | # Optimizer parameters 56 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 57 | help='Optimizer (default: "adamw"') 58 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 59 | help='Optimizer Epsilon (default: 1e-8)') 60 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 61 | help='Optimizer Betas (default: None, use opt default)') 62 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 63 | help='Clip gradient norm (default: None, no clipping)') 64 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 65 | help='SGD momentum (default: 0.9)') 66 | parser.add_argument('--weight_decay', type=float, default=0.05, 67 | help='weight decay (default: 0.05)') 68 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 69 | weight decay. We use a cosine schedule for WD and using a larger decay by 70 | the end of training improves performance for ViTs.""") 71 | 72 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 73 | help='learning rate (default: 1e-3)') 74 | parser.add_argument('--layer_decay', type=float, default=0.75) 75 | 76 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR', 77 | help='warmup learning rate (default: 1e-6)') 78 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 79 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 80 | 81 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 82 | help='epochs to warmup LR, if scheduler supports') 83 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 84 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 85 | 86 | # Augmentation parameters 87 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 88 | help='Color jitter factor (default: 0.4)') 89 | parser.add_argument('--num_sample', type=int, default=2, 90 | help='Repeated_aug (default: 2)') 91 | parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME', 92 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'), 93 | parser.add_argument('--smoothing', type=float, default=0.1, 94 | help='Label smoothing (default: 0.1)') 95 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 96 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 97 | 98 | # Evaluation parameters 99 | parser.add_argument('--crop_pct', type=float, default=None) 100 | parser.add_argument('--short_side_size', type=int, default=224) 101 | parser.add_argument('--test_num_segment', type=int, default=5) 102 | parser.add_argument('--test_num_crop', type=int, default=3) 103 | 104 | # Random Erase params 105 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 106 | help='Random erase prob (default: 0.25)') 107 | parser.add_argument('--remode', type=str, default='pixel', 108 | help='Random erase mode (default: "pixel")') 109 | parser.add_argument('--recount', type=int, default=1, 110 | help='Random erase count (default: 1)') 111 | parser.add_argument('--resplit', action='store_true', default=False, 112 | help='Do not random erase first (clean) augmentation split') 113 | 114 | # Mixup params 115 | parser.add_argument('--mixup', type=float, default=0.0, 116 | help='mixup alpha, mixup enabled if > 0.') 117 | parser.add_argument('--cutmix', type=float, default=0.0, 118 | help='cutmix alpha, cutmix enabled if > 0.') 119 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 120 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 121 | parser.add_argument('--mixup_prob', type=float, default=1.0, 122 | help='Probability of performing mixup or cutmix when either/both is enabled') 123 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 124 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 125 | parser.add_argument('--mixup_mode', type=str, default='batch', 126 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 127 | 128 | # Finetuning params 129 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 130 | parser.add_argument('--model_key', default='model|module', type=str) 131 | parser.add_argument('--model_prefix', default='', type=str) 132 | parser.add_argument('--init_scale', default=0.001, type=float) 133 | parser.add_argument('--use_checkpoint', action='store_true') 134 | parser.set_defaults(use_checkpoint=False) 135 | parser.add_argument('--use_mean_pooling', action='store_true') 136 | parser.set_defaults(use_mean_pooling=True) 137 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling') 138 | 139 | # Dataset parameters 140 | parser.add_argument('--data_path', default='/path/to/list_ucf101', type=str, 141 | help='dataset path') 142 | parser.add_argument('--eval_data_path', default=None, type=str, 143 | help='dataset path for evaluation') 144 | parser.add_argument('--nb_classes', default=101, type=int, 145 | help='number of the classification types') 146 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true') 147 | parser.add_argument('--num_segments', type=int, default= 1) 148 | parser.add_argument('--num_frames', type=int, default= 16) 149 | parser.add_argument('--sampling_rate', type=int, default= 4) 150 | parser.add_argument('--data_set', default='UCF101', choices=['UCF101', 'HMDB51', 'OSCC'], 151 | type=str, help='dataset') 152 | parser.add_argument('--output_dir', default='', 153 | help='path where to save, empty for no saving') 154 | parser.add_argument('--log_dir', default=None, 155 | help='path where to tensorboard log') 156 | parser.add_argument('--device', default='cuda', 157 | help='device to use for training / testing') 158 | parser.add_argument('--seed', default=0, type=int) 159 | parser.add_argument('--resume', default='', 160 | help='resume from checkpoint') 161 | parser.add_argument('--auto_resume', action='store_true') 162 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume') 163 | parser.set_defaults(auto_resume=True) 164 | 165 | parser.add_argument('--save_ckpt', action='store_true') 166 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt') 167 | parser.set_defaults(save_ckpt=True) 168 | 169 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 170 | help='start epoch') 171 | parser.add_argument('--eval', action='store_true', 172 | help='Perform evaluation only') 173 | parser.add_argument('--dist_eval', action='store_true', default=False, 174 | help='Enabling distributed evaluation') 175 | parser.add_argument('--num_workers', default=10, type=int) 176 | parser.add_argument('--pin_mem', action='store_true', 177 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 178 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 179 | parser.set_defaults(pin_mem=True) 180 | 181 | # distributed training parameters 182 | parser.add_argument('--world_size', default=1, type=int, 183 | help='number of distributed processes') 184 | parser.add_argument('--local_rank', default=-1, type=int) 185 | parser.add_argument('--dist_on_itp', action='store_true') 186 | parser.add_argument('--dist_url', default='env://', 187 | help='url used to set up distributed training') 188 | 189 | parser.add_argument('--enable_deepspeed', action='store_true', default=False) 190 | 191 | parser.add_argument('--mcm', action='store_true', default=False) 192 | parser.add_argument('--mcm_ratio', default=0.4, type=float) 193 | 194 | known_args, _ = parser.parse_known_args() 195 | 196 | if known_args.enable_deepspeed: 197 | try: 198 | import deepspeed 199 | from deepspeed import DeepSpeedConfig 200 | parser = deepspeed.add_config_arguments(parser) 201 | ds_init = deepspeed.initialize 202 | except: 203 | print("Please 'pip install deepspeed'") 204 | exit(0) 205 | else: 206 | ds_init = None 207 | 208 | return parser.parse_args(), ds_init 209 | 210 | 211 | def main(args, ds_init): 212 | utils.init_distributed_mode(args) 213 | 214 | if ds_init is not None: 215 | utils.create_ds_config(args) 216 | 217 | print(args) 218 | 219 | device = torch.device(args.device) 220 | 221 | # fix the seed for reproducibility 222 | seed = args.seed + utils.get_rank() 223 | torch.manual_seed(seed) 224 | np.random.seed(seed) 225 | # random.seed(seed) 226 | 227 | cudnn.benchmark = True 228 | 229 | dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args) 230 | if args.disable_eval_during_finetuning: 231 | dataset_val = None 232 | else: 233 | dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args) 234 | dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args) 235 | 236 | 237 | num_tasks = utils.get_world_size() 238 | global_rank = utils.get_rank() 239 | sampler_train = torch.utils.data.DistributedSampler( 240 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 241 | ) 242 | print("Sampler_train = %s" % str(sampler_train)) 243 | if args.dist_eval: 244 | if len(dataset_val) % num_tasks != 0: 245 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 246 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 247 | 'equal num of samples per-process.') 248 | sampler_val = torch.utils.data.DistributedSampler( 249 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 250 | sampler_test = torch.utils.data.DistributedSampler( 251 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False) 252 | else: 253 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 254 | 255 | if global_rank == 0 and args.log_dir is not None: 256 | os.makedirs(args.log_dir, exist_ok=True) 257 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 258 | else: 259 | log_writer = None 260 | 261 | if args.num_sample > 1: 262 | collate_func = partial(multiple_samples_collate, fold=False) 263 | else: 264 | collate_func = None 265 | 266 | data_loader_train = torch.utils.data.DataLoader( 267 | dataset_train, sampler=sampler_train, 268 | batch_size=args.batch_size, 269 | num_workers=args.num_workers, 270 | pin_memory=args.pin_mem, 271 | drop_last=True, 272 | collate_fn=collate_func, 273 | ) 274 | 275 | if dataset_val is not None: 276 | data_loader_val = torch.utils.data.DataLoader( 277 | dataset_val, sampler=sampler_val, 278 | batch_size=int(1.5 * args.batch_size), 279 | num_workers=args.num_workers, 280 | pin_memory=args.pin_mem, 281 | drop_last=False 282 | ) 283 | else: 284 | data_loader_val = None 285 | 286 | if dataset_test is not None: 287 | data_loader_test = torch.utils.data.DataLoader( 288 | dataset_test, sampler=sampler_test, 289 | batch_size=args.batch_size, 290 | num_workers=args.num_workers, 291 | pin_memory=args.pin_mem, 292 | drop_last=False 293 | ) 294 | else: 295 | data_loader_test = None 296 | 297 | mixup_fn = None 298 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 299 | if mixup_active: 300 | print("Mixup is activated!") 301 | mixup_fn = Mixup( 302 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 303 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 304 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 305 | 306 | model = create_model( 307 | args.model, 308 | pretrained=False, 309 | num_classes=args.nb_classes, 310 | all_frames=args.num_frames * args.num_segments, 311 | tubelet_size=args.tubelet_size, 312 | fc_drop_rate=args.fc_drop_rate, 313 | drop_rate=args.drop, 314 | drop_path_rate=args.drop_path, 315 | attn_drop_rate=args.attn_drop_rate, 316 | drop_block_rate=None, 317 | use_checkpoint=args.use_checkpoint, 318 | use_mean_pooling=args.use_mean_pooling, 319 | init_scale=args.init_scale, 320 | mcm=args.mcm, 321 | mcm_ratio=args.mcm_ratio 322 | ) 323 | 324 | patch_size = model.patch_embed.patch_size 325 | print("Patch size = %s" % str(patch_size)) 326 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1]) 327 | args.patch_size = patch_size 328 | 329 | if args.finetune: 330 | if args.finetune.startswith('https'): 331 | checkpoint = torch.hub.load_state_dict_from_url( 332 | args.finetune, map_location='cpu', check_hash=True) 333 | else: 334 | checkpoint = torch.load(args.finetune, map_location='cpu') 335 | 336 | print("Load ckpt from %s" % args.finetune) 337 | checkpoint_model = None 338 | for model_key in args.model_key.split('|'): 339 | if model_key in checkpoint: 340 | checkpoint_model = checkpoint[model_key] 341 | print("Load state_dict by model_key = %s" % model_key) 342 | break 343 | if checkpoint_model is None: 344 | checkpoint_model = checkpoint 345 | state_dict = model.state_dict() 346 | for k in ['head.weight', 'head.bias']: 347 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 348 | print(f"Removing key {k} from pretrained checkpoint") 349 | del checkpoint_model[k] 350 | 351 | all_keys = list(checkpoint_model.keys()) 352 | new_dict = OrderedDict() 353 | for key in all_keys: 354 | if key.startswith('backbone.'): 355 | new_dict[key[9:]] = checkpoint_model[key] 356 | elif key.startswith('encoder.'): 357 | new_dict[key[8:]] = checkpoint_model[key] 358 | else: 359 | new_dict[key] = checkpoint_model[key] 360 | checkpoint_model = new_dict 361 | 362 | # interpolate position embedding 363 | if 'pos_embed' in checkpoint_model: 364 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 365 | embedding_size = pos_embed_checkpoint.shape[-1] # channel dim 366 | num_patches = model.patch_embed.num_patches # 367 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 368 | 369 | # height (== width) for the checkpoint position embedding 370 | orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(args.num_frames // model.patch_embed.tubelet_size)) ** 0.5) 371 | # height (== width) for the new position embedding 372 | new_size = int((num_patches // (args.num_frames // model.patch_embed.tubelet_size) )** 0.5) 373 | # class_token and dist_token are kept unchanged 374 | if orig_size != new_size: 375 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 376 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 377 | # only the position tokens are interpolated 378 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 379 | # B, L, C -> BT, H, W, C -> BT, C, H, W 380 | pos_tokens = pos_tokens.reshape(-1, args.num_frames // model.patch_embed.tubelet_size, orig_size, orig_size, embedding_size) 381 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 382 | pos_tokens = torch.nn.functional.interpolate( 383 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 384 | # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C 385 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, args.num_frames // model.patch_embed.tubelet_size, new_size, new_size, embedding_size) 386 | pos_tokens = pos_tokens.flatten(1, 3) # B, L, C 387 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 388 | checkpoint_model['pos_embed'] = new_pos_embed 389 | 390 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 391 | 392 | model.to(device) 393 | 394 | model_ema = None 395 | if args.model_ema: 396 | model_ema = ModelEma( 397 | model, 398 | decay=args.model_ema_decay, 399 | device='cpu' if args.model_ema_force_cpu else '', 400 | resume='') 401 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 402 | 403 | model_without_ddp = model 404 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 405 | 406 | print("Model = %s" % str(model_without_ddp)) 407 | print('number of params:', n_parameters) 408 | 409 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 410 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 411 | args.lr = args.lr * total_batch_size / 256 412 | args.min_lr = args.min_lr * total_batch_size / 256 413 | args.warmup_lr = args.warmup_lr * total_batch_size / 256 414 | print("LR = %.8f" % args.lr) 415 | print("Batch size = %d" % total_batch_size) 416 | print("Update frequent = %d" % args.update_freq) 417 | print("Number of training examples = %d" % len(dataset_train)) 418 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 419 | 420 | num_layers = model_without_ddp.get_num_layers() 421 | if args.layer_decay < 1.0: 422 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 423 | else: 424 | assigner = None 425 | 426 | if assigner is not None: 427 | print("Assigned values = %s" % str(assigner.values)) 428 | 429 | skip_weight_decay_list = model.no_weight_decay() 430 | print("Skip weight decay list: ", skip_weight_decay_list) 431 | 432 | if args.enable_deepspeed: 433 | loss_scaler = None 434 | optimizer_params = get_parameter_groups( 435 | model, args.weight_decay, skip_weight_decay_list, 436 | assigner.get_layer_id if assigner is not None else None, 437 | assigner.get_scale if assigner is not None else None) 438 | model, optimizer, _, _ = ds_init( 439 | args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed, 440 | ) 441 | 442 | print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps()) 443 | assert model.gradient_accumulation_steps() == args.update_freq 444 | else: 445 | if args.distributed: 446 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 447 | model_without_ddp = model.module 448 | 449 | optimizer = create_optimizer( 450 | args, model_without_ddp, skip_list=skip_weight_decay_list, 451 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 452 | get_layer_scale=assigner.get_scale if assigner is not None else None) 453 | loss_scaler = NativeScaler() 454 | 455 | print("Use step level LR scheduler!") 456 | lr_schedule_values = utils.cosine_scheduler( 457 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 458 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 459 | ) 460 | if args.weight_decay_end is None: 461 | args.weight_decay_end = args.weight_decay 462 | wd_schedule_values = utils.cosine_scheduler( 463 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 464 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 465 | 466 | if mixup_fn is not None: 467 | # smoothing is handled with mixup label transform 468 | criterion = SoftTargetCrossEntropy() 469 | elif args.smoothing > 0.: 470 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 471 | else: 472 | criterion = torch.nn.CrossEntropyLoss() 473 | 474 | print("criterion = %s" % str(criterion)) 475 | 476 | utils.auto_load_model( 477 | args=args, model=model, model_without_ddp=model_without_ddp, 478 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 479 | 480 | if args.eval: 481 | preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt') 482 | test_stats = final_test(data_loader_test, model, device, preds_file) 483 | torch.distributed.barrier() 484 | if global_rank == 0: 485 | print("Start merging results...") 486 | final_top1 ,final_top5 = merge(args.output_dir, num_tasks) 487 | print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%") 488 | log_stats = {'Final top-1': final_top1, 489 | 'Final Top-5': final_top1} 490 | if args.output_dir and utils.is_main_process(): 491 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 492 | f.write(json.dumps(log_stats) + "\n") 493 | exit(0) 494 | 495 | 496 | print(f"Start training for {args.epochs} epochs") 497 | start_time = time.time() 498 | max_accuracy = 0.0 499 | for epoch in range(args.start_epoch, args.epochs): 500 | if args.distributed: 501 | data_loader_train.sampler.set_epoch(epoch) 502 | if log_writer is not None: 503 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 504 | train_stats = train_one_epoch( 505 | model, criterion, data_loader_train, optimizer, 506 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 507 | log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch, 508 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 509 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 510 | ) 511 | if args.output_dir and args.save_ckpt: 512 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 513 | utils.save_model( 514 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 515 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 516 | if data_loader_val is not None: 517 | test_stats = validation_one_epoch(data_loader_val, model, device) 518 | print(f"Accuracy of the network on the {len(dataset_val)} val videos: {test_stats['acc1']:.1f}%") 519 | if max_accuracy < test_stats["acc1"]: 520 | max_accuracy = test_stats["acc1"] 521 | if args.output_dir and args.save_ckpt: 522 | utils.save_model( 523 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 524 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 525 | 526 | print(f'Max accuracy: {max_accuracy:.2f}%') 527 | if log_writer is not None: 528 | log_writer.update(val_acc1=test_stats['acc1'], head="perf", step=epoch) 529 | log_writer.update(val_acc5=test_stats['acc5'], head="perf", step=epoch) 530 | log_writer.update(val_loss=test_stats['loss'], head="perf", step=epoch) 531 | 532 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 533 | **{f'val_{k}': v for k, v in test_stats.items()}, 534 | 'epoch': epoch, 535 | 'n_parameters': n_parameters} 536 | else: 537 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 538 | 'epoch': epoch, 539 | 'n_parameters': n_parameters} 540 | if args.output_dir and utils.is_main_process(): 541 | if log_writer is not None: 542 | log_writer.flush() 543 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 544 | f.write(json.dumps(log_stats) + "\n") 545 | 546 | preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt') 547 | test_stats = final_test(data_loader_test, model, device, preds_file) 548 | torch.distributed.barrier() 549 | if global_rank == 0: 550 | print("Start merging results...") 551 | final_top1 ,final_top5 = merge(args.output_dir, num_tasks) 552 | print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%") 553 | log_stats = {'Final top-1': final_top1, 554 | 'Final Top-5': final_top5} 555 | if args.output_dir and utils.is_main_process(): 556 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 557 | f.write(json.dumps(log_stats) + "\n") 558 | 559 | 560 | total_time = time.time() - start_time 561 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 562 | print('Training time {}'.format(total_time_str)) 563 | 564 | 565 | if __name__ == '__main__': 566 | opts, ds_init = get_args() 567 | if opts.output_dir: 568 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True) 569 | main(opts, ds_init) 570 | --------------------------------------------------------------------------------