├── figs ├── CAST.jpg └── ds_report.JPG ├── util_tools ├── masking_generator.py ├── functional.py ├── volume_transforms.py ├── random_erasing.py ├── optim_factory.py ├── transforms.py ├── mixup.py └── rand_augment.py ├── scripts ├── ssv2 │ └── ssv2.sh ├── ek100 │ └── ek100.sh └── kinetics │ └── k400.sh ├── dataset ├── datasets.py ├── ssv2.py └── epic.py ├── README.md ├── engine_for_onemodel.py ├── engine_for_compomodel.py ├── LICENSE └── models ├── videomae_modelling_finetune.py └── clip_modelling_finetune.py /figs/CAST.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KHU-VLL/CAST/HEAD/figs/CAST.jpg -------------------------------------------------------------------------------- /figs/ds_report.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KHU-VLL/CAST/HEAD/figs/ds_report.JPG -------------------------------------------------------------------------------- /util_tools/masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class TubeMaskingGenerator: 4 | def __init__(self, input_size, mask_ratio): 5 | self.frames, self.height, self.width = input_size 6 | self.num_patches_per_frame = self.height * self.width 7 | self.total_patches = self.frames * self.num_patches_per_frame 8 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 9 | self.total_masks = self.frames * self.num_masks_per_frame 10 | 11 | def __repr__(self): 12 | repr_str = "Maks: total patches {}, mask patches {}".format( 13 | self.total_patches, self.total_masks 14 | ) 15 | return repr_str 16 | 17 | def __call__(self): 18 | mask_per_frame = np.hstack([ 19 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 20 | np.ones(self.num_masks_per_frame), 21 | ]) 22 | np.random.shuffle(mask_per_frame) 23 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten() 24 | return mask -------------------------------------------------------------------------------- /scripts/ssv2/ssv2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=YOUR_PATH 4 | VMAE_MODEL_PATH=YOUR_PATH 5 | CLIP_MODEL_PATH=YOUR_PATH 6 | 7 | OUTPUT_DIR=YOUR_PATH 8 | MASTER_NODE=$1 9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \ 10 | --nproc_per_node=$6 \ 11 | --master_port $3 --nnodes=$5 \ 12 | --node_rank=$2 --master_addr=${MASTER_NODE} \ 13 | YOUR_PATH/run_bidirection.py \ 14 | --data_set SSV2 \ 15 | --nb_classes 174 \ 16 | --vmae_model bidir_vit_base_patch16_224 \ 17 | --data_path ${DATA_PATH} \ 18 | --anno_path ${YOUR_PATH} \ 19 | --clip_finetune ${CLIP_MODEL_PATH} \ 20 | --vmae_finetune ${VMAE_MODEL_PATH} \ 21 | --log_dir ${YOUR_PAHT} \ 22 | --output_dir ${YOUR_PAHT} \ 23 | --batch_size 6 \ 24 | --input_size 224 \ 25 | --short_side_size 224 \ 26 | --save_ckpt_freq 25 \ 27 | --num_sample 1 \ 28 | --num_frames 16 \ 29 | --opt adamw \ 30 | --lr 1e-3 \ 31 | --opt_betas 0.9 0.999 \ 32 | --weight_decay 0.05 \ 33 | --epochs 50 \ 34 | --dist_eval \ 35 | --test_num_segment 2 \ 36 | --test_num_crop 3 \ 37 | --num_workers 8 \ 38 | --seed 0 \ 39 | --mixup_switch_prob 0 \ 40 | --mixup_prob 0.9 \ 41 | --reprob 0. \ 42 | --init_scale 1. \ 43 | --enable_deepspeed \ 44 | --warmup_epochs 5 \ 45 | --update_freq 4 \ 46 | --drop_path 0.3 47 | 48 | echo "Job finish" -------------------------------------------------------------------------------- /scripts/ek100/ek100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=YOUR_PATH 4 | VMAE_MODEL_PATH=YOUR_PATH 5 | CLIP_MODEL_PATH=YOUR_PATH 6 | 7 | OUTPUT_DIR=YOUR_PATH 8 | MASTER_NODE=$1 9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \ 10 | --nproc_per_node=$6 \ 11 | --master_port $3 --nnodes=$5 \ 12 | --node_rank=$2 --master_addr=${MASTER_NODE} \ 13 | {YOUR_PATH}/run_bidirection_compo.py \ 14 | --data_set EPIC \ 15 | --nb_classes 300 \ 16 | --vmae_model compo_bidir_vit_base_patch16_224 \ 17 | --data_path ${DATA_PATH} \ 18 | --anno_path YOUR_PATH \ 19 | --clip_finetune ${CLIP_MODEL_PATH} \ 20 | --vmae_finetune ${VMAE_MODEL_PATH} \ 21 | --log_dir ${OUTPUT_DIR} \ 22 | --output_dir ${OUTPUT_DIR} \ 23 | --batch_size 6 \ 24 | --input_size 224 \ 25 | --short_side_size 224 \ 26 | --save_ckpt_freq 25 \ 27 | --num_sample 1 \ 28 | --num_frames 16 \ 29 | --opt adamw \ 30 | --lr 1e-3 \ 31 | --opt_betas 0.9 0.999 \ 32 | --weight_decay 0.05 \ 33 | --epochs 50 \ 34 | --dist_eval \ 35 | --test_num_segment 2 \ 36 | --test_num_crop 3 \ 37 | --num_workers 8 \ 38 | --drop_path 0.2 \ 39 | --mixup_switch_prob 0 \ 40 | --mixup_prob 0.9 \ 41 | --reprob 0. \ 42 | --init_scale 1. \ 43 | --update_freq 4 \ 44 | --seed 0 \ 45 | --enable_deepspeed \ 46 | --warmup_epochs 5 \ 47 | --composition \ 48 | 49 | echo "Job finish" -------------------------------------------------------------------------------- /scripts/kinetics/k400.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_PATH=YOUR_PATH 4 | VMAE_MODEL_PATH=YOUR_PATH 5 | CLIP_MODEL_PATH=YOUR_PATH 6 | 7 | OUTPUT_DIR=YOUR_PATH 8 | MASTER_NODE=$1 9 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \ 10 | --nproc_per_node=$6 \ 11 | --master_port $3 --nnodes=$5 \ 12 | --node_rank=$2 --master_addr=${MASTER_NODE} \ 13 | YOUR_PATH/run_bidirection.py \ 14 | --data_set Kinetics-400 \ 15 | --nb_classes 400 \ 16 | --vmae_model bidir_vit_base_patch16_224 \ 17 | --anno_path ${ANNOTATION_PATH} \ 18 | --data_path ${DATA_PATH} \ 19 | --clip_finetune ${CLIP_MODEL_PATH} \ 20 | --vmae_finetune ${VMAE_MODEL_PATH} \ 21 | --log_dir ${YOUR_PATH} \ 22 | --output_dir ${YOUR_PATH} \ 23 | --batch_size 6 \ 24 | --input_size 224 \ 25 | --short_side_size 224 \ 26 | --save_ckpt_freq 25 \ 27 | --num_sample 1 \ 28 | --num_frames 16 \ 29 | --opt adamw \ 30 | --lr 1e-3 \ 31 | --opt_betas 0.9 0.999 \ 32 | --weight_decay 0.05 \ 33 | --epochs 70 \ 34 | --dist_eval \ 35 | --test_num_segment 5 \ 36 | --test_num_crop 3 \ 37 | --num_workers 8 \ 38 | --drop_path 0.2 \ 39 | --layer_decay 0.75 \ 40 | --mixup_switch_prob 0 \ 41 | --mixup_prob 0.5 \ 42 | --reprob 0. \ 43 | --init_scale 1. \ 44 | --update_freq 6 \ 45 | --seed 0 \ 46 | --enable_deepspeed \ 47 | --warmup_epochs 5 \ 48 | 49 | echo "Job finish" -------------------------------------------------------------------------------- /util_tools/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /util_tools/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | array = np.load(array) 131 | tensor = torch.from_numpy(array) 132 | return tensor 133 | -------------------------------------------------------------------------------- /dataset/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import transforms 3 | from util_tools.transforms import * 4 | from util_tools.masking_generator import TubeMaskingGenerator 5 | from .kinetics import VideoClsDataset, VideoMAE 6 | from .ssv2 import SSVideoClsDataset 7 | from .epic import EpicVideoClsDataset 8 | 9 | 10 | class DataAugmentationForVideoMAE(object): 11 | def __init__(self, args): 12 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 13 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 14 | normalize = GroupNormalize(self.input_mean, self.input_std) 15 | self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) 16 | self.transform = transforms.Compose([ 17 | self.train_augmentation, 18 | Stack(roll=False), 19 | ToTorchFormatTensor(div=True), 20 | normalize, 21 | ]) 22 | if args.mask_type == 'tube': 23 | self.masked_position_generator = TubeMaskingGenerator( 24 | args.window_size, args.mask_ratio 25 | ) 26 | 27 | def __call__(self, images): 28 | process_data, _ = self.transform(images) 29 | return process_data, self.masked_position_generator() 30 | 31 | def __repr__(self): 32 | repr = "(DataAugmentationForVideoMAE,\n" 33 | repr += " transform = %s,\n" % str(self.transform) 34 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 35 | repr += ")" 36 | return repr 37 | 38 | 39 | def build_pretraining_dataset(args): 40 | transform = DataAugmentationForVideoMAE(args) 41 | dataset = VideoMAE( 42 | root=None, 43 | setting=args.data_path, 44 | video_ext='mp4', 45 | is_color=True, 46 | modality='rgb', 47 | new_length=args.num_frames, 48 | new_step=args.sampling_rate, 49 | transform=transform, 50 | temporal_jitter=False, 51 | video_loader=True, 52 | use_decord=True, 53 | lazy_init=False) 54 | print("Data Aug = %s" % str(transform)) 55 | return dataset 56 | 57 | 58 | def build_dataset(is_train, test_mode, args): 59 | if args.data_set == 'Kinetics-400': 60 | mode = None 61 | anno_path = args.anno_path 62 | if is_train is True: 63 | mode = 'train' 64 | anno_path = os.path.join(args.anno_path, 'train.csv') 65 | elif test_mode is True: 66 | mode = 'test' 67 | anno_path = os.path.join(args.anno_path, 'val.csv') 68 | else: 69 | mode = 'validation' 70 | anno_path = os.path.join(args.anno_path, 'val.csv') 71 | 72 | dataset = VideoClsDataset( 73 | anno_path=anno_path, 74 | data_path=args.data_path, 75 | mode=mode, 76 | clip_len=args.num_frames, 77 | frame_sample_rate=args.sampling_rate, 78 | num_segment=1, 79 | test_num_segment=args.test_num_segment, 80 | test_num_crop=args.test_num_crop, 81 | num_crop=1 if not test_mode else 3, 82 | keep_aspect_ratio=True, 83 | crop_size=args.input_size, 84 | short_side_size=args.short_side_size, 85 | new_height=256, 86 | new_width=320, 87 | args=args) 88 | nb_classes = 400 89 | 90 | elif args.data_set == 'SSV2': 91 | mode = None 92 | anno_path = None 93 | if is_train is True: 94 | mode = 'train' 95 | anno_path = os.path.join(args.anno_path, 'train.csv') 96 | elif test_mode is True: 97 | mode = 'test' 98 | anno_path = os.path.join(args.anno_path, 'val.csv') 99 | else: 100 | mode = 'validation' 101 | anno_path = os.path.join(args.anno_path, 'val.csv') 102 | 103 | dataset = SSVideoClsDataset( 104 | anno_path=anno_path, 105 | data_path=args.data_path, 106 | mode=mode, 107 | clip_len=1, 108 | num_segment=args.num_frames, 109 | test_num_segment=args.test_num_segment, 110 | test_num_crop=args.test_num_crop, 111 | num_crop=1 if not test_mode else 3, 112 | keep_aspect_ratio=True, 113 | crop_size=args.input_size, 114 | short_side_size=args.short_side_size, 115 | new_height=256, 116 | new_width=320, 117 | args=args) 118 | nb_classes = 174 119 | 120 | elif args.data_set == 'EPIC': 121 | mode = None 122 | anno_path = None 123 | if is_train is True: 124 | mode = 'train' 125 | anno_path = os.path.join(args.anno_path, 'train.csv') 126 | elif test_mode is True: 127 | mode = 'test' 128 | anno_path = os.path.join(args.anno_path, 'val.csv') 129 | else: 130 | mode = 'validation' 131 | anno_path = os.path.join(args.anno_path, 'val.csv') 132 | 133 | dataset = EpicVideoClsDataset( 134 | anno_path=anno_path, 135 | data_path=args.data_path, 136 | mode=mode, 137 | clip_len=1, 138 | num_segment=args.num_frames, 139 | test_num_segment=args.test_num_segment, 140 | test_num_crop=args.test_num_crop, 141 | num_crop=1 if not test_mode else 3, 142 | keep_aspect_ratio=True, 143 | crop_size=args.input_size, 144 | short_side_size=args.short_side_size, 145 | new_height=256, 146 | new_width=320, 147 | args=args) 148 | nb_classes = 300 149 | 150 | else: 151 | raise NotImplementedError() 152 | 153 | assert nb_classes == args.nb_classes 154 | print("Number of the class = %d" % args.nb_classes) 155 | 156 | return dataset, nb_classes -------------------------------------------------------------------------------- /util_tools/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /util_tools/optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | #from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_vit(var_name, num_max_layer): 25 | if var_name in ("cls_token", "mask_token", "pos_embed"): 26 | return 0 27 | elif var_name in ("visual.class_embedding", "visual.positional_embedding", "visual.temporal_posembed", "visual.ln_pre"): 28 | return 0 29 | elif var_name in ("clip_class_embedding", "clip_ln_pre.weight","clip_ln_pre.bias", "clip_conv1.weight", "clip_conv1.bias", "clip_positional_embedding"): 30 | return 0 31 | elif var_name.startswith("patch_embed"): 32 | return 0 33 | elif var_name.startswith("visual.conv1"): 34 | return 0 35 | elif var_name.startswith("rel_pos_bias"): 36 | return num_max_layer - 1 37 | elif var_name.startswith("blocks"): 38 | layer_id = int(var_name.split('.')[1]) 39 | return layer_id + 1 40 | elif var_name.startswith("visual.transformer"): 41 | layer_id = int(var_name.split('.')[3]) 42 | return layer_id + 1 43 | else: 44 | return num_max_layer - 1 45 | 46 | 47 | class LayerDecayValueAssigner(object): 48 | def __init__(self, values): 49 | self.values = values 50 | 51 | def get_scale(self, layer_id): 52 | return self.values[layer_id] 53 | 54 | def get_layer_id(self, var_name): 55 | return get_num_layer_for_vit(var_name, len(self.values)) 56 | 57 | 58 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 59 | parameter_group_names = {} 60 | parameter_group_vars = {} 61 | 62 | for name, param in model.named_parameters(): 63 | if not param.requires_grad: 64 | continue # frozen weights 65 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 66 | group_name = "no_decay" 67 | this_weight_decay = 0. 68 | elif any(s in name for s in skip_list): 69 | group_name = "no_decay" 70 | this_weight_decay = 0. 71 | else: 72 | group_name = "decay" 73 | this_weight_decay = weight_decay 74 | if get_num_layer is not None: 75 | layer_id = get_num_layer(name) 76 | group_name = "layer_%d_%s" % (layer_id, group_name) 77 | else: 78 | layer_id = None 79 | 80 | if group_name not in parameter_group_names: 81 | if get_layer_scale is not None: 82 | scale = get_layer_scale(layer_id) 83 | else: 84 | scale = 1. 85 | 86 | parameter_group_names[group_name] = { 87 | "weight_decay": this_weight_decay, 88 | "params": [], 89 | "lr_scale": scale 90 | } 91 | parameter_group_vars[group_name] = { 92 | "weight_decay": this_weight_decay, 93 | "params": [], 94 | "lr_scale": scale 95 | } 96 | 97 | parameter_group_vars[group_name]["params"].append(param) 98 | parameter_group_names[group_name]["params"].append(name) 99 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 100 | return list(parameter_group_vars.values()) 101 | 102 | 103 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 104 | opt_lower = args.opt.lower() 105 | weight_decay = args.weight_decay 106 | if weight_decay and filter_bias_and_bn: 107 | skip = {} 108 | if skip_list is not None: 109 | skip = skip_list 110 | elif hasattr(model, 'no_weight_decay'): 111 | skip = model.no_weight_decay() 112 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 113 | weight_decay = 0. 114 | else: 115 | parameters = model.parameters() 116 | 117 | if 'fused' in opt_lower: 118 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 119 | 120 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 121 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 122 | opt_args['eps'] = args.opt_eps 123 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 124 | opt_args['betas'] = args.opt_betas 125 | 126 | print("optimizer settings:", opt_args) 127 | 128 | opt_split = opt_lower.split('_') 129 | opt_lower = opt_split[-1] 130 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 131 | opt_args.pop('eps', None) 132 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 133 | elif opt_lower == 'momentum': 134 | opt_args.pop('eps', None) 135 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 136 | elif opt_lower == 'adam': 137 | optimizer = optim.Adam(parameters, **opt_args) 138 | elif opt_lower == 'adamw': 139 | optimizer = optim.AdamW(parameters, **opt_args) 140 | elif opt_lower == 'nadam': 141 | optimizer = Nadam(parameters, **opt_args) 142 | elif opt_lower == 'radam': 143 | optimizer = RAdam(parameters, **opt_args) 144 | elif opt_lower == 'adamp': 145 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 146 | elif opt_lower == 'sgdp': 147 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 148 | elif opt_lower == 'adadelta': 149 | optimizer = optim.Adadelta(parameters, **opt_args) 150 | elif opt_lower == 'adafactor': 151 | if not args.lr: 152 | opt_args['lr'] = None 153 | optimizer = Adafactor(parameters, **opt_args) 154 | elif opt_lower == 'adahessian': 155 | optimizer = Adahessian(parameters, **opt_args) 156 | elif opt_lower == 'rmsprop': 157 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 158 | elif opt_lower == 'rmsproptf': 159 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 160 | elif opt_lower == 'novograd': 161 | optimizer = NovoGrad(parameters, **opt_args) 162 | elif opt_lower == 'nvnovograd': 163 | optimizer = NvNovoGrad(parameters, **opt_args) 164 | elif opt_lower == 'fusedsgd': 165 | opt_args.pop('eps', None) 166 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 167 | elif opt_lower == 'fusedmomentum': 168 | opt_args.pop('eps', None) 169 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 170 | elif opt_lower == 'fusedadam': 171 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 172 | elif opt_lower == 'fusedadamw': 173 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 174 | elif opt_lower == 'fusedlamb': 175 | optimizer = FusedLAMB(parameters, **opt_args) 176 | elif opt_lower == 'fusednovograd': 177 | opt_args.setdefault('betas', (0.95, 0.98)) 178 | optimizer = FusedNovoGrad(parameters, **opt_args) 179 | else: 180 | assert False and "Invalid optimizer" 181 | raise ValueError 182 | 183 | if len(opt_split) > 1: 184 | if opt_split[0] == 'lookahead': 185 | optimizer = Lookahead(optimizer) 186 | 187 | return optimizer 188 | -------------------------------------------------------------------------------- /util_tools/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import random 5 | import numpy as np 6 | import torchvision 7 | from PIL import Image, ImageOps 8 | import numbers 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_tuple): 19 | img_group, label = img_tuple 20 | 21 | w, h = img_group[0].size 22 | th, tw = self.size 23 | 24 | out_images = list() 25 | 26 | x1 = random.randint(0, w - tw) 27 | y1 = random.randint(0, h - th) 28 | 29 | for img in img_group: 30 | assert(img.size[0] == w and img.size[1] == h) 31 | if w == tw and h == th: 32 | out_images.append(img) 33 | else: 34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 35 | 36 | return (out_images, label) 37 | 38 | 39 | class GroupCenterCrop(object): 40 | def __init__(self, size): 41 | self.worker = torchvision.transforms.CenterCrop(size) 42 | 43 | def __call__(self, img_tuple): 44 | img_group, label = img_tuple 45 | return ([self.worker(img) for img in img_group], label) 46 | 47 | 48 | class GroupNormalize(object): 49 | def __init__(self, mean, std): 50 | self.mean = mean 51 | self.std = std 52 | 53 | def __call__(self, tensor_tuple): 54 | tensor, label = tensor_tuple 55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 56 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 57 | 58 | # TODO: make efficient 59 | for t, m, s in zip(tensor, rep_mean, rep_std): 60 | t.sub_(m).div_(s) 61 | 62 | return (tensor,label) 63 | 64 | 65 | class GroupGrayScale(object): 66 | def __init__(self, size): 67 | self.worker = torchvision.transforms.Grayscale(size) 68 | 69 | def __call__(self, img_tuple): 70 | img_group, label = img_tuple 71 | return ([self.worker(img) for img in img_group], label) 72 | 73 | 74 | class GroupScale(object): 75 | """ Rescales the input PIL.Image to the given 'size'. 76 | 'size' will be the size of the smaller edge. 77 | For example, if height > width, then image will be 78 | rescaled to (size * height / width, size) 79 | size: size of the smaller edge 80 | interpolation: Default: PIL.Image.BILINEAR 81 | """ 82 | 83 | def __init__(self, size, interpolation=Image.BILINEAR): 84 | self.worker = torchvision.transforms.Resize(size, interpolation) 85 | 86 | def __call__(self, img_tuple): 87 | img_group, label = img_tuple 88 | return ([self.worker(img) for img in img_group], label) 89 | 90 | 91 | class GroupMultiScaleCrop(object): 92 | 93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 94 | self.scales = scales if scales is not None else [1, .875, .75, .66] 95 | self.max_distort = max_distort 96 | self.fix_crop = fix_crop 97 | self.more_fix_crop = more_fix_crop 98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 99 | self.interpolation = Image.BILINEAR 100 | 101 | def __call__(self, img_tuple): 102 | img_group, label = img_tuple 103 | 104 | im_size = img_group[0].size 105 | 106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 109 | return (ret_img_group, label) 110 | 111 | def _sample_crop_size(self, im_size): 112 | image_w, image_h = im_size[0], im_size[1] 113 | 114 | # find a crop size 115 | base_size = min(image_w, image_h) 116 | crop_sizes = [int(base_size * x) for x in self.scales] 117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 119 | 120 | pairs = [] 121 | for i, h in enumerate(crop_h): 122 | for j, w in enumerate(crop_w): 123 | if abs(i - j) <= self.max_distort: 124 | pairs.append((w, h)) 125 | 126 | crop_pair = random.choice(pairs) 127 | if not self.fix_crop: 128 | w_offset = random.randint(0, image_w - crop_pair[0]) 129 | h_offset = random.randint(0, image_h - crop_pair[1]) 130 | else: 131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 132 | 133 | return crop_pair[0], crop_pair[1], w_offset, h_offset 134 | 135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 137 | return random.choice(offsets) 138 | 139 | @staticmethod 140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 141 | w_step = (image_w - crop_w) // 4 142 | h_step = (image_h - crop_h) // 4 143 | 144 | ret = list() 145 | ret.append((0, 0)) # upper left 146 | ret.append((4 * w_step, 0)) # upper right 147 | ret.append((0, 4 * h_step)) # lower left 148 | ret.append((4 * w_step, 4 * h_step)) # lower right 149 | ret.append((2 * w_step, 2 * h_step)) # center 150 | 151 | if more_fix_crop: 152 | ret.append((0, 2 * h_step)) # center left 153 | ret.append((4 * w_step, 2 * h_step)) # center right 154 | ret.append((2 * w_step, 4 * h_step)) # lower center 155 | ret.append((2 * w_step, 0 * h_step)) # upper center 156 | 157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 161 | return ret 162 | 163 | 164 | class Stack(object): 165 | 166 | def __init__(self, roll=False): 167 | self.roll = roll 168 | 169 | def __call__(self, img_tuple): 170 | img_group, label = img_tuple 171 | 172 | if img_group[0].mode == 'L': 173 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 174 | elif img_group[0].mode == 'RGB': 175 | if self.roll: 176 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 177 | else: 178 | return (np.concatenate(img_group, axis=2), label) 179 | 180 | 181 | class ToTorchFormatTensor(object): 182 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 183 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 184 | def __init__(self, div=True): 185 | self.div = div 186 | 187 | def __call__(self, pic_tuple): 188 | pic, label = pic_tuple 189 | 190 | if isinstance(pic, np.ndarray): 191 | # handle numpy array 192 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 193 | else: 194 | # handle PIL Image 195 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 196 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 197 | # put it from HWC to CHW format 198 | # yikes, this transpose takes 80% of the loading time/CPU 199 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 200 | return (img.float().div(255.) if self.div else img.float(), label) 201 | 202 | 203 | class IdentityTransform(object): 204 | 205 | def __call__(self, data): 206 | return data 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # CAST: Cross-Attention in Space and Time for Video Action Recognition [[NeurIPS 2023](https://neurips.cc/virtual/2023/poster/70748)][[Project Page](https://jong980812.github.io/CAST.github.io)][[Arxiv](https://arxiv.org/abs/2311.18825)] 3 | 4 | ![CAST Framework](figs/CAST.jpg) 5 |
6 | 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/cast-cross-attention-in-space-and-time-for-1/action-recognition-on-epic-kitchens-100)](https://paperswithcode.com/sota/action-recognition-on-epic-kitchens-100?p=cast-cross-attention-in-space-and-time-for-1) 9 | 10 | ![GitHub last commit](https://img.shields.io/github/last-commit/khuvll/CAST)
11 | ![Website Status](https://img.shields.io/website?url=https://jong980812.github.io/CAST.github.io/)
12 | ![GitHub issues](https://img.shields.io/github/issues-raw/khuvll/CAST) 13 | ![GitHub closed issue](https://img.shields.io/github/issues-closed/khuvll/CAST)
14 | 15 | 16 | # :wrench: Installation 17 | 18 | We conduct all the experiments with 16 NVIDIA GeForce RTX 3090 GPUs. 19 | First, install PyTorch 1.10.0+ and torchvision 0.11.0. 20 | 21 | ``` 22 | conda create -n vmae_1.10 python=3.8 ipykernel -y 23 | conda activate vmae_1.10 24 | conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 -c pytorch 25 | ``` 26 | Then, install timm, triton, DeepSpeed, and others. 27 | ``` 28 | pip install triton==1.0.0 29 | git clone https://github.com/microsoft/DeepSpeed 30 | cd DeepSpeed 31 | git checkout 3a3dfe66bb 32 | DS_BUILD_OPS=1 pip install . --global-option="build_ext" 33 | pip install TensorboardX decord einops scipy pandas requests 34 | ds_report 35 | ``` 36 | 37 | If you have successfully installed Deepspeed, after running the 'ds_report' command, you can see the following results. 38 | For other Deepspeed-related issues, please refer to the [DeepSpeed GitHub page](https://github.com/microsoft/DeepSpeed). 39 | 40 | ![DS_REPORT](figs/ds_report.JPG) 41 | 42 | # :file_folder: Data Preparation 43 | 44 | * We report experimental results on three standard datasets.([EPIC-KITCHENS-100](https://epic-kitchens.github.io/2023), [Something-Something-V2](https://developer.qualcomm.com/software/ai-datasets/something-something), [Kinetics400](https://deepmind.com/research/open-source/kinetics)) 45 | * We provide sample annotation files -> [annotations](./annotations/). 46 | 47 | ### EPIC-KITCHENS-100 48 | - The pre-processing of **EPIC-KITCHENS-100** can be summarized into 3 steps: 49 | 50 | 1. Download the dataset from [official website](https://github.com/epic-kitchens/epic-kitchens-download-scripts). 51 | 52 | 2. Preprocess the dataset by resizing the short edge of video to **256px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2). 53 | 54 | 3. Generate annotations needed for dataloader (",," in annotations). The annotation usually includes `train.csv`, `val.csv`. The format of `*.csv` file is like:
55 | 56 | 57 | ``` 58 | video_1,verb_1,noun_1 59 | video_2,verb_2,noun_2 60 | video_3,verb_3,noun_3 61 | ... 62 | video_N,verb_N,noun_N 63 | ``` 64 | 4. All video files are located inside the DATA_PATH. 65 | 66 | ### Something-Something-V2 67 | - The pre-processing of **Something-Something-V2** can be summarized into 3 steps: 68 | 69 | 1. Download the dataset from [official website](https://developer.qualcomm.com/software/ai-datasets/something-something). 70 | 71 | 2. Preprocess the dataset by changing the video extension from `webm` to `.mp4` with the **original** height of **240px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2). 72 | 73 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv`. The format of `*.csv` file is like: 74 | 75 | ``` 76 | video_1.mp4 label_1 77 | video_2.mp4 label_2 78 | video_3.mp4 label_3 79 | ... 80 | video_N.mp4 label_N 81 | ``` 82 | 4. All video files are located inside the DATA_PATH. 83 | ### Kinetics-400 84 | - The pre-processing of **Kinetics400** can be summarized into 3 steps: 85 | 86 | 1. Download the dataset from [official website](https://deepmind.com/research/open-source/kinetics) or [OpenDataLab](https://opendatalab.com/OpenMMLab/Kinetics-400). 87 | 88 | 2. Preprocess the dataset by resizing the short edge of video to **320px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2). 89 | 90 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv`. The format of `*.csv` file is like: 91 | 92 | ``` 93 | video_1.mp4 label_1 94 | video_2.mp4 label_2 95 | video_3.mp4 label_3 96 | ... 97 | video_N.mp4 label_N 98 | ``` 99 |
100 | 4. All video files should be splited into **DATA_PATH/train** and **DATA_PATH/val**. 101 | # Expert model preparation 102 | We use the pre-trained weights of spatial and temporal experts. The pretrained weight of the spatial expert (CLIP) uses the [official weight](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt). The pre-trained weight of the temporal expert (VideoMAE) uses the pre-trained weights from the three datasets EK100, K400, and SSV2. Of these, [K400](https://drive.google.com/file/d/1MzwteHH-1yuMnFb8vRBQDvngV1Zl-d3z/view?usp=sharing) and [SSV2](https://drive.google.com/file/d/1dt_59tBIyzdZd5Ecr22lTtzs_64MOZkT/view?usp=sharing) use the [official weights](https://github.com/potatowarriors/VideoMAE/blob/main/MODEL_ZOO.md), and [EK100](https://drive.google.com/file/d/16zDs_9ycAoz8AoEecrQXsC5vLZsaQoTw/view?usp=sharing) uses the weights we pre-trained ourselves. Put each downloaded expert weight into the VMAE_PATH and CLIP_PATH of the fine-tune script. 103 | 104 | 105 | # Fine-tuning CAST 106 | 107 | We provide the **off-the-shelf** scripts in the [scripts folder](scripts). 108 | 109 | - For example, to fine-tune CAST on **Kinetics400** with 16 GPUs (2 nodes x 8 GPUs) script. 110 | 111 | ```bash 112 | DATA_PATH=YOUR_PATH 113 | VMAE_MODEL_PATH=YOUR_PATH 114 | CLIP_MODEL_PATH=YOUR_PATH 115 | 116 | 117 | OMP_NUM_THREADS=1 python -m torch.distributed.launch \ 118 | --nproc_per_node=2 \ 119 | --master_port ${YOUR_NUMBER} --nnodes=8 \ 120 | --node_rank=${YOUR_NUMBER} --master_addr=${YOUR_NUMBER} \ 121 | YOUR_PATH/run_bidirection_compo.py \ 122 | --data_set Kinetics-400 \ 123 | --nb_classes 400 \ 124 | --vmae_model compo_bidir_vit_base_patch16_224 \ 125 | --anno_path ${ANNOTATION_PATH} 126 | --data_path ${DATA_PATH} \ 127 | --clip_finetune ${CLIP_MODEL_PATH} \ 128 | --vmae_finetune ${VMAE_MODEL_PATH} \ 129 | --log_dir ${YOUR_PATH} \ 130 | --output_dir ${YOUR_PATH} \ 131 | --batch_size 6 \ 132 | --input_size 224 \ 133 | --short_side_size 224 \ 134 | --save_ckpt_freq 25 \ 135 | --num_sample 1 \ 136 | --num_frames 16 \ 137 | --opt adamw \ 138 | --lr 1e-3 \ 139 | --opt_betas 0.9 0.999 \ 140 | --weight_decay 0.05 \ 141 | --epochs 70 \ 142 | --dist_eval \ 143 | --test_num_segment 5 \ 144 | --test_num_crop 3 \ 145 | --num_workers 8 \ 146 | --drop_path 0.2 \ 147 | --layer_decay 0.75 \ 148 | --mixup_switch_prob 0 \ 149 | --mixup_prob 0.5 \ 150 | --reprob 0. \ 151 | --init_scale 1. \ 152 | --update_freq 6 \ 153 | --seed 0 \ 154 | --enable_deepspeed \ 155 | --warmup_epochs 5 \ 156 | ``` 157 | # Evaluation 158 | Evaluation commands for the EK100. 159 | ``` 160 | python ./run_bidirection_compo.py --fine_tune {YOUR_FINETUNED_WEIGHT} --composition --eval 161 | ``` 162 | Evaluation commands for the SSV2, K400. 163 | ``` 164 | python ./run_bidirection.py --fine_tune {YOUR_FINETUNED_WEIGHT} --eval 165 | ``` 166 | # Model Zoo 167 | 168 | ### EPIC-KITCHENS-100 169 | 170 | | Method | Spatial Expert | Temporal expert | Epoch | \#Frames x Clips x Crops | Fine-tune | Top-1 | 171 | | :------: | :------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :---: | 172 | | CAST | [CLIP-B/16](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) | [VideoMAE-B/16 (pre-trained on EK100)](https://drive.google.com/file/d/1DaxOctpEkmKTi873J1jzzz_Sl-0wRai7/view?usp=sharing) | 50 | 16x2x3 | [log](https://drive.google.com/file/d/1yry2Nd5BEaX3kZjNYDghGHubpei-by_9/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1pW5tMWG2N5zqQOOPcrawwQpIAhPFMoVx/view?usp=sharing)
| 49.3 | 173 | ### Something-Something V2 174 | 175 | | Method | Spatial Expert | Temporal expert | Epoch | \#Frames x Clips x Crops | Fine-tune | Top-1 | 176 | | :------: | :------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :---: | 177 | | CAST | [CLIP-B/16](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) | [VideoMAE-B/16 (pre-trained on SSV2)](https://drive.google.com/file/d/1dt_59tBIyzdZd5Ecr22lTtzs_64MOZkT/view?usp=sharing) | 50 | 16x2x3 | [log](https://drive.google.com/file/d/1wOjcXSen9B9R2CIQ8ge7HrRufr_MmMcN/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/1RrAVF4tlpZCPYNY49M0pKQfopP6p6Vir/view?usp=sharing)
| 71.6 | 178 | 179 | ### Kinetics-400 180 | 181 | | Method | Spatial Expert | Temporal expert | Epoch | \#Frames x Clips x Crops | Fine-tune | Top-1 | 182 | | :------: | :------: | :------: | :---: | :-----: | :----------------------------------------------------------: | :---: | 183 | | CAST | [CLIP-B/16](https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt) | [VideoMAE-B/16 (pre-trained on K400)](https://drive.google.com/file/d/1MzwteHH-1yuMnFb8vRBQDvngV1Zl-d3z/view?usp=sharing) | 70 | 16x5x3 | [log](https://drive.google.com/file/d/1Npw-GblhSGWVx0nU06ztjDLMFcazeCx6/view?usp=sharing)/[checkpoint](https://drive.google.com/file/d/16ndsBVNjRuJMRM40P0-a3Q1JVA8tNLK7/view?usp=sharing)
| 85.3 | 184 | 185 | 186 | ## Acknowledgements 187 | 188 | This project is built upon [VideoMAE](https://github.com/MCG-NJU/VideoMAE), [MAE](https://github.com/pengzhiliang/MAE-pytorch), [CLIP](https://github.com/openai/CLIP) and [BEiT](https://github.com/microsoft/unilm/tree/master/beit). Thanks to the contributors of these great codebases. 189 | 190 | ## License 191 | 192 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](https://github.com/MCG-NJU/VideoMAE/blob/main/LICENSE) for details. 193 | 194 | ## Citation 195 | ``` 196 | @article{cast, 197 | title={CAST: Cross-Attention in Space and Time for Video Action Recognition}, 198 | author={Lee, Dongho and Lee, Jongseo and Choi, Jinwoo}, 199 | booktitle={NeurIPS}}, 200 | year={2023} 201 | ``` 202 | -------------------------------------------------------------------------------- /engine_for_onemodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import sys 5 | from typing import Iterable, Optional 6 | import torch 7 | from util_tools.mixup import Mixup 8 | from timm.utils import accuracy, ModelEma 9 | import util_tools.utils as utils 10 | from scipy.special import softmax 11 | from einops import rearrange 12 | 13 | def cross_train_class_batch(model, samples, target, criterion): 14 | outputs = model(samples) 15 | loss = criterion(outputs, target) 16 | return loss, outputs 17 | 18 | 19 | def get_loss_scale_for_deepspeed(model): 20 | optimizer = model.optimizer 21 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 22 | 23 | 24 | def train_one_epoch(args, model: torch.nn.Module, criterion: torch.nn.Module, 25 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 26 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 27 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 28 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 29 | num_training_steps_per_epoch=None, update_freq=None): 30 | model.train(True) 31 | metric_logger = utils.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | header = 'Epoch: [{}]'.format(epoch) 35 | print_freq = 10 36 | 37 | if loss_scaler is None: 38 | model.zero_grad() 39 | model.micro_steps = 0 40 | else: 41 | optimizer.zero_grad() 42 | 43 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | step = data_iter_step // update_freq 45 | if step >= num_training_steps_per_epoch: 46 | continue 47 | it = start_steps + step # global training iteration 48 | # Update LR & WD for the first acc 49 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 50 | for i, param_group in enumerate(optimizer.param_groups): 51 | if lr_schedule_values is not None: 52 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 53 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 54 | param_group["weight_decay"] = wd_schedule_values[it] 55 | 56 | samples = samples.to(device, non_blocking=True) 57 | targets = targets.to(device, non_blocking=True) 58 | 59 | if mixup_fn is not None: 60 | samples, targets = mixup_fn(samples, targets) 61 | 62 | if loss_scaler is None: 63 | samples = samples.half() 64 | loss, output = cross_train_class_batch( 65 | model, samples, targets, criterion) 66 | else: 67 | with torch.cuda.amp.autocast(): 68 | samples = samples.half() 69 | loss, output = cross_train_class_batch( 70 | model, samples, targets, criterion) 71 | loss_value = loss.item() 72 | 73 | if not math.isfinite(loss_value): 74 | print("Loss is {}, stopping training".format(loss_value)) 75 | sys.exit(1) 76 | 77 | if loss_scaler is None: 78 | loss /= update_freq 79 | model.backward(loss) 80 | model.step() 81 | 82 | if (data_iter_step + 1) % update_freq == 0: 83 | # model.zero_grad() 84 | # Deepspeed will call step() & model.zero_grad() automatic 85 | if model_ema is not None: 86 | model_ema.update(model) 87 | grad_norm = None 88 | loss_scale_value = get_loss_scale_for_deepspeed(model) 89 | else: 90 | # this attribute is added by timm on one optimizer (adahessian) 91 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 92 | loss /= update_freq 93 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 94 | parameters=model.parameters(), create_graph=is_second_order, 95 | update_grad=(data_iter_step + 1) % update_freq == 0) 96 | if (data_iter_step + 1) % update_freq == 0: 97 | optimizer.zero_grad() 98 | if model_ema is not None: 99 | model_ema.update(model) 100 | loss_scale_value = loss_scaler.state_dict()["scale"] 101 | 102 | torch.cuda.synchronize() 103 | 104 | if mixup_fn is None: 105 | class_acc = (output.max(-1)[-1] == targets).float().mean() 106 | else: 107 | class_acc = None 108 | metric_logger.update(loss=loss_value) 109 | metric_logger.update(class_acc=class_acc) 110 | metric_logger.update(loss_scale=loss_scale_value) 111 | min_lr = 10. 112 | max_lr = 0. 113 | for group in optimizer.param_groups: 114 | min_lr = min(min_lr, group["lr"]) 115 | max_lr = max(max_lr, group["lr"]) 116 | 117 | metric_logger.update(lr=max_lr) 118 | metric_logger.update(min_lr=min_lr) 119 | weight_decay_value = None 120 | for group in optimizer.param_groups: 121 | if group["weight_decay"] > 0: 122 | weight_decay_value = group["weight_decay"] 123 | metric_logger.update(weight_decay=weight_decay_value) 124 | metric_logger.update(grad_norm=grad_norm) 125 | 126 | if log_writer is not None: 127 | log_writer.update(loss=loss_value, head="loss") 128 | log_writer.update(class_acc=class_acc, head="loss") 129 | log_writer.update(loss_scale=loss_scale_value, head="opt") 130 | log_writer.update(lr=max_lr, head="opt") 131 | log_writer.update(min_lr=min_lr, head="opt") 132 | log_writer.update(weight_decay=weight_decay_value, head="opt") 133 | log_writer.update(grad_norm=grad_norm, head="opt") 134 | 135 | log_writer.set_step() 136 | 137 | # gather the stats from all processes 138 | metric_logger.synchronize_between_processes() 139 | print("Averaged stats:", metric_logger) 140 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 141 | 142 | 143 | @torch.no_grad() 144 | def validation_one_epoch(args, data_loader, model, device): 145 | criterion = torch.nn.CrossEntropyLoss() 146 | 147 | metric_logger = utils.MetricLogger(delimiter=" ") 148 | header = 'Val:' 149 | 150 | # switch to evaluation mode 151 | model.eval() 152 | 153 | for batch in metric_logger.log_every(data_loader, 10, header): 154 | samples = batch[0] 155 | target = batch[1] 156 | batch_size = samples.shape[0] 157 | samples = samples.to(device, non_blocking=True) 158 | target = target.to(device, non_blocking=True) 159 | 160 | # compute output 161 | with torch.cuda.amp.autocast(): 162 | output = model(samples) 163 | loss = criterion(output, target) 164 | 165 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 166 | 167 | metric_logger.update(loss=loss.item()) 168 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 169 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 170 | # gather the stats from all processes 171 | metric_logger.synchronize_between_processes() 172 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 173 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 174 | 175 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 176 | 177 | 178 | 179 | @torch.no_grad() 180 | def final_test(args, data_loader, model, device, file): 181 | criterion = torch.nn.CrossEntropyLoss() 182 | 183 | metric_logger = utils.MetricLogger(delimiter=" ") 184 | header = 'Test:' 185 | 186 | # switch to evaluation mode 187 | model.eval() 188 | final_result = [] 189 | 190 | for batch in metric_logger.log_every(data_loader, 10, header): 191 | samples = batch[0] 192 | target = batch[1] 193 | ids = batch[2] 194 | chunk_nb = batch[3] 195 | split_nb = batch[4] 196 | batch_size = samples.shape[0] 197 | samples = samples.to(device, non_blocking=True) 198 | target = target.to(device, non_blocking=True) 199 | 200 | # compute output 201 | with torch.cuda.amp.autocast(): 202 | output = model(samples) 203 | loss = criterion(output, target) 204 | 205 | for i in range(output.size(0)): 206 | string = "{} {} {} {} {}\n".format(ids[i], \ 207 | str(output.data[i].cpu().numpy().tolist()), \ 208 | str(int(target[i].cpu().numpy())), \ 209 | str(int(chunk_nb[i].cpu().numpy())), \ 210 | str(int(split_nb[i].cpu().numpy()))) 211 | final_result.append(string) 212 | 213 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 214 | 215 | metric_logger.update(loss=loss.item()) 216 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 217 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 218 | 219 | if not os.path.exists(file): 220 | os.mknod(file) 221 | with open(file, 'w') as f: 222 | f.write("{}, {}\n".format(acc1, acc5)) 223 | for line in final_result: 224 | f.write(line) 225 | # gather the stats from all processes 226 | metric_logger.synchronize_between_processes() 227 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 228 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 229 | 230 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 231 | 232 | 233 | def merge(eval_path, num_tasks): 234 | dict_feats = {} 235 | dict_label = {} 236 | dict_pos = {} 237 | print("Reading individual output files") 238 | 239 | for x in range(num_tasks): 240 | file = os.path.join(eval_path, str(x) + '.txt') 241 | lines = open(file, 'r').readlines()[1:] 242 | for line in lines: 243 | line = line.strip() 244 | name = line.split('[')[0] 245 | label = line.split(']')[1].split(' ')[1] 246 | chunk_nb = line.split(']')[1].split(' ')[2] 247 | split_nb = line.split(']')[1].split(' ')[3] 248 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',') 249 | data = softmax(data) 250 | if not name in dict_feats: 251 | dict_feats[name] = [] 252 | dict_label[name] = 0 253 | dict_pos[name] = [] 254 | if chunk_nb + split_nb in dict_pos[name]: 255 | continue 256 | dict_feats[name].append(data) 257 | dict_pos[name].append(chunk_nb + split_nb) 258 | dict_label[name] = label 259 | print("Computing final results") 260 | 261 | input_lst = [] 262 | print(len(dict_feats)) 263 | for i, item in enumerate(dict_feats): 264 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 265 | from multiprocessing import Pool 266 | p = Pool(64) 267 | ans = p.map(compute_video, input_lst) 268 | top1 = [x[1] for x in ans] 269 | top5 = [x[2] for x in ans] 270 | pred = [x[0] for x in ans] 271 | label = [x[3] for x in ans] 272 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 273 | return final_top1*100 ,final_top5*100 274 | 275 | def compute_video(lst): 276 | i, video_id, data, label = lst 277 | feat = [x for x in data] 278 | feat = np.mean(feat, axis=0) 279 | pred = np.argmax(feat) 280 | top1 = (int(pred) == int(label)) * 1.0 281 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 282 | return [pred, top1, top5, int(label)] 283 | -------------------------------------------------------------------------------- /util_tools/mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 27 | return y1 * lam + y2 * (1. - lam) 28 | 29 | 30 | def rand_bbox(img_shape, lam, margin=0., count=None): 31 | """ Standard CutMix bounding-box 32 | Generates a random square bbox based on lambda value. This impl includes 33 | support for enforcing a border margin as percent of bbox dimensions. 34 | 35 | Args: 36 | img_shape (tuple): Image shape as tuple 37 | lam (float): Cutmix lambda value 38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 39 | count (int): Number of bbox to generate 40 | """ 41 | ratio = np.sqrt(1 - lam) 42 | img_h, img_w = img_shape[-2:] 43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 47 | yl = np.clip(cy - cut_h // 2, 0, img_h) 48 | yh = np.clip(cy + cut_h // 2, 0, img_h) 49 | xl = np.clip(cx - cut_w // 2, 0, img_w) 50 | xh = np.clip(cx + cut_w // 2, 0, img_w) 51 | return yl, yh, xl, xh 52 | 53 | 54 | def rand_bbox_minmax(img_shape, minmax, count=None): 55 | """ Min-Max CutMix bounding-box 56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 57 | based on min/max percent values applied to each dimension of the input image. 58 | 59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 60 | 61 | Args: 62 | img_shape (tuple): Image shape as tuple 63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 64 | count (int): Number of bbox to generate 65 | """ 66 | assert len(minmax) == 2 67 | img_h, img_w = img_shape[-2:] 68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 70 | yl = np.random.randint(0, img_h - cut_h, size=count) 71 | xl = np.random.randint(0, img_w - cut_w, size=count) 72 | yu = yl + cut_h 73 | xu = xl + cut_w 74 | return yl, yu, xl, xu 75 | 76 | 77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 78 | """ Generate bbox and apply lambda correction. 79 | """ 80 | if ratio_minmax is not None: 81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 82 | else: 83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 84 | if correct_lam or ratio_minmax is not None: 85 | bbox_area = (yu - yl) * (xu - xl) 86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 87 | return (yl, yu, xl, xu), lam 88 | 89 | 90 | class Mixup: 91 | """ Mixup/Cutmix that applies different params to each element or whole batch 92 | 93 | Args: 94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 97 | prob (float): probability of applying mixup or cutmix per batch or element 98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 101 | label_smoothing (float): apply label smoothing to the mixed target tensor 102 | num_classes (int): number of classes for target 103 | """ 104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, composition=False): 106 | self.mixup_alpha = mixup_alpha 107 | self.cutmix_alpha = cutmix_alpha 108 | self.cutmix_minmax = cutmix_minmax 109 | if self.cutmix_minmax is not None: 110 | assert len(self.cutmix_minmax) == 2 111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 112 | self.cutmix_alpha = 1.0 113 | self.mix_prob = prob 114 | self.switch_prob = switch_prob 115 | self.label_smoothing = label_smoothing 116 | self.num_classes = num_classes 117 | self.mode = mode 118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 120 | self.composition = composition 121 | 122 | def _params_per_elem(self, batch_size): 123 | lam = np.ones(batch_size, dtype=np.float32) 124 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 125 | if self.mixup_enabled: 126 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 127 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 128 | lam_mix = np.where( 129 | use_cutmix, 130 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 131 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 132 | elif self.mixup_alpha > 0.: 133 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 134 | elif self.cutmix_alpha > 0.: 135 | use_cutmix = np.ones(batch_size, dtype=np.bool) 136 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 137 | else: 138 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 139 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 140 | return lam, use_cutmix 141 | 142 | def _params_per_batch(self): 143 | lam = 1. 144 | use_cutmix = False 145 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 146 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 147 | use_cutmix = np.random.rand() < self.switch_prob 148 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 149 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 150 | elif self.mixup_alpha > 0.: 151 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 152 | elif self.cutmix_alpha > 0.: 153 | use_cutmix = True 154 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 155 | else: 156 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 157 | lam = float(lam_mix) 158 | return lam, use_cutmix 159 | 160 | def _mix_elem(self, x): 161 | batch_size = len(x) 162 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 163 | x_orig = x.clone() # need to keep an unmodified original for mixing source 164 | for i in range(batch_size): 165 | j = batch_size - i - 1 166 | lam = lam_batch[i] 167 | if lam != 1.: 168 | if use_cutmix[i]: 169 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 170 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 171 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh] 172 | lam_batch[i] = lam 173 | else: 174 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 175 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 176 | 177 | def _mix_pair(self, x): 178 | batch_size = len(x) 179 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 180 | x_orig = x.clone() # need to keep an unmodified original for mixing source 181 | for i in range(batch_size // 2): 182 | j = batch_size - i - 1 183 | lam = lam_batch[i] 184 | if lam != 1.: 185 | if use_cutmix[i]: 186 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 187 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 188 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 189 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 190 | lam_batch[i] = lam 191 | else: 192 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 193 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 194 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 195 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 196 | 197 | def _mix_batch(self, x): 198 | lam, use_cutmix = self._params_per_batch() 199 | if lam == 1.: 200 | return 1. 201 | if use_cutmix: 202 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 203 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 204 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 205 | else: 206 | x_flipped = x.flip(0).mul_(1. - lam) 207 | x.mul_(lam).add_(x_flipped) 208 | return lam 209 | 210 | def __call__(self, x, target): 211 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 212 | if self.mode == 'elem': 213 | lam = self._mix_elem(x) 214 | elif self.mode == 'pair': 215 | lam = self._mix_pair(x) 216 | else: 217 | lam = self._mix_batch(x) 218 | if self.composition: 219 | target_noun = mixup_target(target[:,0], 300, lam, self.label_smoothing, x.device) 220 | target_verb = mixup_target(target[:,1], 97, lam, self.label_smoothing, x.device) 221 | return x, target_noun, target_verb 222 | else: 223 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 224 | return x, target 225 | 226 | 227 | class FastCollateMixup(Mixup): 228 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 229 | 230 | A Mixup impl that's performed while collating the batches. 231 | """ 232 | 233 | def _mix_elem_collate(self, output, batch, half=False): 234 | batch_size = len(batch) 235 | num_elem = batch_size // 2 if half else batch_size 236 | assert len(output) == num_elem 237 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 238 | for i in range(num_elem): 239 | j = batch_size - i - 1 240 | lam = lam_batch[i] 241 | mixed = batch[i][0] 242 | if lam != 1.: 243 | if use_cutmix[i]: 244 | if not half: 245 | mixed = mixed.copy() 246 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 247 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 248 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 249 | lam_batch[i] = lam 250 | else: 251 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 252 | np.rint(mixed, out=mixed) 253 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 254 | if half: 255 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 256 | return torch.tensor(lam_batch).unsqueeze(1) 257 | 258 | def _mix_pair_collate(self, output, batch): 259 | batch_size = len(batch) 260 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 261 | for i in range(batch_size // 2): 262 | j = batch_size - i - 1 263 | lam = lam_batch[i] 264 | mixed_i = batch[i][0] 265 | mixed_j = batch[j][0] 266 | assert 0 <= lam <= 1.0 267 | if lam < 1.: 268 | if use_cutmix[i]: 269 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 270 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 271 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 272 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 273 | mixed_j[:, yl:yh, xl:xh] = patch_i 274 | lam_batch[i] = lam 275 | else: 276 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 277 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 278 | mixed_i = mixed_temp 279 | np.rint(mixed_j, out=mixed_j) 280 | np.rint(mixed_i, out=mixed_i) 281 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 282 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 283 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 284 | return torch.tensor(lam_batch).unsqueeze(1) 285 | 286 | def _mix_batch_collate(self, output, batch): 287 | batch_size = len(batch) 288 | lam, use_cutmix = self._params_per_batch() 289 | if use_cutmix: 290 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 291 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 292 | for i in range(batch_size): 293 | j = batch_size - i - 1 294 | mixed = batch[i][0] 295 | if lam != 1.: 296 | if use_cutmix: 297 | mixed = mixed.copy() # don't want to modify the original while iterating 298 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh] 299 | else: 300 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 301 | np.rint(mixed, out=mixed) 302 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 303 | return lam 304 | 305 | def __call__(self, batch, _=None): 306 | batch_size = len(batch) 307 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 308 | half = 'half' in self.mode 309 | if half: 310 | batch_size //= 2 311 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 312 | if self.mode == 'elem' or self.mode == 'half': 313 | lam = self._mix_elem_collate(output, batch, half=half) 314 | elif self.mode == 'pair': 315 | lam = self._mix_pair_collate(output, batch) 316 | else: 317 | lam = self._mix_batch_collate(output, batch) 318 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 319 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 320 | target = target[:batch_size] 321 | return output, target 322 | 323 | -------------------------------------------------------------------------------- /dataset/ssv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from statistics import NormalDist 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | from util_tools.random_erasing import RandomErasing 7 | import warnings 8 | from decord import VideoReader, cpu 9 | from torch.utils.data import Dataset 10 | import util_tools.video_transforms as video_transforms 11 | import util_tools.volume_transforms as volume_transforms 12 | 13 | 14 | class SSVideoClsDataset(Dataset): 15 | """Load your own video classification dataset.""" 16 | 17 | def __init__(self, anno_path, data_path, mode='train', clip_len=8, 18 | crop_size=224, short_side_size=256, new_height=256, 19 | new_width=340, keep_aspect_ratio=True, num_segment=1, 20 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None): 21 | self.anno_path = anno_path 22 | self.data_path = data_path 23 | self.mode = mode 24 | self.clip_len = clip_len 25 | self.crop_size = crop_size 26 | self.short_side_size = short_side_size 27 | self.new_height = new_height 28 | self.new_width = new_width 29 | self.keep_aspect_ratio = keep_aspect_ratio 30 | self.num_segment = num_segment 31 | self.test_num_segment = test_num_segment 32 | self.num_crop = num_crop 33 | self.test_num_crop = test_num_crop 34 | self.args = args 35 | self.aug = False 36 | self.rand_erase = False 37 | if self.mode in ['train']: 38 | self.aug = True 39 | if self.args.reprob > 0: 40 | self.rand_erase = True 41 | if VideoReader is None: 42 | raise ImportError("Unable to import `decord` which is required to read videos.") 43 | 44 | import pandas as pd 45 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ') 46 | self.dataset_samples = list(cleaned.values[:, 0]) 47 | self.label_array = list(cleaned.values[:, 1]) 48 | 49 | if (mode == 'train'): 50 | pass 51 | 52 | elif (mode == 'validation'): 53 | self.data_transform = video_transforms.Compose([ 54 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'), 55 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 56 | volume_transforms.ClipToTensor(), 57 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 58 | std=[0.229, 0.224, 0.225]) 59 | ]) 60 | elif mode == 'test': 61 | self.data_resize = video_transforms.Compose([ 62 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear') 63 | ]) 64 | self.data_transform = video_transforms.Compose([ 65 | volume_transforms.ClipToTensor(), 66 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 67 | std=[0.229, 0.224, 0.225]) 68 | ]) 69 | self.test_seg = [] 70 | self.test_dataset = [] 71 | self.test_label_array = [] 72 | for ck in range(self.test_num_segment): 73 | for cp in range(self.test_num_crop): 74 | for idx in range(len(self.label_array)): 75 | sample_label = self.label_array[idx] 76 | self.test_label_array.append(sample_label) 77 | self.test_dataset.append(self.dataset_samples[idx]) 78 | self.test_seg.append((ck, cp)) 79 | 80 | def __getitem__(self, index): 81 | if self.mode == 'train': 82 | args = self.args 83 | scale_t = 1 84 | 85 | sample = self.dataset_samples[index] 86 | sample = os.path.join(self.data_path,sample)# self.data_path + '/videos_train/' + sample 87 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C 88 | if len(buffer) == 0: 89 | while len(buffer) == 0: 90 | warnings.warn("video {} not correctly loaded during training".format(sample)) 91 | index = np.random.randint(self.__len__()) 92 | sample = self.dataset_samples[index] 93 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) 94 | 95 | if args.num_sample > 1: 96 | frame_list = [] 97 | label_list = [] 98 | index_list = [] 99 | for _ in range(args.num_sample): 100 | new_frames = self._aug_frame(buffer, args) 101 | label = self.label_array[index] 102 | frame_list.append(new_frames) 103 | label_list.append(label) 104 | index_list.append(index) 105 | return frame_list, label_list, index_list, {} 106 | else: 107 | buffer = self._aug_frame(buffer, args) 108 | 109 | return buffer, self.label_array[index], index, {} 110 | 111 | elif self.mode == 'validation': 112 | sample = self.dataset_samples[index] 113 | sample = os.path.join(self.data_path,sample)# self.data_path + '/videos_train/' + sample 114 | buffer = self.loadvideo_decord(sample) 115 | if len(buffer) == 0: 116 | while len(buffer) == 0: 117 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 118 | index = np.random.randint(self.__len__()) 119 | sample = self.dataset_samples[index] 120 | buffer = self.loadvideo_decord(sample) 121 | buffer = self.data_transform(buffer) 122 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 123 | 124 | elif self.mode == 'test': 125 | sample = self.test_dataset[index] 126 | sample = os.path.join(self.data_path,sample)# self.data_path + '/videos_train/' + sample 127 | chunk_nb, split_nb = self.test_seg[index] 128 | buffer = self.loadvideo_decord(sample) 129 | 130 | while len(buffer) == 0: 131 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 132 | str(self.test_dataset[index]), chunk_nb, split_nb)) 133 | index = np.random.randint(self.__len__()) 134 | sample = self.test_dataset[index] 135 | chunk_nb, split_nb = self.test_seg[index] 136 | buffer = self.loadvideo_decord(sample) 137 | 138 | buffer = self.data_resize(buffer) 139 | if isinstance(buffer, list): 140 | buffer = np.stack(buffer, 0) 141 | 142 | if self.test_num_crop == 1: 143 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 144 | / (self.test_num_crop) 145 | else: 146 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 147 | / (self.test_num_crop - 1) 148 | temporal_start = chunk_nb # 0/1 149 | spatial_start = int(split_nb * spatial_step) 150 | if buffer.shape[1] >= buffer.shape[2]: 151 | buffer = buffer[temporal_start::2, \ 152 | spatial_start:spatial_start + self.short_side_size, :, :] 153 | else: 154 | buffer = buffer[temporal_start::2, \ 155 | :, spatial_start:spatial_start + self.short_side_size, :] 156 | 157 | buffer = self.data_transform(buffer) 158 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 159 | chunk_nb, split_nb 160 | else: 161 | raise NameError('mode {} unkown'.format(self.mode)) 162 | 163 | def _aug_frame( 164 | self, 165 | buffer, 166 | args, 167 | ): 168 | 169 | aug_transform = video_transforms.create_random_augment( 170 | input_size=(self.crop_size, self.crop_size), 171 | auto_augment=args.aa, 172 | interpolation=args.train_interpolation, 173 | ) 174 | 175 | buffer = [ 176 | transforms.ToPILImage()(frame) for frame in buffer 177 | ] 178 | 179 | buffer = aug_transform(buffer) 180 | 181 | buffer = [transforms.ToTensor()(img) for img in buffer] 182 | buffer = torch.stack(buffer) # T C H W 183 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 184 | 185 | # T H W C 186 | buffer = tensor_normalize( 187 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 188 | ) 189 | # T H W C -> C T H W. 190 | buffer = buffer.permute(3, 0, 1, 2) 191 | # Perform data augmentation. 192 | scl, asp = ( 193 | [0.08, 1.0], 194 | [0.75, 1.3333], 195 | ) 196 | 197 | buffer = spatial_sampling( 198 | buffer, 199 | spatial_idx=-1, 200 | min_scale=256, 201 | max_scale=320, 202 | crop_size=self.crop_size, 203 | random_horizontal_flip=False if args.data_set == 'SSV2' else True, 204 | inverse_uniform_sampling=False, 205 | aspect_ratio=asp, 206 | scale=scl, 207 | motion_shift=False 208 | ) 209 | 210 | if self.rand_erase: 211 | erase_transform = RandomErasing( 212 | args.reprob, 213 | mode=args.remode, 214 | max_count=args.recount, 215 | num_splits=args.recount, 216 | device="cpu", 217 | ) 218 | buffer = buffer.permute(1, 0, 2, 3) 219 | buffer = erase_transform(buffer) 220 | buffer = buffer.permute(1, 0, 2, 3) 221 | 222 | return buffer 223 | 224 | 225 | def loadvideo_decord(self, sample, sample_rate_scale=1): 226 | """Load video content using Decord""" 227 | fname = sample 228 | 229 | if not (os.path.exists(fname)): 230 | return [] 231 | 232 | # avoid hanging issue 233 | if os.path.getsize(fname) < 1 * 1024: 234 | print('SKIP: ', fname, " - ", os.path.getsize(fname)) 235 | return [] 236 | try: 237 | if self.keep_aspect_ratio: 238 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 239 | else: 240 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 241 | num_threads=1, ctx=cpu(0)) 242 | except: 243 | print("video cannot be loaded by decord: ", fname) 244 | return [] 245 | 246 | if self.mode == 'test': 247 | all_index = [] 248 | tick = len(vr) / float(self.num_segment) 249 | all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] + 250 | [int(tick * x) for x in range(self.num_segment)])) 251 | while len(all_index) < (self.num_segment * self.test_num_segment): 252 | all_index.append(all_index[-1]) 253 | all_index = list(np.sort(np.array(all_index))) 254 | vr.seek(0) 255 | buffer = vr.get_batch(all_index).asnumpy() 256 | return buffer 257 | 258 | # handle temporal segments 259 | average_duration = len(vr) // self.num_segment 260 | all_index = [] 261 | if average_duration > 0: 262 | all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration, 263 | size=self.num_segment)) 264 | elif len(vr) > self.num_segment: 265 | all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment))) 266 | else: 267 | all_index += list(np.zeros((self.num_segment,))) 268 | all_index = list(np.array(all_index)) 269 | vr.seek(0) 270 | buffer = vr.get_batch(all_index).asnumpy() 271 | return buffer 272 | 273 | def __len__(self): 274 | if self.mode != 'test': 275 | return len(self.dataset_samples) 276 | else: 277 | return len(self.test_dataset) 278 | 279 | 280 | def spatial_sampling( 281 | frames, 282 | spatial_idx=-1, 283 | min_scale=256, 284 | max_scale=320, 285 | crop_size=224, 286 | random_horizontal_flip=True, 287 | inverse_uniform_sampling=False, 288 | aspect_ratio=None, 289 | scale=None, 290 | motion_shift=False, 291 | ): 292 | """ 293 | Perform spatial sampling on the given video frames. If spatial_idx is 294 | -1, perform random scale, random crop, and random flip on the given 295 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 296 | with the given spatial_idx. 297 | Args: 298 | frames (tensor): frames of images sampled from the video. The 299 | dimension is `num frames` x `height` x `width` x `channel`. 300 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 301 | or 2, perform left, center, right crop if width is larger than 302 | height, and perform top, center, buttom crop if height is larger 303 | than width. 304 | min_scale (int): the minimal size of scaling. 305 | max_scale (int): the maximal size of scaling. 306 | crop_size (int): the size of height and width used to crop the 307 | frames. 308 | inverse_uniform_sampling (bool): if True, sample uniformly in 309 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 310 | scale. If False, take a uniform sample from [min_scale, 311 | max_scale]. 312 | aspect_ratio (list): Aspect ratio range for resizing. 313 | scale (list): Scale range for resizing. 314 | motion_shift (bool): Whether to apply motion shift for resizing. 315 | Returns: 316 | frames (tensor): spatially sampled frames. 317 | """ 318 | assert spatial_idx in [-1, 0, 1, 2] 319 | if spatial_idx == -1: 320 | if aspect_ratio is None and scale is None: 321 | frames, _ = video_transforms.random_short_side_scale_jitter( 322 | images=frames, 323 | min_size=min_scale, 324 | max_size=max_scale, 325 | inverse_uniform_sampling=inverse_uniform_sampling, 326 | ) 327 | frames, _ = video_transforms.random_crop(frames, crop_size) 328 | else: 329 | transform_func = ( 330 | video_transforms.random_resized_crop_with_shift 331 | if motion_shift 332 | else video_transforms.random_resized_crop 333 | ) 334 | frames = transform_func( 335 | images=frames, 336 | target_height=crop_size, 337 | target_width=crop_size, 338 | scale=scale, 339 | ratio=aspect_ratio, 340 | ) 341 | if random_horizontal_flip: 342 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 343 | else: 344 | # The testing is deterministic and no jitter should be performed. 345 | # min_scale, max_scale, and crop_size are expect to be the same. 346 | assert len({min_scale, max_scale, crop_size}) == 1 347 | frames, _ = video_transforms.random_short_side_scale_jitter( 348 | frames, min_scale, max_scale 349 | ) 350 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 351 | return frames 352 | 353 | 354 | def tensor_normalize(tensor, mean, std): 355 | """ 356 | Normalize a given tensor by subtracting the mean and dividing the std. 357 | Args: 358 | tensor (tensor): tensor to normalize. 359 | mean (tensor or list): mean value to subtract. 360 | std (tensor or list): std to divide. 361 | """ 362 | if tensor.dtype == torch.uint8: 363 | tensor = tensor.float() 364 | tensor = tensor / 255.0 365 | if type(mean) == list: 366 | mean = torch.tensor(mean) 367 | if type(std) == list: 368 | std = torch.tensor(std) 369 | tensor = tensor - mean 370 | tensor = tensor / std 371 | return tensor -------------------------------------------------------------------------------- /dataset/epic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from statistics import NormalDist 3 | import numpy as np 4 | import torch 5 | from torchvision import transforms 6 | from util_tools.random_erasing import RandomErasing 7 | import warnings 8 | from decord import VideoReader, cpu 9 | from torch.utils.data import Dataset 10 | import util_tools.video_transforms as video_transforms 11 | import util_tools.volume_transforms as volume_transforms 12 | 13 | class EpicVideoClsDataset(Dataset): 14 | 15 | def __init__(self, anno_path, data_path, mode='train', clip_len=8, 16 | crop_size=224, short_side_size=256, new_height=256, 17 | new_width=340, keep_aspect_ratio=True, num_segment=1, 18 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None): 19 | self.anno_path = anno_path 20 | self.data_path = data_path 21 | self.mode = mode 22 | self.clip_len = clip_len 23 | self.crop_size = crop_size 24 | self.short_side_size = short_side_size 25 | self.new_height = new_height 26 | self.new_width = new_width 27 | self.keep_aspect_ratio = keep_aspect_ratio 28 | self.num_segment = num_segment 29 | self.test_num_segment = test_num_segment 30 | self.num_crop = num_crop 31 | self.test_num_crop = test_num_crop 32 | self.args = args 33 | self.aug = False 34 | self.rand_erase = False 35 | if self.mode in ['train']: 36 | self.aug = True 37 | if self.args.reprob > 0: 38 | self.rand_erase = True 39 | if VideoReader is None: 40 | raise ImportError("Unable to import `decord` which is required to read videos.") 41 | 42 | import pandas as pd 43 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=',') 44 | self.dataset_samples = list(cleaned.values[:, 0]) 45 | verb_label_array = list(cleaned.values[:, 1]) # verb 46 | noun_label_array = list(cleaned.values[:, 2]) # noun 47 | self.label_array = np.stack((noun_label_array, verb_label_array), axis=1) # label [noun, verb] sequence 48 | 49 | if (mode == 'train'): 50 | pass 51 | 52 | elif (mode == 'validation'): 53 | self.data_transform = video_transforms.Compose([ 54 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'), 55 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 56 | volume_transforms.ClipToTensor(), 57 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 58 | std=[0.229, 0.224, 0.225]) 59 | ]) 60 | elif (mode == 'test'): 61 | self.data_resize = video_transforms.Compose([ 62 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear') 63 | ]) 64 | self.data_transform = video_transforms.Compose([ 65 | volume_transforms.ClipToTensor(), 66 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 67 | std=[0.229, 0.224, 0.225]) 68 | ]) 69 | self.test_seg = [] 70 | self.test_dataset = [] 71 | self.test_label_array = [] 72 | for ck in range(self.test_num_segment): 73 | for cp in range(self.test_num_crop): 74 | for idx in range(len(self.label_array)): 75 | sample_label = self.label_array[idx] 76 | self.test_label_array.append(sample_label) 77 | self.test_dataset.append(self.dataset_samples[idx]) 78 | self.test_seg.append((ck, cp)) 79 | 80 | def __getitem__(self, index): 81 | if self.mode == 'train': 82 | args = self.args 83 | scale_t = 1 84 | 85 | sample = self.dataset_samples[index] + '.mp4' 86 | sample = os.path.join(self.data_path, sample) 87 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C 88 | if len(buffer) == 0: 89 | while len(buffer) == 0: 90 | warnings.warn("video {} not correctly loaded during training".format(sample)) 91 | index = np.random.randint(self.__len__()) 92 | sample = self.dataset_samples[index] 93 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) 94 | 95 | if args.num_sample > 1: 96 | frame_list = [] 97 | label_list = [] 98 | index_list = [] 99 | for _ in range(args.num_sample): 100 | new_frames = self._aug_frame(buffer, args) 101 | label = self.label_array[index] 102 | frame_list.append(new_frames) 103 | label_list.append(label) 104 | index_list.append(index) 105 | return frame_list, label_list, index_list, {} 106 | else: 107 | buffer = self._aug_frame(buffer, args) 108 | 109 | return buffer, self.label_array[index], index, {} 110 | 111 | elif self.mode == 'validation': 112 | sample = self.dataset_samples[index] + '.mp4' 113 | sample = os.path.join(self.data_path, sample) 114 | buffer = self.loadvideo_decord(sample) 115 | if len(buffer) == 0: 116 | while len(buffer) == 0: 117 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 118 | index = np.random.randint(self.__len__()) 119 | sample = self.dataset_samples[index] 120 | buffer = self.loadvideo_decord(sample) 121 | buffer = self.data_transform(buffer) 122 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 123 | 124 | elif self.mode == 'test': 125 | sample = self.test_dataset[index] + '.mp4' 126 | sample = os.path.join(self.data_path, sample) 127 | chunk_nb, split_nb = self.test_seg[index] 128 | buffer = self.loadvideo_decord(sample) 129 | 130 | while len(buffer) == 0: 131 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 132 | str(self.test_dataset[index]), chunk_nb, split_nb)) 133 | index = np.random.randint(self.__len__()) 134 | sample = self.test_dataset[index] 135 | chunk_nb, split_nb = self.test_seg[index] 136 | buffer = self.loadvideo_decord(sample) 137 | 138 | buffer = self.data_resize(buffer) 139 | if isinstance(buffer, list): 140 | buffer = np.stack(buffer, 0) 141 | 142 | if self.test_num_crop == 1: 143 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 144 | / (self.test_num_crop) 145 | else: 146 | spatial_step = 1.0 * (max( buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 147 | / (self.test_num_crop - 1) 148 | temporal_start = chunk_nb # 0/1 149 | spatial_start = int(split_nb * spatial_step) 150 | if buffer.shape[1] >= buffer.shape[2]: 151 | buffer = buffer[temporal_start::2, \ 152 | spatial_start:spatial_start + self.short_side_size, :, :] 153 | else: 154 | buffer = buffer[temporal_start::2, \ 155 | :, spatial_start:spatial_start + self.short_side_size, :] 156 | 157 | buffer = self.data_transform(buffer) 158 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 159 | chunk_nb, split_nb 160 | else: 161 | raise NameError('mode {} unkown'.format(self.mode)) 162 | 163 | 164 | 165 | def _aug_frame(self,buffer,args): 166 | 167 | aug_transform = video_transforms.create_random_augment( 168 | input_size=(self.crop_size, self.crop_size), 169 | auto_augment=args.aa, 170 | interpolation=args.train_interpolation, 171 | ) 172 | 173 | buffer = [ 174 | transforms.ToPILImage()(frame) for frame in buffer 175 | ] 176 | 177 | buffer = aug_transform(buffer) 178 | 179 | buffer = [transforms.ToTensor()(img) for img in buffer] 180 | buffer = torch.stack(buffer) # T C H W 181 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 182 | 183 | # T H W C 184 | buffer = tensor_normalize( 185 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 186 | ) 187 | # T H W C -> C T H W. 188 | buffer = buffer.permute(3, 0, 1, 2) 189 | # Perform data augmentation. 190 | scl, asp = ( 191 | [0.08, 1.0], 192 | [0.75, 1.3333], 193 | ) 194 | 195 | buffer = spatial_sampling( 196 | buffer, 197 | spatial_idx=-1, 198 | min_scale=256, 199 | max_scale=320, 200 | crop_size=self.crop_size, 201 | random_horizontal_flip=False if args.data_set == 'SSV2' else True, 202 | inverse_uniform_sampling=False, 203 | aspect_ratio=asp, 204 | scale=scl, 205 | motion_shift=False 206 | ) 207 | 208 | if self.rand_erase: 209 | erase_transform = RandomErasing( 210 | args.reprob, 211 | mode=args.remode, 212 | max_count=args.recount, 213 | num_splits=args.recount, 214 | device="cpu", 215 | ) 216 | buffer = buffer.permute(1, 0, 2, 3) 217 | buffer = erase_transform(buffer) 218 | buffer = buffer.permute(1, 0, 2, 3) 219 | 220 | return buffer 221 | 222 | 223 | def loadvideo_decord(self, sample, sample_rate_scale=1): 224 | """Load video content using Decord""" 225 | fname = sample 226 | 227 | if not (os.path.exists(fname)): 228 | return [] 229 | 230 | # avoid hanging issue 231 | if os.path.getsize(fname) < 1 * 1024: 232 | print('SKIP: ', fname, " - ", os.path.getsize(fname)) 233 | return [] 234 | try: 235 | if self.keep_aspect_ratio: 236 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 237 | else: 238 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 239 | num_threads=1, ctx=cpu(0)) 240 | except: 241 | print("video cannot be loaded by decord: ", fname) 242 | return [] 243 | 244 | if self.mode == 'test': 245 | all_index = [] 246 | tick = len(vr) / float(self.num_segment) 247 | all_index = list(np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segment)] + 248 | [int(tick * x) for x in range(self.num_segment)])) 249 | while len(all_index) < (self.num_segment * self.test_num_segment): 250 | all_index.append(all_index[-1]) 251 | all_index = list(np.sort(np.array(all_index))) 252 | vr.seek(0) 253 | buffer = vr.get_batch(all_index).asnumpy() 254 | return buffer 255 | 256 | # handle temporal segments 257 | average_duration = len(vr) // self.num_segment 258 | all_index = [] 259 | if average_duration > 0: 260 | all_index += list(np.multiply(list(range(self.num_segment)), average_duration) + np.random.randint(average_duration, 261 | size=self.num_segment)) 262 | elif len(vr) > self.num_segment: 263 | all_index += list(np.sort(np.random.randint(len(vr), size=self.num_segment))) 264 | else: 265 | all_index += list(np.zeros((self.num_segment,))) 266 | all_index = list(np.array(all_index)) 267 | vr.seek(0) 268 | buffer = vr.get_batch(all_index).asnumpy() 269 | return buffer 270 | 271 | def __len__(self): 272 | if self.mode != 'test': 273 | return len(self.dataset_samples) 274 | else: 275 | return len(self.test_dataset) 276 | 277 | 278 | def spatial_sampling( 279 | frames, 280 | spatial_idx=-1, 281 | min_scale=256, 282 | max_scale=320, 283 | crop_size=224, 284 | random_horizontal_flip=True, 285 | inverse_uniform_sampling=False, 286 | aspect_ratio=None, 287 | scale=None, 288 | motion_shift=False, 289 | ): 290 | """ 291 | Perform spatial sampling on the given video frames. If spatial_idx is 292 | -1, perform random scale, random crop, and random flip on the given 293 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 294 | with the given spatial_idx. 295 | Args: 296 | frames (tensor): frames of images sampled from the video. The 297 | dimension is `num frames` x `height` x `width` x `channel`. 298 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 299 | or 2, perform left, center, right crop if width is larger than 300 | height, and perform top, center, buttom crop if height is larger 301 | than width. 302 | min_scale (int): the minimal size of scaling. 303 | max_scale (int): the maximal size of scaling. 304 | crop_size (int): the size of height and width used to crop the 305 | frames. 306 | inverse_uniform_sampling (bool): if True, sample uniformly in 307 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 308 | scale. If False, take a uniform sample from [min_scale, 309 | max_scale]. 310 | aspect_ratio (list): Aspect ratio range for resizing. 311 | scale (list): Scale range for resizing. 312 | motion_shift (bool): Whether to apply motion shift for resizing. 313 | Returns: 314 | frames (tensor): spatially sampled frames. 315 | """ 316 | assert spatial_idx in [-1, 0, 1, 2] 317 | if spatial_idx == -1: 318 | if aspect_ratio is None and scale is None: 319 | frames, _ = video_transforms.random_short_side_scale_jitter( 320 | images=frames, 321 | min_size=min_scale, 322 | max_size=max_scale, 323 | inverse_uniform_sampling=inverse_uniform_sampling, 324 | ) 325 | frames, _ = video_transforms.random_crop(frames, crop_size) 326 | else: 327 | transform_func = ( 328 | video_transforms.random_resized_crop_with_shift 329 | if motion_shift 330 | else video_transforms.random_resized_crop 331 | ) 332 | frames = transform_func( 333 | images=frames, 334 | target_height=crop_size, 335 | target_width=crop_size, 336 | scale=scale, 337 | ratio=aspect_ratio, 338 | ) 339 | if random_horizontal_flip: 340 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 341 | else: 342 | # The testing is deterministic and no jitter should be performed. 343 | # min_scale, max_scale, and crop_size are expect to be the same. 344 | assert len({min_scale, max_scale, crop_size}) == 1 345 | frames, _ = video_transforms.random_short_side_scale_jitter( 346 | frames, min_scale, max_scale 347 | ) 348 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 349 | return frames 350 | 351 | 352 | def tensor_normalize(tensor, mean, std): 353 | """ 354 | Normalize a given tensor by subtracting the mean and dividing the std. 355 | Args: 356 | tensor (tensor): tensor to normalize. 357 | mean (tensor or list): mean value to subtract. 358 | std (tensor or list): std to divide. 359 | """ 360 | if tensor.dtype == torch.uint8: 361 | tensor = tensor.float() 362 | tensor = tensor / 255.0 363 | if type(mean) == list: 364 | mean = torch.tensor(mean) 365 | if type(std) == list: 366 | std = torch.tensor(std) 367 | tensor = tensor - mean 368 | tensor = tensor / std 369 | return tensor -------------------------------------------------------------------------------- /engine_for_compomodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import sys 5 | from typing import Iterable, Optional 6 | import torch 7 | from util_tools.mixup import Mixup 8 | from timm.utils import accuracy, ModelEma 9 | import util_tools.utils as utils 10 | from scipy.special import softmax 11 | from einops import rearrange 12 | 13 | 14 | def composition_train_class_batch(model, samples, target_noun, target_verb, criterion): 15 | outputs_noun, outputs_verb = model(samples) 16 | loss_noun = criterion(outputs_noun, target_noun) 17 | loss_verb = criterion(outputs_verb, target_verb) 18 | total_loss = loss_noun + loss_verb 19 | return total_loss, loss_noun, loss_verb, outputs_noun, outputs_verb 20 | 21 | 22 | 23 | def get_loss_scale_for_deepspeed(model): 24 | optimizer = model.optimizer 25 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 26 | 27 | 28 | def train_one_epoch(args, model: torch.nn.Module, criterion: torch.nn.Module, 29 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 30 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 31 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 32 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 33 | num_training_steps_per_epoch=None, update_freq=None): 34 | model.train(True) 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 38 | header = 'Epoch: [{}]'.format(epoch) 39 | print_freq = 10 40 | 41 | if loss_scaler is None: 42 | model.zero_grad() 43 | model.micro_steps = 0 44 | else: 45 | optimizer.zero_grad() 46 | 47 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 48 | step = data_iter_step // update_freq 49 | if step >= num_training_steps_per_epoch: 50 | continue 51 | it = start_steps + step # global training iteration 52 | # Update LR & WD for the first acc 53 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 54 | for i, param_group in enumerate(optimizer.param_groups): 55 | if lr_schedule_values is not None: 56 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 57 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 58 | param_group["weight_decay"] = wd_schedule_values[it] 59 | 60 | samples = samples.to(device, non_blocking=True) 61 | targets = targets.to(device, non_blocking=True) 62 | 63 | if mixup_fn is not None: 64 | samples, target_noun, target_verb = mixup_fn(samples, targets) 65 | 66 | if loss_scaler is None: 67 | samples = samples.half() 68 | loss, loss_noun, loss_verb, outputs_noun, outputs_verb = composition_train_class_batch( 69 | model, samples, target_noun, target_verb, criterion) 70 | else: 71 | with torch.cuda.amp.autocast(): 72 | samples = samples.half() 73 | loss, outputs_noun, outpus_verb = composition_train_class_batch( 74 | model, samples, target_noun, target_verb, criterion) 75 | loss_value = loss.item() 76 | 77 | if not math.isfinite(loss_value): 78 | print("Loss is {}, stopping training".format(loss_value)) 79 | sys.exit(1) 80 | 81 | if loss_scaler is None: 82 | loss /= update_freq 83 | model.backward(loss) 84 | model.step() 85 | 86 | if (data_iter_step + 1) % update_freq == 0: 87 | # model.zero_grad() 88 | # Deepspeed will call step() & model.zero_grad() automatic 89 | if model_ema is not None: 90 | model_ema.update(model) 91 | grad_norm = None 92 | loss_scale_value = get_loss_scale_for_deepspeed(model) 93 | else: 94 | # this attribute is added by timm on one optimizer (adahessian) 95 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 96 | loss /= update_freq 97 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 98 | parameters=model.parameters(), create_graph=is_second_order, 99 | update_grad=(data_iter_step + 1) % update_freq == 0) 100 | if (data_iter_step + 1) % update_freq == 0: 101 | optimizer.zero_grad() 102 | if model_ema is not None: 103 | model_ema.update(model) 104 | loss_scale_value = loss_scaler.state_dict()["scale"] 105 | 106 | torch.cuda.synchronize() 107 | 108 | if mixup_fn is None: 109 | pass 110 | # class_acc = (output.max(-1)[-1] == targets).float().mean() 111 | else: 112 | class_acc = None 113 | metric_logger.update(loss=loss_value) 114 | metric_logger.update(loss_noun=loss_noun) 115 | metric_logger.update(loss_verb=loss_verb) 116 | metric_logger.update(class_acc=class_acc) 117 | metric_logger.update(loss_scale=loss_scale_value) 118 | min_lr = 10. 119 | max_lr = 0. 120 | for group in optimizer.param_groups: 121 | min_lr = min(min_lr, group["lr"]) 122 | max_lr = max(max_lr, group["lr"]) 123 | 124 | metric_logger.update(lr=max_lr) 125 | metric_logger.update(min_lr=min_lr) 126 | weight_decay_value = None 127 | for group in optimizer.param_groups: 128 | if group["weight_decay"] > 0: 129 | weight_decay_value = group["weight_decay"] 130 | metric_logger.update(weight_decay=weight_decay_value) 131 | metric_logger.update(grad_norm=grad_norm) 132 | 133 | if log_writer is not None: 134 | log_writer.update(loss=loss_value, head="loss") 135 | log_writer.update(class_acc=class_acc, head="loss") 136 | log_writer.update(loss_scale=loss_scale_value, head="opt") 137 | log_writer.update(lr=max_lr, head="opt") 138 | log_writer.update(min_lr=min_lr, head="opt") 139 | log_writer.update(weight_decay=weight_decay_value, head="opt") 140 | log_writer.update(grad_norm=grad_norm, head="opt") 141 | 142 | log_writer.set_step() 143 | 144 | # gather the stats from all processes 145 | metric_logger.synchronize_between_processes() 146 | print("Averaged stats:", metric_logger) 147 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 148 | 149 | 150 | @torch.no_grad() 151 | def validation_one_epoch(args, data_loader, model, device): 152 | criterion = torch.nn.CrossEntropyLoss() 153 | 154 | metric_logger = utils.MetricLogger(delimiter=" ") 155 | header = 'Val:' 156 | 157 | # switch to evaluation mode 158 | model.eval() 159 | 160 | for batch in metric_logger.log_every(data_loader, 10, header): 161 | samples = batch[0] 162 | target = batch[1] 163 | batch_size = samples.shape[0] 164 | samples = samples.to(device, non_blocking=True) 165 | target = target.to(device, non_blocking=True) 166 | action_target = (target[:,1] * 1000) + target[:,0] 167 | 168 | # compute output 169 | with torch.cuda.amp.autocast(): 170 | output_noun, output_verb = model(samples) 171 | loss_noun = criterion(output_noun, target[:,0]) 172 | loss_verb = criterion(output_verb, target[:,1]) 173 | 174 | acc1_action, acc5_action = action_accuracy(output_noun, output_verb, action_target, topk=(1,5)) 175 | acc1_noun, acc5_noun = accuracy(output_noun, target[:,0], topk=(1, 5)) 176 | acc1_verb, acc5_verb = accuracy(output_verb, target[:,1], topk=(1, 5)) 177 | 178 | metric_logger.update(loss_noun=loss_noun.item()) 179 | metric_logger.update(loss_verb=loss_verb.item()) 180 | metric_logger.update(acc1_action=acc1_action.item()) 181 | metric_logger.update(acc1_noun=acc1_noun.item()) 182 | metric_logger.update(acc1_verb=acc1_verb.item()) 183 | metric_logger.update(acc5_noun=acc5_noun.item()) 184 | metric_logger.update(acc5_verb=acc5_verb.item()) 185 | metric_logger.meters['acc1_noun'].update(acc1_noun.item(), n=batch_size) 186 | metric_logger.meters['acc1_verb'].update(acc1_verb.item(), n=batch_size) 187 | metric_logger.meters['acc5_noun'].update(acc5_noun.item(), n=batch_size) 188 | metric_logger.meters['acc5_verb'].update(acc5_verb.item(), n=batch_size) 189 | # gather the stats from all processes 190 | metric_logger.synchronize_between_processes() 191 | print('* Acc_@1_action {top1_action.global_avg:.3f} Acc_@1_noun {top1_noun.global_avg:.3f} Acc_@1_verb {top1_verb.global_avg:.3f} Acc@5_noun {top5_noun.global_avg:.3f} Acc@5_verb {top5_verb.global_avg:.3f} loss_noun {losses_noun.global_avg:.3f} loss_verb {losses_verb.global_avg:.3f}' 192 | .format(top1_action=metric_logger.acc1_action, top1_noun=metric_logger.acc1_noun, top1_verb=metric_logger.acc1_verb, top5_noun=metric_logger.acc5_noun, top5_verb=metric_logger.acc5_verb, losses_noun=metric_logger.loss_noun, losses_verb=metric_logger.loss_verb)) 193 | 194 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 195 | 196 | 197 | 198 | @torch.no_grad() 199 | def final_test(args, data_loader, model, device, file): 200 | criterion = torch.nn.CrossEntropyLoss() 201 | 202 | metric_logger = utils.MetricLogger(delimiter=" ") 203 | header = 'Test:' 204 | 205 | # switch to evaluation mode 206 | model.eval() 207 | final_result = [] 208 | 209 | for batch in metric_logger.log_every(data_loader, 10, header): 210 | samples = batch[0] 211 | target = batch[1] 212 | ids = batch[2] 213 | chunk_nb = batch[3] 214 | split_nb = batch[4] 215 | batch_size = samples.shape[0] 216 | samples = samples.to(device, non_blocking=True) 217 | target = target.to(device, non_blocking=True) 218 | action_target = (target[:,1] * 1000) + target[:,0] 219 | 220 | # compute output 221 | with torch.cuda.amp.autocast(): 222 | output_noun, output_verb = model(samples) 223 | loss_noun = criterion(output_noun, target[:,0]) 224 | loss_verb = criterion(output_verb, target[:,1]) 225 | 226 | for i in range(output_noun.size(0)): 227 | string = "{} {} {} {} {} {} {} {}\n".format(ids[i], \ 228 | str(output_noun.data[i].cpu().numpy().tolist()), \ 229 | str(output_verb.data[i].cpu().numpy().tolist()), \ 230 | str(int(action_target[i].cpu().numpy())), \ 231 | str(int(target[i,0].cpu().numpy())), \ 232 | str(int(target[i,1].cpu().numpy())), \ 233 | str(int(chunk_nb[i].cpu().numpy())), \ 234 | str(int(split_nb[i].cpu().numpy()))) 235 | final_result.append(string) 236 | 237 | acc1_action, acc5_action = action_accuracy(output_noun, output_verb, action_target, topk=(1,5)) 238 | acc1_noun, acc5_noun = accuracy(output_noun, target[:,0], topk=(1, 5)) 239 | acc1_verb, acc5_verb = accuracy(output_verb, target[:,1], topk=(1, 5)) 240 | 241 | metric_logger.update(loss_noun=loss_noun.item()) 242 | metric_logger.update(loss_verb=loss_verb.item()) 243 | metric_logger.update(acc1_action=acc1_action.item()) 244 | metric_logger.update(acc1_noun=acc1_noun.item()) 245 | metric_logger.update(acc1_verb=acc1_verb.item()) 246 | metric_logger.update(acc5_noun=acc5_noun.item()) 247 | metric_logger.update(acc5_verb=acc5_verb.item()) 248 | metric_logger.meters['acc1_action'].update(acc1_action.item(), n=batch_size) 249 | metric_logger.meters['acc1_noun'].update(acc1_noun.item(), n=batch_size) 250 | metric_logger.meters['acc1_verb'].update(acc1_verb.item(), n=batch_size) 251 | metric_logger.meters['acc5_noun'].update(acc5_noun.item(), n=batch_size) 252 | metric_logger.meters['acc5_verb'].update(acc5_verb.item(), n=batch_size) 253 | 254 | if not os.path.exists(file): 255 | os.mknod(file) 256 | with open(file, 'w') as f: 257 | f.write("{}, {}\n".format(acc1_noun, acc5_noun)) 258 | for line in final_result: 259 | f.write(line) 260 | # gather the stats from all processes 261 | metric_logger.synchronize_between_processes() 262 | print('* Acc_@1_action {top1_action.global_avg:.3f} Acc_@1_noun {top1_noun.global_avg:.3f} Acc_@1_verb {top1_verb.global_avg:.3f} Acc@5_noun {top5_noun.global_avg:.3f} Acc@5_verb {top5_verb.global_avg:.3f} loss_noun {losses_noun.global_avg:.3f} loss_verb {losses_verb.global_avg:.3f}' 263 | .format(top1_action=metric_logger.acc1_action, top1_noun=metric_logger.acc1_noun, top1_verb=metric_logger.acc1_verb, top5_noun=metric_logger.acc5_noun, top5_verb=metric_logger.acc5_verb, losses_noun=metric_logger.loss_noun, losses_verb=metric_logger.loss_verb)) 264 | 265 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 266 | 267 | 268 | def merge(eval_path, num_tasks): 269 | dict_feats_noun = {} 270 | dict_feats_verb = {} 271 | dict_label = {} 272 | dict_action_label ={} 273 | dict_pos = {} 274 | print("Reading individual output files") 275 | 276 | for x in range(num_tasks): 277 | file = os.path.join(eval_path, str(x) + '.txt') 278 | lines = open(file, 'r').readlines()[1:] 279 | for line in lines: 280 | line = line.strip() 281 | name = line.split('[')[0] 282 | label_action = line.split(']')[2].split(' ')[1] 283 | label_noun = line.split(']')[2].split(' ')[2] 284 | label_verb = line.split(']')[2].split(' ')[3] 285 | chunk_nb = line.split(']')[2].split(' ')[4] 286 | split_nb = line.split(']')[2].split(' ')[5] 287 | data_noun = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',') 288 | data_verb = np.fromstring(line.split('[')[2].split(']')[0], dtype=np.float, sep=',') 289 | data_noun = softmax(data_noun) 290 | data_verb = softmax(data_verb) 291 | 292 | if not name in dict_feats_noun: 293 | dict_feats_noun[name] = [] 294 | dict_feats_verb[name] = [] 295 | dict_label[name] = 0 296 | dict_action_label[name] = 0 297 | dict_pos[name] = [] 298 | if chunk_nb + split_nb in dict_pos[name]: 299 | continue 300 | dict_feats_noun[name].append(data_noun) 301 | dict_feats_verb[name].append(data_verb) 302 | dict_pos[name].append(chunk_nb + split_nb) 303 | dict_label[name] = (label_noun, label_verb) 304 | dict_action_label[name] = label_action 305 | print("Computing final results") 306 | 307 | input_lst = [] 308 | print(len(dict_feats_noun)) 309 | for i, item in enumerate(dict_feats_noun): 310 | input_lst.append([i, item, dict_feats_noun[item], dict_feats_verb[item], dict_label[item], dict_action_label[item]]) 311 | from multiprocessing import Pool 312 | p = Pool(8) 313 | ans = p.map(compute_video, input_lst) 314 | top1_action = [x[2] for x in ans] 315 | top5_action = [x[3] for x in ans] 316 | top1_noun = [x[4] for x in ans] 317 | top1_verb = [x[5] for x in ans] 318 | top5_noun = [x[6] for x in ans] 319 | top5_verb = [x[7] for x in ans] 320 | final_top1_noun ,final_top5_noun, final_top1_verb, final_top5_verb = np.mean(top1_noun), np.mean(top5_noun), np.mean(top1_verb), np.mean(top5_verb) 321 | final_top1_action, final_top5_action = np.mean(top1_action), np.mean(top5_action) 322 | return final_top1_action*100, final_top5_action*100, final_top1_noun*100 ,final_top5_noun*100, final_top1_verb*100, final_top5_verb*100 323 | 324 | def compute_video(lst): 325 | i, video_id, data_noun, data_verb, label, label_action = lst 326 | feat_noun = [x for x in data_noun] 327 | feat_verb = [x for x in data_verb] 328 | feat_noun = np.mean(feat_noun, axis=0) 329 | feat_verb = np.mean(feat_verb, axis=0) 330 | pred_noun = np.argmax(feat_noun) 331 | pred_verb = np.argmax(feat_verb) 332 | pred_action = (pred_verb * 1000) + pred_noun 333 | label_noun, label_verb = label 334 | top1_action = (int(pred_action) == int(label_action)) * 1.0 335 | top5_action = (int(label_noun) in np.argsort(-feat_noun)[:5] and int(label_verb) in np.argsort(-feat_verb)[:5]) * 1.0 336 | top1_noun = (int(pred_noun) == int(label_noun)) * 1.0 337 | top5_noun = (int(label_noun) in np.argsort(-feat_noun)[:5]) * 1.0 338 | top1_verb = (int(pred_verb) == int(label_verb)) * 1.0 339 | top5_verb = (int(label_verb) in np.argsort(-feat_verb)[:5]) * 1.0 340 | return [pred_noun, pred_verb, top1_action, top5_action, top1_noun, top1_verb, top5_noun, top5_verb] 341 | 342 | def action_accuracy(output_noun, output_verb, target, topk=(1,)): 343 | """Computes the accuracy over the k top predictions for the specified values of k""" 344 | maxk = max(topk) 345 | batch_size = target.size(0) 346 | _, pred_noun = output_noun.topk(maxk, 1, True, True) 347 | _, pred_verb = output_verb.topk(maxk, 1, True, True) 348 | pred = (pred_verb * 1000) + pred_noun 349 | pred = pred.t() 350 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 351 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 352 | -------------------------------------------------------------------------------- /util_tools/rand_augment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 4 | pulished under an Apache License 2.0. 5 | 6 | COMMENT FROM ORIGINAL: 7 | AutoAugment, RandAugment, and AugMix for PyTorch 8 | This code implements the searched ImageNet policies with various tweaks and 9 | improvements and does not include any of the search code. AA and RA 10 | Implementation adapted from: 11 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 12 | AugMix adapted from: 13 | https://github.com/google-research/augmix 14 | Papers: 15 | AutoAugment: Learning Augmentation Policies from Data 16 | https://arxiv.org/abs/1805.09501 17 | Learning Data Augmentation Strategies for Object Detection 18 | https://arxiv.org/abs/1906.11172 19 | RandAugment: Practical automated data augmentation... 20 | https://arxiv.org/abs/1909.13719 21 | AugMix: A Simple Data Processing Method to Improve Robustness and 22 | Uncertainty https://arxiv.org/abs/1912.02781 23 | 24 | Hacked together by / Copyright 2020 Ross Wightman 25 | """ 26 | 27 | import math 28 | import numpy as np 29 | import random 30 | import re 31 | import PIL 32 | from PIL import Image, ImageEnhance, ImageOps 33 | 34 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 35 | 36 | _FILL = (128, 128, 128) 37 | 38 | # This signifies the max integer that the controller RNN could predict for the 39 | # augmentation scheme. 40 | _MAX_LEVEL = 10.0 41 | 42 | _HPARAMS_DEFAULT = { 43 | "translate_const": 250, 44 | "img_mean": _FILL, 45 | } 46 | 47 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 48 | 49 | 50 | def _interpolation(kwargs): 51 | interpolation = kwargs.pop("resample", Image.BILINEAR) 52 | if isinstance(interpolation, (list, tuple)): 53 | return random.choice(interpolation) 54 | else: 55 | return interpolation 56 | 57 | 58 | def _check_args_tf(kwargs): 59 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 60 | kwargs.pop("fillcolor") 61 | kwargs["resample"] = _interpolation(kwargs) 62 | 63 | 64 | def shear_x(img, factor, **kwargs): 65 | _check_args_tf(kwargs) 66 | return img.transform( 67 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 68 | ) 69 | 70 | 71 | def shear_y(img, factor, **kwargs): 72 | _check_args_tf(kwargs) 73 | return img.transform( 74 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 75 | ) 76 | 77 | 78 | def translate_x_rel(img, pct, **kwargs): 79 | pixels = pct * img.size[0] 80 | _check_args_tf(kwargs) 81 | return img.transform( 82 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 83 | ) 84 | 85 | 86 | def translate_y_rel(img, pct, **kwargs): 87 | pixels = pct * img.size[1] 88 | _check_args_tf(kwargs) 89 | return img.transform( 90 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 91 | ) 92 | 93 | 94 | def translate_x_abs(img, pixels, **kwargs): 95 | _check_args_tf(kwargs) 96 | return img.transform( 97 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 98 | ) 99 | 100 | 101 | def translate_y_abs(img, pixels, **kwargs): 102 | _check_args_tf(kwargs) 103 | return img.transform( 104 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 105 | ) 106 | 107 | 108 | def rotate(img, degrees, **kwargs): 109 | _check_args_tf(kwargs) 110 | if _PIL_VER >= (5, 2): 111 | return img.rotate(degrees, **kwargs) 112 | elif _PIL_VER >= (5, 0): 113 | w, h = img.size 114 | post_trans = (0, 0) 115 | rotn_center = (w / 2.0, h / 2.0) 116 | angle = -math.radians(degrees) 117 | matrix = [ 118 | round(math.cos(angle), 15), 119 | round(math.sin(angle), 15), 120 | 0.0, 121 | round(-math.sin(angle), 15), 122 | round(math.cos(angle), 15), 123 | 0.0, 124 | ] 125 | 126 | def transform(x, y, matrix): 127 | (a, b, c, d, e, f) = matrix 128 | return a * x + b * y + c, d * x + e * y + f 129 | 130 | matrix[2], matrix[5] = transform( 131 | -rotn_center[0] - post_trans[0], 132 | -rotn_center[1] - post_trans[1], 133 | matrix, 134 | ) 135 | matrix[2] += rotn_center[0] 136 | matrix[5] += rotn_center[1] 137 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 138 | else: 139 | return img.rotate(degrees, resample=kwargs["resample"]) 140 | 141 | 142 | def auto_contrast(img, **__): 143 | return ImageOps.autocontrast(img) 144 | 145 | 146 | def invert(img, **__): 147 | return ImageOps.invert(img) 148 | 149 | 150 | def equalize(img, **__): 151 | return ImageOps.equalize(img) 152 | 153 | 154 | def solarize(img, thresh, **__): 155 | return ImageOps.solarize(img, thresh) 156 | 157 | 158 | def solarize_add(img, add, thresh=128, **__): 159 | lut = [] 160 | for i in range(256): 161 | if i < thresh: 162 | lut.append(min(255, i + add)) 163 | else: 164 | lut.append(i) 165 | if img.mode in ("L", "RGB"): 166 | if img.mode == "RGB" and len(lut) == 256: 167 | lut = lut + lut + lut 168 | return img.point(lut) 169 | else: 170 | return img 171 | 172 | 173 | def posterize(img, bits_to_keep, **__): 174 | if bits_to_keep >= 8: 175 | return img 176 | return ImageOps.posterize(img, bits_to_keep) 177 | 178 | 179 | def contrast(img, factor, **__): 180 | return ImageEnhance.Contrast(img).enhance(factor) 181 | 182 | 183 | def color(img, factor, **__): 184 | return ImageEnhance.Color(img).enhance(factor) 185 | 186 | 187 | def brightness(img, factor, **__): 188 | return ImageEnhance.Brightness(img).enhance(factor) 189 | 190 | 191 | def sharpness(img, factor, **__): 192 | return ImageEnhance.Sharpness(img).enhance(factor) 193 | 194 | 195 | def _randomly_negate(v): 196 | """With 50% prob, negate the value""" 197 | return -v if random.random() > 0.5 else v 198 | 199 | 200 | def _rotate_level_to_arg(level, _hparams): 201 | # range [-30, 30] 202 | level = (level / _MAX_LEVEL) * 30.0 203 | level = _randomly_negate(level) 204 | return (level,) 205 | 206 | 207 | def _enhance_level_to_arg(level, _hparams): 208 | # range [0.1, 1.9] 209 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 210 | 211 | 212 | def _enhance_increasing_level_to_arg(level, _hparams): 213 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 214 | # range [0.1, 1.9] 215 | level = (level / _MAX_LEVEL) * 0.9 216 | level = 1.0 + _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _shear_level_to_arg(level, _hparams): 221 | # range [-0.3, 0.3] 222 | level = (level / _MAX_LEVEL) * 0.3 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_abs_level_to_arg(level, hparams): 228 | translate_const = hparams["translate_const"] 229 | level = (level / _MAX_LEVEL) * float(translate_const) 230 | level = _randomly_negate(level) 231 | return (level,) 232 | 233 | 234 | def _translate_rel_level_to_arg(level, hparams): 235 | # default range [-0.45, 0.45] 236 | translate_pct = hparams.get("translate_pct", 0.45) 237 | level = (level / _MAX_LEVEL) * translate_pct 238 | level = _randomly_negate(level) 239 | return (level,) 240 | 241 | 242 | def _posterize_level_to_arg(level, _hparams): 243 | # As per Tensorflow TPU EfficientNet impl 244 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 245 | # intensity/severity of augmentation decreases with level 246 | return (int((level / _MAX_LEVEL) * 4),) 247 | 248 | 249 | def _posterize_increasing_level_to_arg(level, hparams): 250 | # As per Tensorflow models research and UDA impl 251 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 252 | # intensity/severity of augmentation increases with level 253 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 254 | 255 | 256 | def _posterize_original_level_to_arg(level, _hparams): 257 | # As per original AutoAugment paper description 258 | # range [4, 8], 'keep 4 up to 8 MSB of image' 259 | # intensity/severity of augmentation decreases with level 260 | return (int((level / _MAX_LEVEL) * 4) + 4,) 261 | 262 | 263 | def _solarize_level_to_arg(level, _hparams): 264 | # range [0, 256] 265 | # intensity/severity of augmentation decreases with level 266 | return (int((level / _MAX_LEVEL) * 256),) 267 | 268 | 269 | def _solarize_increasing_level_to_arg(level, _hparams): 270 | # range [0, 256] 271 | # intensity/severity of augmentation increases with level 272 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 273 | 274 | 275 | def _solarize_add_level_to_arg(level, _hparams): 276 | # range [0, 110] 277 | return (int((level / _MAX_LEVEL) * 110),) 278 | 279 | 280 | LEVEL_TO_ARG = { 281 | "AutoContrast": None, 282 | "Equalize": None, 283 | "Invert": None, 284 | "Rotate": _rotate_level_to_arg, 285 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 286 | "Posterize": _posterize_level_to_arg, 287 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 288 | "PosterizeOriginal": _posterize_original_level_to_arg, 289 | "Solarize": _solarize_level_to_arg, 290 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 291 | "SolarizeAdd": _solarize_add_level_to_arg, 292 | "Color": _enhance_level_to_arg, 293 | "ColorIncreasing": _enhance_increasing_level_to_arg, 294 | "Contrast": _enhance_level_to_arg, 295 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 296 | "Brightness": _enhance_level_to_arg, 297 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 298 | "Sharpness": _enhance_level_to_arg, 299 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 300 | "ShearX": _shear_level_to_arg, 301 | "ShearY": _shear_level_to_arg, 302 | "TranslateX": _translate_abs_level_to_arg, 303 | "TranslateY": _translate_abs_level_to_arg, 304 | "TranslateXRel": _translate_rel_level_to_arg, 305 | "TranslateYRel": _translate_rel_level_to_arg, 306 | } 307 | 308 | 309 | NAME_TO_OP = { 310 | "AutoContrast": auto_contrast, 311 | "Equalize": equalize, 312 | "Invert": invert, 313 | "Rotate": rotate, 314 | "Posterize": posterize, 315 | "PosterizeIncreasing": posterize, 316 | "PosterizeOriginal": posterize, 317 | "Solarize": solarize, 318 | "SolarizeIncreasing": solarize, 319 | "SolarizeAdd": solarize_add, 320 | "Color": color, 321 | "ColorIncreasing": color, 322 | "Contrast": contrast, 323 | "ContrastIncreasing": contrast, 324 | "Brightness": brightness, 325 | "BrightnessIncreasing": brightness, 326 | "Sharpness": sharpness, 327 | "SharpnessIncreasing": sharpness, 328 | "ShearX": shear_x, 329 | "ShearY": shear_y, 330 | "TranslateX": translate_x_abs, 331 | "TranslateY": translate_y_abs, 332 | "TranslateXRel": translate_x_rel, 333 | "TranslateYRel": translate_y_rel, 334 | } 335 | 336 | 337 | class AugmentOp: 338 | """ 339 | Apply for video. 340 | """ 341 | 342 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 343 | hparams = hparams or _HPARAMS_DEFAULT 344 | self.aug_fn = NAME_TO_OP[name] 345 | self.level_fn = LEVEL_TO_ARG[name] 346 | self.prob = prob 347 | self.magnitude = magnitude 348 | self.hparams = hparams.copy() 349 | self.kwargs = { 350 | "fillcolor": hparams["img_mean"] 351 | if "img_mean" in hparams 352 | else _FILL, 353 | "resample": hparams["interpolation"] 354 | if "interpolation" in hparams 355 | else _RANDOM_INTERPOLATION, 356 | } 357 | 358 | # If magnitude_std is > 0, we introduce some randomness 359 | # in the usually fixed policy and sample magnitude from a normal distribution 360 | # with mean `magnitude` and std-dev of `magnitude_std`. 361 | # NOTE This is my own hack, being tested, not in papers or reference impls. 362 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 363 | 364 | def __call__(self, img_list): 365 | if self.prob < 1.0 and random.random() > self.prob: 366 | return img_list 367 | magnitude = self.magnitude 368 | if self.magnitude_std and self.magnitude_std > 0: 369 | magnitude = random.gauss(magnitude, self.magnitude_std) 370 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 371 | level_args = ( 372 | self.level_fn(magnitude, self.hparams) 373 | if self.level_fn is not None 374 | else () 375 | ) 376 | 377 | if isinstance(img_list, list): 378 | return [ 379 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 380 | ] 381 | else: 382 | return self.aug_fn(img_list, *level_args, **self.kwargs) 383 | 384 | 385 | _RAND_TRANSFORMS = [ 386 | "AutoContrast", 387 | "Equalize", 388 | "Invert", 389 | "Rotate", 390 | "Posterize", 391 | "Solarize", 392 | "SolarizeAdd", 393 | "Color", 394 | "Contrast", 395 | "Brightness", 396 | "Sharpness", 397 | "ShearX", 398 | "ShearY", 399 | "TranslateXRel", 400 | "TranslateYRel", 401 | ] 402 | 403 | 404 | _RAND_INCREASING_TRANSFORMS = [ 405 | "AutoContrast", 406 | "Equalize", 407 | "Invert", 408 | "Rotate", 409 | "PosterizeIncreasing", 410 | "SolarizeIncreasing", 411 | "SolarizeAdd", 412 | "ColorIncreasing", 413 | "ContrastIncreasing", 414 | "BrightnessIncreasing", 415 | "SharpnessIncreasing", 416 | "ShearX", 417 | "ShearY", 418 | "TranslateXRel", 419 | "TranslateYRel", 420 | ] 421 | 422 | 423 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 424 | # They may not result in increased performance, but could likely be tuned to so. 425 | _RAND_CHOICE_WEIGHTS_0 = { 426 | "Rotate": 0.3, 427 | "ShearX": 0.2, 428 | "ShearY": 0.2, 429 | "TranslateXRel": 0.1, 430 | "TranslateYRel": 0.1, 431 | "Color": 0.025, 432 | "Sharpness": 0.025, 433 | "AutoContrast": 0.025, 434 | "Solarize": 0.005, 435 | "SolarizeAdd": 0.005, 436 | "Contrast": 0.005, 437 | "Brightness": 0.005, 438 | "Equalize": 0.005, 439 | "Posterize": 0, 440 | "Invert": 0, 441 | } 442 | 443 | 444 | def _select_rand_weights(weight_idx=0, transforms=None): 445 | transforms = transforms or _RAND_TRANSFORMS 446 | assert weight_idx == 0 # only one set of weights currently 447 | rand_weights = _RAND_CHOICE_WEIGHTS_0 448 | probs = [rand_weights[k] for k in transforms] 449 | probs /= np.sum(probs) 450 | return probs 451 | 452 | 453 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 454 | hparams = hparams or _HPARAMS_DEFAULT 455 | transforms = transforms or _RAND_TRANSFORMS 456 | return [ 457 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 458 | for name in transforms 459 | ] 460 | 461 | 462 | class RandAugment: 463 | def __init__(self, ops, num_layers=2, choice_weights=None): 464 | self.ops = ops 465 | self.num_layers = num_layers 466 | self.choice_weights = choice_weights 467 | 468 | def __call__(self, img): 469 | # no replacement when using weighted choice 470 | ops = np.random.choice( 471 | self.ops, 472 | self.num_layers, 473 | replace=self.choice_weights is None, 474 | p=self.choice_weights, 475 | ) 476 | for op in ops: 477 | img = op(img) 478 | return img 479 | 480 | 481 | def rand_augment_transform(config_str, hparams): 482 | """ 483 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 484 | 485 | Create a RandAugment transform 486 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 487 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 488 | sections, not order sepecific determine 489 | 'm' - integer magnitude of rand augment 490 | 'n' - integer num layers (number of transform ops selected per image) 491 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 492 | 'mstd' - float std deviation of magnitude noise applied 493 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 494 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 495 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 496 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 497 | :return: A PyTorch compatible Transform 498 | """ 499 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 500 | num_layers = 2 # default to 2 ops per image 501 | weight_idx = None # default to no probability weights for op choice 502 | transforms = _RAND_TRANSFORMS 503 | config = config_str.split("-") 504 | assert config[0] == "rand" 505 | config = config[1:] 506 | for c in config: 507 | cs = re.split(r"(\d.*)", c) 508 | if len(cs) < 2: 509 | continue 510 | key, val = cs[:2] 511 | if key == "mstd": 512 | # noise param injected via hparams for now 513 | hparams.setdefault("magnitude_std", float(val)) 514 | elif key == "inc": 515 | if bool(val): 516 | transforms = _RAND_INCREASING_TRANSFORMS 517 | elif key == "m": 518 | magnitude = int(val) 519 | elif key == "n": 520 | num_layers = int(val) 521 | elif key == "w": 522 | weight_idx = int(val) 523 | else: 524 | assert NotImplementedError 525 | ra_ops = rand_augment_ops( 526 | magnitude=magnitude, hparams=hparams, transforms=transforms 527 | ) 528 | choice_weights = ( 529 | None if weight_idx is None else _select_rand_weights(weight_idx) 530 | ) 531 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 532 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /models/videomae_modelling_finetune.py: -------------------------------------------------------------------------------- 1 | # some codes from CLIP github(https://github.com/openai/CLIP), from VideoMAE github(https://github.com/MCG-NJU/VideoMAE) 2 | from functools import partial 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 8 | from timm.models.registry import register_model 9 | from collections import OrderedDict 10 | from einops import rearrange 11 | import random 12 | 13 | 14 | def _cfg(url='', **kwargs): 15 | return { 16 | 'url': url, 17 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None, 18 | 'crop_pct': .9, 'interpolation': 'bicubic', 19 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 20 | **kwargs 21 | } 22 | 23 | class DropPath(nn.Module): 24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 25 | """ 26 | def __init__(self, drop_prob=None): 27 | super(DropPath, self).__init__() 28 | self.drop_prob = drop_prob 29 | 30 | def forward(self, x): 31 | return drop_path(x, self.drop_prob, self.training) 32 | 33 | def extra_repr(self) -> str: 34 | return 'p={}'.format(self.drop_prob) 35 | 36 | class Adapter(nn.Module): 37 | def __init__(self, dim, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): 38 | super().__init__() 39 | self.skip_connect = skip_connect 40 | down_dim = int(dim * mlp_ratio) 41 | self.act = act_layer() 42 | self.D_fc1 = nn.Linear(dim, down_dim) 43 | self.D_fc2 = nn.Linear(down_dim, dim) 44 | 45 | def forward(self, x): 46 | # x is (BT, HW+1, D) 47 | xs = self.D_fc1(x) 48 | xs = self.act(xs) 49 | xs = self.D_fc2(xs) 50 | if self.skip_connect: 51 | x = x + xs 52 | else: 53 | x = xs 54 | return x 55 | 56 | class LayerNorm(nn.LayerNorm): 57 | """Subclass torch's LayerNorm to handle fp16.""" 58 | def forward(self, x: torch.Tensor): 59 | orig_type = x.dtype 60 | if orig_type == torch.float16: 61 | ret = super().forward(x) 62 | elif orig_type == torch.float32: 63 | ret = super().forward(x.type(torch.float32)) 64 | return ret.type(orig_type) 65 | 66 | class QuickGELU(nn.Module): 67 | def forward(self, x: torch.Tensor): 68 | return x * torch.sigmoid(1.702 * x) 69 | 70 | class PatchEmbed(nn.Module): 71 | """ Image to Patch Embedding 72 | """ 73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): 74 | super().__init__() 75 | img_size = to_2tuple(img_size) 76 | patch_size = to_2tuple(patch_size) 77 | self.tubelet_size = int(tubelet_size) 78 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size) 79 | self.img_size = img_size 80 | self.patch_size = patch_size 81 | self.num_patches = num_patches 82 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim, 83 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]), 84 | stride=(self.tubelet_size, patch_size[0], patch_size[1])) 85 | 86 | def forward(self, x, **kwargs): 87 | B, C, T, H, W = x.shape 88 | # FIXME look at relaxing size constraints 89 | assert H == self.img_size[0] and W == self.img_size[1], \ 90 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 91 | x = self.proj(x).flatten(2).transpose(1, 2) 92 | return x 93 | 94 | # sin-cos position encoding 95 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 96 | def get_sinusoid_encoding_table(n_position, d_hid): 97 | ''' Sinusoid position encoding table ''' 98 | # TODO: make it with torch instead of numpy 99 | def get_position_angle_vec(position): 100 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 101 | 102 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 103 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 104 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 105 | 106 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 107 | 108 | class Mlp(nn.Module): 109 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 110 | super().__init__() 111 | out_features = out_features or in_features 112 | hidden_features = hidden_features or in_features 113 | self.fc1 = nn.Linear(in_features, hidden_features) 114 | self.act = act_layer() 115 | self.fc2 = nn.Linear(hidden_features, out_features) 116 | self.drop = nn.Dropout(drop) 117 | 118 | def forward(self, x): 119 | x = self.fc1(x) 120 | x = self.act(x) 121 | # x = self.drop(x) 122 | # commit this for the orignal BERT implement 123 | x = self.fc2(x) 124 | x = self.drop(x) 125 | return x 126 | 127 | # 기존 weight load편의성을 위해 Attention이름을 유지한다. 128 | class Attention(nn.Module): 129 | def __init__( 130 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 131 | proj_drop=0., attn_head_dim=None): 132 | super().__init__() 133 | self.num_heads = num_heads 134 | head_dim = dim // num_heads 135 | if attn_head_dim is not None: 136 | head_dim = attn_head_dim 137 | all_head_dim = head_dim * self.num_heads 138 | self.scale = qk_scale or head_dim ** -0.5 139 | 140 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 141 | if qkv_bias: 142 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 143 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 144 | else: 145 | self.q_bias = None 146 | self.v_bias = None 147 | 148 | self.attn_drop = nn.Dropout(attn_drop) 149 | self.proj = nn.Linear(all_head_dim, dim) 150 | self.proj_drop = nn.Dropout(proj_drop) 151 | 152 | def forward(self, x): 153 | B, N, C = x.shape 154 | qkv_bias = None 155 | if self.q_bias is not None: 156 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 157 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 158 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 159 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 160 | s2t_q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 161 | 162 | s2t_q = s2t_q * self.scale 163 | attn = (s2t_q @ k.transpose(-2, -1)) 164 | 165 | 166 | attn = attn.softmax(dim=-1) 167 | attn = self.attn_drop(attn) 168 | 169 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 170 | x = self.proj(x) 171 | x = self.proj_drop(x) 172 | return x 173 | 174 | # spatial to temporal cross attention module. 175 | class CrossAttentionS2T(nn.Module): 176 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None): 177 | super().__init__() 178 | 179 | # add for cross-attn 180 | self.num_head = n_head 181 | head_dim = dim // self.num_head 182 | self.scale = head_dim ** -0.5 183 | all_head_dim = head_dim * self.num_head 184 | scale = dim ** -0.5 185 | self.space_time_pos = nn.Parameter(scale * torch.randn((197 * 8, dim))) 186 | 187 | #여기에 cross attn t2s module이 들어가야 한다. 188 | self.s2t_q = nn.Linear(dim, all_head_dim, bias=False) 189 | self.s2t_q_bias = nn.Parameter(torch.zeros(all_head_dim)) 190 | self.s2t_kv = nn.Linear(dim, all_head_dim * 2, bias=False) # 197 tokens(cls+patch) * num_frames 191 | self.s2t_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2)) 192 | 193 | self.t2s_proj = nn.Linear(all_head_dim, dim) 194 | 195 | self.attn_mask = attn_mask 196 | 197 | def s2t_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d] 198 | B, _, _ = t_x.shape 199 | s_x = rearrange(s_x, 'n (b t) d -> b (t n) d', b=B) # batch -> token 200 | s_x = s_x + self.space_time_pos ## sapce time position encoding 201 | s2t_q_bias = self.s2t_q_bias 202 | s2t_kv_bias = self.s2t_kv_bias 203 | 204 | s2t_q = F.linear(input=t_x, weight=self.s2t_q.weight, bias=s2t_q_bias) 205 | s2t_q = rearrange(s2t_q, 'b n (h d) -> b h n d', h=self.num_head) 206 | s2t_kv = F.linear(input=s_x, weight=self.s2t_kv.weight, bias=s2t_kv_bias) 207 | s2t_kv = rearrange(s2t_kv, 'b n (e h d) -> e b h n d',e=2, h=self.num_head) 208 | s2t_k, s2t_v = s2t_kv[0], s2t_kv[1] 209 | 210 | s2t_q = s2t_q * self.scale 211 | s2t_attn = (s2t_q @ s2t_k.transpose(-2, -1)) 212 | 213 | s2t_attn = s2t_attn.softmax(dim=-1) 214 | 215 | t_x = (s2t_attn @ s2t_v) 216 | t_x = rearrange(t_x, 'b h n d -> b n (h d)') 217 | t_x = self.t2s_proj(t_x) 218 | return t_x 219 | 220 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor): 221 | return self.s2t_cross_attn(s_x, t_x) 222 | 223 | 224 | # this codes from CLIP github(https://github.com/openai/CLIP) 225 | class CrossAttentionT2S(nn.Module): # 이게 VMAE로 치면 blocks class다. 여기에 cross s2t_attn layer가 추가되어야 한다. 226 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None): 227 | super().__init__() 228 | 229 | # add for cross-attn 230 | self.num_head = n_head 231 | head_dim = dim // self.num_head 232 | self.scale = head_dim ** -0.5 233 | all_head_dim = head_dim * self.num_head 234 | 235 | #여기에 cross attn t2s module이 들어가야 한다. 236 | self.t2s_q = nn.Linear(dim, all_head_dim, bias=False) # 197 tokens(cls+patch) * num_frames 237 | self.t2s_q_bias = nn.Parameter(torch.zeros(all_head_dim)) 238 | self.t2s_kv = nn.Linear(dim, all_head_dim * 2, bias=False) 239 | self.t2s_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2)) 240 | 241 | self.t2s_proj = nn.Linear(all_head_dim, dim) 242 | 243 | self.attn_mask = attn_mask 244 | 245 | def t2s_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d] 246 | B, _, _ = t_x.shape 247 | s_x_cls, s_x_pat = s_x[0, :, :], s_x[1:, :, :] 248 | s_x_pat = rearrange(s_x_pat, 'n (b t) d -> (b n) t d', b=B) # batch -> token 249 | t_x = rearrange(t_x, 'b (t n) d -> (b n) t d', t=8) 250 | t2s_q_bias = self.t2s_q_bias 251 | t2s_kv_bias = self.t2s_kv_bias 252 | 253 | t2s_q = F.linear(input=s_x_pat, weight=self.t2s_q.weight, bias=t2s_q_bias) 254 | t2s_q = rearrange(t2s_q, 'b t (h d) -> b h t d', h=self.num_head) 255 | t2s_kv = F.linear(input=t_x, weight=self.t2s_kv.weight, bias=t2s_kv_bias) 256 | t2s_kv = rearrange(t2s_kv, 'b t (e h d) -> e b h t d',e=2, h=self.num_head) 257 | t2s_k, t2s_v = t2s_kv[0], t2s_kv[1] 258 | 259 | t2s_q = t2s_q * self.scale 260 | t2s_attn = (t2s_q @ t2s_k.transpose(-2, -1)) 261 | 262 | t2s_attn = t2s_attn.softmax(dim=-1) 263 | 264 | s_x_pat = (t2s_attn @ t2s_v) 265 | s_x_pat = rearrange(s_x_pat, 'b h n d -> b n (h d)') 266 | s_x_pat = self.t2s_proj(s_x_pat) 267 | s_x_pat = rearrange(s_x_pat,'(b n) t d -> n (b t) d', b=B) 268 | s_x = torch.cat([s_x_cls.unsqueeze(0), s_x_pat], dim=0) 269 | return s_x 270 | 271 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor): 272 | return self.t2s_cross_attn(s_x, t_x) 273 | 274 | 275 | class Block(nn.Module): 276 | 277 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 278 | drop_path=0., init_values=None, num_layer=0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None,use_adapter=False): 279 | super().__init__() 280 | self.cross = None 281 | self.num_layer = num_layer 282 | self.num_heads = num_heads 283 | self.scale = 0.5 284 | mlp_hidden_dim = int(dim * mlp_ratio) 285 | self.act = act_layer() 286 | self.use_adapter=use_adapter 287 | ############################ VMAE MHSA ########################### 288 | self.norm1 = norm_layer(dim) 289 | self.attn = Attention( 290 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 291 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 292 | if self.use_adapter: 293 | self.T_Adapter = Adapter(dim, skip_connect=True)# base line 의 경우 True. 294 | ############################ VMAE FFN ############################### 295 | self.norm2 = norm_layer(dim) 296 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 297 | if self.use_adapter: 298 | self.T_MLP_Adapter = Adapter(dim, skip_connect=False) 299 | ####################################################################### 300 | ######################################################################################### 301 | 302 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 303 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 304 | 305 | def attention(self, x: torch.Tensor): 306 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 307 | return self.clip_attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 308 | 309 | def forward(self,t_x): 310 | B = t_x.shape[0] 311 | if self.use_adapter: 312 | t_x = t_x + self.T_Adapter(self.attn(self.norm1(t_x))) 313 | t_xn = self.norm2(t_x) 314 | t_x = t_x + self.mlp(t_xn) + self.drop_path(self.scale * self.T_MLP_Adapter(t_xn)) 315 | else: 316 | t_x = t_x + self.drop_path(self.attn(self.norm1(t_x))) 317 | t_x = t_x + self.drop_path(self.mlp(self.norm2(t_x))) 318 | 319 | 320 | ############################################################################ 321 | 322 | return t_x 323 | 324 | class STCrossTransformer(nn.Module): 325 | """ Vision Transformer with support for patch or hybrid CNN input stage 326 | """ 327 | def __init__(self, 328 | img_size=224, 329 | patch_size=16, 330 | in_chans=3, 331 | num_classes=1000, 332 | embed_dim=768, 333 | depth=12, 334 | num_heads=12, 335 | mlp_ratio=4., 336 | qkv_bias=False, 337 | qk_scale=None, 338 | drop_rate=0., 339 | attn_drop_rate=0., 340 | drop_path_rate=0., 341 | norm_layer=nn.LayerNorm, 342 | init_values=0., 343 | use_learnable_pos_emb=False, 344 | init_scale=0., 345 | all_frames=16, 346 | tubelet_size=2, 347 | use_mean_pooling=True, 348 | composition=False, 349 | pretrained_cfg = None, 350 | use_adapter=False, 351 | fusion_method=None 352 | ): 353 | super().__init__() 354 | self.num_classes = num_classes 355 | self.embed_dim = embed_dim # num_features for consistency with other models 356 | self.tubelet_size = tubelet_size 357 | self.composition = composition 358 | self.patch_embed = PatchEmbed( 359 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=all_frames, tubelet_size=self.tubelet_size) 360 | num_patches = self.patch_embed.num_patches 361 | self.use_adapter=use_adapter 362 | scale = embed_dim ** -0.5 363 | 364 | if use_learnable_pos_emb: 365 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 366 | else: 367 | # sine-cosine positional embeddings is on the way 368 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 369 | 370 | self.pos_drop = nn.Dropout(p=drop_rate) 371 | 372 | 373 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 374 | self.blocks = nn.ModuleList([ 375 | Block( 376 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 377 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 378 | init_values=init_values, num_layer=i,use_adapter=self.use_adapter) 379 | for i in range(depth)]) 380 | 381 | # self.clip_ln_post = LayerNorm(embed_dim) 382 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 383 | self.vmae_fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 384 | 385 | if self.composition: 386 | self.head_verb = nn.Linear(embed_dim, 97) 387 | self.head_noun = nn.Linear(embed_dim, 300) 388 | else: 389 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 390 | 391 | if use_learnable_pos_emb: 392 | trunc_normal_(self.pos_embed, std=.02) 393 | 394 | self.apply(self._init_weights) 395 | if self.use_adapter: 396 | self._init_adpater_weight() 397 | 398 | if self.composition: 399 | trunc_normal_(self.head_noun.weight, std=.02) 400 | trunc_normal_(self.head_verb.weight, std=.02) 401 | self.head_verb.weight.data.mul_(init_scale) 402 | self.head_verb.bias.data.mul_(init_scale) 403 | self.head_noun.weight.data.mul_(init_scale) 404 | self.head_noun.bias.data.mul_(init_scale) 405 | else: 406 | trunc_normal_(self.head.weight, std=.02) 407 | self.head.weight.data.mul_(init_scale) 408 | self.head.bias.data.mul_(init_scale) 409 | 410 | def _init_weights(self, m): 411 | if isinstance(m, nn.Linear): 412 | trunc_normal_(m.weight, std=.02) 413 | if isinstance(m, nn.Linear) and m.bias is not None: 414 | nn.init.constant_(m.bias, 0) 415 | elif isinstance(m, nn.LayerNorm): 416 | nn.init.constant_(m.bias, 0) 417 | nn.init.constant_(m.weight, 1.0) 418 | 419 | def _init_adpater_weight(self): 420 | for n, m in self.blocks.named_modules(): 421 | if 'Adapter' in n: 422 | for n2, m2 in m.named_modules(): 423 | if 'D_fc2' in n2: 424 | if isinstance(m2, nn.Linear): 425 | nn.init.constant_(m2.weight, 0) 426 | nn.init.constant_(m2.bias, 0) 427 | elif 'up' in n: 428 | for n2, m2 in m.named_modules(): 429 | if isinstance(m2, nn.Linear): 430 | nn.init.constant_(m2.weight, 0) 431 | nn.init.constant_(m2.bias, 0) 432 | 433 | 434 | def get_num_layers(self): 435 | return len(self.blocks) 436 | 437 | @torch.jit.ignore 438 | def no_weight_decay(self): 439 | return {'clip_temporal_embedding','pos_embed'} 440 | 441 | def get_classifier(self): 442 | return self.head 443 | 444 | def reset_classifier(self, num_classes, global_pool=''): 445 | self.num_classes = num_classes 446 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 447 | 448 | def reset_fcnorm(self): 449 | self.vmae_fc_norm = nn.LayerNorm(self.embed_dim) 450 | 451 | def forward_features(self, x): 452 | B = x.shape[0] 453 | ######################## VMAE spatial path ######################### 454 | t_x = self.patch_embed(x) 455 | 456 | if self.pos_embed is not None: 457 | t_x = t_x + self.pos_embed.expand(B, -1, -1).type_as(t_x).to(t_x.device).clone().detach() 458 | t_x = self.pos_drop(t_x) 459 | ##################################################################### 460 | 461 | for blk in self.blocks: 462 | t_x = blk(t_x) 463 | t_x = self.vmae_fc_norm(t_x.mean(1)) # all patch avg pooling 464 | 465 | return t_x 466 | 467 | 468 | def forward(self, x): 469 | if self.composition: 470 | t_x = self.forward_features(x) 471 | noun = self.head_noun(t_x) 472 | verb = self.head_verb(t_x) 473 | return noun, verb 474 | else: 475 | x = self.forward_features(x) 476 | x = self.head(x) 477 | return x 478 | 479 | 480 | 481 | 482 | @register_model 483 | def compo_videomae_adapter_vit_base_patch16_224(pretrained=False, **kwargs): 484 | model = STCrossTransformer( 485 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 486 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=True,use_adapter=True, **kwargs) 487 | #model.default_cfg = _cfg() 488 | return model 489 | @register_model 490 | def videomae_adapter_vit_base_patch16_224(pretrained=False, **kwargs): 491 | model = STCrossTransformer( 492 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 493 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=False,use_adapter=True, **kwargs) 494 | #model.default_cfg = _cfg() 495 | return model 496 | 497 | 498 | @register_model 499 | def compo_videomae_vit_base_patch16_224(pretrained=False, **kwargs): 500 | model = STCrossTransformer( 501 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 502 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=True,use_adapter=False, **kwargs) 503 | #model.default_cfg = _cfg() 504 | return model 505 | @register_model 506 | def videomae_vit_base_patch16_224(pretrained=False, **kwargs): 507 | model = STCrossTransformer( 508 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 509 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=False,use_adapter=False, **kwargs) 510 | #model.default_cfg = _cfg() 511 | return model 512 | 513 | 514 | -------------------------------------------------------------------------------- /models/clip_modelling_finetune.py: -------------------------------------------------------------------------------- 1 | # some codes from CLIP github(https://github.com/openai/CLIP), from VideoMAE github(https://github.com/MCG-NJU/VideoMAE) 2 | from functools import partial 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 8 | from timm.models.registry import register_model 9 | from collections import OrderedDict 10 | from einops import rearrange 11 | import random 12 | 13 | 14 | def _cfg(url='', **kwargs): 15 | return { 16 | 'url': url, 17 | 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None, 18 | 'crop_pct': .9, 'interpolation': 'bicubic', 19 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 20 | **kwargs 21 | } 22 | 23 | class DropPath(nn.Module): 24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 25 | """ 26 | def __init__(self, drop_prob=None): 27 | super(DropPath, self).__init__() 28 | self.drop_prob = drop_prob 29 | 30 | def forward(self, x): 31 | return drop_path(x, self.drop_prob, self.training) 32 | 33 | def extra_repr(self) -> str: 34 | return 'p={}'.format(self.drop_prob) 35 | 36 | class Adapter(nn.Module): 37 | def __init__(self, dim, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): 38 | super().__init__() 39 | self.skip_connect = skip_connect 40 | down_dim = int(dim * mlp_ratio) 41 | self.act = act_layer() 42 | self.D_fc1 = nn.Linear(dim, down_dim) 43 | self.D_fc2 = nn.Linear(down_dim, dim) 44 | 45 | def forward(self, x): 46 | # x is (BT, HW+1, D) 47 | xs = self.D_fc1(x) 48 | xs = self.act(xs) 49 | xs = self.D_fc2(xs) 50 | if self.skip_connect: 51 | x = x + xs 52 | else: 53 | x = xs 54 | return x 55 | 56 | class LayerNorm(nn.LayerNorm): 57 | """Subclass torch's LayerNorm to handle fp16.""" 58 | def forward(self, x: torch.Tensor): 59 | orig_type = x.dtype 60 | if orig_type == torch.float16: 61 | ret = super().forward(x) 62 | elif orig_type == torch.float32: 63 | ret = super().forward(x.type(torch.float32)) 64 | return ret.type(orig_type) 65 | 66 | class QuickGELU(nn.Module): 67 | def forward(self, x: torch.Tensor): 68 | return x * torch.sigmoid(1.702 * x) 69 | 70 | class PatchEmbed(nn.Module): 71 | """ Image to Patch Embedding 72 | """ 73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): 74 | super().__init__() 75 | img_size = to_2tuple(img_size) 76 | patch_size = to_2tuple(patch_size) 77 | self.tubelet_size = int(tubelet_size) 78 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size) 79 | self.img_size = img_size 80 | self.patch_size = patch_size 81 | self.num_patches = num_patches 82 | self.proj = nn.Conv3d(in_channels=in_chans, out_channels=embed_dim, 83 | kernel_size = (self.tubelet_size, patch_size[0],patch_size[1]), 84 | stride=(self.tubelet_size, patch_size[0], patch_size[1])) 85 | 86 | def forward(self, x, **kwargs): 87 | B, C, T, H, W = x.shape 88 | # FIXME look at relaxing size constraints 89 | assert H == self.img_size[0] and W == self.img_size[1], \ 90 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 91 | x = self.proj(x).flatten(2).transpose(1, 2) 92 | return x 93 | 94 | # sin-cos position encoding 95 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 96 | def get_sinusoid_encoding_table(n_position, d_hid): 97 | ''' Sinusoid position encoding table ''' 98 | # TODO: make it with torch instead of numpy 99 | def get_position_angle_vec(position): 100 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 101 | 102 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 103 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 104 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 105 | 106 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 107 | 108 | class Mlp(nn.Module): 109 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 110 | super().__init__() 111 | out_features = out_features or in_features 112 | hidden_features = hidden_features or in_features 113 | self.fc1 = nn.Linear(in_features, hidden_features) 114 | self.act = act_layer() 115 | self.fc2 = nn.Linear(hidden_features, out_features) 116 | self.drop = nn.Dropout(drop) 117 | 118 | def forward(self, x): 119 | x = self.fc1(x) 120 | x = self.act(x) 121 | # x = self.drop(x) 122 | # commit this for the orignal BERT implement 123 | x = self.fc2(x) 124 | x = self.drop(x) 125 | return x 126 | 127 | # 기존 weight load편의성을 위해 Attention이름을 유지한다. 128 | class Attention(nn.Module): 129 | def __init__( 130 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 131 | proj_drop=0., attn_head_dim=None): 132 | super().__init__() 133 | self.num_heads = num_heads 134 | head_dim = dim // num_heads 135 | if attn_head_dim is not None: 136 | head_dim = attn_head_dim 137 | all_head_dim = head_dim * self.num_heads 138 | self.scale = qk_scale or head_dim ** -0.5 139 | 140 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 141 | if qkv_bias: 142 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 143 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 144 | else: 145 | self.q_bias = None 146 | self.v_bias = None 147 | 148 | self.attn_drop = nn.Dropout(attn_drop) 149 | self.proj = nn.Linear(all_head_dim, dim) 150 | self.proj_drop = nn.Dropout(proj_drop) 151 | 152 | def forward(self, x): 153 | B, N, C = x.shape 154 | qkv_bias = None 155 | if self.q_bias is not None: 156 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 157 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 158 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 159 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 160 | s2t_q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 161 | 162 | s2t_q = s2t_q * self.scale 163 | attn = (s2t_q @ k.transpose(-2, -1)) 164 | 165 | 166 | attn = attn.softmax(dim=-1) 167 | attn = self.attn_drop(attn) 168 | 169 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 170 | x = self.proj(x) 171 | x = self.proj_drop(x) 172 | return x 173 | 174 | # spatial to temporal cross attention module. 175 | class CrossAttentionS2T(nn.Module): 176 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None): 177 | super().__init__() 178 | 179 | # add for cross-attn 180 | self.num_head = n_head 181 | head_dim = dim // self.num_head 182 | self.scale = head_dim ** -0.5 183 | all_head_dim = head_dim * self.num_head 184 | scale = dim ** -0.5 185 | self.space_time_pos = nn.Parameter(scale * torch.randn((197 * 8, dim))) 186 | 187 | #여기에 cross attn t2s module이 들어가야 한다. 188 | self.s2t_q = nn.Linear(dim, all_head_dim, bias=False) 189 | self.s2t_q_bias = nn.Parameter(torch.zeros(all_head_dim)) 190 | self.s2t_kv = nn.Linear(dim, all_head_dim * 2, bias=False) # 197 tokens(cls+patch) * num_frames 191 | self.s2t_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2)) 192 | 193 | self.t2s_proj = nn.Linear(all_head_dim, dim) 194 | 195 | self.attn_mask = attn_mask 196 | 197 | def s2t_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d] 198 | B, _, _ = t_x.shape 199 | s_x = rearrange(s_x, 'n (b t) d -> b (t n) d', b=B) # batch -> token 200 | s_x = s_x + self.space_time_pos ## sapce time position encoding 201 | s2t_q_bias = self.s2t_q_bias 202 | s2t_kv_bias = self.s2t_kv_bias 203 | 204 | s2t_q = F.linear(input=t_x, weight=self.s2t_q.weight, bias=s2t_q_bias) 205 | s2t_q = rearrange(s2t_q, 'b n (h d) -> b h n d', h=self.num_head) 206 | s2t_kv = F.linear(input=s_x, weight=self.s2t_kv.weight, bias=s2t_kv_bias) 207 | s2t_kv = rearrange(s2t_kv, 'b n (e h d) -> e b h n d',e=2, h=self.num_head) 208 | s2t_k, s2t_v = s2t_kv[0], s2t_kv[1] 209 | 210 | s2t_q = s2t_q * self.scale 211 | s2t_attn = (s2t_q @ s2t_k.transpose(-2, -1)) 212 | 213 | s2t_attn = s2t_attn.softmax(dim=-1) 214 | 215 | t_x = (s2t_attn @ s2t_v) 216 | t_x = rearrange(t_x, 'b h n d -> b n (h d)') 217 | t_x = self.t2s_proj(t_x) 218 | return t_x 219 | 220 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor): 221 | return self.s2t_cross_attn(s_x, t_x) 222 | 223 | 224 | # this codes from CLIP github(https://github.com/openai/CLIP) 225 | class CrossAttentionT2S(nn.Module): # 이게 VMAE로 치면 blocks class다. 여기에 cross s2t_attn layer가 추가되어야 한다. 226 | def __init__(self, dim: int, n_head: int, attn_mask: torch.Tensor = None): 227 | super().__init__() 228 | 229 | # add for cross-attn 230 | self.num_head = n_head 231 | head_dim = dim // self.num_head 232 | self.scale = head_dim ** -0.5 233 | all_head_dim = head_dim * self.num_head 234 | 235 | #여기에 cross attn t2s module이 들어가야 한다. 236 | self.t2s_q = nn.Linear(dim, all_head_dim, bias=False) # 197 tokens(cls+patch) * num_frames 237 | self.t2s_q_bias = nn.Parameter(torch.zeros(all_head_dim)) 238 | self.t2s_kv = nn.Linear(dim, all_head_dim * 2, bias=False) 239 | self.t2s_kv_bias = nn.Parameter(torch.zeros(all_head_dim * 2)) 240 | 241 | self.t2s_proj = nn.Linear(all_head_dim, dim) 242 | 243 | self.attn_mask = attn_mask 244 | 245 | def t2s_cross_attn(self, s_x, t_x): # s_x=[n (b t) d], t_x=[b n d] 246 | B, _, _ = t_x.shape 247 | s_x_cls, s_x_pat = s_x[0, :, :], s_x[1:, :, :] 248 | s_x_pat = rearrange(s_x_pat, 'n (b t) d -> (b n) t d', b=B) # batch -> token 249 | t_x = rearrange(t_x, 'b (t n) d -> (b n) t d', t=8) 250 | t2s_q_bias = self.t2s_q_bias 251 | t2s_kv_bias = self.t2s_kv_bias 252 | 253 | t2s_q = F.linear(input=s_x_pat, weight=self.t2s_q.weight, bias=t2s_q_bias) 254 | t2s_q = rearrange(t2s_q, 'b t (h d) -> b h t d', h=self.num_head) 255 | t2s_kv = F.linear(input=t_x, weight=self.t2s_kv.weight, bias=t2s_kv_bias) 256 | t2s_kv = rearrange(t2s_kv, 'b t (e h d) -> e b h t d',e=2, h=self.num_head) 257 | t2s_k, t2s_v = t2s_kv[0], t2s_kv[1] 258 | 259 | t2s_q = t2s_q * self.scale 260 | t2s_attn = (t2s_q @ t2s_k.transpose(-2, -1)) 261 | 262 | t2s_attn = t2s_attn.softmax(dim=-1) 263 | 264 | s_x_pat = (t2s_attn @ t2s_v) 265 | s_x_pat = rearrange(s_x_pat, 'b h n d -> b n (h d)') 266 | s_x_pat = self.t2s_proj(s_x_pat) 267 | s_x_pat = rearrange(s_x_pat,'(b n) t d -> n (b t) d', b=B) 268 | s_x = torch.cat([s_x_cls.unsqueeze(0), s_x_pat], dim=0) 269 | return s_x 270 | 271 | def forward(self, s_x: torch.Tensor, t_x: torch.Tensor): 272 | return self.t2s_cross_attn(s_x, t_x) 273 | 274 | 275 | class Block(nn.Module): 276 | 277 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 278 | drop_path=0., init_values=None, num_layer=0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None,use_adapter=False): 279 | super().__init__() 280 | self.cross = None 281 | self.num_layer = num_layer 282 | self.num_heads = num_heads 283 | self.scale = 0.5 284 | mlp_hidden_dim = int(dim * mlp_ratio) 285 | self.act = act_layer() 286 | self.use_adapter=use_adapter 287 | self.num_frames=16 288 | ############################ AIM MHSA ########################### 289 | self.clip_ln_1 = LayerNorm(dim) 290 | if self.use_adapter: 291 | # self.time_attn = nn.MultiheadAttention(dim, num_heads) 292 | self.T_Adapter = Adapter(dim, skip_connect=False) 293 | self.clip_attn = nn.MultiheadAttention(dim, num_heads) 294 | if self.use_adapter: 295 | self.S_Adapter = Adapter(dim) 296 | 297 | 298 | 299 | ############################ AIM FFN ############################### 300 | self.clip_ln_2 = LayerNorm(dim) 301 | self.clip_mlp = nn.Sequential(OrderedDict([ 302 | ("c_fc", nn.Linear(dim, dim * 4)), 303 | ("gelu", QuickGELU()), 304 | ("c_proj", nn.Linear(dim * 4, dim)) 305 | ])) 306 | if self.use_adapter: 307 | self.S_MLP_Adapter = Adapter(dim, skip_connect=False) 308 | self.attn_mask = None 309 | 310 | 311 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 312 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 313 | 314 | def attention(self, x: torch.Tensor): 315 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 316 | return self.clip_attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 317 | def time_attention(self, x:torch.Tensor): 318 | return self.time_attn(x,x,x,need_weights=False,attn_mask=None)[0] 319 | 320 | def forward(self,s_x): 321 | n, bt, d = s_x.shape 322 | 323 | if self.use_adapter: 324 | ############################ AIM TIME ############################# 325 | xt = rearrange(s_x, 'n (b t) d -> t (b n) d', t=self.num_frames) 326 | xt = self.T_Adapter(self.attention(self.clip_ln_1(xt))) 327 | xt = rearrange(xt, 't (b n) d -> n (b t) d', n=n) 328 | ########################################################## 329 | s_x = s_x + self.drop_path(xt) # skip connection original + time attention result 330 | # AIM Space MHSA 331 | s_x = s_x + self.S_Adapter(self.attention(self.clip_ln_1(s_x))) # original space multi head self attention 332 | ############################ FFN Forward ################################## 333 | s_xn = self.clip_ln_2(s_x) 334 | s_x = s_x + self.clip_mlp(s_xn) + self.drop_path(self.scale * self.S_MLP_Adapter(s_xn)) 335 | ############################################################################ 336 | else: 337 | s_x = s_x + self.attention(self.clip_ln_1(s_x)) 338 | s_x = s_x + self.clip_mlp(self.clip_ln_2(s_x)) 339 | return s_x 340 | 341 | class STCrossTransformer(nn.Module): 342 | """ Vision Transformer with support for patch or hybrid CNN input stage 343 | """ 344 | def __init__(self, 345 | img_size=224, 346 | patch_size=16, 347 | in_chans=3, 348 | num_classes=1000, 349 | embed_dim=768, 350 | depth=12, 351 | num_heads=12, 352 | mlp_ratio=4., 353 | qkv_bias=False, 354 | qk_scale=None, 355 | drop_rate=0., 356 | attn_drop_rate=0., 357 | drop_path_rate=0., 358 | norm_layer=nn.LayerNorm, 359 | init_values=0., 360 | use_learnable_pos_emb=False, 361 | init_scale=0., 362 | all_frames=16, 363 | tubelet_size=2, 364 | use_mean_pooling=True, 365 | composition=False, 366 | pretrained_cfg = None, 367 | use_adapter=False, 368 | fusion_method=None, 369 | ): 370 | super().__init__() 371 | self.num_classes = num_classes 372 | self.embed_dim = embed_dim # num_features for consistency with other models 373 | self.tubelet_size = tubelet_size 374 | self.composition = composition 375 | self.use_adapter=use_adapter 376 | scale = embed_dim ** -0.5 377 | self.fusion_method=fusion_method 378 | self.clip_conv1 = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False) 379 | self.clip_class_embedding = nn.Parameter(scale * torch.randn(embed_dim)) 380 | self.clip_positional_embedding = nn.Parameter(scale * torch.randn((img_size // patch_size) ** 2 + 1, embed_dim)) 381 | if self.use_adapter: 382 | self.clip_temporal_embedding = nn.Parameter(torch.zeros(1, all_frames, embed_dim)) 383 | self.clip_ln_pre = LayerNorm(embed_dim) 384 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 385 | self.blocks = nn.ModuleList([ 386 | Block( 387 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 388 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 389 | init_values=init_values, num_layer=i,use_adapter=self.use_adapter) 390 | for i in range(depth)]) 391 | 392 | self.clip_ln_post = LayerNorm(embed_dim) 393 | 394 | if self.composition: 395 | self.head_verb = nn.Linear(embed_dim, 97) 396 | self.head_noun = nn.Linear(embed_dim, 300) 397 | else: 398 | self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 399 | 400 | if use_learnable_pos_emb: 401 | trunc_normal_(self.pos_embed, std=.02) 402 | 403 | self.apply(self._init_weights) 404 | if self.use_adapter: 405 | self._init_adpater_weight() 406 | 407 | if self.composition: 408 | trunc_normal_(self.head_noun.weight, std=.02) 409 | trunc_normal_(self.head_verb.weight, std=.02) 410 | self.head_verb.weight.data.mul_(init_scale) 411 | self.head_verb.bias.data.mul_(init_scale) 412 | self.head_noun.weight.data.mul_(init_scale) 413 | self.head_noun.bias.data.mul_(init_scale) 414 | else: 415 | trunc_normal_(self.head.weight, std=.02) 416 | # self.head.weight.data.mul_(init_scale) 417 | # self.head.bias.data.mul_(init_scale) 418 | 419 | def _init_weights(self, m): 420 | if isinstance(m, nn.Linear): 421 | trunc_normal_(m.weight, std=.02) 422 | if isinstance(m, nn.Linear) and m.bias is not None: 423 | nn.init.constant_(m.bias, 0) 424 | elif isinstance(m, nn.LayerNorm): 425 | nn.init.constant_(m.bias, 0) 426 | nn.init.constant_(m.weight, 1.0) 427 | 428 | def _init_adpater_weight(self): 429 | for n, m in self.blocks.named_modules(): 430 | if 'Adapter' in n: 431 | for n2, m2 in m.named_modules(): 432 | if 'D_fc2' in n2: 433 | if isinstance(m2, nn.Linear): 434 | nn.init.constant_(m2.weight, 0) 435 | nn.init.constant_(m2.bias, 0) 436 | elif 'up' in n: 437 | for n2, m2 in m.named_modules(): 438 | if isinstance(m2, nn.Linear): 439 | nn.init.constant_(m2.weight, 0) 440 | nn.init.constant_(m2.bias, 0) 441 | 442 | 443 | def get_num_layers(self): 444 | return len(self.blocks) 445 | 446 | @torch.jit.ignore 447 | def no_weight_decay(self): 448 | return {'clip_temporal_embedding','pos_embed'} 449 | 450 | def get_classifier(self): 451 | return self.head 452 | 453 | def reset_classifier(self, num_classes, global_pool=''): 454 | self.num_classes = num_classes 455 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 456 | 457 | def reset_fcnorm(self): 458 | self.vmae_fc_norm = nn.LayerNorm(self.embed_dim) 459 | 460 | def forward_features(self, x): 461 | B = x.shape[0] 462 | ######################## AIM spatial path ######################### 463 | s_x= x 464 | s_t = s_x.shape[2] 465 | s_x = rearrange(s_x, 'b c t h w -> (b t) c h w') 466 | s_x = self.clip_conv1(s_x) # shape = [*, embeddim, grid, grid] 467 | s_x = s_x.reshape(s_x.shape[0], s_x.shape[1], -1) # [*, embeddim, grid**2] 468 | s_x = s_x.permute(0, 2, 1) # shape[batch, patchnum, embeddim] 469 | s_x = torch.cat([self.clip_class_embedding.to(s_x.dtype) + torch.zeros(s_x.shape[0], 1, s_x.shape[-1], dtype=s_x.dtype, device=s_x.device), s_x], dim=1) 470 | s_x = s_x + self.clip_positional_embedding.to(s_x.dtype) 471 | n = s_x.shape[1] 472 | if self.use_adapter: 473 | s_x = rearrange(s_x, '(b t) n d -> (b n) t d', t=s_t) 474 | s_x = s_x + self.clip_temporal_embedding#(1,t,d) 475 | s_x = rearrange(s_x, '(b n) t d -> (b t) n d', n=n) 476 | s_x = self.clip_ln_pre(s_x) 477 | ##################################################################### ##################################################################### 478 | 479 | 480 | s_x = s_x.permute(1,0,2) 481 | for blk in self.blocks: 482 | s_x = blk(s_x) 483 | s_x = s_x.permute(1,0,2) 484 | 485 | s_x = rearrange(s_x, '(b t) n d -> b t n d', b=B) 486 | s_x = self.clip_ln_post(s_x[:,:,0,:].mean(1)) # all cls tokens avg pooling 487 | 488 | 489 | return s_x 490 | 491 | 492 | def forward(self, x): 493 | if self.composition: 494 | s_x = self.forward_features(x) 495 | noun = self.head_noun(s_x) 496 | verb = self.head_verb(s_x) 497 | return noun, verb 498 | else: 499 | x = self.forward_features(x) 500 | x = self.head(x) 501 | return x 502 | 503 | 504 | 505 | @register_model 506 | def compo_clip_vit_base_patch16_224(pretrained=False, **kwargs): 507 | model = STCrossTransformer( 508 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 509 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=True,use_adapter=False,**kwargs) 510 | #model.default_cfg = _cfg() 511 | return model 512 | def clip_vit_base_patch16_224(pretrained=False, **kwargs): 513 | model = STCrossTransformer( 514 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 515 | norm_layer=partial(nn.LayerNorm, eps=1e-6), composition=False,use_adapter=False,**kwargs) 516 | #model.default_cfg = _cfg() 517 | return model 518 | 519 | --------------------------------------------------------------------------------