" in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv` ( here `test.csv` is the same as `val.csv`). The format of `*.csv` file is like:
28 |
29 | ```
30 | dataset_root/video_1.mp4 label_1
31 | dataset_root/video_2.mp4 label_2
32 | dataset_root/video_3.mp4 label_3
33 | ...
34 | dataset_root/video_N.mp4 label_N
35 | ```
36 |
37 | ### Note:
38 |
39 | We use [decord](https://github.com/dmlc/decord) to decode the videos **on the fly** during both pre-training and fine-tuning phases.
--------------------------------------------------------------------------------
/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import cv2
3 | import numpy as np
4 | import PIL
5 | import torch
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [
18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19 | ]
20 | else:
21 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
22 | 'but got list of {0}'.format(type(clip[0])))
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation='bilinear'):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
32 | and im_h == size):
33 | return clip
34 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
35 | size = (new_w, new_h)
36 | else:
37 | size = size[0], size[1]
38 | if interpolation == 'bilinear':
39 | np_inter = cv2.INTER_LINEAR
40 | else:
41 | np_inter = cv2.INTER_NEAREST
42 | scaled = [
43 | cv2.resize(img, size, interpolation=np_inter) for img in clip
44 | ]
45 | elif isinstance(clip[0], PIL.Image.Image):
46 | if isinstance(size, numbers.Number):
47 | im_w, im_h = clip[0].size
48 | # Min spatial dim already matches minimal size
49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
50 | and im_h == size):
51 | return clip
52 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
53 | size = (new_w, new_h)
54 | else:
55 | size = size[1], size[0]
56 | if interpolation == 'bilinear':
57 | pil_inter = PIL.Image.BILINEAR
58 | else:
59 | pil_inter = PIL.Image.NEAREST
60 | scaled = [img.resize(size, pil_inter) for img in clip]
61 | else:
62 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
63 | 'but got list of {0}'.format(type(clip[0])))
64 | return scaled
65 |
66 |
67 | def get_resize_sizes(im_h, im_w, size):
68 | if im_w < im_h:
69 | ow = size
70 | oh = int(size * im_h / im_w)
71 | else:
72 | oh = size
73 | ow = int(size * im_w / im_h)
74 | return oh, ow
75 |
76 |
77 | def normalize(clip, mean, std, inplace=False):
78 | if not _is_tensor_clip(clip):
79 | raise TypeError('tensor is not a torch clip.')
80 |
81 | if not inplace:
82 | clip = clip.clone()
83 |
84 | dtype = clip.dtype
85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
88 |
89 | return clip
90 |
--------------------------------------------------------------------------------
/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 |
6 | def convert_img(img):
7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format
8 | """
9 | if len(img.shape) == 3:
10 | img = img.transpose(2, 0, 1)
11 | if len(img.shape) == 2:
12 | img = np.expand_dims(img, 0)
13 | return img
14 |
15 |
16 | class ClipToTensor(object):
17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
19 | """
20 |
21 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
22 | self.channel_nb = channel_nb
23 | self.div_255 = div_255
24 | self.numpy = numpy
25 |
26 | def __call__(self, clip):
27 | """
28 | Args: clip (list of numpy.ndarray): clip (list of images)
29 | to be converted to tensor.
30 | """
31 | # Retrieve shape
32 | if isinstance(clip[0], np.ndarray):
33 | h, w, ch = clip[0].shape
34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
35 | ch)
36 | elif isinstance(clip[0], Image.Image):
37 | w, h = clip[0].size
38 | else:
39 | raise TypeError('Expected numpy.ndarray or PIL.Image\
40 | but got list of {0}'.format(type(clip[0])))
41 |
42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
43 |
44 | # Convert
45 | for img_idx, img in enumerate(clip):
46 | if isinstance(img, np.ndarray):
47 | pass
48 | elif isinstance(img, Image.Image):
49 | img = np.array(img, copy=False)
50 | else:
51 | raise TypeError('Expected numpy.ndarray or PIL.Image\
52 | but got list of {0}'.format(type(clip[0])))
53 | img = convert_img(img)
54 | np_clip[:, img_idx, :, :] = img
55 | if self.numpy:
56 | if self.div_255:
57 | np_clip = np_clip / 255.0
58 | return np_clip
59 |
60 | else:
61 | tensor_clip = torch.from_numpy(np_clip)
62 |
63 | if not isinstance(tensor_clip, torch.FloatTensor):
64 | tensor_clip = tensor_clip.float()
65 | if self.div_255:
66 | tensor_clip = torch.div(tensor_clip, 255)
67 | return tensor_clip
68 |
69 |
70 | # Note this norms data to -1/1
71 | class ClipToTensor_K(object):
72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
74 | """
75 |
76 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
77 | self.channel_nb = channel_nb
78 | self.div_255 = div_255
79 | self.numpy = numpy
80 |
81 | def __call__(self, clip):
82 | """
83 | Args: clip (list of numpy.ndarray): clip (list of images)
84 | to be converted to tensor.
85 | """
86 | # Retrieve shape
87 | if isinstance(clip[0], np.ndarray):
88 | h, w, ch = clip[0].shape
89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
90 | ch)
91 | elif isinstance(clip[0], Image.Image):
92 | w, h = clip[0].size
93 | else:
94 | raise TypeError('Expected numpy.ndarray or PIL.Image\
95 | but got list of {0}'.format(type(clip[0])))
96 |
97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
98 |
99 | # Convert
100 | for img_idx, img in enumerate(clip):
101 | if isinstance(img, np.ndarray):
102 | pass
103 | elif isinstance(img, Image.Image):
104 | img = np.array(img, copy=False)
105 | else:
106 | raise TypeError('Expected numpy.ndarray or PIL.Image\
107 | but got list of {0}'.format(type(clip[0])))
108 | img = convert_img(img)
109 | np_clip[:, img_idx, :, :] = img
110 | if self.numpy:
111 | if self.div_255:
112 | np_clip = (np_clip - 127.5) / 127.5
113 | return np_clip
114 |
115 | else:
116 | tensor_clip = torch.from_numpy(np_clip)
117 |
118 | if not isinstance(tensor_clip, torch.FloatTensor):
119 | tensor_clip = tensor_clip.float()
120 | if self.div_255:
121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
122 | return tensor_clip
123 |
124 |
125 | class ToTensor(object):
126 | """Converts numpy array to tensor
127 | """
128 |
129 | def __call__(self, array):
130 | tensor = torch.from_numpy(array)
131 | return tensor
132 |
--------------------------------------------------------------------------------
/PRETRAIN.md:
--------------------------------------------------------------------------------
1 | # Pre-training VideoMAE
2 |
3 | ## Original Implementation
4 |
5 | The implementation of our VideoMAE supports **multi-node distributed training**. We provide the **off-the-shelf** scripts in the [scripts folder](scripts).
6 |
7 | - For example, to pre-train VideoMAE ViT-Base on **Something-Something V2** with 64 GPUs (8 nodes x 8 GPUs), you can run
8 |
9 | ```bash
10 | OUTPUT_DIR='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800'
11 | DATA_PATH='YOUR_PATH/list_ssv2/train.csv'
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 12320 --nnodes=8 \
15 | --node_rank=0 --master_addr=$ip_node_0 \
16 | run_mae_pretraining.py \
17 | --data_path ${DATA_PATH} \
18 | --mask_type tube \
19 | --mask_ratio 0.9 \
20 | --model pretrain_videomae_base_patch16_224 \
21 | --decoder_depth 4 \
22 | --batch_size 32 \
23 | --num_frames 16 \
24 | --sampling_rate 2 \
25 | --opt adamw \
26 | --opt_betas 0.9 0.95 \
27 | --warmup_epochs 40 \
28 | --save_ckpt_freq 20 \
29 | --epochs 801 \
30 | --log_dir ${OUTPUT_DIR} \
31 | --output_dir ${OUTPUT_DIR}
32 | ```
33 |
34 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0.
35 |
36 | - For example, to pre-train VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run
37 |
38 | ```bash
39 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800'
40 | DATA_PATH='YOUR_PATH/list_kinetics-400/train.csv'
41 |
42 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 \
43 | --master_port 12320 --nnodes=8 \
44 | --node_rank=0 --master_addr=$your_node_0_ip \
45 | run_mae_pretraining.py \
46 | --data_path ${DATA_PATH} \
47 | --mask_type tube \
48 | --mask_ratio 0.9 \
49 | --model pretrain_videomae_base_patch16_224 \
50 | --decoder_depth 4 \
51 | --batch_size 32 \
52 | --num_frames 16 \
53 | --sampling_rate 4 \
54 | --opt adamw \
55 | --opt_betas 0.9 0.95 \
56 | --warmup_epochs 40 \
57 | --save_ckpt_freq 20 \
58 | --epochs 801 \
59 | --log_dir ${OUTPUT_DIR} \
60 | --output_dir ${OUTPUT_DIR}
61 | ```
62 |
63 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0.
64 |
65 | ### Note:
66 |
67 | - Here the batch size is 32 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 2048.
68 | - `lr` here is the base learning rate and is set to `1.5e-4` as default. The ` actual lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `` actual lr`` = `lr` * total batch size / 256.
69 | - We have observed accidental interrupt in the last epoch when conduct the experiment on V100 GPUs (torch 1.6.0). This interrupt is caused by the scheduler of learning rate. We naively set `--epochs 801` to walk away from issue :)
70 |
71 | ## Slurm
72 |
73 | To help the community to reproduce our results on slurm cluster, we also provide the the **off-the-shelf** script.
74 |
75 | For example, to pre-train VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run
76 |
77 | ```bash
78 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
79 | export OMP_NUM_THREADS=1
80 |
81 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800'
82 | DATA_PATH='YOUR_PATH/list_kinetics-400/train.csv'
83 |
84 | JOB_NAME=$1
85 | PARTITION=${PARTITION:-"video"}
86 | # 8 for 1 node, 16 for 2 node, etc.
87 | GPUS=${GPUS:-64}
88 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
89 | CPUS_PER_TASK=${CPUS_PER_TASK:-8}
90 | SRUN_ARGS=${SRUN_ARGS:-""}
91 | PY_ARGS=${@:2}
92 |
93 | # batch_size can be adjusted according to the graphics card
94 | srun -p $PARTITION \
95 | --job-name=${JOB_NAME} \
96 | --gres=gpu:${GPUS_PER_NODE} \
97 | --ntasks=${GPUS} \
98 | --ntasks-per-node=${GPUS_PER_NODE} \
99 | --cpus-per-task=${CPUS_PER_TASK} \
100 | --kill-on-bad-exit=1 \
101 | ${SRUN_ARGS} \
102 | python -u run_mae_pretraining.py \
103 | --data_path ${DATA_PATH} \
104 | --mask_type tube \
105 | --mask_ratio 0.9 \
106 | --model pretrain_videomae_base_patch16_224 \
107 | --decoder_depth 4 \
108 | --batch_size 32 \
109 | --num_frames 16 \
110 | --sampling_rate 4 \
111 | --opt adamw \
112 | --opt_betas 0.9 0.95 \
113 | --warmup_epochs 40 \
114 | --save_ckpt_freq 20 \
115 | --epochs 801 \
116 | --log_dir ${OUTPUT_DIR} \
117 | --output_dir ${OUTPUT_DIR} \
118 | ${PY_ARGS}
119 | ```
120 |
121 |
--------------------------------------------------------------------------------
/engine_for_pretraining.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 | import torch
5 | import torch.nn as nn
6 | import utils
7 | from einops import rearrange
8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9 |
10 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16,
12 | normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None,
13 | lr_schedule_values=None, wd_schedule_values=None):
14 | model.train()
15 | metric_logger = utils.MetricLogger(delimiter=" ")
16 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
17 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
18 | header = 'Epoch: [{}]'.format(epoch)
19 | print_freq = 10
20 |
21 | loss_func = nn.MSELoss()
22 |
23 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
24 | # assign learning rate & weight decay for each step
25 | it = start_steps + step # global training iteration
26 | if lr_schedule_values is not None or wd_schedule_values is not None:
27 | for i, param_group in enumerate(optimizer.param_groups):
28 | if lr_schedule_values is not None:
29 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
30 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
31 | param_group["weight_decay"] = wd_schedule_values[it]
32 |
33 | videos, bool_masked_pos = batch
34 | videos = videos.to(device, non_blocking=True)
35 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
36 |
37 | with torch.no_grad():
38 | # calculate the predict label
39 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
40 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
41 | unnorm_videos = videos * std + mean # in [0, 1]
42 |
43 | if normlize_target:
44 | videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size)
45 | videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True)
46 | ) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
47 | # we find that the mean is about 0.48 and standard deviation is about 0.08.
48 | videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)')
49 | else:
50 | videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size)
51 |
52 | B, _, C = videos_patch.shape
53 | labels = videos_patch[bool_masked_pos].reshape(B, -1, C)
54 |
55 | with torch.cuda.amp.autocast():
56 | outputs = model(videos, bool_masked_pos)
57 | loss = loss_func(input=outputs, target=labels)
58 |
59 | loss_value = loss.item()
60 |
61 | if not math.isfinite(loss_value):
62 | print("Loss is {}, stopping training".format(loss_value))
63 | sys.exit(1)
64 |
65 | optimizer.zero_grad()
66 | # this attribute is added by timm on one optimizer (adahessian)
67 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
68 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
69 | parameters=model.parameters(), create_graph=is_second_order)
70 | loss_scale_value = loss_scaler.state_dict()["scale"]
71 |
72 | torch.cuda.synchronize()
73 |
74 | metric_logger.update(loss=loss_value)
75 | metric_logger.update(loss_scale=loss_scale_value)
76 | min_lr = 10.
77 | max_lr = 0.
78 | for group in optimizer.param_groups:
79 | min_lr = min(min_lr, group["lr"])
80 | max_lr = max(max_lr, group["lr"])
81 |
82 | metric_logger.update(lr=max_lr)
83 | metric_logger.update(min_lr=min_lr)
84 | weight_decay_value = None
85 | for group in optimizer.param_groups:
86 | if group["weight_decay"] > 0:
87 | weight_decay_value = group["weight_decay"]
88 | metric_logger.update(weight_decay=weight_decay_value)
89 | metric_logger.update(grad_norm=grad_norm)
90 |
91 | if log_writer is not None:
92 | log_writer.update(loss=loss_value, head="loss")
93 | log_writer.update(loss_scale=loss_scale_value, head="opt")
94 | log_writer.update(lr=max_lr, head="opt")
95 | log_writer.update(min_lr=min_lr, head="opt")
96 | log_writer.update(weight_decay=weight_decay_value, head="opt")
97 | log_writer.update(grad_norm=grad_norm, head="opt")
98 | log_writer.set_step()
99 |
100 | if lr_scheduler is not None:
101 | lr_scheduler.step_update(start_steps + step)
102 | # gather the stats from all processes
103 | metric_logger.synchronize_between_processes()
104 | print("Averaged stats:", metric_logger)
105 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
106 |
--------------------------------------------------------------------------------
/FINETUNE.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning VideoMAE
2 |
3 | ## Original Implementation
4 |
5 | The implementation of our VideoMAE supports **multi-node distributed training**. We provide the **off-the-shelf** scripts in the [scripts folder](scripts).
6 |
7 | - For example, to fine-tune VideoMAE ViT-Base on **Something-Something V2** with 64 GPUs (8 nodes x 8 GPUs), you can run
8 |
9 | ```bash
10 | OUTPUT_DIR='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800/eval_lr_5e-4_epoch_50'
11 | DATA_PATH='YOUR_PATH/list_ssv2'
12 | MODEL_PATH='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800/checkpoint-799.pth'
13 |
14 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
15 | --master_port 12320 --nnodes=8 \
16 | --node_rank=0 --master_addr=$ip_node_0 \
17 | run_class_finetuning.py \
18 | --model vit_base_patch16_224 \
19 | --data_set SSV2 \
20 | --nb_classes 174 \
21 | --data_path ${DATA_PATH} \
22 | --finetune ${MODEL_PATH} \
23 | --log_dir ${OUTPUT_DIR} \
24 | --output_dir ${OUTPUT_DIR} \
25 | --batch_size 8 \
26 | --num_sample 1 \
27 | --input_size 224 \
28 | --short_side_size 224 \
29 | --save_ckpt_freq 10 \
30 | --num_frames 16 \
31 | --opt adamw \
32 | --lr 5e-4 \
33 | --opt_betas 0.9 0.999 \
34 | --weight_decay 0.05 \
35 | --epochs 50 \
36 | --dist_eval \
37 | --test_num_segment 2 \
38 | --test_num_crop 3 \
39 | --enable_deepspeed
40 | ```
41 |
42 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0.
43 |
44 | - For example, to fine-tune VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run
45 |
46 | ```bash
47 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/eval_lr_1e-3_epoch_100'
48 | DATA_PATH='YOUR_PATH/list_kinetics-400'
49 | MODEL_PATH='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/checkpoint-799.pth'
50 |
51 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
52 | --master_port 12320 --nnodes=8 \
53 | --node_rank=0 --master_addr=$ip_node_0 \
54 | run_class_finetuning.py \
55 | --model vit_base_patch16_224 \
56 | --data_set Kinetics-400 \
57 | --nb_classes 400 \
58 | --data_path ${DATA_PATH} \
59 | --finetune ${MODEL_PATH} \
60 | --log_dir ${OUTPUT_DIR} \
61 | --output_dir ${OUTPUT_DIR} \
62 | --batch_size 8 \
63 | --num_sample 1 \
64 | --input_size 224 \
65 | --short_side_size 224 \
66 | --save_ckpt_freq 10 \
67 | --num_frames 16 \
68 | --sampling_rate 4 \
69 | --opt adamw \
70 | --lr 1e-3 \
71 | --opt_betas 0.9 0.999 \
72 | --weight_decay 0.05 \
73 | --epochs 100 \
74 | --dist_eval \
75 | --test_num_segment 5 \
76 | --test_num_crop 3 \
77 | --enable_deepspeed
78 | ```
79 |
80 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0.
81 |
82 | ### Note:
83 |
84 | - We perform the **I3D dense sampling** on **Kinetics400** and **uniform sampling** on **Something-Something V2**, respectively.
85 | - We didn't use `cls token` in our implementation, and directly average the feature of last layer for video classification.
86 | - Here total batch size = (`batch_size` per gpu) x `nodes` x (gpus per node).
87 | - `lr` here is the base learning rate. The ` actual lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `` actual lr`` = `lr` * total batch size / 256.
88 |
89 | ## Slurm
90 |
91 | To help the community to reproduce our results on slurm cluster, we also provide the the **off-the-shelf** script.
92 |
93 | For example, to fine-tune VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run:
94 |
95 | ```bash
96 | export MASTER_PORT=$((12000 + $RANDOM % 20000))
97 | export OMP_NUM_THREADS=1
98 |
99 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/eval_lr_1e-3_epoch_100'
100 | DATA_PATH='YOUR_PATH/list_kinetics-400'
101 | MODEL_PATH='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800/checkpoint-799.pth'
102 |
103 | JOB_NAME=$1
104 | PARTITION=${PARTITION:-"video"}
105 | # 8 for 1 node, 16 for 2 node, etc.
106 | GPUS=${GPUS:-64}
107 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
108 | CPUS_PER_TASK=${CPUS_PER_TASK:-8}
109 | SRUN_ARGS=${SRUN_ARGS:-""}
110 | PY_ARGS=${@:2}
111 |
112 | # batch_size can be adjusted according to the graphics card
113 | srun -p $PARTITION \
114 | --job-name=${JOB_NAME} \
115 | --gres=gpu:${GPUS_PER_NODE} \
116 | --ntasks=${GPUS} \
117 | --ntasks-per-node=${GPUS_PER_NODE} \
118 | --cpus-per-task=${CPUS_PER_TASK} \
119 | --kill-on-bad-exit=1 \
120 | ${SRUN_ARGS} \
121 | python -u run_class_finetuning.py \
122 | --model vit_base_patch16_224 \
123 | --data_set Kinetics-400 \
124 | --nb_classes 400 \
125 | --data_path ${DATA_PATH} \
126 | --finetune ${MODEL_PATH} \
127 | --log_dir ${OUTPUT_DIR} \
128 | --output_dir ${OUTPUT_DIR} \
129 | --batch_size 8 \
130 | --num_sample 1 \
131 | --input_size 224 \
132 | --short_side_size 224 \
133 | --save_ckpt_freq 10 \
134 | --num_frames 16 \
135 | --sampling_rate 4 \
136 | --opt adamw \
137 | --lr 1e-3 \
138 | --opt_betas 0.9 0.999 \
139 | --weight_decay 0.05 \
140 | --epochs 100 \
141 | --dist_eval \
142 | --test_num_segment 5 \
143 | --test_num_crop 3 \
144 | --enable_deepspeed \
145 | ${PY_ARGS}
146 | ```
147 |
148 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # APT: Attention Prompt Tuning
2 | > A Parameter-Efficient Adaptation of Pre-Trained Models for Action Recognition ...
3 |
4 | > [Wele Gedara Chaminda Bandara](https://github.com/wgcban), [Vishal M Patel](https://engineering.jhu.edu/vpatel36/team/vishalpatel/)
Johns Hopkins University
5 |
6 | > Accepted at [FG'24](https://fg2024.ieee-biometrics.org)
7 |
8 | > [Paper (on ArXiv)](https://arxiv.org/abs/2403.06978)
9 |
10 | ## Overview of Proposed Method
11 |
12 |
13 |
14 |
15 |
16 | Comparison of our Attention Prompt Tuning (APT) for videos action classification with other existing tuning methods: linear probing, adapter tuning, visual prompt tuning (VPT), and full fine-tuning.
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | Attention Prompt Tuning (APT) injects learnable prompts directly into the MHA unlike VPT.
25 |
26 |
27 | ## Getting Started
28 |
29 | ### Step 1: Conda Environment
30 |
31 | Setup the virtual conda environment using the `environment.yml`:
32 | ```
33 | conda env create -f environment.yml
34 | ```
35 |
36 | Then activate the conda environment:
37 | ```
38 | conda activate apt
39 | ```
40 |
41 | ### Step 2: Download the VideoMAE Pre-trained Models:
42 |
43 | We use [VideoMAE](https://github.com/MCG-NJU/VideoMAE) pretrianed on [Kinetics-400](https://github.com/cvdfoundation/kinetics-dataset) dataset for our experiments.
44 |
45 | The pre-trained models for ViT-Small and ViT-Base backbones can be downloaded from below links:
46 |
47 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train |
48 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: |
49 | | VideoMAE | ***no*** | ViT-S | 1600 | 16x5x3 | [checkpoint](https://drive.google.com/file/d/1nU-H1u3eJ-VuyCveU7v-WIOcAVxs5Hww/view?usp=sharing) |
50 | | VideoMAE | ***no*** | ViT-B | 1600 | 16x5x3 | [checkpoint](https://drive.google.com/file/d/1tEhLyskjb755TJ65ptsrafUG2llSwQE1/view?usp=sharing) |
51 |
52 | If you need other pre-trained models please refer [MODEL_ZOO.md](https://github.com/wgcban/apt/blob/main/MODEL_ZOO.md).
53 |
54 | ### Step 3: Download the datasets
55 |
56 | We conduct experiments on three action recognition datasets: 1) UCF101 2) HMDB51 3) Something-Something-V2.
57 |
58 | Please refer [DATASETS.md](https://github.com/wgcban/apt/blob/main/DATASET.md) for access to those links and pre-processing steps.
59 |
60 | ### Step 4: Attention Prompt Tuning
61 |
62 | We provide example scripts to run the attention prompt tuning on UCF101, HMDB51, and SSv2 datasets in `scripts/` folder.
63 |
64 | Inside `scripts/` you can find two folders which corresponds to APT finetuning with ViT-Small and ViT-Base architectures.
65 |
66 | To fine-tune with APT you just need to execute `finetune.sh` file -- which will launch the job with distributed training by
67 |
68 |
69 | For example, to fine-tune ViT-Base on SSv2 with APT, you may run:
70 | ```
71 | sh scripts/ssv2/vit_base/finetune.sh
72 | ```
73 |
74 | The `finetune.sh` looks like this:
75 |
76 | ```bash
77 | # APT on SSv2
78 | OUTPUT_DIR='experiments/APT/SSV2/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/adam_mome9e-1_wd1e-5_lr5se-2_pl2_ps0_pe11_drop10'
79 | DATA_PATH='datasets/ss2/list_ssv2/'
80 | MODEL_PATH='experiments/pretrain/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e2400/checkpoint.pth'
81 |
82 | NCCL_P2P_DISABLE=1 OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 \
83 | run_class_apt.py \
84 | --model vit_base_patch16_224 \
85 | --transfer_type prompt \
86 | --prompt_start 0 \
87 | --prompt_end 11 \
88 | --prompt_num_tokens 2 \
89 | --prompt_dropout 0.1 \
90 | --data_set SSV2 \
91 | --nb_classes 174 \
92 | --data_path ${DATA_PATH} \
93 | --finetune ${MODEL_PATH} \
94 | --log_dir ${OUTPUT_DIR} \
95 | --output_dir ${OUTPUT_DIR} \
96 | --batch_size 8 \
97 | --batch_size_val 8 \
98 | --num_sample 2 \
99 | --input_size 224 \
100 | --short_side_size 224 \
101 | --save_ckpt_freq 10 \
102 | --num_frames 16 \
103 | --opt adamw \
104 | --lr 0.05 \
105 | --weight_decay 0.00001 \
106 | --epochs 100 \
107 | --warmup_epochs 10 \
108 | --test_num_segment 2 \
109 | --test_num_crop 3 \
110 | --dist_eval \
111 | --pin_mem \
112 | --enable_deepspeed \
113 | --prompt_reparam \
114 | --is_aa \
115 | --aa rand-m4-n2-mstd0.2-inc1
116 |
117 | ```
118 |
119 | Here,
120 |
121 | - `OUTPUT_DIR`: place where you wants to save the results (i.e., logs and checkpoints)
122 | - `DATA_PATH`: path to where the dataset is stored
123 | - `MODEL_PATH`: path to the downloaded videomae pre-trained model
124 | - specifiy thich gpus (gpu ids) you wants to use for finetuning in `CUDA_VISIBLE_DEVICES=`...
125 | - `nproc_per_node` is the number of gpus using for fine-tuning
126 | - `model` is the vit-base (vit_base_patch16_224) or vit-small (vit_small_patch16_224)
127 | - `transfer_type` specifies which finetuning method to use. 'random' means random initialization, 'end2end' means full end-to-end fine tuning, 'prompt' means APT (ours), 'linear' means linear probing
128 | - `prompt_start` refers to starting trasnformer block where you add attention prompts. 0 means you start adding learninable prompts from 1st transformer block in vit
129 | - `prompt_end` refers to ending trasformer block where you stop adding attention prompts. vit-base / vit-small has 12 transformer blocks. hence 11 here means you add prompts until last trasnformer block
130 | - `data_set` specifies the dataset
131 | - * all the other parameters are hyperparamters related to apt fine-tuning.
132 |
133 |
134 | ## ✏️ Citation
135 |
136 | If you think this project is helpful, please feel free to leave a star and cite our paper:
137 |
138 | ```bibtex
139 | @misc{bandara2024attention,
140 | title={Attention Prompt Tuning: Parameter-efficient Adaptation of Pre-trained Models for Spatiotemporal Modeling},
141 | author={Wele Gedara Chaminda Bandara and Vishal M. Patel},
142 | year={2024},
143 | eprint={2403.06978},
144 | archivePrefix={arXiv},
145 | primaryClass={cs.CV}
146 | }
147 | ```
148 |
149 |
150 | ## ✏️ Disclaimer
151 |
152 | This repocitory is built on top of VideoMAE: https://github.com/MCG-NJU/VideoMAE codebase and we approcite the authors of VideoMAE for making their codebase publically available.
153 |
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_vit(var_name, num_max_layer):
25 | if var_name in ("cls_token", "mask_token", "pos_embed"):
26 | return 0
27 | elif var_name.startswith("patch_embed"):
28 | return 0
29 | elif var_name.startswith("rel_pos_bias"):
30 | return num_max_layer - 1
31 | elif var_name.startswith("blocks"):
32 | layer_id = int(var_name.split('.')[1])
33 | return layer_id + 1
34 | else:
35 | return num_max_layer - 1
36 |
37 |
38 | class LayerDecayValueAssigner(object):
39 | def __init__(self, values):
40 | self.values = values
41 |
42 | def get_scale(self, layer_id):
43 | return self.values[layer_id]
44 |
45 | def get_layer_id(self, var_name):
46 | return get_num_layer_for_vit(var_name, len(self.values))
47 |
48 |
49 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
50 | parameter_group_names = {}
51 | parameter_group_vars = {}
52 |
53 | for name, param in model.named_parameters():
54 | if not param.requires_grad:
55 | continue # frozen weights
56 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
57 | group_name = "no_decay"
58 | this_weight_decay = 0.
59 | else:
60 | group_name = "decay"
61 | this_weight_decay = weight_decay
62 | if get_num_layer is not None:
63 | layer_id = get_num_layer(name)
64 | group_name = "layer_%d_%s" % (layer_id, group_name)
65 | else:
66 | layer_id = None
67 |
68 | if group_name not in parameter_group_names:
69 | if get_layer_scale is not None:
70 | scale = get_layer_scale(layer_id)
71 | else:
72 | scale = 1.
73 |
74 | parameter_group_names[group_name] = {
75 | "weight_decay": this_weight_decay,
76 | "params": [],
77 | "lr_scale": scale
78 | }
79 | parameter_group_vars[group_name] = {
80 | "weight_decay": this_weight_decay,
81 | "params": [],
82 | "lr_scale": scale
83 | }
84 |
85 | parameter_group_vars[group_name]["params"].append(param)
86 | parameter_group_names[group_name]["params"].append(name)
87 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
88 | return list(parameter_group_vars.values())
89 |
90 |
91 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
92 | opt_lower = args.opt.lower()
93 | weight_decay = args.weight_decay
94 | if weight_decay and filter_bias_and_bn:
95 | skip = {}
96 | if skip_list is not None:
97 | skip = skip_list
98 | elif hasattr(model, 'no_weight_decay'):
99 | skip = model.no_weight_decay()
100 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
101 | weight_decay = 0.
102 | else:
103 | parameters = model.parameters()
104 |
105 | if 'fused' in opt_lower:
106 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
107 |
108 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
109 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
110 | opt_args['eps'] = args.opt_eps
111 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
112 | opt_args['betas'] = args.opt_betas
113 |
114 | print("optimizer settings:", opt_args)
115 |
116 | opt_split = opt_lower.split('_')
117 | opt_lower = opt_split[-1]
118 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
119 | opt_args.pop('eps', None)
120 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
121 | elif opt_lower == 'momentum':
122 | opt_args.pop('eps', None)
123 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
124 | elif opt_lower == 'adam':
125 | optimizer = optim.Adam(parameters, **opt_args)
126 | elif opt_lower == 'adamw':
127 | optimizer = optim.AdamW(parameters, **opt_args)
128 | elif opt_lower == 'nadam':
129 | optimizer = Nadam(parameters, **opt_args)
130 | elif opt_lower == 'radam':
131 | optimizer = RAdam(parameters, **opt_args)
132 | elif opt_lower == 'adamp':
133 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
134 | elif opt_lower == 'sgdp':
135 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
136 | elif opt_lower == 'adadelta':
137 | optimizer = optim.Adadelta(parameters, **opt_args)
138 | elif opt_lower == 'adafactor':
139 | if not args.lr:
140 | opt_args['lr'] = None
141 | optimizer = Adafactor(parameters, **opt_args)
142 | elif opt_lower == 'adahessian':
143 | optimizer = Adahessian(parameters, **opt_args)
144 | elif opt_lower == 'rmsprop':
145 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
146 | elif opt_lower == 'rmsproptf':
147 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
148 | elif opt_lower == 'novograd':
149 | optimizer = NovoGrad(parameters, **opt_args)
150 | elif opt_lower == 'nvnovograd':
151 | optimizer = NvNovoGrad(parameters, **opt_args)
152 | elif opt_lower == 'fusedsgd':
153 | opt_args.pop('eps', None)
154 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
155 | elif opt_lower == 'fusedmomentum':
156 | opt_args.pop('eps', None)
157 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
158 | elif opt_lower == 'fusedadam':
159 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
160 | elif opt_lower == 'fusedadamw':
161 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
162 | elif opt_lower == 'fusedlamb':
163 | optimizer = FusedLAMB(parameters, **opt_args)
164 | elif opt_lower == 'fusednovograd':
165 | opt_args.setdefault('betas', (0.95, 0.98))
166 | optimizer = FusedNovoGrad(parameters, **opt_args)
167 | else:
168 | assert False and "Invalid optimizer"
169 | raise ValueError
170 |
171 | if len(opt_split) > 1:
172 | if opt_split[0] == 'lookahead':
173 | optimizer = Lookahead(optimizer)
174 |
175 | return optimizer
176 |
--------------------------------------------------------------------------------
/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # Pre-trained VideoMAE Models
2 |
3 | For all experiments on APT, we use VideoMAE pre-trained ViT models on Kinetics-400.
4 |
5 | The following table provide different checkpoints.
6 |
7 | Note that we use pre-trained checkpoint. Not the fine-tuned one.
8 |
9 | ### Kinetics-400
10 |
11 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | Fine-tune | Top-1 | Top-5 |
12 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: |
13 | | VideoMAE | ***no*** | ViT-S | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1fbmQtp3UUw9fro3MVkKCW62Ib_HlZvNz/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1nU-H1u3eJ-VuyCveU7v-WIOcAVxs5Hww/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1RuEvCT2OMKPax2gGB1gBsH6ItiXIPH-R/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1ygjLRm1kvs9mwGsP3lLxUExhRo6TWnrx/view?usp=sharing) | 79.0 | 93.8 |
14 | | VideoMAE | ***no*** | ViT-B | 800 | 16x5x3 | [script](scripts/kinetics/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/pretrain.sh)/[log](https://drive.google.com/file/d/1kP3_-465jCL7PRNFq1JcAghPo2BONRWY/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1JfrhN144Hdg7we213H1WxwR3lGYOlmIn/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/finetune.sh)/[log](https://drive.google.com/file/d/1JOJzhlCujgpsjjth0J49k5EwBNxy76xt/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/18EEgdXY9347yK3Yb28O-GxFMbk41F6Ne/view?usp=sharing)
(w/o repeated aug) | 80.0 | 94.4 |
15 | | VideoMAE | ***no*** | ViT-B | 800 | 16x5x3 | same as above | TODO | 81.0 | 94.8 |
16 | | VideoMAE | ***no*** | ViT-B | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1ftVHzzCupEGV4bCHC5JWIUsEwOEeAQcg/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1tEhLyskjb755TJ65ptsrafUG2llSwQE1/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_large_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1fYXtL2y2ZTMxDtTRqoUOe6leVmdVI5HH/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1MzwteHH-1yuMnFb8vRBQDvngV1Zl-d3z/view?usp=sharing) | 81.5 | 95.1 |
17 | | VideoMAE | ***no*** | ViT-L | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_large_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1X7WBzn_yG4lDWuvBMBBgrtgqDLZVHrc2/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1qLOXWb_MGEvaI7tvuAe94CV7S2HXRwT3/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_large_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1Doqx6zDQEMnMyPvDdz2knG385o0sZn3f/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1jX1CiqxSkCfc94y8FRW1YGHy-GNvHCuD/view?usp=sharing) | 85.2 | 96.8 |
18 | | VideoMAE | ***no*** | ViT-H | 1600 | 16x5x3 | [script](scripts/kinetics/videomae_vit_huge_patch16_224_tubemasking_ratio_0.9_epoch_1600/pretrain.sh)/[log](https://drive.google.com/file/d/1ZGOGk5_L7cqJ2UkrNQ7c_jcw1OUBqptl/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1AJQR1Rsi2N1pDn9tLyJ8DQrUREiBA1bO/view?usp=sharing) | [script](scripts/kinetics/videomae_vit_huge_patch16_224_tubemasking_ratio_0.9_epoch_1600/finetune.sh)/[log](https://drive.google.com/file/d/1NOUjO5wPrHZo4EUfklKvfGM3ScJVmGAK/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/104ouJZxSVPSAm0LwJXd6IzjdA_RGLqZi/view?usp=sharing) | 86.6 | 97.1 |
19 |
20 | ### Something-Something V2
21 |
22 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | Fine-tune | Top-1 | Top-5 |
23 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: |
24 | | VideoMAE | ***no*** | ViT-S | 2400 | 16x2x3 | [script](scripts/ssv2/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/pretrain.sh)/[log](https://drive.google.com/file/d/129wqpAtwTCD-T1SQIX7q5nB9CEGchhw0/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1p_I1aaONOeUvRmRQw1UT3-L2H8XJClHu/view?usp=sharing) | [script](scripts/ssv2/videomae_vit_small_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune.sh)/[log](https://drive.google.com/file/d/17X9PcDSBB1Zb1blNqQP3vvnqOuMzJrGp/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1ajlMrT06jiiM-5YjNI2X_UFyzsuYbbtZ/view?usp=sharing) | 66.8 | 90.3 |
25 | | VideoMAE | ***no*** | ViT-B | 800 | 16x2x3 | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/pretrain.sh)/[log](https://drive.google.com/file/d/1eGS18rKvbgEJ3nbsXxokkMSwNGxxoX48/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/181hLvyrrPW2IOGA46fkxdJk0tNLIgdB2/view?usp=sharing) | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_800/finetune.sh)/[log](https://drive.google.com/file/d/1jYAHPcs7zt_QMPM2D_geEWoWrf3yHox8/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1xZCiaPF4w7lYmLt5o1D5tIZyDdLtJAvH/view?usp=sharing)
(w/o repeated aug) | 69.6 | 92.0 |
26 | | VideoMAE | ***no*** | ViT-B | 2400 | 16x2x3 | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/pretrain.sh)/[log](https://drive.google.com/file/d/148nURgfcIFBQd3IQH5YhJ9dTwNCc2jkU/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1I18dY_7rSalGL8fPWV82c0-foRUDzJJk/view?usp=sharing) | [script](scripts/ssv2/videomae_vit_base_patch16_224_tubemasking_ratio_0.9_epoch_2400/finetune.sh)/[log](https://drive.google.com/file/d/15TPBiUl_K2Q_9l6J41G_vf-2lovVLEHM/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1dt_59tBIyzdZd5Ecr22lTtzs_64MOZkT/view?usp=sharing) | 70.8 | 92.4 |
27 |
28 | ### UCF101
29 |
30 | | Method | Extra Data | Backbone | Epoch | \#Frame | Pre-train | Fine-tune | Top-1 | Top-5 |
31 | | :------: | :--------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---: | :---: |
32 | | VideoMAE | ***no*** | ViT-B | 3200 | 16x5x3 | [script](scripts/ucf101/videomae_vit_base_patch16_224_tubemasking_ratio_0.75_epoch_3200/pretrain.sh)/[log](https://drive.google.com/file/d/1kZODk_dQgB-aW6oIwPYZxqZAG6YKNtXC/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1BHev4meNgKM0o_8DMRbuzAsKSP3IpQ3o/view?usp=sharing) | [script](scripts/ucf101/videomae_vit_base_patch16_224_tubemasking_ratio_0.75_epoch_3200/finetune.sh)/[log](https://drive.google.com/file/d/17Mq7rlM1TRgV4KKX7UIlmKw653RmwSqe/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1MSyon6fPpKz7oqD6WDGPFK4k_Rbyb6fw/view?usp=sharing) | 91.3 | 98.5 |
33 |
34 | ### Note:
35 |
36 | - We report the results of VideoMAE finetuned with `I3D dense sampling` on **Kinetics400** and `TSN uniform sampling` on **Something-Something V2**, respectively.
37 | - \#Frame = #input_frame x #clip x #crop.
38 | - \#input_frame means how many frames are input for model during the test phase.
39 | - \#crop means spatial crops (e.g., 3 for left/right/center crop).
40 | - \#clip means temporal clips (e.g., 5 means repeted temporal sampling five clips with different start indices).
41 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import warnings
4 | import random
5 | import numpy as np
6 | import torchvision
7 | from PIL import Image, ImageOps
8 | import numbers
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_tuple):
19 | img_group, label = img_tuple
20 |
21 | w, h = img_group[0].size
22 | th, tw = self.size
23 |
24 | out_images = list()
25 |
26 | x1 = random.randint(0, w - tw)
27 | y1 = random.randint(0, h - th)
28 |
29 | for img in img_group:
30 | assert(img.size[0] == w and img.size[1] == h)
31 | if w == tw and h == th:
32 | out_images.append(img)
33 | else:
34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35 |
36 | return (out_images, label)
37 |
38 |
39 | class GroupCenterCrop(object):
40 | def __init__(self, size):
41 | self.worker = torchvision.transforms.CenterCrop(size)
42 |
43 | def __call__(self, img_tuple):
44 | img_group, label = img_tuple
45 | return ([self.worker(img) for img in img_group], label)
46 |
47 |
48 | class GroupNormalize(object):
49 | def __init__(self, mean, std):
50 | self.mean = mean
51 | self.std = std
52 |
53 | def __call__(self, tensor_tuple):
54 | tensor, label = tensor_tuple
55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56 | rep_std = self.std * (tensor.size()[0]//len(self.std))
57 |
58 | # TODO: make efficient
59 | for t, m, s in zip(tensor, rep_mean, rep_std):
60 | t.sub_(m).div_(s)
61 |
62 | return (tensor,label)
63 |
64 |
65 | class GroupGrayScale(object):
66 | def __init__(self, size):
67 | self.worker = torchvision.transforms.Grayscale(size)
68 |
69 | def __call__(self, img_tuple):
70 | img_group, label = img_tuple
71 | return ([self.worker(img) for img in img_group], label)
72 |
73 |
74 | class GroupScale(object):
75 | """ Rescales the input PIL.Image to the given 'size'.
76 | 'size' will be the size of the smaller edge.
77 | For example, if height > width, then image will be
78 | rescaled to (size * height / width, size)
79 | size: size of the smaller edge
80 | interpolation: Default: PIL.Image.BILINEAR
81 | """
82 |
83 | def __init__(self, size, interpolation=Image.BILINEAR):
84 | self.worker = torchvision.transforms.Resize(size, interpolation)
85 |
86 | def __call__(self, img_tuple):
87 | img_group, label = img_tuple
88 | return ([self.worker(img) for img in img_group], label)
89 |
90 |
91 | class GroupMultiScaleCrop(object):
92 |
93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
94 | self.scales = scales if scales is not None else [1, .875, .75, .66]
95 | self.max_distort = max_distort
96 | self.fix_crop = fix_crop
97 | self.more_fix_crop = more_fix_crop
98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
99 | self.interpolation = Image.BILINEAR
100 |
101 | def __call__(self, img_tuple):
102 | img_group, label = img_tuple
103 |
104 | im_size = img_group[0].size
105 |
106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
109 | return (ret_img_group, label)
110 |
111 | def _sample_crop_size(self, im_size):
112 | image_w, image_h = im_size[0], im_size[1]
113 |
114 | # find a crop size
115 | base_size = min(image_w, image_h)
116 | crop_sizes = [int(base_size * x) for x in self.scales]
117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
119 |
120 | pairs = []
121 | for i, h in enumerate(crop_h):
122 | for j, w in enumerate(crop_w):
123 | if abs(i - j) <= self.max_distort:
124 | pairs.append((w, h))
125 |
126 | crop_pair = random.choice(pairs)
127 | if not self.fix_crop:
128 | w_offset = random.randint(0, image_w - crop_pair[0])
129 | h_offset = random.randint(0, image_h - crop_pair[1])
130 | else:
131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
132 |
133 | return crop_pair[0], crop_pair[1], w_offset, h_offset
134 |
135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
137 | return random.choice(offsets)
138 |
139 | @staticmethod
140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
141 | w_step = (image_w - crop_w) // 4
142 | h_step = (image_h - crop_h) // 4
143 |
144 | ret = list()
145 | ret.append((0, 0)) # upper left
146 | ret.append((4 * w_step, 0)) # upper right
147 | ret.append((0, 4 * h_step)) # lower left
148 | ret.append((4 * w_step, 4 * h_step)) # lower right
149 | ret.append((2 * w_step, 2 * h_step)) # center
150 |
151 | if more_fix_crop:
152 | ret.append((0, 2 * h_step)) # center left
153 | ret.append((4 * w_step, 2 * h_step)) # center right
154 | ret.append((2 * w_step, 4 * h_step)) # lower center
155 | ret.append((2 * w_step, 0 * h_step)) # upper center
156 |
157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
161 | return ret
162 |
163 |
164 | class Stack(object):
165 |
166 | def __init__(self, roll=False):
167 | self.roll = roll
168 |
169 | def __call__(self, img_tuple):
170 | img_group, label = img_tuple
171 |
172 | if img_group[0].mode == 'L':
173 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
174 | elif img_group[0].mode == 'RGB':
175 | if self.roll:
176 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
177 | else:
178 | return (np.concatenate(img_group, axis=2), label)
179 |
180 |
181 | class ToTorchFormatTensor(object):
182 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
183 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
184 | def __init__(self, div=True):
185 | self.div = div
186 |
187 | def __call__(self, pic_tuple):
188 | pic, label = pic_tuple
189 |
190 | if isinstance(pic, np.ndarray):
191 | # handle numpy array
192 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
193 | else:
194 | # handle PIL Image
195 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
196 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
197 | # put it from HWC to CHW format
198 | # yikes, this transpose takes 80% of the loading time/CPU
199 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
200 | return (img.float().div(255.) if self.div else img.float(), label)
201 |
202 |
203 | class IdentityTransform(object):
204 |
205 | def __call__(self, data):
206 | return data
207 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import transforms
3 | from transforms import *
4 | from masking_generator import TubeMaskingGenerator
5 | from kinetics import VideoClsDataset, VideoMAE
6 | from ssv2 import SSVideoClsDataset
7 |
8 |
9 | class DataAugmentationForVideoMAE(object):
10 | def __init__(self, args):
11 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
12 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
13 | normalize = GroupNormalize(self.input_mean, self.input_std)
14 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
15 | self.transform = transforms.Compose([
16 | self.train_augmentation,
17 | Stack(roll=False),
18 | ToTorchFormatTensor(div=True),
19 | normalize,
20 | ])
21 | if args.mask_type == 'tube':
22 | self.masked_position_generator = TubeMaskingGenerator(
23 | args.window_size, args.mask_ratio
24 | )
25 |
26 | def __call__(self, images):
27 | process_data, _ = self.transform(images)
28 | return process_data, self.masked_position_generator()
29 |
30 | def __repr__(self):
31 | repr = "(DataAugmentationForVideoMAE,\n"
32 | repr += " transform = %s,\n" % str(self.transform)
33 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
34 | repr += ")"
35 | return repr
36 |
37 |
38 | def build_pretraining_dataset(args):
39 | transform = DataAugmentationForVideoMAE(args)
40 | dataset = VideoMAE(
41 | root=None,
42 | setting=args.data_path,
43 | video_ext='mp4',
44 | is_color=True,
45 | modality='rgb',
46 | new_length=args.num_frames,
47 | new_step=args.sampling_rate,
48 | transform=transform,
49 | temporal_jitter=False,
50 | video_loader=True,
51 | use_decord=True,
52 | lazy_init=False)
53 | print("Data Aug = %s" % str(transform))
54 | return dataset
55 |
56 |
57 | def build_dataset(is_train, test_mode, args):
58 | if args.data_set == 'Kinetics-400':
59 | mode = None
60 | anno_path = None
61 | if is_train is True:
62 | mode = 'train'
63 | anno_path = os.path.join(args.data_path, 'train.csv')
64 | elif test_mode is True:
65 | mode = 'test'
66 | anno_path = os.path.join(args.data_path, 'test.csv')
67 | else:
68 | mode = 'validation'
69 | anno_path = os.path.join(args.data_path, 'val.csv')
70 |
71 | dataset = VideoClsDataset(
72 | anno_path=anno_path,
73 | data_path='/',
74 | mode=mode,
75 | clip_len=args.num_frames,
76 | frame_sample_rate=args.sampling_rate,
77 | num_segment=1,
78 | test_num_segment=args.test_num_segment,
79 | test_num_crop=args.test_num_crop,
80 | num_crop=1 if not test_mode else 3,
81 | keep_aspect_ratio=True,
82 | crop_size=args.input_size,
83 | short_side_size=args.short_side_size,
84 | new_height=256,
85 | new_width=320,
86 | args=args)
87 | nb_classes = 400
88 |
89 | elif args.data_set == 'SSV2':
90 | mode = None
91 | anno_path = None
92 | if is_train is True:
93 | mode = 'train'
94 | anno_path = os.path.join(args.data_path, 'train.csv')
95 | elif test_mode is True:
96 | mode = 'test'
97 | anno_path = os.path.join(args.data_path, 'test.csv')
98 | else:
99 | mode = 'validation'
100 | anno_path = os.path.join(args.data_path, 'val.csv')
101 |
102 | dataset = SSVideoClsDataset(
103 | anno_path=anno_path,
104 | data_path='/',
105 | mode=mode,
106 | clip_len=1,
107 | num_segment=args.num_frames,
108 | test_num_segment=args.test_num_segment,
109 | test_num_crop=args.test_num_crop,
110 | num_crop=1 if not test_mode else 3,
111 | keep_aspect_ratio=True,
112 | crop_size=args.input_size,
113 | short_side_size=args.short_side_size,
114 | new_height=256,
115 | new_width=320,
116 | args=args)
117 | nb_classes = 174
118 |
119 | elif args.data_set == 'UCF101':
120 | mode = None
121 | anno_path = None
122 | if is_train is True:
123 | mode = 'train'
124 | anno_path = os.path.join(args.data_path, 'train.csv')
125 | elif test_mode is True:
126 | mode = 'test'
127 | anno_path = os.path.join(args.data_path, 'test.csv')
128 | else:
129 | mode = 'validation'
130 | anno_path = os.path.join(args.data_path, 'val.csv')
131 |
132 | dataset = VideoClsDataset(
133 | anno_path=anno_path,
134 | data_path='/',
135 | mode=mode,
136 | clip_len=args.num_frames,
137 | frame_sample_rate=args.sampling_rate,
138 | num_segment=1,
139 | test_num_segment=args.test_num_segment,
140 | test_num_crop=args.test_num_crop,
141 | num_crop=1 if not test_mode else 3,
142 | keep_aspect_ratio=True,
143 | crop_size=args.input_size,
144 | short_side_size=args.short_side_size,
145 | new_height=256,
146 | new_width=320,
147 | args=args)
148 | nb_classes = 101
149 |
150 | elif args.data_set == 'HMDB51':
151 | mode = None
152 | anno_path = None
153 | if is_train is True:
154 | mode = 'train'
155 | anno_path = os.path.join(args.data_path, 'train.csv')
156 | elif test_mode is True:
157 | mode = 'test'
158 | anno_path = os.path.join(args.data_path, 'test.csv')
159 | else:
160 | mode = 'validation'
161 | anno_path = os.path.join(args.data_path, 'val.csv')
162 |
163 | dataset = VideoClsDataset(
164 | anno_path=anno_path,
165 | data_path='/',
166 | mode=mode,
167 | clip_len=args.num_frames,
168 | frame_sample_rate=args.sampling_rate,
169 | num_segment=1,
170 | test_num_segment=args.test_num_segment,
171 | test_num_crop=args.test_num_crop,
172 | num_crop=1 if not test_mode else 3,
173 | keep_aspect_ratio=True,
174 | crop_size=args.input_size,
175 | short_side_size=args.short_side_size,
176 | new_height=256,
177 | new_width=320,
178 | args=args)
179 | nb_classes = 51
180 |
181 | elif args.data_set == 'ROCOG':
182 | mode = None
183 | anno_path = None
184 | if is_train is True:
185 | mode = 'train'
186 | anno_path = os.path.join(args.data_path, 'train.csv')
187 | elif test_mode is True:
188 | mode = 'test'
189 | anno_path = os.path.join(args.data_path, 'test.csv')
190 | else:
191 | mode = 'validation'
192 | anno_path = os.path.join(args.data_path, 'val.csv')
193 |
194 | dataset = VideoClsDataset(
195 | anno_path=anno_path,
196 | data_path='/',
197 | mode=mode,
198 | clip_len=args.num_frames,
199 | frame_sample_rate=args.sampling_rate,
200 | num_segment=1,
201 | test_num_segment=args.test_num_segment,
202 | test_num_crop=args.test_num_crop,
203 | num_crop=1 if not test_mode else 3,
204 | keep_aspect_ratio=True,
205 | crop_size=args.input_size,
206 | short_side_size=args.short_side_size,
207 | new_height=256,
208 | new_width=320,
209 | args=args)
210 | nb_classes = 7
211 | else:
212 | raise NotImplementedError()
213 | assert nb_classes == args.nb_classes
214 | print("Number of the class = %d" % args.nb_classes)
215 |
216 | return dataset, nb_classes
217 |
--------------------------------------------------------------------------------
/run_videomae_vis.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | from PIL import Image
7 | from pathlib import Path
8 | from timm.models import create_model
9 | import utils
10 | import modeling_pretrain
11 | from datasets import DataAugmentationForVideoMAE
12 | from torchvision.transforms import ToPILImage
13 | from einops import rearrange
14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15 | from decord import VideoReader, cpu
16 | from torchvision import transforms
17 | from transforms import *
18 | from masking_generator import TubeMaskingGenerator
19 |
20 | class DataAugmentationForVideoMAE(object):
21 | def __init__(self, args):
22 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
23 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
24 | normalize = GroupNormalize(self.input_mean, self.input_std)
25 | self.train_augmentation = GroupCenterCrop(args.input_size)
26 | self.transform = transforms.Compose([
27 | self.train_augmentation,
28 | Stack(roll=False),
29 | ToTorchFormatTensor(div=True),
30 | normalize,
31 | ])
32 | if args.mask_type == 'tube':
33 | self.masked_position_generator = TubeMaskingGenerator(
34 | args.window_size, args.mask_ratio
35 | )
36 |
37 | def __call__(self, images):
38 | process_data , _ = self.transform(images)
39 | return process_data, self.masked_position_generator()
40 |
41 | def __repr__(self):
42 | repr = "(DataAugmentationForVideoMAE,\n"
43 | repr += " transform = %s,\n" % str(self.transform)
44 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
45 | repr += ")"
46 | return repr
47 |
48 | def get_args():
49 | parser = argparse.ArgumentParser('VideoMAE visualization reconstruction script', add_help=False)
50 | parser.add_argument('img_path', type=str, help='input video path')
51 | parser.add_argument('save_path', type=str, help='save video path')
52 | parser.add_argument('model_path', type=str, help='checkpoint path of model')
53 | parser.add_argument('--mask_type', default='random', choices=['random', 'tube'],
54 | type=str, help='masked strategy of video tokens/patches')
55 | parser.add_argument('--num_frames', type=int, default= 16)
56 | parser.add_argument('--sampling_rate', type=int, default= 4)
57 | parser.add_argument('--decoder_depth', default=4, type=int,
58 | help='depth of decoder')
59 | parser.add_argument('--input_size', default=224, type=int,
60 | help='videos input size for backbone')
61 | parser.add_argument('--device', default='cuda:0',
62 | help='device to use for training / testing')
63 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
64 | parser.add_argument('--mask_ratio', default=0.75, type=float,
65 | help='ratio of the visual tokens/patches need be masked')
66 | # Model parameters
67 | parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
68 | help='Name of model to vis')
69 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
70 | help='Drop path rate (default: 0.1)')
71 |
72 | return parser.parse_args()
73 |
74 |
75 | def get_model(args):
76 | print(f"Creating model: {args.model}")
77 | model = create_model(
78 | args.model,
79 | pretrained=False,
80 | drop_path_rate=args.drop_path,
81 | drop_block_rate=None,
82 | decoder_depth=args.decoder_depth
83 | )
84 |
85 | return model
86 |
87 |
88 | def main(args):
89 | print(args)
90 |
91 | device = torch.device(args.device)
92 | cudnn.benchmark = True
93 |
94 | model = get_model(args)
95 | patch_size = model.encoder.patch_embed.patch_size
96 | print("Patch size = %s" % str(patch_size))
97 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
98 | args.patch_size = patch_size
99 |
100 | model.to(device)
101 | checkpoint = torch.load(args.model_path, map_location='cpu')
102 | model.load_state_dict(checkpoint['model'])
103 | model.eval()
104 |
105 | if args.save_path:
106 | Path(args.save_path).mkdir(parents=True, exist_ok=True)
107 |
108 | with open(args.img_path, 'rb') as f:
109 | vr = VideoReader(f, ctx=cpu(0))
110 | duration = len(vr)
111 | new_length = 1
112 | new_step = 1
113 | skip_length = new_length * new_step
114 | # frame_id_list = [1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61]
115 |
116 |
117 | tmp = np.arange(0,32, 2) + 60
118 | frame_id_list = tmp.tolist()
119 | # average_duration = (duration - skip_length + 1) // args.num_frames
120 | # if average_duration > 0:
121 | # frame_id_list = np.multiply(list(range(args.num_frames)),
122 | # average_duration)
123 | # frame_id_list = frame_id_list + np.random.randint(average_duration,
124 | # size=args.num_frames)
125 |
126 | video_data = vr.get_batch(frame_id_list).asnumpy()
127 | print(video_data.shape)
128 | img = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)]
129 |
130 | transforms = DataAugmentationForVideoMAE(args)
131 | img, bool_masked_pos = transforms((img, None)) # T*C,H,W
132 | # print(img.shape)
133 | img = img.view((args.num_frames , 3) + img.size()[-2:]).transpose(0,1) # T*C,H,W -> T,C,H,W -> C,T,H,W
134 | # img = img.view(( -1 , args.num_frames) + img.size()[-2:])
135 | bool_masked_pos = torch.from_numpy(bool_masked_pos)
136 |
137 | with torch.no_grad():
138 | # img = img[None, :]
139 | # bool_masked_pos = bool_masked_pos[None, :]
140 | img = img.unsqueeze(0)
141 | print(img.shape)
142 | bool_masked_pos = bool_masked_pos.unsqueeze(0)
143 |
144 | img = img.to(device, non_blocking=True)
145 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
146 | outputs = model(img, bool_masked_pos)
147 |
148 | #save original video
149 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
150 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
151 | ori_img = img * std + mean # in [0, 1]
152 | imgs = [ToPILImage()(ori_img[0,:,vid,:,:].cpu()) for vid, _ in enumerate(frame_id_list) ]
153 | for id, im in enumerate(imgs):
154 | im.save(f"{args.save_path}/ori_img{id}.jpg")
155 |
156 | img_squeeze = rearrange(ori_img, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size[0], p2=patch_size[0])
157 | img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
158 | img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
159 | img_patch[bool_masked_pos] = outputs
160 |
161 | #make mask
162 | mask = torch.ones_like(img_patch)
163 | mask[bool_masked_pos] = 0
164 | mask = rearrange(mask, 'b n (p c) -> b n p c', c=3)
165 | mask = rearrange(mask, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
166 |
167 | #save reconstruction video
168 | rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)
169 | # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
170 | rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean(dim=-2, keepdim=True)
171 | rec_img = rearrange(rec_img, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14)
172 | imgs = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0,0.996)) for vid, _ in enumerate(frame_id_list) ]
173 |
174 | for id, im in enumerate(imgs):
175 | im.save(f"{args.save_path}/rec_img{id}.jpg")
176 |
177 | #save masked video
178 | img_mask = rec_img * mask
179 | imgs = [ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(frame_id_list)]
180 | for id, im in enumerate(imgs):
181 | im.save(f"{args.save_path}/mask_img{id}.jpg")
182 |
183 | if __name__ == '__main__':
184 | opts = get_args()
185 | main(opts)
186 |
--------------------------------------------------------------------------------
/engine_for_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import sys
5 | from typing import Iterable, Optional
6 | import torch
7 | from mixup import Mixup
8 | from timm.utils import accuracy, ModelEma
9 | import utils
10 | from scipy.special import softmax
11 |
12 | def train_class_batch(model, samples, target, criterion):
13 | outputs = model(samples)
14 | loss = criterion(outputs, target)
15 | return loss, outputs
16 |
17 |
18 | def get_loss_scale_for_deepspeed(model):
19 | optimizer = model.optimizer
20 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
21 |
22 |
23 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
24 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
25 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
27 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
28 | num_training_steps_per_epoch=None, update_freq=None):
29 | model.train(True)
30 | metric_logger = utils.MetricLogger(delimiter=" ")
31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
32 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
33 | header = 'Epoch: [{}]'.format(epoch)
34 | print_freq = 10
35 |
36 | if loss_scaler is None:
37 | model.zero_grad()
38 | model.micro_steps = 0
39 | else:
40 | optimizer.zero_grad()
41 |
42 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
43 | step = data_iter_step // update_freq
44 | if step >= num_training_steps_per_epoch:
45 | continue
46 | it = start_steps + step # global training iteration
47 | # Update LR & WD for the first acc
48 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
49 | for i, param_group in enumerate(optimizer.param_groups):
50 | if lr_schedule_values is not None:
51 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
52 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
53 | param_group["weight_decay"] = wd_schedule_values[it]
54 |
55 | samples = samples.to(device, non_blocking=True)
56 | targets = targets.to(device, non_blocking=True)
57 |
58 | if mixup_fn is not None:
59 | samples, targets = mixup_fn(samples, targets)
60 |
61 | if loss_scaler is None:
62 | samples = samples.half()
63 | loss, output = train_class_batch(
64 | model, samples, targets, criterion)
65 | else:
66 | with torch.cuda.amp.autocast():
67 | loss, output = train_class_batch(
68 | model, samples, targets, criterion)
69 |
70 | loss_value = loss.item()
71 |
72 | if not math.isfinite(loss_value):
73 | print("Loss is {}, stopping training".format(loss_value))
74 | sys.exit(1)
75 |
76 | if loss_scaler is None:
77 | loss /= update_freq
78 | model.backward(loss)
79 | model.step()
80 |
81 | if (data_iter_step + 1) % update_freq == 0:
82 | # model.zero_grad()
83 | # Deepspeed will call step() & model.zero_grad() automatic
84 | if model_ema is not None:
85 | model_ema.update(model)
86 | grad_norm = None
87 | loss_scale_value = get_loss_scale_for_deepspeed(model)
88 | else:
89 | # this attribute is added by timm on one optimizer (adahessian)
90 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
91 | loss /= update_freq
92 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
93 | parameters=model.parameters(), create_graph=is_second_order,
94 | update_grad=(data_iter_step + 1) % update_freq == 0)
95 | if (data_iter_step + 1) % update_freq == 0:
96 | optimizer.zero_grad()
97 | if model_ema is not None:
98 | model_ema.update(model)
99 | loss_scale_value = loss_scaler.state_dict()["scale"]
100 |
101 | torch.cuda.synchronize()
102 |
103 | if mixup_fn is None:
104 | class_acc = (output.max(-1)[-1] == targets).float().mean()
105 | else:
106 | class_acc = None
107 | metric_logger.update(loss=loss_value)
108 | metric_logger.update(class_acc=class_acc)
109 | metric_logger.update(loss_scale=loss_scale_value)
110 | min_lr = 10.
111 | max_lr = 0.
112 | for group in optimizer.param_groups:
113 | min_lr = min(min_lr, group["lr"])
114 | max_lr = max(max_lr, group["lr"])
115 |
116 | metric_logger.update(lr=max_lr)
117 | metric_logger.update(min_lr=min_lr)
118 | weight_decay_value = None
119 | for group in optimizer.param_groups:
120 | if group["weight_decay"] > 0:
121 | weight_decay_value = group["weight_decay"]
122 | metric_logger.update(weight_decay=weight_decay_value)
123 | metric_logger.update(grad_norm=grad_norm)
124 |
125 | if log_writer is not None:
126 | log_writer.update(loss=loss_value, head="loss")
127 | log_writer.update(class_acc=class_acc, head="loss")
128 | log_writer.update(loss_scale=loss_scale_value, head="opt")
129 | log_writer.update(lr=max_lr, head="opt")
130 | log_writer.update(min_lr=min_lr, head="opt")
131 | log_writer.update(weight_decay=weight_decay_value, head="opt")
132 | log_writer.update(grad_norm=grad_norm, head="opt")
133 |
134 | log_writer.set_step()
135 |
136 | # gather the stats from all processes
137 | metric_logger.synchronize_between_processes()
138 | print("Averaged stats:", metric_logger)
139 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
140 |
141 |
142 | @torch.no_grad()
143 | def validation_one_epoch(data_loader, model, device):
144 | criterion = torch.nn.CrossEntropyLoss()
145 |
146 | metric_logger = utils.MetricLogger(delimiter=" ")
147 | header = 'Val:'
148 |
149 | # switch to evaluation mode
150 | model.eval()
151 |
152 | for batch in metric_logger.log_every(data_loader, 10, header):
153 | videos = batch[0]
154 | target = batch[1]
155 | videos = videos.to(device, non_blocking=True)
156 | target = target.to(device, non_blocking=True)
157 |
158 | # compute output
159 | with torch.cuda.amp.autocast():
160 | output = model(videos)
161 | loss = criterion(output, target)
162 |
163 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
164 |
165 | batch_size = videos.shape[0]
166 | metric_logger.update(loss=loss.item())
167 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
168 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
169 | # gather the stats from all processes
170 | metric_logger.synchronize_between_processes()
171 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
172 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
173 |
174 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
175 |
176 |
177 |
178 | @torch.no_grad()
179 | def final_test(data_loader, model, device, file):
180 | criterion = torch.nn.CrossEntropyLoss()
181 |
182 | metric_logger = utils.MetricLogger(delimiter=" ")
183 | header = 'Test:'
184 |
185 | # switch to evaluation mode
186 | model.eval()
187 | final_result = []
188 |
189 | for batch in metric_logger.log_every(data_loader, 10, header):
190 | videos = batch[0]
191 | target = batch[1]
192 | ids = batch[2]
193 | chunk_nb = batch[3]
194 | split_nb = batch[4]
195 | videos = videos.to(device, non_blocking=True)
196 | target = target.to(device, non_blocking=True)
197 |
198 | # compute output
199 | with torch.cuda.amp.autocast():
200 | output = model(videos)
201 | loss = criterion(output, target)
202 |
203 | for i in range(output.size(0)):
204 | string = "{} {} {} {} {}\n".format(ids[i], \
205 | str(output.data[i].cpu().numpy().tolist()), \
206 | str(int(target[i].cpu().numpy())), \
207 | str(int(chunk_nb[i].cpu().numpy())), \
208 | str(int(split_nb[i].cpu().numpy())))
209 | final_result.append(string)
210 |
211 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
212 |
213 | batch_size = videos.shape[0]
214 | metric_logger.update(loss=loss.item())
215 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
216 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
217 |
218 | if not os.path.exists(file):
219 | os.mknod(file)
220 | with open(file, 'w') as f:
221 | f.write("{}, {}\n".format(acc1, acc5))
222 | for line in final_result:
223 | f.write(line)
224 | # gather the stats from all processes
225 | metric_logger.synchronize_between_processes()
226 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
227 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
228 |
229 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
230 |
231 |
232 | def merge(eval_path, num_tasks):
233 | dict_feats = {}
234 | dict_label = {}
235 | dict_pos = {}
236 | print("Reading individual output files")
237 |
238 | for x in range(num_tasks):
239 | file = os.path.join(eval_path, str(x) + '.txt')
240 | lines = open(file, 'r').readlines()[1:]
241 | for line in lines:
242 | line = line.strip()
243 | name = line.split('[')[0]
244 | label = line.split(']')[1].split(' ')[1]
245 | chunk_nb = line.split(']')[1].split(' ')[2]
246 | split_nb = line.split(']')[1].split(' ')[3]
247 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=float, sep=',')
248 | data = softmax(data)
249 | if not name in dict_feats:
250 | dict_feats[name] = []
251 | dict_label[name] = 0
252 | dict_pos[name] = []
253 | if chunk_nb + split_nb in dict_pos[name]:
254 | continue
255 | dict_feats[name].append(data)
256 | dict_pos[name].append(chunk_nb + split_nb)
257 | dict_label[name] = label
258 | print("Computing final results")
259 |
260 | input_lst = []
261 | print(len(dict_feats))
262 | for i, item in enumerate(dict_feats):
263 | input_lst.append([i, item, dict_feats[item], dict_label[item]])
264 | from multiprocessing import Pool
265 | p = Pool(64)
266 | ans = p.map(compute_video, input_lst)
267 | top1 = [x[1] for x in ans]
268 | top5 = [x[2] for x in ans]
269 | pred = [x[0] for x in ans]
270 | label = [x[3] for x in ans]
271 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
272 | return final_top1*100 ,final_top5*100
273 |
274 | def compute_video(lst):
275 | i, video_id, data, label = lst
276 | feat = [x for x in data]
277 | feat = np.mean(feat, axis=0)
278 | pred = np.argmax(feat)
279 | top1 = (int(pred) == int(label)) * 1.0
280 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
281 | return [pred, top1, top5, int(label)]
282 |
--------------------------------------------------------------------------------
/run_mae_pretraining.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import json
8 | import os
9 | from pathlib import Path
10 | from timm.models import create_model
11 | from optim_factory import create_optimizer
12 | from datasets import build_pretraining_dataset
13 | from engine_for_pretraining import train_one_epoch
14 | from utils import NativeScalerWithGradNormCount as NativeScaler
15 | import utils
16 | import modeling_pretrain
17 |
18 |
19 | def get_args():
20 | parser = argparse.ArgumentParser('VideoMAE pre-training script', add_help=False)
21 | parser.add_argument('--batch_size', default=64, type=int)
22 | parser.add_argument('--epochs', default=800, type=int)
23 | parser.add_argument('--save_ckpt_freq', default=50, type=int)
24 |
25 | # Model parameters
26 | parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
27 | help='Name of model to train')
28 |
29 | parser.add_argument('--decoder_depth', default=4, type=int,
30 | help='depth of decoder')
31 |
32 | parser.add_argument('--mask_type', default='tube', choices=['random', 'tube'],
33 | type=str, help='masked strategy of video tokens/patches')
34 |
35 | parser.add_argument('--mask_ratio', default=0.75, type=float,
36 | help='ratio of the visual tokens/patches need be masked')
37 |
38 | parser.add_argument('--input_size', default=224, type=int,
39 | help='videos input size for backbone')
40 |
41 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
42 | help='Drop path rate (default: 0.1)')
43 |
44 | parser.add_argument('--normlize_target', default=True, type=bool,
45 | help='normalized the target patch pixels')
46 |
47 | # Optimizer parameters
48 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
49 | help='Optimizer (default: "adamw"')
50 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
51 | help='Optimizer Epsilon (default: 1e-8)')
52 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
53 | help='Optimizer Betas (default: None, use opt default)')
54 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
55 | help='Clip gradient norm (default: None, no clipping)')
56 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
57 | help='SGD momentum (default: 0.9)')
58 | parser.add_argument('--weight_decay', type=float, default=0.05,
59 | help='weight decay (default: 0.05)')
60 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
61 | weight decay. We use a cosine schedule for WD.
62 | (Set the same value with args.weight_decay to keep weight decay no change)""")
63 |
64 | parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
65 | help='learning rate (default: 1.5e-4)')
66 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
67 | help='warmup learning rate (default: 1e-6)')
68 | parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
69 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
70 |
71 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
72 | help='epochs to warmup LR, if scheduler supports')
73 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
74 | help='epochs to warmup LR, if scheduler supports')
75 | parser.add_argument('--use_checkpoint', action='store_true')
76 | parser.set_defaults(use_checkpoint=False)
77 |
78 | # Augmentation parameters
79 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
80 | help='Color jitter factor (default: 0.4)')
81 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
82 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
83 |
84 | # Dataset parameters
85 | parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
86 | help='dataset path')
87 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
88 | parser.add_argument('--num_frames', type=int, default= 16)
89 | parser.add_argument('--sampling_rate', type=int, default= 4)
90 | parser.add_argument('--output_dir', default='',
91 | help='path where to save, empty for no saving')
92 | parser.add_argument('--log_dir', default=None,
93 | help='path where to tensorboard log')
94 | parser.add_argument('--device', default='cuda',
95 | help='device to use for training / testing')
96 | parser.add_argument('--seed', default=0, type=int)
97 | parser.add_argument('--resume', default='', help='resume from checkpoint')
98 | parser.add_argument('--auto_resume', action='store_true')
99 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
100 | parser.set_defaults(auto_resume=True)
101 |
102 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
103 | help='start epoch')
104 | parser.add_argument('--num_workers', default=10, type=int)
105 | parser.add_argument('--pin_mem', action='store_true',
106 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
107 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
108 | help='')
109 | parser.set_defaults(pin_mem=True)
110 |
111 | # distributed training parameters
112 | parser.add_argument('--world_size', default=1, type=int,
113 | help='number of distributed processes')
114 | parser.add_argument('--local_rank', default=-1, type=int)
115 | parser.add_argument('--dist_on_itp', action='store_true')
116 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
117 |
118 | return parser.parse_args()
119 |
120 |
121 | def get_model(args):
122 | print(f"Creating model: {args.model}")
123 | model = create_model(
124 | args.model,
125 | pretrained=False,
126 | drop_path_rate=args.drop_path,
127 | drop_block_rate=None,
128 | decoder_depth=args.decoder_depth,
129 | use_checkpoint=args.use_checkpoint
130 | )
131 | return model
132 |
133 |
134 | def main(args):
135 | utils.init_distributed_mode(args)
136 |
137 | print(args)
138 |
139 | device = torch.device(args.device)
140 |
141 | # fix the seed for reproducibility
142 | seed = args.seed + utils.get_rank()
143 | torch.manual_seed(seed)
144 | np.random.seed(seed)
145 |
146 | cudnn.benchmark = True
147 |
148 | model = get_model(args)
149 | patch_size = model.encoder.patch_embed.patch_size
150 | print("Patch size = %s" % str(patch_size))
151 | args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
152 | args.patch_size = patch_size
153 |
154 | # get dataset
155 | dataset_train = build_pretraining_dataset(args)
156 |
157 |
158 | num_tasks = utils.get_world_size()
159 | global_rank = utils.get_rank()
160 | sampler_rank = global_rank
161 | num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks
162 |
163 | sampler_train = torch.utils.data.DistributedSampler(
164 | dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
165 | )
166 | print("Sampler_train = %s" % str(sampler_train))
167 |
168 |
169 | if global_rank == 0 and args.log_dir is not None:
170 | os.makedirs(args.log_dir, exist_ok=True)
171 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
172 | else:
173 | log_writer = None
174 |
175 | data_loader_train = torch.utils.data.DataLoader(
176 | dataset_train, sampler=sampler_train,
177 | batch_size=args.batch_size,
178 | num_workers=args.num_workers,
179 | pin_memory=args.pin_mem,
180 | drop_last=True,
181 | worker_init_fn=utils.seed_worker
182 | )
183 |
184 | model.to(device)
185 | model_without_ddp = model
186 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
187 |
188 | print("Model = %s" % str(model_without_ddp))
189 | print('number of params: {} M'.format(n_parameters / 1e6))
190 |
191 | total_batch_size = args.batch_size * utils.get_world_size()
192 |
193 | args.lr = args.lr * total_batch_size / 256
194 | args.min_lr = args.min_lr * total_batch_size / 256
195 | args.warmup_lr = args.warmup_lr * total_batch_size / 256
196 | print("LR = %.8f" % args.lr)
197 | print("Batch size = %d" % total_batch_size)
198 | print("Number of training steps = %d" % num_training_steps_per_epoch)
199 | print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
200 |
201 | if args.distributed:
202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
203 | model_without_ddp = model.module
204 |
205 | optimizer = create_optimizer(
206 | args, model_without_ddp)
207 | loss_scaler = NativeScaler()
208 |
209 | print("Use step level LR & WD scheduler!")
210 | lr_schedule_values = utils.cosine_scheduler(
211 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
212 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
213 | )
214 | if args.weight_decay_end is None:
215 | args.weight_decay_end = args.weight_decay
216 | wd_schedule_values = utils.cosine_scheduler(
217 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
218 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
219 |
220 | utils.auto_load_model(
221 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
222 | torch.cuda.empty_cache()
223 | print(f"Start training for {args.epochs} epochs")
224 | start_time = time.time()
225 | for epoch in range(args.start_epoch, args.epochs):
226 | if args.distributed:
227 | data_loader_train.sampler.set_epoch(epoch)
228 | if log_writer is not None:
229 | log_writer.set_step(epoch * num_training_steps_per_epoch)
230 | train_stats = train_one_epoch(
231 | model, data_loader_train,
232 | optimizer, device, epoch, loss_scaler,
233 | args.clip_grad, log_writer=log_writer,
234 | start_steps=epoch * num_training_steps_per_epoch,
235 | lr_schedule_values=lr_schedule_values,
236 | wd_schedule_values=wd_schedule_values,
237 | patch_size=patch_size[0],
238 | normlize_target=args.normlize_target,
239 | )
240 | if args.output_dir:
241 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
242 | utils.save_model(
243 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
244 | loss_scaler=loss_scaler, epoch=epoch)
245 |
246 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
247 | 'epoch': epoch, 'n_parameters': n_parameters}
248 |
249 | if args.output_dir and utils.is_main_process():
250 | if log_writer is not None:
251 | log_writer.flush()
252 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
253 | f.write(json.dumps(log_stats) + "\n")
254 |
255 | total_time = time.time() - start_time
256 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
257 | print('Training time {}'.format(total_time_str))
258 |
259 |
260 | if __name__ == '__main__':
261 | opts = get_args()
262 | if opts.output_dir:
263 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
264 | main(opts)
265 |
--------------------------------------------------------------------------------
/modeling_finetune.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 | import torch.utils.checkpoint as checkpoint
9 |
10 |
11 | def _cfg(url='', **kwargs):
12 | return {
13 | 'url': url,
14 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
15 | 'crop_pct': .9, 'interpolation': 'bicubic',
16 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
17 | **kwargs
18 | }
19 |
20 |
21 | class DropPath(nn.Module):
22 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
23 | """
24 | def __init__(self, drop_prob=None):
25 | super(DropPath, self).__init__()
26 | self.drop_prob = drop_prob
27 |
28 | def forward(self, x):
29 | return drop_path(x, self.drop_prob, self.training)
30 |
31 | def extra_repr(self) -> str:
32 | return 'p={}'.format(self.drop_prob)
33 |
34 |
35 | class Mlp(nn.Module):
36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
37 | super().__init__()
38 | out_features = out_features or in_features
39 | hidden_features = hidden_features or in_features
40 | self.fc1 = nn.Linear(in_features, hidden_features)
41 | self.act = act_layer()
42 | self.fc2 = nn.Linear(hidden_features, out_features)
43 | self.drop = nn.Dropout(drop)
44 |
45 | def forward(self, x):
46 | x = self.fc1(x)
47 | x = self.act(x)
48 | # x = self.drop(x)
49 | # commit this for the orignal BERT implement
50 | x = self.fc2(x)
51 | x = self.drop(x)
52 | return x
53 |
54 |
55 | class Attention(nn.Module):
56 | def __init__(
57 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
58 | proj_drop=0., attn_head_dim=None):
59 | super().__init__()
60 | self.num_heads = num_heads
61 | head_dim = dim // num_heads
62 | if attn_head_dim is not None:
63 | head_dim = attn_head_dim
64 | all_head_dim = head_dim * self.num_heads
65 | self.scale = qk_scale or head_dim ** -0.5
66 |
67 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
68 | if qkv_bias:
69 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
70 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
71 | else:
72 | self.q_bias = None
73 | self.v_bias = None
74 |
75 | self.attn_drop = nn.Dropout(attn_drop)
76 | self.proj = nn.Linear(all_head_dim, dim)
77 | self.proj_drop = nn.Dropout(proj_drop)
78 |
79 | def forward(self, x):
80 | B, N, C = x.shape
81 | qkv_bias = None
82 | if self.q_bias is not None:
83 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
84 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
85 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
86 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
87 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
88 |
89 | q = q * self.scale
90 | attn = (q @ k.transpose(-2, -1))
91 |
92 |
93 | attn = attn.softmax(dim=-1)
94 | attn = self.attn_drop(attn)
95 |
96 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
97 | x = self.proj(x)
98 | x = self.proj_drop(x)
99 | return x
100 |
101 |
102 | class Block(nn.Module):
103 |
104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
105 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
106 | attn_head_dim=None):
107 | super().__init__()
108 | self.norm1 = norm_layer(dim)
109 | self.attn = Attention(
110 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
111 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
112 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
113 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114 | self.norm2 = norm_layer(dim)
115 | mlp_hidden_dim = int(dim * mlp_ratio)
116 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
117 |
118 | if init_values > 0:
119 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
120 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
121 | else:
122 | self.gamma_1, self.gamma_2 = None, None
123 |
124 | def forward(self, x):
125 | if self.gamma_1 is None:
126 | x = x + self.drop_path(self.attn(self.norm1(x)))
127 | x = x + self.drop_path(self.mlp(self.norm2(x)))
128 | else:
129 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
130 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
131 | return x
132 |
133 |
134 | class PatchEmbed(nn.Module):
135 | """ Image to Patch Embedding
136 | """
137 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
138 | super().__init__()
139 | img_size = to_2tuple(img_size)
140 | patch_size = to_2tuple(patch_size)
141 | self.tubelet_size = int(tubelet_size)
142 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
143 | self.img_size = img_size
144 | self.patch_size = patch_size
145 | self.num_patches = num_patches
146 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
147 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
148 | stride=(self.tubelet_size, patch_size[0], patch_size[1]))
149 |
150 | def forward(self, x, **kwargs):
151 | B, C, T, H, W = x.shape
152 | # FIXME look at relaxing size constraints
153 | assert H == self.img_size[0] and W == self.img_size[1], \
154 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
155 | x = self.proj(x).flatten(2).transpose(1, 2)
156 | return x
157 |
158 | # sin-cos position encoding
159 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
160 | def get_sinusoid_encoding_table(n_position, d_hid):
161 | ''' Sinusoid position encoding table '''
162 | # TODO: make it with torch instead of numpy
163 | def get_position_angle_vec(position):
164 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
165 |
166 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
167 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
168 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
169 |
170 | return torch.tensor(sinusoid_table,dtype=torch.float, requires_grad=False).unsqueeze(0)
171 |
172 |
173 | class VisionTransformer(nn.Module):
174 | """ Vision Transformer with support for patch or hybrid CNN input stage
175 | """
176 | def __init__(self,
177 | img_size=224,
178 | patch_size=16,
179 | in_chans=3,
180 | num_classes=1000,
181 | embed_dim=768,
182 | depth=12,
183 | num_heads=12,
184 | mlp_ratio=4.,
185 | qkv_bias=False,
186 | qk_scale=None,
187 | fc_drop_rate=0.,
188 | drop_rate=0.,
189 | attn_drop_rate=0.,
190 | drop_path_rate=0.,
191 | norm_layer=nn.LayerNorm,
192 | init_values=0.,
193 | use_learnable_pos_emb=False,
194 | init_scale=0.,
195 | all_frames=16,
196 | tubelet_size=2,
197 | use_checkpoint=False,
198 | use_mean_pooling=True):
199 | super().__init__()
200 | self.num_classes = num_classes
201 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
202 | self.tubelet_size = tubelet_size
203 | self.patch_embed = PatchEmbed(
204 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
205 | num_patches = self.patch_embed.num_patches
206 | self.use_checkpoint = use_checkpoint
207 |
208 | if use_learnable_pos_emb:
209 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
210 | else:
211 | # sine-cosine positional embeddings is on the way
212 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
213 |
214 | self.pos_drop = nn.Dropout(p=drop_rate)
215 |
216 |
217 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
218 | self.blocks = nn.ModuleList([
219 | Block(
220 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
221 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
222 | init_values=init_values)
223 | for i in range(depth)])
224 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
225 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
226 | self.fc_dropout = nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
227 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
228 |
229 | if use_learnable_pos_emb:
230 | trunc_normal_(self.pos_embed, std=.02)
231 |
232 | trunc_normal_(self.head.weight, std=.02)
233 | self.apply(self._init_weights)
234 |
235 | self.head.weight.data.mul_(init_scale)
236 | self.head.bias.data.mul_(init_scale)
237 |
238 | def _init_weights(self, m):
239 | if isinstance(m, nn.Linear):
240 | trunc_normal_(m.weight, std=.02)
241 | if isinstance(m, nn.Linear) and m.bias is not None:
242 | nn.init.constant_(m.bias, 0)
243 | elif isinstance(m, nn.LayerNorm):
244 | nn.init.constant_(m.bias, 0)
245 | nn.init.constant_(m.weight, 1.0)
246 |
247 | def get_num_layers(self):
248 | return len(self.blocks)
249 |
250 | @torch.jit.ignore
251 | def no_weight_decay(self):
252 | return {'pos_embed', 'cls_token'}
253 |
254 | def get_classifier(self):
255 | return self.head
256 |
257 | def reset_classifier(self, num_classes, global_pool=''):
258 | self.num_classes = num_classes
259 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
260 |
261 | def forward_features(self, x):
262 | x = self.patch_embed(x)
263 | B, _, _ = x.size()
264 |
265 | if self.pos_embed is not None:
266 | x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
267 | x = self.pos_drop(x)
268 |
269 | if self.use_checkpoint:
270 | for blk in self.blocks:
271 | x = checkpoint.checkpoint(blk, x)
272 | else:
273 | for blk in self.blocks:
274 | x = blk(x)
275 |
276 | x = self.norm(x)
277 | if self.fc_norm is not None:
278 | return self.fc_norm(x.mean(1))
279 | else:
280 | return x[:, 0]
281 |
282 | def forward(self, x):
283 | x = self.forward_features(x)
284 | x = self.head(self.fc_dropout(x))
285 | return x
286 |
287 |
288 | @register_model
289 | def vit_small_patch16_224(pretrained=False, **kwargs):
290 | model = VisionTransformer(
291 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
292 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
293 | model.default_cfg = _cfg()
294 | return model
295 |
296 |
297 | @register_model
298 | def vit_base_patch16_224(pretrained=False, **kwargs):
299 | model = VisionTransformer(
300 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
301 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
302 | model.default_cfg = _cfg()
303 | return model
304 |
305 |
306 | @register_model
307 | def vit_base_patch16_384(pretrained=False, **kwargs):
308 | model = VisionTransformer(
309 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
310 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
311 | model.default_cfg = _cfg()
312 | return model
313 |
314 |
315 | @register_model
316 | def vit_large_patch16_224(pretrained=False, **kwargs):
317 | model = VisionTransformer(
318 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
319 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
320 | model.default_cfg = _cfg()
321 | return model
322 |
323 |
324 | @register_model
325 | def vit_large_patch16_384(pretrained=False, **kwargs):
326 | model = VisionTransformer(
327 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
328 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
329 | model.default_cfg = _cfg()
330 | return model
331 |
332 |
333 | @register_model
334 | def vit_large_patch16_512(pretrained=False, **kwargs):
335 | model = VisionTransformer(
336 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
337 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
338 | model.default_cfg = _cfg()
339 | return model
340 |
341 |
342 | @register_model
343 | def vit_huge_patch16_224(pretrained=False, **kwargs):
344 | model = VisionTransformer(
345 | patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
346 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
347 | model.default_cfg = _cfg()
348 | return model
349 |
--------------------------------------------------------------------------------
/modeling_pretrain.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.checkpoint as checkpoint
6 | from functools import partial
7 |
8 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
9 | from timm.models.registry import register_model
10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
11 |
12 |
13 |
14 | def trunc_normal_(tensor, mean=0., std=1.):
15 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
16 |
17 |
18 | __all__ = [
19 | 'pretrain_videomae_small_patch16_224',
20 | 'pretrain_videomae_base_patch16_224',
21 | 'pretrain_videomae_large_patch16_224',
22 | 'pretrain_videomae_huge_patch16_224',
23 | ]
24 |
25 |
26 | class PretrainVisionTransformerEncoder(nn.Module):
27 | """ Vision Transformer with support for patch or hybrid CNN input stage
28 | """
29 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
30 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
31 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, use_checkpoint=False,
32 | use_learnable_pos_emb=False):
33 | super().__init__()
34 | self.num_classes = num_classes
35 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
36 | self.patch_embed = PatchEmbed(
37 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size)
38 | num_patches = self.patch_embed.num_patches
39 | self.use_checkpoint = use_checkpoint
40 |
41 |
42 | # TODO: Add the cls token
43 | if use_learnable_pos_emb:
44 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
45 | else:
46 | # sine-cosine positional embeddings
47 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
48 |
49 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
50 | self.blocks = nn.ModuleList([
51 | Block(
52 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
53 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
54 | init_values=init_values)
55 | for i in range(depth)])
56 | self.norm = norm_layer(embed_dim)
57 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
58 |
59 | if use_learnable_pos_emb:
60 | trunc_normal_(self.pos_embed, std=.02)
61 |
62 | self.apply(self._init_weights)
63 |
64 |
65 | def _init_weights(self, m):
66 | if isinstance(m, nn.Linear):
67 | nn.init.xavier_uniform_(m.weight)
68 | if isinstance(m, nn.Linear) and m.bias is not None:
69 | nn.init.constant_(m.bias, 0)
70 | elif isinstance(m, nn.LayerNorm):
71 | nn.init.constant_(m.bias, 0)
72 | nn.init.constant_(m.weight, 1.0)
73 |
74 | def get_num_layers(self):
75 | return len(self.blocks)
76 |
77 | @torch.jit.ignore
78 | def no_weight_decay(self):
79 | return {'pos_embed', 'cls_token'}
80 |
81 | def get_classifier(self):
82 | return self.head
83 |
84 | def reset_classifier(self, num_classes, global_pool=''):
85 | self.num_classes = num_classes
86 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
87 |
88 | def forward_features(self, x, mask):
89 | _, _, T, _, _ = x.shape
90 | x = self.patch_embed(x)
91 |
92 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
93 |
94 | B, _, C = x.shape
95 | x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
96 |
97 | if self.use_checkpoint:
98 | for blk in self.blocks:
99 | x_vis = checkpoint.checkpoint(blk, x_vis)
100 | else:
101 | for blk in self.blocks:
102 | x_vis = blk(x_vis)
103 |
104 | x_vis = self.norm(x_vis)
105 | return x_vis
106 |
107 | def forward(self, x, mask):
108 | x = self.forward_features(x, mask)
109 | x = self.head(x)
110 | return x
111 |
112 | class PretrainVisionTransformerDecoder(nn.Module):
113 | """ Vision Transformer with support for patch or hybrid CNN input stage
114 | """
115 | def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
116 | qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
117 | norm_layer=nn.LayerNorm, init_values=None, num_patches=196, tubelet_size=2, use_checkpoint=False
118 | ):
119 | super().__init__()
120 | self.num_classes = num_classes
121 | assert num_classes == 3 * tubelet_size * patch_size ** 2
122 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
123 | self.patch_size = patch_size
124 | self.use_checkpoint = use_checkpoint
125 |
126 |
127 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
128 | self.blocks = nn.ModuleList([
129 | Block(
130 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
131 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
132 | init_values=init_values)
133 | for i in range(depth)])
134 | self.norm = norm_layer(embed_dim)
135 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
136 |
137 | self.apply(self._init_weights)
138 |
139 |
140 | def _init_weights(self, m):
141 | if isinstance(m, nn.Linear):
142 | nn.init.xavier_uniform_(m.weight)
143 | if isinstance(m, nn.Linear) and m.bias is not None:
144 | nn.init.constant_(m.bias, 0)
145 | elif isinstance(m, nn.LayerNorm):
146 | nn.init.constant_(m.bias, 0)
147 | nn.init.constant_(m.weight, 1.0)
148 |
149 | def get_num_layers(self):
150 | return len(self.blocks)
151 |
152 | @torch.jit.ignore
153 | def no_weight_decay(self):
154 | return {'pos_embed', 'cls_token'}
155 |
156 | def get_classifier(self):
157 | return self.head
158 |
159 | def reset_classifier(self, num_classes, global_pool=''):
160 | self.num_classes = num_classes
161 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
162 |
163 | def forward(self, x, return_token_num):
164 | if self.use_checkpoint:
165 | for blk in self.blocks:
166 | x = checkpoint.checkpoint(blk, x)
167 | else:
168 | for blk in self.blocks:
169 | x = blk(x)
170 |
171 | if return_token_num > 0:
172 | x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
173 | else:
174 | x = self.head(self.norm(x))
175 |
176 | return x
177 |
178 | class PretrainVisionTransformer(nn.Module):
179 | """ Vision Transformer with support for patch or hybrid CNN input stage
180 | """
181 | def __init__(self,
182 | img_size=224,
183 | patch_size=16,
184 | encoder_in_chans=3,
185 | encoder_num_classes=0,
186 | encoder_embed_dim=768,
187 | encoder_depth=12,
188 | encoder_num_heads=12,
189 | decoder_num_classes=1536, # decoder_num_classes=768,
190 | decoder_embed_dim=512,
191 | decoder_depth=8,
192 | decoder_num_heads=8,
193 | mlp_ratio=4.,
194 | qkv_bias=False,
195 | qk_scale=None,
196 | drop_rate=0.,
197 | attn_drop_rate=0.,
198 | drop_path_rate=0.,
199 | norm_layer=nn.LayerNorm,
200 | init_values=0.,
201 | use_learnable_pos_emb=False,
202 | use_checkpoint=False,
203 | tubelet_size=2,
204 | num_classes=0, # avoid the error from create_fn in timm
205 | in_chans=0, # avoid the error from create_fn in timm
206 | ):
207 | super().__init__()
208 | self.encoder = PretrainVisionTransformerEncoder(
209 | img_size=img_size,
210 | patch_size=patch_size,
211 | in_chans=encoder_in_chans,
212 | num_classes=encoder_num_classes,
213 | embed_dim=encoder_embed_dim,
214 | depth=encoder_depth,
215 | num_heads=encoder_num_heads,
216 | mlp_ratio=mlp_ratio,
217 | qkv_bias=qkv_bias,
218 | qk_scale=qk_scale,
219 | drop_rate=drop_rate,
220 | attn_drop_rate=attn_drop_rate,
221 | drop_path_rate=drop_path_rate,
222 | norm_layer=norm_layer,
223 | init_values=init_values,
224 | tubelet_size=tubelet_size,
225 | use_checkpoint=use_checkpoint,
226 | use_learnable_pos_emb=use_learnable_pos_emb)
227 |
228 | self.decoder = PretrainVisionTransformerDecoder(
229 | patch_size=patch_size,
230 | num_patches=self.encoder.patch_embed.num_patches,
231 | num_classes=decoder_num_classes,
232 | embed_dim=decoder_embed_dim,
233 | depth=decoder_depth,
234 | num_heads=decoder_num_heads,
235 | mlp_ratio=mlp_ratio,
236 | qkv_bias=qkv_bias,
237 | qk_scale=qk_scale,
238 | drop_rate=drop_rate,
239 | attn_drop_rate=attn_drop_rate,
240 | drop_path_rate=drop_path_rate,
241 | norm_layer=norm_layer,
242 | init_values=init_values,
243 | tubelet_size=tubelet_size,
244 | use_checkpoint=use_checkpoint)
245 |
246 | self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
247 |
248 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
249 |
250 | self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
251 |
252 | trunc_normal_(self.mask_token, std=.02)
253 |
254 |
255 | def _init_weights(self, m):
256 | if isinstance(m, nn.Linear):
257 | nn.init.xavier_uniform_(m.weight)
258 | if isinstance(m, nn.Linear) and m.bias is not None:
259 | nn.init.constant_(m.bias, 0)
260 | elif isinstance(m, nn.LayerNorm):
261 | nn.init.constant_(m.bias, 0)
262 | nn.init.constant_(m.weight, 1.0)
263 |
264 | def get_num_layers(self):
265 | return len(self.blocks)
266 |
267 | @torch.jit.ignore
268 | def no_weight_decay(self):
269 | return {'pos_embed', 'cls_token', 'mask_token'}
270 |
271 | def forward(self, x, mask):
272 | _, _, T, _, _ = x.shape
273 | x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
274 | x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
275 | B, N, C = x_vis.shape
276 | # we don't unshuffle the correct visible token order,
277 | # but shuffle the pos embedding accorddingly.
278 | expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
279 | pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
280 | pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
281 | x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
282 | x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
283 |
284 | return x
285 |
286 | @register_model
287 | def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
288 | model = PretrainVisionTransformer(
289 | img_size=224,
290 | patch_size=16,
291 | encoder_embed_dim=384,
292 | encoder_depth=12,
293 | encoder_num_heads=6,
294 | encoder_num_classes=0,
295 | decoder_num_classes=1536,
296 | decoder_embed_dim=192,
297 | decoder_num_heads=3,
298 | mlp_ratio=4,
299 | qkv_bias=True,
300 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
301 | **kwargs)
302 | model.default_cfg = _cfg()
303 | if pretrained:
304 | checkpoint = torch.load(
305 | kwargs["init_ckpt"], map_location="cpu"
306 | )
307 | model.load_state_dict(checkpoint["model"])
308 | return model
309 |
310 | @register_model
311 | def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
312 | model = PretrainVisionTransformer(
313 | img_size=224,
314 | patch_size=16,
315 | encoder_embed_dim=768,
316 | encoder_depth=12,
317 | encoder_num_heads=12,
318 | encoder_num_classes=0,
319 | decoder_num_classes=1536,
320 | decoder_embed_dim=384,
321 | decoder_num_heads=6,
322 | mlp_ratio=4,
323 | qkv_bias=True,
324 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
325 | **kwargs)
326 | model.default_cfg = _cfg()
327 | if pretrained:
328 | checkpoint = torch.load(
329 | kwargs["init_ckpt"], map_location="cpu"
330 | )
331 | model.load_state_dict(checkpoint["model"])
332 | return model
333 |
334 | @register_model
335 | def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
336 | model = PretrainVisionTransformer(
337 | img_size=224,
338 | patch_size=16,
339 | encoder_embed_dim=1024,
340 | encoder_depth=24,
341 | encoder_num_heads=16,
342 | encoder_num_classes=0,
343 | decoder_num_classes=1536,
344 | decoder_embed_dim=512,
345 | decoder_num_heads=8,
346 | mlp_ratio=4,
347 | qkv_bias=True,
348 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
349 | **kwargs)
350 | model.default_cfg = _cfg()
351 | if pretrained:
352 | checkpoint = torch.load(
353 | kwargs["init_ckpt"], map_location="cpu"
354 | )
355 | model.load_state_dict(checkpoint["model"])
356 | return model
357 |
358 | @register_model
359 | def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
360 | model = PretrainVisionTransformer(
361 | img_size=224,
362 | patch_size=16,
363 | encoder_embed_dim=1280,
364 | encoder_depth=32,
365 | encoder_num_heads=16,
366 | encoder_num_classes=0,
367 | decoder_num_classes=1536,
368 | decoder_embed_dim=640,
369 | decoder_num_heads=8,
370 | mlp_ratio=4,
371 | qkv_bias=True,
372 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
373 | **kwargs)
374 | model.default_cfg = _cfg()
375 | if pretrained:
376 | checkpoint = torch.load(
377 | kwargs["init_ckpt"], map_location="cpu"
378 | )
379 | model.load_state_dict(checkpoint["model"])
380 | return model
381 |
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27 | return y1 * lam + y2 * (1. - lam)
28 |
29 |
30 | def rand_bbox(img_shape, lam, margin=0., count=None):
31 | """ Standard CutMix bounding-box
32 | Generates a random square bbox based on lambda value. This impl includes
33 | support for enforcing a border margin as percent of bbox dimensions.
34 |
35 | Args:
36 | img_shape (tuple): Image shape as tuple
37 | lam (float): Cutmix lambda value
38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39 | count (int): Number of bbox to generate
40 | """
41 | ratio = np.sqrt(1 - lam)
42 | img_h, img_w = img_shape[-2:]
43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47 | yl = np.clip(cy - cut_h // 2, 0, img_h)
48 | yh = np.clip(cy + cut_h // 2, 0, img_h)
49 | xl = np.clip(cx - cut_w // 2, 0, img_w)
50 | xh = np.clip(cx + cut_w // 2, 0, img_w)
51 | return yl, yh, xl, xh
52 |
53 |
54 | def rand_bbox_minmax(img_shape, minmax, count=None):
55 | """ Min-Max CutMix bounding-box
56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
57 | based on min/max percent values applied to each dimension of the input image.
58 |
59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60 |
61 | Args:
62 | img_shape (tuple): Image shape as tuple
63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64 | count (int): Number of bbox to generate
65 | """
66 | assert len(minmax) == 2
67 | img_h, img_w = img_shape[-2:]
68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70 | yl = np.random.randint(0, img_h - cut_h, size=count)
71 | xl = np.random.randint(0, img_w - cut_w, size=count)
72 | yu = yl + cut_h
73 | xu = xl + cut_w
74 | return yl, yu, xl, xu
75 |
76 |
77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78 | """ Generate bbox and apply lambda correction.
79 | """
80 | if ratio_minmax is not None:
81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82 | else:
83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84 | if correct_lam or ratio_minmax is not None:
85 | bbox_area = (yu - yl) * (xu - xl)
86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87 | return (yl, yu, xl, xu), lam
88 |
89 |
90 | class Mixup:
91 | """ Mixup/Cutmix that applies different params to each element or whole batch
92 |
93 | Args:
94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97 | prob (float): probability of applying mixup or cutmix per batch or element
98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101 | label_smoothing (float): apply label smoothing to the mixed target tensor
102 | num_classes (int): number of classes for target
103 | """
104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106 | self.mixup_alpha = mixup_alpha
107 | self.cutmix_alpha = cutmix_alpha
108 | self.cutmix_minmax = cutmix_minmax
109 | if self.cutmix_minmax is not None:
110 | assert len(self.cutmix_minmax) == 2
111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112 | self.cutmix_alpha = 1.0
113 | self.mix_prob = prob
114 | self.switch_prob = switch_prob
115 | self.label_smoothing = label_smoothing
116 | self.num_classes = num_classes
117 | self.mode = mode
118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120 |
121 | def _params_per_elem(self, batch_size):
122 | lam = np.ones(batch_size, dtype=np.float32)
123 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
124 | if self.mixup_enabled:
125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
127 | lam_mix = np.where(
128 | use_cutmix,
129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131 | elif self.mixup_alpha > 0.:
132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133 | elif self.cutmix_alpha > 0.:
134 | use_cutmix = np.ones(batch_size, dtype=np.bool)
135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136 | else:
137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139 | return lam, use_cutmix
140 |
141 | def _params_per_batch(self):
142 | lam = 1.
143 | use_cutmix = False
144 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146 | use_cutmix = np.random.rand() < self.switch_prob
147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
149 | elif self.mixup_alpha > 0.:
150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151 | elif self.cutmix_alpha > 0.:
152 | use_cutmix = True
153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154 | else:
155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156 | lam = float(lam_mix)
157 | return lam, use_cutmix
158 |
159 | def _mix_elem(self, x):
160 | batch_size = len(x)
161 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
162 | x_orig = x.clone() # need to keep an unmodified original for mixing source
163 | for i in range(batch_size):
164 | j = batch_size - i - 1
165 | lam = lam_batch[i]
166 | if lam != 1.:
167 | if use_cutmix[i]:
168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
171 | lam_batch[i] = lam
172 | else:
173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175 |
176 | def _mix_pair(self, x):
177 | batch_size = len(x)
178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179 | x_orig = x.clone() # need to keep an unmodified original for mixing source
180 | for i in range(batch_size // 2):
181 | j = batch_size - i - 1
182 | lam = lam_batch[i]
183 | if lam != 1.:
184 | if use_cutmix[i]:
185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189 | lam_batch[i] = lam
190 | else:
191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195 |
196 | def _mix_batch(self, x):
197 | lam, use_cutmix = self._params_per_batch()
198 | if lam == 1.:
199 | return 1.
200 | if use_cutmix:
201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
204 | else:
205 | x_flipped = x.flip(0).mul_(1. - lam)
206 | x.mul_(lam).add_(x_flipped)
207 | return lam
208 |
209 | def __call__(self, x, target):
210 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
211 | if self.mode == 'elem':
212 | lam = self._mix_elem(x)
213 | elif self.mode == 'pair':
214 | lam = self._mix_pair(x)
215 | else:
216 | lam = self._mix_batch(x)
217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218 | return x, target
219 |
220 |
221 | class FastCollateMixup(Mixup):
222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223 |
224 | A Mixup impl that's performed while collating the batches.
225 | """
226 |
227 | def _mix_elem_collate(self, output, batch, half=False):
228 | batch_size = len(batch)
229 | num_elem = batch_size // 2 if half else batch_size
230 | assert len(output) == num_elem
231 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
232 | for i in range(num_elem):
233 | j = batch_size - i - 1
234 | lam = lam_batch[i]
235 | mixed = batch[i][0]
236 | if lam != 1.:
237 | if use_cutmix[i]:
238 | if not half:
239 | mixed = mixed.copy()
240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243 | lam_batch[i] = lam
244 | else:
245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246 | np.rint(mixed, out=mixed)
247 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
248 | if half:
249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250 | return torch.tensor(lam_batch).unsqueeze(1)
251 |
252 | def _mix_pair_collate(self, output, batch):
253 | batch_size = len(batch)
254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255 | for i in range(batch_size // 2):
256 | j = batch_size - i - 1
257 | lam = lam_batch[i]
258 | mixed_i = batch[i][0]
259 | mixed_j = batch[j][0]
260 | assert 0 <= lam <= 1.0
261 | if lam < 1.:
262 | if use_cutmix[i]:
263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267 | mixed_j[:, yl:yh, xl:xh] = patch_i
268 | lam_batch[i] = lam
269 | else:
270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272 | mixed_i = mixed_temp
273 | np.rint(mixed_j, out=mixed_j)
274 | np.rint(mixed_i, out=mixed_i)
275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278 | return torch.tensor(lam_batch).unsqueeze(1)
279 |
280 | def _mix_batch_collate(self, output, batch):
281 | batch_size = len(batch)
282 | lam, use_cutmix = self._params_per_batch()
283 | if use_cutmix:
284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286 | for i in range(batch_size):
287 | j = batch_size - i - 1
288 | mixed = batch[i][0]
289 | if lam != 1.:
290 | if use_cutmix:
291 | mixed = mixed.copy() # don't want to modify the original while iterating
292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
293 | else:
294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295 | np.rint(mixed, out=mixed)
296 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
297 | return lam
298 |
299 | def __call__(self, batch, _=None):
300 | batch_size = len(batch)
301 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
302 | half = 'half' in self.mode
303 | if half:
304 | batch_size //= 2
305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306 | if self.mode == 'elem' or self.mode == 'half':
307 | lam = self._mix_elem_collate(output, batch, half=half)
308 | elif self.mode == 'pair':
309 | lam = self._mix_pair_collate(output, batch)
310 | else:
311 | lam = self._mix_batch_collate(output, batch)
312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314 | target = target[:batch_size]
315 | return output, target
316 |
317 |
--------------------------------------------------------------------------------
/ssv2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torchvision import transforms
5 | from random_erasing import RandomErasing
6 | import warnings
7 | from decord import VideoReader, cpu
8 | from torch.utils.data import Dataset
9 | import video_transforms as video_transforms
10 | import volume_transforms as volume_transforms
11 |
12 |
13 | class SSVideoClsDataset(Dataset):
14 | """Load your own video classification dataset."""
15 |
16 | def __init__(self, anno_path, data_path, mode='train', clip_len=8,
17 | crop_size=224, short_side_size=256, new_height=256,
18 | new_width=340, keep_aspect_ratio=True, num_segment=1,
19 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None):
20 | self.anno_path = anno_path
21 | self.data_path = data_path
22 | self.mode = mode
23 | self.clip_len = clip_len
24 | self.crop_size = crop_size
25 | self.short_side_size = short_side_size
26 | self.new_height = new_height
27 | self.new_width = new_width
28 | self.keep_aspect_ratio = keep_aspect_ratio
29 | self.num_segment = num_segment
30 | self.test_num_segment = test_num_segment
31 | self.num_crop = num_crop
32 | self.test_num_crop = test_num_crop
33 | self.args = args
34 | self.aug = False
35 | self.rand_erase = False
36 | if self.mode in ['train']:
37 | self.aug = True
38 | if self.args.reprob > 0:
39 | self.rand_erase = True
40 | if VideoReader is None:
41 | raise ImportError("Unable to import `decord` which is required to read videos.")
42 |
43 | import pandas as pd
44 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
45 | self.dataset_samples = list(cleaned.values[:, 0])
46 | self.label_array = list(cleaned.values[:, 1])
47 |
48 | if (mode == 'train'):
49 | pass
50 |
51 | elif (mode == 'validation'):
52 | self.data_transform = video_transforms.Compose([
53 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
54 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
55 | volume_transforms.ClipToTensor(),
56 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
57 | std=[0.229, 0.224, 0.225])
58 | ])
59 | elif mode == 'test':
60 | self.data_resize = video_transforms.Compose([
61 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
62 | ])
63 | self.data_transform = video_transforms.Compose([
64 | volume_transforms.ClipToTensor(),
65 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
66 | std=[0.229, 0.224, 0.225])
67 | ])
68 | self.test_seg = []
69 | self.test_dataset = []
70 | self.test_label_array = []
71 | for ck in range(self.test_num_segment):
72 | for cp in range(self.test_num_crop):
73 | for idx in range(len(self.label_array)):
74 | sample_label = self.label_array[idx]
75 | self.test_label_array.append(sample_label)
76 | self.test_dataset.append(self.dataset_samples[idx])
77 | self.test_seg.append((ck, cp))
78 |
79 | def __getitem__(self, index):
80 | if self.mode == 'train':
81 | args = self.args
82 | scale_t = 1
83 |
84 | sample = self.dataset_samples[index]
85 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
86 | if len(buffer) == 0:
87 | while len(buffer) == 0:
88 | warnings.warn("video {} not correctly loaded during training".format(sample))
89 | index = np.random.randint(self.__len__())
90 | sample = self.dataset_samples[index]
91 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
92 |
93 | if args.num_sample > 1:
94 | frame_list = []
95 | label_list = []
96 | index_list = []
97 | for _ in range(args.num_sample):
98 | new_frames = self._aug_frame(buffer, args)
99 | label = self.label_array[index]
100 | frame_list.append(new_frames)
101 | label_list.append(label)
102 | index_list.append(index)
103 | return frame_list, label_list, index_list, {}
104 | else:
105 | buffer = self._aug_frame(buffer, args)
106 |
107 | return buffer, self.label_array[index], index, {}
108 |
109 | elif self.mode == 'validation':
110 | sample = self.dataset_samples[index]
111 | buffer = self.loadvideo_decord(sample)
112 | if len(buffer) == 0:
113 | while len(buffer) == 0:
114 | warnings.warn("video {} not correctly loaded during validation".format(sample))
115 | index = np.random.randint(self.__len__())
116 | sample = self.dataset_samples[index]
117 | buffer = self.loadvideo_decord(sample)
118 | buffer = self.data_transform(buffer)
119 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
120 |
121 | elif self.mode == 'test':
122 | sample = self.test_dataset[index]
123 | chunk_nb, split_nb = self.test_seg[index]
124 | buffer = self.loadvideo_decord(sample)
125 |
126 | while len(buffer) == 0:
127 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
128 | str(self.test_dataset[index]), chunk_nb, split_nb))
129 | index = np.random.randint(self.__len__())
130 | sample = self.test_dataset[index]
131 | chunk_nb, split_nb = self.test_seg[index]
132 | buffer = self.loadvideo_decord(sample)
133 |
134 | buffer = self.data_resize(buffer)
135 | if isinstance(buffer, list):
136 | buffer = np.stack(buffer, 0)
137 |
138 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
139 | / (self.test_num_crop - 1)
140 | temporal_start = chunk_nb # 0/1
141 | spatial_start = int(split_nb * spatial_step)
142 | if buffer.shape[1] >= buffer.shape[2]:
143 | buffer = buffer[temporal_start::2, \
144 | spatial_start:spatial_start + self.short_side_size, :, :]
145 | else:
146 | buffer = buffer[temporal_start::2, \
147 | :, spatial_start:spatial_start + self.short_side_size, :]
148 |
149 | buffer = self.data_transform(buffer)
150 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
151 | chunk_nb, split_nb
152 | else:
153 | raise NameError('mode {} unkown'.format(self.mode))
154 |
155 | def _aug_frame(
156 | self,
157 | buffer,
158 | args,
159 | ):
160 |
161 | buffer = [
162 | transforms.ToPILImage()(frame) for frame in buffer
163 | ]
164 |
165 | if args.is_aa:
166 | aug_transform = video_transforms.create_random_augment(
167 | input_size=(self.crop_size, self.crop_size),
168 | auto_augment=args.aa,
169 | interpolation=args.train_interpolation,
170 | )
171 | buffer = aug_transform(buffer)
172 |
173 | buffer = [transforms.ToTensor()(img) for img in buffer]
174 | buffer = torch.stack(buffer) # T C H W
175 | buffer = buffer.permute(0, 2, 3, 1) # T H W C
176 |
177 | # T H W C
178 | buffer = tensor_normalize(
179 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
180 | )
181 | # T H W C -> C T H W.
182 | buffer = buffer.permute(3, 0, 1, 2)
183 | # Perform data augmentation.
184 | scl, asp = (
185 | [0.08, 1.0],
186 | [0.75, 1.3333],
187 | )
188 |
189 | buffer = spatial_sampling(
190 | buffer,
191 | spatial_idx=-1,
192 | min_scale=256,
193 | max_scale=320,
194 | crop_size=self.crop_size,
195 | random_horizontal_flip=False if args.data_set == 'SSV2' else True,
196 | inverse_uniform_sampling=False,
197 | aspect_ratio=asp,
198 | scale=scl,
199 | motion_shift=False
200 | )
201 |
202 | # No random erase for linear probing or prompting.
203 | # if self.rand_erase:
204 | # erase_transform = RandomErasing(
205 | # args.reprob,
206 | # mode=args.remode,
207 | # max_count=args.recount,
208 | # num_splits=args.recount,
209 | # device="cpu",
210 | # )
211 | # buffer = buffer.permute(1, 0, 2, 3)
212 | # buffer = erase_transform(buffer)
213 | # buffer = buffer.permute(1, 0, 2, 3)
214 |
215 | return buffer
216 |
217 |
218 | def loadvideo_decord(self, sample, sample_rate_scale=1):
219 | """Load video content using Decord"""
220 | fname = sample
221 |
222 | if not (os.path.exists(fname)):
223 | return []
224 |
225 | # avoid hanging issue
226 | if os.path.getsize(fname) < 1 * 1024:
227 | print('SKIP: ', fname, " - ", os.path.getsize(fname))
228 | return []
229 | try:
230 | if self.keep_aspect_ratio:
231 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
232 | else:
233 | vr = VideoReader(fname, width=self.new_width, height=self.new_height,
234 | num_threads=1, ctx=cpu(0))
235 | except:
236 | print("video cannot be loaded by decord: ", fname)
237 | return []
238 |
239 | if self.mode == 'test':
240 | all_index = []
241 | tick = len(vr) / float(self.num_segment)
242 | all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
243 | [int(tick * x) for x in range(self.num_segment)]))
244 | while len(all_index) < (self.num_segment * self.test_num_segment):
245 | all_index.append(all_index[-1])
246 | all_index = list(np.sort(np.array(all_index)))
247 | vr.seek(0)
248 | buffer = vr.get_batch(all_index).asnumpy()
249 | return buffer
250 |
251 | # handle temporal segments
252 | average_duration = len(vr) // self.num_segment
253 | all_index = []
254 | if average_duration > 0:
255 | all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
256 | size=self.num_segment))
257 | elif len(vr) > self.num_segment:
258 | all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
259 | else:
260 | all_index += list(np.zeros((self.num_segment,)))
261 | all_index = list(np.array(all_index))
262 | vr.seek(0)
263 | buffer = vr.get_batch(all_index).asnumpy()
264 | return buffer
265 |
266 | def __len__(self):
267 | if self.mode != 'test':
268 | return len(self.dataset_samples)
269 | else:
270 | return len(self.test_dataset)
271 |
272 |
273 | def spatial_sampling(
274 | frames,
275 | spatial_idx=-1,
276 | min_scale=256,
277 | max_scale=320,
278 | crop_size=224,
279 | random_horizontal_flip=True,
280 | inverse_uniform_sampling=False,
281 | aspect_ratio=None,
282 | scale=None,
283 | motion_shift=False,
284 | ):
285 | """
286 | Perform spatial sampling on the given video frames. If spatial_idx is
287 | -1, perform random scale, random crop, and random flip on the given
288 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
289 | with the given spatial_idx.
290 | Args:
291 | frames (tensor): frames of images sampled from the video. The
292 | dimension is `num frames` x `height` x `width` x `channel`.
293 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
294 | or 2, perform left, center, right crop if width is larger than
295 | height, and perform top, center, buttom crop if height is larger
296 | than width.
297 | min_scale (int): the minimal size of scaling.
298 | max_scale (int): the maximal size of scaling.
299 | crop_size (int): the size of height and width used to crop the
300 | frames.
301 | inverse_uniform_sampling (bool): if True, sample uniformly in
302 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
303 | scale. If False, take a uniform sample from [min_scale,
304 | max_scale].
305 | aspect_ratio (list): Aspect ratio range for resizing.
306 | scale (list): Scale range for resizing.
307 | motion_shift (bool): Whether to apply motion shift for resizing.
308 | Returns:
309 | frames (tensor): spatially sampled frames.
310 | """
311 | assert spatial_idx in [-1, 0, 1, 2]
312 | if spatial_idx == -1:
313 | if aspect_ratio is None and scale is None:
314 | frames, _ = video_transforms.random_short_side_scale_jitter(
315 | images=frames,
316 | min_size=min_scale,
317 | max_size=max_scale,
318 | inverse_uniform_sampling=inverse_uniform_sampling,
319 | )
320 | frames, _ = video_transforms.random_crop(frames, crop_size)
321 | else:
322 | transform_func = (
323 | video_transforms.random_resized_crop_with_shift
324 | if motion_shift
325 | else video_transforms.random_resized_crop
326 | )
327 | frames = transform_func(
328 | images=frames,
329 | target_height=crop_size,
330 | target_width=crop_size,
331 | scale=scale,
332 | ratio=aspect_ratio,
333 | )
334 | if random_horizontal_flip:
335 | frames, _ = video_transforms.horizontal_flip(0.5, frames)
336 | else:
337 | # The testing is deterministic and no jitter should be performed.
338 | # min_scale, max_scale, and crop_size are expect to be the same.
339 | assert len({min_scale, max_scale, crop_size}) == 1
340 | frames, _ = video_transforms.random_short_side_scale_jitter(
341 | frames, min_scale, max_scale
342 | )
343 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
344 | return frames
345 |
346 |
347 | def tensor_normalize(tensor, mean, std):
348 | """
349 | Normalize a given tensor by subtracting the mean and dividing the std.
350 | Args:
351 | tensor (tensor): tensor to normalize.
352 | mean (tensor or list): mean value to subtract.
353 | std (tensor or list): std to divide.
354 | """
355 | if tensor.dtype == torch.uint8:
356 | tensor = tensor.float()
357 | tensor = tensor / 255.0
358 | if type(mean) == list:
359 | mean = torch.tensor(mean)
360 | if type(std) == list:
361 | std = torch.tensor(std)
362 | tensor = tensor - mean
363 | tensor = tensor / std
364 | return tensor
365 |
--------------------------------------------------------------------------------
/rand_augment.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
4 | pulished under an Apache License 2.0.
5 |
6 | COMMENT FROM ORIGINAL:
7 | AutoAugment, RandAugment, and AugMix for PyTorch
8 | This code implements the searched ImageNet policies with various tweaks and
9 | improvements and does not include any of the search code. AA and RA
10 | Implementation adapted from:
11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12 | AugMix adapted from:
13 | https://github.com/google-research/augmix
14 | Papers:
15 | AutoAugment: Learning Augmentation Policies from Data
16 | https://arxiv.org/abs/1805.09501
17 | Learning Data Augmentation Strategies for Object Detection
18 | https://arxiv.org/abs/1906.11172
19 | RandAugment: Practical automated data augmentation...
20 | https://arxiv.org/abs/1909.13719
21 | AugMix: A Simple Data Processing Method to Improve Robustness and
22 | Uncertainty https://arxiv.org/abs/1912.02781
23 |
24 | Hacked together by / Copyright 2020 Ross Wightman
25 | """
26 |
27 | import math
28 | import numpy as np
29 | import random
30 | import re
31 | import PIL
32 | from PIL import Image, ImageEnhance, ImageOps
33 |
34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
35 |
36 | _FILL = (128, 128, 128)
37 |
38 | # This signifies the max integer that the controller RNN could predict for the
39 | # augmentation scheme.
40 | _MAX_LEVEL = 10.0
41 |
42 | _HPARAMS_DEFAULT = {
43 | "translate_const": 250,
44 | "img_mean": _FILL,
45 | }
46 |
47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
48 |
49 |
50 | def _interpolation(kwargs):
51 | interpolation = kwargs.pop("resample", Image.BILINEAR)
52 | if isinstance(interpolation, (list, tuple)):
53 | return random.choice(interpolation)
54 | else:
55 | return interpolation
56 |
57 |
58 | def _check_args_tf(kwargs):
59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
60 | kwargs.pop("fillcolor")
61 | kwargs["resample"] = _interpolation(kwargs)
62 |
63 |
64 | def shear_x(img, factor, **kwargs):
65 | _check_args_tf(kwargs)
66 | return img.transform(
67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
68 | )
69 |
70 |
71 | def shear_y(img, factor, **kwargs):
72 | _check_args_tf(kwargs)
73 | return img.transform(
74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
75 | )
76 |
77 |
78 | def translate_x_rel(img, pct, **kwargs):
79 | pixels = pct * img.size[0]
80 | _check_args_tf(kwargs)
81 | return img.transform(
82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
83 | )
84 |
85 |
86 | def translate_y_rel(img, pct, **kwargs):
87 | pixels = pct * img.size[1]
88 | _check_args_tf(kwargs)
89 | return img.transform(
90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
91 | )
92 |
93 |
94 | def translate_x_abs(img, pixels, **kwargs):
95 | _check_args_tf(kwargs)
96 | return img.transform(
97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
98 | )
99 |
100 |
101 | def translate_y_abs(img, pixels, **kwargs):
102 | _check_args_tf(kwargs)
103 | return img.transform(
104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
105 | )
106 |
107 |
108 | def rotate(img, degrees, **kwargs):
109 | _check_args_tf(kwargs)
110 | if _PIL_VER >= (5, 2):
111 | return img.rotate(degrees, **kwargs)
112 | elif _PIL_VER >= (5, 0):
113 | w, h = img.size
114 | post_trans = (0, 0)
115 | rotn_center = (w / 2.0, h / 2.0)
116 | angle = -math.radians(degrees)
117 | matrix = [
118 | round(math.cos(angle), 15),
119 | round(math.sin(angle), 15),
120 | 0.0,
121 | round(-math.sin(angle), 15),
122 | round(math.cos(angle), 15),
123 | 0.0,
124 | ]
125 |
126 | def transform(x, y, matrix):
127 | (a, b, c, d, e, f) = matrix
128 | return a * x + b * y + c, d * x + e * y + f
129 |
130 | matrix[2], matrix[5] = transform(
131 | -rotn_center[0] - post_trans[0],
132 | -rotn_center[1] - post_trans[1],
133 | matrix,
134 | )
135 | matrix[2] += rotn_center[0]
136 | matrix[5] += rotn_center[1]
137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
138 | else:
139 | return img.rotate(degrees, resample=kwargs["resample"])
140 |
141 |
142 | def auto_contrast(img, **__):
143 | return ImageOps.autocontrast(img)
144 |
145 |
146 | def invert(img, **__):
147 | return ImageOps.invert(img)
148 |
149 |
150 | def equalize(img, **__):
151 | return ImageOps.equalize(img)
152 |
153 |
154 | def solarize(img, thresh, **__):
155 | return ImageOps.solarize(img, thresh)
156 |
157 |
158 | def solarize_add(img, add, thresh=128, **__):
159 | lut = []
160 | for i in range(256):
161 | if i < thresh:
162 | lut.append(min(255, i + add))
163 | else:
164 | lut.append(i)
165 | if img.mode in ("L", "RGB"):
166 | if img.mode == "RGB" and len(lut) == 256:
167 | lut = lut + lut + lut
168 | return img.point(lut)
169 | else:
170 | return img
171 |
172 |
173 | def posterize(img, bits_to_keep, **__):
174 | if bits_to_keep >= 8:
175 | return img
176 | return ImageOps.posterize(img, bits_to_keep)
177 |
178 |
179 | def contrast(img, factor, **__):
180 | return ImageEnhance.Contrast(img).enhance(factor)
181 |
182 |
183 | def color(img, factor, **__):
184 | return ImageEnhance.Color(img).enhance(factor)
185 |
186 |
187 | def brightness(img, factor, **__):
188 | return ImageEnhance.Brightness(img).enhance(factor)
189 |
190 |
191 | def sharpness(img, factor, **__):
192 | return ImageEnhance.Sharpness(img).enhance(factor)
193 |
194 |
195 | def _randomly_negate(v):
196 | """With 50% prob, negate the value"""
197 | return -v if random.random() > 0.5 else v
198 |
199 |
200 | def _rotate_level_to_arg(level, _hparams):
201 | # range [-30, 30]
202 | level = (level / _MAX_LEVEL) * 30.0
203 | level = _randomly_negate(level)
204 | return (level,)
205 |
206 |
207 | def _enhance_level_to_arg(level, _hparams):
208 | # range [0.1, 1.9]
209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
210 |
211 |
212 | def _enhance_increasing_level_to_arg(level, _hparams):
213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
214 | # range [0.1, 1.9]
215 | level = (level / _MAX_LEVEL) * 0.9
216 | level = 1.0 + _randomly_negate(level)
217 | return (level,)
218 |
219 |
220 | def _shear_level_to_arg(level, _hparams):
221 | # range [-0.3, 0.3]
222 | level = (level / _MAX_LEVEL) * 0.3
223 | level = _randomly_negate(level)
224 | return (level,)
225 |
226 |
227 | def _translate_abs_level_to_arg(level, hparams):
228 | translate_const = hparams["translate_const"]
229 | level = (level / _MAX_LEVEL) * float(translate_const)
230 | level = _randomly_negate(level)
231 | return (level,)
232 |
233 |
234 | def _translate_rel_level_to_arg(level, hparams):
235 | # default range [-0.45, 0.45]
236 | translate_pct = hparams.get("translate_pct", 0.45)
237 | level = (level / _MAX_LEVEL) * translate_pct
238 | level = _randomly_negate(level)
239 | return (level,)
240 |
241 |
242 | def _posterize_level_to_arg(level, _hparams):
243 | # As per Tensorflow TPU EfficientNet impl
244 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
245 | # intensity/severity of augmentation decreases with level
246 | return (int((level / _MAX_LEVEL) * 4),)
247 |
248 |
249 | def _posterize_increasing_level_to_arg(level, hparams):
250 | # As per Tensorflow models research and UDA impl
251 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
252 | # intensity/severity of augmentation increases with level
253 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
254 |
255 |
256 | def _posterize_original_level_to_arg(level, _hparams):
257 | # As per original AutoAugment paper description
258 | # range [4, 8], 'keep 4 up to 8 MSB of image'
259 | # intensity/severity of augmentation decreases with level
260 | return (int((level / _MAX_LEVEL) * 4) + 4,)
261 |
262 |
263 | def _solarize_level_to_arg(level, _hparams):
264 | # range [0, 256]
265 | # intensity/severity of augmentation decreases with level
266 | return (int((level / _MAX_LEVEL) * 256),)
267 |
268 |
269 | def _solarize_increasing_level_to_arg(level, _hparams):
270 | # range [0, 256]
271 | # intensity/severity of augmentation increases with level
272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
273 |
274 |
275 | def _solarize_add_level_to_arg(level, _hparams):
276 | # range [0, 110]
277 | return (int((level / _MAX_LEVEL) * 110),)
278 |
279 |
280 | LEVEL_TO_ARG = {
281 | "AutoContrast": None,
282 | "Equalize": None,
283 | "Invert": None,
284 | "Rotate": _rotate_level_to_arg,
285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
286 | "Posterize": _posterize_level_to_arg,
287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
288 | "PosterizeOriginal": _posterize_original_level_to_arg,
289 | "Solarize": _solarize_level_to_arg,
290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
291 | "SolarizeAdd": _solarize_add_level_to_arg,
292 | "Color": _enhance_level_to_arg,
293 | "ColorIncreasing": _enhance_increasing_level_to_arg,
294 | "Contrast": _enhance_level_to_arg,
295 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
296 | "Brightness": _enhance_level_to_arg,
297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
298 | "Sharpness": _enhance_level_to_arg,
299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
300 | "ShearX": _shear_level_to_arg,
301 | "ShearY": _shear_level_to_arg,
302 | "TranslateX": _translate_abs_level_to_arg,
303 | "TranslateY": _translate_abs_level_to_arg,
304 | "TranslateXRel": _translate_rel_level_to_arg,
305 | "TranslateYRel": _translate_rel_level_to_arg,
306 | }
307 |
308 |
309 | NAME_TO_OP = {
310 | "AutoContrast": auto_contrast,
311 | "Equalize": equalize,
312 | "Invert": invert,
313 | "Rotate": rotate,
314 | "Posterize": posterize,
315 | "PosterizeIncreasing": posterize,
316 | "PosterizeOriginal": posterize,
317 | "Solarize": solarize,
318 | "SolarizeIncreasing": solarize,
319 | "SolarizeAdd": solarize_add,
320 | "Color": color,
321 | "ColorIncreasing": color,
322 | "Contrast": contrast,
323 | "ContrastIncreasing": contrast,
324 | "Brightness": brightness,
325 | "BrightnessIncreasing": brightness,
326 | "Sharpness": sharpness,
327 | "SharpnessIncreasing": sharpness,
328 | "ShearX": shear_x,
329 | "ShearY": shear_y,
330 | "TranslateX": translate_x_abs,
331 | "TranslateY": translate_y_abs,
332 | "TranslateXRel": translate_x_rel,
333 | "TranslateYRel": translate_y_rel,
334 | }
335 |
336 |
337 | class AugmentOp:
338 | """
339 | Apply for video.
340 | """
341 |
342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
343 | hparams = hparams or _HPARAMS_DEFAULT
344 | self.aug_fn = NAME_TO_OP[name]
345 | self.level_fn = LEVEL_TO_ARG[name]
346 | self.prob = prob
347 | self.magnitude = magnitude
348 | self.hparams = hparams.copy()
349 | self.kwargs = {
350 | "fillcolor": hparams["img_mean"]
351 | if "img_mean" in hparams
352 | else _FILL,
353 | "resample": hparams["interpolation"]
354 | if "interpolation" in hparams
355 | else _RANDOM_INTERPOLATION,
356 | }
357 |
358 | # If magnitude_std is > 0, we introduce some randomness
359 | # in the usually fixed policy and sample magnitude from a normal distribution
360 | # with mean `magnitude` and std-dev of `magnitude_std`.
361 | # NOTE This is my own hack, being tested, not in papers or reference impls.
362 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
363 |
364 | def __call__(self, img_list):
365 | if self.prob < 1.0 and random.random() > self.prob:
366 | return img_list
367 | magnitude = self.magnitude
368 | if self.magnitude_std and self.magnitude_std > 0:
369 | magnitude = random.gauss(magnitude, self.magnitude_std)
370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
371 | level_args = (
372 | self.level_fn(magnitude, self.hparams)
373 | if self.level_fn is not None
374 | else ()
375 | )
376 |
377 | if isinstance(img_list, list):
378 | return [
379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
380 | ]
381 | else:
382 | return self.aug_fn(img_list, *level_args, **self.kwargs)
383 |
384 |
385 | _RAND_TRANSFORMS = [
386 | "AutoContrast",
387 | "Equalize",
388 | "Invert",
389 | "Rotate",
390 | "Posterize",
391 | "Solarize",
392 | "SolarizeAdd",
393 | "Color",
394 | "Contrast",
395 | "Brightness",
396 | "Sharpness",
397 | "ShearX",
398 | "ShearY",
399 | "TranslateXRel",
400 | "TranslateYRel",
401 | ]
402 |
403 |
404 | _RAND_INCREASING_TRANSFORMS = [
405 | "AutoContrast",
406 | "Equalize",
407 | "Invert",
408 | "Rotate",
409 | "PosterizeIncreasing",
410 | "SolarizeIncreasing",
411 | "SolarizeAdd",
412 | "ColorIncreasing",
413 | "ContrastIncreasing",
414 | "BrightnessIncreasing",
415 | "SharpnessIncreasing",
416 | "ShearX",
417 | "ShearY",
418 | "TranslateXRel",
419 | "TranslateYRel",
420 | ]
421 |
422 |
423 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
424 | # They may not result in increased performance, but could likely be tuned to so.
425 | _RAND_CHOICE_WEIGHTS_0 = {
426 | "Rotate": 0.3,
427 | "ShearX": 0.2,
428 | "ShearY": 0.2,
429 | "TranslateXRel": 0.1,
430 | "TranslateYRel": 0.1,
431 | "Color": 0.025,
432 | "Sharpness": 0.025,
433 | "AutoContrast": 0.025,
434 | "Solarize": 0.005,
435 | "SolarizeAdd": 0.005,
436 | "Contrast": 0.005,
437 | "Brightness": 0.005,
438 | "Equalize": 0.005,
439 | "Posterize": 0,
440 | "Invert": 0,
441 | }
442 |
443 |
444 | def _select_rand_weights(weight_idx=0, transforms=None):
445 | transforms = transforms or _RAND_TRANSFORMS
446 | assert weight_idx == 0 # only one set of weights currently
447 | rand_weights = _RAND_CHOICE_WEIGHTS_0
448 | probs = [rand_weights[k] for k in transforms]
449 | probs /= np.sum(probs)
450 | return probs
451 |
452 |
453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
454 | hparams = hparams or _HPARAMS_DEFAULT
455 | transforms = transforms or _RAND_TRANSFORMS
456 | return [
457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
458 | for name in transforms
459 | ]
460 |
461 |
462 | class RandAugment:
463 | def __init__(self, ops, num_layers=2, choice_weights=None):
464 | self.ops = ops
465 | self.num_layers = num_layers
466 | self.choice_weights = choice_weights
467 |
468 | def __call__(self, img):
469 | # no replacement when using weighted choice
470 | ops = np.random.choice(
471 | self.ops,
472 | self.num_layers,
473 | replace=self.choice_weights is None,
474 | p=self.choice_weights,
475 | )
476 | for op in ops:
477 | img = op(img)
478 | return img
479 |
480 |
481 | def rand_augment_transform(config_str, hparams):
482 | """
483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
484 |
485 | Create a RandAugment transform
486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
488 | sections, not order sepecific determine
489 | 'm' - integer magnitude of rand augment
490 | 'n' - integer num layers (number of transform ops selected per image)
491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
492 | 'mstd' - float std deviation of magnitude noise applied
493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
497 | :return: A PyTorch compatible Transform
498 | """
499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
500 | num_layers = 2 # default to 2 ops per image
501 | weight_idx = None # default to no probability weights for op choice
502 | transforms = _RAND_TRANSFORMS
503 | config = config_str.split("-")
504 | assert config[0] == "rand"
505 | config = config[1:]
506 | for c in config:
507 | cs = re.split(r"(\d.*)", c)
508 | if len(cs) < 2:
509 | continue
510 | key, val = cs[:2]
511 | if key == "mstd":
512 | # noise param injected via hparams for now
513 | hparams.setdefault("magnitude_std", float(val))
514 | elif key == "inc":
515 | if bool(val):
516 | transforms = _RAND_INCREASING_TRANSFORMS
517 | elif key == "m":
518 | magnitude = int(val)
519 | elif key == "n":
520 | num_layers = int(val)
521 | elif key == "w":
522 | weight_idx = int(val)
523 | else:
524 | assert NotImplementedError
525 | ra_ops = rand_augment_ops(
526 | magnitude=magnitude, hparams=hparams, transforms=transforms
527 | )
528 | choice_weights = (
529 | None if weight_idx is None else _select_rand_weights(weight_idx)
530 | )
531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
532 |
--------------------------------------------------------------------------------