├── figs
├── CAST.jpg
└── ds_report.JPG
├── util_tools
├── masking_generator.py
├── functional.py
├── volume_transforms.py
├── random_erasing.py
├── optim_factory.py
├── transforms.py
├── mixup.py
└── rand_augment.py
├── scripts
├── ssv2
│ └── ssv2.sh
├── ek100
│ └── ek100.sh
└── kinetics
│ └── k400.sh
├── dataset
├── datasets.py
├── ssv2.py
└── epic.py
├── README.md
├── engine_for_onemodel.py
├── engine_for_compomodel.py
├── LICENSE
└── models
├── videomae_modelling_finetune.py
└── clip_modelling_finetune.py
/figs/CAST.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KHU-VLL/CAST/HEAD/figs/CAST.jpg
--------------------------------------------------------------------------------
/figs/ds_report.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KHU-VLL/CAST/HEAD/figs/ds_report.JPG
--------------------------------------------------------------------------------
/util_tools/masking_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class TubeMaskingGenerator:
4 | def __init__(self, input_size, mask_ratio):
5 | self.frames, self.height, self.width = input_size
6 | self.num_patches_per_frame = self.height * self.width
7 | self.total_patches = self.frames * self.num_patches_per_frame
8 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
9 | self.total_masks = self.frames * self.num_masks_per_frame
10 |
11 | def __repr__(self):
12 | repr_str = "Maks: total patches {}, mask patches {}".format(
13 | self.total_patches, self.total_masks
14 | )
15 | return repr_str
16 |
17 | def __call__(self):
18 | mask_per_frame = np.hstack([
19 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
20 | np.ones(self.num_masks_per_frame),
21 | ])
22 | np.random.shuffle(mask_per_frame)
23 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten()
24 | return mask
--------------------------------------------------------------------------------
/scripts/ssv2/ssv2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=YOUR_PATH
4 | VMAE_MODEL_PATH=YOUR_PATH
5 | CLIP_MODEL_PATH=YOUR_PATH
6 |
7 | OUTPUT_DIR=YOUR_PATH
8 | MASTER_NODE=$1
9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \
10 | --nproc_per_node=$6 \
11 | --master_port $3 --nnodes=$5 \
12 | --node_rank=$2 --master_addr=${MASTER_NODE} \
13 | YOUR_PATH/run_bidirection.py \
14 | --data_set SSV2 \
15 | --nb_classes 174 \
16 | --vmae_model bidir_vit_base_patch16_224 \
17 | --data_path ${DATA_PATH} \
18 | --anno_path ${YOUR_PATH} \
19 | --clip_finetune ${CLIP_MODEL_PATH} \
20 | --vmae_finetune ${VMAE_MODEL_PATH} \
21 | --log_dir ${YOUR_PAHT} \
22 | --output_dir ${YOUR_PAHT} \
23 | --batch_size 6 \
24 | --input_size 224 \
25 | --short_side_size 224 \
26 | --save_ckpt_freq 25 \
27 | --num_sample 1 \
28 | --num_frames 16 \
29 | --opt adamw \
30 | --lr 1e-3 \
31 | --opt_betas 0.9 0.999 \
32 | --weight_decay 0.05 \
33 | --epochs 50 \
34 | --dist_eval \
35 | --test_num_segment 2 \
36 | --test_num_crop 3 \
37 | --num_workers 8 \
38 | --seed 0 \
39 | --mixup_switch_prob 0 \
40 | --mixup_prob 0.9 \
41 | --reprob 0. \
42 | --init_scale 1. \
43 | --enable_deepspeed \
44 | --warmup_epochs 5 \
45 | --update_freq 4 \
46 | --drop_path 0.3
47 |
48 | echo "Job finish"
--------------------------------------------------------------------------------
/scripts/ek100/ek100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=YOUR_PATH
4 | VMAE_MODEL_PATH=YOUR_PATH
5 | CLIP_MODEL_PATH=YOUR_PATH
6 |
7 | OUTPUT_DIR=YOUR_PATH
8 | MASTER_NODE=$1
9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \
10 | --nproc_per_node=$6 \
11 | --master_port $3 --nnodes=$5 \
12 | --node_rank=$2 --master_addr=${MASTER_NODE} \
13 | {YOUR_PATH}/run_bidirection_compo.py \
14 | --data_set EPIC \
15 | --nb_classes 300 \
16 | --vmae_model compo_bidir_vit_base_patch16_224 \
17 | --data_path ${DATA_PATH} \
18 | --anno_path YOUR_PATH \
19 | --clip_finetune ${CLIP_MODEL_PATH} \
20 | --vmae_finetune ${VMAE_MODEL_PATH} \
21 | --log_dir ${OUTPUT_DIR} \
22 | --output_dir ${OUTPUT_DIR} \
23 | --batch_size 6 \
24 | --input_size 224 \
25 | --short_side_size 224 \
26 | --save_ckpt_freq 25 \
27 | --num_sample 1 \
28 | --num_frames 16 \
29 | --opt adamw \
30 | --lr 1e-3 \
31 | --opt_betas 0.9 0.999 \
32 | --weight_decay 0.05 \
33 | --epochs 50 \
34 | --dist_eval \
35 | --test_num_segment 2 \
36 | --test_num_crop 3 \
37 | --num_workers 8 \
38 | --drop_path 0.2 \
39 | --mixup_switch_prob 0 \
40 | --mixup_prob 0.9 \
41 | --reprob 0. \
42 | --init_scale 1. \
43 | --update_freq 4 \
44 | --seed 0 \
45 | --enable_deepspeed \
46 | --warmup_epochs 5 \
47 | --composition \
48 |
49 | echo "Job finish"
--------------------------------------------------------------------------------
/scripts/kinetics/k400.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_PATH=YOUR_PATH
4 | VMAE_MODEL_PATH=YOUR_PATH
5 | CLIP_MODEL_PATH=YOUR_PATH
6 |
7 | OUTPUT_DIR=YOUR_PATH
8 | MASTER_NODE=$1
9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \
10 | --nproc_per_node=$6 \
11 | --master_port $3 --nnodes=$5 \
12 | --node_rank=$2 --master_addr=${MASTER_NODE} \
13 | YOUR_PATH/run_bidirection.py \
14 | --data_set Kinetics-400 \
15 | --nb_classes 400 \
16 | --vmae_model bidir_vit_base_patch16_224 \
17 | --anno_path ${ANNOTATION_PATH} \
18 | --data_path ${DATA_PATH} \
19 | --clip_finetune ${CLIP_MODEL_PATH} \
20 | --vmae_finetune ${VMAE_MODEL_PATH} \
21 | --log_dir ${YOUR_PATH} \
22 | --output_dir ${YOUR_PATH} \
23 | --batch_size 6 \
24 | --input_size 224 \
25 | --short_side_size 224 \
26 | --save_ckpt_freq 25 \
27 | --num_sample 1 \
28 | --num_frames 16 \
29 | --opt adamw \
30 | --lr 1e-3 \
31 | --opt_betas 0.9 0.999 \
32 | --weight_decay 0.05 \
33 | --epochs 70 \
34 | --dist_eval \
35 | --test_num_segment 5 \
36 | --test_num_crop 3 \
37 | --num_workers 8 \
38 | --drop_path 0.2 \
39 | --layer_decay 0.75 \
40 | --mixup_switch_prob 0 \
41 | --mixup_prob 0.5 \
42 | --reprob 0. \
43 | --init_scale 1. \
44 | --update_freq 6 \
45 | --seed 0 \
46 | --enable_deepspeed \
47 | --warmup_epochs 5 \
48 |
49 | echo "Job finish"
--------------------------------------------------------------------------------
/util_tools/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import cv2
3 | import numpy as np
4 | import PIL
5 | import torch
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [
18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19 | ]
20 | else:
21 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
22 | 'but got list of {0}'.format(type(clip[0])))
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation='bilinear'):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
32 | and im_h == size):
33 | return clip
34 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
35 | size = (new_w, new_h)
36 | else:
37 | size = size[0], size[1]
38 | if interpolation == 'bilinear':
39 | np_inter = cv2.INTER_LINEAR
40 | else:
41 | np_inter = cv2.INTER_NEAREST
42 | scaled = [
43 | cv2.resize(img, size, interpolation=np_inter) for img in clip
44 | ]
45 | elif isinstance(clip[0], PIL.Image.Image):
46 | if isinstance(size, numbers.Number):
47 | im_w, im_h = clip[0].size
48 | # Min spatial dim already matches minimal size
49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
50 | and im_h == size):
51 | return clip
52 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
53 | size = (new_w, new_h)
54 | else:
55 | size = size[1], size[0]
56 | if interpolation == 'bilinear':
57 | pil_inter = PIL.Image.BILINEAR
58 | else:
59 | pil_inter = PIL.Image.NEAREST
60 | scaled = [img.resize(size, pil_inter) for img in clip]
61 | else:
62 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
63 | 'but got list of {0}'.format(type(clip[0])))
64 | return scaled
65 |
66 |
67 | def get_resize_sizes(im_h, im_w, size):
68 | if im_w < im_h:
69 | ow = size
70 | oh = int(size * im_h / im_w)
71 | else:
72 | oh = size
73 | ow = int(size * im_w / im_h)
74 | return oh, ow
75 |
76 |
77 | def normalize(clip, mean, std, inplace=False):
78 | if not _is_tensor_clip(clip):
79 | raise TypeError('tensor is not a torch clip.')
80 |
81 | if not inplace:
82 | clip = clip.clone()
83 |
84 | dtype = clip.dtype
85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
88 |
89 | return clip
90 |
--------------------------------------------------------------------------------
/util_tools/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 |
6 | def convert_img(img):
7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format
8 | """
9 | if len(img.shape) == 3:
10 | img = img.transpose(2, 0, 1)
11 | if len(img.shape) == 2:
12 | img = np.expand_dims(img, 0)
13 | return img
14 |
15 |
16 | class ClipToTensor(object):
17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
19 | """
20 |
21 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
22 | self.channel_nb = channel_nb
23 | self.div_255 = div_255
24 | self.numpy = numpy
25 |
26 | def __call__(self, clip):
27 | """
28 | Args: clip (list of numpy.ndarray): clip (list of images)
29 | to be converted to tensor.
30 | """
31 | # Retrieve shape
32 | if isinstance(clip[0], np.ndarray):
33 | h, w, ch = clip[0].shape
34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
35 | ch)
36 | elif isinstance(clip[0], Image.Image):
37 | w, h = clip[0].size
38 | else:
39 | raise TypeError('Expected numpy.ndarray or PIL.Image\
40 | but got list of {0}'.format(type(clip[0])))
41 |
42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
43 |
44 | # Convert
45 | for img_idx, img in enumerate(clip):
46 | if isinstance(img, np.ndarray):
47 | pass
48 | elif isinstance(img, Image.Image):
49 | img = np.array(img, copy=False)
50 | else:
51 | raise TypeError('Expected numpy.ndarray or PIL.Image\
52 | but got list of {0}'.format(type(clip[0])))
53 | img = convert_img(img)
54 | np_clip[:, img_idx, :, :] = img
55 | if self.numpy:
56 | if self.div_255:
57 | np_clip = np_clip / 255.0
58 | return np_clip
59 |
60 | else:
61 | tensor_clip = torch.from_numpy(np_clip)
62 |
63 | if not isinstance(tensor_clip, torch.FloatTensor):
64 | tensor_clip = tensor_clip.float()
65 | if self.div_255:
66 | tensor_clip = torch.div(tensor_clip, 255)
67 | return tensor_clip
68 |
69 |
70 | # Note this norms data to -1/1
71 | class ClipToTensor_K(object):
72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
74 | """
75 |
76 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
77 | self.channel_nb = channel_nb
78 | self.div_255 = div_255
79 | self.numpy = numpy
80 |
81 | def __call__(self, clip):
82 | """
83 | Args: clip (list of numpy.ndarray): clip (list of images)
84 | to be converted to tensor.
85 | """
86 | # Retrieve shape
87 | if isinstance(clip[0], np.ndarray):
88 | h, w, ch = clip[0].shape
89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
90 | ch)
91 | elif isinstance(clip[0], Image.Image):
92 | w, h = clip[0].size
93 | else:
94 | raise TypeError('Expected numpy.ndarray or PIL.Image\
95 | but got list of {0}'.format(type(clip[0])))
96 |
97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
98 |
99 | # Convert
100 | for img_idx, img in enumerate(clip):
101 | if isinstance(img, np.ndarray):
102 | pass
103 | elif isinstance(img, Image.Image):
104 | img = np.array(img, copy=False)
105 | else:
106 | raise TypeError('Expected numpy.ndarray or PIL.Image\
107 | but got list of {0}'.format(type(clip[0])))
108 | img = convert_img(img)
109 | np_clip[:, img_idx, :, :] = img
110 | if self.numpy:
111 | if self.div_255:
112 | np_clip = (np_clip - 127.5) / 127.5
113 | return np_clip
114 |
115 | else:
116 | tensor_clip = torch.from_numpy(np_clip)
117 |
118 | if not isinstance(tensor_clip, torch.FloatTensor):
119 | tensor_clip = tensor_clip.float()
120 | if self.div_255:
121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
122 | return tensor_clip
123 |
124 |
125 | class ToTensor(object):
126 | """Converts numpy array to tensor
127 | """
128 |
129 | def __call__(self, array):
130 | array = np.load(array)
131 | tensor = torch.from_numpy(array)
132 | return tensor
133 |
--------------------------------------------------------------------------------
/dataset/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import transforms
3 | from util_tools.transforms import *
4 | from util_tools.masking_generator import TubeMaskingGenerator
5 | from .kinetics import VideoClsDataset, VideoMAE
6 | from .ssv2 import SSVideoClsDataset
7 | from .epic import EpicVideoClsDataset
8 |
9 |
10 | class DataAugmentationForVideoMAE(object):
11 | def __init__(self, args):
12 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
13 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
14 | normalize = GroupNormalize(self.input_mean, self.input_std)
15 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
16 | self.transform = transforms.Compose([
17 | self.train_augmentation,
18 | Stack(roll=False),
19 | ToTorchFormatTensor(div=True),
20 | normalize,
21 | ])
22 | if args.mask_type == 'tube':
23 | self.masked_position_generator = TubeMaskingGenerator(
24 | args.window_size, args.mask_ratio
25 | )
26 |
27 | def __call__(self, images):
28 | process_data, _ = self.transform(images)
29 | return process_data, self.masked_position_generator()
30 |
31 | def __repr__(self):
32 | repr = "(DataAugmentationForVideoMAE,\n"
33 | repr += " transform = %s,\n" % str(self.transform)
34 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
35 | repr += ")"
36 | return repr
37 |
38 |
39 | def build_pretraining_dataset(args):
40 | transform = DataAugmentationForVideoMAE(args)
41 | dataset = VideoMAE(
42 | root=None,
43 | setting=args.data_path,
44 | video_ext='mp4',
45 | is_color=True,
46 | modality='rgb',
47 | new_length=args.num_frames,
48 | new_step=args.sampling_rate,
49 | transform=transform,
50 | temporal_jitter=False,
51 | video_loader=True,
52 | use_decord=True,
53 | lazy_init=False)
54 | print("Data Aug = %s" % str(transform))
55 | return dataset
56 |
57 |
58 | def build_dataset(is_train, test_mode, args):
59 | if args.data_set == 'Kinetics-400':
60 | mode = None
61 | anno_path = args.anno_path
62 | if is_train is True:
63 | mode = 'train'
64 | anno_path = os.path.join(args.anno_path, 'train.csv')
65 | elif test_mode is True:
66 | mode = 'test'
67 | anno_path = os.path.join(args.anno_path, 'val.csv')
68 | else:
69 | mode = 'validation'
70 | anno_path = os.path.join(args.anno_path, 'val.csv')
71 |
72 | dataset = VideoClsDataset(
73 | anno_path=anno_path,
74 | data_path=args.data_path,
75 | mode=mode,
76 | clip_len=args.num_frames,
77 | frame_sample_rate=args.sampling_rate,
78 | num_segment=1,
79 | test_num_segment=args.test_num_segment,
80 | test_num_crop=args.test_num_crop,
81 | num_crop=1 if not test_mode else 3,
82 | keep_aspect_ratio=True,
83 | crop_size=args.input_size,
84 | short_side_size=args.short_side_size,
85 | new_height=256,
86 | new_width=320,
87 | args=args)
88 | nb_classes = 400
89 |
90 | elif args.data_set == 'SSV2':
91 | mode = None
92 | anno_path = None
93 | if is_train is True:
94 | mode = 'train'
95 | anno_path = os.path.join(args.anno_path, 'train.csv')
96 | elif test_mode is True:
97 | mode = 'test'
98 | anno_path = os.path.join(args.anno_path, 'val.csv')
99 | else:
100 | mode = 'validation'
101 | anno_path = os.path.join(args.anno_path, 'val.csv')
102 |
103 | dataset = SSVideoClsDataset(
104 | anno_path=anno_path,
105 | data_path=args.data_path,
106 | mode=mode,
107 | clip_len=1,
108 | num_segment=args.num_frames,
109 | test_num_segment=args.test_num_segment,
110 | test_num_crop=args.test_num_crop,
111 | num_crop=1 if not test_mode else 3,
112 | keep_aspect_ratio=True,
113 | crop_size=args.input_size,
114 | short_side_size=args.short_side_size,
115 | new_height=256,
116 | new_width=320,
117 | args=args)
118 | nb_classes = 174
119 |
120 | elif args.data_set == 'EPIC':
121 | mode = None
122 | anno_path = None
123 | if is_train is True:
124 | mode = 'train'
125 | anno_path = os.path.join(args.anno_path, 'train.csv')
126 | elif test_mode is True:
127 | mode = 'test'
128 | anno_path = os.path.join(args.anno_path, 'val.csv')
129 | else:
130 | mode = 'validation'
131 | anno_path = os.path.join(args.anno_path, 'val.csv')
132 |
133 | dataset = EpicVideoClsDataset(
134 | anno_path=anno_path,
135 | data_path=args.data_path,
136 | mode=mode,
137 | clip_len=1,
138 | num_segment=args.num_frames,
139 | test_num_segment=args.test_num_segment,
140 | test_num_crop=args.test_num_crop,
141 | num_crop=1 if not test_mode else 3,
142 | keep_aspect_ratio=True,
143 | crop_size=args.input_size,
144 | short_side_size=args.short_side_size,
145 | new_height=256,
146 | new_width=320,
147 | args=args)
148 | nb_classes = 300
149 |
150 | else:
151 | raise NotImplementedError()
152 |
153 | assert nb_classes == args.nb_classes
154 | print("Number of the class = %d" % args.nb_classes)
155 |
156 | return dataset, nb_classes
--------------------------------------------------------------------------------
/util_tools/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/util_tools/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | #from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_vit(var_name, num_max_layer):
25 | if var_name in ("cls_token", "mask_token", "pos_embed"):
26 | return 0
27 | elif var_name in ("visual.class_embedding", "visual.positional_embedding", "visual.temporal_posembed", "visual.ln_pre"):
28 | return 0
29 | elif var_name in ("clip_class_embedding", "clip_ln_pre.weight","clip_ln_pre.bias", "clip_conv1.weight", "clip_conv1.bias", "clip_positional_embedding"):
30 | return 0
31 | elif var_name.startswith("patch_embed"):
32 | return 0
33 | elif var_name.startswith("visual.conv1"):
34 | return 0
35 | elif var_name.startswith("rel_pos_bias"):
36 | return num_max_layer - 1
37 | elif var_name.startswith("blocks"):
38 | layer_id = int(var_name.split('.')[1])
39 | return layer_id + 1
40 | elif var_name.startswith("visual.transformer"):
41 | layer_id = int(var_name.split('.')[3])
42 | return layer_id + 1
43 | else:
44 | return num_max_layer - 1
45 |
46 |
47 | class LayerDecayValueAssigner(object):
48 | def __init__(self, values):
49 | self.values = values
50 |
51 | def get_scale(self, layer_id):
52 | return self.values[layer_id]
53 |
54 | def get_layer_id(self, var_name):
55 | return get_num_layer_for_vit(var_name, len(self.values))
56 |
57 |
58 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
59 | parameter_group_names = {}
60 | parameter_group_vars = {}
61 |
62 | for name, param in model.named_parameters():
63 | if not param.requires_grad:
64 | continue # frozen weights
65 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
66 | group_name = "no_decay"
67 | this_weight_decay = 0.
68 | elif any(s in name for s in skip_list):
69 | group_name = "no_decay"
70 | this_weight_decay = 0.
71 | else:
72 | group_name = "decay"
73 | this_weight_decay = weight_decay
74 | if get_num_layer is not None:
75 | layer_id = get_num_layer(name)
76 | group_name = "layer_%d_%s" % (layer_id, group_name)
77 | else:
78 | layer_id = None
79 |
80 | if group_name not in parameter_group_names:
81 | if get_layer_scale is not None:
82 | scale = get_layer_scale(layer_id)
83 | else:
84 | scale = 1.
85 |
86 | parameter_group_names[group_name] = {
87 | "weight_decay": this_weight_decay,
88 | "params": [],
89 | "lr_scale": scale
90 | }
91 | parameter_group_vars[group_name] = {
92 | "weight_decay": this_weight_decay,
93 | "params": [],
94 | "lr_scale": scale
95 | }
96 |
97 | parameter_group_vars[group_name]["params"].append(param)
98 | parameter_group_names[group_name]["params"].append(name)
99 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
100 | return list(parameter_group_vars.values())
101 |
102 |
103 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
104 | opt_lower = args.opt.lower()
105 | weight_decay = args.weight_decay
106 | if weight_decay and filter_bias_and_bn:
107 | skip = {}
108 | if skip_list is not None:
109 | skip = skip_list
110 | elif hasattr(model, 'no_weight_decay'):
111 | skip = model.no_weight_decay()
112 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
113 | weight_decay = 0.
114 | else:
115 | parameters = model.parameters()
116 |
117 | if 'fused' in opt_lower:
118 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
119 |
120 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
121 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
122 | opt_args['eps'] = args.opt_eps
123 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
124 | opt_args['betas'] = args.opt_betas
125 |
126 | print("optimizer settings:", opt_args)
127 |
128 | opt_split = opt_lower.split('_')
129 | opt_lower = opt_split[-1]
130 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
131 | opt_args.pop('eps', None)
132 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
133 | elif opt_lower == 'momentum':
134 | opt_args.pop('eps', None)
135 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
136 | elif opt_lower == 'adam':
137 | optimizer = optim.Adam(parameters, **opt_args)
138 | elif opt_lower == 'adamw':
139 | optimizer = optim.AdamW(parameters, **opt_args)
140 | elif opt_lower == 'nadam':
141 | optimizer = Nadam(parameters, **opt_args)
142 | elif opt_lower == 'radam':
143 | optimizer = RAdam(parameters, **opt_args)
144 | elif opt_lower == 'adamp':
145 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
146 | elif opt_lower == 'sgdp':
147 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
148 | elif opt_lower == 'adadelta':
149 | optimizer = optim.Adadelta(parameters, **opt_args)
150 | elif opt_lower == 'adafactor':
151 | if not args.lr:
152 | opt_args['lr'] = None
153 | optimizer = Adafactor(parameters, **opt_args)
154 | elif opt_lower == 'adahessian':
155 | optimizer = Adahessian(parameters, **opt_args)
156 | elif opt_lower == 'rmsprop':
157 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
158 | elif opt_lower == 'rmsproptf':
159 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
160 | elif opt_lower == 'novograd':
161 | optimizer = NovoGrad(parameters, **opt_args)
162 | elif opt_lower == 'nvnovograd':
163 | optimizer = NvNovoGrad(parameters, **opt_args)
164 | elif opt_lower == 'fusedsgd':
165 | opt_args.pop('eps', None)
166 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
167 | elif opt_lower == 'fusedmomentum':
168 | opt_args.pop('eps', None)
169 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
170 | elif opt_lower == 'fusedadam':
171 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
172 | elif opt_lower == 'fusedadamw':
173 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
174 | elif opt_lower == 'fusedlamb':
175 | optimizer = FusedLAMB(parameters, **opt_args)
176 | elif opt_lower == 'fusednovograd':
177 | opt_args.setdefault('betas', (0.95, 0.98))
178 | optimizer = FusedNovoGrad(parameters, **opt_args)
179 | else:
180 | assert False and "Invalid optimizer"
181 | raise ValueError
182 |
183 | if len(opt_split) > 1:
184 | if opt_split[0] == 'lookahead':
185 | optimizer = Lookahead(optimizer)
186 |
187 | return optimizer
188 |
--------------------------------------------------------------------------------
/util_tools/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import warnings
4 | import random
5 | import numpy as np
6 | import torchvision
7 | from PIL import Image, ImageOps
8 | import numbers
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_tuple):
19 | img_group, label = img_tuple
20 |
21 | w, h = img_group[0].size
22 | th, tw = self.size
23 |
24 | out_images = list()
25 |
26 | x1 = random.randint(0, w - tw)
27 | y1 = random.randint(0, h - th)
28 |
29 | for img in img_group:
30 | assert(img.size[0] == w and img.size[1] == h)
31 | if w == tw and h == th:
32 | out_images.append(img)
33 | else:
34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35 |
36 | return (out_images, label)
37 |
38 |
39 | class GroupCenterCrop(object):
40 | def __init__(self, size):
41 | self.worker = torchvision.transforms.CenterCrop(size)
42 |
43 | def __call__(self, img_tuple):
44 | img_group, label = img_tuple
45 | return ([self.worker(img) for img in img_group], label)
46 |
47 |
48 | class GroupNormalize(object):
49 | def __init__(self, mean, std):
50 | self.mean = mean
51 | self.std = std
52 |
53 | def __call__(self, tensor_tuple):
54 | tensor, label = tensor_tuple
55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56 | rep_std = self.std * (tensor.size()[0]//len(self.std))
57 |
58 | # TODO: make efficient
59 | for t, m, s in zip(tensor, rep_mean, rep_std):
60 | t.sub_(m).div_(s)
61 |
62 | return (tensor,label)
63 |
64 |
65 | class GroupGrayScale(object):
66 | def __init__(self, size):
67 | self.worker = torchvision.transforms.Grayscale(size)
68 |
69 | def __call__(self, img_tuple):
70 | img_group, label = img_tuple
71 | return ([self.worker(img) for img in img_group], label)
72 |
73 |
74 | class GroupScale(object):
75 | """ Rescales the input PIL.Image to the given 'size'.
76 | 'size' will be the size of the smaller edge.
77 | For example, if height > width, then image will be
78 | rescaled to (size * height / width, size)
79 | size: size of the smaller edge
80 | interpolation: Default: PIL.Image.BILINEAR
81 | """
82 |
83 | def __init__(self, size, interpolation=Image.BILINEAR):
84 | self.worker = torchvision.transforms.Resize(size, interpolation)
85 |
86 | def __call__(self, img_tuple):
87 | img_group, label = img_tuple
88 | return ([self.worker(img) for img in img_group], label)
89 |
90 |
91 | class GroupMultiScaleCrop(object):
92 |
93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
94 | self.scales = scales if scales is not None else [1, .875, .75, .66]
95 | self.max_distort = max_distort
96 | self.fix_crop = fix_crop
97 | self.more_fix_crop = more_fix_crop
98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
99 | self.interpolation = Image.BILINEAR
100 |
101 | def __call__(self, img_tuple):
102 | img_group, label = img_tuple
103 |
104 | im_size = img_group[0].size
105 |
106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
109 | return (ret_img_group, label)
110 |
111 | def _sample_crop_size(self, im_size):
112 | image_w, image_h = im_size[0], im_size[1]
113 |
114 | # find a crop size
115 | base_size = min(image_w, image_h)
116 | crop_sizes = [int(base_size * x) for x in self.scales]
117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
119 |
120 | pairs = []
121 | for i, h in enumerate(crop_h):
122 | for j, w in enumerate(crop_w):
123 | if abs(i - j) <= self.max_distort:
124 | pairs.append((w, h))
125 |
126 | crop_pair = random.choice(pairs)
127 | if not self.fix_crop:
128 | w_offset = random.randint(0, image_w - crop_pair[0])
129 | h_offset = random.randint(0, image_h - crop_pair[1])
130 | else:
131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
132 |
133 | return crop_pair[0], crop_pair[1], w_offset, h_offset
134 |
135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
137 | return random.choice(offsets)
138 |
139 | @staticmethod
140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
141 | w_step = (image_w - crop_w) // 4
142 | h_step = (image_h - crop_h) // 4
143 |
144 | ret = list()
145 | ret.append((0, 0)) # upper left
146 | ret.append((4 * w_step, 0)) # upper right
147 | ret.append((0, 4 * h_step)) # lower left
148 | ret.append((4 * w_step, 4 * h_step)) # lower right
149 | ret.append((2 * w_step, 2 * h_step)) # center
150 |
151 | if more_fix_crop:
152 | ret.append((0, 2 * h_step)) # center left
153 | ret.append((4 * w_step, 2 * h_step)) # center right
154 | ret.append((2 * w_step, 4 * h_step)) # lower center
155 | ret.append((2 * w_step, 0 * h_step)) # upper center
156 |
157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
161 | return ret
162 |
163 |
164 | class Stack(object):
165 |
166 | def __init__(self, roll=False):
167 | self.roll = roll
168 |
169 | def __call__(self, img_tuple):
170 | img_group, label = img_tuple
171 |
172 | if img_group[0].mode == 'L':
173 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
174 | elif img_group[0].mode == 'RGB':
175 | if self.roll:
176 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
177 | else:
178 | return (np.concatenate(img_group, axis=2), label)
179 |
180 |
181 | class ToTorchFormatTensor(object):
182 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
183 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
184 | def __init__(self, div=True):
185 | self.div = div
186 |
187 | def __call__(self, pic_tuple):
188 | pic, label = pic_tuple
189 |
190 | if isinstance(pic, np.ndarray):
191 | # handle numpy array
192 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
193 | else:
194 | # handle PIL Image
195 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
196 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
197 | # put it from HWC to CHW format
198 | # yikes, this transpose takes 80% of the loading time/CPU
199 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
200 | return (img.float().div(255.) if self.div else img.float(), label)
201 |
202 |
203 | class IdentityTransform(object):
204 |
205 | def __call__(self, data):
206 | return data
207 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # CAST: Cross-Attention in Space and Time for Video Action Recognition [[NeurIPS 2023](https://neurips.cc/virtual/2023/poster/70748)][[Project Page](https://jong980812.github.io/CAST.github.io)][[Arxiv](https://arxiv.org/abs/2311.18825)]
3 |
4 | 
5 |
6 |
7 |
8 | [](https://paperswithcode.com/sota/action-recognition-on-epic-kitchens-100?p=cast-cross-attention-in-space-and-time-for-1)
9 |
10 | 
11 | 
12 | 
13 | 
14 |
15 |
16 | # :wrench: Installation
17 |
18 | We conduct all the experiments with 16 NVIDIA GeForce RTX 3090 GPUs.
19 | First, install PyTorch 1.10.0+ and torchvision 0.11.0.
20 |
21 | ```
22 | conda create -n vmae_1.10 python=3.8 ipykernel -y
23 | conda activate vmae_1.10
24 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 -c pytorch
25 | ```
26 | Then, install timm, triton, DeepSpeed, and others.
27 | ```
28 | pip install triton==1.0.0
29 | git clone https://github.com/microsoft/DeepSpeed
30 | cd DeepSpeed
31 | git checkout 3a3dfe66bb
32 | DS_BUILD_OPS=1 pip install . --global-option="build_ext"
33 | pip install TensorboardX decord einops scipy pandas requests
34 | ds_report
35 | ```
36 |
37 | If you have successfully installed Deepspeed, after running the 'ds_report' command, you can see the following results.
38 | For other Deepspeed-related issues, please refer to the [DeepSpeed GitHub page](https://github.com/microsoft/DeepSpeed).
39 |
40 | 
41 |
42 | # :file_folder: Data Preparation
43 |
44 | * We report experimental results on three standard datasets.([EPIC-KITCHENS-100](https://epic-kitchens.github.io/2023), [Something-Something-V2](https://developer.qualcomm.com/software/ai-datasets/something-something), [Kinetics400](https://deepmind.com/research/open-source/kinetics))
45 | * We provide sample annotation files -> [annotations](./annotations/).
46 |
47 | ### EPIC-KITCHENS-100
48 | - The pre-processing of **EPIC-KITCHENS-100** can be summarized into 3 steps:
49 |
50 | 1. Download the dataset from [official website](https://github.com/epic-kitchens/epic-kitchens-download-scripts).
51 |
52 | 2. Preprocess the dataset by resizing the short edge of video to **256px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2).
53 |
54 | 3. Generate annotations needed for dataloader (",," in annotations). The annotation usually includes `train.csv`, `val.csv`. The format of `*.csv` file is like:
55 |
56 |
57 | ```
58 | video_1,verb_1,noun_1
59 | video_2,verb_2,noun_2
60 | video_3,verb_3,noun_3
61 | ...
62 | video_N,verb_N,noun_N
63 | ```
64 | 4. All video files are located inside the DATA_PATH.
65 |
66 | ### Something-Something-V2
67 | - The pre-processing of **Something-Something-V2** can be summarized into 3 steps:
68 |
69 | 1. Download the dataset from [official website](https://developer.qualcomm.com/software/ai-datasets/something-something).
70 |
71 | 2. Preprocess the dataset by changing the video extension from `webm` to `.mp4` with the **original** height of **240px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2).
72 |
73 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv`. The format of `*.csv` file is like:
74 |
75 | ```
76 | video_1.mp4 label_1
77 | video_2.mp4 label_2
78 | video_3.mp4 label_3
79 | ...
80 | video_N.mp4 label_N
81 | ```
82 | 4. All video files are located inside the DATA_PATH.
83 | ### Kinetics-400
84 | - The pre-processing of **Kinetics400** can be summarized into 3 steps:
85 |
86 | 1. Download the dataset from [official website](https://deepmind.com/research/open-source/kinetics) or [OpenDataLab](https://opendatalab.com/OpenMMLab/Kinetics-400).
87 |
88 | 2. Preprocess the dataset by resizing the short edge of video to **320px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2).
89 |
90 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv`. The format of `*.csv` file is like:
91 |
92 | ```
93 | video_1.mp4 label_1
94 | video_2.mp4 label_2
95 | video_3.mp4 label_3
96 | ...
97 | video_N.mp4 label_N
98 | ```
99 |
100 | 4. All video files should be splited into **DATA_PATH/train** and **DATA_PATH/val**.
101 | # Expert model preparation
102 | We use the pre-trained weights of spatial and temporal experts. The pretrained weight of the spatial expert (CLIP) uses the [official weight](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt). The pre-trained weight of the temporal expert (VideoMAE) uses the pre-trained weights from the three datasets EK100, K400, and SSV2. Of these, [K400](https://drive.google.com/file/d/1MzwteHH-1yuMnFb8vRBQDvngV1Zl-d3z/view?usp=sharing) and [SSV2](https://drive.google.com/file/d/1dt_59tBIyzdZd5Ecr22lTtzs_64MOZkT/view?usp=sharing) use the [official weights](https://github.com/potatowarriors/VideoMAE/blob/main/MODEL_ZOO.md), and [EK100](https://drive.google.com/file/d/16zDs_9ycAoz8AoEecrQXsC5vLZsaQoTw/view?usp=sharing) uses the weights we pre-trained ourselves. Put each downloaded expert weight into the VMAE_PATH and CLIP_PATH of the fine-tune script.
103 |
104 |
105 | # Fine-tuning CAST
106 |
107 | We provide the **off-the-shelf** scripts in the [scripts folder](scripts).
108 |
109 | - For example, to fine-tune CAST on **Kinetics400** with 16 GPUs (2 nodes x 8 GPUs) script.
110 |
111 | ```bash
112 | DATA_PATH=YOUR_PATH
113 | VMAE_MODEL_PATH=YOUR_PATH
114 | CLIP_MODEL_PATH=YOUR_PATH
115 |
116 |
117 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \
118 | --nproc_per_node=2 \
119 | --master_port ${YOUR_NUMBER} --nnodes=8 \
120 | --node_rank=${YOUR_NUMBER} --master_addr=${YOUR_NUMBER} \
121 | YOUR_PATH/run_bidirection_compo.py \
122 | --data_set Kinetics-400 \
123 | --nb_classes 400 \
124 | --vmae_model compo_bidir_vit_base_patch16_224 \
125 | --anno_path ${ANNOTATION_PATH}
126 | --data_path ${DATA_PATH} \
127 | --clip_finetune ${CLIP_MODEL_PATH} \
128 | --vmae_finetune ${VMAE_MODEL_PATH} \
129 | --log_dir ${YOUR_PATH} \
130 | --output_dir ${YOUR_PATH} \
131 | --batch_size 6 \
132 | --input_size 224 \
133 | --short_side_size 224 \
134 | --save_ckpt_freq 25 \
135 | --num_sample 1 \
136 | --num_frames 16 \
137 | --opt adamw \
138 | --lr 1e-3 \
139 | --opt_betas 0.9 0.999 \
140 | --weight_decay 0.05 \
141 | --epochs 70 \
142 | --dist_eval \
143 | --test_num_segment 5 \
144 | --test_num_crop 3 \
145 | --num_workers 8 \
146 | --drop_path 0.2 \
147 | --layer_decay 0.75 \
148 | --mixup_switch_prob 0 \
149 | --mixup_prob 0.5 \
150 | --reprob 0. \
151 | --init_scale 1. \
152 | --update_freq 6 \
153 | --seed 0 \
154 | --enable_deepspeed \
155 | --warmup_epochs 5 \
156 | ```
157 | # Evaluation
158 | Evaluation commands for the EK100.
159 | ```
160 | python ./run_bidirection_compo.py --fine_tune {YOUR_FINETUNED_WEIGHT} --composition --eval
161 | ```
162 | Evaluation commands for the SSV2, K400.
163 | ```
164 | python ./run_bidirection.py --fine_tune {YOUR_FINETUNED_WEIGHT} --eval
165 | ```
166 | # Model Zoo
167 |
168 | ### EPIC-KITCHENS-100
169 |
170 | | Method | Spatial Expert | Temporal expert | Epoch | \#Frames x Clips x Crops | Fine-tune | Top-1 |
171 | | :------: | :------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :---: |
172 | | CAST | [CLIP-B/16](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) | [VideoMAE-B/16 (pre-trained on EK100)](https://drive.google.com/file/d/1DaxOctpEkmKTi873J1jzzz_Sl-0wRai7/view?usp=sharing) | 50 | 16x2x3 | [log](https://drive.google.com/file/d/1yry2Nd5BEaX3kZjNYDghGHubpei-by_9/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1pW5tMWG2N5zqQOOPcrawwQpIAhPFMoVx/view?usp=sharing)
| 49.3 |
173 | ### Something-Something V2
174 |
175 | | Method | Spatial Expert | Temporal expert | Epoch | \#Frames x Clips x Crops | Fine-tune | Top-1 |
176 | | :------: | :------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :---: |
177 | | CAST | [CLIP-B/16](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) | [VideoMAE-B/16 (pre-trained on SSV2)](https://drive.google.com/file/d/1dt_59tBIyzdZd5Ecr22lTtzs_64MOZkT/view?usp=sharing) | 50 | 16x2x3 | [log](https://drive.google.com/file/d/1wOjcXSen9B9R2CIQ8ge7HrRufr_MmMcN/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1RrAVF4tlpZCPYNY49M0pKQfopP6p6Vir/view?usp=sharing)
| 71.6 |
178 |
179 | ### Kinetics-400
180 |
181 | | Method | Spatial Expert | Temporal expert | Epoch | \#Frames x Clips x Crops | Fine-tune | Top-1 |
182 | | :------: | :------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :---: |
183 | | CAST | [CLIP-B/16](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) | [VideoMAE-B/16 (pre-trained on K400)](https://drive.google.com/file/d/1MzwteHH-1yuMnFb8vRBQDvngV1Zl-d3z/view?usp=sharing) | 70 | 16x5x3 | [log](https://drive.google.com/file/d/1Npw-GblhSGWVx0nU06ztjDLMFcazeCx6/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/16ndsBVNjRuJMRM40P0-a3Q1JVA8tNLK7/view?usp=sharing)
| 85.3 |
184 |
185 |
186 | ## Acknowledgements
187 |
188 | This project is built upon [VideoMAE](https://github.com/MCG-NJU/VideoMAE), [MAE](https://github.com/pengzhiliang/MAE-pytorch), [CLIP](https://github.com/openai/CLIP) and [BEiT](https://github.com/microsoft/unilm/tree/master/beit). Thanks to the contributors of these great codebases.
189 |
190 | ## License
191 |
192 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](https://github.com/MCG-NJU/VideoMAE/blob/main/LICENSE) for details.
193 |
194 | ## Citation
195 | ```
196 | @article{cast,
197 | title={CAST: Cross-Attention in Space and Time for Video Action Recognition},
198 | author={Lee, Dongho and Lee, Jongseo and Choi, Jinwoo},
199 | booktitle={NeurIPS}},
200 | year={2023}
201 | ```
202 |
--------------------------------------------------------------------------------
/engine_for_onemodel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import sys
5 | from typing import Iterable, Optional
6 | import torch
7 | from util_tools.mixup import Mixup
8 | from timm.utils import accuracy, ModelEma
9 | import util_tools.utils as utils
10 | from scipy.special import softmax
11 | from einops import rearrange
12 |
13 | def cross_train_class_batch(model, samples, target, criterion):
14 | outputs = model(samples)
15 | loss = criterion(outputs, target)
16 | return loss, outputs
17 |
18 |
19 | def get_loss_scale_for_deepspeed(model):
20 | optimizer = model.optimizer
21 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
22 |
23 |
24 | def train_one_epoch(args, model: torch.nn.Module, criterion: torch.nn.Module,
25 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
28 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
29 | num_training_steps_per_epoch=None, update_freq=None):
30 | model.train(True)
31 | metric_logger = utils.MetricLogger(delimiter=" ")
32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
33 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34 | header = 'Epoch: [{}]'.format(epoch)
35 | print_freq = 10
36 |
37 | if loss_scaler is None:
38 | model.zero_grad()
39 | model.micro_steps = 0
40 | else:
41 | optimizer.zero_grad()
42 |
43 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
44 | step = data_iter_step // update_freq
45 | if step >= num_training_steps_per_epoch:
46 | continue
47 | it = start_steps + step # global training iteration
48 | # Update LR & WD for the first acc
49 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
50 | for i, param_group in enumerate(optimizer.param_groups):
51 | if lr_schedule_values is not None:
52 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
53 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
54 | param_group["weight_decay"] = wd_schedule_values[it]
55 |
56 | samples = samples.to(device, non_blocking=True)
57 | targets = targets.to(device, non_blocking=True)
58 |
59 | if mixup_fn is not None:
60 | samples, targets = mixup_fn(samples, targets)
61 |
62 | if loss_scaler is None:
63 | samples = samples.half()
64 | loss, output = cross_train_class_batch(
65 | model, samples, targets, criterion)
66 | else:
67 | with torch.cuda.amp.autocast():
68 | samples = samples.half()
69 | loss, output = cross_train_class_batch(
70 | model, samples, targets, criterion)
71 | loss_value = loss.item()
72 |
73 | if not math.isfinite(loss_value):
74 | print("Loss is {}, stopping training".format(loss_value))
75 | sys.exit(1)
76 |
77 | if loss_scaler is None:
78 | loss /= update_freq
79 | model.backward(loss)
80 | model.step()
81 |
82 | if (data_iter_step + 1) % update_freq == 0:
83 | # model.zero_grad()
84 | # Deepspeed will call step() & model.zero_grad() automatic
85 | if model_ema is not None:
86 | model_ema.update(model)
87 | grad_norm = None
88 | loss_scale_value = get_loss_scale_for_deepspeed(model)
89 | else:
90 | # this attribute is added by timm on one optimizer (adahessian)
91 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
92 | loss /= update_freq
93 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
94 | parameters=model.parameters(), create_graph=is_second_order,
95 | update_grad=(data_iter_step + 1) % update_freq == 0)
96 | if (data_iter_step + 1) % update_freq == 0:
97 | optimizer.zero_grad()
98 | if model_ema is not None:
99 | model_ema.update(model)
100 | loss_scale_value = loss_scaler.state_dict()["scale"]
101 |
102 | torch.cuda.synchronize()
103 |
104 | if mixup_fn is None:
105 | class_acc = (output.max(-1)[-1] == targets).float().mean()
106 | else:
107 | class_acc = None
108 | metric_logger.update(loss=loss_value)
109 | metric_logger.update(class_acc=class_acc)
110 | metric_logger.update(loss_scale=loss_scale_value)
111 | min_lr = 10.
112 | max_lr = 0.
113 | for group in optimizer.param_groups:
114 | min_lr = min(min_lr, group["lr"])
115 | max_lr = max(max_lr, group["lr"])
116 |
117 | metric_logger.update(lr=max_lr)
118 | metric_logger.update(min_lr=min_lr)
119 | weight_decay_value = None
120 | for group in optimizer.param_groups:
121 | if group["weight_decay"] > 0:
122 | weight_decay_value = group["weight_decay"]
123 | metric_logger.update(weight_decay=weight_decay_value)
124 | metric_logger.update(grad_norm=grad_norm)
125 |
126 | if log_writer is not None:
127 | log_writer.update(loss=loss_value, head="loss")
128 | log_writer.update(class_acc=class_acc, head="loss")
129 | log_writer.update(loss_scale=loss_scale_value, head="opt")
130 | log_writer.update(lr=max_lr, head="opt")
131 | log_writer.update(min_lr=min_lr, head="opt")
132 | log_writer.update(weight_decay=weight_decay_value, head="opt")
133 | log_writer.update(grad_norm=grad_norm, head="opt")
134 |
135 | log_writer.set_step()
136 |
137 | # gather the stats from all processes
138 | metric_logger.synchronize_between_processes()
139 | print("Averaged stats:", metric_logger)
140 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
141 |
142 |
143 | @torch.no_grad()
144 | def validation_one_epoch(args, data_loader, model, device):
145 | criterion = torch.nn.CrossEntropyLoss()
146 |
147 | metric_logger = utils.MetricLogger(delimiter=" ")
148 | header = 'Val:'
149 |
150 | # switch to evaluation mode
151 | model.eval()
152 |
153 | for batch in metric_logger.log_every(data_loader, 10, header):
154 | samples = batch[0]
155 | target = batch[1]
156 | batch_size = samples.shape[0]
157 | samples = samples.to(device, non_blocking=True)
158 | target = target.to(device, non_blocking=True)
159 |
160 | # compute output
161 | with torch.cuda.amp.autocast():
162 | output = model(samples)
163 | loss = criterion(output, target)
164 |
165 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
166 |
167 | metric_logger.update(loss=loss.item())
168 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
169 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
170 | # gather the stats from all processes
171 | metric_logger.synchronize_between_processes()
172 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
173 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
174 |
175 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
176 |
177 |
178 |
179 | @torch.no_grad()
180 | def final_test(args, data_loader, model, device, file):
181 | criterion = torch.nn.CrossEntropyLoss()
182 |
183 | metric_logger = utils.MetricLogger(delimiter=" ")
184 | header = 'Test:'
185 |
186 | # switch to evaluation mode
187 | model.eval()
188 | final_result = []
189 |
190 | for batch in metric_logger.log_every(data_loader, 10, header):
191 | samples = batch[0]
192 | target = batch[1]
193 | ids = batch[2]
194 | chunk_nb = batch[3]
195 | split_nb = batch[4]
196 | batch_size = samples.shape[0]
197 | samples = samples.to(device, non_blocking=True)
198 | target = target.to(device, non_blocking=True)
199 |
200 | # compute output
201 | with torch.cuda.amp.autocast():
202 | output = model(samples)
203 | loss = criterion(output, target)
204 |
205 | for i in range(output.size(0)):
206 | string = "{} {} {} {} {}\n".format(ids[i], \
207 | str(output.data[i].cpu().numpy().tolist()), \
208 | str(int(target[i].cpu().numpy())), \
209 | str(int(chunk_nb[i].cpu().numpy())), \
210 | str(int(split_nb[i].cpu().numpy())))
211 | final_result.append(string)
212 |
213 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
214 |
215 | metric_logger.update(loss=loss.item())
216 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
217 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
218 |
219 | if not os.path.exists(file):
220 | os.mknod(file)
221 | with open(file, 'w') as f:
222 | f.write("{}, {}\n".format(acc1, acc5))
223 | for line in final_result:
224 | f.write(line)
225 | # gather the stats from all processes
226 | metric_logger.synchronize_between_processes()
227 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
228 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
229 |
230 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
231 |
232 |
233 | def merge(eval_path, num_tasks):
234 | dict_feats = {}
235 | dict_label = {}
236 | dict_pos = {}
237 | print("Reading individual output files")
238 |
239 | for x in range(num_tasks):
240 | file = os.path.join(eval_path, str(x) + '.txt')
241 | lines = open(file, 'r').readlines()[1:]
242 | for line in lines:
243 | line = line.strip()
244 | name = line.split('[')[0]
245 | label = line.split(']')[1].split(' ')[1]
246 | chunk_nb = line.split(']')[1].split(' ')[2]
247 | split_nb = line.split(']')[1].split(' ')[3]
248 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',')
249 | data = softmax(data)
250 | if not name in dict_feats:
251 | dict_feats[name] = []
252 | dict_label[name] = 0
253 | dict_pos[name] = []
254 | if chunk_nb + split_nb in dict_pos[name]:
255 | continue
256 | dict_feats[name].append(data)
257 | dict_pos[name].append(chunk_nb + split_nb)
258 | dict_label[name] = label
259 | print("Computing final results")
260 |
261 | input_lst = []
262 | print(len(dict_feats))
263 | for i, item in enumerate(dict_feats):
264 | input_lst.append([i, item, dict_feats[item], dict_label[item]])
265 | from multiprocessing import Pool
266 | p = Pool(64)
267 | ans = p.map(compute_video, input_lst)
268 | top1 = [x[1] for x in ans]
269 | top5 = [x[2] for x in ans]
270 | pred = [x[0] for x in ans]
271 | label = [x[3] for x in ans]
272 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
273 | return final_top1*100 ,final_top5*100
274 |
275 | def compute_video(lst):
276 | i, video_id, data, label = lst
277 | feat = [x for x in data]
278 | feat = np.mean(feat, axis=0)
279 | pred = np.argmax(feat)
280 | top1 = (int(pred) == int(label)) * 1.0
281 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
282 | return [pred, top1, top5, int(label)]
283 |
--------------------------------------------------------------------------------
/util_tools/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27 | return y1 * lam + y2 * (1. - lam)
28 |
29 |
30 | def rand_bbox(img_shape, lam, margin=0., count=None):
31 | """ Standard CutMix bounding-box
32 | Generates a random square bbox based on lambda value. This impl includes
33 | support for enforcing a border margin as percent of bbox dimensions.
34 |
35 | Args:
36 | img_shape (tuple): Image shape as tuple
37 | lam (float): Cutmix lambda value
38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39 | count (int): Number of bbox to generate
40 | """
41 | ratio = np.sqrt(1 - lam)
42 | img_h, img_w = img_shape[-2:]
43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47 | yl = np.clip(cy - cut_h // 2, 0, img_h)
48 | yh = np.clip(cy + cut_h // 2, 0, img_h)
49 | xl = np.clip(cx - cut_w // 2, 0, img_w)
50 | xh = np.clip(cx + cut_w // 2, 0, img_w)
51 | return yl, yh, xl, xh
52 |
53 |
54 | def rand_bbox_minmax(img_shape, minmax, count=None):
55 | """ Min-Max CutMix bounding-box
56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
57 | based on min/max percent values applied to each dimension of the input image.
58 |
59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60 |
61 | Args:
62 | img_shape (tuple): Image shape as tuple
63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64 | count (int): Number of bbox to generate
65 | """
66 | assert len(minmax) == 2
67 | img_h, img_w = img_shape[-2:]
68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70 | yl = np.random.randint(0, img_h - cut_h, size=count)
71 | xl = np.random.randint(0, img_w - cut_w, size=count)
72 | yu = yl + cut_h
73 | xu = xl + cut_w
74 | return yl, yu, xl, xu
75 |
76 |
77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78 | """ Generate bbox and apply lambda correction.
79 | """
80 | if ratio_minmax is not None:
81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82 | else:
83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84 | if correct_lam or ratio_minmax is not None:
85 | bbox_area = (yu - yl) * (xu - xl)
86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87 | return (yl, yu, xl, xu), lam
88 |
89 |
90 | class Mixup:
91 | """ Mixup/Cutmix that applies different params to each element or whole batch
92 |
93 | Args:
94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97 | prob (float): probability of applying mixup or cutmix per batch or element
98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101 | label_smoothing (float): apply label smoothing to the mixed target tensor
102 | num_classes (int): number of classes for target
103 | """
104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, composition=False):
106 | self.mixup_alpha = mixup_alpha
107 | self.cutmix_alpha = cutmix_alpha
108 | self.cutmix_minmax = cutmix_minmax
109 | if self.cutmix_minmax is not None:
110 | assert len(self.cutmix_minmax) == 2
111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112 | self.cutmix_alpha = 1.0
113 | self.mix_prob = prob
114 | self.switch_prob = switch_prob
115 | self.label_smoothing = label_smoothing
116 | self.num_classes = num_classes
117 | self.mode = mode
118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120 | self.composition = composition
121 |
122 | def _params_per_elem(self, batch_size):
123 | lam = np.ones(batch_size, dtype=np.float32)
124 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
125 | if self.mixup_enabled:
126 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
127 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
128 | lam_mix = np.where(
129 | use_cutmix,
130 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
131 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
132 | elif self.mixup_alpha > 0.:
133 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
134 | elif self.cutmix_alpha > 0.:
135 | use_cutmix = np.ones(batch_size, dtype=np.bool)
136 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
137 | else:
138 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
139 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
140 | return lam, use_cutmix
141 |
142 | def _params_per_batch(self):
143 | lam = 1.
144 | use_cutmix = False
145 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
146 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
147 | use_cutmix = np.random.rand() < self.switch_prob
148 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
149 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
150 | elif self.mixup_alpha > 0.:
151 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
152 | elif self.cutmix_alpha > 0.:
153 | use_cutmix = True
154 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
155 | else:
156 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
157 | lam = float(lam_mix)
158 | return lam, use_cutmix
159 |
160 | def _mix_elem(self, x):
161 | batch_size = len(x)
162 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
163 | x_orig = x.clone() # need to keep an unmodified original for mixing source
164 | for i in range(batch_size):
165 | j = batch_size - i - 1
166 | lam = lam_batch[i]
167 | if lam != 1.:
168 | if use_cutmix[i]:
169 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
170 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
171 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
172 | lam_batch[i] = lam
173 | else:
174 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
175 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
176 |
177 | def _mix_pair(self, x):
178 | batch_size = len(x)
179 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
180 | x_orig = x.clone() # need to keep an unmodified original for mixing source
181 | for i in range(batch_size // 2):
182 | j = batch_size - i - 1
183 | lam = lam_batch[i]
184 | if lam != 1.:
185 | if use_cutmix[i]:
186 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
187 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
188 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
189 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
190 | lam_batch[i] = lam
191 | else:
192 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
193 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
194 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
195 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
196 |
197 | def _mix_batch(self, x):
198 | lam, use_cutmix = self._params_per_batch()
199 | if lam == 1.:
200 | return 1.
201 | if use_cutmix:
202 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
203 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
204 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
205 | else:
206 | x_flipped = x.flip(0).mul_(1. - lam)
207 | x.mul_(lam).add_(x_flipped)
208 | return lam
209 |
210 | def __call__(self, x, target):
211 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
212 | if self.mode == 'elem':
213 | lam = self._mix_elem(x)
214 | elif self.mode == 'pair':
215 | lam = self._mix_pair(x)
216 | else:
217 | lam = self._mix_batch(x)
218 | if self.composition:
219 | target_noun = mixup_target(target[:,0], 300, lam, self.label_smoothing, x.device)
220 | target_verb = mixup_target(target[:,1], 97, lam, self.label_smoothing, x.device)
221 | return x, target_noun, target_verb
222 | else:
223 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
224 | return x, target
225 |
226 |
227 | class FastCollateMixup(Mixup):
228 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
229 |
230 | A Mixup impl that's performed while collating the batches.
231 | """
232 |
233 | def _mix_elem_collate(self, output, batch, half=False):
234 | batch_size = len(batch)
235 | num_elem = batch_size // 2 if half else batch_size
236 | assert len(output) == num_elem
237 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
238 | for i in range(num_elem):
239 | j = batch_size - i - 1
240 | lam = lam_batch[i]
241 | mixed = batch[i][0]
242 | if lam != 1.:
243 | if use_cutmix[i]:
244 | if not half:
245 | mixed = mixed.copy()
246 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
247 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
248 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
249 | lam_batch[i] = lam
250 | else:
251 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
252 | np.rint(mixed, out=mixed)
253 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
254 | if half:
255 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
256 | return torch.tensor(lam_batch).unsqueeze(1)
257 |
258 | def _mix_pair_collate(self, output, batch):
259 | batch_size = len(batch)
260 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
261 | for i in range(batch_size // 2):
262 | j = batch_size - i - 1
263 | lam = lam_batch[i]
264 | mixed_i = batch[i][0]
265 | mixed_j = batch[j][0]
266 | assert 0 <= lam <= 1.0
267 | if lam < 1.:
268 | if use_cutmix[i]:
269 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
270 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
271 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
272 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
273 | mixed_j[:, yl:yh, xl:xh] = patch_i
274 | lam_batch[i] = lam
275 | else:
276 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
277 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
278 | mixed_i = mixed_temp
279 | np.rint(mixed_j, out=mixed_j)
280 | np.rint(mixed_i, out=mixed_i)
281 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
282 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
283 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
284 | return torch.tensor(lam_batch).unsqueeze(1)
285 |
286 | def _mix_batch_collate(self, output, batch):
287 | batch_size = len(batch)
288 | lam, use_cutmix = self._params_per_batch()
289 | if use_cutmix:
290 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
291 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
292 | for i in range(batch_size):
293 | j = batch_size - i - 1
294 | mixed = batch[i][0]
295 | if lam != 1.:
296 | if use_cutmix:
297 | mixed = mixed.copy() # don't want to modify the original while iterating
298 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
299 | else:
300 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
301 | np.rint(mixed, out=mixed)
302 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
303 | return lam
304 |
305 | def __call__(self, batch, _=None):
306 | batch_size = len(batch)
307 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
308 | half = 'half' in self.mode
309 | if half:
310 | batch_size //= 2
311 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
312 | if self.mode == 'elem' or self.mode == 'half':
313 | lam = self._mix_elem_collate(output, batch, half=half)
314 | elif self.mode == 'pair':
315 | lam = self._mix_pair_collate(output, batch)
316 | else:
317 | lam = self._mix_batch_collate(output, batch)
318 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
319 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
320 | target = target[:batch_size]
321 | return output, target
322 |
323 |
--------------------------------------------------------------------------------
/dataset/ssv2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from statistics import NormalDist
3 | import numpy as np
4 | import torch
5 | from torchvision import transforms
6 | from util_tools.random_erasing import RandomErasing
7 | import warnings
8 | from decord import VideoReader, cpu
9 | from torch.utils.data import Dataset
10 | import util_tools.video_transforms as video_transforms
11 | import util_tools.volume_transforms as volume_transforms
12 |
13 |
14 | class SSVideoClsDataset(Dataset):
15 | """Load your own video classification dataset."""
16 |
17 | def __init__(self, anno_path, data_path, mode='train', clip_len=8,
18 | crop_size=224, short_side_size=256, new_height=256,
19 | new_width=340, keep_aspect_ratio=True, num_segment=1,
20 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None):
21 | self.anno_path = anno_path
22 | self.data_path = data_path
23 | self.mode = mode
24 | self.clip_len = clip_len
25 | self.crop_size = crop_size
26 | self.short_side_size = short_side_size
27 | self.new_height = new_height
28 | self.new_width = new_width
29 | self.keep_aspect_ratio = keep_aspect_ratio
30 | self.num_segment = num_segment
31 | self.test_num_segment = test_num_segment
32 | self.num_crop = num_crop
33 | self.test_num_crop = test_num_crop
34 | self.args = args
35 | self.aug = False
36 | self.rand_erase = False
37 | if self.mode in ['train']:
38 | self.aug = True
39 | if self.args.reprob > 0:
40 | self.rand_erase = True
41 | if VideoReader is None:
42 | raise ImportError("Unable to import `decord` which is required to read videos.")
43 |
44 | import pandas as pd
45 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
46 | self.dataset_samples = list(cleaned.values[:, 0])
47 | self.label_array = list(cleaned.values[:, 1])
48 |
49 | if (mode == 'train'):
50 | pass
51 |
52 | elif (mode == 'validation'):
53 | self.data_transform = video_transforms.Compose([
54 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
55 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
56 | volume_transforms.ClipToTensor(),
57 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
58 | std=[0.229, 0.224, 0.225])
59 | ])
60 | elif mode == 'test':
61 | self.data_resize = video_transforms.Compose([
62 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
63 | ])
64 | self.data_transform = video_transforms.Compose([
65 | volume_transforms.ClipToTensor(),
66 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
67 | std=[0.229, 0.224, 0.225])
68 | ])
69 | self.test_seg = []
70 | self.test_dataset = []
71 | self.test_label_array = []
72 | for ck in range(self.test_num_segment):
73 | for cp in range(self.test_num_crop):
74 | for idx in range(len(self.label_array)):
75 | sample_label = self.label_array[idx]
76 | self.test_label_array.append(sample_label)
77 | self.test_dataset.append(self.dataset_samples[idx])
78 | self.test_seg.append((ck, cp))
79 |
80 | def __getitem__(self, index):
81 | if self.mode == 'train':
82 | args = self.args
83 | scale_t = 1
84 |
85 | sample = self.dataset_samples[index]
86 | sample = os.path.join(self.data_path,sample)# self.data_path + '/videos_train/' + sample
87 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
88 | if len(buffer) == 0:
89 | while len(buffer) == 0:
90 | warnings.warn("video {} not correctly loaded during training".format(sample))
91 | index = np.random.randint(self.__len__())
92 | sample = self.dataset_samples[index]
93 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
94 |
95 | if args.num_sample > 1:
96 | frame_list = []
97 | label_list = []
98 | index_list = []
99 | for _ in range(args.num_sample):
100 | new_frames = self._aug_frame(buffer, args)
101 | label = self.label_array[index]
102 | frame_list.append(new_frames)
103 | label_list.append(label)
104 | index_list.append(index)
105 | return frame_list, label_list, index_list, {}
106 | else:
107 | buffer = self._aug_frame(buffer, args)
108 |
109 | return buffer, self.label_array[index], index, {}
110 |
111 | elif self.mode == 'validation':
112 | sample = self.dataset_samples[index]
113 | sample = os.path.join(self.data_path,sample)# self.data_path + '/videos_train/' + sample
114 | buffer = self.loadvideo_decord(sample)
115 | if len(buffer) == 0:
116 | while len(buffer) == 0:
117 | warnings.warn("video {} not correctly loaded during validation".format(sample))
118 | index = np.random.randint(self.__len__())
119 | sample = self.dataset_samples[index]
120 | buffer = self.loadvideo_decord(sample)
121 | buffer = self.data_transform(buffer)
122 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
123 |
124 | elif self.mode == 'test':
125 | sample = self.test_dataset[index]
126 | sample = os.path.join(self.data_path,sample)# self.data_path + '/videos_train/' + sample
127 | chunk_nb, split_nb = self.test_seg[index]
128 | buffer = self.loadvideo_decord(sample)
129 |
130 | while len(buffer) == 0:
131 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
132 | str(self.test_dataset[index]), chunk_nb, split_nb))
133 | index = np.random.randint(self.__len__())
134 | sample = self.test_dataset[index]
135 | chunk_nb, split_nb = self.test_seg[index]
136 | buffer = self.loadvideo_decord(sample)
137 |
138 | buffer = self.data_resize(buffer)
139 | if isinstance(buffer, list):
140 | buffer = np.stack(buffer, 0)
141 |
142 | if self.test_num_crop == 1:
143 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
144 | / (self.test_num_crop)
145 | else:
146 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
147 | / (self.test_num_crop - 1)
148 | temporal_start = chunk_nb # 0/1
149 | spatial_start = int(split_nb * spatial_step)
150 | if buffer.shape[1] >= buffer.shape[2]:
151 | buffer = buffer[temporal_start::2, \
152 | spatial_start:spatial_start + self.short_side_size, :, :]
153 | else:
154 | buffer = buffer[temporal_start::2, \
155 | :, spatial_start:spatial_start + self.short_side_size, :]
156 |
157 | buffer = self.data_transform(buffer)
158 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
159 | chunk_nb, split_nb
160 | else:
161 | raise NameError('mode {} unkown'.format(self.mode))
162 |
163 | def _aug_frame(
164 | self,
165 | buffer,
166 | args,
167 | ):
168 |
169 | aug_transform = video_transforms.create_random_augment(
170 | input_size=(self.crop_size, self.crop_size),
171 | auto_augment=args.aa,
172 | interpolation=args.train_interpolation,
173 | )
174 |
175 | buffer = [
176 | transforms.ToPILImage()(frame) for frame in buffer
177 | ]
178 |
179 | buffer = aug_transform(buffer)
180 |
181 | buffer = [transforms.ToTensor()(img) for img in buffer]
182 | buffer = torch.stack(buffer) # T C H W
183 | buffer = buffer.permute(0, 2, 3, 1) # T H W C
184 |
185 | # T H W C
186 | buffer = tensor_normalize(
187 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
188 | )
189 | # T H W C -> C T H W.
190 | buffer = buffer.permute(3, 0, 1, 2)
191 | # Perform data augmentation.
192 | scl, asp = (
193 | [0.08, 1.0],
194 | [0.75, 1.3333],
195 | )
196 |
197 | buffer = spatial_sampling(
198 | buffer,
199 | spatial_idx=-1,
200 | min_scale=256,
201 | max_scale=320,
202 | crop_size=self.crop_size,
203 | random_horizontal_flip=False if args.data_set == 'SSV2' else True,
204 | inverse_uniform_sampling=False,
205 | aspect_ratio=asp,
206 | scale=scl,
207 | motion_shift=False
208 | )
209 |
210 | if self.rand_erase:
211 | erase_transform = RandomErasing(
212 | args.reprob,
213 | mode=args.remode,
214 | max_count=args.recount,
215 | num_splits=args.recount,
216 | device="cpu",
217 | )
218 | buffer = buffer.permute(1, 0, 2, 3)
219 | buffer = erase_transform(buffer)
220 | buffer = buffer.permute(1, 0, 2, 3)
221 |
222 | return buffer
223 |
224 |
225 | def loadvideo_decord(self, sample, sample_rate_scale=1):
226 | """Load video content using Decord"""
227 | fname = sample
228 |
229 | if not (os.path.exists(fname)):
230 | return []
231 |
232 | # avoid hanging issue
233 | if os.path.getsize(fname) < 1 * 1024:
234 | print('SKIP: ', fname, " - ", os.path.getsize(fname))
235 | return []
236 | try:
237 | if self.keep_aspect_ratio:
238 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
239 | else:
240 | vr = VideoReader(fname, width=self.new_width, height=self.new_height,
241 | num_threads=1, ctx=cpu(0))
242 | except:
243 | print("video cannot be loaded by decord: ", fname)
244 | return []
245 |
246 | if self.mode == 'test':
247 | all_index = []
248 | tick = len(vr) / float(self.num_segment)
249 | all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
250 | [int(tick * x) for x in range(self.num_segment)]))
251 | while len(all_index) < (self.num_segment * self.test_num_segment):
252 | all_index.append(all_index[-1])
253 | all_index = list(np.sort(np.array(all_index)))
254 | vr.seek(0)
255 | buffer = vr.get_batch(all_index).asnumpy()
256 | return buffer
257 |
258 | # handle temporal segments
259 | average_duration = len(vr) // self.num_segment
260 | all_index = []
261 | if average_duration > 0:
262 | all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
263 | size=self.num_segment))
264 | elif len(vr) > self.num_segment:
265 | all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
266 | else:
267 | all_index += list(np.zeros((self.num_segment,)))
268 | all_index = list(np.array(all_index))
269 | vr.seek(0)
270 | buffer = vr.get_batch(all_index).asnumpy()
271 | return buffer
272 |
273 | def __len__(self):
274 | if self.mode != 'test':
275 | return len(self.dataset_samples)
276 | else:
277 | return len(self.test_dataset)
278 |
279 |
280 | def spatial_sampling(
281 | frames,
282 | spatial_idx=-1,
283 | min_scale=256,
284 | max_scale=320,
285 | crop_size=224,
286 | random_horizontal_flip=True,
287 | inverse_uniform_sampling=False,
288 | aspect_ratio=None,
289 | scale=None,
290 | motion_shift=False,
291 | ):
292 | """
293 | Perform spatial sampling on the given video frames. If spatial_idx is
294 | -1, perform random scale, random crop, and random flip on the given
295 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
296 | with the given spatial_idx.
297 | Args:
298 | frames (tensor): frames of images sampled from the video. The
299 | dimension is `num frames` x `height` x `width` x `channel`.
300 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
301 | or 2, perform left, center, right crop if width is larger than
302 | height, and perform top, center, buttom crop if height is larger
303 | than width.
304 | min_scale (int): the minimal size of scaling.
305 | max_scale (int): the maximal size of scaling.
306 | crop_size (int): the size of height and width used to crop the
307 | frames.
308 | inverse_uniform_sampling (bool): if True, sample uniformly in
309 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
310 | scale. If False, take a uniform sample from [min_scale,
311 | max_scale].
312 | aspect_ratio (list): Aspect ratio range for resizing.
313 | scale (list): Scale range for resizing.
314 | motion_shift (bool): Whether to apply motion shift for resizing.
315 | Returns:
316 | frames (tensor): spatially sampled frames.
317 | """
318 | assert spatial_idx in [-1, 0, 1, 2]
319 | if spatial_idx == -1:
320 | if aspect_ratio is None and scale is None:
321 | frames, _ = video_transforms.random_short_side_scale_jitter(
322 | images=frames,
323 | min_size=min_scale,
324 | max_size=max_scale,
325 | inverse_uniform_sampling=inverse_uniform_sampling,
326 | )
327 | frames, _ = video_transforms.random_crop(frames, crop_size)
328 | else:
329 | transform_func = (
330 | video_transforms.random_resized_crop_with_shift
331 | if motion_shift
332 | else video_transforms.random_resized_crop
333 | )
334 | frames = transform_func(
335 | images=frames,
336 | target_height=crop_size,
337 | target_width=crop_size,
338 | scale=scale,
339 | ratio=aspect_ratio,
340 | )
341 | if random_horizontal_flip:
342 | frames, _ = video_transforms.horizontal_flip(0.5, frames)
343 | else:
344 | # The testing is deterministic and no jitter should be performed.
345 | # min_scale, max_scale, and crop_size are expect to be the same.
346 | assert len({min_scale, max_scale, crop_size}) == 1
347 | frames, _ = video_transforms.random_short_side_scale_jitter(
348 | frames, min_scale, max_scale
349 | )
350 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
351 | return frames
352 |
353 |
354 | def tensor_normalize(tensor, mean, std):
355 | """
356 | Normalize a given tensor by subtracting the mean and dividing the std.
357 | Args:
358 | tensor (tensor): tensor to normalize.
359 | mean (tensor or list): mean value to subtract.
360 | std (tensor or list): std to divide.
361 | """
362 | if tensor.dtype == torch.uint8:
363 | tensor = tensor.float()
364 | tensor = tensor / 255.0
365 | if type(mean) == list:
366 | mean = torch.tensor(mean)
367 | if type(std) == list:
368 | std = torch.tensor(std)
369 | tensor = tensor - mean
370 | tensor = tensor / std
371 | return tensor
--------------------------------------------------------------------------------
/dataset/epic.py:
--------------------------------------------------------------------------------
1 | import os
2 | from statistics import NormalDist
3 | import numpy as np
4 | import torch
5 | from torchvision import transforms
6 | from util_tools.random_erasing import RandomErasing
7 | import warnings
8 | from decord import VideoReader, cpu
9 | from torch.utils.data import Dataset
10 | import util_tools.video_transforms as video_transforms
11 | import util_tools.volume_transforms as volume_transforms
12 |
13 | class EpicVideoClsDataset(Dataset):
14 |
15 | def __init__(self, anno_path, data_path, mode='train', clip_len=8,
16 | crop_size=224, short_side_size=256, new_height=256,
17 | new_width=340, keep_aspect_ratio=True, num_segment=1,
18 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None):
19 | self.anno_path = anno_path
20 | self.data_path = data_path
21 | self.mode = mode
22 | self.clip_len = clip_len
23 | self.crop_size = crop_size
24 | self.short_side_size = short_side_size
25 | self.new_height = new_height
26 | self.new_width = new_width
27 | self.keep_aspect_ratio = keep_aspect_ratio
28 | self.num_segment = num_segment
29 | self.test_num_segment = test_num_segment
30 | self.num_crop = num_crop
31 | self.test_num_crop = test_num_crop
32 | self.args = args
33 | self.aug = False
34 | self.rand_erase = False
35 | if self.mode in ['train']:
36 | self.aug = True
37 | if self.args.reprob > 0:
38 | self.rand_erase = True
39 | if VideoReader is None:
40 | raise ImportError("Unable to import `decord` which is required to read videos.")
41 |
42 | import pandas as pd
43 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=',')
44 | self.dataset_samples = list(cleaned.values[:, 0])
45 | verb_label_array = list(cleaned.values[:, 1]) # verb
46 | noun_label_array = list(cleaned.values[:, 2]) # noun
47 | self.label_array = np.stack((noun_label_array, verb_label_array), axis=1) # label [noun, verb] sequence
48 |
49 | if (mode == 'train'):
50 | pass
51 |
52 | elif (mode == 'validation'):
53 | self.data_transform = video_transforms.Compose([
54 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
55 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
56 | volume_transforms.ClipToTensor(),
57 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
58 | std=[0.229, 0.224, 0.225])
59 | ])
60 | elif (mode == 'test'):
61 | self.data_resize = video_transforms.Compose([
62 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
63 | ])
64 | self.data_transform = video_transforms.Compose([
65 | volume_transforms.ClipToTensor(),
66 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
67 | std=[0.229, 0.224, 0.225])
68 | ])
69 | self.test_seg = []
70 | self.test_dataset = []
71 | self.test_label_array = []
72 | for ck in range(self.test_num_segment):
73 | for cp in range(self.test_num_crop):
74 | for idx in range(len(self.label_array)):
75 | sample_label = self.label_array[idx]
76 | self.test_label_array.append(sample_label)
77 | self.test_dataset.append(self.dataset_samples[idx])
78 | self.test_seg.append((ck, cp))
79 |
80 | def __getitem__(self, index):
81 | if self.mode == 'train':
82 | args = self.args
83 | scale_t = 1
84 |
85 | sample = self.dataset_samples[index] + '.mp4'
86 | sample = os.path.join(self.data_path, sample)
87 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
88 | if len(buffer) == 0:
89 | while len(buffer) == 0:
90 | warnings.warn("video {} not correctly loaded during training".format(sample))
91 | index = np.random.randint(self.__len__())
92 | sample = self.dataset_samples[index]
93 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
94 |
95 | if args.num_sample > 1:
96 | frame_list = []
97 | label_list = []
98 | index_list = []
99 | for _ in range(args.num_sample):
100 | new_frames = self._aug_frame(buffer, args)
101 | label = self.label_array[index]
102 | frame_list.append(new_frames)
103 | label_list.append(label)
104 | index_list.append(index)
105 | return frame_list, label_list, index_list, {}
106 | else:
107 | buffer = self._aug_frame(buffer, args)
108 |
109 | return buffer, self.label_array[index], index, {}
110 |
111 | elif self.mode == 'validation':
112 | sample = self.dataset_samples[index] + '.mp4'
113 | sample = os.path.join(self.data_path, sample)
114 | buffer = self.loadvideo_decord(sample)
115 | if len(buffer) == 0:
116 | while len(buffer) == 0:
117 | warnings.warn("video {} not correctly loaded during validation".format(sample))
118 | index = np.random.randint(self.__len__())
119 | sample = self.dataset_samples[index]
120 | buffer = self.loadvideo_decord(sample)
121 | buffer = self.data_transform(buffer)
122 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
123 |
124 | elif self.mode == 'test':
125 | sample = self.test_dataset[index] + '.mp4'
126 | sample = os.path.join(self.data_path, sample)
127 | chunk_nb, split_nb = self.test_seg[index]
128 | buffer = self.loadvideo_decord(sample)
129 |
130 | while len(buffer) == 0:
131 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
132 | str(self.test_dataset[index]), chunk_nb, split_nb))
133 | index = np.random.randint(self.__len__())
134 | sample = self.test_dataset[index]
135 | chunk_nb, split_nb = self.test_seg[index]
136 | buffer = self.loadvideo_decord(sample)
137 |
138 | buffer = self.data_resize(buffer)
139 | if isinstance(buffer, list):
140 | buffer = np.stack(buffer, 0)
141 |
142 | if self.test_num_crop == 1:
143 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
144 | / (self.test_num_crop)
145 | else:
146 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
147 | / (self.test_num_crop - 1)
148 | temporal_start = chunk_nb # 0/1
149 | spatial_start = int(split_nb * spatial_step)
150 | if buffer.shape[1] >= buffer.shape[2]:
151 | buffer = buffer[temporal_start::2, \
152 | spatial_start:spatial_start + self.short_side_size, :, :]
153 | else:
154 | buffer = buffer[temporal_start::2, \
155 | :, spatial_start:spatial_start + self.short_side_size, :]
156 |
157 | buffer = self.data_transform(buffer)
158 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
159 | chunk_nb, split_nb
160 | else:
161 | raise NameError('mode {} unkown'.format(self.mode))
162 |
163 |
164 |
165 | def _aug_frame(self,buffer,args):
166 |
167 | aug_transform = video_transforms.create_random_augment(
168 | input_size=(self.crop_size, self.crop_size),
169 | auto_augment=args.aa,
170 | interpolation=args.train_interpolation,
171 | )
172 |
173 | buffer = [
174 | transforms.ToPILImage()(frame) for frame in buffer
175 | ]
176 |
177 | buffer = aug_transform(buffer)
178 |
179 | buffer = [transforms.ToTensor()(img) for img in buffer]
180 | buffer = torch.stack(buffer) # T C H W
181 | buffer = buffer.permute(0, 2, 3, 1) # T H W C
182 |
183 | # T H W C
184 | buffer = tensor_normalize(
185 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
186 | )
187 | # T H W C -> C T H W.
188 | buffer = buffer.permute(3, 0, 1, 2)
189 | # Perform data augmentation.
190 | scl, asp = (
191 | [0.08, 1.0],
192 | [0.75, 1.3333],
193 | )
194 |
195 | buffer = spatial_sampling(
196 | buffer,
197 | spatial_idx=-1,
198 | min_scale=256,
199 | max_scale=320,
200 | crop_size=self.crop_size,
201 | random_horizontal_flip=False if args.data_set == 'SSV2' else True,
202 | inverse_uniform_sampling=False,
203 | aspect_ratio=asp,
204 | scale=scl,
205 | motion_shift=False
206 | )
207 |
208 | if self.rand_erase:
209 | erase_transform = RandomErasing(
210 | args.reprob,
211 | mode=args.remode,
212 | max_count=args.recount,
213 | num_splits=args.recount,
214 | device="cpu",
215 | )
216 | buffer = buffer.permute(1, 0, 2, 3)
217 | buffer = erase_transform(buffer)
218 | buffer = buffer.permute(1, 0, 2, 3)
219 |
220 | return buffer
221 |
222 |
223 | def loadvideo_decord(self, sample, sample_rate_scale=1):
224 | """Load video content using Decord"""
225 | fname = sample
226 |
227 | if not (os.path.exists(fname)):
228 | return []
229 |
230 | # avoid hanging issue
231 | if os.path.getsize(fname) < 1 * 1024:
232 | print('SKIP: ', fname, " - ", os.path.getsize(fname))
233 | return []
234 | try:
235 | if self.keep_aspect_ratio:
236 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
237 | else:
238 | vr = VideoReader(fname, width=self.new_width, height=self.new_height,
239 | num_threads=1, ctx=cpu(0))
240 | except:
241 | print("video cannot be loaded by decord: ", fname)
242 | return []
243 |
244 | if self.mode == 'test':
245 | all_index = []
246 | tick = len(vr) / float(self.num_segment)
247 | all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] +
248 | [int(tick * x) for x in range(self.num_segment)]))
249 | while len(all_index) < (self.num_segment * self.test_num_segment):
250 | all_index.append(all_index[-1])
251 | all_index = list(np.sort(np.array(all_index)))
252 | vr.seek(0)
253 | buffer = vr.get_batch(all_index).asnumpy()
254 | return buffer
255 |
256 | # handle temporal segments
257 | average_duration = len(vr) // self.num_segment
258 | all_index = []
259 | if average_duration > 0:
260 | all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration,
261 | size=self.num_segment))
262 | elif len(vr) > self.num_segment:
263 | all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment)))
264 | else:
265 | all_index += list(np.zeros((self.num_segment,)))
266 | all_index = list(np.array(all_index))
267 | vr.seek(0)
268 | buffer = vr.get_batch(all_index).asnumpy()
269 | return buffer
270 |
271 | def __len__(self):
272 | if self.mode != 'test':
273 | return len(self.dataset_samples)
274 | else:
275 | return len(self.test_dataset)
276 |
277 |
278 | def spatial_sampling(
279 | frames,
280 | spatial_idx=-1,
281 | min_scale=256,
282 | max_scale=320,
283 | crop_size=224,
284 | random_horizontal_flip=True,
285 | inverse_uniform_sampling=False,
286 | aspect_ratio=None,
287 | scale=None,
288 | motion_shift=False,
289 | ):
290 | """
291 | Perform spatial sampling on the given video frames. If spatial_idx is
292 | -1, perform random scale, random crop, and random flip on the given
293 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
294 | with the given spatial_idx.
295 | Args:
296 | frames (tensor): frames of images sampled from the video. The
297 | dimension is `num frames` x `height` x `width` x `channel`.
298 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
299 | or 2, perform left, center, right crop if width is larger than
300 | height, and perform top, center, buttom crop if height is larger
301 | than width.
302 | min_scale (int): the minimal size of scaling.
303 | max_scale (int): the maximal size of scaling.
304 | crop_size (int): the size of height and width used to crop the
305 | frames.
306 | inverse_uniform_sampling (bool): if True, sample uniformly in
307 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
308 | scale. If False, take a uniform sample from [min_scale,
309 | max_scale].
310 | aspect_ratio (list): Aspect ratio range for resizing.
311 | scale (list): Scale range for resizing.
312 | motion_shift (bool): Whether to apply motion shift for resizing.
313 | Returns:
314 | frames (tensor): spatially sampled frames.
315 | """
316 | assert spatial_idx in [-1, 0, 1, 2]
317 | if spatial_idx == -1:
318 | if aspect_ratio is None and scale is None:
319 | frames, _ = video_transforms.random_short_side_scale_jitter(
320 | images=frames,
321 | min_size=min_scale,
322 | max_size=max_scale,
323 | inverse_uniform_sampling=inverse_uniform_sampling,
324 | )
325 | frames, _ = video_transforms.random_crop(frames, crop_size)
326 | else:
327 | transform_func = (
328 | video_transforms.random_resized_crop_with_shift
329 | if motion_shift
330 | else video_transforms.random_resized_crop
331 | )
332 | frames = transform_func(
333 | images=frames,
334 | target_height=crop_size,
335 | target_width=crop_size,
336 | scale=scale,
337 | ratio=aspect_ratio,
338 | )
339 | if random_horizontal_flip:
340 | frames, _ = video_transforms.horizontal_flip(0.5, frames)
341 | else:
342 | # The testing is deterministic and no jitter should be performed.
343 | # min_scale, max_scale, and crop_size are expect to be the same.
344 | assert len({min_scale, max_scale, crop_size}) == 1
345 | frames, _ = video_transforms.random_short_side_scale_jitter(
346 | frames, min_scale, max_scale
347 | )
348 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
349 | return frames
350 |
351 |
352 | def tensor_normalize(tensor, mean, std):
353 | """
354 | Normalize a given tensor by subtracting the mean and dividing the std.
355 | Args:
356 | tensor (tensor): tensor to normalize.
357 | mean (tensor or list): mean value to subtract.
358 | std (tensor or list): std to divide.
359 | """
360 | if tensor.dtype == torch.uint8:
361 | tensor = tensor.float()
362 | tensor = tensor / 255.0
363 | if type(mean) == list:
364 | mean = torch.tensor(mean)
365 | if type(std) == list:
366 | std = torch.tensor(std)
367 | tensor = tensor - mean
368 | tensor = tensor / std
369 | return tensor
--------------------------------------------------------------------------------
/engine_for_compomodel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import sys
5 | from typing import Iterable, Optional
6 | import torch
7 | from util_tools.mixup import Mixup
8 | from timm.utils import accuracy, ModelEma
9 | import util_tools.utils as utils
10 | from scipy.special import softmax
11 | from einops import rearrange
12 |
13 |
14 | def composition_train_class_batch(model, samples, target_noun, target_verb, criterion):
15 | outputs_noun, outputs_verb = model(samples)
16 | loss_noun = criterion(outputs_noun, target_noun)
17 | loss_verb = criterion(outputs_verb, target_verb)
18 | total_loss = loss_noun + loss_verb
19 | return total_loss, loss_noun, loss_verb, outputs_noun, outputs_verb
20 |
21 |
22 |
23 | def get_loss_scale_for_deepspeed(model):
24 | optimizer = model.optimizer
25 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
26 |
27 |
28 | def train_one_epoch(args, model: torch.nn.Module, criterion: torch.nn.Module,
29 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
30 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
31 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
32 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
33 | num_training_steps_per_epoch=None, update_freq=None):
34 | model.train(True)
35 | metric_logger = utils.MetricLogger(delimiter=" ")
36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
38 | header = 'Epoch: [{}]'.format(epoch)
39 | print_freq = 10
40 |
41 | if loss_scaler is None:
42 | model.zero_grad()
43 | model.micro_steps = 0
44 | else:
45 | optimizer.zero_grad()
46 |
47 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
48 | step = data_iter_step // update_freq
49 | if step >= num_training_steps_per_epoch:
50 | continue
51 | it = start_steps + step # global training iteration
52 | # Update LR & WD for the first acc
53 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
54 | for i, param_group in enumerate(optimizer.param_groups):
55 | if lr_schedule_values is not None:
56 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
57 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
58 | param_group["weight_decay"] = wd_schedule_values[it]
59 |
60 | samples = samples.to(device, non_blocking=True)
61 | targets = targets.to(device, non_blocking=True)
62 |
63 | if mixup_fn is not None:
64 | samples, target_noun, target_verb = mixup_fn(samples, targets)
65 |
66 | if loss_scaler is None:
67 | samples = samples.half()
68 | loss, loss_noun, loss_verb, outputs_noun, outputs_verb = composition_train_class_batch(
69 | model, samples, target_noun, target_verb, criterion)
70 | else:
71 | with torch.cuda.amp.autocast():
72 | samples = samples.half()
73 | loss, outputs_noun, outpus_verb = composition_train_class_batch(
74 | model, samples, target_noun, target_verb, criterion)
75 | loss_value = loss.item()
76 |
77 | if not math.isfinite(loss_value):
78 | print("Loss is {}, stopping training".format(loss_value))
79 | sys.exit(1)
80 |
81 | if loss_scaler is None:
82 | loss /= update_freq
83 | model.backward(loss)
84 | model.step()
85 |
86 | if (data_iter_step + 1) % update_freq == 0:
87 | # model.zero_grad()
88 | # Deepspeed will call step() & model.zero_grad() automatic
89 | if model_ema is not None:
90 | model_ema.update(model)
91 | grad_norm = None
92 | loss_scale_value = get_loss_scale_for_deepspeed(model)
93 | else:
94 | # this attribute is added by timm on one optimizer (adahessian)
95 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
96 | loss /= update_freq
97 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
98 | parameters=model.parameters(), create_graph=is_second_order,
99 | update_grad=(data_iter_step + 1) % update_freq == 0)
100 | if (data_iter_step + 1) % update_freq == 0:
101 | optimizer.zero_grad()
102 | if model_ema is not None:
103 | model_ema.update(model)
104 | loss_scale_value = loss_scaler.state_dict()["scale"]
105 |
106 | torch.cuda.synchronize()
107 |
108 | if mixup_fn is None:
109 | pass
110 | # class_acc = (output.max(-1)[-1] == targets).float().mean()
111 | else:
112 | class_acc = None
113 | metric_logger.update(loss=loss_value)
114 | metric_logger.update(loss_noun=loss_noun)
115 | metric_logger.update(loss_verb=loss_verb)
116 | metric_logger.update(class_acc=class_acc)
117 | metric_logger.update(loss_scale=loss_scale_value)
118 | min_lr = 10.
119 | max_lr = 0.
120 | for group in optimizer.param_groups:
121 | min_lr = min(min_lr, group["lr"])
122 | max_lr = max(max_lr, group["lr"])
123 |
124 | metric_logger.update(lr=max_lr)
125 | metric_logger.update(min_lr=min_lr)
126 | weight_decay_value = None
127 | for group in optimizer.param_groups:
128 | if group["weight_decay"] > 0:
129 | weight_decay_value = group["weight_decay"]
130 | metric_logger.update(weight_decay=weight_decay_value)
131 | metric_logger.update(grad_norm=grad_norm)
132 |
133 | if log_writer is not None:
134 | log_writer.update(loss=loss_value, head="loss")
135 | log_writer.update(class_acc=class_acc, head="loss")
136 | log_writer.update(loss_scale=loss_scale_value, head="opt")
137 | log_writer.update(lr=max_lr, head="opt")
138 | log_writer.update(min_lr=min_lr, head="opt")
139 | log_writer.update(weight_decay=weight_decay_value, head="opt")
140 | log_writer.update(grad_norm=grad_norm, head="opt")
141 |
142 | log_writer.set_step()
143 |
144 | # gather the stats from all processes
145 | metric_logger.synchronize_between_processes()
146 | print("Averaged stats:", metric_logger)
147 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
148 |
149 |
150 | @torch.no_grad()
151 | def validation_one_epoch(args, data_loader, model, device):
152 | criterion = torch.nn.CrossEntropyLoss()
153 |
154 | metric_logger = utils.MetricLogger(delimiter=" ")
155 | header = 'Val:'
156 |
157 | # switch to evaluation mode
158 | model.eval()
159 |
160 | for batch in metric_logger.log_every(data_loader, 10, header):
161 | samples = batch[0]
162 | target = batch[1]
163 | batch_size = samples.shape[0]
164 | samples = samples.to(device, non_blocking=True)
165 | target = target.to(device, non_blocking=True)
166 | action_target = (target[:,1] * 1000) + target[:,0]
167 |
168 | # compute output
169 | with torch.cuda.amp.autocast():
170 | output_noun, output_verb = model(samples)
171 | loss_noun = criterion(output_noun, target[:,0])
172 | loss_verb = criterion(output_verb, target[:,1])
173 |
174 | acc1_action, acc5_action = action_accuracy(output_noun, output_verb, action_target, topk=(1,5))
175 | acc1_noun, acc5_noun = accuracy(output_noun, target[:,0], topk=(1, 5))
176 | acc1_verb, acc5_verb = accuracy(output_verb, target[:,1], topk=(1, 5))
177 |
178 | metric_logger.update(loss_noun=loss_noun.item())
179 | metric_logger.update(loss_verb=loss_verb.item())
180 | metric_logger.update(acc1_action=acc1_action.item())
181 | metric_logger.update(acc1_noun=acc1_noun.item())
182 | metric_logger.update(acc1_verb=acc1_verb.item())
183 | metric_logger.update(acc5_noun=acc5_noun.item())
184 | metric_logger.update(acc5_verb=acc5_verb.item())
185 | metric_logger.meters['acc1_noun'].update(acc1_noun.item(), n=batch_size)
186 | metric_logger.meters['acc1_verb'].update(acc1_verb.item(), n=batch_size)
187 | metric_logger.meters['acc5_noun'].update(acc5_noun.item(), n=batch_size)
188 | metric_logger.meters['acc5_verb'].update(acc5_verb.item(), n=batch_size)
189 | # gather the stats from all processes
190 | metric_logger.synchronize_between_processes()
191 | print('* Acc_@1_action {top1_action.global_avg:.3f} Acc_@1_noun {top1_noun.global_avg:.3f} Acc_@1_verb {top1_verb.global_avg:.3f} Acc@5_noun {top5_noun.global_avg:.3f} Acc@5_verb {top5_verb.global_avg:.3f} loss_noun {losses_noun.global_avg:.3f} loss_verb {losses_verb.global_avg:.3f}'
192 | .format(top1_action=metric_logger.acc1_action, top1_noun=metric_logger.acc1_noun, top1_verb=metric_logger.acc1_verb, top5_noun=metric_logger.acc5_noun, top5_verb=metric_logger.acc5_verb, losses_noun=metric_logger.loss_noun, losses_verb=metric_logger.loss_verb))
193 |
194 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
195 |
196 |
197 |
198 | @torch.no_grad()
199 | def final_test(args, data_loader, model, device, file):
200 | criterion = torch.nn.CrossEntropyLoss()
201 |
202 | metric_logger = utils.MetricLogger(delimiter=" ")
203 | header = 'Test:'
204 |
205 | # switch to evaluation mode
206 | model.eval()
207 | final_result = []
208 |
209 | for batch in metric_logger.log_every(data_loader, 10, header):
210 | samples = batch[0]
211 | target = batch[1]
212 | ids = batch[2]
213 | chunk_nb = batch[3]
214 | split_nb = batch[4]
215 | batch_size = samples.shape[0]
216 | samples = samples.to(device, non_blocking=True)
217 | target = target.to(device, non_blocking=True)
218 | action_target = (target[:,1] * 1000) + target[:,0]
219 |
220 | # compute output
221 | with torch.cuda.amp.autocast():
222 | output_noun, output_verb = model(samples)
223 | loss_noun = criterion(output_noun, target[:,0])
224 | loss_verb = criterion(output_verb, target[:,1])
225 |
226 | for i in range(output_noun.size(0)):
227 | string = "{} {} {} {} {} {} {} {}\n".format(ids[i], \
228 | str(output_noun.data[i].cpu().numpy().tolist()), \
229 | str(output_verb.data[i].cpu().numpy().tolist()), \
230 | str(int(action_target[i].cpu().numpy())), \
231 | str(int(target[i,0].cpu().numpy())), \
232 | str(int(target[i,1].cpu().numpy())), \
233 | str(int(chunk_nb[i].cpu().numpy())), \
234 | str(int(split_nb[i].cpu().numpy())))
235 | final_result.append(string)
236 |
237 | acc1_action, acc5_action = action_accuracy(output_noun, output_verb, action_target, topk=(1,5))
238 | acc1_noun, acc5_noun = accuracy(output_noun, target[:,0], topk=(1, 5))
239 | acc1_verb, acc5_verb = accuracy(output_verb, target[:,1], topk=(1, 5))
240 |
241 | metric_logger.update(loss_noun=loss_noun.item())
242 | metric_logger.update(loss_verb=loss_verb.item())
243 | metric_logger.update(acc1_action=acc1_action.item())
244 | metric_logger.update(acc1_noun=acc1_noun.item())
245 | metric_logger.update(acc1_verb=acc1_verb.item())
246 | metric_logger.update(acc5_noun=acc5_noun.item())
247 | metric_logger.update(acc5_verb=acc5_verb.item())
248 | metric_logger.meters['acc1_action'].update(acc1_action.item(), n=batch_size)
249 | metric_logger.meters['acc1_noun'].update(acc1_noun.item(), n=batch_size)
250 | metric_logger.meters['acc1_verb'].update(acc1_verb.item(), n=batch_size)
251 | metric_logger.meters['acc5_noun'].update(acc5_noun.item(), n=batch_size)
252 | metric_logger.meters['acc5_verb'].update(acc5_verb.item(), n=batch_size)
253 |
254 | if not os.path.exists(file):
255 | os.mknod(file)
256 | with open(file, 'w') as f:
257 | f.write("{}, {}\n".format(acc1_noun, acc5_noun))
258 | for line in final_result:
259 | f.write(line)
260 | # gather the stats from all processes
261 | metric_logger.synchronize_between_processes()
262 | print('* Acc_@1_action {top1_action.global_avg:.3f} Acc_@1_noun {top1_noun.global_avg:.3f} Acc_@1_verb {top1_verb.global_avg:.3f} Acc@5_noun {top5_noun.global_avg:.3f} Acc@5_verb {top5_verb.global_avg:.3f} loss_noun {losses_noun.global_avg:.3f} loss_verb {losses_verb.global_avg:.3f}'
263 | .format(top1_action=metric_logger.acc1_action, top1_noun=metric_logger.acc1_noun, top1_verb=metric_logger.acc1_verb, top5_noun=metric_logger.acc5_noun, top5_verb=metric_logger.acc5_verb, losses_noun=metric_logger.loss_noun, losses_verb=metric_logger.loss_verb))
264 |
265 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
266 |
267 |
268 | def merge(eval_path, num_tasks):
269 | dict_feats_noun = {}
270 | dict_feats_verb = {}
271 | dict_label = {}
272 | dict_action_label ={}
273 | dict_pos = {}
274 | print("Reading individual output files")
275 |
276 | for x in range(num_tasks):
277 | file = os.path.join(eval_path, str(x) + '.txt')
278 | lines = open(file, 'r').readlines()[1:]
279 | for line in lines:
280 | line = line.strip()
281 | name = line.split('[')[0]
282 | label_action = line.split(']')[2].split(' ')[1]
283 | label_noun = line.split(']')[2].split(' ')[2]
284 | label_verb = line.split(']')[2].split(' ')[3]
285 | chunk_nb = line.split(']')[2].split(' ')[4]
286 | split_nb = line.split(']')[2].split(' ')[5]
287 | data_noun = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',')
288 | data_verb = np.fromstring(line.split('[')[2].split(']')[0], dtype=np.float, sep=',')
289 | data_noun = softmax(data_noun)
290 | data_verb = softmax(data_verb)
291 |
292 | if not name in dict_feats_noun:
293 | dict_feats_noun[name] = []
294 | dict_feats_verb[name] = []
295 | dict_label[name] = 0
296 | dict_action_label[name] = 0
297 | dict_pos[name] = []
298 | if chunk_nb + split_nb in dict_pos[name]:
299 | continue
300 | dict_feats_noun[name].append(data_noun)
301 | dict_feats_verb[name].append(data_verb)
302 | dict_pos[name].append(chunk_nb + split_nb)
303 | dict_label[name] = (label_noun, label_verb)
304 | dict_action_label[name] = label_action
305 | print("Computing final results")
306 |
307 | input_lst = []
308 | print(len(dict_feats_noun))
309 | for i, item in enumerate(dict_feats_noun):
310 | input_lst.append([i, item, dict_feats_noun[item], dict_feats_verb[item], dict_label[item], dict_action_label[item]])
311 | from multiprocessing import Pool
312 | p = Pool(8)
313 | ans = p.map(compute_video, input_lst)
314 | top1_action = [x[2] for x in ans]
315 | top5_action = [x[3] for x in ans]
316 | top1_noun = [x[4] for x in ans]
317 | top1_verb = [x[5] for x in ans]
318 | top5_noun = [x[6] for x in ans]
319 | top5_verb = [x[7] for x in ans]
320 | final_top1_noun ,final_top5_noun, final_top1_verb, final_top5_verb = np.mean(top1_noun), np.mean(top5_noun), np.mean(top1_verb), np.mean(top5_verb)
321 | final_top1_action, final_top5_action = np.mean(top1_action), np.mean(top5_action)
322 | return final_top1_action*100, final_top5_action*100, final_top1_noun*100 ,final_top5_noun*100, final_top1_verb*100, final_top5_verb*100
323 |
324 | def compute_video(lst):
325 | i, video_id, data_noun, data_verb, label, label_action = lst
326 | feat_noun = [x for x in data_noun]
327 | feat_verb = [x for x in data_verb]
328 | feat_noun = np.mean(feat_noun, axis=0)
329 | feat_verb = np.mean(feat_verb, axis=0)
330 | pred_noun = np.argmax(feat_noun)
331 | pred_verb = np.argmax(feat_verb)
332 | pred_action = (pred_verb * 1000) + pred_noun
333 | label_noun, label_verb = label
334 | top1_action = (int(pred_action) == int(label_action)) * 1.0
335 | top5_action = (int(label_noun) in np.argsort(-feat_noun)[:5] and int(label_verb) in np.argsort(-feat_verb)[:5]) * 1.0
336 | top1_noun = (int(pred_noun) == int(label_noun)) * 1.0
337 | top5_noun = (int(label_noun) in np.argsort(-feat_noun)[:5]) * 1.0
338 | top1_verb = (int(pred_verb) == int(label_verb)) * 1.0
339 | top5_verb = (int(label_verb) in np.argsort(-feat_verb)[:5]) * 1.0
340 | return [pred_noun, pred_verb, top1_action, top5_action, top1_noun, top1_verb, top5_noun, top5_verb]
341 |
342 | def action_accuracy(output_noun, output_verb, target, topk=(1,)):
343 | """Computes the accuracy over the k top predictions for the specified values of k"""
344 | maxk = max(topk)
345 | batch_size = target.size(0)
346 | _, pred_noun = output_noun.topk(maxk, 1, True, True)
347 | _, pred_verb = output_verb.topk(maxk, 1, True, True)
348 | pred = (pred_verb * 1000) + pred_noun
349 | pred = pred.t()
350 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
351 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
352 |
--------------------------------------------------------------------------------
/util_tools/rand_augment.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
4 | pulished under an Apache License 2.0.
5 |
6 | COMMENT FROM ORIGINAL:
7 | AutoAugment, RandAugment, and AugMix for PyTorch
8 | This code implements the searched ImageNet policies with various tweaks and
9 | improvements and does not include any of the search code. AA and RA
10 | Implementation adapted from:
11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
12 | AugMix adapted from:
13 | https://github.com/google-research/augmix
14 | Papers:
15 | AutoAugment: Learning Augmentation Policies from Data
16 | https://arxiv.org/abs/1805.09501
17 | Learning Data Augmentation Strategies for Object Detection
18 | https://arxiv.org/abs/1906.11172
19 | RandAugment: Practical automated data augmentation...
20 | https://arxiv.org/abs/1909.13719
21 | AugMix: A Simple Data Processing Method to Improve Robustness and
22 | Uncertainty https://arxiv.org/abs/1912.02781
23 |
24 | Hacked together by / Copyright 2020 Ross Wightman
25 | """
26 |
27 | import math
28 | import numpy as np
29 | import random
30 | import re
31 | import PIL
32 | from PIL import Image, ImageEnhance, ImageOps
33 |
34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
35 |
36 | _FILL = (128, 128, 128)
37 |
38 | # This signifies the max integer that the controller RNN could predict for the
39 | # augmentation scheme.
40 | _MAX_LEVEL = 10.0
41 |
42 | _HPARAMS_DEFAULT = {
43 | "translate_const": 250,
44 | "img_mean": _FILL,
45 | }
46 |
47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
48 |
49 |
50 | def _interpolation(kwargs):
51 | interpolation = kwargs.pop("resample", Image.BILINEAR)
52 | if isinstance(interpolation, (list, tuple)):
53 | return random.choice(interpolation)
54 | else:
55 | return interpolation
56 |
57 |
58 | def _check_args_tf(kwargs):
59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
60 | kwargs.pop("fillcolor")
61 | kwargs["resample"] = _interpolation(kwargs)
62 |
63 |
64 | def shear_x(img, factor, **kwargs):
65 | _check_args_tf(kwargs)
66 | return img.transform(
67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
68 | )
69 |
70 |
71 | def shear_y(img, factor, **kwargs):
72 | _check_args_tf(kwargs)
73 | return img.transform(
74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
75 | )
76 |
77 |
78 | def translate_x_rel(img, pct, **kwargs):
79 | pixels = pct * img.size[0]
80 | _check_args_tf(kwargs)
81 | return img.transform(
82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
83 | )
84 |
85 |
86 | def translate_y_rel(img, pct, **kwargs):
87 | pixels = pct * img.size[1]
88 | _check_args_tf(kwargs)
89 | return img.transform(
90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
91 | )
92 |
93 |
94 | def translate_x_abs(img, pixels, **kwargs):
95 | _check_args_tf(kwargs)
96 | return img.transform(
97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
98 | )
99 |
100 |
101 | def translate_y_abs(img, pixels, **kwargs):
102 | _check_args_tf(kwargs)
103 | return img.transform(
104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
105 | )
106 |
107 |
108 | def rotate(img, degrees, **kwargs):
109 | _check_args_tf(kwargs)
110 | if _PIL_VER >= (5, 2):
111 | return img.rotate(degrees, **kwargs)
112 | elif _PIL_VER >= (5, 0):
113 | w, h = img.size
114 | post_trans = (0, 0)
115 | rotn_center = (w / 2.0, h / 2.0)
116 | angle = -math.radians(degrees)
117 | matrix = [
118 | round(math.cos(angle), 15),
119 | round(math.sin(angle), 15),
120 | 0.0,
121 | round(-math.sin(angle), 15),
122 | round(math.cos(angle), 15),
123 | 0.0,
124 | ]
125 |
126 | def transform(x, y, matrix):
127 | (a, b, c, d, e, f) = matrix
128 | return a * x + b * y + c, d * x + e * y + f
129 |
130 | matrix[2], matrix[5] = transform(
131 | -rotn_center[0] - post_trans[0],
132 | -rotn_center[1] - post_trans[1],
133 | matrix,
134 | )
135 | matrix[2] += rotn_center[0]
136 | matrix[5] += rotn_center[1]
137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
138 | else:
139 | return img.rotate(degrees, resample=kwargs["resample"])
140 |
141 |
142 | def auto_contrast(img, **__):
143 | return ImageOps.autocontrast(img)
144 |
145 |
146 | def invert(img, **__):
147 | return ImageOps.invert(img)
148 |
149 |
150 | def equalize(img, **__):
151 | return ImageOps.equalize(img)
152 |
153 |
154 | def solarize(img, thresh, **__):
155 | return ImageOps.solarize(img, thresh)
156 |
157 |
158 | def solarize_add(img, add, thresh=128, **__):
159 | lut = []
160 | for i in range(256):
161 | if i < thresh:
162 | lut.append(min(255, i + add))
163 | else:
164 | lut.append(i)
165 | if img.mode in ("L", "RGB"):
166 | if img.mode == "RGB" and len(lut) == 256:
167 | lut = lut + lut + lut
168 | return img.point(lut)
169 | else:
170 | return img
171 |
172 |
173 | def posterize(img, bits_to_keep, **__):
174 | if bits_to_keep >= 8:
175 | return img
176 | return ImageOps.posterize(img, bits_to_keep)
177 |
178 |
179 | def contrast(img, factor, **__):
180 | return ImageEnhance.Contrast(img).enhance(factor)
181 |
182 |
183 | def color(img, factor, **__):
184 | return ImageEnhance.Color(img).enhance(factor)
185 |
186 |
187 | def brightness(img, factor, **__):
188 | return ImageEnhance.Brightness(img).enhance(factor)
189 |
190 |
191 | def sharpness(img, factor, **__):
192 | return ImageEnhance.Sharpness(img).enhance(factor)
193 |
194 |
195 | def _randomly_negate(v):
196 | """With 50% prob, negate the value"""
197 | return -v if random.random() > 0.5 else v
198 |
199 |
200 | def _rotate_level_to_arg(level, _hparams):
201 | # range [-30, 30]
202 | level = (level / _MAX_LEVEL) * 30.0
203 | level = _randomly_negate(level)
204 | return (level,)
205 |
206 |
207 | def _enhance_level_to_arg(level, _hparams):
208 | # range [0.1, 1.9]
209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
210 |
211 |
212 | def _enhance_increasing_level_to_arg(level, _hparams):
213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
214 | # range [0.1, 1.9]
215 | level = (level / _MAX_LEVEL) * 0.9
216 | level = 1.0 + _randomly_negate(level)
217 | return (level,)
218 |
219 |
220 | def _shear_level_to_arg(level, _hparams):
221 | # range [-0.3, 0.3]
222 | level = (level / _MAX_LEVEL) * 0.3
223 | level = _randomly_negate(level)
224 | return (level,)
225 |
226 |
227 | def _translate_abs_level_to_arg(level, hparams):
228 | translate_const = hparams["translate_const"]
229 | level = (level / _MAX_LEVEL) * float(translate_const)
230 | level = _randomly_negate(level)
231 | return (level,)
232 |
233 |
234 | def _translate_rel_level_to_arg(level, hparams):
235 | # default range [-0.45, 0.45]
236 | translate_pct = hparams.get("translate_pct", 0.45)
237 | level = (level / _MAX_LEVEL) * translate_pct
238 | level = _randomly_negate(level)
239 | return (level,)
240 |
241 |
242 | def _posterize_level_to_arg(level, _hparams):
243 | # As per Tensorflow TPU EfficientNet impl
244 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
245 | # intensity/severity of augmentation decreases with level
246 | return (int((level / _MAX_LEVEL) * 4),)
247 |
248 |
249 | def _posterize_increasing_level_to_arg(level, hparams):
250 | # As per Tensorflow models research and UDA impl
251 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
252 | # intensity/severity of augmentation increases with level
253 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
254 |
255 |
256 | def _posterize_original_level_to_arg(level, _hparams):
257 | # As per original AutoAugment paper description
258 | # range [4, 8], 'keep 4 up to 8 MSB of image'
259 | # intensity/severity of augmentation decreases with level
260 | return (int((level / _MAX_LEVEL) * 4) + 4,)
261 |
262 |
263 | def _solarize_level_to_arg(level, _hparams):
264 | # range [0, 256]
265 | # intensity/severity of augmentation decreases with level
266 | return (int((level / _MAX_LEVEL) * 256),)
267 |
268 |
269 | def _solarize_increasing_level_to_arg(level, _hparams):
270 | # range [0, 256]
271 | # intensity/severity of augmentation increases with level
272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
273 |
274 |
275 | def _solarize_add_level_to_arg(level, _hparams):
276 | # range [0, 110]
277 | return (int((level / _MAX_LEVEL) * 110),)
278 |
279 |
280 | LEVEL_TO_ARG = {
281 | "AutoContrast": None,
282 | "Equalize": None,
283 | "Invert": None,
284 | "Rotate": _rotate_level_to_arg,
285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
286 | "Posterize": _posterize_level_to_arg,
287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
288 | "PosterizeOriginal": _posterize_original_level_to_arg,
289 | "Solarize": _solarize_level_to_arg,
290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
291 | "SolarizeAdd": _solarize_add_level_to_arg,
292 | "Color": _enhance_level_to_arg,
293 | "ColorIncreasing": _enhance_increasing_level_to_arg,
294 | "Contrast": _enhance_level_to_arg,
295 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
296 | "Brightness": _enhance_level_to_arg,
297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
298 | "Sharpness": _enhance_level_to_arg,
299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
300 | "ShearX": _shear_level_to_arg,
301 | "ShearY": _shear_level_to_arg,
302 | "TranslateX": _translate_abs_level_to_arg,
303 | "TranslateY": _translate_abs_level_to_arg,
304 | "TranslateXRel": _translate_rel_level_to_arg,
305 | "TranslateYRel": _translate_rel_level_to_arg,
306 | }
307 |
308 |
309 | NAME_TO_OP = {
310 | "AutoContrast": auto_contrast,
311 | "Equalize": equalize,
312 | "Invert": invert,
313 | "Rotate": rotate,
314 | "Posterize": posterize,
315 | "PosterizeIncreasing": posterize,
316 | "PosterizeOriginal": posterize,
317 | "Solarize": solarize,
318 | "SolarizeIncreasing": solarize,
319 | "SolarizeAdd": solarize_add,
320 | "Color": color,
321 | "ColorIncreasing": color,
322 | "Contrast": contrast,
323 | "ContrastIncreasing": contrast,
324 | "Brightness": brightness,
325 | "BrightnessIncreasing": brightness,
326 | "Sharpness": sharpness,
327 | "SharpnessIncreasing": sharpness,
328 | "ShearX": shear_x,
329 | "ShearY": shear_y,
330 | "TranslateX": translate_x_abs,
331 | "TranslateY": translate_y_abs,
332 | "TranslateXRel": translate_x_rel,
333 | "TranslateYRel": translate_y_rel,
334 | }
335 |
336 |
337 | class AugmentOp:
338 | """
339 | Apply for video.
340 | """
341 |
342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
343 | hparams = hparams or _HPARAMS_DEFAULT
344 | self.aug_fn = NAME_TO_OP[name]
345 | self.level_fn = LEVEL_TO_ARG[name]
346 | self.prob = prob
347 | self.magnitude = magnitude
348 | self.hparams = hparams.copy()
349 | self.kwargs = {
350 | "fillcolor": hparams["img_mean"]
351 | if "img_mean" in hparams
352 | else _FILL,
353 | "resample": hparams["interpolation"]
354 | if "interpolation" in hparams
355 | else _RANDOM_INTERPOLATION,
356 | }
357 |
358 | # If magnitude_std is > 0, we introduce some randomness
359 | # in the usually fixed policy and sample magnitude from a normal distribution
360 | # with mean `magnitude` and std-dev of `magnitude_std`.
361 | # NOTE This is my own hack, being tested, not in papers or reference impls.
362 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
363 |
364 | def __call__(self, img_list):
365 | if self.prob < 1.0 and random.random() > self.prob:
366 | return img_list
367 | magnitude = self.magnitude
368 | if self.magnitude_std and self.magnitude_std > 0:
369 | magnitude = random.gauss(magnitude, self.magnitude_std)
370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
371 | level_args = (
372 | self.level_fn(magnitude, self.hparams)
373 | if self.level_fn is not None
374 | else ()
375 | )
376 |
377 | if isinstance(img_list, list):
378 | return [
379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
380 | ]
381 | else:
382 | return self.aug_fn(img_list, *level_args, **self.kwargs)
383 |
384 |
385 | _RAND_TRANSFORMS = [
386 | "AutoContrast",
387 | "Equalize",
388 | "Invert",
389 | "Rotate",
390 | "Posterize",
391 | "Solarize",
392 | "SolarizeAdd",
393 | "Color",
394 | "Contrast",
395 | "Brightness",
396 | "Sharpness",
397 | "ShearX",
398 | "ShearY",
399 | "TranslateXRel",
400 | "TranslateYRel",
401 | ]
402 |
403 |
404 | _RAND_INCREASING_TRANSFORMS = [
405 | "AutoContrast",
406 | "Equalize",
407 | "Invert",
408 | "Rotate",
409 | "PosterizeIncreasing",
410 | "SolarizeIncreasing",
411 | "SolarizeAdd",
412 | "ColorIncreasing",
413 | "ContrastIncreasing",
414 | "BrightnessIncreasing",
415 | "SharpnessIncreasing",
416 | "ShearX",
417 | "ShearY",
418 | "TranslateXRel",
419 | "TranslateYRel",
420 | ]
421 |
422 |
423 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
424 | # They may not result in increased performance, but could likely be tuned to so.
425 | _RAND_CHOICE_WEIGHTS_0 = {
426 | "Rotate": 0.3,
427 | "ShearX": 0.2,
428 | "ShearY": 0.2,
429 | "TranslateXRel": 0.1,
430 | "TranslateYRel": 0.1,
431 | "Color": 0.025,
432 | "Sharpness": 0.025,
433 | "AutoContrast": 0.025,
434 | "Solarize": 0.005,
435 | "SolarizeAdd": 0.005,
436 | "Contrast": 0.005,
437 | "Brightness": 0.005,
438 | "Equalize": 0.005,
439 | "Posterize": 0,
440 | "Invert": 0,
441 | }
442 |
443 |
444 | def _select_rand_weights(weight_idx=0, transforms=None):
445 | transforms = transforms or _RAND_TRANSFORMS
446 | assert weight_idx == 0 # only one set of weights currently
447 | rand_weights = _RAND_CHOICE_WEIGHTS_0
448 | probs = [rand_weights[k] for k in transforms]
449 | probs /= np.sum(probs)
450 | return probs
451 |
452 |
453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
454 | hparams = hparams or _HPARAMS_DEFAULT
455 | transforms = transforms or _RAND_TRANSFORMS
456 | return [
457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
458 | for name in transforms
459 | ]
460 |
461 |
462 | class RandAugment:
463 | def __init__(self, ops, num_layers=2, choice_weights=None):
464 | self.ops = ops
465 | self.num_layers = num_layers
466 | self.choice_weights = choice_weights
467 |
468 | def __call__(self, img):
469 | # no replacement when using weighted choice
470 | ops = np.random.choice(
471 | self.ops,
472 | self.num_layers,
473 | replace=self.choice_weights is None,
474 | p=self.choice_weights,
475 | )
476 | for op in ops:
477 | img = op(img)
478 | return img
479 |
480 |
481 | def rand_augment_transform(config_str, hparams):
482 | """
483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
484 |
485 | Create a RandAugment transform
486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
488 | sections, not order sepecific determine
489 | 'm' - integer magnitude of rand augment
490 | 'n' - integer num layers (number of transform ops selected per image)
491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
492 | 'mstd' - float std deviation of magnitude noise applied
493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
497 | :return: A PyTorch compatible Transform
498 | """
499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
500 | num_layers = 2 # default to 2 ops per image
501 | weight_idx = None # default to no probability weights for op choice
502 | transforms = _RAND_TRANSFORMS
503 | config = config_str.split("-")
504 | assert config[0] == "rand"
505 | config = config[1:]
506 | for c in config:
507 | cs = re.split(r"(\d.*)", c)
508 | if len(cs) < 2:
509 | continue
510 | key, val = cs[:2]
511 | if key == "mstd":
512 | # noise param injected via hparams for now
513 | hparams.setdefault("magnitude_std", float(val))
514 | elif key == "inc":
515 | if bool(val):
516 | transforms = _RAND_INCREASING_TRANSFORMS
517 | elif key == "m":
518 | magnitude = int(val)
519 | elif key == "n":
520 | num_layers = int(val)
521 | elif key == "w":
522 | weight_idx = int(val)
523 | else:
524 | assert NotImplementedError
525 | ra_ops = rand_augment_ops(
526 | magnitude=magnitude, hparams=hparams, transforms=transforms
527 | )
528 | choice_weights = (
529 | None if weight_idx is None else _select_rand_weights(weight_idx)
530 | )
531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
532 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Attribution-NonCommercial 4.0 International
3 |
4 | =======================================================================
5 |
6 | Creative Commons Corporation ("Creative Commons") is not a law firm and
7 | does not provide legal services or legal advice. Distribution of
8 | Creative Commons public licenses does not create a lawyer-client or
9 | other relationship. Creative Commons makes its licenses and related
10 | information available on an "as-is" basis. Creative Commons gives no
11 | warranties regarding its licenses, any material licensed under their
12 | terms and conditions, or any related information. Creative Commons
13 | disclaims all liability for damages resulting from their use to the
14 | fullest extent possible.
15 |
16 | Using Creative Commons Public Licenses
17 |
18 | Creative Commons public licenses provide a standard set of terms and
19 | conditions that creators and other rights holders may use to share
20 | original works of authorship and other material subject to copyright
21 | and certain other rights specified in the public license below. The
22 | following considerations are for informational purposes only, are not
23 | exhaustive, and do not form part of our licenses.
24 |
25 | Considerations for licensors: Our public licenses are
26 | intended for use by those authorized to give the public
27 | permission to use material in ways otherwise restricted by
28 | copyright and certain other rights. Our licenses are
29 | irrevocable. Licensors should read and understand the terms
30 | and conditions of the license they choose before applying it.
31 | Licensors should also secure all rights necessary before
32 | applying our licenses so that the public can reuse the
33 | material as expected. Licensors should clearly mark any
34 | material not subject to the license. This includes other CC-
35 | licensed material, or material used under an exception or
36 | limitation to copyright. More considerations for licensors:
37 | wiki.creativecommons.org/Considerations_for_licensors
38 |
39 | Considerations for the public: By using one of our public
40 | licenses, a licensor grants the public permission to use the
41 | licensed material under specified terms and conditions. If
42 | the licensor's permission is not necessary for any reason--for
43 | example, because of any applicable exception or limitation to
44 | copyright--then that use is not regulated by the license. Our
45 | licenses grant only permissions under copyright and certain
46 | other rights that a licensor has authority to grant. Use of
47 | the licensed material may still be restricted for other
48 | reasons, including because others have copyright or other
49 | rights in the material. A licensor may make special requests,
50 | such as asking that all changes be marked or described.
51 | Although not required by our licenses, you are encouraged to
52 | respect those requests where reasonable. More_considerations
53 | for the public:
54 | wiki.creativecommons.org/Considerations_for_licensees
55 |
56 | =======================================================================
57 |
58 | Creative Commons Attribution-NonCommercial 4.0 International Public
59 | License
60 |
61 | By exercising the Licensed Rights (defined below), You accept and agree
62 | to be bound by the terms and conditions of this Creative Commons
63 | Attribution-NonCommercial 4.0 International Public License ("Public
64 | License"). To the extent this Public License may be interpreted as a
65 | contract, You are granted the Licensed Rights in consideration of Your
66 | acceptance of these terms and conditions, and the Licensor grants You
67 | such rights in consideration of benefits the Licensor receives from
68 | making the Licensed Material available under these terms and
69 | conditions.
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 | Section 2 -- Scope.
142 |
143 | a. License grant.
144 |
145 | 1. Subject to the terms and conditions of this Public License,
146 | the Licensor hereby grants You a worldwide, royalty-free,
147 | non-sublicensable, non-exclusive, irrevocable license to
148 | exercise the Licensed Rights in the Licensed Material to:
149 |
150 | a. reproduce and Share the Licensed Material, in whole or
151 | in part, for NonCommercial purposes only; and
152 |
153 | b. produce, reproduce, and Share Adapted Material for
154 | NonCommercial purposes only.
155 |
156 | 2. Exceptions and Limitations. For the avoidance of doubt, where
157 | Exceptions and Limitations apply to Your use, this Public
158 | License does not apply, and You do not need to comply with
159 | its terms and conditions.
160 |
161 | 3. Term. The term of this Public License is specified in Section
162 | 6(a).
163 |
164 | 4. Media and formats; technical modifications allowed. The
165 | Licensor authorizes You to exercise the Licensed Rights in
166 | all media and formats whether now known or hereafter created,
167 | and to make technical modifications necessary to do so. The
168 | Licensor waives and/or agrees not to assert any right or
169 | authority to forbid You from making technical modifications
170 | necessary to exercise the Licensed Rights, including
171 | technical modifications necessary to circumvent Effective
172 | Technological Measures. For purposes of this Public License,
173 | simply making modifications authorized by this Section 2(a)
174 | (4) never produces Adapted Material.
175 |
176 | 5. Downstream recipients.
177 |
178 | a. Offer from the Licensor -- Licensed Material. Every
179 | recipient of the Licensed Material automatically
180 | receives an offer from the Licensor to exercise the
181 | Licensed Rights under the terms and conditions of this
182 | Public License.
183 |
184 | b. No downstream restrictions. You may not offer or impose
185 | any additional or different terms or conditions on, or
186 | apply any Effective Technological Measures to, the
187 | Licensed Material if doing so restricts exercise of the
188 | Licensed Rights by any recipient of the Licensed
189 | Material.
190 |
191 | 6. No endorsement. Nothing in this Public License constitutes or
192 | may be construed as permission to assert or imply that You
193 | are, or that Your use of the Licensed Material is, connected
194 | with, or sponsored, endorsed, or granted official status by,
195 | the Licensor or others designated to receive attribution as
196 | provided in Section 3(a)(1)(A)(i).
197 |
198 | b. Other rights.
199 |
200 | 1. Moral rights, such as the right of integrity, are not
201 | licensed under this Public License, nor are publicity,
202 | privacy, and/or other similar personality rights; however, to
203 | the extent possible, the Licensor waives and/or agrees not to
204 | assert any such rights held by the Licensor to the limited
205 | extent necessary to allow You to exercise the Licensed
206 | Rights, but not otherwise.
207 |
208 | 2. Patent and trademark rights are not licensed under this
209 | Public License.
210 |
211 | 3. To the extent possible, the Licensor waives any right to
212 | collect royalties from You for the exercise of the Licensed
213 | Rights, whether directly or through a collecting society
214 | under any voluntary or waivable statutory or compulsory
215 | licensing scheme. In all other cases the Licensor expressly
216 | reserves any right to collect such royalties, including when
217 | the Licensed Material is used other than for NonCommercial
218 | purposes.
219 |
220 | Section 3 -- License Conditions.
221 |
222 | Your exercise of the Licensed Rights is expressly made subject to the
223 | following conditions.
224 |
225 | a. Attribution.
226 |
227 | 1. If You Share the Licensed Material (including in modified
228 | form), You must:
229 |
230 | a. retain the following if it is supplied by the Licensor
231 | with the Licensed Material:
232 |
233 | i. identification of the creator(s) of the Licensed
234 | Material and any others designated to receive
235 | attribution, in any reasonable manner requested by
236 | the Licensor (including by pseudonym if
237 | designated);
238 |
239 | ii. a copyright notice;
240 |
241 | iii. a notice that refers to this Public License;
242 |
243 | iv. a notice that refers to the disclaimer of
244 | warranties;
245 |
246 | v. a URI or hyperlink to the Licensed Material to the
247 | extent reasonably practicable;
248 |
249 | b. indicate if You modified the Licensed Material and
250 | retain an indication of any previous modifications; and
251 |
252 | c. indicate the Licensed Material is licensed under this
253 | Public License, and include the text of, or the URI or
254 | hyperlink to, this Public License.
255 |
256 | 2. You may satisfy the conditions in Section 3(a)(1) in any
257 | reasonable manner based on the medium, means, and context in
258 | which You Share the Licensed Material. For example, it may be
259 | reasonable to satisfy the conditions by providing a URI or
260 | hyperlink to a resource that includes the required
261 | information.
262 |
263 | 3. If requested by the Licensor, You must remove any of the
264 | information required by Section 3(a)(1)(A) to the extent
265 | reasonably practicable.
266 |
267 | 4. If You Share Adapted Material You produce, the Adapter's
268 | License You apply must not prevent recipients of the Adapted
269 | Material from complying with this Public License.
270 |
271 | Section 4 -- Sui Generis Database Rights.
272 |
273 | Where the Licensed Rights include Sui Generis Database Rights that
274 | apply to Your use of the Licensed Material:
275 |
276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277 | to extract, reuse, reproduce, and Share all or a substantial
278 | portion of the contents of the database for NonCommercial purposes
279 | only;
280 |
281 | b. if You include all or a substantial portion of the database
282 | contents in a database in which You have Sui Generis Database
283 | Rights, then the database in which You have Sui Generis Database
284 | Rights (but not its individual contents) is Adapted Material; and
285 |
286 | c. You must comply with the conditions in Section 3(a) if You Share
287 | all or a substantial portion of the contents of the database.
288 |
289 | For the avoidance of doubt, this Section 4 supplements and does not
290 | replace Your obligations under this Public License where the Licensed
291 | Rights include other Copyright and Similar Rights.
292 |
293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294 |
295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305 |
306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315 |
316 | c. The disclaimer of warranties and limitation of liability provided
317 | above shall be interpreted in a manner that, to the extent
318 | possible, most closely approximates an absolute disclaimer and
319 | waiver of all liability.
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 | Section 7 -- Other Terms and Conditions.
350 |
351 | a. The Licensor shall not be bound by any additional or different
352 | terms or conditions communicated by You unless expressly agreed.
353 |
354 | b. Any arrangements, understandings, or agreements regarding the
355 | Licensed Material not stated herein are separate from and
356 | independent of the terms and conditions of this Public License.
357 |
358 | Section 8 -- Interpretation.
359 |
360 | a. For the avoidance of doubt, this Public License does not, and
361 | shall not be interpreted to, reduce, limit, restrict, or impose
362 | conditions on any use of the Licensed Material that could lawfully
363 | be made without permission under this Public License.
364 |
365 | b. To the extent possible, if any provision of this Public License is
366 | deemed unenforceable, it shall be automatically reformed to the
367 | minimum extent necessary to make it enforceable. If the provision
368 | cannot be reformed, it shall be severed from this Public License
369 | without affecting the enforceability of the remaining terms and
370 | conditions.
371 |
372 | c. No term or condition of this Public License will be waived and no
373 | failure to comply consented to unless expressly agreed to by the
374 | Licensor.
375 |
376 | d. Nothing in this Public License constitutes or may be interpreted
377 | as a limitation upon, or waiver of, any privileges and immunities
378 | that apply to the Licensor or You, including from the legal
379 | processes of any jurisdiction or authority.
380 |
381 | =======================================================================
382 |
383 | Creative Commons is not a party to its public
384 | licenses. Notwithstanding, Creative Commons may elect to apply one of
385 | its public licenses to material it publishes and in those instances
386 | will be considered the “Licensor.” The text of the Creative Commons
387 | public licenses is dedicated to the public domain under the CC0 Public
388 | Domain Dedication. Except for the limited purpose of indicating that
389 | material is shared under a Creative Commons public license or as
390 | otherwise permitted by the Creative Commons policies published at
391 | creativecommons.org/policies, Creative Commons does not authorize the
392 | use of the trademark "Creative Commons" or any other trademark or logo
393 | of Creative Commons without its prior written consent including,
394 | without limitation, in connection with any unauthorized modifications
395 | to any of its public licenses or any other arrangements,
396 | understandings, or agreements concerning use of licensed material. For
397 | the avoidance of doubt, this paragraph does not form part of the
398 | public licenses.
399 |
400 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/models/videomae_modelling_finetune.py:
--------------------------------------------------------------------------------
1 | # some codes from CLIP github(https://github.com/openai/CLIP), from VideoMAE github(https://github.com/MCG-NJU/VideoMAE)
2 | from functools import partial
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
8 | from timm.models.registry import register_model
9 | from collections import OrderedDict
10 | from einops import rearrange
11 | import random
12 |
13 |
14 | def _cfg(url='', **kwargs):
15 | return {
16 | 'url': url,
17 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
18 | 'crop_pct': .9, 'interpolation': 'bicubic',
19 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
20 | **kwargs
21 | }
22 |
23 | class DropPath(nn.Module):
24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
25 | """
26 | def __init__(self, drop_prob=None):
27 | super(DropPath, self).__init__()
28 | self.drop_prob = drop_prob
29 |
30 | def forward(self, x):
31 | return drop_path(x, self.drop_prob, self.training)
32 |
33 | def extra_repr(self) -> str:
34 | return 'p={}'.format(self.drop_prob)
35 |
36 | class Adapter(nn.Module):
37 | def __init__(self, dim, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
38 | super().__init__()
39 | self.skip_connect = skip_connect
40 | down_dim = int(dim * mlp_ratio)
41 | self.act = act_layer()
42 | self.D_fc1 = nn.Linear(dim, down_dim)
43 | self.D_fc2 = nn.Linear(down_dim, dim)
44 |
45 | def forward(self, x):
46 | # x is (BT, HW+1, D)
47 | xs = self.D_fc1(x)
48 | xs = self.act(xs)
49 | xs = self.D_fc2(xs)
50 | if self.skip_connect:
51 | x = x + xs
52 | else:
53 | x = xs
54 | return x
55 |
56 | class LayerNorm(nn.LayerNorm):
57 | """Subclass torch's LayerNorm to handle fp16."""
58 | def forward(self, x: torch.Tensor):
59 | orig_type = x.dtype
60 | if orig_type == torch.float16:
61 | ret = super().forward(x)
62 | elif orig_type == torch.float32:
63 | ret = super().forward(x.type(torch.float32))
64 | return ret.type(orig_type)
65 |
66 | class QuickGELU(nn.Module):
67 | def forward(self, x: torch.Tensor):
68 | return x * torch.sigmoid(1.702 * x)
69 |
70 | class PatchEmbed(nn.Module):
71 | """ Image to Patch Embedding
72 | """
73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
74 | super().__init__()
75 | img_size = to_2tuple(img_size)
76 | patch_size = to_2tuple(patch_size)
77 | self.tubelet_size = int(tubelet_size)
78 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
79 | self.img_size = img_size
80 | self.patch_size = patch_size
81 | self.num_patches = num_patches
82 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
83 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
84 | stride=(self.tubelet_size, patch_size[0], patch_size[1]))
85 |
86 | def forward(self, x, **kwargs):
87 | B, C, T, H, W = x.shape
88 | # FIXME look at relaxing size constraints
89 | assert H == self.img_size[0] and W == self.img_size[1], \
90 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
91 | x = self.proj(x).flatten(2).transpose(1, 2)
92 | return x
93 |
94 | # sin-cos position encoding
95 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
96 | def get_sinusoid_encoding_table(n_position, d_hid):
97 | ''' Sinusoid position encoding table '''
98 | # TODO: make it with torch instead of numpy
99 | def get_position_angle_vec(position):
100 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
101 |
102 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
103 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
104 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
105 |
106 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
107 |
108 | class Mlp(nn.Module):
109 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
110 | super().__init__()
111 | out_features = out_features or in_features
112 | hidden_features = hidden_features or in_features
113 | self.fc1 = nn.Linear(in_features, hidden_features)
114 | self.act = act_layer()
115 | self.fc2 = nn.Linear(hidden_features, out_features)
116 | self.drop = nn.Dropout(drop)
117 |
118 | def forward(self, x):
119 | x = self.fc1(x)
120 | x = self.act(x)
121 | # x = self.drop(x)
122 | # commit this for the orignal BERT implement
123 | x = self.fc2(x)
124 | x = self.drop(x)
125 | return x
126 |
127 | # 기존 weight load편의성을 위해 Attention이름을 유지한다.
128 | class Attention(nn.Module):
129 | def __init__(
130 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
131 | proj_drop=0., attn_head_dim=None):
132 | super().__init__()
133 | self.num_heads = num_heads
134 | head_dim = dim // num_heads
135 | if attn_head_dim is not None:
136 | head_dim = attn_head_dim
137 | all_head_dim = head_dim * self.num_heads
138 | self.scale = qk_scale or head_dim ** -0.5
139 |
140 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
141 | if qkv_bias:
142 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
143 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
144 | else:
145 | self.q_bias = None
146 | self.v_bias = None
147 |
148 | self.attn_drop = nn.Dropout(attn_drop)
149 | self.proj = nn.Linear(all_head_dim, dim)
150 | self.proj_drop = nn.Dropout(proj_drop)
151 |
152 | def forward(self, x):
153 | B, N, C = x.shape
154 | qkv_bias = None
155 | if self.q_bias is not None:
156 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
157 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
158 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
159 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
160 | s2t_q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
161 |
162 | s2t_q = s2t_q * self.scale
163 | attn = (s2t_q @ k.transpose(-2, -1))
164 |
165 |
166 | attn = attn.softmax(dim=-1)
167 | attn = self.attn_drop(attn)
168 |
169 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
170 | x = self.proj(x)
171 | x = self.proj_drop(x)
172 | return x
173 |
174 | # spatial to temporal cross attention module.
175 | class CrossAttentionS2T(nn.Module):
176 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None):
177 | super().__init__()
178 |
179 | # add for cross-attn
180 | self.num_head = n_head
181 | head_dim = dim // self.num_head
182 | self.scale = head_dim ** -0.5
183 | all_head_dim = head_dim * self.num_head
184 | scale = dim ** -0.5
185 | self.space_time_pos = nn.Parameter(scale * torch.randn((197 * 8, dim)))
186 |
187 | #여기에 cross attn t2s module이 들어가야 한다.
188 | self.s2t_q = nn.Linear(dim, all_head_dim, bias=False)
189 | self.s2t_q_bias = nn.Parameter(torch.zeros(all_head_dim))
190 | self.s2t_kv = nn.Linear(dim, all_head_dim * 2, bias=False) # 197 tokens(cls+patch) * num_frames
191 | self.s2t_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2))
192 |
193 | self.t2s_proj = nn.Linear(all_head_dim, dim)
194 |
195 | self.attn_mask = attn_mask
196 |
197 | def s2t_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d]
198 | B, _, _ = t_x.shape
199 | s_x = rearrange(s_x, 'n (b t) d -> b (t n) d', b=B) # batch -> token
200 | s_x = s_x + self.space_time_pos ## sapce time position encoding
201 | s2t_q_bias = self.s2t_q_bias
202 | s2t_kv_bias = self.s2t_kv_bias
203 |
204 | s2t_q = F.linear(input=t_x, weight=self.s2t_q.weight, bias=s2t_q_bias)
205 | s2t_q = rearrange(s2t_q, 'b n (h d) -> b h n d', h=self.num_head)
206 | s2t_kv = F.linear(input=s_x, weight=self.s2t_kv.weight, bias=s2t_kv_bias)
207 | s2t_kv = rearrange(s2t_kv, 'b n (e h d) -> e b h n d',e=2, h=self.num_head)
208 | s2t_k, s2t_v = s2t_kv[0], s2t_kv[1]
209 |
210 | s2t_q = s2t_q * self.scale
211 | s2t_attn = (s2t_q @ s2t_k.transpose(-2, -1))
212 |
213 | s2t_attn = s2t_attn.softmax(dim=-1)
214 |
215 | t_x = (s2t_attn @ s2t_v)
216 | t_x = rearrange(t_x, 'b h n d -> b n (h d)')
217 | t_x = self.t2s_proj(t_x)
218 | return t_x
219 |
220 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor):
221 | return self.s2t_cross_attn(s_x, t_x)
222 |
223 |
224 | # this codes from CLIP github(https://github.com/openai/CLIP)
225 | class CrossAttentionT2S(nn.Module): # 이게 VMAE로 치면 blocks class다. 여기에 cross s2t_attn layer가 추가되어야 한다.
226 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None):
227 | super().__init__()
228 |
229 | # add for cross-attn
230 | self.num_head = n_head
231 | head_dim = dim // self.num_head
232 | self.scale = head_dim ** -0.5
233 | all_head_dim = head_dim * self.num_head
234 |
235 | #여기에 cross attn t2s module이 들어가야 한다.
236 | self.t2s_q = nn.Linear(dim, all_head_dim, bias=False) # 197 tokens(cls+patch) * num_frames
237 | self.t2s_q_bias = nn.Parameter(torch.zeros(all_head_dim))
238 | self.t2s_kv = nn.Linear(dim, all_head_dim * 2, bias=False)
239 | self.t2s_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2))
240 |
241 | self.t2s_proj = nn.Linear(all_head_dim, dim)
242 |
243 | self.attn_mask = attn_mask
244 |
245 | def t2s_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d]
246 | B, _, _ = t_x.shape
247 | s_x_cls, s_x_pat = s_x[0, :, :], s_x[1:, :, :]
248 | s_x_pat = rearrange(s_x_pat, 'n (b t) d -> (b n) t d', b=B) # batch -> token
249 | t_x = rearrange(t_x, 'b (t n) d -> (b n) t d', t=8)
250 | t2s_q_bias = self.t2s_q_bias
251 | t2s_kv_bias = self.t2s_kv_bias
252 |
253 | t2s_q = F.linear(input=s_x_pat, weight=self.t2s_q.weight, bias=t2s_q_bias)
254 | t2s_q = rearrange(t2s_q, 'b t (h d) -> b h t d', h=self.num_head)
255 | t2s_kv = F.linear(input=t_x, weight=self.t2s_kv.weight, bias=t2s_kv_bias)
256 | t2s_kv = rearrange(t2s_kv, 'b t (e h d) -> e b h t d',e=2, h=self.num_head)
257 | t2s_k, t2s_v = t2s_kv[0], t2s_kv[1]
258 |
259 | t2s_q = t2s_q * self.scale
260 | t2s_attn = (t2s_q @ t2s_k.transpose(-2, -1))
261 |
262 | t2s_attn = t2s_attn.softmax(dim=-1)
263 |
264 | s_x_pat = (t2s_attn @ t2s_v)
265 | s_x_pat = rearrange(s_x_pat, 'b h n d -> b n (h d)')
266 | s_x_pat = self.t2s_proj(s_x_pat)
267 | s_x_pat = rearrange(s_x_pat,'(b n) t d -> n (b t) d', b=B)
268 | s_x = torch.cat([s_x_cls.unsqueeze(0), s_x_pat], dim=0)
269 | return s_x
270 |
271 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor):
272 | return self.t2s_cross_attn(s_x, t_x)
273 |
274 |
275 | class Block(nn.Module):
276 |
277 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
278 | drop_path=0., init_values=None, num_layer=0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None,use_adapter=False):
279 | super().__init__()
280 | self.cross = None
281 | self.num_layer = num_layer
282 | self.num_heads = num_heads
283 | self.scale = 0.5
284 | mlp_hidden_dim = int(dim * mlp_ratio)
285 | self.act = act_layer()
286 | self.use_adapter=use_adapter
287 | ############################ VMAE MHSA ###########################
288 | self.norm1 = norm_layer(dim)
289 | self.attn = Attention(
290 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
291 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
292 | if self.use_adapter:
293 | self.T_Adapter = Adapter(dim, skip_connect=True)# base line 의 경우 True.
294 | ############################ VMAE FFN ###############################
295 | self.norm2 = norm_layer(dim)
296 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
297 | if self.use_adapter:
298 | self.T_MLP_Adapter = Adapter(dim, skip_connect=False)
299 | #######################################################################
300 | #########################################################################################
301 |
302 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
303 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
304 |
305 | def attention(self, x: torch.Tensor):
306 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
307 | return self.clip_attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
308 |
309 | def forward(self,t_x):
310 | B = t_x.shape[0]
311 | if self.use_adapter:
312 | t_x = t_x + self.T_Adapter(self.attn(self.norm1(t_x)))
313 | t_xn = self.norm2(t_x)
314 | t_x = t_x + self.mlp(t_xn) + self.drop_path(self.scale * self.T_MLP_Adapter(t_xn))
315 | else:
316 | t_x = t_x + self.drop_path(self.attn(self.norm1(t_x)))
317 | t_x = t_x + self.drop_path(self.mlp(self.norm2(t_x)))
318 |
319 |
320 | ############################################################################
321 |
322 | return t_x
323 |
324 | class STCrossTransformer(nn.Module):
325 | """ Vision Transformer with support for patch or hybrid CNN input stage
326 | """
327 | def __init__(self,
328 | img_size=224,
329 | patch_size=16,
330 | in_chans=3,
331 | num_classes=1000,
332 | embed_dim=768,
333 | depth=12,
334 | num_heads=12,
335 | mlp_ratio=4.,
336 | qkv_bias=False,
337 | qk_scale=None,
338 | drop_rate=0.,
339 | attn_drop_rate=0.,
340 | drop_path_rate=0.,
341 | norm_layer=nn.LayerNorm,
342 | init_values=0.,
343 | use_learnable_pos_emb=False,
344 | init_scale=0.,
345 | all_frames=16,
346 | tubelet_size=2,
347 | use_mean_pooling=True,
348 | composition=False,
349 | pretrained_cfg = None,
350 | use_adapter=False,
351 | fusion_method=None
352 | ):
353 | super().__init__()
354 | self.num_classes = num_classes
355 | self.embed_dim = embed_dim # num_features for consistency with other models
356 | self.tubelet_size = tubelet_size
357 | self.composition = composition
358 | self.patch_embed = PatchEmbed(
359 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size)
360 | num_patches = self.patch_embed.num_patches
361 | self.use_adapter=use_adapter
362 | scale = embed_dim ** -0.5
363 |
364 | if use_learnable_pos_emb:
365 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
366 | else:
367 | # sine-cosine positional embeddings is on the way
368 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
369 |
370 | self.pos_drop = nn.Dropout(p=drop_rate)
371 |
372 |
373 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
374 | self.blocks = nn.ModuleList([
375 | Block(
376 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
377 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
378 | init_values=init_values, num_layer=i,use_adapter=self.use_adapter)
379 | for i in range(depth)])
380 |
381 | # self.clip_ln_post = LayerNorm(embed_dim)
382 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
383 | self.vmae_fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
384 |
385 | if self.composition:
386 | self.head_verb = nn.Linear(embed_dim, 97)
387 | self.head_noun = nn.Linear(embed_dim, 300)
388 | else:
389 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
390 |
391 | if use_learnable_pos_emb:
392 | trunc_normal_(self.pos_embed, std=.02)
393 |
394 | self.apply(self._init_weights)
395 | if self.use_adapter:
396 | self._init_adpater_weight()
397 |
398 | if self.composition:
399 | trunc_normal_(self.head_noun.weight, std=.02)
400 | trunc_normal_(self.head_verb.weight, std=.02)
401 | self.head_verb.weight.data.mul_(init_scale)
402 | self.head_verb.bias.data.mul_(init_scale)
403 | self.head_noun.weight.data.mul_(init_scale)
404 | self.head_noun.bias.data.mul_(init_scale)
405 | else:
406 | trunc_normal_(self.head.weight, std=.02)
407 | self.head.weight.data.mul_(init_scale)
408 | self.head.bias.data.mul_(init_scale)
409 |
410 | def _init_weights(self, m):
411 | if isinstance(m, nn.Linear):
412 | trunc_normal_(m.weight, std=.02)
413 | if isinstance(m, nn.Linear) and m.bias is not None:
414 | nn.init.constant_(m.bias, 0)
415 | elif isinstance(m, nn.LayerNorm):
416 | nn.init.constant_(m.bias, 0)
417 | nn.init.constant_(m.weight, 1.0)
418 |
419 | def _init_adpater_weight(self):
420 | for n, m in self.blocks.named_modules():
421 | if 'Adapter' in n:
422 | for n2, m2 in m.named_modules():
423 | if 'D_fc2' in n2:
424 | if isinstance(m2, nn.Linear):
425 | nn.init.constant_(m2.weight, 0)
426 | nn.init.constant_(m2.bias, 0)
427 | elif 'up' in n:
428 | for n2, m2 in m.named_modules():
429 | if isinstance(m2, nn.Linear):
430 | nn.init.constant_(m2.weight, 0)
431 | nn.init.constant_(m2.bias, 0)
432 |
433 |
434 | def get_num_layers(self):
435 | return len(self.blocks)
436 |
437 | @torch.jit.ignore
438 | def no_weight_decay(self):
439 | return {'clip_temporal_embedding','pos_embed'}
440 |
441 | def get_classifier(self):
442 | return self.head
443 |
444 | def reset_classifier(self, num_classes, global_pool=''):
445 | self.num_classes = num_classes
446 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
447 |
448 | def reset_fcnorm(self):
449 | self.vmae_fc_norm = nn.LayerNorm(self.embed_dim)
450 |
451 | def forward_features(self, x):
452 | B = x.shape[0]
453 | ######################## VMAE spatial path #########################
454 | t_x = self.patch_embed(x)
455 |
456 | if self.pos_embed is not None:
457 | t_x = t_x + self.pos_embed.expand(B, -1, -1).type_as(t_x).to(t_x.device).clone().detach()
458 | t_x = self.pos_drop(t_x)
459 | #####################################################################
460 |
461 | for blk in self.blocks:
462 | t_x = blk(t_x)
463 | t_x = self.vmae_fc_norm(t_x.mean(1)) # all patch avg pooling
464 |
465 | return t_x
466 |
467 |
468 | def forward(self, x):
469 | if self.composition:
470 | t_x = self.forward_features(x)
471 | noun = self.head_noun(t_x)
472 | verb = self.head_verb(t_x)
473 | return noun, verb
474 | else:
475 | x = self.forward_features(x)
476 | x = self.head(x)
477 | return x
478 |
479 |
480 |
481 |
482 | @register_model
483 | def compo_videomae_adapter_vit_base_patch16_224(pretrained=False, **kwargs):
484 | model = STCrossTransformer(
485 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
486 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=True,use_adapter=True, **kwargs)
487 | #model.default_cfg = _cfg()
488 | return model
489 | @register_model
490 | def videomae_adapter_vit_base_patch16_224(pretrained=False, **kwargs):
491 | model = STCrossTransformer(
492 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
493 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=False,use_adapter=True, **kwargs)
494 | #model.default_cfg = _cfg()
495 | return model
496 |
497 |
498 | @register_model
499 | def compo_videomae_vit_base_patch16_224(pretrained=False, **kwargs):
500 | model = STCrossTransformer(
501 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
502 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=True,use_adapter=False, **kwargs)
503 | #model.default_cfg = _cfg()
504 | return model
505 | @register_model
506 | def videomae_vit_base_patch16_224(pretrained=False, **kwargs):
507 | model = STCrossTransformer(
508 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
509 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=False,use_adapter=False, **kwargs)
510 | #model.default_cfg = _cfg()
511 | return model
512 |
513 |
514 |
--------------------------------------------------------------------------------
/models/clip_modelling_finetune.py:
--------------------------------------------------------------------------------
1 | # some codes from CLIP github(https://github.com/openai/CLIP), from VideoMAE github(https://github.com/MCG-NJU/VideoMAE)
2 | from functools import partial
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
8 | from timm.models.registry import register_model
9 | from collections import OrderedDict
10 | from einops import rearrange
11 | import random
12 |
13 |
14 | def _cfg(url='', **kwargs):
15 | return {
16 | 'url': url,
17 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
18 | 'crop_pct': .9, 'interpolation': 'bicubic',
19 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
20 | **kwargs
21 | }
22 |
23 | class DropPath(nn.Module):
24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
25 | """
26 | def __init__(self, drop_prob=None):
27 | super(DropPath, self).__init__()
28 | self.drop_prob = drop_prob
29 |
30 | def forward(self, x):
31 | return drop_path(x, self.drop_prob, self.training)
32 |
33 | def extra_repr(self) -> str:
34 | return 'p={}'.format(self.drop_prob)
35 |
36 | class Adapter(nn.Module):
37 | def __init__(self, dim, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
38 | super().__init__()
39 | self.skip_connect = skip_connect
40 | down_dim = int(dim * mlp_ratio)
41 | self.act = act_layer()
42 | self.D_fc1 = nn.Linear(dim, down_dim)
43 | self.D_fc2 = nn.Linear(down_dim, dim)
44 |
45 | def forward(self, x):
46 | # x is (BT, HW+1, D)
47 | xs = self.D_fc1(x)
48 | xs = self.act(xs)
49 | xs = self.D_fc2(xs)
50 | if self.skip_connect:
51 | x = x + xs
52 | else:
53 | x = xs
54 | return x
55 |
56 | class LayerNorm(nn.LayerNorm):
57 | """Subclass torch's LayerNorm to handle fp16."""
58 | def forward(self, x: torch.Tensor):
59 | orig_type = x.dtype
60 | if orig_type == torch.float16:
61 | ret = super().forward(x)
62 | elif orig_type == torch.float32:
63 | ret = super().forward(x.type(torch.float32))
64 | return ret.type(orig_type)
65 |
66 | class QuickGELU(nn.Module):
67 | def forward(self, x: torch.Tensor):
68 | return x * torch.sigmoid(1.702 * x)
69 |
70 | class PatchEmbed(nn.Module):
71 | """ Image to Patch Embedding
72 | """
73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
74 | super().__init__()
75 | img_size = to_2tuple(img_size)
76 | patch_size = to_2tuple(patch_size)
77 | self.tubelet_size = int(tubelet_size)
78 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
79 | self.img_size = img_size
80 | self.patch_size = patch_size
81 | self.num_patches = num_patches
82 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim,
83 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]),
84 | stride=(self.tubelet_size, patch_size[0], patch_size[1]))
85 |
86 | def forward(self, x, **kwargs):
87 | B, C, T, H, W = x.shape
88 | # FIXME look at relaxing size constraints
89 | assert H == self.img_size[0] and W == self.img_size[1], \
90 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
91 | x = self.proj(x).flatten(2).transpose(1, 2)
92 | return x
93 |
94 | # sin-cos position encoding
95 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
96 | def get_sinusoid_encoding_table(n_position, d_hid):
97 | ''' Sinusoid position encoding table '''
98 | # TODO: make it with torch instead of numpy
99 | def get_position_angle_vec(position):
100 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
101 |
102 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
103 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
104 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
105 |
106 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
107 |
108 | class Mlp(nn.Module):
109 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
110 | super().__init__()
111 | out_features = out_features or in_features
112 | hidden_features = hidden_features or in_features
113 | self.fc1 = nn.Linear(in_features, hidden_features)
114 | self.act = act_layer()
115 | self.fc2 = nn.Linear(hidden_features, out_features)
116 | self.drop = nn.Dropout(drop)
117 |
118 | def forward(self, x):
119 | x = self.fc1(x)
120 | x = self.act(x)
121 | # x = self.drop(x)
122 | # commit this for the orignal BERT implement
123 | x = self.fc2(x)
124 | x = self.drop(x)
125 | return x
126 |
127 | # 기존 weight load편의성을 위해 Attention이름을 유지한다.
128 | class Attention(nn.Module):
129 | def __init__(
130 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
131 | proj_drop=0., attn_head_dim=None):
132 | super().__init__()
133 | self.num_heads = num_heads
134 | head_dim = dim // num_heads
135 | if attn_head_dim is not None:
136 | head_dim = attn_head_dim
137 | all_head_dim = head_dim * self.num_heads
138 | self.scale = qk_scale or head_dim ** -0.5
139 |
140 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
141 | if qkv_bias:
142 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
143 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
144 | else:
145 | self.q_bias = None
146 | self.v_bias = None
147 |
148 | self.attn_drop = nn.Dropout(attn_drop)
149 | self.proj = nn.Linear(all_head_dim, dim)
150 | self.proj_drop = nn.Dropout(proj_drop)
151 |
152 | def forward(self, x):
153 | B, N, C = x.shape
154 | qkv_bias = None
155 | if self.q_bias is not None:
156 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
157 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
158 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
159 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
160 | s2t_q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
161 |
162 | s2t_q = s2t_q * self.scale
163 | attn = (s2t_q @ k.transpose(-2, -1))
164 |
165 |
166 | attn = attn.softmax(dim=-1)
167 | attn = self.attn_drop(attn)
168 |
169 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
170 | x = self.proj(x)
171 | x = self.proj_drop(x)
172 | return x
173 |
174 | # spatial to temporal cross attention module.
175 | class CrossAttentionS2T(nn.Module):
176 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None):
177 | super().__init__()
178 |
179 | # add for cross-attn
180 | self.num_head = n_head
181 | head_dim = dim // self.num_head
182 | self.scale = head_dim ** -0.5
183 | all_head_dim = head_dim * self.num_head
184 | scale = dim ** -0.5
185 | self.space_time_pos = nn.Parameter(scale * torch.randn((197 * 8, dim)))
186 |
187 | #여기에 cross attn t2s module이 들어가야 한다.
188 | self.s2t_q = nn.Linear(dim, all_head_dim, bias=False)
189 | self.s2t_q_bias = nn.Parameter(torch.zeros(all_head_dim))
190 | self.s2t_kv = nn.Linear(dim, all_head_dim * 2, bias=False) # 197 tokens(cls+patch) * num_frames
191 | self.s2t_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2))
192 |
193 | self.t2s_proj = nn.Linear(all_head_dim, dim)
194 |
195 | self.attn_mask = attn_mask
196 |
197 | def s2t_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d]
198 | B, _, _ = t_x.shape
199 | s_x = rearrange(s_x, 'n (b t) d -> b (t n) d', b=B) # batch -> token
200 | s_x = s_x + self.space_time_pos ## sapce time position encoding
201 | s2t_q_bias = self.s2t_q_bias
202 | s2t_kv_bias = self.s2t_kv_bias
203 |
204 | s2t_q = F.linear(input=t_x, weight=self.s2t_q.weight, bias=s2t_q_bias)
205 | s2t_q = rearrange(s2t_q, 'b n (h d) -> b h n d', h=self.num_head)
206 | s2t_kv = F.linear(input=s_x, weight=self.s2t_kv.weight, bias=s2t_kv_bias)
207 | s2t_kv = rearrange(s2t_kv, 'b n (e h d) -> e b h n d',e=2, h=self.num_head)
208 | s2t_k, s2t_v = s2t_kv[0], s2t_kv[1]
209 |
210 | s2t_q = s2t_q * self.scale
211 | s2t_attn = (s2t_q @ s2t_k.transpose(-2, -1))
212 |
213 | s2t_attn = s2t_attn.softmax(dim=-1)
214 |
215 | t_x = (s2t_attn @ s2t_v)
216 | t_x = rearrange(t_x, 'b h n d -> b n (h d)')
217 | t_x = self.t2s_proj(t_x)
218 | return t_x
219 |
220 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor):
221 | return self.s2t_cross_attn(s_x, t_x)
222 |
223 |
224 | # this codes from CLIP github(https://github.com/openai/CLIP)
225 | class CrossAttentionT2S(nn.Module): # 이게 VMAE로 치면 blocks class다. 여기에 cross s2t_attn layer가 추가되어야 한다.
226 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None):
227 | super().__init__()
228 |
229 | # add for cross-attn
230 | self.num_head = n_head
231 | head_dim = dim // self.num_head
232 | self.scale = head_dim ** -0.5
233 | all_head_dim = head_dim * self.num_head
234 |
235 | #여기에 cross attn t2s module이 들어가야 한다.
236 | self.t2s_q = nn.Linear(dim, all_head_dim, bias=False) # 197 tokens(cls+patch) * num_frames
237 | self.t2s_q_bias = nn.Parameter(torch.zeros(all_head_dim))
238 | self.t2s_kv = nn.Linear(dim, all_head_dim * 2, bias=False)
239 | self.t2s_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2))
240 |
241 | self.t2s_proj = nn.Linear(all_head_dim, dim)
242 |
243 | self.attn_mask = attn_mask
244 |
245 | def t2s_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d]
246 | B, _, _ = t_x.shape
247 | s_x_cls, s_x_pat = s_x[0, :, :], s_x[1:, :, :]
248 | s_x_pat = rearrange(s_x_pat, 'n (b t) d -> (b n) t d', b=B) # batch -> token
249 | t_x = rearrange(t_x, 'b (t n) d -> (b n) t d', t=8)
250 | t2s_q_bias = self.t2s_q_bias
251 | t2s_kv_bias = self.t2s_kv_bias
252 |
253 | t2s_q = F.linear(input=s_x_pat, weight=self.t2s_q.weight, bias=t2s_q_bias)
254 | t2s_q = rearrange(t2s_q, 'b t (h d) -> b h t d', h=self.num_head)
255 | t2s_kv = F.linear(input=t_x, weight=self.t2s_kv.weight, bias=t2s_kv_bias)
256 | t2s_kv = rearrange(t2s_kv, 'b t (e h d) -> e b h t d',e=2, h=self.num_head)
257 | t2s_k, t2s_v = t2s_kv[0], t2s_kv[1]
258 |
259 | t2s_q = t2s_q * self.scale
260 | t2s_attn = (t2s_q @ t2s_k.transpose(-2, -1))
261 |
262 | t2s_attn = t2s_attn.softmax(dim=-1)
263 |
264 | s_x_pat = (t2s_attn @ t2s_v)
265 | s_x_pat = rearrange(s_x_pat, 'b h n d -> b n (h d)')
266 | s_x_pat = self.t2s_proj(s_x_pat)
267 | s_x_pat = rearrange(s_x_pat,'(b n) t d -> n (b t) d', b=B)
268 | s_x = torch.cat([s_x_cls.unsqueeze(0), s_x_pat], dim=0)
269 | return s_x
270 |
271 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor):
272 | return self.t2s_cross_attn(s_x, t_x)
273 |
274 |
275 | class Block(nn.Module):
276 |
277 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
278 | drop_path=0., init_values=None, num_layer=0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None,use_adapter=False):
279 | super().__init__()
280 | self.cross = None
281 | self.num_layer = num_layer
282 | self.num_heads = num_heads
283 | self.scale = 0.5
284 | mlp_hidden_dim = int(dim * mlp_ratio)
285 | self.act = act_layer()
286 | self.use_adapter=use_adapter
287 | self.num_frames=16
288 | ############################ AIM MHSA ###########################
289 | self.clip_ln_1 = LayerNorm(dim)
290 | if self.use_adapter:
291 | # self.time_attn = nn.MultiheadAttention(dim, num_heads)
292 | self.T_Adapter = Adapter(dim, skip_connect=False)
293 | self.clip_attn = nn.MultiheadAttention(dim, num_heads)
294 | if self.use_adapter:
295 | self.S_Adapter = Adapter(dim)
296 |
297 |
298 |
299 | ############################ AIM FFN ###############################
300 | self.clip_ln_2 = LayerNorm(dim)
301 | self.clip_mlp = nn.Sequential(OrderedDict([
302 | ("c_fc", nn.Linear(dim, dim * 4)),
303 | ("gelu", QuickGELU()),
304 | ("c_proj", nn.Linear(dim * 4, dim))
305 | ]))
306 | if self.use_adapter:
307 | self.S_MLP_Adapter = Adapter(dim, skip_connect=False)
308 | self.attn_mask = None
309 |
310 |
311 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
312 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
313 |
314 | def attention(self, x: torch.Tensor):
315 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
316 | return self.clip_attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
317 | def time_attention(self, x:torch.Tensor):
318 | return self.time_attn(x,x,x,need_weights=False,attn_mask=None)[0]
319 |
320 | def forward(self,s_x):
321 | n, bt, d = s_x.shape
322 |
323 | if self.use_adapter:
324 | ############################ AIM TIME #############################
325 | xt = rearrange(s_x, 'n (b t) d -> t (b n) d', t=self.num_frames)
326 | xt = self.T_Adapter(self.attention(self.clip_ln_1(xt)))
327 | xt = rearrange(xt, 't (b n) d -> n (b t) d', n=n)
328 | ##########################################################
329 | s_x = s_x + self.drop_path(xt) # skip connection original + time attention result
330 | # AIM Space MHSA
331 | s_x = s_x + self.S_Adapter(self.attention(self.clip_ln_1(s_x))) # original space multi head self attention
332 | ############################ FFN Forward ##################################
333 | s_xn = self.clip_ln_2(s_x)
334 | s_x = s_x + self.clip_mlp(s_xn) + self.drop_path(self.scale * self.S_MLP_Adapter(s_xn))
335 | ############################################################################
336 | else:
337 | s_x = s_x + self.attention(self.clip_ln_1(s_x))
338 | s_x = s_x + self.clip_mlp(self.clip_ln_2(s_x))
339 | return s_x
340 |
341 | class STCrossTransformer(nn.Module):
342 | """ Vision Transformer with support for patch or hybrid CNN input stage
343 | """
344 | def __init__(self,
345 | img_size=224,
346 | patch_size=16,
347 | in_chans=3,
348 | num_classes=1000,
349 | embed_dim=768,
350 | depth=12,
351 | num_heads=12,
352 | mlp_ratio=4.,
353 | qkv_bias=False,
354 | qk_scale=None,
355 | drop_rate=0.,
356 | attn_drop_rate=0.,
357 | drop_path_rate=0.,
358 | norm_layer=nn.LayerNorm,
359 | init_values=0.,
360 | use_learnable_pos_emb=False,
361 | init_scale=0.,
362 | all_frames=16,
363 | tubelet_size=2,
364 | use_mean_pooling=True,
365 | composition=False,
366 | pretrained_cfg = None,
367 | use_adapter=False,
368 | fusion_method=None,
369 | ):
370 | super().__init__()
371 | self.num_classes = num_classes
372 | self.embed_dim = embed_dim # num_features for consistency with other models
373 | self.tubelet_size = tubelet_size
374 | self.composition = composition
375 | self.use_adapter=use_adapter
376 | scale = embed_dim ** -0.5
377 | self.fusion_method=fusion_method
378 | self.clip_conv1 = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
379 | self.clip_class_embedding = nn.Parameter(scale * torch.randn(embed_dim))
380 | self.clip_positional_embedding = nn.Parameter(scale * torch.randn((img_size // patch_size) ** 2 + 1, embed_dim))
381 | if self.use_adapter:
382 | self.clip_temporal_embedding = nn.Parameter(torch.zeros(1, all_frames, embed_dim))
383 | self.clip_ln_pre = LayerNorm(embed_dim)
384 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
385 | self.blocks = nn.ModuleList([
386 | Block(
387 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
388 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
389 | init_values=init_values, num_layer=i,use_adapter=self.use_adapter)
390 | for i in range(depth)])
391 |
392 | self.clip_ln_post = LayerNorm(embed_dim)
393 |
394 | if self.composition:
395 | self.head_verb = nn.Linear(embed_dim, 97)
396 | self.head_noun = nn.Linear(embed_dim, 300)
397 | else:
398 | self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
399 |
400 | if use_learnable_pos_emb:
401 | trunc_normal_(self.pos_embed, std=.02)
402 |
403 | self.apply(self._init_weights)
404 | if self.use_adapter:
405 | self._init_adpater_weight()
406 |
407 | if self.composition:
408 | trunc_normal_(self.head_noun.weight, std=.02)
409 | trunc_normal_(self.head_verb.weight, std=.02)
410 | self.head_verb.weight.data.mul_(init_scale)
411 | self.head_verb.bias.data.mul_(init_scale)
412 | self.head_noun.weight.data.mul_(init_scale)
413 | self.head_noun.bias.data.mul_(init_scale)
414 | else:
415 | trunc_normal_(self.head.weight, std=.02)
416 | # self.head.weight.data.mul_(init_scale)
417 | # self.head.bias.data.mul_(init_scale)
418 |
419 | def _init_weights(self, m):
420 | if isinstance(m, nn.Linear):
421 | trunc_normal_(m.weight, std=.02)
422 | if isinstance(m, nn.Linear) and m.bias is not None:
423 | nn.init.constant_(m.bias, 0)
424 | elif isinstance(m, nn.LayerNorm):
425 | nn.init.constant_(m.bias, 0)
426 | nn.init.constant_(m.weight, 1.0)
427 |
428 | def _init_adpater_weight(self):
429 | for n, m in self.blocks.named_modules():
430 | if 'Adapter' in n:
431 | for n2, m2 in m.named_modules():
432 | if 'D_fc2' in n2:
433 | if isinstance(m2, nn.Linear):
434 | nn.init.constant_(m2.weight, 0)
435 | nn.init.constant_(m2.bias, 0)
436 | elif 'up' in n:
437 | for n2, m2 in m.named_modules():
438 | if isinstance(m2, nn.Linear):
439 | nn.init.constant_(m2.weight, 0)
440 | nn.init.constant_(m2.bias, 0)
441 |
442 |
443 | def get_num_layers(self):
444 | return len(self.blocks)
445 |
446 | @torch.jit.ignore
447 | def no_weight_decay(self):
448 | return {'clip_temporal_embedding','pos_embed'}
449 |
450 | def get_classifier(self):
451 | return self.head
452 |
453 | def reset_classifier(self, num_classes, global_pool=''):
454 | self.num_classes = num_classes
455 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
456 |
457 | def reset_fcnorm(self):
458 | self.vmae_fc_norm = nn.LayerNorm(self.embed_dim)
459 |
460 | def forward_features(self, x):
461 | B = x.shape[0]
462 | ######################## AIM spatial path #########################
463 | s_x= x
464 | s_t = s_x.shape[2]
465 | s_x = rearrange(s_x, 'b c t h w -> (b t) c h w')
466 | s_x = self.clip_conv1(s_x) # shape = [*, embeddim, grid, grid]
467 | s_x = s_x.reshape(s_x.shape[0], s_x.shape[1], -1) # [*, embeddim, grid**2]
468 | s_x = s_x.permute(0, 2, 1) # shape[batch, patchnum, embeddim]
469 | s_x = torch.cat([self.clip_class_embedding.to(s_x.dtype) + torch.zeros(s_x.shape[0], 1, s_x.shape[-1], dtype=s_x.dtype, device=s_x.device), s_x], dim=1)
470 | s_x = s_x + self.clip_positional_embedding.to(s_x.dtype)
471 | n = s_x.shape[1]
472 | if self.use_adapter:
473 | s_x = rearrange(s_x, '(b t) n d -> (b n) t d', t=s_t)
474 | s_x = s_x + self.clip_temporal_embedding#(1,t,d)
475 | s_x = rearrange(s_x, '(b n) t d -> (b t) n d', n=n)
476 | s_x = self.clip_ln_pre(s_x)
477 | ##################################################################### #####################################################################
478 |
479 |
480 | s_x = s_x.permute(1,0,2)
481 | for blk in self.blocks:
482 | s_x = blk(s_x)
483 | s_x = s_x.permute(1,0,2)
484 |
485 | s_x = rearrange(s_x, '(b t) n d -> b t n d', b=B)
486 | s_x = self.clip_ln_post(s_x[:,:,0,:].mean(1)) # all cls tokens avg pooling
487 |
488 |
489 | return s_x
490 |
491 |
492 | def forward(self, x):
493 | if self.composition:
494 | s_x = self.forward_features(x)
495 | noun = self.head_noun(s_x)
496 | verb = self.head_verb(s_x)
497 | return noun, verb
498 | else:
499 | x = self.forward_features(x)
500 | x = self.head(x)
501 | return x
502 |
503 |
504 |
505 | @register_model
506 | def compo_clip_vit_base_patch16_224(pretrained=False, **kwargs):
507 | model = STCrossTransformer(
508 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
509 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=True,use_adapter=False,**kwargs)
510 | #model.default_cfg = _cfg()
511 | return model
512 | def clip_vit_base_patch16_224(pretrained=False, **kwargs):
513 | model = STCrossTransformer(
514 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
515 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=False,use_adapter=False,**kwargs)
516 | #model.default_cfg = _cfg()
517 | return model
518 |
519 |
--------------------------------------------------------------------------------