├── assets
├── EVEREST_plot.PNG
├── EVEREST_concept.PNG
├── VideoMS_concept.PNG
└── icml2024_main_figure.pdf
├── requirements.txt
├── scripts
├── hmdb51
│ ├── pretrain.sh
│ └── finetune.sh
└── ucf101
│ ├── pretrain.sh
│ └── finetune.sh
├── masking_generator.py
├── .gitignore
├── functional.py
├── README.md
├── volume_transforms.py
├── engine_for_pretraining.py
├── datasets.py
├── optim_factory.py
├── random_erasing.py
├── transforms.py
├── engine_for_finetuning.py
├── run_ms_pretraining.py
├── mixup.py
├── modeling_finetune.py
├── modeling_pretrain.py
├── rand_augment.py
├── utils.py
├── ucf.py
└── run_class_finetuning.py
/assets/EVEREST_plot.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/EVEREST_plot.PNG
--------------------------------------------------------------------------------
/assets/EVEREST_concept.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/EVEREST_concept.PNG
--------------------------------------------------------------------------------
/assets/VideoMS_concept.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/VideoMS_concept.PNG
--------------------------------------------------------------------------------
/assets/icml2024_main_figure.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunilhoho/EVEREST/HEAD/assets/icml2024_main_figure.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.4.12
2 | deepspeed==0.7.3
3 | tensorboardX
4 | decord
5 | einops
6 | opencv-python
7 | scipy
8 | pandas
--------------------------------------------------------------------------------
/scripts/hmdb51/pretrain.sh:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR='output/hmdb51/pt'
2 | DATA_PATH='path_to_data/HMDB51/train.csv'
3 |
4 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
5 | --master_port 12320 run_ms_pretraining.py \
6 | --data_path ${DATA_PATH} \
7 | --mask_type motion-centric \
8 | --motion_centric_masking_ratio 0.7 \
9 | --mask_ratio 0.9 \
10 | --model pretrain_videoms_base_patch16_224 \
11 | --decoder_depth 4 \
12 | --lr 1e-3 \
13 | --batch_size 24 \
14 | --num_frames 16 \
15 | --sampling_rate 2 \
16 | --opt adamw \
17 | --opt_betas 0.9 0.95 \
18 | --warmup_epochs 40 \
19 | --epochs 4800 \
20 | --save_ckpt_freq 100 \
21 | --log_dir ${OUTPUT_DIR} \
22 | --output_dir ${OUTPUT_DIR}
23 |
--------------------------------------------------------------------------------
/scripts/ucf101/pretrain.sh:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR='output/ucf101/pt'
2 | DATA_PATH='path_to_data/UCF101/train.csv'
3 |
4 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
5 | --master_port 12320 run_ms_pretraining.py \
6 | --data_path ${DATA_PATH} \
7 | --mask_type motion-centric \
8 | --motion_centric_masking_ratio 0.7 \
9 | --mask_ratio 0.9 \
10 | --model pretrain_videoms_base_patch16_224 \
11 | --decoder_depth 4 \
12 | --lr 1e-3 \
13 | --batch_size 24 \
14 | --num_frames 16 \
15 | --sampling_rate 4 \
16 | --opt adamw \
17 | --opt_betas 0.9 0.95 \
18 | --warmup_epochs 40 \
19 | --epochs 3200 \
20 | --save_ckpt_freq 100 \
21 | --log_dir ${OUTPUT_DIR} \
22 | --output_dir ${OUTPUT_DIR}
23 |
--------------------------------------------------------------------------------
/scripts/hmdb51/finetune.sh:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR='output/hmdb51/ft'
2 | DATA_PATH='path_to_data/HMDB51/' # path to HMDB51 annotation file (train.csv/val.csv/test.csv)
3 | MODEL_PATH='output/hmdb51/pt/checkpoint-4799.pth'
4 |
5 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
6 | --master_port 12320 run_class_finetuning.py \
7 | --model vit_base_patch16_224 \
8 | --data_path ${DATA_PATH} \
9 | --finetune ${MODEL_PATH} \
10 | --log_dir ${OUTPUT_DIR} \
11 | --output_dir ${OUTPUT_DIR} \
12 | --data_set HMDB51 \
13 | --nb_classes 51 \
14 | --batch_size 16 \
15 | --input_size 224 \
16 | --short_side_size 224 \
17 | --save_ckpt_freq 20 \
18 | --num_frames 16 \
19 | --sampling_rate 2 \
20 | --num_sample 1 \
21 | --opt adamw \
22 | --lr 1e-3 \
23 | --opt_betas 0.9 0.999 \
24 | --weight_decay 0.05 \
25 | --epochs 50 \
26 | --test_num_segment 10 \
27 | --test_num_crop 3 \
28 | --use_checkpoint \
29 | --dist_eval \
30 | --enable_deepspeed \
31 | --mcm \
32 | --mcm_ratio 0.4 \
33 | --mixup 0.0 \
34 | --cutmix 0.0
35 |
--------------------------------------------------------------------------------
/scripts/ucf101/finetune.sh:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR='output/ucf101/ft'
2 | DATA_PATH='path_to_data/UCF101/' # path to UCF101 annotation file (train.csv/val.csv/test.csv)
3 | MODEL_PATH='output/ucf101/pt/checkpoint-3199.pth'
4 |
5 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
6 | --master_port 12320 run_class_finetuning.py \
7 | --model vit_base_patch16_224 \
8 | --data_path ${DATA_PATH} \
9 | --finetune ${MODEL_PATH} \
10 | --log_dir ${OUTPUT_DIR} \
11 | --output_dir ${OUTPUT_DIR} \
12 | --data_set UCF101 \
13 | --nb_classes 101 \
14 | --batch_size 16 \
15 | --input_size 224 \
16 | --short_side_size 224 \
17 | --save_ckpt_freq 20 \
18 | --num_frames 16 \
19 | --sampling_rate 4 \
20 | --num_sample 1 \
21 | --opt adamw \
22 | --lr 1e-3 \
23 | --opt_betas 0.9 0.999 \
24 | --weight_decay 0.05 \
25 | --epochs 100 \
26 | --test_num_segment 5 \
27 | --test_num_crop 3 \
28 | --use_checkpoint \
29 | --dist_eval \
30 | --enable_deepspeed \
31 | --mcm \
32 | --mcm_ratio 0.4 \
33 | --mixup 0.0 \
34 | --cutmix 0.0
35 |
--------------------------------------------------------------------------------
/masking_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class TubeMaskingGenerator:
4 | def __init__(self, input_size, mask_ratio):
5 | self.frames, self.height, self.width = input_size
6 | self.num_patches_per_frame = self.height * self.width
7 | self.total_patches = self.frames * self.num_patches_per_frame
8 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
9 | self.total_masks = self.frames * self.num_masks_per_frame
10 |
11 | def __repr__(self):
12 | repr_str = "Maks: total patches {}, mask patches {}".format(
13 | self.total_patches, self.total_masks
14 | )
15 | return repr_str
16 |
17 | def __call__(self):
18 | mask_per_frame = np.hstack([
19 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
20 | np.ones(self.num_masks_per_frame),
21 | ])
22 | np.random.shuffle(mask_per_frame)
23 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
24 | return mask
25 |
26 | class RandomMaskingGenerator:
27 | def __init__(self, input_size, mask_ratio):
28 | self.frames, self.height, self.width = input_size
29 | self.num_patches_per_frame = self.height * self.width
30 | self.total_patches = self.frames * self.num_patches_per_frame
31 | self.total_masks = int(mask_ratio * self.total_patches)
32 |
33 | def __repr__(self):
34 | repr_str = "Maks: total patches {}, mask patches {}".format(
35 | self.total_patches, self.total_masks
36 | )
37 | return repr_str
38 |
39 | def __call__(self):
40 | mask_per_input = np.hstack([
41 | np.zeros(self.total_patches - self.total_masks),
42 | np.ones(self.total_masks),
43 | ])
44 | np.random.shuffle(mask_per_input)
45 | return mask_per_input
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | output/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | */__pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
--------------------------------------------------------------------------------
/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import cv2
3 | import numpy as np
4 | import PIL
5 | import torch
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [
18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19 | ]
20 | else:
21 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
22 | 'but got list of {0}'.format(type(clip[0])))
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation='bilinear'):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
32 | and im_h == size):
33 | return clip
34 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
35 | size = (new_w, new_h)
36 | else:
37 | size = size[0], size[1]
38 | if interpolation == 'bilinear':
39 | np_inter = cv2.INTER_LINEAR
40 | else:
41 | np_inter = cv2.INTER_NEAREST
42 | scaled = [
43 | cv2.resize(img, size, interpolation=np_inter) for img in clip
44 | ]
45 | elif isinstance(clip[0], PIL.Image.Image):
46 | if isinstance(size, numbers.Number):
47 | im_w, im_h = clip[0].size
48 | # Min spatial dim already matches minimal size
49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
50 | and im_h == size):
51 | return clip
52 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
53 | size = (new_w, new_h)
54 | else:
55 | size = size[1], size[0]
56 | if interpolation == 'bilinear':
57 | pil_inter = PIL.Image.BILINEAR
58 | else:
59 | pil_inter = PIL.Image.NEAREST
60 | scaled = [img.resize(size, pil_inter) for img in clip]
61 | else:
62 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
63 | 'but got list of {0}'.format(type(clip[0])))
64 | return scaled
65 |
66 |
67 | def get_resize_sizes(im_h, im_w, size):
68 | if im_w < im_h:
69 | ow = size
70 | oh = int(size * im_h / im_w)
71 | else:
72 | oh = size
73 | ow = int(size * im_w / im_h)
74 | return oh, ow
75 |
76 |
77 | def normalize(clip, mean, std, inplace=False):
78 | if not _is_tensor_clip(clip):
79 | raise TypeError('tensor is not a torch clip.')
80 |
81 | if not inplace:
82 | clip = clip.clone()
83 |
84 | dtype = clip.dtype
85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
88 |
89 | return clip
90 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EVEREST: Efficient Masked Video Autoencoder by Removing Redundant Spatiotemporal Tokens [ICML2024]
2 |
3 | This repository is an official Pytorch implementation of [EVEREST: Efficient Masked Video Autoencoder by Removing Redundant Spatiotemporal Tokens](https://arxiv.org/abs/2211.10636).
4 |
5 | **The new version of EVEREST will be updated soon!!!** 🚨
6 |
7 |
8 |
9 |
10 |
11 | ## Abstract
12 |
13 | Masked Video Autoencoder (MVA) approaches have demonstrated their potential by significantly outperforming previous video representation learning methods. However, they waste an excessive amount of computations and memory in predicting uninformative tokens/frames due to random masking strategies. (e.g., over 16 nodes with 128 NVIDIA A100 GPUs). To resolve this issue, we exploit the unequal information density among the patches in videos and propose EVEREST, a surprisingly efficient MVA approach for video representation learning that finds tokens containing rich motion features and discards uninformative ones during both pre-training and fine-tuning. We further present an information-intensive frame selection strategy that allows the model to focus on informative and causal frames with minimal redundancy. Our method significantly reduces the computation and memory requirements of MVA, enabling the pre-training and fine-tuning on a single machine with 8 GPUs while achieving comparable performance to computation- and memory-heavy baselines on multiple benchmarks and the uncurated Ego4D dataset. We hope that our work contributes to reducing the barrier to further research on video understanding.
14 |
15 | ## Results
16 |
17 |
18 |
19 |
20 |
21 | ## Prerequisites
22 | EVEREST is built in **Python 3.7.12**, **torch 1.8.0** and **torchvision 0.9.0**. Please use the following command to install the requirements:
23 | ```
24 | $ pip install -r requirements.txt
25 | ```
26 |
27 | ## Run
28 | 1. __UCF101__ experiment
29 | ```
30 | $ bash scripts/ucf101/pretrain.sh
31 | $ bash scripts/ucf101/finetune.sh
32 | ```
33 |
34 | 2. __HMDB51__ experiment
35 |
36 | ```
37 | $ bash scripts/hmdb51/pretrain.sh
38 | $ bash scripts/hmdb51/finetune.sh
39 | ```
40 |
41 | 3. __K400, SSv2, OSCC__ experiment will be released soon.
42 |
43 | ## Dataset
44 | 1. Download [UCF101](https://www.crcv.ucf.edu/data/UCF101.php) and [HMDB51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) datasets from the official websites.
45 | 2. Make annotation files in `*.csv` format like this:
46 | ```
47 | path_to_video/video_0.avi 0
48 | path_to_video/video_1.avi 0
49 | ...
50 | path_to_video/video_N.avi 101
51 | ```
52 |
53 | ## Training Logs and Checkpoints
54 | ### UCF101
55 |
56 | | Backbone | \#Frame | Pre-train (3,200 epochs) | Fine-tune (100 epochs) | Top-1 | Top-5 |
57 | | :------: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: |
58 | | ViT-B | 16x5x3 | [log](https://drive.google.com/file/d/1dupg3ultdh1qsijUAYSZm8-hW2SAspLT/view?usp=share_link) / [checkpoint](https://drive.google.com/file/d/1liGNGprKdfiOCArK-WMqIcfeOJ-AZKzr/view?usp=share_link) | [log](https://drive.google.com/file/d/1EMlHBPqTC1_QURiCiaOdwPeoXdL67Gql/view?usp=share_link) / [checkpoint](https://drive.google.com/file/d/1iGFUxYpzjb7zaajB0O0j1MzS6PzyzrQF/view?usp=share_link) | 93.7 | 98.9 |
59 |
60 | ## Contact
61 | Sunil Hwang: sunilhoho@kaist.ac.kr
62 | Jaehong Yoon: jaehong.yoon@kaist.ac.kr
63 |
64 | ## Acknowledgment
65 | The code is built upon [VidoeMAE](https://github.com/MCG-NJU/VideoMAE).
66 |
67 | ## Citations
68 | ```
69 | @inproceedings{hwang2024everest,
70 | title={EVEREST: Efficient Masked Video Autoencoder by Removing Redundant Spatiotemporal Tokens},
71 | author={Hwang, Sunil and Yoon, Jaehong and Lee, Youngwan and Hwang, Sung Ju},
72 | booktitle={International Conference on Machine Learning},
73 | year={2024},
74 | }
75 | ```
--------------------------------------------------------------------------------
/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 |
6 | def convert_img(img):
7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format
8 | """
9 | if len(img.shape) == 3:
10 | img = img.transpose(2, 0, 1)
11 | if len(img.shape) == 2:
12 | img = np.expand_dims(img, 0)
13 | return img
14 |
15 |
16 | class ClipToTensor(object):
17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
19 | """
20 |
21 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
22 | self.channel_nb = channel_nb
23 | self.div_255 = div_255
24 | self.numpy = numpy
25 |
26 | def __call__(self, clip):
27 | """
28 | Args: clip (list of numpy.ndarray): clip (list of images)
29 | to be converted to tensor.
30 | """
31 | # Retrieve shape
32 | if isinstance(clip[0], np.ndarray):
33 | h, w, ch = clip[0].shape
34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
35 | ch)
36 | elif isinstance(clip[0], Image.Image):
37 | w, h = clip[0].size
38 | else:
39 | raise TypeError('Expected numpy.ndarray or PIL.Image\
40 | but got list of {0}'.format(type(clip[0])))
41 |
42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
43 |
44 | # Convert
45 | for img_idx, img in enumerate(clip):
46 | if isinstance(img, np.ndarray):
47 | pass
48 | elif isinstance(img, Image.Image):
49 | img = np.array(img, copy=False)
50 | else:
51 | raise TypeError('Expected numpy.ndarray or PIL.Image\
52 | but got list of {0}'.format(type(clip[0])))
53 | img = convert_img(img)
54 | np_clip[:, img_idx, :, :] = img
55 | if self.numpy:
56 | if self.div_255:
57 | np_clip = np_clip / 255.0
58 | return np_clip
59 |
60 | else:
61 | tensor_clip = torch.from_numpy(np_clip)
62 |
63 | if not isinstance(tensor_clip, torch.FloatTensor):
64 | tensor_clip = tensor_clip.float()
65 | if self.div_255:
66 | tensor_clip = torch.div(tensor_clip, 255)
67 | return tensor_clip
68 |
69 |
70 | # Note this norms data to -1/1
71 | class ClipToTensor_K(object):
72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
74 | """
75 |
76 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
77 | self.channel_nb = channel_nb
78 | self.div_255 = div_255
79 | self.numpy = numpy
80 |
81 | def __call__(self, clip):
82 | """
83 | Args: clip (list of numpy.ndarray): clip (list of images)
84 | to be converted to tensor.
85 | """
86 | # Retrieve shape
87 | if isinstance(clip[0], np.ndarray):
88 | h, w, ch = clip[0].shape
89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
90 | ch)
91 | elif isinstance(clip[0], Image.Image):
92 | w, h = clip[0].size
93 | else:
94 | raise TypeError('Expected numpy.ndarray or PIL.Image\
95 | but got list of {0}'.format(type(clip[0])))
96 |
97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
98 |
99 | # Convert
100 | for img_idx, img in enumerate(clip):
101 | if isinstance(img, np.ndarray):
102 | pass
103 | elif isinstance(img, Image.Image):
104 | img = np.array(img, copy=False)
105 | else:
106 | raise TypeError('Expected numpy.ndarray or PIL.Image\
107 | but got list of {0}'.format(type(clip[0])))
108 | img = convert_img(img)
109 | np_clip[:, img_idx, :, :] = img
110 | if self.numpy:
111 | if self.div_255:
112 | np_clip = (np_clip - 127.5) / 127.5
113 | return np_clip
114 |
115 | else:
116 | tensor_clip = torch.from_numpy(np_clip)
117 |
118 | if not isinstance(tensor_clip, torch.FloatTensor):
119 | tensor_clip = tensor_clip.float()
120 | if self.div_255:
121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
122 | return tensor_clip
123 |
124 |
125 | class ToTensor(object):
126 | """Converts numpy array to tensor
127 | """
128 |
129 | def __call__(self, array):
130 | tensor = torch.from_numpy(array)
131 | return tensor
132 |
--------------------------------------------------------------------------------
/engine_for_pretraining.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 | import torch
5 | import torch.nn as nn
6 | import utils
7 | from einops import rearrange
8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9 |
10 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16,
12 | normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None,
13 | lr_schedule_values=None, wd_schedule_values=None, mask_type='motion-centric'):
14 | model.train()
15 | metric_logger = utils.MetricLogger(delimiter=" ")
16 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
17 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
18 | header = 'Epoch: [{}]'.format(epoch)
19 | print_freq = 10
20 |
21 | loss_func = nn.MSELoss()
22 |
23 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
24 | # assign learning rate & weight decay for each step
25 | it = start_steps + step # global training iteration
26 | if lr_schedule_values is not None or wd_schedule_values is not None:
27 | for i, param_group in enumerate(optimizer.param_groups):
28 | if lr_schedule_values is not None:
29 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
30 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
31 | param_group["weight_decay"] = wd_schedule_values[it]
32 |
33 | videos, bool_masked_pos = batch
34 | videos = videos.to(device, non_blocking=True)
35 | if mask_type != 'motion-centric':
36 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
37 |
38 | with torch.no_grad():
39 | # calculate the predict label
40 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
41 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
42 | unnorm_videos = videos * std + mean # in [0, 1]
43 |
44 | if normlize_target:
45 | videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size)
46 | videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True)
47 | ) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
48 | # we find that the mean is about 0.48 and standard deviation is about 0.08.
49 | videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)')
50 | else:
51 | videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size)
52 |
53 | B, _, C = videos_patch.shape
54 | if mask_type != 'motion-centric':
55 | labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
56 |
57 | with torch.cuda.amp.autocast():
58 | outputs, (_, mc_target_mask) = model(videos, bool_masked_pos)
59 | if mask_type == 'motion-centric':
60 | labels = videos_patch[~mc_target_mask].reshape(B, -1, C)
61 | loss = loss_func(input=outputs, target=labels)
62 |
63 | loss_value = loss.item()
64 |
65 | if not math.isfinite(loss_value):
66 | print("Loss is {}, stopping training".format(loss_value))
67 | sys.exit(1)
68 |
69 | optimizer.zero_grad()
70 | # this attribute is added by timm on one optimizer (adahessian)
71 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
72 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
73 | parameters=model.parameters(), create_graph=is_second_order)
74 | loss_scale_value = loss_scaler.state_dict()["scale"]
75 |
76 | torch.cuda.synchronize()
77 |
78 | metric_logger.update(loss=loss_value)
79 | metric_logger.update(loss_scale=loss_scale_value)
80 | min_lr = 10.
81 | max_lr = 0.
82 | for group in optimizer.param_groups:
83 | min_lr = min(min_lr, group["lr"])
84 | max_lr = max(max_lr, group["lr"])
85 |
86 | metric_logger.update(lr=max_lr)
87 | metric_logger.update(min_lr=min_lr)
88 | weight_decay_value = None
89 | for group in optimizer.param_groups:
90 | if group["weight_decay"] > 0:
91 | weight_decay_value = group["weight_decay"]
92 | metric_logger.update(weight_decay=weight_decay_value)
93 | metric_logger.update(grad_norm=grad_norm)
94 |
95 | if log_writer is not None:
96 | log_writer.update(loss=loss_value, head="loss")
97 | log_writer.update(loss_scale=loss_scale_value, head="opt")
98 | log_writer.update(lr=max_lr, head="opt")
99 | log_writer.update(min_lr=min_lr, head="opt")
100 | log_writer.update(weight_decay=weight_decay_value, head="opt")
101 | log_writer.update(grad_norm=grad_norm, head="opt")
102 | log_writer.set_step()
103 |
104 | if lr_scheduler is not None:
105 | lr_scheduler.step_update(start_steps + step)
106 | # gather the stats from all processes
107 | metric_logger.synchronize_between_processes()
108 | print("Averaged stats:", metric_logger)
109 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
110 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import transforms
3 | from transforms import *
4 | from masking_generator import TubeMaskingGenerator, RandomMaskingGenerator
5 | from ucf import VideoClsDataset, VideoMS
6 |
7 |
8 | class DataAugmentationForVideoMS(object):
9 | def __init__(self, args):
10 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
11 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
12 | normalize = GroupNormalize(self.input_mean, self.input_std)
13 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
14 | self.transform = transforms.Compose([
15 | self.train_augmentation,
16 | Stack(roll=False),
17 | ToTorchFormatTensor(div=True),
18 | normalize,
19 | ])
20 | self.mcm = False
21 | if args.mask_type == 'tube':
22 | self.masked_position_generator = TubeMaskingGenerator(
23 | args.window_size, args.mask_ratio
24 | )
25 | elif args.mask_type == 'random':
26 | self.masked_position_generator = RandomMaskingGenerator(
27 | args.window_size, args.mask_ratio
28 | )
29 | elif args.mask_type == 'motion-centric':
30 | self.mcm = True
31 | self.masked_position_generator = 'Motion-centric Masking'
32 |
33 | def __call__(self, images):
34 | process_data, _ = self.transform(images)
35 | if self.mcm:
36 | return process_data, 0
37 | return process_data, self.masked_position_generator()
38 |
39 | def __repr__(self):
40 | repr = "(DataAugmentationForVideoMS,\n"
41 | repr += " transform = %s,\n" % str(self.transform)
42 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
43 | repr += ")"
44 | return repr
45 |
46 |
47 | def build_pretraining_dataset(args):
48 | transform = DataAugmentationForVideoMS(args)
49 | dataset = VideoMS(
50 | root=None,
51 | setting=args.data_path,
52 | video_ext='mp4',
53 | is_color=True,
54 | modality='rgb',
55 | new_length=args.num_frames,
56 | new_step=args.sampling_rate,
57 | transform=transform,
58 | temporal_jitter=False,
59 | video_loader=True,
60 | use_decord=True,
61 | lazy_init=False)
62 | print("Data Aug = %s" % str(transform))
63 | return dataset
64 |
65 |
66 | def build_dataset(is_train, test_mode, args):
67 | if args.data_set == 'UCF101':
68 | mode = None
69 | anno_path = None
70 | if is_train is True:
71 | mode = 'train'
72 | anno_path = os.path.join(args.data_path, 'train.csv')
73 | elif test_mode is True:
74 | mode = 'test'
75 | anno_path = os.path.join(args.data_path, 'test.csv')
76 | else:
77 | mode = 'validation'
78 | anno_path = os.path.join(args.data_path, 'val.csv')
79 |
80 | dataset = VideoClsDataset(
81 | anno_path=anno_path,
82 | data_path='/',
83 | mode=mode,
84 | clip_len=args.num_frames,
85 | frame_sample_rate=args.sampling_rate,
86 | num_segment=1,
87 | test_num_segment=args.test_num_segment,
88 | test_num_crop=args.test_num_crop,
89 | num_crop=1 if not test_mode else 3,
90 | keep_aspect_ratio=True,
91 | crop_size=args.input_size,
92 | short_side_size=args.short_side_size,
93 | new_height=256,
94 | new_width=320,
95 | args=args)
96 | nb_classes = 101
97 |
98 | elif args.data_set == 'HMDB51':
99 | mode = None
100 | anno_path = None
101 | if is_train is True:
102 | mode = 'train'
103 | anno_path = os.path.join(args.data_path, 'train.csv')
104 | elif test_mode is True:
105 | mode = 'test'
106 | anno_path = os.path.join(args.data_path, 'test.csv')
107 | else:
108 | mode = 'validation'
109 | anno_path = os.path.join(args.data_path, 'val.csv')
110 |
111 | dataset = VideoClsDataset(
112 | anno_path=anno_path,
113 | data_path='/',
114 | mode=mode,
115 | clip_len=args.num_frames,
116 | frame_sample_rate=args.sampling_rate,
117 | num_segment=1,
118 | test_num_segment=args.test_num_segment,
119 | test_num_crop=args.test_num_crop,
120 | num_crop=1 if not test_mode else 3,
121 | keep_aspect_ratio=True,
122 | crop_size=args.input_size,
123 | short_side_size=args.short_side_size,
124 | new_height=256,
125 | new_width=320,
126 | args=args)
127 | nb_classes = 51
128 |
129 | elif args.data_set == 'OSCC':
130 | mode = None
131 | anno_path = None
132 | if is_train is True:
133 | mode = 'train'
134 | anno_path = os.path.join(args.data_path, 'train.csv')
135 | elif test_mode is True:
136 | mode = 'test'
137 | anno_path = os.path.join(args.data_path, 'test.csv')
138 | else:
139 | mode = 'validation'
140 | anno_path = os.path.join(args.data_path, 'val.csv')
141 |
142 | dataset = VideoClsDataset(
143 | anno_path=anno_path,
144 | data_path='/',
145 | mode=mode,
146 | clip_len=args.num_frames,
147 | frame_sample_rate=args.sampling_rate,
148 | num_segment=1,
149 | test_num_segment=args.test_num_segment,
150 | test_num_crop=args.test_num_crop,
151 | num_crop=1 if not test_mode else 3,
152 | keep_aspect_ratio=True,
153 | crop_size=args.input_size,
154 | short_side_size=args.short_side_size,
155 | new_height=256,
156 | new_width=320,
157 | args=args)
158 | nb_classes = 2
159 |
160 | else:
161 | raise NotImplementedError()
162 | assert nb_classes == args.nb_classes
163 | print("Number of the class = %d" % args.nb_classes)
164 |
165 | return dataset, nb_classes
166 |
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_vit(var_name, num_max_layer):
25 | if var_name in ("cls_token", "mask_token", "pos_embed"):
26 | return 0
27 | elif var_name.startswith("patch_embed"):
28 | return 0
29 | elif var_name.startswith("rel_pos_bias"):
30 | return num_max_layer - 1
31 | elif var_name.startswith("blocks"):
32 | layer_id = int(var_name.split('.')[1])
33 | return layer_id + 1
34 | else:
35 | return num_max_layer - 1
36 |
37 |
38 | class LayerDecayValueAssigner(object):
39 | def __init__(self, values):
40 | self.values = values
41 |
42 | def get_scale(self, layer_id):
43 | return self.values[layer_id]
44 |
45 | def get_layer_id(self, var_name):
46 | return get_num_layer_for_vit(var_name, len(self.values))
47 |
48 |
49 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
50 | parameter_group_names = {}
51 | parameter_group_vars = {}
52 |
53 | for name, param in model.named_parameters():
54 | if not param.requires_grad:
55 | continue # frozen weights
56 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
57 | group_name = "no_decay"
58 | this_weight_decay = 0.
59 | else:
60 | group_name = "decay"
61 | this_weight_decay = weight_decay
62 | if get_num_layer is not None:
63 | layer_id = get_num_layer(name)
64 | group_name = "layer_%d_%s" % (layer_id, group_name)
65 | else:
66 | layer_id = None
67 |
68 | if group_name not in parameter_group_names:
69 | if get_layer_scale is not None:
70 | scale = get_layer_scale(layer_id)
71 | else:
72 | scale = 1.
73 |
74 | parameter_group_names[group_name] = {
75 | "weight_decay": this_weight_decay,
76 | "params": [],
77 | "lr_scale": scale
78 | }
79 | parameter_group_vars[group_name] = {
80 | "weight_decay": this_weight_decay,
81 | "params": [],
82 | "lr_scale": scale
83 | }
84 |
85 | parameter_group_vars[group_name]["params"].append(param)
86 | parameter_group_names[group_name]["params"].append(name)
87 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
88 | return list(parameter_group_vars.values())
89 |
90 |
91 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
92 | opt_lower = args.opt.lower()
93 | weight_decay = args.weight_decay
94 | if weight_decay and filter_bias_and_bn:
95 | skip = {}
96 | if skip_list is not None:
97 | skip = skip_list
98 | elif hasattr(model, 'no_weight_decay'):
99 | skip = model.no_weight_decay()
100 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
101 | weight_decay = 0.
102 | else:
103 | parameters = model.parameters()
104 |
105 | if 'fused' in opt_lower:
106 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
107 |
108 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
109 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
110 | opt_args['eps'] = args.opt_eps
111 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
112 | opt_args['betas'] = args.opt_betas
113 |
114 | print("optimizer settings:", opt_args)
115 |
116 | opt_split = opt_lower.split('_')
117 | opt_lower = opt_split[-1]
118 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
119 | opt_args.pop('eps', None)
120 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
121 | elif opt_lower == 'momentum':
122 | opt_args.pop('eps', None)
123 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
124 | elif opt_lower == 'adam':
125 | optimizer = optim.Adam(parameters, **opt_args)
126 | elif opt_lower == 'adamw':
127 | optimizer = optim.AdamW(parameters, **opt_args)
128 | elif opt_lower == 'nadam':
129 | optimizer = Nadam(parameters, **opt_args)
130 | elif opt_lower == 'radam':
131 | optimizer = RAdam(parameters, **opt_args)
132 | elif opt_lower == 'adamp':
133 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
134 | elif opt_lower == 'sgdp':
135 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
136 | elif opt_lower == 'adadelta':
137 | optimizer = optim.Adadelta(parameters, **opt_args)
138 | elif opt_lower == 'adafactor':
139 | if not args.lr:
140 | opt_args['lr'] = None
141 | optimizer = Adafactor(parameters, **opt_args)
142 | elif opt_lower == 'adahessian':
143 | optimizer = Adahessian(parameters, **opt_args)
144 | elif opt_lower == 'rmsprop':
145 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
146 | elif opt_lower == 'rmsproptf':
147 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
148 | elif opt_lower == 'novograd':
149 | optimizer = NovoGrad(parameters, **opt_args)
150 | elif opt_lower == 'nvnovograd':
151 | optimizer = NvNovoGrad(parameters, **opt_args)
152 | elif opt_lower == 'fusedsgd':
153 | opt_args.pop('eps', None)
154 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
155 | elif opt_lower == 'fusedmomentum':
156 | opt_args.pop('eps', None)
157 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
158 | elif opt_lower == 'fusedadam':
159 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
160 | elif opt_lower == 'fusedadamw':
161 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
162 | elif opt_lower == 'fusedlamb':
163 | optimizer = FusedLAMB(parameters, **opt_args)
164 | elif opt_lower == 'fusednovograd':
165 | opt_args.setdefault('betas', (0.95, 0.98))
166 | optimizer = FusedNovoGrad(parameters, **opt_args)
167 | else:
168 | assert False and "Invalid optimizer"
169 | raise ValueError
170 |
171 | if len(opt_split) > 1:
172 | if opt_split[0] == 'lookahead':
173 | optimizer = Lookahead(optimizer)
174 |
175 | return optimizer
176 |
--------------------------------------------------------------------------------
/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import warnings
4 | import random
5 | import numpy as np
6 | import torchvision
7 | from PIL import Image, ImageOps
8 | import numbers
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_tuple):
19 | img_group, label = img_tuple
20 |
21 | w, h = img_group[0].size
22 | th, tw = self.size
23 |
24 | out_images = list()
25 |
26 | x1 = random.randint(0, w - tw)
27 | y1 = random.randint(0, h - th)
28 |
29 | for img in img_group:
30 | assert(img.size[0] == w and img.size[1] == h)
31 | if w == tw and h == th:
32 | out_images.append(img)
33 | else:
34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35 |
36 | return (out_images, label)
37 |
38 |
39 | class GroupCenterCrop(object):
40 | def __init__(self, size):
41 | self.worker = torchvision.transforms.CenterCrop(size)
42 |
43 | def __call__(self, img_tuple):
44 | img_group, label = img_tuple
45 | return ([self.worker(img) for img in img_group], label)
46 |
47 |
48 | class GroupNormalize(object):
49 | def __init__(self, mean, std):
50 | self.mean = mean
51 | self.std = std
52 |
53 | def __call__(self, tensor_tuple):
54 | tensor, label = tensor_tuple
55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56 | rep_std = self.std * (tensor.size()[0]//len(self.std))
57 |
58 | # TODO: make efficient
59 | for t, m, s in zip(tensor, rep_mean, rep_std):
60 | t.sub_(m).div_(s)
61 |
62 | return (tensor,label)
63 |
64 |
65 | class GroupGrayScale(object):
66 | def __init__(self, size):
67 | self.worker = torchvision.transforms.Grayscale(size)
68 |
69 | def __call__(self, img_tuple):
70 | img_group, label = img_tuple
71 | return ([self.worker(img) for img in img_group], label)
72 |
73 |
74 | class GroupScale(object):
75 | """ Rescales the input PIL.Image to the given 'size'.
76 | 'size' will be the size of the smaller edge.
77 | For example, if height > width, then image will be
78 | rescaled to (size * height / width, size)
79 | size: size of the smaller edge
80 | interpolation: Default: PIL.Image.BILINEAR
81 | """
82 |
83 | def __init__(self, size, interpolation=Image.BILINEAR):
84 | self.worker = torchvision.transforms.Resize(size, interpolation)
85 |
86 | def __call__(self, img_tuple):
87 | img_group, label = img_tuple
88 | return ([self.worker(img) for img in img_group], label)
89 |
90 |
91 | class GroupMultiScaleCrop(object):
92 |
93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
94 | self.scales = scales if scales is not None else [1, 875, .75, .66]
95 | self.max_distort = max_distort
96 | self.fix_crop = fix_crop
97 | self.more_fix_crop = more_fix_crop
98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
99 | self.interpolation = Image.BILINEAR
100 |
101 | def __call__(self, img_tuple):
102 | img_group, label = img_tuple
103 |
104 | im_size = img_group[0].size
105 |
106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
109 | return (ret_img_group, label)
110 |
111 | def _sample_crop_size(self, im_size):
112 | image_w, image_h = im_size[0], im_size[1]
113 |
114 | # find a crop size
115 | base_size = min(image_w, image_h)
116 | crop_sizes = [int(base_size * x) for x in self.scales]
117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
119 |
120 | pairs = []
121 | for i, h in enumerate(crop_h):
122 | for j, w in enumerate(crop_w):
123 | if abs(i - j) <= self.max_distort:
124 | pairs.append((w, h))
125 |
126 | crop_pair = random.choice(pairs)
127 | if not self.fix_crop:
128 | w_offset = random.randint(0, image_w - crop_pair[0])
129 | h_offset = random.randint(0, image_h - crop_pair[1])
130 | else:
131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
132 |
133 | return crop_pair[0], crop_pair[1], w_offset, h_offset
134 |
135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
137 | return random.choice(offsets)
138 |
139 | @staticmethod
140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
141 | w_step = (image_w - crop_w) // 4
142 | h_step = (image_h - crop_h) // 4
143 |
144 | ret = list()
145 | ret.append((0, 0)) # upper left
146 | ret.append((4 * w_step, 0)) # upper right
147 | ret.append((0, 4 * h_step)) # lower left
148 | ret.append((4 * w_step, 4 * h_step)) # lower right
149 | ret.append((2 * w_step, 2 * h_step)) # center
150 |
151 | if more_fix_crop:
152 | ret.append((0, 2 * h_step)) # center left
153 | ret.append((4 * w_step, 2 * h_step)) # center right
154 | ret.append((2 * w_step, 4 * h_step)) # lower center
155 | ret.append((2 * w_step, 0 * h_step)) # upper center
156 |
157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
161 | return ret
162 |
163 |
164 | class Stack(object):
165 |
166 | def __init__(self, roll=False):
167 | self.roll = roll
168 |
169 | def __call__(self, img_tuple):
170 | img_group, label = img_tuple
171 |
172 | if img_group[0].mode == 'L':
173 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
174 | elif img_group[0].mode == 'RGB':
175 | if self.roll:
176 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
177 | else:
178 | return (np.concatenate(img_group, axis=2), label)
179 |
180 |
181 | class ToTorchFormatTensor(object):
182 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
183 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
184 | def __init__(self, div=True):
185 | self.div = div
186 |
187 | def __call__(self, pic_tuple):
188 | pic, label = pic_tuple
189 |
190 | if isinstance(pic, np.ndarray):
191 | # handle numpy array
192 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
193 | else:
194 | # handle PIL Image
195 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
196 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
197 | # put it from HWC to CHW format
198 | # yikes, this transpose takes 80% of the loading time/CPU
199 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
200 | return (img.float().div(255.) if self.div else img.float(), label)
201 |
202 |
203 | class IdentityTransform(object):
204 |
205 | def __call__(self, data):
206 | return data
207 |
--------------------------------------------------------------------------------
/engine_for_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import sys
5 | from typing import Iterable, Optional
6 | import torch
7 | from mixup import Mixup
8 | from timm.utils import accuracy, ModelEma
9 | import utils
10 | from scipy.special import softmax
11 |
12 |
13 | def train_class_batch(model, samples, target, criterion):
14 | outputs = model(samples)
15 | loss = criterion(outputs, target)
16 | return loss, outputs
17 |
18 |
19 | def get_loss_scale_for_deepspeed(model):
20 | optimizer = model.optimizer
21 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
22 |
23 |
24 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
25 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
28 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
29 | num_training_steps_per_epoch=None, update_freq=None):
30 | model.train(True)
31 | metric_logger = utils.MetricLogger(delimiter=" ")
32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
33 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34 | header = 'Epoch: [{}]'.format(epoch)
35 | print_freq = 10
36 |
37 | if loss_scaler is None:
38 | model.zero_grad()
39 | model.micro_steps = 0
40 | else:
41 | optimizer.zero_grad()
42 |
43 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
44 | step = data_iter_step // update_freq
45 | if step >= num_training_steps_per_epoch:
46 | continue
47 | it = start_steps + step # global training iteration
48 | # Update LR & WD for the first acc
49 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
50 | for i, param_group in enumerate(optimizer.param_groups):
51 | if lr_schedule_values is not None:
52 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
53 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
54 | param_group["weight_decay"] = wd_schedule_values[it]
55 |
56 | samples = samples.to(device, non_blocking=True)
57 | targets = targets.to(device, non_blocking=True)
58 |
59 | if mixup_fn is not None:
60 | samples, targets = mixup_fn(samples, targets)
61 |
62 | if loss_scaler is None:
63 | samples = samples.half()
64 | loss, output = train_class_batch(
65 | model, samples, targets, criterion)
66 | else:
67 | with torch.cuda.amp.autocast():
68 | loss, output = train_class_batch(
69 | model, samples, targets, criterion)
70 |
71 | loss_value = loss.item()
72 |
73 | if not math.isfinite(loss_value):
74 | print("Loss is {}, stopping training".format(loss_value))
75 | sys.exit(1)
76 |
77 | if loss_scaler is None:
78 | loss /= update_freq
79 | model.backward(loss)
80 | model.step()
81 |
82 | if (data_iter_step + 1) % update_freq == 0:
83 | # model.zero_grad()
84 | # Deepspeed will call step() & model.zero_grad() automatic
85 | if model_ema is not None:
86 | model_ema.update(model)
87 | grad_norm = None
88 | loss_scale_value = get_loss_scale_for_deepspeed(model)
89 | else:
90 | # this attribute is added by timm on one optimizer (adahessian)
91 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
92 | loss /= update_freq
93 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
94 | parameters=model.parameters(), create_graph=is_second_order,
95 | update_grad=(data_iter_step + 1) % update_freq == 0)
96 | if (data_iter_step + 1) % update_freq == 0:
97 | optimizer.zero_grad()
98 | if model_ema is not None:
99 | model_ema.update(model)
100 | loss_scale_value = loss_scaler.state_dict()["scale"]
101 |
102 | torch.cuda.synchronize()
103 |
104 | if mixup_fn is None:
105 | class_acc = (output.max(-1)[-1] == targets).float().mean()
106 | else:
107 | class_acc = None
108 | metric_logger.update(loss=loss_value)
109 | metric_logger.update(class_acc=class_acc)
110 | metric_logger.update(loss_scale=loss_scale_value)
111 | min_lr = 10.
112 | max_lr = 0.
113 | for group in optimizer.param_groups:
114 | min_lr = min(min_lr, group["lr"])
115 | max_lr = max(max_lr, group["lr"])
116 |
117 | metric_logger.update(lr=max_lr)
118 | metric_logger.update(min_lr=min_lr)
119 | weight_decay_value = None
120 | for group in optimizer.param_groups:
121 | if group["weight_decay"] > 0:
122 | weight_decay_value = group["weight_decay"]
123 | metric_logger.update(weight_decay=weight_decay_value)
124 | metric_logger.update(grad_norm=grad_norm)
125 |
126 | if log_writer is not None:
127 | log_writer.update(loss=loss_value, head="loss")
128 | log_writer.update(class_acc=class_acc, head="loss")
129 | log_writer.update(loss_scale=loss_scale_value, head="opt")
130 | log_writer.update(lr=max_lr, head="opt")
131 | log_writer.update(min_lr=min_lr, head="opt")
132 | log_writer.update(weight_decay=weight_decay_value, head="opt")
133 | log_writer.update(grad_norm=grad_norm, head="opt")
134 |
135 | log_writer.set_step()
136 |
137 | # gather the stats from all processes
138 | metric_logger.synchronize_between_processes()
139 | print("Averaged stats:", metric_logger)
140 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
141 |
142 |
143 | @torch.no_grad()
144 | def validation_one_epoch(data_loader, model, device):
145 | criterion = torch.nn.CrossEntropyLoss()
146 |
147 | metric_logger = utils.MetricLogger(delimiter=" ")
148 | header = 'Val:'
149 |
150 | # switch to evaluation mode
151 | model.eval()
152 |
153 | for batch in metric_logger.log_every(data_loader, 10, header):
154 | videos = batch[0]
155 | target = batch[1]
156 | videos = videos.to(device, non_blocking=True)
157 | target = target.to(device, non_blocking=True)
158 |
159 | # compute output
160 | with torch.cuda.amp.autocast():
161 | output = model(videos)
162 | loss = criterion(output, target)
163 |
164 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
165 |
166 | batch_size = videos.shape[0]
167 | metric_logger.update(loss=loss.item())
168 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
169 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
170 | # gather the stats from all processes
171 | metric_logger.synchronize_between_processes()
172 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
173 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
174 |
175 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
176 |
177 |
178 |
179 | @torch.no_grad()
180 | def final_test(data_loader, model, device, file):
181 | criterion = torch.nn.CrossEntropyLoss()
182 |
183 | metric_logger = utils.MetricLogger(delimiter=" ")
184 | header = 'Test:'
185 |
186 | # switch to evaluation mode
187 | model.eval()
188 | final_result = []
189 |
190 | for batch in metric_logger.log_every(data_loader, 10, header):
191 | videos = batch[0]
192 | target = batch[1]
193 | ids = batch[2]
194 | chunk_nb = batch[3]
195 | split_nb = batch[4]
196 | videos = videos.to(device, non_blocking=True)
197 | target = target.to(device, non_blocking=True)
198 |
199 | # compute output
200 | with torch.cuda.amp.autocast():
201 | output = model(videos)
202 | loss = criterion(output, target)
203 |
204 | for i in range(output.size(0)):
205 | string = "{} {} {} {} {}\n".format(ids[i], \
206 | str(output.data[i].cpu().numpy().tolist()), \
207 | str(int(target[i].cpu().numpy())), \
208 | str(int(chunk_nb[i].cpu().numpy())), \
209 | str(int(split_nb[i].cpu().numpy())))
210 | final_result.append(string)
211 |
212 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
213 |
214 | batch_size = videos.shape[0]
215 | metric_logger.update(loss=loss.item())
216 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
217 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
218 |
219 | if not os.path.exists(file):
220 | os.mknod(file)
221 | with open(file, 'w') as f:
222 | f.write("{}, {}\n".format(acc1, acc5))
223 | for line in final_result:
224 | f.write(line)
225 | # gather the stats from all processes
226 | metric_logger.synchronize_between_processes()
227 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
228 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
229 |
230 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
231 |
232 |
233 | def merge(eval_path, num_tasks):
234 | dict_feats = {}
235 | dict_label = {}
236 | dict_pos = {}
237 | print("Reading individual output files")
238 |
239 | for x in range(num_tasks):
240 | file = os.path.join(eval_path, str(x) + '.txt')
241 | lines = open(file, 'r').readlines()[1:]
242 | for line in lines:
243 | line = line.strip()
244 | name = line.split('[')[0]
245 | label = line.split(']')[1].split(' ')[1]
246 | chunk_nb = line.split(']')[1].split(' ')[2]
247 | split_nb = line.split(']')[1].split(' ')[3]
248 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',')
249 | data = softmax(data)
250 | if not name in dict_feats:
251 | dict_feats[name] = []
252 | dict_label[name] = 0
253 | dict_pos[name] = []
254 | if chunk_nb + split_nb in dict_pos[name]:
255 | continue
256 | dict_feats[name].append(data)
257 | dict_pos[name].append(chunk_nb + split_nb)
258 | dict_label[name] = label
259 | print("Computing final results")
260 |
261 | input_lst = []
262 | print(len(dict_feats))
263 | for i, item in enumerate(dict_feats):
264 | input_lst.append([i, item, dict_feats[item], dict_label[item]])
265 | from multiprocessing import Pool
266 | p = Pool(64)
267 | ans = p.map(compute_video, input_lst)
268 | top1 = [x[1] for x in ans]
269 | top5 = [x[2] for x in ans]
270 | pred = [x[0] for x in ans]
271 | label = [x[3] for x in ans]
272 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
273 | return final_top1*100 ,final_top5*100
274 |
275 | def compute_video(lst):
276 | i, video_id, data, label = lst
277 | feat = [x for x in data]
278 | feat = np.mean(feat, axis=0)
279 | pred = np.argmax(feat)
280 | top1 = (int(pred) == int(label)) * 1.0
281 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
282 | return [pred, top1, top5, int(label)]
283 |
--------------------------------------------------------------------------------
/run_ms_pretraining.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import json
8 | import os
9 | from pathlib import Path
10 | from timm.models import create_model
11 | from optim_factory import create_optimizer
12 | from datasets import build_pretraining_dataset
13 | from engine_for_pretraining import train_one_epoch
14 | from utils import NativeScalerWithGradNormCount as NativeScaler
15 | import utils
16 | import modeling_pretrain
17 |
18 |
19 | def get_args():
20 | parser = argparse.ArgumentParser('VideoMS pre-training script', add_help=False)
21 | parser.add_argument('--batch_size', default=64, type=int)
22 | parser.add_argument('--epochs', default=800, type=int)
23 | parser.add_argument('--save_ckpt_freq', default=50, type=int)
24 |
25 | # Model parameters
26 | parser.add_argument('--model', default='pretrain_videoms_base_patch16_224', type=str, metavar='MODEL',
27 | help='Name of model to train')
28 |
29 | parser.add_argument('--decoder_depth', default=4, type=int,
30 | help='depth of decoder')
31 |
32 | parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'motion-centric'],
33 | type=str, help='masked strategy of video tokens/patches')
34 |
35 | parser.add_argument('--mask_ratio', default=0.75, type=float,
36 | help='ratio of the visual tokens/patches need be masked')
37 |
38 | parser.add_argument('--motion_centric_masking_ratio', default=0.7, type=float,
39 | help='ratio of the motion-centric masking')
40 |
41 | parser.add_argument('--input_size', default=224, type=int,
42 | help='videos input size for backbone')
43 |
44 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
45 | help='Drop path rate (default: 0.1)')
46 |
47 | parser.add_argument('--normlize_target', default=True, type=bool,
48 | help='normalized the target patch pixels')
49 |
50 | # Optimizer parameters
51 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
52 | help='Optimizer (default: "adamw"')
53 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
54 | help='Optimizer Epsilon (default: 1e-8)')
55 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
56 | help='Optimizer Betas (default: None, use opt default)')
57 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
58 | help='Clip gradient norm (default: None, no clipping)')
59 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
60 | help='SGD momentum (default: 0.9)')
61 | parser.add_argument('--weight_decay', type=float, default=0.05,
62 | help='weight decay (default: 0.05)')
63 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
64 | weight decay. We use a cosine schedule for WD.
65 | (Set the same value with args.weight_decay to keep weight decay no change)""")
66 |
67 | parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
68 | help='learning rate (default: 1.5e-4)')
69 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
70 | help='warmup learning rate (default: 1e-6)')
71 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
72 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
73 |
74 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
75 | help='epochs to warmup LR, if scheduler supports')
76 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
77 | help='epochs to warmup LR, if scheduler supports')
78 | parser.add_argument('--use_checkpoint', action='store_true')
79 | parser.set_defaults(use_checkpoint=False)
80 |
81 | # Augmentation parameters
82 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
83 | help='Color jitter factor (default: 0.4)')
84 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
85 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
86 |
87 | # Dataset parameters
88 | parser.add_argument('--data_path', default='/path/to/list_ucf-101', type=str,
89 | help='dataset path')
90 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
91 | parser.add_argument('--num_frames', type=int, default= 16)
92 | parser.add_argument('--sampling_rate', type=int, default= 4)
93 | parser.add_argument('--output_dir', default='',
94 | help='path where to save, empty for no saving')
95 | parser.add_argument('--log_dir', default=None,
96 | help='path where to tensorboard log')
97 | parser.add_argument('--device', default='cuda',
98 | help='device to use for training / testing')
99 | parser.add_argument('--seed', default=0, type=int)
100 | parser.add_argument('--resume', default='', help='resume from checkpoint')
101 | parser.add_argument('--auto_resume', action='store_true')
102 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
103 | parser.set_defaults(auto_resume=True)
104 |
105 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
106 | help='start epoch')
107 | parser.add_argument('--num_workers', default=10, type=int)
108 | parser.add_argument('--pin_mem', action='store_true',
109 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
110 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
111 | help='')
112 | parser.set_defaults(pin_mem=True)
113 |
114 | # distributed training parameters
115 | parser.add_argument('--world_size', default=1, type=int,
116 | help='number of distributed processes')
117 | parser.add_argument('--local_rank', default=-1, type=int)
118 | parser.add_argument('--dist_on_itp', action='store_true')
119 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
120 |
121 | return parser.parse_args()
122 |
123 |
124 | def get_model(args):
125 | print(f"Creating model: {args.model}")
126 | model = create_model(
127 | args.model,
128 | pretrained=False,
129 | drop_path_rate=args.drop_path,
130 | drop_block_rate=None,
131 | decoder_depth=args.decoder_depth,
132 | use_checkpoint=args.use_checkpoint,
133 | motion_centric_masking=args.mask_type=='motion-centric',
134 | motion_centric_masking_ratio=args.motion_centric_masking_ratio,
135 | masking_ratio=args.mask_ratio
136 | )
137 | return model
138 |
139 |
140 | def main(args):
141 | utils.init_distributed_mode(args)
142 |
143 | print(args)
144 |
145 | device = torch.device(args.device)
146 |
147 | # fix the seed for reproducibility
148 | seed = args.seed + utils.get_rank()
149 | torch.manual_seed(seed)
150 | np.random.seed(seed)
151 |
152 | cudnn.benchmark = True
153 |
154 | model = get_model(args)
155 | patch_size = model.encoder.patch_embed.patch_size
156 | print("Patch size = %s" % str(patch_size))
157 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
158 | args.patch_size = patch_size
159 |
160 | # get dataset
161 | dataset_train = build_pretraining_dataset(args)
162 |
163 |
164 | num_tasks = utils.get_world_size()
165 | global_rank = utils.get_rank()
166 | sampler_rank = global_rank
167 | num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks
168 |
169 | sampler_train = torch.utils.data.DistributedSampler(
170 | dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
171 | )
172 | print("Sampler_train = %s" % str(sampler_train))
173 |
174 |
175 | if global_rank == 0 and args.log_dir is not None:
176 | os.makedirs(args.log_dir, exist_ok=True)
177 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
178 | else:
179 | log_writer = None
180 |
181 | data_loader_train = torch.utils.data.DataLoader(
182 | dataset_train, sampler=sampler_train,
183 | batch_size=args.batch_size,
184 | num_workers=args.num_workers,
185 | pin_memory=args.pin_mem,
186 | drop_last=True,
187 | worker_init_fn=utils.seed_worker
188 | )
189 |
190 | model.to(device)
191 | model_without_ddp = model
192 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
193 |
194 | print("Model = %s" % str(model_without_ddp))
195 | print('number of params: {} M'.format(n_parameters / 1e6))
196 |
197 | total_batch_size = args.batch_size * utils.get_world_size()
198 |
199 | args.lr = args.lr * total_batch_size / 256
200 | args.min_lr = args.min_lr * total_batch_size / 256
201 | args.warmup_lr = args.warmup_lr * total_batch_size / 256
202 | print("LR = %.8f" % args.lr)
203 | print("Batch size = %d" % total_batch_size)
204 | print("Number of training steps = %d" % num_training_steps_per_epoch)
205 | print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
206 |
207 | if args.distributed:
208 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
209 | model_without_ddp = model.module
210 |
211 | optimizer = create_optimizer(
212 | args, model_without_ddp)
213 | loss_scaler = NativeScaler()
214 |
215 | print("Use step level LR & WD scheduler!")
216 | lr_schedule_values = utils.cosine_scheduler(
217 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
218 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
219 | )
220 | if args.weight_decay_end is None:
221 | args.weight_decay_end = args.weight_decay
222 | wd_schedule_values = utils.cosine_scheduler(
223 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
224 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
225 |
226 | utils.auto_load_model(
227 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
228 | torch.cuda.empty_cache()
229 | print(f"Start training for {args.epochs} epochs")
230 | start_time = time.time()
231 | for epoch in range(args.start_epoch, args.epochs):
232 | if args.distributed:
233 | data_loader_train.sampler.set_epoch(epoch)
234 | if log_writer is not None:
235 | log_writer.set_step(epoch * num_training_steps_per_epoch)
236 | train_stats = train_one_epoch(
237 | model, data_loader_train,
238 | optimizer, device, epoch, loss_scaler,
239 | args.clip_grad, log_writer=log_writer,
240 | start_steps=epoch * num_training_steps_per_epoch,
241 | lr_schedule_values=lr_schedule_values,
242 | wd_schedule_values=wd_schedule_values,
243 | patch_size=patch_size[0],
244 | normlize_target=args.normlize_target,
245 | mask_type=args.mask_type
246 | )
247 | if args.output_dir:
248 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
249 | utils.save_model(
250 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
251 | loss_scaler=loss_scaler, epoch=epoch)
252 |
253 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
254 | 'epoch': epoch, 'n_parameters': n_parameters}
255 |
256 | if args.output_dir and utils.is_main_process():
257 | if log_writer is not None:
258 | log_writer.flush()
259 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
260 | f.write(json.dumps(log_stats) + "\n")
261 |
262 | total_time = time.time() - start_time
263 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
264 | print('Training time {}'.format(total_time_str))
265 |
266 |
267 | if __name__ == '__main__':
268 | opts = get_args()
269 | if opts.output_dir:
270 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
271 | main(opts)
272 |
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27 | return y1 * lam + y2 * (1. - lam)
28 |
29 |
30 | def rand_bbox(img_shape, lam, margin=0., count=None):
31 | """ Standard CutMix bounding-box
32 | Generates a random square bbox based on lambda value. This impl includes
33 | support for enforcing a border margin as percent of bbox dimensions.
34 |
35 | Args:
36 | img_shape (tuple): Image shape as tuple
37 | lam (float): Cutmix lambda value
38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39 | count (int): Number of bbox to generate
40 | """
41 | ratio = np.sqrt(1 - lam)
42 | img_h, img_w = img_shape[-2:]
43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47 | yl = np.clip(cy - cut_h // 2, 0, img_h)
48 | yh = np.clip(cy + cut_h // 2, 0, img_h)
49 | xl = np.clip(cx - cut_w // 2, 0, img_w)
50 | xh = np.clip(cx + cut_w // 2, 0, img_w)
51 | return yl, yh, xl, xh
52 |
53 |
54 | def rand_bbox_minmax(img_shape, minmax, count=None):
55 | """ Min-Max CutMix bounding-box
56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
57 | based on min/max percent values applied to each dimension of the input image.
58 |
59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60 |
61 | Args:
62 | img_shape (tuple): Image shape as tuple
63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64 | count (int): Number of bbox to generate
65 | """
66 | assert len(minmax) == 2
67 | img_h, img_w = img_shape[-2:]
68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70 | yl = np.random.randint(0, img_h - cut_h, size=count)
71 | xl = np.random.randint(0, img_w - cut_w, size=count)
72 | yu = yl + cut_h
73 | xu = xl + cut_w
74 | return yl, yu, xl, xu
75 |
76 |
77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78 | """ Generate bbox and apply lambda correction.
79 | """
80 | if ratio_minmax is not None:
81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82 | else:
83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84 | if correct_lam or ratio_minmax is not None:
85 | bbox_area = (yu - yl) * (xu - xl)
86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87 | return (yl, yu, xl, xu), lam
88 |
89 |
90 | class Mixup:
91 | """ Mixup/Cutmix that applies different params to each element or whole batch
92 |
93 | Args:
94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97 | prob (float): probability of applying mixup or cutmix per batch or element
98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101 | label_smoothing (float): apply label smoothing to the mixed target tensor
102 | num_classes (int): number of classes for target
103 | """
104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106 | self.mixup_alpha = mixup_alpha
107 | self.cutmix_alpha = cutmix_alpha
108 | self.cutmix_minmax = cutmix_minmax
109 | if self.cutmix_minmax is not None:
110 | assert len(self.cutmix_minmax) == 2
111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112 | self.cutmix_alpha = 1.0
113 | self.mix_prob = prob
114 | self.switch_prob = switch_prob
115 | self.label_smoothing = label_smoothing
116 | self.num_classes = num_classes
117 | self.mode = mode
118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120 |
121 | def _params_per_elem(self, batch_size):
122 | lam = np.ones(batch_size, dtype=np.float32)
123 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
124 | if self.mixup_enabled:
125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
127 | lam_mix = np.where(
128 | use_cutmix,
129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131 | elif self.mixup_alpha > 0.:
132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133 | elif self.cutmix_alpha > 0.:
134 | use_cutmix = np.ones(batch_size, dtype=np.bool)
135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136 | else:
137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139 | return lam, use_cutmix
140 |
141 | def _params_per_batch(self):
142 | lam = 1.
143 | use_cutmix = False
144 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146 | use_cutmix = np.random.rand() < self.switch_prob
147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
149 | elif self.mixup_alpha > 0.:
150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151 | elif self.cutmix_alpha > 0.:
152 | use_cutmix = True
153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154 | else:
155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156 | lam = float(lam_mix)
157 | return lam, use_cutmix
158 |
159 | def _mix_elem(self, x):
160 | batch_size = len(x)
161 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
162 | x_orig = x.clone() # need to keep an unmodified original for mixing source
163 | for i in range(batch_size):
164 | j = batch_size - i - 1
165 | lam = lam_batch[i]
166 | if lam != 1.:
167 | if use_cutmix[i]:
168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
171 | lam_batch[i] = lam
172 | else:
173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175 |
176 | def _mix_pair(self, x):
177 | batch_size = len(x)
178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179 | x_orig = x.clone() # need to keep an unmodified original for mixing source
180 | for i in range(batch_size // 2):
181 | j = batch_size - i - 1
182 | lam = lam_batch[i]
183 | if lam != 1.:
184 | if use_cutmix[i]:
185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189 | lam_batch[i] = lam
190 | else:
191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195 |
196 | def _mix_batch(self, x):
197 | lam, use_cutmix = self._params_per_batch()
198 | if lam == 1.:
199 | return 1.
200 | if use_cutmix:
201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
204 | else:
205 | x_flipped = x.flip(0).mul_(1. - lam)
206 | x.mul_(lam).add_(x_flipped)
207 | return lam
208 |
209 | def __call__(self, x, target):
210 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
211 | if self.mode == 'elem':
212 | lam = self._mix_elem(x)
213 | elif self.mode == 'pair':
214 | lam = self._mix_pair(x)
215 | else:
216 | lam = self._mix_batch(x)
217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218 | return x, target
219 |
220 |
221 | class FastCollateMixup(Mixup):
222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223 |
224 | A Mixup impl that's performed while collating the batches.
225 | """
226 |
227 | def _mix_elem_collate(self, output, batch, half=False):
228 | batch_size = len(batch)
229 | num_elem = batch_size // 2 if half else batch_size
230 | assert len(output) == num_elem
231 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
232 | for i in range(num_elem):
233 | j = batch_size - i - 1
234 | lam = lam_batch[i]
235 | mixed = batch[i][0]
236 | if lam != 1.:
237 | if use_cutmix[i]:
238 | if not half:
239 | mixed = mixed.copy()
240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243 | lam_batch[i] = lam
244 | else:
245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246 | np.rint(mixed, out=mixed)
247 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
248 | if half:
249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250 | return torch.tensor(lam_batch).unsqueeze(1)
251 |
252 | def _mix_pair_collate(self, output, batch):
253 | batch_size = len(batch)
254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255 | for i in range(batch_size // 2):
256 | j = batch_size - i - 1
257 | lam = lam_batch[i]
258 | mixed_i = batch[i][0]
259 | mixed_j = batch[j][0]
260 | assert 0 <= lam <= 1.0
261 | if lam < 1.:
262 | if use_cutmix[i]:
263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267 | mixed_j[:, yl:yh, xl:xh] = patch_i
268 | lam_batch[i] = lam
269 | else:
270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272 | mixed_i = mixed_temp
273 | np.rint(mixed_j, out=mixed_j)
274 | np.rint(mixed_i, out=mixed_i)
275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278 | return torch.tensor(lam_batch).unsqueeze(1)
279 |
280 | def _mix_batch_collate(self, output, batch):
281 | batch_size = len(batch)
282 | lam, use_cutmix = self._params_per_batch()
283 | if use_cutmix:
284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286 | for i in range(batch_size):
287 | j = batch_size - i - 1
288 | mixed = batch[i][0]
289 | if lam != 1.:
290 | if use_cutmix:
291 | mixed = mixed.copy() # don't want to modify the original while iterating
292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
293 | else:
294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295 | np.rint(mixed, out=mixed)
296 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
297 | return lam
298 |
299 | def __call__(self, batch, _=None):
300 | batch_size = len(batch)
301 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
302 | half = 'half' in self.mode
303 | if half:
304 | batch_size //= 2
305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306 | if self.mode == 'elem' or self.mode == 'half':
307 | lam = self._mix_elem_collate(output, batch, half=half)
308 | elif self.mode == 'pair':
309 | lam = self._mix_pair_collate(output, batch)
310 | else:
311 | lam = self._mix_batch_collate(output, batch)
312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314 | target = target[:batch_size]
315 | return output, target
316 |
317 |
--------------------------------------------------------------------------------
/modeling_finetune.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 | import torch.utils.checkpoint as checkpoint
9 |
10 |
11 | def _cfg(url='', **kwargs):
12 | return {
13 | 'url': url,
14 | 'num_classes': 101, 'input_size': (3, 224, 224), 'pool_size': None,
15 | 'crop_pct': .9, 'interpolation': 'bicubic',
16 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
17 | **kwargs
18 | }
19 |
20 |
21 | class DropPath(nn.Module):
22 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
23 | """
24 | def __init__(self, drop_prob=None):
25 | super(DropPath, self).__init__()
26 | self.drop_prob = drop_prob
27 |
28 | def forward(self, x):
29 | return drop_path(x, self.drop_prob, self.training)
30 |
31 | def extra_repr(self) -> str:
32 | return 'p={}'.format(self.drop_prob)
33 |
34 |
35 | class Mlp(nn.Module):
36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
37 | super().__init__()
38 | out_features = out_features or in_features
39 | hidden_features = hidden_features or in_features
40 | self.fc1 = nn.Linear(in_features, hidden_features)
41 | self.act = act_layer()
42 | self.fc2 = nn.Linear(hidden_features, out_features)
43 | self.drop = nn.Dropout(drop)
44 |
45 | def forward(self, x):
46 | x = self.fc1(x)
47 | x = self.act(x)
48 | # x = self.drop(x)
49 | # commit this for the orignal BERT implement
50 | x = self.fc2(x)
51 | x = self.drop(x)
52 | return x
53 |
54 |
55 | class Attention(nn.Module):
56 | def __init__(
57 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
58 | proj_drop=0., attn_head_dim=None):
59 | super().__init__()
60 | self.num_heads = num_heads
61 | head_dim = dim // num_heads
62 | if attn_head_dim is not None:
63 | head_dim = attn_head_dim
64 | all_head_dim = head_dim * self.num_heads
65 | self.scale = qk_scale or head_dim ** -0.5
66 |
67 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
68 | if qkv_bias:
69 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
70 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
71 | else:
72 | self.q_bias = None
73 | self.v_bias = None
74 |
75 | self.attn_drop = nn.Dropout(attn_drop)
76 | self.proj = nn.Linear(all_head_dim, dim)
77 | self.proj_drop = nn.Dropout(proj_drop)
78 |
79 | def forward(self, x):
80 | B, N, C = x.shape
81 | qkv_bias = None
82 | if self.q_bias is not None:
83 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
84 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
85 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
86 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
87 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
88 |
89 | q = q * self.scale
90 | attn = (q @ k.transpose(-2, -1))
91 |
92 |
93 | attn = attn.softmax(dim=-1)
94 | attn = self.attn_drop(attn)
95 |
96 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
97 | x = self.proj(x)
98 | x = self.proj_drop(x)
99 | return x
100 |
101 |
102 | class Block(nn.Module):
103 |
104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
105 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
106 | attn_head_dim=None):
107 | super().__init__()
108 | self.norm1 = norm_layer(dim)
109 | self.attn = Attention(
110 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
111 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114 | self.norm2 = norm_layer(dim)
115 | mlp_hidden_dim = int(dim * mlp_ratio)
116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
117 |
118 | if init_values > 0:
119 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
120 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
121 | else:
122 | self.gamma_1, self.gamma_2 = None, None
123 |
124 | def forward(self, x):
125 | if self.gamma_1 is None:
126 | x = x + self.drop_path(self.attn(self.norm1(x)))
127 | x = x + self.drop_path(self.mlp(self.norm2(x)))
128 | else:
129 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
130 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
131 | return x
132 |
133 |
134 | class PatchEmbed(nn.Module):
135 | """ Image to Patch Embedding
136 | """
137 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
138 | super().__init__()
139 | img_size = to_2tuple(img_size)
140 | patch_size = to_2tuple(patch_size)
141 | self.tubelet_size = int(tubelet_size)
142 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
143 | self.img_size = img_size
144 | self.patch_size = patch_size
145 | self.num_patches = num_patches
146 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
147 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
148 | stride=(self.tubelet_size, patch_size[0], patch_size[1]))
149 |
150 | def forward(self, x, **kwargs):
151 | B, C, T, H, W = x.shape
152 | # FIXME look at relaxing size constraints
153 | assert H == self.img_size[0] and W == self.img_size[1], \
154 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
155 | x = self.proj(x).flatten(2).transpose(1, 2)
156 | return x
157 |
158 | # sin-cos position encoding
159 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
160 | def get_sinusoid_encoding_table(n_position, d_hid):
161 | ''' Sinusoid position encoding table '''
162 | # TODO: make it with torch instead of numpy
163 | def get_position_angle_vec(position):
164 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
165 |
166 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
167 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
168 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
169 |
170 | return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)
171 |
172 |
173 | class VisionTransformer(nn.Module):
174 | """ Vision Transformer with support for patch or hybrid CNN input stage
175 | """
176 | def __init__(self,
177 | img_size=224,
178 | patch_size=16,
179 | in_chans=3,
180 | num_classes=1000,
181 | embed_dim=768,
182 | depth=12,
183 | num_heads=12,
184 | mlp_ratio=4.,
185 | qkv_bias=False,
186 | qk_scale=None,
187 | fc_drop_rate=0.,
188 | drop_rate=0.,
189 | attn_drop_rate=0.,
190 | drop_path_rate=0.,
191 | norm_layer=nn.LayerNorm,
192 | init_values=0.,
193 | use_learnable_pos_emb=False,
194 | init_scale=0.,
195 | all_frames=16,
196 | tubelet_size=2,
197 | use_checkpoint=False,
198 | use_mean_pooling=True,
199 | mcm=False,
200 | mcm_ratio=0.4):
201 | super().__init__()
202 | self.num_classes = num_classes
203 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
204 | self.tubelet_size = tubelet_size
205 | self.patch_embed = PatchEmbed(
206 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
207 | num_patches = self.patch_embed.num_patches
208 | self.use_checkpoint = use_checkpoint
209 |
210 | if use_learnable_pos_emb:
211 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
212 | else:
213 | # sine-cosine positional embeddings is on the way
214 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
215 |
216 | self.pos_drop = nn.Dropout(p=drop_rate)
217 |
218 |
219 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
220 | self.blocks = nn.ModuleList([
221 | Block(
222 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
223 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
224 | init_values=init_values)
225 | for i in range(depth)])
226 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
227 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
228 | self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
229 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
230 |
231 | if use_learnable_pos_emb:
232 | trunc_normal_(self.pos_embed, std=.02)
233 |
234 | trunc_normal_(self.head.weight, std=.02)
235 | self.apply(self._init_weights)
236 |
237 | self.head.weight.data.mul_(init_scale)
238 | self.head.bias.data.mul_(init_scale)
239 |
240 | self.mcm = mcm
241 | self.mcm_ratio = mcm_ratio
242 |
243 | def _init_weights(self, m):
244 | if isinstance(m, nn.Linear):
245 | trunc_normal_(m.weight, std=.02)
246 | if isinstance(m, nn.Linear) and m.bias is not None:
247 | nn.init.constant_(m.bias, 0)
248 | elif isinstance(m, nn.LayerNorm):
249 | nn.init.constant_(m.bias, 0)
250 | nn.init.constant_(m.weight, 1.0)
251 |
252 | def get_num_layers(self):
253 | return len(self.blocks)
254 |
255 | @torch.jit.ignore
256 | def no_weight_decay(self):
257 | return {'pos_embed', 'cls_token'}
258 |
259 | def get_classifier(self):
260 | return self.head
261 |
262 | def reset_classifier(self, num_classes, global_pool=''):
263 | self.num_classes = num_classes
264 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
265 |
266 | def forward_features(self, x):
267 | N, _, T, _, _ = x.shape
268 | x = self.patch_embed(x)
269 | if self.mcm:
270 | mask = self.mcm_step(x, N, x.shape[1] // T * 2, T//2, x.shape[2])
271 | B, _, C = x.size()
272 |
273 | if self.pos_embed is not None:
274 | x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
275 | if self.mcm:
276 | x = x[~mask].reshape(B, -1, C)
277 | x = self.pos_drop(x)
278 |
279 | if self.use_checkpoint:
280 | for blk in self.blocks:
281 | x = checkpoint.checkpoint(blk, x)
282 | else:
283 | for blk in self.blocks:
284 | x = blk(x)
285 |
286 | x = self.norm(x)
287 | if self.fc_norm is not None:
288 | return self.fc_norm(x.mean(1))
289 | else:
290 | return x[:, 0]
291 |
292 | def mcm_step(self, x, N, L, T, D):
293 | patch_embed_vectors = x.detach().clone().reshape(shape=(N, T, L, D))
294 |
295 | distance = torch.norm(patch_embed_vectors[:,:7,:,:] - patch_embed_vectors[:,1:,:,:], p=2, dim=3)
296 | importance = torch.cat((distance[:,0,:], distance.flatten(1)), dim=1)
297 |
298 | ids_sorted = torch.argsort(importance, dim=1, descending=True)
299 | num_compressed_tokens = int((1 - self.mcm_ratio) * (T * L))
300 |
301 | ids_restore = torch.argsort(ids_sorted, dim=1)
302 | # keep the first subset
303 | ids_keep = ids_sorted[:, :num_compressed_tokens]
304 |
305 | input_mask = torch.ones([N, T * L], device=x.device)
306 | input_mask[:, :num_compressed_tokens] = 0
307 |
308 | # unshuffle to get the binary mask
309 | input_mask = torch.gather(input_mask, dim=1, index=ids_restore)
310 |
311 | return input_mask.to(torch.bool)
312 |
313 | def forward(self, x):
314 | x = self.forward_features(x)
315 | x = self.head(self.fc_dropout(x))
316 | return x
317 |
318 |
319 | @register_model
320 | def vit_small_patch16_224(pretrained=False, **kwargs):
321 | model = VisionTransformer(
322 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
323 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
324 | model.default_cfg = _cfg()
325 | return model
326 |
327 |
328 | @register_model
329 | def vit_base_patch16_224(pretrained=False, **kwargs):
330 | model = VisionTransformer(
331 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
332 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
333 | model.default_cfg = _cfg()
334 | return model
335 |
336 |
337 | @register_model
338 | def vit_base_patch16_384(pretrained=False, **kwargs):
339 | model = VisionTransformer(
340 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
341 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
342 | model.default_cfg = _cfg()
343 | return model
344 |
345 |
346 | @register_model
347 | def vit_large_patch16_224(pretrained=False, **kwargs):
348 | model = VisionTransformer(
349 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
350 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
351 | model.default_cfg = _cfg()
352 | return model
353 |
354 |
355 | @register_model
356 | def vit_large_patch16_384(pretrained=False, **kwargs):
357 | model = VisionTransformer(
358 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
359 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
360 | model.default_cfg = _cfg()
361 | return model
362 |
363 |
364 | @register_model
365 | def vit_large_patch16_512(pretrained=False, **kwargs):
366 | model = VisionTransformer(
367 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
368 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
369 | model.default_cfg = _cfg()
370 | return model
371 |
372 |
373 | @register_model
374 | def vit_huge_patch16_224(pretrained=False, **kwargs):
375 | model = VisionTransformer(
376 | patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
377 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
378 | model.default_cfg = _cfg()
379 | return model
380 |
--------------------------------------------------------------------------------
/modeling_pretrain.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.checkpoint as checkpoint
6 | from functools import partial
7 |
8 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
9 | from timm.models.registry import register_model
10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
11 |
12 |
13 |
14 | def trunc_normal_(tensor, mean=0., std=1.):
15 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
16 |
17 |
18 | __all__ = [
19 | 'pretrain_videoms_small_patch16_224',
20 | 'pretrain_videoms_base_patch16_224',
21 | 'pretrain_videoms_large_patch16_224',
22 | 'pretrain_videoms_huge_patch16_224',
23 | ]
24 |
25 |
26 | class PretrainVisionTransformerEncoder(nn.Module):
27 | """ Vision Transformer with support for patch or hybrid CNN input stage
28 | """
29 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
30 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
31 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False,
32 | use_learnable_pos_emb=False, mcm=False, mcm_ratio=0.7, masking_ratio=0.9):
33 | super().__init__()
34 | self.num_classes = num_classes
35 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
36 | self.patch_embed = PatchEmbed(
37 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size)
38 | num_patches = self.patch_embed.num_patches
39 | self.use_checkpoint = use_checkpoint
40 |
41 |
42 | # TODO: Add the cls token
43 | if use_learnable_pos_emb:
44 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
45 | else:
46 | # sine-cosine positional embeddings
47 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
48 |
49 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
50 | self.blocks = nn.ModuleList([
51 | Block(
52 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
53 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
54 | init_values=init_values)
55 | for i in range(depth)])
56 | self.norm = norm_layer(embed_dim)
57 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
58 |
59 | if use_learnable_pos_emb:
60 | trunc_normal_(self.pos_embed, std=.02)
61 |
62 | self.apply(self._init_weights)
63 |
64 | self.mcm = mcm
65 | self.mcm_ratio = mcm_ratio
66 | self.masking_ratio = masking_ratio
67 |
68 |
69 | def _init_weights(self, m):
70 | if isinstance(m, nn.Linear):
71 | nn.init.xavier_uniform_(m.weight)
72 | if isinstance(m, nn.Linear) and m.bias is not None:
73 | nn.init.constant_(m.bias, 0)
74 | elif isinstance(m, nn.LayerNorm):
75 | nn.init.constant_(m.bias, 0)
76 | nn.init.constant_(m.weight, 1.0)
77 |
78 | def get_num_layers(self):
79 | return len(self.blocks)
80 |
81 | @torch.jit.ignore
82 | def no_weight_decay(self):
83 | return {'pos_embed', 'cls_token'}
84 |
85 | def get_classifier(self):
86 | return self.head
87 |
88 | def reset_classifier(self, num_classes, global_pool=''):
89 | self.num_classes = num_classes
90 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
91 |
92 | def forward_features(self, x, mask):
93 | N, _, T, _, _ = x.shape
94 | x = self.patch_embed(x)
95 | masks = (None, None)
96 | if self.mcm:
97 | mask, target_mask = self.mcm_step(x, N, x.shape[1] // T * 2, T // 2, x.shape[2])
98 | masks = (mask, target_mask)
99 |
100 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
101 |
102 | B, _, C = x.shape
103 | x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
104 |
105 | if self.use_checkpoint:
106 | for blk in self.blocks:
107 | x_vis = checkpoint.checkpoint(blk, x_vis)
108 | else:
109 | for blk in self.blocks:
110 | x_vis = blk(x_vis)
111 |
112 | x_vis = self.norm(x_vis)
113 | return x_vis, masks
114 |
115 | def mcm_step(self, x, N, L, T, D):
116 | patch_embed_vectors = x.detach().clone().reshape(shape=(N, T, L, D))
117 |
118 | distance = torch.norm(patch_embed_vectors[:,:7,:,:] - patch_embed_vectors[:,1:,:,:], p=2, dim=3)
119 | importance = torch.cat((distance[:,0,:], distance.flatten(1)), dim=1)
120 |
121 | ids_sorted = torch.argsort(importance, dim=1, descending=True)
122 | num_compressed_tokens = int((1 - self.mcm_ratio) * (T * L))
123 | num_input_tokens = int((1 - self.masking_ratio) * (T * L))
124 | noise = torch.rand(N, num_compressed_tokens, device=x.device) # noise in [0, 1]
125 | noise_id_shuffled = torch.argsort(noise, dim=1)
126 |
127 | ids_sorted[:,:num_compressed_tokens] = torch.gather(ids_sorted[:,:num_compressed_tokens], dim=1, index=noise_id_shuffled)
128 |
129 | ids_restore = torch.argsort(ids_sorted, dim=1)
130 |
131 | ids_keep = ids_sorted[:, :num_input_tokens]
132 | ids_recon = ids_sorted[:, num_input_tokens:num_compressed_tokens]
133 |
134 | input_mask = torch.ones([N, T * L], device=x.device)
135 | input_mask[:, :num_input_tokens] = 0
136 |
137 | target_mask = torch.ones([N, T * L], device=x.device)
138 | target_mask[:, num_input_tokens:num_compressed_tokens] = 0
139 |
140 |
141 | input_mask = torch.gather(input_mask, dim=1, index=ids_restore)
142 | target_mask = torch.gather(target_mask, dim=1, index=ids_restore)
143 |
144 | return input_mask.to(torch.bool), target_mask.to(torch.bool)
145 |
146 | def forward(self, x, mask):
147 | x, masks = self.forward_features(x, mask)
148 | x = self.head(x)
149 | return x, masks
150 |
151 | class PretrainVisionTransformerDecoder(nn.Module):
152 | """ Vision Transformer with support for patch or hybrid CNN input stage
153 | """
154 | def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
155 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
156 | norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False
157 | ):
158 | super().__init__()
159 | self.num_classes = num_classes
160 | assert num_classes == 3 * tubelet_size * patch_size ** 2
161 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
162 | self.patch_size = patch_size
163 | self.use_checkpoint = use_checkpoint
164 |
165 |
166 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
167 | self.blocks = nn.ModuleList([
168 | Block(
169 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
170 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
171 | init_values=init_values)
172 | for i in range(depth)])
173 | self.norm = norm_layer(embed_dim)
174 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175 |
176 | self.apply(self._init_weights)
177 |
178 |
179 | def _init_weights(self, m):
180 | if isinstance(m, nn.Linear):
181 | nn.init.xavier_uniform_(m.weight)
182 | if isinstance(m, nn.Linear) and m.bias is not None:
183 | nn.init.constant_(m.bias, 0)
184 | elif isinstance(m, nn.LayerNorm):
185 | nn.init.constant_(m.bias, 0)
186 | nn.init.constant_(m.weight, 1.0)
187 |
188 | def get_num_layers(self):
189 | return len(self.blocks)
190 |
191 | @torch.jit.ignore
192 | def no_weight_decay(self):
193 | return {'pos_embed', 'cls_token'}
194 |
195 | def get_classifier(self):
196 | return self.head
197 |
198 | def reset_classifier(self, num_classes, global_pool=''):
199 | self.num_classes = num_classes
200 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
201 |
202 | def forward(self, x, return_token_num):
203 | if self.use_checkpoint:
204 | for blk in self.blocks:
205 | x = checkpoint.checkpoint(blk, x)
206 | else:
207 | for blk in self.blocks:
208 | x = blk(x)
209 |
210 | if return_token_num > 0:
211 | x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
212 | else:
213 | x = self.head(self.norm(x))
214 |
215 | return x
216 |
217 | class PretrainVisionTransformer(nn.Module):
218 | """ Vision Transformer with support for patch or hybrid CNN input stage
219 | """
220 | def __init__(self,
221 | img_size=224,
222 | patch_size=16,
223 | encoder_in_chans=3,
224 | encoder_num_classes=0,
225 | encoder_embed_dim=768,
226 | encoder_depth=12,
227 | encoder_num_heads=12,
228 | decoder_num_classes=1536, # decoder_num_classes=768,
229 | decoder_embed_dim=512,
230 | decoder_depth=8,
231 | decoder_num_heads=8,
232 | mlp_ratio=4.,
233 | qkv_bias=False,
234 | qk_scale=None,
235 | drop_rate=0.,
236 | attn_drop_rate=0.,
237 | drop_path_rate=0.,
238 | norm_layer=nn.LayerNorm,
239 | init_values=0.,
240 | use_learnable_pos_emb=False,
241 | use_checkpoint=False,
242 | tubelet_size=2,
243 | num_classes=0, # avoid the error from create_fn in timm
244 | in_chans=0, # avoid the error from create_fn in timm
245 | motion_centric_masking=False,
246 | motion_centric_masking_ratio=0.7,
247 | masking_ratio=0.9
248 | ):
249 | super().__init__()
250 | self.encoder = PretrainVisionTransformerEncoder(
251 | img_size=img_size,
252 | patch_size=patch_size,
253 | in_chans=encoder_in_chans,
254 | num_classes=encoder_num_classes,
255 | embed_dim=encoder_embed_dim,
256 | depth=encoder_depth,
257 | num_heads=encoder_num_heads,
258 | mlp_ratio=mlp_ratio,
259 | qkv_bias=qkv_bias,
260 | qk_scale=qk_scale,
261 | drop_rate=drop_rate,
262 | attn_drop_rate=attn_drop_rate,
263 | drop_path_rate=drop_path_rate,
264 | norm_layer=norm_layer,
265 | init_values=init_values,
266 | tubelet_size=tubelet_size,
267 | use_checkpoint=use_checkpoint,
268 | use_learnable_pos_emb=use_learnable_pos_emb,
269 | mcm=motion_centric_masking,
270 | mcm_ratio=motion_centric_masking_ratio,
271 | masking_ratio=masking_ratio)
272 |
273 | self.decoder = PretrainVisionTransformerDecoder(
274 | patch_size=patch_size,
275 | num_patches=self.encoder.patch_embed.num_patches,
276 | num_classes=decoder_num_classes,
277 | embed_dim=decoder_embed_dim,
278 | depth=decoder_depth,
279 | num_heads=decoder_num_heads,
280 | mlp_ratio=mlp_ratio,
281 | qkv_bias=qkv_bias,
282 | qk_scale=qk_scale,
283 | drop_rate=drop_rate,
284 | attn_drop_rate=attn_drop_rate,
285 | drop_path_rate=drop_path_rate,
286 | norm_layer=norm_layer,
287 | init_values=init_values,
288 | tubelet_size=tubelet_size,
289 | use_checkpoint=use_checkpoint)
290 |
291 | self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
292 |
293 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
294 |
295 | self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
296 |
297 | trunc_normal_(self.mask_token, std=.02)
298 |
299 |
300 | def _init_weights(self, m):
301 | if isinstance(m, nn.Linear):
302 | nn.init.xavier_uniform_(m.weight)
303 | if isinstance(m, nn.Linear) and m.bias is not None:
304 | nn.init.constant_(m.bias, 0)
305 | elif isinstance(m, nn.LayerNorm):
306 | nn.init.constant_(m.bias, 0)
307 | nn.init.constant_(m.weight, 1.0)
308 |
309 | def get_num_layers(self):
310 | return len(self.blocks)
311 |
312 | @torch.jit.ignore
313 | def no_weight_decay(self):
314 | return {'pos_embed', 'cls_token', 'mask_token'}
315 |
316 | def forward(self, x, mask):
317 | _, _, T, _, _ = x.shape
318 | x_vis, masks = self.encoder(x, mask)
319 | x_vis = self.encoder_to_decoder(x_vis)
320 | B, N, C = x_vis.shape
321 | expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
322 | if masks[0] is None:
323 | pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
324 | pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
325 | x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
326 | x = self.decoder(x_full, pos_emd_mask.shape[1])
327 | else:
328 | pos_emd_vis = expand_pos_embed[~masks[0]].reshape(B, -1, C)
329 | pos_emd_mask = expand_pos_embed[~masks[1]].reshape(B, -1, C)
330 | x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
331 | x = self.decoder(x_full, pos_emd_mask.shape[1])
332 | return x, masks
333 |
334 | @register_model
335 | def pretrain_videoms_small_patch16_224(pretrained=False, **kwargs):
336 | model = PretrainVisionTransformer(
337 | img_size=224,
338 | patch_size=16,
339 | encoder_embed_dim=384,
340 | encoder_depth=12,
341 | encoder_num_heads=6,
342 | encoder_num_classes=0,
343 | decoder_num_classes=1536,
344 | decoder_embed_dim=192,
345 | decoder_num_heads=3,
346 | mlp_ratio=4,
347 | qkv_bias=True,
348 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
349 | **kwargs)
350 | model.default_cfg = _cfg()
351 | if pretrained:
352 | checkpoint = torch.load(
353 | kwargs["init_ckpt"], map_location="cpu"
354 | )
355 | model.load_state_dict(checkpoint["model"])
356 | return model
357 |
358 | @register_model
359 | def pretrain_videoms_base_patch16_224(pretrained=False, **kwargs):
360 | model = PretrainVisionTransformer(
361 | img_size=224,
362 | patch_size=16,
363 | encoder_embed_dim=768,
364 | encoder_depth=12,
365 | encoder_num_heads=12,
366 | encoder_num_classes=0,
367 | decoder_num_classes=1536,
368 | decoder_embed_dim=384,
369 | decoder_num_heads=6,
370 | mlp_ratio=4,
371 | qkv_bias=True,
372 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
373 | **kwargs)
374 | model.default_cfg = _cfg()
375 | if pretrained:
376 | checkpoint = torch.load(
377 | kwargs["init_ckpt"], map_location="cpu"
378 | )
379 | model.load_state_dict(checkpoint["model"])
380 | return model
381 |
382 | @register_model
383 | def pretrain_videoms_large_patch16_224(pretrained=False, **kwargs):
384 | model = PretrainVisionTransformer(
385 | img_size=224,
386 | patch_size=16,
387 | encoder_embed_dim=1024,
388 | encoder_depth=24,
389 | encoder_num_heads=16,
390 | encoder_num_classes=0,
391 | decoder_num_classes=1536,
392 | decoder_embed_dim=512,
393 | decoder_num_heads=8,
394 | mlp_ratio=4,
395 | qkv_bias=True,
396 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
397 | **kwargs)
398 | model.default_cfg = _cfg()
399 | if pretrained:
400 | checkpoint = torch.load(
401 | kwargs["init_ckpt"], map_location="cpu"
402 | )
403 | model.load_state_dict(checkpoint["model"])
404 | return model
405 |
406 | @register_model
407 | def pretrain_videoms_huge_patch16_224(pretrained=False, **kwargs):
408 | model = PretrainVisionTransformer(
409 | img_size=224,
410 | patch_size=16,
411 | encoder_embed_dim=1280,
412 | encoder_depth=32,
413 | encoder_num_heads=16,
414 | encoder_num_classes=0,
415 | decoder_num_classes=1536,
416 | decoder_embed_dim=640,
417 | decoder_num_heads=8,
418 | mlp_ratio=4,
419 | qkv_bias=True,
420 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
421 | **kwargs)
422 | model.default_cfg = _cfg()
423 | if pretrained:
424 | checkpoint = torch.load(
425 | kwargs["init_ckpt"], map_location="cpu"
426 | )
427 | model.load_state_dict(checkpoint["model"])
428 | return model
429 |
--------------------------------------------------------------------------------
/rand_augment.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
4 | pulished under an Apache License 2.0.
5 |
6 | COMMENT FROM ORIGINAL:
7 | AutoAugment, RandAugment, and AugMix for PyTorch
8 | This code implements the searched ImageNet policies with various tweaks and
9 | improvements and does not include any of the search code. AA and RA
10 | Implementation adapted from:
11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12 | AugMix adapted from:
13 | https://github.com/google-research/augmix
14 | Papers:
15 | AutoAugment: Learning Augmentation Policies from Data
16 | https://arxiv.org/abs/1805.09501
17 | Learning Data Augmentation Strategies for Object Detection
18 | https://arxiv.org/abs/1906.11172
19 | RandAugment: Practical automated data augmentation...
20 | https://arxiv.org/abs/1909.13719
21 | AugMix: A Simple Data Processing Method to Improve Robustness and
22 | Uncertainty https://arxiv.org/abs/1912.02781
23 |
24 | Hacked together by / Copyright 2020 Ross Wightman
25 | """
26 |
27 | import math
28 | import numpy as np
29 | import random
30 | import re
31 | import PIL
32 | from PIL import Image, ImageEnhance, ImageOps
33 |
34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
35 |
36 | _FILL = (128, 128, 128)
37 |
38 | # This signifies the max integer that the controller RNN could predict for the
39 | # augmentation scheme.
40 | _MAX_LEVEL = 10.0
41 |
42 | _HPARAMS_DEFAULT = {
43 | "translate_const": 250,
44 | "img_mean": _FILL,
45 | }
46 |
47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
48 |
49 |
50 | def _interpolation(kwargs):
51 | interpolation = kwargs.pop("resample", Image.BILINEAR)
52 | if isinstance(interpolation, (list, tuple)):
53 | return random.choice(interpolation)
54 | else:
55 | return interpolation
56 |
57 |
58 | def _check_args_tf(kwargs):
59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
60 | kwargs.pop("fillcolor")
61 | kwargs["resample"] = _interpolation(kwargs)
62 |
63 |
64 | def shear_x(img, factor, **kwargs):
65 | _check_args_tf(kwargs)
66 | return img.transform(
67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
68 | )
69 |
70 |
71 | def shear_y(img, factor, **kwargs):
72 | _check_args_tf(kwargs)
73 | return img.transform(
74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
75 | )
76 |
77 |
78 | def translate_x_rel(img, pct, **kwargs):
79 | pixels = pct * img.size[0]
80 | _check_args_tf(kwargs)
81 | return img.transform(
82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
83 | )
84 |
85 |
86 | def translate_y_rel(img, pct, **kwargs):
87 | pixels = pct * img.size[1]
88 | _check_args_tf(kwargs)
89 | return img.transform(
90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
91 | )
92 |
93 |
94 | def translate_x_abs(img, pixels, **kwargs):
95 | _check_args_tf(kwargs)
96 | return img.transform(
97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
98 | )
99 |
100 |
101 | def translate_y_abs(img, pixels, **kwargs):
102 | _check_args_tf(kwargs)
103 | return img.transform(
104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
105 | )
106 |
107 |
108 | def rotate(img, degrees, **kwargs):
109 | _check_args_tf(kwargs)
110 | if _PIL_VER >= (5, 2):
111 | return img.rotate(degrees, **kwargs)
112 | elif _PIL_VER >= (5, 0):
113 | w, h = img.size
114 | post_trans = (0, 0)
115 | rotn_center = (w / 2.0, h / 2.0)
116 | angle = -math.radians(degrees)
117 | matrix = [
118 | round(math.cos(angle), 15),
119 | round(math.sin(angle), 15),
120 | 0.0,
121 | round(-math.sin(angle), 15),
122 | round(math.cos(angle), 15),
123 | 0.0,
124 | ]
125 |
126 | def transform(x, y, matrix):
127 | (a, b, c, d, e, f) = matrix
128 | return a * x + b * y + c, d * x + e * y + f
129 |
130 | matrix[2], matrix[5] = transform(
131 | -rotn_center[0] - post_trans[0],
132 | -rotn_center[1] - post_trans[1],
133 | matrix,
134 | )
135 | matrix[2] += rotn_center[0]
136 | matrix[5] += rotn_center[1]
137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
138 | else:
139 | return img.rotate(degrees, resample=kwargs["resample"])
140 |
141 |
142 | def auto_contrast(img, **__):
143 | return ImageOps.autocontrast(img)
144 |
145 |
146 | def invert(img, **__):
147 | return ImageOps.invert(img)
148 |
149 |
150 | def equalize(img, **__):
151 | return ImageOps.equalize(img)
152 |
153 |
154 | def solarize(img, thresh, **__):
155 | return ImageOps.solarize(img, thresh)
156 |
157 |
158 | def solarize_add(img, add, thresh=128, **__):
159 | lut = []
160 | for i in range(256):
161 | if i < thresh:
162 | lut.append(min(255, i + add))
163 | else:
164 | lut.append(i)
165 | if img.mode in ("L", "RGB"):
166 | if img.mode == "RGB" and len(lut) == 256:
167 | lut = lut + lut + lut
168 | return img.point(lut)
169 | else:
170 | return img
171 |
172 |
173 | def posterize(img, bits_to_keep, **__):
174 | if bits_to_keep >= 8:
175 | return img
176 | return ImageOps.posterize(img, bits_to_keep)
177 |
178 |
179 | def contrast(img, factor, **__):
180 | return ImageEnhance.Contrast(img).enhance(factor)
181 |
182 |
183 | def color(img, factor, **__):
184 | return ImageEnhance.Color(img).enhance(factor)
185 |
186 |
187 | def brightness(img, factor, **__):
188 | return ImageEnhance.Brightness(img).enhance(factor)
189 |
190 |
191 | def sharpness(img, factor, **__):
192 | return ImageEnhance.Sharpness(img).enhance(factor)
193 |
194 |
195 | def _randomly_negate(v):
196 | """With 50% prob, negate the value"""
197 | return -v if random.random() > 0.5 else v
198 |
199 |
200 | def _rotate_level_to_arg(level, _hparams):
201 | # range [-30, 30]
202 | level = (level / _MAX_LEVEL) * 30.0
203 | level = _randomly_negate(level)
204 | return (level,)
205 |
206 |
207 | def _enhance_level_to_arg(level, _hparams):
208 | # range [0.1, 1.9]
209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
210 |
211 |
212 | def _enhance_increasing_level_to_arg(level, _hparams):
213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
214 | # range [0.1, 1.9]
215 | level = (level / _MAX_LEVEL) * 0.9
216 | level = 1.0 + _randomly_negate(level)
217 | return (level,)
218 |
219 |
220 | def _shear_level_to_arg(level, _hparams):
221 | # range [-0.3, 0.3]
222 | level = (level / _MAX_LEVEL) * 0.3
223 | level = _randomly_negate(level)
224 | return (level,)
225 |
226 |
227 | def _translate_abs_level_to_arg(level, hparams):
228 | translate_const = hparams["translate_const"]
229 | level = (level / _MAX_LEVEL) * float(translate_const)
230 | level = _randomly_negate(level)
231 | return (level,)
232 |
233 |
234 | def _translate_rel_level_to_arg(level, hparams):
235 | # default range [-0.45, 0.45]
236 | translate_pct = hparams.get("translate_pct", 0.45)
237 | level = (level / _MAX_LEVEL) * translate_pct
238 | level = _randomly_negate(level)
239 | return (level,)
240 |
241 |
242 | def _posterize_level_to_arg(level, _hparams):
243 | # As per Tensorflow TPU EfficientNet impl
244 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
245 | # intensity/severity of augmentation decreases with level
246 | return (int((level / _MAX_LEVEL) * 4),)
247 |
248 |
249 | def _posterize_increasing_level_to_arg(level, hparams):
250 | # As per Tensorflow models research and UDA impl
251 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
252 | # intensity/severity of augmentation increases with level
253 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
254 |
255 |
256 | def _posterize_original_level_to_arg(level, _hparams):
257 | # As per original AutoAugment paper description
258 | # range [4, 8], 'keep 4 up to 8 MSB of image'
259 | # intensity/severity of augmentation decreases with level
260 | return (int((level / _MAX_LEVEL) * 4) + 4,)
261 |
262 |
263 | def _solarize_level_to_arg(level, _hparams):
264 | # range [0, 256]
265 | # intensity/severity of augmentation decreases with level
266 | return (int((level / _MAX_LEVEL) * 256),)
267 |
268 |
269 | def _solarize_increasing_level_to_arg(level, _hparams):
270 | # range [0, 256]
271 | # intensity/severity of augmentation increases with level
272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
273 |
274 |
275 | def _solarize_add_level_to_arg(level, _hparams):
276 | # range [0, 110]
277 | return (int((level / _MAX_LEVEL) * 110),)
278 |
279 |
280 | LEVEL_TO_ARG = {
281 | "AutoContrast": None,
282 | "Equalize": None,
283 | "Invert": None,
284 | "Rotate": _rotate_level_to_arg,
285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
286 | "Posterize": _posterize_level_to_arg,
287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
288 | "PosterizeOriginal": _posterize_original_level_to_arg,
289 | "Solarize": _solarize_level_to_arg,
290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
291 | "SolarizeAdd": _solarize_add_level_to_arg,
292 | "Color": _enhance_level_to_arg,
293 | "ColorIncreasing": _enhance_increasing_level_to_arg,
294 | "Contrast": _enhance_level_to_arg,
295 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
296 | "Brightness": _enhance_level_to_arg,
297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
298 | "Sharpness": _enhance_level_to_arg,
299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
300 | "ShearX": _shear_level_to_arg,
301 | "ShearY": _shear_level_to_arg,
302 | "TranslateX": _translate_abs_level_to_arg,
303 | "TranslateY": _translate_abs_level_to_arg,
304 | "TranslateXRel": _translate_rel_level_to_arg,
305 | "TranslateYRel": _translate_rel_level_to_arg,
306 | }
307 |
308 |
309 | NAME_TO_OP = {
310 | "AutoContrast": auto_contrast,
311 | "Equalize": equalize,
312 | "Invert": invert,
313 | "Rotate": rotate,
314 | "Posterize": posterize,
315 | "PosterizeIncreasing": posterize,
316 | "PosterizeOriginal": posterize,
317 | "Solarize": solarize,
318 | "SolarizeIncreasing": solarize,
319 | "SolarizeAdd": solarize_add,
320 | "Color": color,
321 | "ColorIncreasing": color,
322 | "Contrast": contrast,
323 | "ContrastIncreasing": contrast,
324 | "Brightness": brightness,
325 | "BrightnessIncreasing": brightness,
326 | "Sharpness": sharpness,
327 | "SharpnessIncreasing": sharpness,
328 | "ShearX": shear_x,
329 | "ShearY": shear_y,
330 | "TranslateX": translate_x_abs,
331 | "TranslateY": translate_y_abs,
332 | "TranslateXRel": translate_x_rel,
333 | "TranslateYRel": translate_y_rel,
334 | }
335 |
336 |
337 | class AugmentOp:
338 | """
339 | Apply for video.
340 | """
341 |
342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
343 | hparams = hparams or _HPARAMS_DEFAULT
344 | self.aug_fn = NAME_TO_OP[name]
345 | self.level_fn = LEVEL_TO_ARG[name]
346 | self.prob = prob
347 | self.magnitude = magnitude
348 | self.hparams = hparams.copy()
349 | self.kwargs = {
350 | "fillcolor": hparams["img_mean"]
351 | if "img_mean" in hparams
352 | else _FILL,
353 | "resample": hparams["interpolation"]
354 | if "interpolation" in hparams
355 | else _RANDOM_INTERPOLATION,
356 | }
357 |
358 | # If magnitude_std is > 0, we introduce some randomness
359 | # in the usually fixed policy and sample magnitude from a normal distribution
360 | # with mean `magnitude` and std-dev of `magnitude_std`.
361 | # NOTE This is my own hack, being tested, not in papers or reference impls.
362 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
363 |
364 | def __call__(self, img_list):
365 | if self.prob < 1.0 and random.random() > self.prob:
366 | return img_list
367 | magnitude = self.magnitude
368 | if self.magnitude_std and self.magnitude_std > 0:
369 | magnitude = random.gauss(magnitude, self.magnitude_std)
370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
371 | level_args = (
372 | self.level_fn(magnitude, self.hparams)
373 | if self.level_fn is not None
374 | else ()
375 | )
376 |
377 | if isinstance(img_list, list):
378 | return [
379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
380 | ]
381 | else:
382 | return self.aug_fn(img_list, *level_args, **self.kwargs)
383 |
384 |
385 | _RAND_TRANSFORMS = [
386 | "AutoContrast",
387 | "Equalize",
388 | "Invert",
389 | "Rotate",
390 | "Posterize",
391 | "Solarize",
392 | "SolarizeAdd",
393 | "Color",
394 | "Contrast",
395 | "Brightness",
396 | "Sharpness",
397 | "ShearX",
398 | "ShearY",
399 | "TranslateXRel",
400 | "TranslateYRel",
401 | ]
402 |
403 |
404 | _RAND_INCREASING_TRANSFORMS = [
405 | "AutoContrast",
406 | "Equalize",
407 | "Invert",
408 | "Rotate",
409 | "PosterizeIncreasing",
410 | "SolarizeIncreasing",
411 | "SolarizeAdd",
412 | "ColorIncreasing",
413 | "ContrastIncreasing",
414 | "BrightnessIncreasing",
415 | "SharpnessIncreasing",
416 | "ShearX",
417 | "ShearY",
418 | "TranslateXRel",
419 | "TranslateYRel",
420 | ]
421 |
422 |
423 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
424 | # They may not result in increased performance, but could likely be tuned to so.
425 | _RAND_CHOICE_WEIGHTS_0 = {
426 | "Rotate": 0.3,
427 | "ShearX": 0.2,
428 | "ShearY": 0.2,
429 | "TranslateXRel": 0.1,
430 | "TranslateYRel": 0.1,
431 | "Color": 0.025,
432 | "Sharpness": 0.025,
433 | "AutoContrast": 0.025,
434 | "Solarize": 0.005,
435 | "SolarizeAdd": 0.005,
436 | "Contrast": 0.005,
437 | "Brightness": 0.005,
438 | "Equalize": 0.005,
439 | "Posterize": 0,
440 | "Invert": 0,
441 | }
442 |
443 |
444 | def _select_rand_weights(weight_idx=0, transforms=None):
445 | transforms = transforms or _RAND_TRANSFORMS
446 | assert weight_idx == 0 # only one set of weights currently
447 | rand_weights = _RAND_CHOICE_WEIGHTS_0
448 | probs = [rand_weights[k] for k in transforms]
449 | probs /= np.sum(probs)
450 | return probs
451 |
452 |
453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
454 | hparams = hparams or _HPARAMS_DEFAULT
455 | transforms = transforms or _RAND_TRANSFORMS
456 | return [
457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
458 | for name in transforms
459 | ]
460 |
461 |
462 | class RandAugment:
463 | def __init__(self, ops, num_layers=2, choice_weights=None):
464 | self.ops = ops
465 | self.num_layers = num_layers
466 | self.choice_weights = choice_weights
467 |
468 | def __call__(self, img):
469 | # no replacement when using weighted choice
470 | ops = np.random.choice(
471 | self.ops,
472 | self.num_layers,
473 | replace=self.choice_weights is None,
474 | p=self.choice_weights,
475 | )
476 | for op in ops:
477 | img = op(img)
478 | return img
479 |
480 |
481 | def rand_augment_transform(config_str, hparams):
482 | """
483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
484 |
485 | Create a RandAugment transform
486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
488 | sections, not order sepecific determine
489 | 'm' - integer magnitude of rand augment
490 | 'n' - integer num layers (number of transform ops selected per image)
491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
492 | 'mstd' - float std deviation of magnitude noise applied
493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
497 | :return: A PyTorch compatible Transform
498 | """
499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
500 | num_layers = 2 # default to 2 ops per image
501 | weight_idx = None # default to no probability weights for op choice
502 | transforms = _RAND_TRANSFORMS
503 | config = config_str.split("-")
504 | assert config[0] == "rand"
505 | config = config[1:]
506 | for c in config:
507 | cs = re.split(r"(\d.*)", c)
508 | if len(cs) < 2:
509 | continue
510 | key, val = cs[:2]
511 | if key == "mstd":
512 | # noise param injected via hparams for now
513 | hparams.setdefault("magnitude_std", float(val))
514 | elif key == "inc":
515 | if bool(val):
516 | transforms = _RAND_INCREASING_TRANSFORMS
517 | elif key == "m":
518 | magnitude = int(val)
519 | elif key == "n":
520 | num_layers = int(val)
521 | elif key == "w":
522 | weight_idx = int(val)
523 | else:
524 | assert NotImplementedError
525 | ra_ops = rand_augment_ops(
526 | magnitude=magnitude, hparams=hparams, transforms=transforms
527 | )
528 | choice_weights = (
529 | None if weight_idx is None else _select_rand_weights(weight_idx)
530 | )
531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
532 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import math
4 | import time
5 | import json
6 | from collections import defaultdict, deque
7 | import datetime
8 | import numpy as np
9 | from timm.utils import get_state_dict
10 | from torch.utils.data._utils.collate import default_collate
11 | from pathlib import Path
12 | import subprocess
13 | import torch
14 | import torch.distributed as dist
15 | from torch._six import inf
16 | import random
17 |
18 | from tensorboardX import SummaryWriter
19 |
20 |
21 | class SmoothedValue(object):
22 | """Track a series of values and provide access to smoothed values over a
23 | window or the global series average.
24 | """
25 |
26 | def __init__(self, window_size=20, fmt=None):
27 | if fmt is None:
28 | fmt = "{median:.4f} ({global_avg:.4f})"
29 | self.deque = deque(maxlen=window_size)
30 | self.total = 0.0
31 | self.count = 0
32 | self.fmt = fmt
33 |
34 | def update(self, value, n=1):
35 | self.deque.append(value)
36 | self.count += n
37 | self.total += value * n
38 |
39 | def synchronize_between_processes(self):
40 | """
41 | Warning: does not synchronize the deque!
42 | """
43 | if not is_dist_avail_and_initialized():
44 | return
45 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
46 | dist.barrier()
47 | dist.all_reduce(t)
48 | t = t.tolist()
49 | self.count = int(t[0])
50 | self.total = t[1]
51 |
52 | @property
53 | def median(self):
54 | d = torch.tensor(list(self.deque))
55 | return d.median().item()
56 |
57 | @property
58 | def avg(self):
59 | d = torch.tensor(list(self.deque), dtype=torch.float32)
60 | return d.mean().item()
61 |
62 | @property
63 | def global_avg(self):
64 | return self.total / self.count
65 |
66 | @property
67 | def max(self):
68 | return max(self.deque)
69 |
70 | @property
71 | def value(self):
72 | return self.deque[-1]
73 |
74 | def __str__(self):
75 | return self.fmt.format(
76 | median=self.median,
77 | avg=self.avg,
78 | global_avg=self.global_avg,
79 | max=self.max,
80 | value=self.value)
81 |
82 |
83 | class MetricLogger(object):
84 | def __init__(self, delimiter="\t"):
85 | self.meters = defaultdict(SmoothedValue)
86 | self.delimiter = delimiter
87 |
88 | def update(self, **kwargs):
89 | for k, v in kwargs.items():
90 | if v is None:
91 | continue
92 | if isinstance(v, torch.Tensor):
93 | v = v.item()
94 | assert isinstance(v, (float, int))
95 | self.meters[k].update(v)
96 |
97 | def __getattr__(self, attr):
98 | if attr in self.meters:
99 | return self.meters[attr]
100 | if attr in self.__dict__:
101 | return self.__dict__[attr]
102 | raise AttributeError("'{}' object has no attribute '{}'".format(
103 | type(self).__name__, attr))
104 |
105 | def __str__(self):
106 | loss_str = []
107 | for name, meter in self.meters.items():
108 | loss_str.append(
109 | "{}: {}".format(name, str(meter))
110 | )
111 | return self.delimiter.join(loss_str)
112 |
113 | def synchronize_between_processes(self):
114 | for meter in self.meters.values():
115 | meter.synchronize_between_processes()
116 |
117 | def add_meter(self, name, meter):
118 | self.meters[name] = meter
119 |
120 | def log_every(self, iterable, print_freq, header=None):
121 | i = 0
122 | if not header:
123 | header = ''
124 | start_time = time.time()
125 | end = time.time()
126 | iter_time = SmoothedValue(fmt='{avg:.4f}')
127 | data_time = SmoothedValue(fmt='{avg:.4f}')
128 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
129 | log_msg = [
130 | header,
131 | '[{0' + space_fmt + '}/{1}]',
132 | 'eta: {eta}',
133 | '{meters}',
134 | 'time: {time}',
135 | 'data: {data}'
136 | ]
137 | if torch.cuda.is_available():
138 | log_msg.append('max mem: {memory:.0f}')
139 | log_msg = self.delimiter.join(log_msg)
140 | MB = 1024.0 * 1024.0
141 | for obj in iterable:
142 | data_time.update(time.time() - end)
143 | yield obj
144 | iter_time.update(time.time() - end)
145 | if i % print_freq == 0 or i == len(iterable) - 1:
146 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
147 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
148 | if torch.cuda.is_available():
149 | print(log_msg.format(
150 | i, len(iterable), eta=eta_string,
151 | meters=str(self),
152 | time=str(iter_time), data=str(data_time),
153 | memory=torch.cuda.max_memory_allocated() / MB))
154 | else:
155 | print(log_msg.format(
156 | i, len(iterable), eta=eta_string,
157 | meters=str(self),
158 | time=str(iter_time), data=str(data_time)))
159 | i += 1
160 | end = time.time()
161 | total_time = time.time() - start_time
162 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
163 | print('{} Total time: {} ({:.4f} s / it)'.format(
164 | header, total_time_str, total_time / len(iterable)))
165 |
166 |
167 | class TensorboardLogger(object):
168 | def __init__(self, log_dir):
169 | self.writer = SummaryWriter(logdir=log_dir)
170 | self.step = 0
171 |
172 | def set_step(self, step=None):
173 | if step is not None:
174 | self.step = step
175 | else:
176 | self.step += 1
177 |
178 | def update(self, head='scalar', step=None, **kwargs):
179 | for k, v in kwargs.items():
180 | if v is None:
181 | continue
182 | if isinstance(v, torch.Tensor):
183 | v = v.item()
184 | assert isinstance(v, (float, int))
185 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
186 |
187 | def flush(self):
188 | self.writer.flush()
189 |
190 | def seed_worker(worker_id):
191 | worker_seed = torch.initial_seed() % 2**32
192 | np.random.seed(worker_seed)
193 | random.seed(worker_seed)
194 |
195 | def _load_checkpoint_for_ema(model_ema, checkpoint):
196 | """
197 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
198 | """
199 | mem_file = io.BytesIO()
200 | torch.save(checkpoint, mem_file)
201 | mem_file.seek(0)
202 | model_ema._load_checkpoint(mem_file)
203 |
204 |
205 | def setup_for_distributed(is_master):
206 | """
207 | This function disables printing when not in master process
208 | """
209 | import builtins as __builtin__
210 | builtin_print = __builtin__.print
211 |
212 | def print(*args, **kwargs):
213 | force = kwargs.pop('force', False)
214 | if is_master or force:
215 | builtin_print(*args, **kwargs)
216 |
217 | __builtin__.print = print
218 |
219 |
220 | def is_dist_avail_and_initialized():
221 | if not dist.is_available():
222 | return False
223 | if not dist.is_initialized():
224 | return False
225 | return True
226 |
227 |
228 | def get_world_size():
229 | if not is_dist_avail_and_initialized():
230 | return 1
231 | return dist.get_world_size()
232 |
233 |
234 | def get_rank():
235 | if not is_dist_avail_and_initialized():
236 | return 0
237 | return dist.get_rank()
238 |
239 |
240 | def is_main_process():
241 | return get_rank() == 0
242 |
243 |
244 | def save_on_master(*args, **kwargs):
245 | if is_main_process():
246 | torch.save(*args, **kwargs)
247 |
248 |
249 | def init_distributed_mode(args):
250 | if args.dist_on_itp:
251 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
252 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
253 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
254 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
255 | os.environ['LOCAL_RANK'] = str(args.gpu)
256 | os.environ['RANK'] = str(args.rank)
257 | os.environ['WORLD_SIZE'] = str(args.world_size)
258 | elif 'SLURM_PROCID' in os.environ:
259 | args.rank = int(os.environ['SLURM_PROCID'])
260 | args.gpu = int(os.environ['SLURM_LOCALID'])
261 | args.world_size = int(os.environ['SLURM_NTASKS'])
262 | os.environ['RANK'] = str(args.rank)
263 | os.environ['LOCAL_RANK'] = str(args.gpu)
264 | os.environ['WORLD_SIZE'] = str(args.world_size)
265 |
266 | node_list = os.environ['SLURM_NODELIST']
267 | addr = subprocess.getoutput(
268 | f'scontrol show hostname {node_list} | head -n1')
269 | if 'MASTER_ADDR' not in os.environ:
270 | os.environ['MASTER_ADDR'] = addr
271 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
272 | args.rank = int(os.environ["RANK"])
273 | args.world_size = int(os.environ['WORLD_SIZE'])
274 | args.gpu = int(os.environ['LOCAL_RANK'])
275 | else:
276 | print('Not using distributed mode')
277 | args.distributed = False
278 | return
279 |
280 | args.distributed = True
281 |
282 | torch.cuda.set_device(args.gpu)
283 | args.dist_backend = 'nccl'
284 | print('| distributed init (rank {}): {}, gpu {}'.format(
285 | args.rank, args.dist_url, args.gpu), flush=True)
286 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
287 | world_size=args.world_size, rank=args.rank)
288 | torch.distributed.barrier()
289 | # assert torch.distributed.is_initialized()
290 | setup_for_distributed(args.rank == 0)
291 |
292 |
293 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
294 | missing_keys = []
295 | unexpected_keys = []
296 | error_msgs = []
297 | metadata = getattr(state_dict, '_metadata', None)
298 | state_dict = state_dict.copy()
299 | if metadata is not None:
300 | state_dict._metadata = metadata
301 |
302 | def load(module, prefix=''):
303 | local_metadata = {} if metadata is None else metadata.get(
304 | prefix[:-1], {})
305 | module._load_from_state_dict(
306 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
307 | for name, child in module._modules.items():
308 | if child is not None:
309 | load(child, prefix + name + '.')
310 |
311 | load(model, prefix=prefix)
312 |
313 | warn_missing_keys = []
314 | ignore_missing_keys = []
315 | for key in missing_keys:
316 | keep_flag = True
317 | for ignore_key in ignore_missing.split('|'):
318 | if ignore_key in key:
319 | keep_flag = False
320 | break
321 | if keep_flag:
322 | warn_missing_keys.append(key)
323 | else:
324 | ignore_missing_keys.append(key)
325 |
326 | missing_keys = warn_missing_keys
327 |
328 | if len(missing_keys) > 0:
329 | print("Weights of {} not initialized from pretrained model: {}".format(
330 | model.__class__.__name__, missing_keys))
331 | if len(unexpected_keys) > 0:
332 | print("Weights from pretrained model not used in {}: {}".format(
333 | model.__class__.__name__, unexpected_keys))
334 | if len(ignore_missing_keys) > 0:
335 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
336 | model.__class__.__name__, ignore_missing_keys))
337 | if len(error_msgs) > 0:
338 | print('\n'.join(error_msgs))
339 |
340 |
341 | class NativeScalerWithGradNormCount:
342 | state_dict_key = "amp_scaler"
343 |
344 | def __init__(self):
345 | self._scaler = torch.cuda.amp.GradScaler()
346 |
347 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
348 | self._scaler.scale(loss).backward(create_graph=create_graph)
349 | if update_grad:
350 | if clip_grad is not None:
351 | assert parameters is not None
352 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
353 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
354 | else:
355 | self._scaler.unscale_(optimizer)
356 | norm = get_grad_norm_(parameters)
357 | self._scaler.step(optimizer)
358 | self._scaler.update()
359 | else:
360 | norm = None
361 | return norm
362 |
363 | def state_dict(self):
364 | return self._scaler.state_dict()
365 |
366 | def load_state_dict(self, state_dict):
367 | self._scaler.load_state_dict(state_dict)
368 |
369 |
370 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
371 | if isinstance(parameters, torch.Tensor):
372 | parameters = [parameters]
373 | parameters = [p for p in parameters if p.grad is not None]
374 | norm_type = float(norm_type)
375 | if len(parameters) == 0:
376 | return torch.tensor(0.)
377 | device = parameters[0].grad.device
378 | if norm_type == inf:
379 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
380 | else:
381 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
382 | return total_norm
383 |
384 |
385 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
386 | start_warmup_value=0, warmup_steps=-1):
387 | warmup_schedule = np.array([])
388 | warmup_iters = warmup_epochs * niter_per_ep
389 | if warmup_steps > 0:
390 | warmup_iters = warmup_steps
391 | print("Set warmup steps = %d" % warmup_iters)
392 | if warmup_epochs > 0:
393 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
394 |
395 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
396 | schedule = np.array(
397 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
398 |
399 | schedule = np.concatenate((warmup_schedule, schedule))
400 |
401 | assert len(schedule) == epochs * niter_per_ep
402 | return schedule
403 |
404 |
405 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
406 | output_dir = Path(args.output_dir)
407 | epoch_name = str(epoch)
408 | if loss_scaler is not None:
409 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
410 | for checkpoint_path in checkpoint_paths:
411 | to_save = {
412 | 'model': model_without_ddp.state_dict(),
413 | 'optimizer': optimizer.state_dict(),
414 | 'epoch': epoch,
415 | 'scaler': loss_scaler.state_dict(),
416 | 'args': args,
417 | }
418 |
419 | if model_ema is not None:
420 | to_save['model_ema'] = get_state_dict(model_ema)
421 |
422 | save_on_master(to_save, checkpoint_path)
423 | else:
424 | client_state = {'epoch': epoch}
425 | if model_ema is not None:
426 | client_state['model_ema'] = get_state_dict(model_ema)
427 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
428 |
429 |
430 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
431 | output_dir = Path(args.output_dir)
432 | if loss_scaler is not None:
433 | # torch.amp
434 | if args.auto_resume and len(args.resume) == 0:
435 | import glob
436 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
437 | latest_ckpt = -1
438 | for ckpt in all_checkpoints:
439 | t = ckpt.split('-')[-1].split('.')[0]
440 | if t.isdigit():
441 | latest_ckpt = max(int(t), latest_ckpt)
442 | if latest_ckpt >= 0:
443 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
444 | print("Auto resume checkpoint: %s" % args.resume)
445 |
446 | if args.resume:
447 | if args.resume.startswith('https'):
448 | checkpoint = torch.hub.load_state_dict_from_url(
449 | args.resume, map_location='cpu', check_hash=True)
450 | else:
451 | checkpoint = torch.load(args.resume, map_location='cpu')
452 | model_without_ddp.load_state_dict(checkpoint['model'])
453 | print("Resume checkpoint %s" % args.resume)
454 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
455 | optimizer.load_state_dict(checkpoint['optimizer'])
456 | args.start_epoch = checkpoint['epoch'] + 1
457 | if hasattr(args, 'model_ema') and args.model_ema:
458 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
459 | if 'scaler' in checkpoint:
460 | loss_scaler.load_state_dict(checkpoint['scaler'])
461 | print("With optim & sched!")
462 | else:
463 | # deepspeed, only support '--auto_resume'.
464 | if args.auto_resume:
465 | import glob
466 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
467 | latest_ckpt = -1
468 | for ckpt in all_checkpoints:
469 | t = ckpt.split('-')[-1].split('.')[0]
470 | if t.isdigit():
471 | latest_ckpt = max(int(t), latest_ckpt)
472 | if latest_ckpt >= 0:
473 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
474 | print("Auto resume checkpoint: %d" % latest_ckpt)
475 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
476 | args.start_epoch = client_states['epoch'] + 1
477 | if model_ema is not None:
478 | if args.model_ema:
479 | _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
480 |
481 |
482 | def create_ds_config(args):
483 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
484 | with open(args.deepspeed_config, mode="w") as writer:
485 | ds_config = {
486 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
487 | "train_micro_batch_size_per_gpu": args.batch_size,
488 | "steps_per_print": 1000,
489 | "optimizer": {
490 | "type": "Adam",
491 | "adam_w_mode": True,
492 | "params": {
493 | "lr": args.lr,
494 | "weight_decay": args.weight_decay,
495 | "bias_correction": True,
496 | "betas": [
497 | 0.9,
498 | 0.999
499 | ],
500 | "eps": 1e-8
501 | }
502 | },
503 | "fp16": {
504 | "enabled": True,
505 | "loss_scale": 0,
506 | "initial_scale_power": 7,
507 | "loss_scale_window": 128
508 | }
509 | }
510 |
511 | writer.write(json.dumps(ds_config, indent=2))
512 |
513 | def multiple_samples_collate(batch, fold=False):
514 | """
515 | Collate function for repeated augmentation. Each instance in the batch has
516 | more than one sample.
517 | Args:
518 | batch (tuple or list): data batch to collate.
519 | Returns:
520 | (tuple): collated data batch.
521 | """
522 | inputs, labels, video_idx, extra_data = zip(*batch)
523 | inputs = [item for sublist in inputs for item in sublist]
524 | labels = [item for sublist in labels for item in sublist]
525 | video_idx = [item for sublist in video_idx for item in sublist]
526 | inputs, labels, video_idx, extra_data = (
527 | default_collate(inputs),
528 | default_collate(labels),
529 | default_collate(video_idx),
530 | default_collate(extra_data),
531 | )
532 | if fold:
533 | return [inputs], labels, video_idx, extra_data
534 | else:
535 | return inputs, labels, video_idx, extra_data
536 |
--------------------------------------------------------------------------------
/ucf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from numpy.lib.function_base import disp
4 | import torch
5 | import decord
6 | from PIL import Image
7 | from torchvision import transforms
8 | from random_erasing import RandomErasing
9 | import warnings
10 | from decord import VideoReader, cpu
11 | from torch.utils.data import Dataset
12 | import video_transforms as video_transforms
13 | import volume_transforms as volume_transforms
14 |
15 | class VideoClsDataset(Dataset):
16 | """Load your own video classification dataset."""
17 |
18 | def __init__(self, anno_path, data_path, mode='train', clip_len=8,
19 | frame_sample_rate=2, crop_size=224, short_side_size=256,
20 | new_height=256, new_width=340, keep_aspect_ratio=True,
21 | num_segment=1, num_crop=1, test_num_segment=10, test_num_crop=3,args=None):
22 | self.anno_path = anno_path
23 | self.data_path = data_path
24 | self.mode = mode
25 | self.clip_len = clip_len
26 | self.frame_sample_rate = frame_sample_rate
27 | self.crop_size = crop_size
28 | self.short_side_size = short_side_size
29 | self.new_height = new_height
30 | self.new_width = new_width
31 | self.keep_aspect_ratio = keep_aspect_ratio
32 | self.num_segment = num_segment
33 | self.test_num_segment = test_num_segment
34 | self.num_crop = num_crop
35 | self.test_num_crop = test_num_crop
36 | self.args = args
37 | self.aug = False
38 | self.rand_erase = False
39 | if self.mode in ['train']:
40 | self.aug = True
41 | if self.args.reprob > 0:
42 | self.rand_erase = True
43 | if VideoReader is None:
44 | raise ImportError("Unable to import `decord` which is required to read videos.")
45 |
46 | import pandas as pd
47 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
48 | self.dataset_samples = list(cleaned.values[:, 0])
49 | self.label_array = list(cleaned.values[:, 1])
50 |
51 | if (mode == 'train'):
52 | pass
53 |
54 | elif (mode == 'validation'):
55 | self.data_transform = video_transforms.Compose([
56 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
57 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
58 | volume_transforms.ClipToTensor(),
59 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
60 | std=[0.229, 0.224, 0.225])
61 | ])
62 | elif mode == 'test':
63 | self.data_resize = video_transforms.Compose([
64 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
65 | ])
66 | self.data_transform = video_transforms.Compose([
67 | volume_transforms.ClipToTensor(),
68 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
69 | std=[0.229, 0.224, 0.225])
70 | ])
71 | self.test_seg = []
72 | self.test_dataset = []
73 | self.test_label_array = []
74 | for ck in range(self.test_num_segment):
75 | for cp in range(self.test_num_crop):
76 | for idx in range(len(self.label_array)):
77 | sample_label = self.label_array[idx]
78 | self.test_label_array.append(sample_label)
79 | self.test_dataset.append(self.dataset_samples[idx])
80 | self.test_seg.append((ck, cp))
81 |
82 | def __getitem__(self, index):
83 | if self.mode == 'train':
84 | args = self.args
85 | scale_t = 1
86 |
87 | sample = self.dataset_samples[index]
88 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
89 | if len(buffer) == 0:
90 | while len(buffer) == 0:
91 | warnings.warn("video {} not correctly loaded during training".format(sample))
92 | index = np.random.randint(self.__len__())
93 | sample = self.dataset_samples[index]
94 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
95 |
96 | if args.num_sample > 1:
97 | frame_list = []
98 | label_list = []
99 | index_list = []
100 | for _ in range(args.num_sample):
101 | new_frames = self._aug_frame(buffer, args)
102 | label = self.label_array[index]
103 | frame_list.append(new_frames)
104 | label_list.append(label)
105 | index_list.append(index)
106 | return frame_list, label_list, index_list, {}
107 | else:
108 | buffer = self._aug_frame(buffer, args)
109 |
110 | return buffer, self.label_array[index], index, {}
111 |
112 | elif self.mode == 'validation':
113 | sample = self.dataset_samples[index]
114 | buffer = self.loadvideo_decord(sample)
115 | if len(buffer) == 0:
116 | while len(buffer) == 0:
117 | warnings.warn("video {} not correctly loaded during validation".format(sample))
118 | index = np.random.randint(self.__len__())
119 | sample = self.dataset_samples[index]
120 | buffer = self.loadvideo_decord(sample)
121 | buffer = self.data_transform(buffer)
122 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
123 |
124 | elif self.mode == 'test':
125 | sample = self.test_dataset[index]
126 | chunk_nb, split_nb = self.test_seg[index]
127 | buffer = self.loadvideo_decord(sample)
128 |
129 | while len(buffer) == 0:
130 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
131 | str(self.test_dataset[index]), chunk_nb, split_nb))
132 | index = np.random.randint(self.__len__())
133 | sample = self.test_dataset[index]
134 | chunk_nb, split_nb = self.test_seg[index]
135 | buffer = self.loadvideo_decord(sample)
136 |
137 | buffer = self.data_resize(buffer)
138 | if isinstance(buffer, list):
139 | buffer = np.stack(buffer, 0)
140 |
141 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
142 | / (self.test_num_crop - 1)
143 | temporal_step = max(1.0 * (buffer.shape[0] - self.clip_len) \
144 | / (self.test_num_segment - 1), 0)
145 | temporal_start = int(chunk_nb * temporal_step)
146 | spatial_start = int(split_nb * spatial_step)
147 | if buffer.shape[1] >= buffer.shape[2]:
148 | buffer = buffer[temporal_start:temporal_start + self.clip_len, \
149 | spatial_start:spatial_start + self.short_side_size, :, :]
150 | else:
151 | buffer = buffer[temporal_start:temporal_start + self.clip_len, \
152 | :, spatial_start:spatial_start + self.short_side_size, :]
153 |
154 | buffer = self.data_transform(buffer)
155 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
156 | chunk_nb, split_nb
157 | else:
158 | raise NameError('mode {} unkown'.format(self.mode))
159 |
160 | def _aug_frame(
161 | self,
162 | buffer,
163 | args,
164 | ):
165 |
166 | aug_transform = video_transforms.create_random_augment(
167 | input_size=(self.crop_size, self.crop_size),
168 | auto_augment=args.aa,
169 | interpolation=args.train_interpolation,
170 | )
171 |
172 | buffer = [
173 | transforms.ToPILImage()(frame) for frame in buffer
174 | ]
175 |
176 | buffer = aug_transform(buffer)
177 |
178 | buffer = [transforms.ToTensor()(img) for img in buffer]
179 | buffer = torch.stack(buffer) # T C H W
180 | buffer = buffer.permute(0, 2, 3, 1) # T H W C
181 |
182 | # T H W C
183 | buffer = tensor_normalize(
184 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
185 | )
186 | # T H W C -> C T H W.
187 | buffer = buffer.permute(3, 0, 1, 2)
188 | # Perform data augmentation.
189 | scl, asp = (
190 | [0.08, 1.0],
191 | [0.75, 1.3333],
192 | )
193 |
194 | buffer = spatial_sampling(
195 | buffer,
196 | spatial_idx=-1,
197 | min_scale=256,
198 | max_scale=320,
199 | crop_size=self.crop_size,
200 | random_horizontal_flip=False if args.data_set == 'SSV2' else True ,
201 | inverse_uniform_sampling=False,
202 | aspect_ratio=asp,
203 | scale=scl,
204 | motion_shift=False
205 | )
206 |
207 | if self.rand_erase:
208 | erase_transform = RandomErasing(
209 | args.reprob,
210 | mode=args.remode,
211 | max_count=args.recount,
212 | num_splits=args.recount,
213 | device="cpu",
214 | )
215 | buffer = buffer.permute(1, 0, 2, 3)
216 | buffer = erase_transform(buffer)
217 | buffer = buffer.permute(1, 0, 2, 3)
218 |
219 | return buffer
220 |
221 |
222 | def loadvideo_decord(self, sample, sample_rate_scale=1):
223 | """Load video content using Decord"""
224 | fname = sample
225 |
226 | if not (os.path.exists(fname)):
227 | return []
228 |
229 | # avoid hanging issue
230 | if os.path.getsize(fname) < 1 * 1024:
231 | print('SKIP: ', fname, " - ", os.path.getsize(fname))
232 | return []
233 | try:
234 | if self.keep_aspect_ratio:
235 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
236 | else:
237 | vr = VideoReader(fname, width=self.new_width, height=self.new_height,
238 | num_threads=1, ctx=cpu(0))
239 | except:
240 | print("video cannot be loaded by decord: ", fname)
241 | return []
242 |
243 | if self.mode == 'test':
244 | all_index = [x for x in range(0, len(vr), self.frame_sample_rate)]
245 | while len(all_index) < self.clip_len:
246 | all_index.append(all_index[-1])
247 | vr.seek(0)
248 | buffer = vr.get_batch(all_index).asnumpy()
249 | return buffer
250 |
251 | # handle temporal segments
252 | converted_len = int(self.clip_len * self.frame_sample_rate)
253 | seg_len = len(vr) // self.num_segment
254 |
255 | all_index = []
256 | for i in range(self.num_segment):
257 | if seg_len <= converted_len:
258 | index = np.linspace(0, seg_len, num=seg_len // self.frame_sample_rate)
259 | index = np.concatenate((index, np.ones(self.clip_len - seg_len // self.frame_sample_rate) * seg_len))
260 | index = np.clip(index, 0, seg_len - 1).astype(np.int64)
261 | else:
262 | end_idx = np.random.randint(converted_len, seg_len)
263 | str_idx = end_idx - converted_len
264 | index = np.linspace(str_idx, end_idx, num=self.clip_len)
265 | index = np.clip(index, str_idx, end_idx - 1).astype(np.int64)
266 | index = index + i*seg_len
267 | all_index.extend(list(index))
268 |
269 | all_index = all_index[::int(sample_rate_scale)]
270 | vr.seek(0)
271 | buffer = vr.get_batch(all_index).asnumpy()
272 | return buffer
273 |
274 | def __len__(self):
275 | if self.mode != 'test':
276 | return len(self.dataset_samples)
277 | else:
278 | return len(self.test_dataset)
279 |
280 |
281 | def spatial_sampling(
282 | frames,
283 | spatial_idx=-1,
284 | min_scale=256,
285 | max_scale=320,
286 | crop_size=224,
287 | random_horizontal_flip=True,
288 | inverse_uniform_sampling=False,
289 | aspect_ratio=None,
290 | scale=None,
291 | motion_shift=False,
292 | ):
293 | """
294 | Perform spatial sampling on the given video frames. If spatial_idx is
295 | -1, perform random scale, random crop, and random flip on the given
296 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
297 | with the given spatial_idx.
298 | Args:
299 | frames (tensor): frames of images sampled from the video. The
300 | dimension is `num frames` x `height` x `width` x `channel`.
301 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
302 | or 2, perform left, center, right crop if width is larger than
303 | height, and perform top, center, buttom crop if height is larger
304 | than width.
305 | min_scale (int): the minimal size of scaling.
306 | max_scale (int): the maximal size of scaling.
307 | crop_size (int): the size of height and width used to crop the
308 | frames.
309 | inverse_uniform_sampling (bool): if True, sample uniformly in
310 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
311 | scale. If False, take a uniform sample from [min_scale,
312 | max_scale].
313 | aspect_ratio (list): Aspect ratio range for resizing.
314 | scale (list): Scale range for resizing.
315 | motion_shift (bool): Whether to apply motion shift for resizing.
316 | Returns:
317 | frames (tensor): spatially sampled frames.
318 | """
319 | assert spatial_idx in [-1, 0, 1, 2]
320 | if spatial_idx == -1:
321 | if aspect_ratio is None and scale is None:
322 | frames, _ = video_transforms.random_short_side_scale_jitter(
323 | images=frames,
324 | min_size=min_scale,
325 | max_size=max_scale,
326 | inverse_uniform_sampling=inverse_uniform_sampling,
327 | )
328 | frames, _ = video_transforms.random_crop(frames, crop_size)
329 | else:
330 | transform_func = (
331 | video_transforms.random_resized_crop_with_shift
332 | if motion_shift
333 | else video_transforms.random_resized_crop
334 | )
335 | frames = transform_func(
336 | images=frames,
337 | target_height=crop_size,
338 | target_width=crop_size,
339 | scale=scale,
340 | ratio=aspect_ratio,
341 | )
342 | if random_horizontal_flip:
343 | frames, _ = video_transforms.horizontal_flip(0.5, frames)
344 | else:
345 | # The testing is deterministic and no jitter should be performed.
346 | # min_scale, max_scale, and crop_size are expect to be the same.
347 | assert len({min_scale, max_scale, crop_size}) == 1
348 | frames, _ = video_transforms.random_short_side_scale_jitter(
349 | frames, min_scale, max_scale
350 | )
351 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
352 | return frames
353 |
354 |
355 | def tensor_normalize(tensor, mean, std):
356 | """
357 | Normalize a given tensor by subtracting the mean and dividing the std.
358 | Args:
359 | tensor (tensor): tensor to normalize.
360 | mean (tensor or list): mean value to subtract.
361 | std (tensor or list): std to divide.
362 | """
363 | if tensor.dtype == torch.uint8:
364 | tensor = tensor.float()
365 | tensor = tensor / 255.0
366 | if type(mean) == list:
367 | mean = torch.tensor(mean)
368 | if type(std) == list:
369 | std = torch.tensor(std)
370 | tensor = tensor - mean
371 | tensor = tensor / std
372 | return tensor
373 |
374 |
375 | class VideoMS(torch.utils.data.Dataset):
376 | """Load your own video classification dataset.
377 | Parameters
378 | ----------
379 | root : str, required.
380 | Path to the root folder storing the dataset.
381 | setting : str, required.
382 | A text file describing the dataset, each line per video sample.
383 | There are three items in each line: (1) video path; (2) video length and (3) video label.
384 | train : bool, default True.
385 | Whether to load the training or validation set.
386 | test_mode : bool, default False.
387 | Whether to perform evaluation on the test set.
388 | Usually there is three-crop or ten-crop evaluation strategy involved.
389 | name_pattern : str, default None.
390 | The naming pattern of the decoded video frames.
391 | For example, img_00012.jpg.
392 | video_ext : str, default 'mp4'.
393 | If video_loader is set to True, please specify the video format accordinly.
394 | is_color : bool, default True.
395 | Whether the loaded image is color or grayscale.
396 | modality : str, default 'rgb'.
397 | Input modalities, we support only rgb video frames for now.
398 | Will add support for rgb difference image and optical flow image later.
399 | num_segments : int, default 1.
400 | Number of segments to evenly divide the video into clips.
401 | A useful technique to obtain global video-level information.
402 | Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
403 | num_crop : int, default 1.
404 | Number of crops for each image. default is 1.
405 | Common choices are three crops and ten crops during evaluation.
406 | new_length : int, default 1.
407 | The length of input video clip. Default is a single image, but it can be multiple video frames.
408 | For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
409 | new_step : int, default 1.
410 | Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
411 | new_step=2 means we will extract a video clip of every other frame.
412 | temporal_jitter : bool, default False.
413 | Whether to temporally jitter if new_step > 1.
414 | video_loader : bool, default False.
415 | Whether to use video loader to load data.
416 | use_decord : bool, default True.
417 | Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
418 | transform : function, default None.
419 | A function that takes data and label and transforms them.
420 | data_aug : str, default 'v1'.
421 | Different types of data augmentation auto. Supports v1, v2, v3 and v4.
422 | lazy_init : bool, default False.
423 | If set to True, build a dataset instance without loading any dataset.
424 | """
425 | def __init__(self,
426 | root,
427 | setting,
428 | train=True,
429 | test_mode=False,
430 | name_pattern='img_%05d.jpg',
431 | video_ext='mp4',
432 | is_color=True,
433 | modality='rgb',
434 | num_segments=1,
435 | num_crop=1,
436 | new_length=1,
437 | new_step=1,
438 | transform=None,
439 | temporal_jitter=False,
440 | video_loader=False,
441 | use_decord=False,
442 | lazy_init=False):
443 |
444 | super(VideoMS, self).__init__()
445 | self.root = root
446 | self.setting = setting
447 | self.train = train
448 | self.test_mode = test_mode
449 | self.is_color = is_color
450 | self.modality = modality
451 | self.num_segments = num_segments
452 | self.num_crop = num_crop
453 | self.new_length = new_length
454 | self.new_step = new_step
455 | self.skip_length = self.new_length * self.new_step
456 | self.temporal_jitter = temporal_jitter
457 | self.name_pattern = name_pattern
458 | self.video_loader = video_loader
459 | self.video_ext = video_ext
460 | self.use_decord = use_decord
461 | self.transform = transform
462 | self.lazy_init = lazy_init
463 |
464 |
465 | if not self.lazy_init:
466 | self.clips = self._make_dataset(root, setting)
467 | if len(self.clips) == 0:
468 | raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
469 | "Check your data directory (opt.data-dir)."))
470 |
471 | def __getitem__(self, index):
472 |
473 | directory, target = self.clips[index]
474 | if self.video_loader:
475 | if '.' in directory.split('/')[-1]:
476 | # data in the "setting" file already have extension, e.g., demo.mp4
477 | video_name = directory
478 | else:
479 | # data in the "setting" file do not have extension, e.g., demo
480 | # So we need to provide extension (i.e., .mp4) to complete the file name.
481 | video_name = '{}.{}'.format(directory, self.video_ext)
482 |
483 | decord_vr = decord.VideoReader(video_name, num_threads=1)
484 | duration = len(decord_vr)
485 |
486 | segment_indices, skip_offsets = self._sample_train_indices(duration)
487 |
488 | images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets)
489 |
490 | process_data, mask = self.transform((images, None)) # T*C,H,W
491 | process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
492 |
493 | return (process_data, mask)
494 |
495 | def __len__(self):
496 | return len(self.clips)
497 |
498 | def _make_dataset(self, directory, setting):
499 | if not os.path.exists(setting):
500 | raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
501 | clips = []
502 | with open(setting) as split_f:
503 | data = split_f.readlines()
504 | for line in data:
505 | line_info = line.split(' ')
506 | # line format: video_path, video_duration, video_label
507 | if len(line_info) < 2:
508 | raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
509 | clip_path = os.path.join(line_info[0])
510 | target = int(line_info[1])
511 | item = (clip_path, target)
512 | clips.append(item)
513 | return clips
514 |
515 | def _sample_train_indices(self, num_frames):
516 | average_duration = (num_frames - self.skip_length + 1) // self.num_segments
517 | if average_duration > 0:
518 | offsets = np.multiply(list(range(self.num_segments)),
519 | average_duration)
520 | offsets = offsets + np.random.randint(average_duration,
521 | size=self.num_segments)
522 | elif num_frames > max(self.num_segments, self.skip_length):
523 | offsets = np.sort(np.random.randint(
524 | num_frames - self.skip_length + 1,
525 | size=self.num_segments))
526 | else:
527 | offsets = np.zeros((self.num_segments,))
528 |
529 | if self.temporal_jitter:
530 | skip_offsets = np.random.randint(
531 | self.new_step, size=self.skip_length // self.new_step)
532 | else:
533 | skip_offsets = np.zeros(
534 | self.skip_length // self.new_step, dtype=int)
535 | return offsets + 1, skip_offsets
536 |
537 |
538 | def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets):
539 | sampled_list = []
540 | frame_id_list = []
541 | for seg_ind in indices:
542 | offset = int(seg_ind)
543 | for i, _ in enumerate(range(0, self.skip_length, self.new_step)):
544 | if offset + skip_offsets[i] <= duration:
545 | frame_id = offset + skip_offsets[i] - 1
546 | else:
547 | frame_id = offset - 1
548 | frame_id_list.append(frame_id)
549 | if offset + self.new_step < duration:
550 | offset += self.new_step
551 | try:
552 | video_data = video_reader.get_batch(frame_id_list).asnumpy()
553 | sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
554 | except:
555 | raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration))
556 | return sampled_list
557 |
--------------------------------------------------------------------------------
/run_class_finetuning.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import json
8 | import os
9 | from functools import partial
10 | from pathlib import Path
11 | from collections import OrderedDict
12 |
13 | from mixup import Mixup
14 | from timm.models import create_model
15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16 | from timm.utils import ModelEma
17 | from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
18 |
19 | from datasets import build_dataset
20 | from engine_for_finetuning import train_one_epoch, validation_one_epoch, final_test, merge
21 | from utils import NativeScalerWithGradNormCount as NativeScaler
22 | from utils import multiple_samples_collate
23 | import utils
24 | import modeling_finetune
25 |
26 |
27 | def get_args():
28 | parser = argparse.ArgumentParser('VideoMS fine-tuning and evaluation script for video classification', add_help=False)
29 | parser.add_argument('--batch_size', default=64, type=int)
30 | parser.add_argument('--epochs', default=30, type=int)
31 | parser.add_argument('--update_freq', default=1, type=int)
32 | parser.add_argument('--save_ckpt_freq', default=100, type=int)
33 |
34 | # Model parameters
35 | parser.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
36 | help='Name of model to train')
37 | parser.add_argument('--tubelet_size', type=int, default= 2)
38 | parser.add_argument('--input_size', default=224, type=int,
39 | help='videos input size')
40 |
41 | parser.add_argument('--fc_drop_rate', type=float, default=0.0, metavar='PCT',
42 | help='Dropout rate (default: 0.)')
43 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
44 | help='Dropout rate (default: 0.)')
45 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
46 | help='Attention dropout rate (default: 0.)')
47 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
48 | help='Drop path rate (default: 0.1)')
49 |
50 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
51 | parser.add_argument('--model_ema', action='store_true', default=False)
52 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
53 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
54 |
55 | # Optimizer parameters
56 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
57 | help='Optimizer (default: "adamw"')
58 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
59 | help='Optimizer Epsilon (default: 1e-8)')
60 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
61 | help='Optimizer Betas (default: None, use opt default)')
62 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
63 | help='Clip gradient norm (default: None, no clipping)')
64 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
65 | help='SGD momentum (default: 0.9)')
66 | parser.add_argument('--weight_decay', type=float, default=0.05,
67 | help='weight decay (default: 0.05)')
68 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
69 | weight decay. We use a cosine schedule for WD and using a larger decay by
70 | the end of training improves performance for ViTs.""")
71 |
72 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
73 | help='learning rate (default: 1e-3)')
74 | parser.add_argument('--layer_decay', type=float, default=0.75)
75 |
76 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
77 | help='warmup learning rate (default: 1e-6)')
78 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
79 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
80 |
81 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
82 | help='epochs to warmup LR, if scheduler supports')
83 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
84 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
85 |
86 | # Augmentation parameters
87 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
88 | help='Color jitter factor (default: 0.4)')
89 | parser.add_argument('--num_sample', type=int, default=2,
90 | help='Repeated_aug (default: 2)')
91 | parser.add_argument('--aa', type=str, default='rand-m7-n4-mstd0.5-inc1', metavar='NAME',
92 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m7-n4-mstd0.5-inc1)'),
93 | parser.add_argument('--smoothing', type=float, default=0.1,
94 | help='Label smoothing (default: 0.1)')
95 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
96 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
97 |
98 | # Evaluation parameters
99 | parser.add_argument('--crop_pct', type=float, default=None)
100 | parser.add_argument('--short_side_size', type=int, default=224)
101 | parser.add_argument('--test_num_segment', type=int, default=5)
102 | parser.add_argument('--test_num_crop', type=int, default=3)
103 |
104 | # Random Erase params
105 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
106 | help='Random erase prob (default: 0.25)')
107 | parser.add_argument('--remode', type=str, default='pixel',
108 | help='Random erase mode (default: "pixel")')
109 | parser.add_argument('--recount', type=int, default=1,
110 | help='Random erase count (default: 1)')
111 | parser.add_argument('--resplit', action='store_true', default=False,
112 | help='Do not random erase first (clean) augmentation split')
113 |
114 | # Mixup params
115 | parser.add_argument('--mixup', type=float, default=0.0,
116 | help='mixup alpha, mixup enabled if > 0.')
117 | parser.add_argument('--cutmix', type=float, default=0.0,
118 | help='cutmix alpha, cutmix enabled if > 0.')
119 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
120 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
121 | parser.add_argument('--mixup_prob', type=float, default=1.0,
122 | help='Probability of performing mixup or cutmix when either/both is enabled')
123 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
124 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
125 | parser.add_argument('--mixup_mode', type=str, default='batch',
126 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
127 |
128 | # Finetuning params
129 | parser.add_argument('--finetune', default='', help='finetune from checkpoint')
130 | parser.add_argument('--model_key', default='model|module', type=str)
131 | parser.add_argument('--model_prefix', default='', type=str)
132 | parser.add_argument('--init_scale', default=0.001, type=float)
133 | parser.add_argument('--use_checkpoint', action='store_true')
134 | parser.set_defaults(use_checkpoint=False)
135 | parser.add_argument('--use_mean_pooling', action='store_true')
136 | parser.set_defaults(use_mean_pooling=True)
137 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
138 |
139 | # Dataset parameters
140 | parser.add_argument('--data_path', default='/path/to/list_ucf101', type=str,
141 | help='dataset path')
142 | parser.add_argument('--eval_data_path', default=None, type=str,
143 | help='dataset path for evaluation')
144 | parser.add_argument('--nb_classes', default=101, type=int,
145 | help='number of the classification types')
146 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
147 | parser.add_argument('--num_segments', type=int, default= 1)
148 | parser.add_argument('--num_frames', type=int, default= 16)
149 | parser.add_argument('--sampling_rate', type=int, default= 4)
150 | parser.add_argument('--data_set', default='UCF101', choices=['UCF101', 'HMDB51', 'OSCC'],
151 | type=str, help='dataset')
152 | parser.add_argument('--output_dir', default='',
153 | help='path where to save, empty for no saving')
154 | parser.add_argument('--log_dir', default=None,
155 | help='path where to tensorboard log')
156 | parser.add_argument('--device', default='cuda',
157 | help='device to use for training / testing')
158 | parser.add_argument('--seed', default=0, type=int)
159 | parser.add_argument('--resume', default='',
160 | help='resume from checkpoint')
161 | parser.add_argument('--auto_resume', action='store_true')
162 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
163 | parser.set_defaults(auto_resume=True)
164 |
165 | parser.add_argument('--save_ckpt', action='store_true')
166 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
167 | parser.set_defaults(save_ckpt=True)
168 |
169 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
170 | help='start epoch')
171 | parser.add_argument('--eval', action='store_true',
172 | help='Perform evaluation only')
173 | parser.add_argument('--dist_eval', action='store_true', default=False,
174 | help='Enabling distributed evaluation')
175 | parser.add_argument('--num_workers', default=10, type=int)
176 | parser.add_argument('--pin_mem', action='store_true',
177 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
178 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
179 | parser.set_defaults(pin_mem=True)
180 |
181 | # distributed training parameters
182 | parser.add_argument('--world_size', default=1, type=int,
183 | help='number of distributed processes')
184 | parser.add_argument('--local_rank', default=-1, type=int)
185 | parser.add_argument('--dist_on_itp', action='store_true')
186 | parser.add_argument('--dist_url', default='env://',
187 | help='url used to set up distributed training')
188 |
189 | parser.add_argument('--enable_deepspeed', action='store_true', default=False)
190 |
191 | parser.add_argument('--mcm', action='store_true', default=False)
192 | parser.add_argument('--mcm_ratio', default=0.4, type=float)
193 |
194 | known_args, _ = parser.parse_known_args()
195 |
196 | if known_args.enable_deepspeed:
197 | try:
198 | import deepspeed
199 | from deepspeed import DeepSpeedConfig
200 | parser = deepspeed.add_config_arguments(parser)
201 | ds_init = deepspeed.initialize
202 | except:
203 | print("Please 'pip install deepspeed'")
204 | exit(0)
205 | else:
206 | ds_init = None
207 |
208 | return parser.parse_args(), ds_init
209 |
210 |
211 | def main(args, ds_init):
212 | utils.init_distributed_mode(args)
213 |
214 | if ds_init is not None:
215 | utils.create_ds_config(args)
216 |
217 | print(args)
218 |
219 | device = torch.device(args.device)
220 |
221 | # fix the seed for reproducibility
222 | seed = args.seed + utils.get_rank()
223 | torch.manual_seed(seed)
224 | np.random.seed(seed)
225 | # random.seed(seed)
226 |
227 | cudnn.benchmark = True
228 |
229 | dataset_train, args.nb_classes = build_dataset(is_train=True, test_mode=False, args=args)
230 | if args.disable_eval_during_finetuning:
231 | dataset_val = None
232 | else:
233 | dataset_val, _ = build_dataset(is_train=False, test_mode=False, args=args)
234 | dataset_test, _ = build_dataset(is_train=False, test_mode=True, args=args)
235 |
236 |
237 | num_tasks = utils.get_world_size()
238 | global_rank = utils.get_rank()
239 | sampler_train = torch.utils.data.DistributedSampler(
240 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
241 | )
242 | print("Sampler_train = %s" % str(sampler_train))
243 | if args.dist_eval:
244 | if len(dataset_val) % num_tasks != 0:
245 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
246 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
247 | 'equal num of samples per-process.')
248 | sampler_val = torch.utils.data.DistributedSampler(
249 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
250 | sampler_test = torch.utils.data.DistributedSampler(
251 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
252 | else:
253 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
254 |
255 | if global_rank == 0 and args.log_dir is not None:
256 | os.makedirs(args.log_dir, exist_ok=True)
257 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
258 | else:
259 | log_writer = None
260 |
261 | if args.num_sample > 1:
262 | collate_func = partial(multiple_samples_collate, fold=False)
263 | else:
264 | collate_func = None
265 |
266 | data_loader_train = torch.utils.data.DataLoader(
267 | dataset_train, sampler=sampler_train,
268 | batch_size=args.batch_size,
269 | num_workers=args.num_workers,
270 | pin_memory=args.pin_mem,
271 | drop_last=True,
272 | collate_fn=collate_func,
273 | )
274 |
275 | if dataset_val is not None:
276 | data_loader_val = torch.utils.data.DataLoader(
277 | dataset_val, sampler=sampler_val,
278 | batch_size=int(1.5 * args.batch_size),
279 | num_workers=args.num_workers,
280 | pin_memory=args.pin_mem,
281 | drop_last=False
282 | )
283 | else:
284 | data_loader_val = None
285 |
286 | if dataset_test is not None:
287 | data_loader_test = torch.utils.data.DataLoader(
288 | dataset_test, sampler=sampler_test,
289 | batch_size=args.batch_size,
290 | num_workers=args.num_workers,
291 | pin_memory=args.pin_mem,
292 | drop_last=False
293 | )
294 | else:
295 | data_loader_test = None
296 |
297 | mixup_fn = None
298 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
299 | if mixup_active:
300 | print("Mixup is activated!")
301 | mixup_fn = Mixup(
302 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
303 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
304 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
305 |
306 | model = create_model(
307 | args.model,
308 | pretrained=False,
309 | num_classes=args.nb_classes,
310 | all_frames=args.num_frames * args.num_segments,
311 | tubelet_size=args.tubelet_size,
312 | fc_drop_rate=args.fc_drop_rate,
313 | drop_rate=args.drop,
314 | drop_path_rate=args.drop_path,
315 | attn_drop_rate=args.attn_drop_rate,
316 | drop_block_rate=None,
317 | use_checkpoint=args.use_checkpoint,
318 | use_mean_pooling=args.use_mean_pooling,
319 | init_scale=args.init_scale,
320 | mcm=args.mcm,
321 | mcm_ratio=args.mcm_ratio
322 | )
323 |
324 | patch_size = model.patch_embed.patch_size
325 | print("Patch size = %s" % str(patch_size))
326 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
327 | args.patch_size = patch_size
328 |
329 | if args.finetune:
330 | if args.finetune.startswith('https'):
331 | checkpoint = torch.hub.load_state_dict_from_url(
332 | args.finetune, map_location='cpu', check_hash=True)
333 | else:
334 | checkpoint = torch.load(args.finetune, map_location='cpu')
335 |
336 | print("Load ckpt from %s" % args.finetune)
337 | checkpoint_model = None
338 | for model_key in args.model_key.split('|'):
339 | if model_key in checkpoint:
340 | checkpoint_model = checkpoint[model_key]
341 | print("Load state_dict by model_key = %s" % model_key)
342 | break
343 | if checkpoint_model is None:
344 | checkpoint_model = checkpoint
345 | state_dict = model.state_dict()
346 | for k in ['head.weight', 'head.bias']:
347 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
348 | print(f"Removing key {k} from pretrained checkpoint")
349 | del checkpoint_model[k]
350 |
351 | all_keys = list(checkpoint_model.keys())
352 | new_dict = OrderedDict()
353 | for key in all_keys:
354 | if key.startswith('backbone.'):
355 | new_dict[key[9:]] = checkpoint_model[key]
356 | elif key.startswith('encoder.'):
357 | new_dict[key[8:]] = checkpoint_model[key]
358 | else:
359 | new_dict[key] = checkpoint_model[key]
360 | checkpoint_model = new_dict
361 |
362 | # interpolate position embedding
363 | if 'pos_embed' in checkpoint_model:
364 | pos_embed_checkpoint = checkpoint_model['pos_embed']
365 | embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
366 | num_patches = model.patch_embed.num_patches #
367 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
368 |
369 | # height (== width) for the checkpoint position embedding
370 | orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(args.num_frames // model.patch_embed.tubelet_size)) ** 0.5)
371 | # height (== width) for the new position embedding
372 | new_size = int((num_patches // (args.num_frames // model.patch_embed.tubelet_size) )** 0.5)
373 | # class_token and dist_token are kept unchanged
374 | if orig_size != new_size:
375 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
376 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
377 | # only the position tokens are interpolated
378 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
379 | # B, L, C -> BT, H, W, C -> BT, C, H, W
380 | pos_tokens = pos_tokens.reshape(-1, args.num_frames // model.patch_embed.tubelet_size, orig_size, orig_size, embedding_size)
381 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
382 | pos_tokens = torch.nn.functional.interpolate(
383 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
384 | # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
385 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, args.num_frames // model.patch_embed.tubelet_size, new_size, new_size, embedding_size)
386 | pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
387 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
388 | checkpoint_model['pos_embed'] = new_pos_embed
389 |
390 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
391 |
392 | model.to(device)
393 |
394 | model_ema = None
395 | if args.model_ema:
396 | model_ema = ModelEma(
397 | model,
398 | decay=args.model_ema_decay,
399 | device='cpu' if args.model_ema_force_cpu else '',
400 | resume='')
401 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
402 |
403 | model_without_ddp = model
404 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
405 |
406 | print("Model = %s" % str(model_without_ddp))
407 | print('number of params:', n_parameters)
408 |
409 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
410 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
411 | args.lr = args.lr * total_batch_size / 256
412 | args.min_lr = args.min_lr * total_batch_size / 256
413 | args.warmup_lr = args.warmup_lr * total_batch_size / 256
414 | print("LR = %.8f" % args.lr)
415 | print("Batch size = %d" % total_batch_size)
416 | print("Update frequent = %d" % args.update_freq)
417 | print("Number of training examples = %d" % len(dataset_train))
418 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
419 |
420 | num_layers = model_without_ddp.get_num_layers()
421 | if args.layer_decay < 1.0:
422 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
423 | else:
424 | assigner = None
425 |
426 | if assigner is not None:
427 | print("Assigned values = %s" % str(assigner.values))
428 |
429 | skip_weight_decay_list = model.no_weight_decay()
430 | print("Skip weight decay list: ", skip_weight_decay_list)
431 |
432 | if args.enable_deepspeed:
433 | loss_scaler = None
434 | optimizer_params = get_parameter_groups(
435 | model, args.weight_decay, skip_weight_decay_list,
436 | assigner.get_layer_id if assigner is not None else None,
437 | assigner.get_scale if assigner is not None else None)
438 | model, optimizer, _, _ = ds_init(
439 | args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
440 | )
441 |
442 | print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
443 | assert model.gradient_accumulation_steps() == args.update_freq
444 | else:
445 | if args.distributed:
446 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
447 | model_without_ddp = model.module
448 |
449 | optimizer = create_optimizer(
450 | args, model_without_ddp, skip_list=skip_weight_decay_list,
451 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
452 | get_layer_scale=assigner.get_scale if assigner is not None else None)
453 | loss_scaler = NativeScaler()
454 |
455 | print("Use step level LR scheduler!")
456 | lr_schedule_values = utils.cosine_scheduler(
457 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
458 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
459 | )
460 | if args.weight_decay_end is None:
461 | args.weight_decay_end = args.weight_decay
462 | wd_schedule_values = utils.cosine_scheduler(
463 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
464 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
465 |
466 | if mixup_fn is not None:
467 | # smoothing is handled with mixup label transform
468 | criterion = SoftTargetCrossEntropy()
469 | elif args.smoothing > 0.:
470 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
471 | else:
472 | criterion = torch.nn.CrossEntropyLoss()
473 |
474 | print("criterion = %s" % str(criterion))
475 |
476 | utils.auto_load_model(
477 | args=args, model=model, model_without_ddp=model_without_ddp,
478 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
479 |
480 | if args.eval:
481 | preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
482 | test_stats = final_test(data_loader_test, model, device, preds_file)
483 | torch.distributed.barrier()
484 | if global_rank == 0:
485 | print("Start merging results...")
486 | final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
487 | print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
488 | log_stats = {'Final top-1': final_top1,
489 | 'Final Top-5': final_top1}
490 | if args.output_dir and utils.is_main_process():
491 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
492 | f.write(json.dumps(log_stats) + "\n")
493 | exit(0)
494 |
495 |
496 | print(f"Start training for {args.epochs} epochs")
497 | start_time = time.time()
498 | max_accuracy = 0.0
499 | for epoch in range(args.start_epoch, args.epochs):
500 | if args.distributed:
501 | data_loader_train.sampler.set_epoch(epoch)
502 | if log_writer is not None:
503 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
504 | train_stats = train_one_epoch(
505 | model, criterion, data_loader_train, optimizer,
506 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
507 | log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
508 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
509 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
510 | )
511 | if args.output_dir and args.save_ckpt:
512 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
513 | utils.save_model(
514 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
515 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
516 | if data_loader_val is not None:
517 | test_stats = validation_one_epoch(data_loader_val, model, device)
518 | print(f"Accuracy of the network on the {len(dataset_val)} val videos: {test_stats['acc1']:.1f}%")
519 | if max_accuracy < test_stats["acc1"]:
520 | max_accuracy = test_stats["acc1"]
521 | if args.output_dir and args.save_ckpt:
522 | utils.save_model(
523 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
524 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
525 |
526 | print(f'Max accuracy: {max_accuracy:.2f}%')
527 | if log_writer is not None:
528 | log_writer.update(val_acc1=test_stats['acc1'], head="perf", step=epoch)
529 | log_writer.update(val_acc5=test_stats['acc5'], head="perf", step=epoch)
530 | log_writer.update(val_loss=test_stats['loss'], head="perf", step=epoch)
531 |
532 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
533 | **{f'val_{k}': v for k, v in test_stats.items()},
534 | 'epoch': epoch,
535 | 'n_parameters': n_parameters}
536 | else:
537 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
538 | 'epoch': epoch,
539 | 'n_parameters': n_parameters}
540 | if args.output_dir and utils.is_main_process():
541 | if log_writer is not None:
542 | log_writer.flush()
543 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
544 | f.write(json.dumps(log_stats) + "\n")
545 |
546 | preds_file = os.path.join(args.output_dir, str(global_rank) + '.txt')
547 | test_stats = final_test(data_loader_test, model, device, preds_file)
548 | torch.distributed.barrier()
549 | if global_rank == 0:
550 | print("Start merging results...")
551 | final_top1 ,final_top5 = merge(args.output_dir, num_tasks)
552 | print(f"Accuracy of the network on the {len(dataset_test)} test videos: Top-1: {final_top1:.2f}%, Top-5: {final_top5:.2f}%")
553 | log_stats = {'Final top-1': final_top1,
554 | 'Final Top-5': final_top5}
555 | if args.output_dir and utils.is_main_process():
556 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
557 | f.write(json.dumps(log_stats) + "\n")
558 |
559 |
560 | total_time = time.time() - start_time
561 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
562 | print('Training time {}'.format(total_time_str))
563 |
564 |
565 | if __name__ == '__main__':
566 | opts, ds_init = get_args()
567 | if opts.output_dir:
568 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
569 | main(opts, ds_init)
570 |
--------------------------------------------------------------------------------