├── README.md ├── dataset ├── __init__.py ├── kinetics.py ├── prepare_dataset.py ├── something_something.py ├── transform.py └── transforms.py ├── main.py ├── models ├── SmallBig.py ├── __init__.py └── blocks.py ├── scripts └── kinetics.sh ├── tools ├── __init__.py └── tools.py └── train_val.py /README.md: -------------------------------------------------------------------------------- 1 | # SmallBigNet 2 | 3 | 4 | This repo is the official implementation of our paper ["SmallBigNet: Integrating Core and Contextual Views for Video Classification (CVPR2020)"](https://arxiv.org/abs/2006.14582). 5 | 6 | ## Citation 7 | 8 | 9 | ``` 10 | @inproceedings{li2020smallbignet, 11 | title={Smallbignet: Integrating core and contextual views for video classification}, 12 | author={Li, Xianhang and Wang, Yali and Zhou, Zhipeng and Qiao, Yu}, 13 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 14 | year={2020} 15 | } 16 | 17 | ``` 18 | 19 | 20 | 21 | ## Usage 22 | 23 | ### Data Preparation 24 | First, please follow the [mmaction2](https://github.com/open-mmlab/mmaction2/blob/master/tools/data/kinetics/README.md) to prepare data. Note that our codebase only supports the **RGB** frames. Thus you may need to decord the video dataset offline and store it in SSD. 25 | If you need the Kinetics-400 dataset, please feel free to email me. 26 | (Tips: if you want to use video online decode, highly recommend you to use the mmaction2. Our idea is simple so only a few codes need to change in [resnet3d.py](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/backbones/resnet3d.py) ) 27 | 28 | 29 | ### K400 Training Scripts 30 | 31 | After you prepare the dataset, edit the parameters in ``scripts/kinectis.sh``. 32 | 33 | `` 34 | --half indicates using mix precision 35 | `` 36 | 37 | `` 38 | --root_path the path you store the whole dataset(RGB) 39 | `` 40 | 41 | `` 42 | --train_list_file the train list file (video_name num_frames label) 43 | `` 44 | 45 | 46 | `` 47 | --val_list_file the val list file (video_name num_frames label) 48 | `` 49 | 50 | `` 51 | --model_name [res50, slowonly50, slowonly50_extra, smallbig50_no_extra,smallbig50_extra, smallbig101_no_extra] 52 | `` 53 | 54 | `` 55 | --image_tmpl the format of the name you store the RGB frames like img_{:05d}.jpg 56 | `` 57 | 58 | 59 | ---------------- 60 | 61 | If you have any question about the code and data, please contact us directly. 62 | 63 | 64 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhl-video/SmallBigNet/9e6d9ea4b61a0efb87893ab830463f56d0c5c8b4/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/kinetics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch.utils.data as data 4 | 5 | import os 6 | import os.path 7 | import numpy as np 8 | import torchvision 9 | from numpy.random import randint 10 | 11 | import torch 12 | 13 | from dataset import transforms 14 | from dataset.transform import * 15 | 16 | 17 | class VideoRecord(object): 18 | def __init__( 19 | self, 20 | row, 21 | root_path, 22 | phase='Train', 23 | copy_id=0, 24 | crop=0, 25 | vid=0): 26 | self._data = row 27 | self.crop_pos = crop 28 | 29 | self.phase = phase 30 | self.copy_id = copy_id 31 | self.vid = vid 32 | self._root_path = root_path 33 | 34 | @property 35 | def path(self): 36 | if self.phase == 'Train': 37 | return os.path.join( 38 | self._root_path, self._data[0].replace( 39 | 'RGB_train/', '')) 40 | else: 41 | return os.path.join( 42 | self._root_path, 43 | self._data[0].replace( 44 | 'RGB_val/', 45 | '')) 46 | 47 | @property 48 | def num_frames(self): 49 | return int(self._data[1]) 50 | 51 | @property 52 | def label(self): 53 | return int(self._data[2]) 54 | 55 | 56 | class Kinetics(data.Dataset): 57 | def __init__(self, 58 | root_path, 59 | list_file, 60 | t_length=8, 61 | t_stride=8, 62 | num_clips=1, 63 | image_tmpl='img_{:05d}.jpg', 64 | transform=None, 65 | crop_num=1, 66 | style="Dense", 67 | phase="Train", 68 | seed=1): 69 | """ 70 | :style: Dense, for 2D and 3D model, and Sparse for TSN model 71 | :phase: Train, Val, Test 72 | """ 73 | 74 | self.root_path = root_path 75 | self.list_file = list_file 76 | self.crop_num = crop_num 77 | self.t_length = t_length 78 | self.t_stride = t_stride 79 | self.num_clips = num_clips 80 | self.image_tmpl = image_tmpl 81 | self.transform = transform 82 | self.video_frame = {} 83 | self.rng = np.random.RandomState(seed) 84 | 85 | assert(style in ("Dense", "UnevenDense") 86 | ), "Only support Dense and UnevenDense" 87 | self.style = style 88 | self.phase = phase 89 | 90 | assert(t_length > 0), "Length of time must be bigger than zero." 91 | assert(t_stride > 0), "Stride of time must be bigger than zero." 92 | 93 | self._parse_list() 94 | 95 | def _load_image(self, directory, idx): 96 | from PIL import Image 97 | if os.path.exists( 98 | os.path.join( 99 | directory, 100 | self.image_tmpl.format(idx))): 101 | 102 | cv_img = Image.open( 103 | os.path.join( 104 | directory, 105 | self.image_tmpl.format(idx))).convert('RGB') 106 | else: 107 | while True: 108 | idx += 1 109 | if os.path.exists( 110 | os.path.join( 111 | directory, 112 | self.image_tmpl.format(idx))): 113 | cv_img = Image.open( 114 | os.path.join( 115 | directory, 116 | self.image_tmpl.format(idx))).convert('RGB') 117 | break 118 | return [cv_img] 119 | 120 | def _parse_list(self): 121 | self.video_list = [] 122 | if self.phase == 'Fntest': 123 | vid = 0 124 | for x in open(self.list_file): 125 | idx = 0 126 | for j in range(self.crop_num): 127 | 128 | for i in range(self.num_clips): 129 | data = x.strip().split(' ')[0] 130 | name = data.split('/')[-1].split('.')[0][0:11] 131 | path = self.root_path 132 | if os.path.exists(os.path.join(path, name)): 133 | self.video_list.append(VideoRecord([name, x.strip().split(' ')[1], int(x.strip( 134 | ).split(' ')[2])], self.root_path, phase='Val', copy_id=idx, crop=j, vid=vid)) 135 | idx += 1 136 | vid += 1 137 | 138 | elif self.phase == 'Val': 139 | for x in open(self.list_file): 140 | data = x.strip().split(' ')[0] 141 | name = data.split('/')[-1].split('.')[0][0:11] 142 | # name = os.path.join('val', name[0:11]) 143 | path = self.root_path 144 | if os.path.exists(os.path.join(path, name)): 145 | self.video_list.append(VideoRecord([name, x.strip().split(' ')[1], int( 146 | x.strip().split(' ')[2])], self.root_path, phase='Val', )) 147 | 148 | elif self.phase == 'Train': 149 | for x in open(self.list_file): 150 | idx = 0 151 | for i in range(self.num_clips): 152 | data = x.strip().split(' ')[0] 153 | name = data.split('/')[-1].split('.')[0][0:11] 154 | #name = os.path.join('train', name[0:11]) 155 | path = self.root_path 156 | # print(name) 157 | if os.path.exists(os.path.join(path, name)): 158 | self.video_list.append(VideoRecord([name, x.strip().split(' ')[1], int( 159 | x.strip().split(' ')[2])], self.root_path, phase='Train')) 160 | idx += 1 161 | self.rng.shuffle(self.video_list) 162 | 163 | @staticmethod 164 | def dense_sampler(num_frames, length, stride=1): 165 | t_length = length 166 | t_stride = stride 167 | offset = 0 168 | average_duration = num_frames - (t_length - 1) * t_stride - 1 169 | if average_duration >= 0: 170 | offset = randint(average_duration + 1) 171 | elif num_frames > t_length: 172 | while (t_stride - 1 > 0): 173 | t_stride -= 1 174 | average_duration = num_frames - (t_length - 1) * t_stride - 1 175 | if average_duration >= 0: 176 | offset = randint(average_duration + 1) 177 | break 178 | assert (t_stride >= 1), "temporal stride must be bigger than zero." 179 | else: 180 | t_stride = 1 181 | # sampling 182 | samples = [] 183 | for i in range(t_length): 184 | samples.append(offset + i * t_stride + 1) 185 | return samples 186 | 187 | def _sample_indices(self, record): 188 | """ 189 | :param record: VideoRecord 190 | :return: list 191 | """ 192 | if self.style == "Dense": 193 | frames = [] 194 | average_duration = record.num_frames / self.num_clips 195 | offsets = [average_duration * i for i in range(self.num_clips)] 196 | for i in range(self.num_clips): 197 | samples = self.dense_sampler( 198 | average_duration, self.t_length, self.t_stride) 199 | samples = [int(sample + offsets[i]) for sample in samples] 200 | frames.append(samples) 201 | return {"dense": frames[0]} 202 | 203 | def _get_val_indices(self, record): 204 | """ 205 | get indices in val phase 206 | """ 207 | valid_offset_range = record.num_frames - \ 208 | (self.t_length - 1) * self.t_stride 209 | if valid_offset_range > 0: 210 | offset = randint(valid_offset_range) 211 | else: 212 | offset = valid_offset_range 213 | if offset < 0: 214 | offset = 0 215 | samples = [] 216 | for i in range(self.t_length): 217 | samples.append(offset + i * self.t_stride + 1) 218 | return {"dense": samples} 219 | 220 | def _get_test_index(self, record): 221 | sample_op = max(1, record.num_frames - self.t_length * self.t_stride) 222 | t_stride = self.t_stride 223 | start_list = np.linspace( 224 | 0, sample_op - 1, num=self.num_clips, dtype=int) 225 | offsets = [] 226 | for start in start_list.tolist(): 227 | offsets.append([(idx * t_stride + start) % 228 | record.num_frames for idx in range(self.t_length)]) 229 | frame = np.array(offsets) + 1 230 | return {"dense": frame.tolist()} 231 | 232 | def __getitem__(self, index): 233 | record = self.video_list[index] 234 | 235 | if self.phase == "Train": 236 | indices = self._sample_indices(record) 237 | return self.get(record, indices, self.phase, index) 238 | elif self.phase == "Val": 239 | indices = self._get_val_indices(record) 240 | return self.get(record, indices, self.phase, index) 241 | elif self.phase == "Fntest": 242 | 243 | indices = self._get_test_indices(record) 244 | 245 | idx = record.copy_id % self.num_clips 246 | 247 | indices['dense'] = indices['dense'][idx] 248 | 249 | return self.get(record, indices, self.phase, index) 250 | else: 251 | raise TypeError("Unsuported phase {}".format(self.phase)) 252 | 253 | def get(self, record, indices, phase, index): 254 | # dense process data 255 | def dense_process_data(index): 256 | images = list() 257 | for ind in indices['dense']: 258 | ptr = int(ind) 259 | 260 | if ptr <= record.num_frames: 261 | imgs = self._load_image(record.path, ptr) 262 | else: 263 | imgs = self._load_image(record.path, record.num_frames) 264 | images.extend(imgs) 265 | 266 | if self.phase == 'Fntest': 267 | 268 | images = [np.asarray(im) for im in images] 269 | clip_input = np.concatenate(images, axis=2) 270 | 271 | self.t = transforms.Compose([ 272 | transforms.Resize(256)]) 273 | clip_input = self.t(clip_input) 274 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 275 | std=[0.229, 0.224, 0.225]) 276 | if record.crop_pos == 0: 277 | self.transform = transforms.Compose([ 278 | 279 | transforms.CenterCrop((256, 256)), 280 | 281 | transforms.ToTensor(), 282 | normalize, 283 | ]) 284 | elif record.crop_pos == 1: 285 | self.transform = transforms.Compose([ 286 | 287 | transforms.CornerCrop2((256, 256)), 288 | 289 | transforms.ToTensor(), 290 | normalize, 291 | ]) 292 | elif record.crop_pos == 2: 293 | self.transform = transforms.Compose([ 294 | transforms.CornerCrop1((256, 256)), 295 | 296 | transforms.ToTensor(), 297 | normalize, 298 | ]) 299 | 300 | return self.transform(clip_input) 301 | 302 | if self.phase == 'Train': 303 | return self.transform(images) 304 | if self.phase == 'Val': 305 | return self.transform(images) 306 | 307 | if phase == "Train": 308 | if self.style == "Dense": 309 | process_data = dense_process_data(index) 310 | elif phase in ("Val", "Test"): 311 | process_data = dense_process_data(index) 312 | return process_data, record.label # , indices 313 | else: 314 | process_data = dense_process_data(index) 315 | return process_data, record.label # ,record.vid 316 | 317 | return process_data, record.label # , indices 318 | 319 | def __len__(self): 320 | return len(self.video_list) 321 | 322 | 323 | if __name__ == "__main__": 324 | parser = argparse.ArgumentParser(description='SmallBig Training') 325 | parser.add_argument('--batch_size', default=1, type=int, 326 | help="Total batch size for training.") 327 | parser.add_argument('--t_length', default=8, type=int, 328 | help="Total length of sampling frames.") 329 | parser.add_argument('--t_stride', default=8, type=int, 330 | help="Temporal stride between each frame.") 331 | parser.add_argument('--num_clips', default=1, type=int, 332 | help="Total number of clips for training or testing.") 333 | parser.add_argument( 334 | '--crop_num', 335 | default=1, 336 | type=int, 337 | help="Total number of crops for each frame during full-resolution testing.") 338 | parser.add_argument('--image_tmpl', default='image_{:06d}.jpg', type=str, 339 | help="The name format of each frames you saved.") 340 | parser.add_argument('--seed', default=0, type=int, 341 | help="Random Seed") 342 | parser.add_argument( 343 | '--phase', 344 | default='Val', 345 | choices=[ 346 | "Train", 347 | "Val", 348 | "Fntest"], 349 | help="Different phases have different sampling methods.") 350 | parser.add_argument('--root_path', default='/dataset/kinetics', type=str, 351 | help='root path for accessing your image data') 352 | parser.add_argument( 353 | '--list_file', 354 | default='/dataset/kinetics/val.txt', 355 | type=str, 356 | help='path for your data list(txt)') 357 | 358 | args = parser.parse_args() 359 | 360 | transform = torchvision.transforms.Compose([ 361 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 362 | GroupRandomHorizontalFlip(), 363 | Stack(mode='3D'), 364 | ToTorchFormatTensor(), 365 | GroupNormalize(), 366 | ]) 367 | dataset = Kinetics( 368 | root_path=args.root_path, 369 | list_file=args.list_file, 370 | t_length=args.t_length, 371 | t_stride=args.t_stride, 372 | crop_num=args.crop_num, 373 | num_clips=args.num_clips, 374 | image_tmpl=args.image_tmpl, 375 | transform=transform, 376 | phase=args.phase, 377 | seed=args.seed) 378 | loader = torch.utils.data.DataLoader( 379 | dataset, 380 | batch_size=args.batch_size, 381 | shuffle=True, 382 | drop_last=True, 383 | num_workers=8, 384 | pin_memory=True) 385 | 386 | for ind, (data, label) in enumerate(loader): 387 | label = label.cuda(non_blocking=True) 388 | 389 | -------------------------------------------------------------------------------- /dataset/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | from dataset.kinetics import Kinetics 8 | from dataset.something_something import Someting_something 9 | from dataset.transform import * 10 | from dataset.transforms import Lighting, To_3DTensor 11 | from tools.tools import is_main_process 12 | 13 | __imagenet_pca = { 14 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 15 | 'eigvec': torch.Tensor([ 16 | [-0.5675, 0.7192, 0.4009], 17 | [-0.5808, -0.0045, -0.8140], 18 | [-0.5836, -0.6948, 0.4203], 19 | ]) 20 | } 21 | 22 | def get_dataloader(args): 23 | if args.dataset == 'kinetics': 24 | train_transform = torchvision.transforms.Compose([ 25 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 26 | GroupRandomHorizontalFlip(), 27 | Stack(mode='3D'), 28 | ToTorchFormatTensor(), 29 | GroupNormalize(), 30 | ]) 31 | train_dataset = Kinetics( 32 | root_path=args.root_path, 33 | list_file=args.train_list_file, 34 | t_length=args.t_length, 35 | t_stride=args.t_stride, 36 | crop_num=args.crop_num, 37 | num_clips=args.num_clips, 38 | image_tmpl=args.image_tmpl, 39 | transform=train_transform, 40 | phase='Train', 41 | seed=args.seed) 42 | val_transform = torchvision.transforms.Compose([ 43 | GroupScale(256), 44 | GroupCenterCrop(224), 45 | Stack(mode='3D'), 46 | ToTorchFormatTensor(), 47 | GroupNormalize(), 48 | ]) 49 | val_dataset = Kinetics( 50 | root_path=args.root_path, 51 | list_file=args.val_list_file, 52 | t_length=args.t_length, 53 | t_stride=args.t_stride, 54 | crop_num=args.crop_num, 55 | num_clips=args.num_clips, 56 | image_tmpl=args.image_tmpl, 57 | transform=val_transform, 58 | phase='Val', 59 | seed=args.seed) 60 | elif args.dataset == 'something': 61 | train_transform = torchvision.transforms.Compose([ 62 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 63 | Stack(mode='3D'), 64 | ToTorchFormatTensor(), 65 | GroupNormalize(), 66 | ]) 67 | train_dataset = Someting_something( 68 | root_path=args.root_path, 69 | list_file=args.train_list_file, 70 | t_length=args.t_length, 71 | t_stride=args.t_stride, 72 | crop_num=args.crop_num, 73 | num_clips=args.num_clips, 74 | image_tmpl=args.image_tmpl, 75 | transform=train_transform, 76 | phase='Train', 77 | seed=args.seed) 78 | val_transform = torchvision.transforms.Compose([ 79 | GroupScale(256), 80 | GroupCenterCrop(224), 81 | Stack(mode='3D'), 82 | ToTorchFormatTensor(), 83 | GroupNormalize(), 84 | ]) 85 | val_dataset = Someting_something( 86 | root_path=args.root_path, 87 | list_file=args.val_list_file, 88 | t_length=args.t_length, 89 | t_stride=args.t_stride, 90 | crop_num=args.crop_num, 91 | num_clips=args.num_clips, 92 | image_tmpl=args.image_tmpl, 93 | transform=val_transform, 94 | phase='Val', 95 | seed=args.seed) 96 | elif args.dataset == 'imagenet': 97 | traindir = os.path.join(args.root_path, 'train') 98 | valdir = os.path.join(args.root_path, 'val') 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | train_dataset = datasets.ImageFolder( 102 | traindir, 103 | transforms.Compose([ 104 | transforms.RandomResizedCrop(224), 105 | transforms.RandomHorizontalFlip(), 106 | transforms.ColorJitter(.4, .4, .4), 107 | transforms.ToTensor(), 108 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 109 | normalize, 110 | To_3DTensor() 111 | ])) 112 | val_dataset = datasets.ImageFolder( 113 | valdir, 114 | transforms.Compose([ 115 | transforms.Resize(256), 116 | transforms.CenterCrop(224), 117 | transforms.ToTensor(), 118 | normalize, 119 | To_3DTensor() 120 | ])) 121 | 122 | if args.distribute: 123 | train_sampler = torch.utils.data.distributed.DistributedSampler( 124 | train_dataset, 125 | shuffle=True, 126 | num_replicas=torch.distributed.get_world_size(), 127 | rank=args.local_rank) 128 | 129 | val_sampler = torch.utils.data.distributed.DistributedSampler( 130 | val_dataset, 131 | shuffle=False, 132 | num_replicas=torch.distributed.get_world_size(), 133 | rank=args.local_rank) 134 | 135 | batch_size = args.batch_size // torch.distributed.get_world_size() 136 | 137 | else: 138 | train_sampler = None 139 | val_sampler = None 140 | batch_size = args.batch_size 141 | 142 | train_loader = torch.utils.data.DataLoader( 143 | train_dataset, 144 | batch_size=batch_size, 145 | shuffle=not args.distribute, 146 | drop_last=True, 147 | num_workers=8, 148 | sampler=train_sampler, 149 | pin_memory=True) 150 | 151 | val_loader = torch.utils.data.DataLoader( 152 | val_dataset, 153 | batch_size=batch_size, 154 | shuffle=False, 155 | drop_last=False, 156 | sampler=val_sampler, 157 | num_workers=8, 158 | pin_memory=True) 159 | 160 | dataloaders = {'train': train_loader, 'val': val_loader} 161 | samplers = {'train': train_sampler, 'val': val_sampler} 162 | dataset_sizes = {x: len(dataloaders[x].dataset) for x in ['train', 'val']} 163 | if is_main_process(): 164 | print(args.dataset, 'has the size: ', dataset_sizes) 165 | return dataloaders, dataset_sizes, samplers 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser(description='SmallBig Training') 170 | parser.add_argument('--batch_size', default=1, type=int, 171 | help="Total batch size for training.") 172 | parser.add_argument('--t_length', default=8, type=int, 173 | help="Total length of sampling frames.") 174 | parser.add_argument('--t_stride', default=8, type=int, 175 | help="Temporal stride between each frame.") 176 | parser.add_argument('--num_clips', default=1, type=int, 177 | help="Total number of clips for training or testing.") 178 | parser.add_argument( 179 | '--crop_num', 180 | default=1, 181 | type=int, 182 | help="Total number of crops for each frame during full-resolution testing.") 183 | parser.add_argument('--image_tmpl', default='image_{:06d}.jpg', type=str, 184 | help="The name format of each frames you saved.") 185 | parser.add_argument('--seed', default=0, type=int, 186 | help="Random Seed") 187 | parser.add_argument( 188 | '--dataset', 189 | default='kinetics', 190 | choices=[ 191 | "kinetics", 192 | "something"], 193 | help="Choose dataset for training and validation") 194 | parser.add_argument( 195 | '--phase', 196 | default='Val', 197 | choices=[ 198 | "Train", 199 | "Val", 200 | "Fntest"], 201 | help="Different phases have different sampling methods.") 202 | parser.add_argument('--root_path', default='/dataset/kinetics', type=str, 203 | help='root path for accessing your image data') 204 | parser.add_argument( 205 | '--val_list_file', 206 | default='/dataset/kinetics/val.txt', 207 | type=str, 208 | help='path for your data list(txt)') 209 | parser.add_argument( 210 | '--train_list_file', 211 | default='/dataset/kinetics/train.txt', 212 | type=str, 213 | help='path for your data list(txt)') 214 | 215 | parser.add_argument( 216 | '--local_rank', 217 | type=int, 218 | default=0, 219 | help='node rank for distributed training') 220 | parser.add_argument('--distribute', action='store_true') 221 | 222 | args = parser.parse_args() 223 | 224 | dataloaders, dataset_sizes, samplers = get_dataloader(args) 225 | for k, v in dataset_sizes.items(): 226 | print(k, v) 227 | -------------------------------------------------------------------------------- /dataset/something_something.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import torch.utils.data as data 5 | 6 | import os 7 | import os.path 8 | import numpy as np 9 | import torchvision 10 | from numpy.random import randint 11 | 12 | import torch 13 | 14 | from dataset import transforms 15 | from dataset.transform import * 16 | 17 | 18 | class VideoRecord(object): 19 | def __init__( 20 | self, 21 | row, 22 | root_path, 23 | phase='Train', 24 | copy_id=0, 25 | crop=0, 26 | vid=0): 27 | self._data = row 28 | self.crop_pos = crop 29 | 30 | self.phase = phase 31 | self.copy_id = copy_id 32 | self.vid = vid 33 | self._root_path = root_path 34 | 35 | @property 36 | def path(self): 37 | if self.phase == 'Train': 38 | return os.path.join( 39 | self._root_path, self._data[0].replace( 40 | 'RGB_train/', '')) 41 | else: 42 | return os.path.join( 43 | self._root_path, 44 | self._data[0].replace( 45 | 'RGB_val/', 46 | '')) 47 | 48 | @property 49 | def num_frames(self): 50 | return int(self._data[1]) 51 | 52 | @property 53 | def label(self): 54 | return int(self._data[2]) 55 | 56 | 57 | class Someting_something(data.Dataset): 58 | def __init__(self, 59 | root_path, 60 | list_file, 61 | t_length=32, 62 | t_stride=2, 63 | num_clips=10, 64 | image_tmpl='img_{:05d}.jpg', 65 | transform=None, 66 | crop_num=1, 67 | style="Dense", 68 | phase="Train", 69 | seed=1): 70 | """ 71 | :style: Dense, for 2D and 3D model, and Sparse for TSN model 72 | :phase: Train, Val, Test 73 | """ 74 | 75 | self.root_path = root_path 76 | self.list_file = list_file 77 | self.crop_num = crop_num 78 | self.t_length = t_length 79 | self.t_stride = t_stride 80 | self.image_tmpl = image_tmpl 81 | self.transform = transform 82 | self.n_times = num_clips 83 | self.rng = np.random.RandomState(seed) 84 | assert(style in ("Dense", "UnevenDense") 85 | ), "Only support Dense and UnevenDense" 86 | self.style = style 87 | self.phase = phase 88 | 89 | assert(t_length > 0), "Length of time must be bigger than zero." 90 | assert(t_stride > 0), "Stride of time must be bigger than zero." 91 | 92 | self._parse_list() 93 | 94 | def _load_image(self, directory, idx): 95 | from PIL import Image 96 | if os.path.exists( 97 | os.path.join( 98 | directory, 99 | self.image_tmpl.format(idx))): 100 | 101 | cv_img = Image.open( 102 | os.path.join( 103 | directory, 104 | self.image_tmpl.format(idx))).convert('RGB') 105 | else: 106 | print( 107 | 'no frames at ', 108 | os.path.join( 109 | directory, 110 | self.image_tmpl.format(idx))) 111 | while True: 112 | idx += 1 113 | if os.path.exists( 114 | os.path.join( 115 | directory, 116 | self.image_tmpl.format(idx))): 117 | cv_img = Image.open( 118 | os.path.join( 119 | directory, 120 | self.image_tmpl.format(idx))).convert('RGB') 121 | break 122 | return [cv_img] 123 | 124 | def _parse_list(self): 125 | self.video_list = [] 126 | 127 | if self.phase == 'Fntest': 128 | vid = 0 129 | for x in open(self.list_file): 130 | idx = 0 131 | for i in range(self.n_times): 132 | for j in range(self.crop_num): 133 | data = x.strip().split(' ')[0] 134 | name = data.split('/')[-1].split('.')[0] 135 | path = self.root_path 136 | if os.path.exists(os.path.join(path, name)): 137 | self.video_list.append(VideoRecord([name, x.split(' ')[1], x.split( 138 | ' ')[2]], self.root_path, phase='Val', copy_id=i, crop=j, vid=vid)) 139 | idx += 1 140 | vid += 1 141 | 142 | elif self.phase == 'Val': 143 | for x in open(self.list_file): 144 | data = x.strip().split(' ')[0] 145 | name = data.split('/')[-1].split('.')[0] 146 | path = self.root_path 147 | 148 | if os.path.exists(os.path.join(path, name)): 149 | self.video_list.append(VideoRecord( 150 | [name, x.split(' ')[1], x.split(' ')[2]], self.root_path, )) 151 | 152 | else: 153 | for x in open(self.list_file): 154 | data = x.strip().split(' ')[0] 155 | name = data.split('/')[-1].split('.')[0] 156 | path = self.root_path 157 | if os.path.exists(os.path.join(path, name)): 158 | self.video_list.append(VideoRecord( 159 | [name, x.split(' ')[1], x.split(' ')[2]], self.root_path, )) 160 | self.rng.shuffle(self.video_list) 161 | 162 | def _sample_indices(self, record): 163 | """ 164 | :param record: VideoRecord 165 | :return: list 166 | """ 167 | if self.style == "Dense": 168 | 169 | average_duration = record.num_frames // self.t_length 170 | 171 | if average_duration > 0: 172 | 173 | offsets = np.multiply(list(range(self.t_length)), 174 | average_duration) + randint(average_duration, 175 | size=self.t_length) 176 | elif record.num_frames > self.t_length: 177 | 178 | offsets = np.sort( 179 | randint( 180 | record.num_frames, 181 | size=self.t_length)) 182 | 183 | else: 184 | 185 | offsets = np.zeros((self.t_length,)) 186 | 187 | return {"dense": offsets + 1} 188 | 189 | def _get_val_indices(self, record): 190 | """ 191 | get indices in val phase 192 | """ 193 | if record.num_frames > self.t_length - 1: 194 | 195 | tick = (record.num_frames - 1) / float(self.t_length) 196 | 197 | offsets = np.array([int(tick / 2.0 + tick * x) 198 | for x in range(self.t_length)]) 199 | 200 | else: 201 | 202 | offsets = np.zeros((self.t_length,)) 203 | 204 | return {"dense": offsets + 1} 205 | 206 | def _get_test_indices(self, record): 207 | 208 | tick = (record.num_frames) / float(self.t_length) 209 | 210 | offsets = [[int(tick / 2.0 + tick * x) + 1for x in range(self.t_length)], 211 | [int(tick * x) + 1for x in range(self.t_length)]] 212 | return {"dense": offsets} 213 | 214 | def __getitem__(self, index): 215 | record = self.video_list[index] 216 | 217 | if self.phase == "Train": 218 | indices = self._sample_indices(record) 219 | 220 | return self.get(record, indices, self.phase, index) 221 | elif self.phase == "Val": 222 | indices = self._get_val_indices(record) 223 | 224 | return self.get(record, indices, self.phase, index) 225 | elif self.phase == "Fntest": 226 | 227 | indices = self._get_test_indices(record) 228 | idx = record.copy_id % self.n_times 229 | indices['dense'] = indices['dense'][idx] 230 | 231 | return self.get(record, indices, self.phase, index) 232 | else: 233 | raise TypeError("Unsuported phase {}".format(self.phase)) 234 | 235 | def get(self, record, indices, phase, index): 236 | # dense process data 237 | def dense_process_data(index): 238 | images = list() 239 | for ind in indices['dense']: 240 | ptr = int(ind) 241 | 242 | if ptr <= record.num_frames: 243 | imgs = self._load_image(record.path, ptr) 244 | else: 245 | imgs = self._load_image(record.path, record.num_frames) 246 | images.extend(imgs) 247 | 248 | if self.phase == 'Fntest': 249 | 250 | images = [np.asarray(im) for im in images] 251 | clip_input = np.concatenate(images, axis=2) 252 | 253 | self.t = transforms.Compose([ 254 | transforms.Resize(256)]) 255 | clip_input = self.t(clip_input) 256 | 257 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 258 | std=[0.229, 0.224, 0.225]) 259 | 260 | if record.crop_pos == 0: 261 | self.transform = transforms.Compose([ 262 | 263 | transforms.CenterCrop((256, 256)), 264 | 265 | transforms.ToTensor(), 266 | normalize, 267 | ]) 268 | elif record.crop_pos == 1: 269 | self.transform = transforms.Compose([ 270 | 271 | transforms.CornerCrop2((256, 256),), 272 | 273 | transforms.ToTensor(), 274 | normalize, 275 | ]) 276 | elif record.crop_pos == 2: 277 | self.transform = transforms.Compose([ 278 | transforms.CornerCrop1((256, 256)), 279 | transforms.ToTensor(), 280 | normalize, 281 | ]) 282 | 283 | return self.transform(clip_input) 284 | 285 | return self.transform(images) 286 | 287 | if phase == "Train": 288 | if self.style == "Dense": 289 | process_data = dense_process_data(index) 290 | 291 | elif phase in ("Val", "Test"): 292 | process_data = dense_process_data(index) 293 | return process_data, record.label, indices 294 | else: 295 | process_data = dense_process_data(index) 296 | return process_data, record.label, indices 297 | 298 | return process_data, record.label, indices 299 | 300 | def __len__(self): 301 | return len(self.video_list) 302 | 303 | 304 | if __name__ == "__main__": 305 | parser = argparse.ArgumentParser(description='SmallBig Training') 306 | parser.add_argument('--batch_size', default=1, type=int, 307 | help="Total batch size for training.") 308 | parser.add_argument('--t_length', default=8, type=int, 309 | help="Total length of sampling frames.") 310 | parser.add_argument('--t_stride', default=8, type=int, 311 | help="Temporal stride between each frame.") 312 | parser.add_argument('--num_clips', default=2, type=int, 313 | help="Total number of clips for training or testing.") 314 | parser.add_argument( 315 | '--crop_num', 316 | default=1, 317 | type=int, 318 | help="Total number of crops for each frame during full-resolution testing.") 319 | parser.add_argument('--image_tmpl', default='{:05d}.jpg', type=str, 320 | help="The name format of each frames you saved.") 321 | parser.add_argument('--seed', default=0, type=int, 322 | help="Random Seed") 323 | parser.add_argument( 324 | '--phase', 325 | default='Fntest', 326 | choices=[ 327 | "Train", 328 | "Val", 329 | "Fntest"], 330 | help="Different phases have different sampling methods.") 331 | parser.add_argument('--root_path', default='/dataset/sthv1', type=str, 332 | help='root path for accessing your image data') 333 | parser.add_argument( 334 | '--list_file', 335 | default='/dataset/sthv1_val.txt', 336 | type=str, 337 | help='path for your data list(txt)') 338 | 339 | args = parser.parse_args() 340 | 341 | transform = torchvision.transforms.Compose([ 342 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 343 | Stack(mode='3D'), 344 | ToTorchFormatTensor(), 345 | GroupNormalize(), 346 | ]) 347 | dataset = Someting_something( 348 | root_path=args.root_path, 349 | list_file=args.list_file, 350 | t_length=args.t_length, 351 | t_stride=args.t_stride, 352 | crop_num=args.crop_num, 353 | num_clips=args.num_clips, 354 | image_tmpl=args.image_tmpl, 355 | transform=transform, 356 | phase=args.phase, 357 | seed=args.seed) 358 | loader = torch.utils.data.DataLoader( 359 | dataset, 360 | batch_size=args.batch_size, 361 | shuffle=False, 362 | drop_last=True, 363 | num_workers=8, 364 | pin_memory=True) 365 | 366 | for ind, (data, label, indices) in enumerate(loader): 367 | label = label.cuda(non_blocking=True) 368 | print(indices) 369 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert (img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | 44 | 45 | class GroupRandomHorizontalFlip(object): 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 47 | There is no need to define an init function. 48 | """ 49 | 50 | def __call__(self, img_group): 51 | v = random.random() 52 | if v < 0.5: 53 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 54 | return ret 55 | else: 56 | return img_group 57 | 58 | 59 | class GroupNormalize(object): 60 | def __init__(self, 61 | mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]): 63 | self.mean = mean 64 | self.std = std 65 | 66 | def __call__(self, tensor): 67 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 68 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 69 | 70 | # TODO: make efficient 71 | for t, m, s in zip(tensor, rep_mean, rep_std): 72 | t.sub_(m).div_(s) 73 | 74 | return tensor 75 | 76 | 77 | class GroupScale(object): 78 | """ Rescales the input PIL.Image to the given 'size'. 79 | 'size' will be the size of the smaller edge. 80 | For example, if height > width, then image will be 81 | rescaled to (size * height / width, size) 82 | size: size of the smaller edge 83 | interpolation: Default: PIL.Image.BILINEAR 84 | """ 85 | 86 | def __init__(self, size, interpolation=Image.BILINEAR): 87 | self.worker = torchvision.transforms.Resize(size, interpolation) 88 | 89 | def __call__(self, img_group): 90 | return [self.worker(img) for img in img_group] 91 | 92 | 93 | class GroupRandomScale(object): 94 | """ Rescales the input PIL.Image to the given 'size'. 95 | 'size' will be the size of the smaller edge. 96 | For example, if height > width, then image will be 97 | rescaled to (size * height / width, size) 98 | size: size of the smaller edge 99 | interpolation: Default: PIL.Image.BILINEAR 100 | """ 101 | 102 | def __init__( 103 | self, 104 | smallest_size=256, 105 | largest_size=320, 106 | interpolation=Image.BILINEAR): 107 | self.smallest_size = smallest_size 108 | self.largest_size = largest_size 109 | self.interpolation = interpolation 110 | 111 | def __call__(self, img_group): 112 | size = random.randint(self.smallest_size, self.largest_size) 113 | # print(size) 114 | self.worker = torchvision.transforms.Resize(size, self.interpolation) 115 | return [self.worker(img) for img in img_group] 116 | 117 | 118 | class GroupOverSample(object): 119 | def __init__(self, crop_size, scale_size=None): 120 | self.crop_size = crop_size if not isinstance( 121 | crop_size, int) else (crop_size, crop_size) 122 | 123 | if scale_size is not None: 124 | self.scale_worker = GroupScale(scale_size) 125 | else: 126 | self.scale_worker = None 127 | 128 | def __call__(self, img_group): 129 | 130 | if self.scale_worker is not None: 131 | img_group = self.scale_worker(img_group) 132 | 133 | image_w, image_h = img_group[0].size 134 | crop_w, crop_h = self.crop_size 135 | 136 | offsets = GroupMultiScaleCrop.fill_fix_offset( 137 | False, image_w, image_h, crop_w, crop_h) 138 | oversample_group = list() 139 | for o_w, o_h in offsets: 140 | normal_group = list() 141 | flip_group = list() 142 | for i, img in enumerate(img_group): 143 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 144 | normal_group.append(crop) 145 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 146 | flip_group.append(flip_crop) 147 | 148 | oversample_group.extend(normal_group) 149 | oversample_group.extend(flip_group) 150 | return oversample_group 151 | 152 | 153 | class GroupOverSampleKaiming(object): 154 | def __init__(self, crop_size, scale_size=None): 155 | self.crop_size = crop_size if not isinstance( 156 | crop_size, int) else (crop_size, crop_size) 157 | 158 | if scale_size is not None: 159 | self.scale_worker = GroupScale(scale_size) 160 | else: 161 | self.scale_worker = None 162 | 163 | def __call__(self, img_group): 164 | 165 | if self.scale_worker is not None: 166 | img_group = self.scale_worker(img_group) 167 | 168 | image_w, image_h = img_group[0].size 169 | crop_w, crop_h = self.crop_size 170 | 171 | offsets = self.fill_fix_offset(image_w, image_h, crop_w, crop_h) 172 | oversample_group = list() 173 | for o_w, o_h in offsets: 174 | normal_group = list() 175 | # flip_group = list() 176 | for i, img in enumerate(img_group): 177 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 178 | normal_group.append(crop) 179 | # flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 180 | # flip_group.append(flip_crop) 181 | 182 | oversample_group.extend(normal_group) 183 | # oversample_group.extend(flip_group) 184 | return oversample_group 185 | 186 | def fill_fix_offset(self, image_w, image_h, crop_w, crop_h): 187 | # assert(crop_h == image_h), "In Kaiming mode, crop_h should equal to image_h" 188 | ret = list() 189 | if image_w == 256: 190 | h_step = (image_h - crop_h) // 4 191 | ret.append((0, 0)) # upper 192 | ret.append((0, 4 * h_step)) # down 193 | ret.append((0, 2 * h_step)) # center 194 | elif image_h == 256: 195 | w_step = (image_w - crop_w) // 4 196 | ret.append((0, 0)) # left 197 | ret.append((4 * w_step, 0)) # right 198 | ret.append((2 * w_step, 0)) # center 199 | else: 200 | raise ValueError( 201 | "Either image_w or image_h should be equal to 256") 202 | 203 | return ret 204 | 205 | 206 | class GroupMultiScaleCrop(object): 207 | 208 | def __init__( 209 | self, 210 | input_size, 211 | scales=None, 212 | max_distort=1, 213 | fix_crop=True, 214 | more_fix_crop=True): 215 | self.input_size = input_size if not isinstance(input_size, int) else [ 216 | input_size, input_size] 217 | self.scales = scales if scales is not None else [1, .875, .75, .66] 218 | self.max_distort = max_distort 219 | self.fix_crop = fix_crop 220 | self.more_fix_crop = more_fix_crop 221 | self.interpolation = Image.BILINEAR 222 | 223 | def __call__(self, img_group): 224 | 225 | im_size = img_group[0].size 226 | 227 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 228 | crop_img_group = [ 229 | img.crop( 230 | (offset_w, 231 | offset_h, 232 | offset_w + 233 | crop_w, 234 | offset_h + 235 | crop_h)) for img in img_group] 236 | ret_img_group = [ 237 | img.resize( 238 | (self.input_size[0], 239 | self.input_size[1]), 240 | self.interpolation) for img in crop_img_group] 241 | return ret_img_group 242 | 243 | def _sample_crop_size(self, im_size): 244 | image_w, image_h = im_size[0], im_size[1] 245 | 246 | # find a crop size 247 | base_size = min(image_w, image_h) 248 | crop_sizes = [int(base_size * x) for x in self.scales] 249 | crop_h = [ 250 | self.input_size[1] if abs( 251 | x - self.input_size[1]) < 3 else x for x in crop_sizes] 252 | crop_w = [ 253 | self.input_size[0] if abs( 254 | x - self.input_size[0]) < 3 else x for x in crop_sizes] 255 | 256 | pairs = [] 257 | for i, h in enumerate(crop_h): 258 | for j, w in enumerate(crop_w): 259 | if abs(i - j) <= self.max_distort: 260 | pairs.append((w, h)) 261 | 262 | crop_pair = random.choice(pairs) 263 | if not self.fix_crop: 264 | w_offset = random.randint(0, image_w - crop_pair[0]) 265 | h_offset = random.randint(0, image_h - crop_pair[1]) 266 | else: 267 | w_offset, h_offset = self._sample_fix_offset( 268 | image_w, image_h, crop_pair[0], crop_pair[1]) 269 | 270 | return crop_pair[0], crop_pair[1], w_offset, h_offset 271 | 272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 273 | offsets = self.fill_fix_offset( 274 | self.more_fix_crop, image_w, image_h, crop_w, crop_h) 275 | return random.choice(offsets) 276 | 277 | @staticmethod 278 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 279 | w_step = (image_w - crop_w) // 4 280 | h_step = (image_h - crop_h) // 4 281 | 282 | ret = list() 283 | ret.append((0, 0)) # upper left 284 | ret.append((4 * w_step, 0)) # upper right 285 | ret.append((0, 4 * h_step)) # lower left 286 | ret.append((4 * w_step, 4 * h_step)) # lower right 287 | ret.append((2 * w_step, 2 * h_step)) # center 288 | 289 | if more_fix_crop: 290 | ret.append((0, 2 * h_step)) # center left 291 | ret.append((4 * w_step, 2 * h_step)) # center right 292 | ret.append((2 * w_step, 4 * h_step)) # lower center 293 | ret.append((2 * w_step, 0 * h_step)) # upper center 294 | 295 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 296 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 297 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 298 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 299 | 300 | return ret 301 | 302 | 303 | class GroupRandomSizedCrop(object): 304 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 305 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 306 | This is popularly used to train the Inception networks 307 | size: size of the smaller edge 308 | interpolation: Default: PIL.Image.BILINEAR 309 | """ 310 | 311 | def __init__(self, size, interpolation=Image.BILINEAR): 312 | self.size = size 313 | self.interpolation = interpolation 314 | 315 | def __call__(self, img_group): 316 | for attempt in range(10): 317 | area = img_group[0].size[0] * img_group[0].size[1] 318 | target_area = random.uniform(0.08, 1.0) * area 319 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 320 | 321 | w = int(round(math.sqrt(target_area * aspect_ratio))) 322 | h = int(round(math.sqrt(target_area / aspect_ratio))) 323 | 324 | if random.random() < 0.5: 325 | w, h = h, w 326 | 327 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 328 | x1 = random.randint(0, img_group[0].size[0] - w) 329 | y1 = random.randint(0, img_group[0].size[1] - h) 330 | found = True 331 | break 332 | else: 333 | found = False 334 | x1 = 0 335 | y1 = 0 336 | 337 | if found: 338 | out_group = list() 339 | for img in img_group: 340 | img = img.crop((x1, y1, x1 + w, y1 + h)) 341 | assert (img.size == (w, h)) 342 | out_group.append( 343 | img.resize( 344 | (self.size, self.size), self.interpolation)) 345 | return out_group 346 | else: 347 | # Fallback 348 | scale = GroupScale(self.size, interpolation=self.interpolation) 349 | crop = GroupRandomCrop(self.size) 350 | return crop(scale(img_group)) 351 | 352 | 353 | class Stack(object): 354 | 355 | def __init__(self, mode="3D"): 356 | """Support modes: ["3D", "TSN", "2D", "TSN+3D"] 357 | """ 358 | assert (mode in ["3D", "TSN+2D", "2D", "TSN+3D"] 359 | ), "Unsupported mode: {}".format() 360 | self.mode = mode 361 | 362 | def __call__(self, img_group): 363 | """Only support RGB mode now 364 | img_group: list([h, w, c]) 365 | """ 366 | assert (img_group[0].mode == 'RGB'), "Must read images in RGB mode." 367 | if "3D" in self.mode: 368 | imgs = np.concatenate([np.array(img)[np.newaxis, ...] 369 | for img in img_group], axis=0) 370 | imgs = torch.from_numpy(imgs).permute(3, 0, 1, 2).contiguous() 371 | elif "2D" in self.mode: 372 | imgs = np.concatenate([np.array(img) for img in img_group], axis=2) 373 | imgs = torch.from_numpy(imgs).permute(2, 0, 1).contiguous() 374 | else: 375 | raise Exception("Unsupported mode.") 376 | return imgs 377 | 378 | 379 | class ToTorchFormatTensor(object): 380 | """ Converts a torch.Tensor in the range [0, 255] 381 | to a torch.FloatTensor in the range [0.0, 1.0] """ 382 | 383 | def __init__(self, div=True): 384 | self.div = div 385 | 386 | def __call__(self, imgs): 387 | assert (isinstance(imgs, torch.Tensor)), "pic must be torch.Tensor." 388 | return imgs.float().div(255) if self.div else img.float() 389 | 390 | 391 | class IdentityTransform(object): 392 | 393 | def __call__(self, data): 394 | return data 395 | 396 | 397 | if __name__ == "__main__": 398 | trans = torchvision.transforms.Compose([ 399 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 400 | Stack(mode="2D"), 401 | ToTorchFormatTensor(), 402 | GroupNormalize()] 403 | ) 404 | 405 | im = Image.open('/home/leizhou/CVPR2019/vid_cls/lena.png') 406 | 407 | color_group = [im] 408 | rst = trans(color_group) 409 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | 6 | 7 | class Compose(object): 8 | """Composes several video_transforms together. 9 | Args: 10 | transforms (List[Transform]): list of transforms to compose. 11 | Example: 12 | >>> video_transforms.Compose([ 13 | >>> video_transforms.CenterCrop(10), 14 | >>> video_transforms.ToTensor(), 15 | >>> ]) 16 | """ 17 | 18 | def __init__(self, transforms, aug_seed=0): 19 | self.transforms = transforms 20 | 21 | for i, t in enumerate(self.transforms): 22 | t.set_random_state(seed=(aug_seed + i)) 23 | 24 | def __call__(self, data): 25 | for t in self.transforms: 26 | data = t(data) 27 | return data 28 | 29 | 30 | class Transform(object): 31 | """basse class for all transformation""" 32 | 33 | def set_random_state(self, seed=None): 34 | self.rng = np.random.RandomState(seed) 35 | 36 | 37 | #################################### 38 | # Customized Transformations 39 | #################################### 40 | 41 | class Normalize(Transform): 42 | """Given mean: (R, G, B) and std: (R, G, B), 43 | will normalize each channel of the torch.*Tensor, i.e. 44 | channel = (channel - mean) / std 45 | """ 46 | 47 | def __init__(self, mean, std): 48 | self.mean = mean 49 | self.std = std 50 | 51 | def __call__(self, tensor): 52 | for t, m, s in zip(tensor, self.mean, self.std): 53 | t.sub_(m).div_(s) 54 | return tensor 55 | 56 | 57 | class Resize(Transform): 58 | """ Rescales the input numpy array to the given 'size'. 59 | 'size' will be the size of the smaller edge. 60 | For example, if height > width, then image will be 61 | rescaled to (size * height / width, size) 62 | size: size of the smaller edge 63 | interpolation: Default: cv2.INTER_LINEAR 64 | """ 65 | 66 | def __init__(self, size, interpolation=cv2.INTER_LINEAR): 67 | self.size = size # [w, h] 68 | self.interpolation = interpolation 69 | 70 | def __call__(self, data): 71 | h, w, c = data.shape 72 | 73 | if isinstance(self.size, int): 74 | slen = self.size 75 | if min(w, h) == slen: 76 | return data 77 | if w < h: 78 | new_w = self.size 79 | new_h = int(self.size * h / w) 80 | else: 81 | new_w = int(self.size * w / h) 82 | new_h = self.size 83 | else: 84 | new_w = self.size[0] 85 | new_h = self.size[1] 86 | 87 | if (h != new_h) or (w != new_w): 88 | scaled_data = cv2.resize(data, (new_w, new_h), self.interpolation) 89 | else: 90 | scaled_data = data 91 | 92 | return scaled_data 93 | 94 | 95 | class RandomScale_nonlocal(Transform): 96 | """ Rescales the input numpy array to the given 'size'. 97 | 'size' will be the size of the smaller edge. 98 | For example, if height > width, then image will be 99 | rescaled to (size * height / width, size) 100 | size: size of the smaller edge 101 | interpolation: Default: cv2.INTER_LINEAR 102 | """ 103 | 104 | def __init__(self, 105 | 106 | slen=[224, 288], 107 | interpolation=cv2.INTER_LINEAR): 108 | 109 | self.slen = slen # [min factor, max factor] 110 | 111 | self.interpolation = interpolation 112 | self.rng = np.random.RandomState(0) 113 | 114 | def __call__(self, data): 115 | random_slen = self.rng.uniform(self.slen[0], self.slen[1]) 116 | resize = Resize(int(random_slen)) 117 | scaled_data = resize(data) 118 | return scaled_data 119 | 120 | 121 | class RandomScale(Transform): 122 | """ Rescales the input numpy array to the given 'size'. 123 | 'size' will be the size of the smaller edge. 124 | For example, if height > width, then image will be 125 | rescaled to (size * height / width, size) 126 | size: size of the smaller edge 127 | interpolation: Default: cv2.INTER_LINEAR 128 | """ 129 | 130 | def __init__(self, make_square=False, 131 | aspect_ratio=[1.0, 1.0], 132 | slen=[224, 288], 133 | interpolation=cv2.INTER_LINEAR): 134 | # assert slen[1] >= slen[0], \ 135 | # "slen ({}) should be in increase order".format(scale) 136 | # assert aspect_ratio[1] >= aspect_ratio[0], \ 137 | # "aspect_ratio ({}) should be in increase order".format(aspect_ratio) 138 | self.slen = slen # [min factor, max factor] 139 | self.aspect_ratio = aspect_ratio 140 | self.make_square = make_square 141 | self.interpolation = interpolation 142 | self.rng = np.random.RandomState(0) 143 | 144 | def __call__(self, data): 145 | h, w, c = data.shape 146 | new_w = w 147 | new_h = h if not self.make_square else w 148 | if self.aspect_ratio: 149 | random_aspect_ratio = self.rng.uniform( 150 | self.aspect_ratio[0], self.aspect_ratio[1]) 151 | if self.rng.rand() > 0.5: 152 | random_aspect_ratio = 1.0 / random_aspect_ratio 153 | new_w *= random_aspect_ratio 154 | new_h /= random_aspect_ratio 155 | resize_factor = self.rng.uniform( 156 | self.slen[0], self.slen[1]) / min(new_w, new_h) 157 | new_w *= resize_factor 158 | new_h *= resize_factor 159 | scaled_data = cv2.resize( 160 | data, (int(new_w + 1), int(new_h + 1)), self.interpolation) 161 | return scaled_data 162 | 163 | 164 | class CornerCrop1(Transform): 165 | """Crops the given numpy array at the center to have a region of 166 | the given size. size can be a tuple (target_height, target_width) 167 | or an integer, in which case the target will be of a square shape (size, size) 168 | """ 169 | 170 | def __init__(self, size): 171 | if isinstance(size, int): 172 | self.size = (size, size) 173 | else: 174 | self.size = size 175 | 176 | def __call__(self, data): 177 | h, w, c = data.shape 178 | th, tw = self.size 179 | x1 = int(round((w - tw)) / 4) 180 | y1 = int(round((h - th)) / 4) 181 | x1 = 0 182 | y1 = 0 183 | # if x1==0 and y1!=0: 184 | # y1=int((y1*3)/4)) 185 | # if 186 | cropped_data = data[y1:(y1 + th), x1:(x1 + tw), :] 187 | return cropped_data 188 | 189 | 190 | class CornerCrop2(Transform): 191 | """Crops the given numpy array at the center to have a region of 192 | the given size. size can be a tuple (target_height, target_width) 193 | or an integer, in which case the target will be of a square shape (size, size) 194 | """ 195 | 196 | def __init__(self, size): 197 | if isinstance(size, int): 198 | self.size = (size, size) 199 | else: 200 | self.size = size 201 | 202 | def __call__(self, data): 203 | h, w, c = data.shape 204 | th, tw = self.size 205 | x1 = int(round((w - tw))) 206 | y1 = int(round((h - th))) 207 | #x1=int(round((w - tw))) 208 | #y1=int(round((w - tw))) 209 | cropped_data = data[y1:(y1 + th), x1:(x1 + tw), :] 210 | return cropped_data 211 | 212 | 213 | class CenterCrop(Transform): 214 | """Crops the given numpy array at the center to have a region of 215 | the given size. size can be a tuple (target_height, target_width) 216 | or an integer, in which case the target will be of a square shape (size, size) 217 | """ 218 | 219 | def __init__(self, size): 220 | if isinstance(size, int): 221 | self.size = (size, size) 222 | else: 223 | self.size = size 224 | 225 | def __call__(self, data): 226 | h, w, c = data.shape 227 | th, tw = self.size 228 | x1 = int(round((w - tw) / 2.)) 229 | y1 = int(round((h - th) / 2.)) 230 | cropped_data = data[y1:(y1 + th), x1:(x1 + tw), :] 231 | return cropped_data 232 | 233 | 234 | class GroupCrop(Transform): 235 | """Crops the given numpy array at the center to have a region of 236 | the given size. size can be a tuple (target_height, target_width) 237 | or an integer, in which case the target will be of a square shape (size, size) 238 | """ 239 | 240 | def __init__(self, size): 241 | if isinstance(size, int): 242 | self.size = (size, size) 243 | else: 244 | self.size = size 245 | 246 | def __call__(self, data, crop_time=3): 247 | h, w, c = data.shape 248 | th, tw = self.size 249 | img = [] 250 | 251 | x1 = [np.random.randint(0, w - tw)for i in range(crop_time)] 252 | y1 = [0 for i in range(crop_time)] 253 | 254 | for i in range(crop_time): 255 | 256 | cropped_data = data[y1[i]:(y1[i] + th), x1[i]:(x1[i] + tw), :] 257 | img.append(cropped_data) 258 | return np.concatenate(img, axis=2) 259 | 260 | 261 | class Crop(Transform): 262 | """Crops the given numpy array at the random location to have a region of 263 | the given size. size can be a tuple (target_height, target_width) 264 | or an integer, in which case the target will be of a square shape (size, size) 265 | """ 266 | 267 | def __init__(self, size, crop): 268 | 269 | self.size = size 270 | self.crop = crop 271 | self.rng = np.random.RandomState(0) 272 | 273 | def __call__(self, data): 274 | h, w, c = data.shape 275 | tw = self.size[0] 276 | th = self.size[1] 277 | x1 = self.crop[1] 278 | y1 = self.crop[0] 279 | cropped_data = data[y1:(y1 + th), x1:(x1 + tw), :] 280 | return cropped_data 281 | 282 | 283 | class RandomCrop(Transform): 284 | """Crops the given numpy array at the random location to have a region of 285 | the given size. size can be a tuple (target_height, target_width) 286 | or an integer, in which case the target will be of a square shape (size, size) 287 | """ 288 | 289 | def __init__(self, size): 290 | 291 | self.size = size 292 | self.rng = np.random.RandomState(0) 293 | 294 | def __call__(self, data): 295 | h, w, c = data.shape 296 | tw = self.size[0] 297 | th = self.size[1] 298 | # p=w-tw 299 | # q=h-th 300 | # if p!=0: 301 | # x1 = self.rng.choice(range(w - tw)) 302 | # y1 = 0 303 | # elif q!=0: 304 | # x1 = 0 305 | # y1 = self.rng.choice(range(h - th)) 306 | # elif p==0 and q==0: 307 | # x1=0 308 | # y1=0 309 | if tw < w and th < h: 310 | x1 = self.rng.choice(range(w - tw)) 311 | y1 = self.rng.choice(range(h - th)) 312 | # cropped_data = data[y1:(y1+th), x1:(x1+tw), :] 313 | # else: 314 | 315 | # resize=Resize([th,tw]) 316 | # cropped_data = resize(data) 317 | cropped_data = data[y1:(y1 + th), x1:(x1 + tw), :] 318 | 319 | return cropped_data 320 | 321 | 322 | class RandomHorizontalFlip(Transform): 323 | """Randomly horizontally flips the given numpy array with a probability of 0.5 324 | """ 325 | 326 | def __init__(self): 327 | self.rng = np.random.RandomState(0) 328 | 329 | def __call__(self, data): 330 | if self.rng.rand() < 0.5: 331 | data = np.fliplr(data) 332 | data = np.ascontiguousarray(data) 333 | return data 334 | 335 | 336 | class RandomVerticalFlip(Transform): 337 | """Randomly vertically flips the given numpy array with a probability of 0.5 338 | """ 339 | 340 | def __init__(self): 341 | self.rng = np.random.RandomState(0) 342 | 343 | def __call__(self, data): 344 | if self.rng.rand() < 0.5: 345 | data = np.flipud(data) 346 | data = np.ascontiguousarray(data) 347 | return data 348 | 349 | 350 | class RandomRGB(Transform): 351 | def __init__(self, vars=[10, 10, 10]): 352 | self.vars = vars 353 | self.rng = np.random.RandomState(0) 354 | 355 | def __call__(self, data): 356 | h, w, c = data.shape 357 | 358 | random_vars = [int(round(self.rng.uniform(-x, x))) for x in self.vars] 359 | 360 | base = len(random_vars) 361 | augmented_data = np.zeros(data.shape) 362 | for ic in range(0, c): 363 | var = random_vars[ic % base] 364 | augmented_data[:, :, ic] = np.minimum( 365 | np.maximum(data[:, :, ic] + var, 0), 255) 366 | return augmented_data 367 | 368 | 369 | class RandomHLS(Transform): 370 | def __init__(self, vars=[15, 35, 25]): 371 | self.vars = vars 372 | self.rng = np.random.RandomState(0) 373 | 374 | def __call__(self, data): 375 | h, w, c = data.shape 376 | assert c % 3 == 0, "input channel = %d, illegal" % c 377 | 378 | random_vars = [int(round(self.rng.uniform(-x, x))) for x in self.vars] 379 | 380 | base = len(random_vars) 381 | augmented_data = np.zeros(data.shape, ) 382 | 383 | for i_im in range(0, int(c / 3)): 384 | augmented_data[:, 385 | :, 386 | 3 * i_im:(3 * i_im + 3)] = cv2.cvtColor(data[:, 387 | :, 388 | 3 * i_im:(3 * i_im + 3)], 389 | cv2.COLOR_RGB2HLS) 390 | 391 | hls_limits = [180, 255, 255] 392 | for ic in range(0, c): 393 | var = random_vars[ic % base] 394 | limit = hls_limits[ic % base] 395 | augmented_data[:, :, ic] = np.minimum( 396 | np.maximum(augmented_data[:, :, ic] + var, 0), limit) 397 | 398 | for i_im in range(0, int(c / 3)): 399 | augmented_data[:, :, 3 * 400 | i_im:(3 * 401 | i_im + 402 | 3)] = cv2.cvtColor(augmented_data[:, :, 3 * 403 | i_im:(3 * 404 | i_im + 405 | 3)].astype(np.uint8), cv2.COLOR_HLS2RGB) 406 | 407 | return augmented_data 408 | 409 | 410 | class ToTensor(Transform): 411 | """Converts a numpy.ndarray (H x W x C) in the range 412 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 413 | """ 414 | 415 | def __init__(self, dim=3): 416 | self.dim = dim 417 | 418 | def __call__(self, image): 419 | if isinstance(image, np.ndarray): 420 | # H, W, C = image.shape 421 | # handle numpy array 422 | image = torch.from_numpy(image.transpose((2, 0, 1))) 423 | # backward compatibility 424 | return image.float() / 255.0 425 | 426 | 427 | class ToTensor(Transform): 428 | """Converts a numpy.ndarray (H x W x (T x C)) in the range 429 | [0, 255] to a torch.FloatTensor of shape (C x T x H x W) in the range [0.0, 1.0]. 430 | """ 431 | 432 | def __init__(self, dim=3): 433 | self.dim = dim 434 | 435 | def __call__(self, clips): 436 | if isinstance(clips, np.ndarray): 437 | H, W, _ = clips.shape 438 | # handle numpy array 439 | clips = torch.from_numpy(clips.reshape( 440 | (H, W, -1, self.dim)).transpose((3, 2, 0, 1))) 441 | # backward compatibility 442 | return clips.float() / 255.0 443 | class To_3DTensor(Transform): 444 | 445 | def __init__(self, dim=2): 446 | self.dim = 2 447 | 448 | def __call__(self, images): 449 | if isinstance(images, torch.Tensor): 450 | images = images.unsqueeze(1) 451 | # backward compatibility 452 | return images 453 | class Grayscale(object): 454 | 455 | def __call__(self, img): 456 | gs = img.clone() 457 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 458 | gs[1].copy_(gs[0]) 459 | gs[2].copy_(gs[0]) 460 | return gs 461 | 462 | 463 | class Saturation(object): 464 | 465 | def __init__(self, var): 466 | self.var = var 467 | 468 | def __call__(self, img): 469 | gs = Grayscale()(img) 470 | alpha = np.random.uniform(-self.var, self.var) 471 | return img.lerp(gs, alpha) 472 | 473 | 474 | class Lighting(object): 475 | """ 476 | Lighting noise(AlexNet - style PCA - based noise). 477 | """ 478 | def __init__(self, alphastd, eigval, eigvec): 479 | self.alphastd = alphastd 480 | self.eigval = eigval 481 | self.eigvec = eigvec 482 | 483 | def __call__(self, img): 484 | if self.alphastd == 0: 485 | return img 486 | 487 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 488 | rgb = self.eigvec.type_as(img).clone()\ 489 | .mul(alpha.view(1, 3).expand(3, 3))\ 490 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 491 | .sum(1).squeeze() 492 | 493 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 494 | class Brightness(object): 495 | 496 | def __init__(self, var): 497 | self.var = var 498 | 499 | def __call__(self, img): 500 | gs = img.new().resize_as_(img).zero_() 501 | alpha = np.random.uniform(-self.var, self.var) 502 | return img.lerp(gs, alpha) 503 | 504 | class Contrast(object): 505 | 506 | def __init__(self, var): 507 | self.var = var 508 | 509 | def __call__(self, img): 510 | gs = Grayscale()(img) 511 | gs.fill_(gs.mean()) 512 | alpha = np.random.uniform(-self.var, self.var) 513 | return img.lerp(gs, alpha) 514 | 515 | class ColorJitter(object): 516 | 517 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 518 | self.brightness = brightness 519 | self.contrast = contrast 520 | self.saturation = saturation 521 | 522 | def __call__(self, img): 523 | self.transforms = [] 524 | if self.brightness != 0: 525 | self.transforms.append(Brightness(self.brightness)) 526 | if self.contrast != 0: 527 | self.transforms.append(Contrast(self.contrast)) 528 | if self.saturation != 0: 529 | self.transforms.append(Saturation(self.saturation)) 530 | 531 | np.random.shuffle(self.transforms) 532 | transform = Compose(self.transforms) 533 | # print(transform) 534 | return transform(img) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from tensorboardX import SummaryWriter 8 | 9 | from dataset.prepare_dataset import get_dataloader 10 | from models.SmallBig import get_model 11 | 12 | from tools.tools import save_checkpoint, is_main_process 13 | 14 | from train_val import train, validate 15 | 16 | parser = argparse.ArgumentParser(description='SmallBig Training') 17 | parser.add_argument('--batch_size', default=1, type=int, 18 | help="Total batch size for training.") 19 | parser.add_argument('--t_length', default=8, type=int, 20 | help="Total length of sampling frames.") 21 | parser.add_argument('--t_stride', default=8, type=int, 22 | help="Temporal stride between each frame.") 23 | parser.add_argument('--num_clips', default=1, type=int, 24 | help="Total number of clips for training or testing.") 25 | parser.add_argument( 26 | '--crop_num', 27 | default=1, 28 | type=int, 29 | help="Total number of crops for each frame during full-resolution testing.") 30 | parser.add_argument('--image_tmpl', default='image_{:06d}.jpg', type=str, 31 | help="The name format of each frames you saved.") 32 | parser.add_argument('--seed', default=0, type=int, 33 | help="Random Seed") 34 | parser.add_argument( 35 | '--dataset', 36 | default='kinetics', 37 | choices=[ 38 | "kinetics", 39 | "imagenet", 40 | "something"], 41 | help="Choose dataset for training and validation") 42 | parser.add_argument( 43 | '--phase', 44 | default='Val', 45 | choices=[ 46 | "Train", 47 | "Val", 48 | "Fntest"], 49 | help="Different phases have different sampling methods.") 50 | parser.add_argument('--root_path', default='/dataset/kinetics', type=str, 51 | help='root path for accessing your image data') 52 | parser.add_argument('--val_list_file', default='/dataset/kinetics/val.txt', 53 | type=str, help='path for your data list(txt)') 54 | parser.add_argument( 55 | '--train_list_file', 56 | default='/dataset/kinetics/train.txt', 57 | type=str, 58 | help='path for your data list(txt)') 59 | parser.add_argument('--model_name', default="smallbig50_no_extra", 60 | choices=[ 61 | "res50", 62 | "slowonly50", 63 | "slowonly50_extra", 64 | "smallbig23_no_extra", 65 | "smallbig50_no_extra", 66 | "smallbig101_no_extra", 67 | "smallbig50_extra"], 68 | help="name of your model") 69 | parser.add_argument('--local_rank', type=int, default=0, 70 | help='node rank for distributed training') 71 | parser.add_argument('--distribute', action='store_true') 72 | parser.add_argument('--print_freq', type=int, default=50, 73 | help='print frequency') 74 | parser.add_argument('--test', action='store_true') 75 | parser.add_argument('--feat', action='store_true') 76 | parser.add_argument('--half', action='store_true') 77 | parser.add_argument('--imagenet', action='store_false') 78 | parser.add_argument('--num_classes', default=400, type=int, 79 | help="num classes of your dataset") 80 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 81 | parser.add_argument('--num_epochs', default=100, type=int) 82 | parser.add_argument('--resume', type=str, 83 | help="Checkpoint path that you want to restart training.") 84 | parser.add_argument('--check_dir', type=str, 85 | help="Location to store your model") 86 | parser.add_argument('--log_dir', type=str, 87 | help="Location to store your logs") 88 | 89 | 90 | def set_logger(args): 91 | import time 92 | logdir = os.path.join(args.log_dir, args.model_name) 93 | if not os.path.exists(logdir): 94 | os.makedirs(logdir, exist_ok=True) 95 | log_file = args.model_name + '_' + args.dataset + '_t_length_' + str(args.t_length) + '_t_stride_' + str(args.t_stride) + '_batch_' + str( 96 | args.batch_size) + '_lr_' + str(args.lr) + "_logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", time.localtime()) 97 | log_file = os.path.join(logdir, log_file) 98 | if not os.path.exists(log_file): 99 | os.makedirs(log_file, exist_ok=True) 100 | log_file = os.path.join(log_file, "logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", 101 | time.localtime())) 102 | handlers = [logging.FileHandler(log_file), logging.StreamHandler()] 103 | 104 | """ add '%(filename)s:%(lineno)d %(levelname)s:' to format show source file """ 105 | logging.basicConfig(level= logging.INFO, 106 | format='%(asctime)s: %(message)s', 107 | datefmt='%Y-%m-%d %H:%M:%S', 108 | handlers=handlers) 109 | 110 | 111 | def train_model(args): 112 | global best_metric, epoch_resume 113 | epoch_resume = 0 114 | best_metric = 0 115 | model = get_model(args) 116 | 117 | if args.distribute: 118 | model = model.cuda() 119 | model = torch.nn.parallel.DistributedDataParallel( 120 | model, device_ids=[args.local_rank]) 121 | else: 122 | model = torch.nn.DataParallel(model).cuda() 123 | writer = None 124 | if is_main_process(): 125 | log_file = args.model_name + '_' + args.dataset + '_t_length_' + str(args.t_length) + '_t_stride_' + str( 126 | args.t_stride) + '_batch_' + str( 127 | args.batch_size) + '_lr_' + str(args.lr) + "_logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", 128 | time.localtime()) 129 | log_file = os.path.join(args.log_dir, args.model_name, log_file) 130 | writer = SummaryWriter(log_dir=log_file) 131 | print(model) 132 | dataloaders, dataset_sizes, samplers = get_dataloader(args) 133 | optimizer = torch.optim.SGD( 134 | model.parameters(), 135 | lr=args.lr, 136 | weight_decay=1e-4, 137 | momentum=0.9) 138 | criterion = nn.CrossEntropyLoss().cuda() 139 | 140 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 141 | optimizer, T_max=args.num_epochs) 142 | if args.resume: 143 | checkpoint = torch.load(args.resume, map_location='cpu') 144 | epoch_resume = checkpoint['epoch'] 145 | best_metric = checkpoint['best_metric'] 146 | model_dict = model.state_dict() 147 | idx = 0 148 | print(len(model_dict)) 149 | print(len(checkpoint['state_dict'])) 150 | for k, v in checkpoint['state_dict'].items(): 151 | k = k.replace('module.', '') 152 | if k in model_dict: 153 | if v.shape == model_dict[k].shape: 154 | model_dict[k] = v.cuda() 155 | idx += 1 156 | print(idx) 157 | print('upload parameter already') 158 | model.load_state_dict(model_dict) 159 | optimizer.load_state_dict(checkpoint['optimizer']) 160 | print(("=> loaded checkpoint '{}' (epoch {})" 161 | .format(args.resume, checkpoint['epoch']))) 162 | print(best_metric) 163 | elif is_main_process(): 164 | print(("=> no checkpoint found at '{}'".format(args.resume))) 165 | 166 | for epoch in range(epoch_resume, args.num_epochs): 167 | if args.distribute: 168 | samplers['train'].set_epoch(epoch) 169 | samplers['val'].set_epoch(epoch) 170 | end = time.time() 171 | train(dataloaders['train'], model, criterion, optimizer, 172 | epoch, args.print_freq, writer, args=args) 173 | scheduler.step() 174 | if epoch >= 0: 175 | metric = validate( 176 | dataloaders['val'], 177 | model, 178 | criterion, 179 | args.print_freq, 180 | epoch + 1, 181 | writer, 182 | args=args) 183 | if is_main_process(): 184 | print(metric) 185 | # remember best prec@1 and save checkpoint 186 | is_best = metric > best_metric 187 | best_metric = max(metric, best_metric) 188 | print(best_metric) 189 | save_checkpoint({ 190 | 'epoch': epoch + 1, 191 | 'state_dict': model.state_dict(), 192 | 'best_metric': best_metric, 193 | 'optimizer': optimizer.state_dict(), 194 | }, is_best, 195 | str('current'), 196 | args.check_dir, 197 | args = args, 198 | name=args.model_name) 199 | 200 | time_elapsed = time.time() - end 201 | if is_main_process(): 202 | print( 203 | f"Training complete in {time_elapsed//3600}h {(time_elapsed%3600)//60}m {time_elapsed %60}s") 204 | 205 | 206 | if __name__ == '__main__': 207 | args = parser.parse_args() 208 | if is_main_process(): 209 | set_logger(args) 210 | logging.info(args) 211 | if args.distribute: 212 | torch.cuda.set_device(args.local_rank) 213 | torch.distributed.init_process_group( 214 | 'nccl', 215 | init_method='env://' 216 | ) 217 | torch.manual_seed(1) 218 | torch.cuda.manual_seed_all(1) 219 | torch.backends.cudnn.deterministic = True 220 | train_model(args) 221 | -------------------------------------------------------------------------------- /models/SmallBig.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchvision 3 | 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.utils import model_zoo 7 | 8 | from models.blocks import * 9 | from tools.tools import * 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | class SmallBigNet(nn.Module): 21 | def __init__( 22 | self, 23 | Test, 24 | block, 25 | layers, 26 | imagenet_pre, 27 | num_classes=1000, 28 | feat=False, 29 | t_length=8, 30 | **kwargs): 31 | if not isinstance(block, list): 32 | block = [block] * 4 33 | self.inplanes = 64 34 | super(SmallBigNet, self).__init__() 35 | 36 | self.feat = feat 37 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), 38 | stride=(1, 2, 2), padding=(0, 3, 3), 39 | bias=False) 40 | self.bn1 = nn.BatchNorm3d(64) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 43 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 44 | self.layer2 = self._make_layer( 45 | block[1], 46 | 128, 47 | layers[1], 48 | stride=2, 49 | t_stride=1, 50 | t_length=t_length) 51 | self.layer3 = self._make_layer( 52 | block[2], 53 | 256, 54 | layers[2], 55 | stride=2, 56 | t_stride=1, 57 | t_length=t_length) 58 | self.layer4 = self._make_layer( 59 | block[3], 60 | 512, 61 | layers[3], 62 | stride=2, 63 | t_stride=1, 64 | t_length=t_length) 65 | self.avgpool = nn.AdaptiveAvgPool3d(1) 66 | 67 | self.feat_dim = 512 * block[0].expansion 68 | self.test = Test 69 | if imagenet_pre and is_main_process(): 70 | print('using imagenet pretraining weight set the BN as zero') 71 | 72 | if not feat: 73 | self.fc = nn.Linear(self.feat_dim, num_classes) 74 | 75 | for n, m in self.named_modules(): 76 | if isinstance(m, nn.Conv3d): 77 | nn.init.kaiming_normal_( 78 | m.weight, mode='fan_out', nonlinearity='relu') 79 | if 'big' in n: 80 | if isinstance(m, nn.BatchNorm3d): 81 | nn.init.constant_(m.weight, 0) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | def _make_layer( 85 | self, 86 | block, 87 | planes, 88 | blocks, 89 | stride=1, 90 | t_stride=1, 91 | t_length=8): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = nn.Sequential( 95 | nn.Conv3d( 96 | self.inplanes, 97 | planes * block.expansion, 98 | kernel_size=1, 99 | stride=(1, stride, stride), 100 | bias=False), 101 | nn.BatchNorm3d( 102 | planes * block.expansion), 103 | ) 104 | layers = [] 105 | layers.append( 106 | block( 107 | self.inplanes, 108 | planes, 109 | stride=stride, 110 | t_stride=t_stride, 111 | downsample=downsample, 112 | t_length=t_length)) 113 | self.inplanes = planes * block.expansion 114 | 115 | for i in range(1, blocks): 116 | layers.append(block(self.inplanes, planes, t_length=t_length)) 117 | 118 | return nn.Sequential(*layers) 119 | 120 | def forward(self, input): 121 | x = self.conv1(input) 122 | x = self.bn1(x) 123 | 124 | x = self.relu(x) 125 | x = self.maxpool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | 132 | x = self.avgpool(x) 133 | if not self.test: 134 | x = x.view(x.size(0), -1) 135 | if not self.feat: 136 | x = self.fc(x) 137 | return x 138 | 139 | 140 | def use_image_pre_train(model, args): 141 | if '50' or '23' in args.model_name: 142 | state_dict = model_zoo.load_url(model_urls['resnet50']) 143 | elif '101' in args.model_name: 144 | state_dict = model_zoo.load_url(model_urls['resnet101']) 145 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 146 | idx = 0 147 | model_dict = model.state_dict() 148 | for k, v in new_state_dict.items(): 149 | if k in model_dict: 150 | if v.shape == model_dict[k].shape: 151 | model_dict[k] = v.cuda() 152 | idx += 1 153 | 154 | if is_main_process(): 155 | print(len(new_state_dict)) 156 | print(idx) 157 | print('imagenet pre-trained weight upload already') 158 | model.load_state_dict(model_dict) 159 | 160 | return model 161 | def res50(args): 162 | model = SmallBigNet( 163 | args.test, [Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_000], 164 | [3, 4, 6, 3], args.imagenet, num_classes=args.num_classes, feat=args.feat) 165 | #model = torchvision.models.resnet50(pretrained=False) 166 | if is_main_process(): 167 | # print_model_parm_flops(model, frame=args.t_length) 168 | #print(model) 169 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 170 | if args.imagenet: 171 | model = use_image_pre_train(model, args) 172 | return model 173 | 174 | def slowonly50(args): 175 | model = SmallBigNet( 176 | args.test, [Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100], 177 | [3, 4, 6, 3], args.imagenet, num_classes=args.num_classes, feat=args.feat) 178 | if is_main_process(): 179 | #print_model_parm_flops(model, frame=args.t_length) 180 | #print(model) 181 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 182 | if args.imagenet: 183 | model = use_image_pre_train(model, args) 184 | return model 185 | 186 | def slowonly50_extra(args): 187 | model = SmallBigNet( 188 | args.test, [Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_100_extra, Bottleneck3D_100_extra], 189 | [3, 4, 6, 3], args.imagenet, num_classes=args.num_classes, feat=args.feat) 190 | if is_main_process(): 191 | #print_model_parm_flops(model, frame=args.t_length) 192 | #print(model) 193 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 194 | if args.imagenet: 195 | model = use_image_pre_train(model, args) 196 | return model 197 | 198 | 199 | def smallbig50_no_extra(args): 200 | model = SmallBigNet( 201 | args.test, [Bottleneck3D_000, SmallBig_module, SmallBig_module, SmallBig_module], 202 | [3, 4, 6, 3], args.imagenet, num_classes=args.num_classes, feat=args.feat, t_length=args.t_length) 203 | if is_main_process(): 204 | #print(model) 205 | #print_model_parm_flops(model, frame=args.t_length) 206 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 207 | if args.imagenet: 208 | model = use_image_pre_train(model, args) 209 | return model 210 | 211 | 212 | def smallbig23_no_extra(args): 213 | model = SmallBigNet( 214 | args.test, [Bottleneck3D_000, SmallBig_module, SmallBig_module, SmallBig_module], 215 | [1, 2, 3, 1], args.imagenet, num_classes=args.num_classes, feat=args.feat, t_length=args.t_length) 216 | if is_main_process(): 217 | #print_model_parm_flops(model, frame=args.t_length) 218 | #print(model) 219 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 220 | if args.imagenet: 221 | model = use_image_pre_train(model, args) 222 | return model 223 | 224 | 225 | def smallbig101_no_extra(args): 226 | model = SmallBigNet( 227 | args.test, [Bottleneck3D_000, SmallBig_module, SmallBig_module, SmallBig_module], 228 | [3, 4, 23, 3],args.imagenet, num_classes=args.num_classes, feat=args.feat, t_length=args.t_length) 229 | if is_main_process(): 230 | #print_model_parm_flops(model, frame=args.t_length) 231 | # print(model) 232 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 233 | if args.imagenet: 234 | model = use_image_pre_train(model, args) 235 | return model 236 | 237 | def smallbig50_extra(args): 238 | model = SmallBigNet( 239 | args.test, [Bottleneck3D_000, SmallBig_module_extra, SmallBig_module_extra, SmallBig_module_extra], 240 | [3, 4, 6, 3], args.imagenet, num_classes=args.num_classes, feat=args.feat, t_length=args.t_length) 241 | if is_main_process(): 242 | #print(model) 243 | # print_model_parm_flops(model, frame=args.t_length) 244 | print(sum([np.prod(param.data.shape) for param in model.parameters()])) 245 | if args.imagenet: 246 | model = use_image_pre_train(model, args) 247 | return model 248 | 249 | def get_model(args): 250 | model_name_dict = { 251 | 'res50': res50(args), 252 | 'slowonly50' : slowonly50(args), 253 | 'slowonly50_extra': slowonly50_extra(args), 254 | 'smallbig23_no_extra' : smallbig23_no_extra(args), 255 | 'smallbig50_no_extra' : smallbig50_no_extra(args), 256 | 'smallbig101_no_extra': smallbig101_no_extra(args), 257 | 'smallbig50_extra': smallbig50_extra(args), 258 | } 259 | 260 | return model_name_dict[args.model_name] 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser(description='SmallBig Training') 265 | parser.add_argument('--test', action='store_true') 266 | parser.add_argument('--feat', action='store_true') 267 | parser.add_argument('--imagenet', action='store_true') 268 | parser.add_argument('--num_classes', default=400, type=int, 269 | help="num classes of your dataset") 270 | parser.add_argument('--t_length', default=8, type=int, 271 | help="Total length of sampling frames.") 272 | parser.add_argument('--model_name', default="smallbig50_no_extra", 273 | choices=[ 274 | "slowonly50", 275 | "smallbig23_no_extra", 276 | "smallbig50_no_extra", 277 | "smallbig101_no_extra"], 278 | help="name of your model") 279 | args = parser.parse_args() 280 | net = get_model(args) 281 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhl-video/SmallBigNet/9e6d9ea4b61a0efb87893ab830463f56d0c5c8b4/models/__init__.py -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class Bottleneck3D_100(nn.Module): 7 | expansion = 4 8 | 9 | def __init__( 10 | self, 11 | inplanes, 12 | planes, 13 | stride=1, 14 | t_stride=1, 15 | downsample=None, 16 | t_length=None): 17 | super(Bottleneck3D_100, self).__init__() 18 | self.conv1 = nn.Conv3d( 19 | inplanes, planes, 20 | kernel_size=(3, 1, 1), 21 | stride=(t_stride, 1, 1), 22 | padding=(1, 0, 0), 23 | bias=False) 24 | self.bn1 = nn.BatchNorm3d(planes) 25 | self.conv2 = nn.Conv3d( 26 | planes, planes, 27 | kernel_size=(1, 3, 3), 28 | stride=(1, stride, stride), 29 | padding=(0, 1, 1), bias=False) 30 | self.bn2 = nn.BatchNorm3d(planes) 31 | self.conv3 = nn.Conv3d( 32 | planes, 33 | planes * 34 | self.expansion, 35 | kernel_size=1, 36 | bias=False) 37 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv3(out) 54 | out = self.bn3(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck3D_000(nn.Module): 66 | expansion = 4 67 | 68 | def __init__( 69 | self, 70 | inplanes, 71 | planes, 72 | stride=1, 73 | t_stride=1, 74 | downsample=None, 75 | t_length=None): 76 | super(Bottleneck3D_000, self).__init__() 77 | self.conv1 = nn.Conv3d( 78 | inplanes, planes, 79 | kernel_size=1, 80 | stride=[t_stride, 1, 1], 81 | bias=False) 82 | self.bn1 = nn.BatchNorm3d(planes) 83 | self.conv2 = nn.Conv3d( 84 | planes, planes, 85 | kernel_size=(1, 3, 3), 86 | stride=[1, stride, stride], 87 | padding=(0, 1, 1), 88 | bias=False) 89 | self.bn2 = nn.BatchNorm3d(planes) 90 | self.conv3 = nn.Conv3d( 91 | planes, 92 | planes * self.expansion, 93 | kernel_size=1, 94 | bias=False) 95 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | residual = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv3(out) 112 | out = self.bn3(out) 113 | 114 | if self.downsample is not None: 115 | residual = self.downsample(x) 116 | 117 | out += residual 118 | out = self.relu(out) 119 | 120 | return out 121 | 122 | class Bottleneck3D_100_extra(nn.Module): 123 | expansion = 4 124 | 125 | def __init__( 126 | self, 127 | inplanes, 128 | planes, 129 | stride=1, 130 | t_stride=1, 131 | downsample=None, 132 | t_length=None): 133 | super(Bottleneck3D_100_extra, self).__init__() 134 | self.conv1 = nn.Conv3d( 135 | inplanes, planes, 136 | kernel_size=(3, 1, 1), 137 | stride=(t_stride, 1, 1), 138 | padding=(1, 0, 0), 139 | bias=False) 140 | self.bn1 = nn.BatchNorm3d(planes) 141 | self.conv2 = nn.Conv3d( 142 | planes, planes, 143 | kernel_size=(1, 3, 3), 144 | stride=(1, stride, stride), 145 | padding=(0, 1, 1), bias=False) 146 | self.bn2 = nn.BatchNorm3d(planes) 147 | self.conv3 = nn.Conv3d( 148 | planes, 149 | planes * 150 | self.expansion, 151 | kernel_size=1, 152 | bias=False) 153 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 154 | self.relu = nn.ReLU(inplace=True) 155 | self.downsample = downsample 156 | self.stride = stride 157 | self.big4_1 = nn.Conv3d( 158 | inplanes, 159 | inplanes // 4, 160 | kernel_size=1, 161 | padding=0, 162 | stride=1, 163 | bias=False) 164 | self.big4_2 = nn.Sequential(nn.Conv3d( 165 | inplanes // 4, 166 | inplanes, 167 | kernel_size=1, 168 | padding=0, 169 | stride=1, 170 | bias=False), 171 | nn.BatchNorm3d(inplanes)) 172 | self.big_extra = nn.AdaptiveMaxPool3d(1) 173 | self.sigmod = nn.Sigmoid() 174 | 175 | def forward(self, x): 176 | x = x + self.big4_2(self.big4_1(x) * self.sigmod(self.big4_1(self.big_extra(x)))) 177 | residual = x 178 | 179 | out = self.conv1(x) 180 | out = self.bn1(out) 181 | out = self.relu(out) 182 | 183 | out = self.conv2(out) 184 | out = self.bn2(out) 185 | out = self.relu(out) 186 | 187 | out = self.conv3(out) 188 | out = self.bn3(out) 189 | 190 | if self.downsample is not None: 191 | residual = self.downsample(x) 192 | 193 | out += residual 194 | out = self.relu(out) 195 | 196 | return out 197 | 198 | class SmallBig_module(nn.Module): 199 | expansion = 4 200 | 201 | def __init__( 202 | self, 203 | inplanes, 204 | planes, 205 | stride=1, 206 | t_stride=1, 207 | t_length=8, 208 | downsample=None): 209 | super(SmallBig_module, self).__init__() 210 | self.conv1 = nn.Conv3d( 211 | inplanes, planes, 212 | kernel_size=1, 213 | stride=[t_stride, 1, 1], 214 | bias=False) 215 | self.bn1 = nn.BatchNorm3d(planes) 216 | self.conv2 = nn.Conv3d( 217 | planes, planes, 218 | kernel_size=(1, 3, 3), 219 | stride=[1, stride, stride], 220 | padding=(0, 1, 1), 221 | bias=False) 222 | self.bn2 = nn.BatchNorm3d(planes) 223 | self.conv3 = nn.Conv3d( 224 | planes, 225 | planes * 226 | self.expansion, 227 | kernel_size=1, 228 | bias=False) 229 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 230 | self.relu = nn.ReLU(inplace=True) 231 | 232 | self.big1 = nn.Sequential( 233 | nn.MaxPool3d( 234 | kernel_size=3, 235 | padding=1, 236 | stride=1), 237 | nn.BatchNorm3d(planes)) 238 | 239 | self.big2 = nn.Sequential( 240 | nn.MaxPool3d( 241 | kernel_size=3, 242 | padding=1, 243 | stride=1), 244 | nn.BatchNorm3d(planes)) 245 | 246 | self.big3 = nn.Sequential( 247 | nn.MaxPool3d( 248 | kernel_size=(t_length, 3, 3), 249 | padding=(0, 1, 1), 250 | stride=1), 251 | nn.BatchNorm3d(planes * self.expansion)) 252 | self.downsample = downsample 253 | self.stride = stride 254 | 255 | def forward(self, x): 256 | residual = x 257 | out = self.conv1(x) 258 | out = self.bn1(out) 259 | out = self.relu(out + self.big1[1](self.conv1(self.big1[0](x)))) 260 | out = self.relu(self.bn2(self.conv2(out)) + self.big2[1](self.conv2(self.big2[0](out)))) 261 | if self.downsample is not None: 262 | residual = self.downsample(x) 263 | out = self.bn3(self.conv3(out)) + self.big3[1](self.conv3(self.big3[0](out))) 264 | out += residual 265 | out = self.relu(out) 266 | 267 | return out 268 | 269 | class SmallBig_module_extra(nn.Module): 270 | expansion = 4 271 | 272 | def __init__( 273 | self, 274 | inplanes, 275 | planes, 276 | stride=1, 277 | t_stride=1, 278 | t_length=8, 279 | downsample=None): 280 | super(SmallBig_module_extra, self).__init__() 281 | self.conv1 = nn.Conv3d( 282 | inplanes, planes, 283 | kernel_size=1, 284 | stride=[t_stride, 1, 1], 285 | bias=False) 286 | self.bn1 = nn.BatchNorm3d(planes) 287 | self.conv2 = nn.Conv3d( 288 | planes, planes, 289 | kernel_size=(1, 3, 3), 290 | stride=[1, stride, stride], 291 | padding=(0, 1, 1), 292 | bias=False) 293 | self.bn2 = nn.BatchNorm3d(planes) 294 | self.conv3 = nn.Conv3d( 295 | planes, 296 | planes * 297 | self.expansion, 298 | kernel_size=1, 299 | bias=False) 300 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 301 | self.relu = nn.ReLU(inplace=True) 302 | 303 | self.big1 = nn.Sequential( 304 | nn.MaxPool3d( 305 | kernel_size=3, 306 | padding=1, 307 | stride=1), 308 | nn.BatchNorm3d(planes)) 309 | 310 | self.big2 = nn.Sequential( 311 | nn.MaxPool3d( 312 | kernel_size=3, 313 | padding=1, 314 | stride=1), 315 | nn.BatchNorm3d(planes)) 316 | 317 | self.big3 = nn.Sequential( 318 | nn.MaxPool3d( 319 | kernel_size=(t_length, 3, 3), 320 | padding=(0, 1, 1), 321 | stride=1), 322 | nn.BatchNorm3d(planes * self.expansion)) 323 | self.downsample = downsample 324 | self.stride = stride 325 | self.big4_1 = nn.Conv3d( 326 | inplanes, 327 | inplanes // 4, 328 | kernel_size=1, 329 | padding=0, 330 | stride=1, 331 | bias=False) 332 | self.big4_2 = nn.Sequential(nn.Conv3d( 333 | inplanes // 4, 334 | inplanes, 335 | kernel_size=1, 336 | padding=0, 337 | stride=1, 338 | bias=False), 339 | nn.BatchNorm3d(inplanes)) 340 | self.big_extra = nn.AdaptiveMaxPool3d(1) 341 | self.sigmod = nn.Sigmoid() 342 | 343 | def forward(self, x): 344 | 345 | x = x + self.big4_2(self.big4_1(x) * self.sigmod(self.big4_1(self.big_extra(x)))) 346 | 347 | residual = x 348 | out = self.conv1(x) 349 | out = self.bn1(out) 350 | out = self.relu(out + self.big1[1](self.conv1(self.big1[0](x)))) 351 | out = self.relu(self.bn2(self.conv2(out)) + self.big2[1](self.conv2(self.big2[0](out)))) 352 | if self.downsample is not None: 353 | residual = self.downsample(x) 354 | out = self.bn3(self.conv3(out)) + self.big3[1](self.conv3(self.big3[0](out))) 355 | out += residual 356 | out = self.relu(out) 357 | 358 | return out 359 | 360 | 361 | class SmallBig_plus_module_extra(nn.Module): 362 | expansion = 4 363 | 364 | def __init__( 365 | self, 366 | inplanes, 367 | planes, 368 | stride=1, 369 | t_stride=1, 370 | t_length=8, 371 | downsample=None): 372 | super(SmallBig_plus_module_extra, self).__init__() 373 | self.conv1 = nn.Conv3d( 374 | inplanes, planes, 375 | kernel_size=1, 376 | stride=[t_stride, 1, 1], 377 | bias=False) 378 | self.bn1 = nn.BatchNorm3d(planes) 379 | self.conv2 = nn.Conv3d( 380 | planes, planes, 381 | kernel_size=(1, 3, 3), 382 | stride=[1, stride, stride], 383 | padding=(0, 1, 1), 384 | bias=False) 385 | self.bn2 = nn.BatchNorm3d(planes) 386 | self.conv3 = nn.Conv3d( 387 | planes, 388 | planes * 389 | self.expansion, 390 | kernel_size=1, 391 | bias=False) 392 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 393 | self.relu = nn.ReLU(inplace=True) 394 | 395 | self.big1 = nn.Sequential( 396 | nn.AvgPool3d( 397 | kernel_size=(1, 2, 2), 398 | padding=0, 399 | stride=(1, 2, 2)), 400 | nn.BatchNorm3d(planes), 401 | nn.UpsamplingBilinear2d(scale_factor=2), 402 | ) 403 | 404 | self.big2 = nn.Sequential( 405 | nn.AvgPool3d( 406 | kernel_size=(1, 2, 2), 407 | padding=0, 408 | stride=(1, 2, 2)), 409 | nn.BatchNorm3d(planes), 410 | nn.UpsamplingBilinear2d(scale_factor=2), 411 | ) 412 | self.big3 = nn.Sequential( 413 | nn.AvgPool3d( 414 | kernel_size=(t_length, 2, 2), 415 | padding=0, 416 | stride=(1, 2, 2)), 417 | nn.BatchNorm3d(planes), 418 | nn.Upsample(scale_factor=(1, 2, 2)), 419 | ) 420 | self.downsample = downsample 421 | self.stride = stride 422 | self.big4_1 = nn.Conv3d( 423 | inplanes, 424 | inplanes // 4, 425 | kernel_size=1, 426 | padding=0, 427 | stride=1, 428 | bias=False) 429 | self.big4_2 = nn.Sequential(nn.Conv3d( 430 | inplanes // 4, 431 | inplanes, 432 | kernel_size=1, 433 | padding=0, 434 | stride=1, 435 | bias=False), 436 | nn.BatchNorm3d(inplanes)) 437 | self.big_extra = nn.AdaptiveMaxPool3d(1) 438 | self.sigmod = nn.Sigmoid() 439 | 440 | def forward(self, x): 441 | 442 | x = x + self.big4_2(self.big4_1(x) * self.sigmod(self.big4_1(self.big_extra(x)))) 443 | 444 | residual = x 445 | out = self.conv1(x) 446 | out = self.bn1(out) 447 | out = self.relu(out + self.big1[2](self.big1[1](self.conv1(self.big1[0](x))))) 448 | out = self.relu(self.bn2(self.conv2(out)) + self.big2[2](self.big2[1](self.conv2(self.big2[0](out))))) 449 | if self.downsample is not None: 450 | residual = self.downsample(x) 451 | out = self.bn3(self.conv3(out)) + self.big3[2](self.big3[1](self.conv3(self.big3[0](out)))) 452 | out += residual 453 | out = self.relu(out) 454 | 455 | return out 456 | -------------------------------------------------------------------------------- /scripts/kinetics.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8 python -m torch.distributed.launch --nproc_per_node=8 main.py \ 7 | --distribute --half --num_classes 400 --batch_size 64 --t_length 8 --t_stride 8 --image_tmpl img_{:05d}.jpg \ 8 | --dataset kinetics --root_path /data2/data/kinetics_400/rawframes_320 --val_list_file /data1/data/kinetics_400/RGB_val_videofolder.txt --train_list_file /data1/data/kinetics_400/RGB_train_videofolder.txt \ 9 | --model_name smallbig50_no_extra --lr 0.01 --num_epochs 100 --check_dir ./checkpoint --log_dir ./logs 10 | 11 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xhl-video/SmallBigNet/9e6d9ea4b61a0efb87893ab830463f56d0c5c8b4/tools/__init__.py -------------------------------------------------------------------------------- /tools/tools.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import logging 3 | import math 4 | import shutil 5 | import time 6 | from collections import defaultdict 7 | 8 | import torch 9 | import os 10 | import random 11 | import torch.distributed as dist 12 | from torch.autograd import Variable 13 | import numpy as np 14 | import torch.nn as nn 15 | 16 | 17 | def is_dist_avail_and_initialized(): 18 | if not dist.is_available(): 19 | return False 20 | if not dist.is_initialized(): 21 | return False 22 | return True 23 | 24 | 25 | def get_world_size(): 26 | if not is_dist_avail_and_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank(): 32 | if not is_dist_avail_and_initialized(): 33 | return 0 34 | return dist.get_rank() 35 | 36 | 37 | def is_main_process(): 38 | return get_rank() == 0 39 | 40 | 41 | class AverageMeter(object): 42 | """Computes and stores the average and current value 43 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 44 | """ 45 | 46 | def __init__(self): 47 | self.reset() 48 | 49 | def reset(self): 50 | self.val = 0 51 | self.avg = 0 52 | self.sum = 0 53 | self.count = 0 54 | 55 | def update(self, val, n=1): 56 | self.val = val 57 | self.sum += val * n 58 | self.count += n 59 | self.avg = self.sum / self.count 60 | 61 | def synchronize_between_processes(self): 62 | """ 63 | Warning: does not synchronize the deque! 64 | """ 65 | if not is_dist_avail_and_initialized(): 66 | return 67 | t = torch.tensor([self.count, self.sum], 68 | dtype=torch.float64, device='cuda') 69 | dist.barrier() 70 | dist.all_reduce(t) 71 | t = t.tolist() 72 | self.count = int(t[0]) 73 | self.sum = t[1] 74 | self.avg = self.sum / (self.count + 1e-5) 75 | 76 | 77 | def to_item(x): 78 | """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" 79 | if isinstance(x, (float, int)): 80 | return x 81 | 82 | if float(torch.__version__[0:3]) < 0.4: 83 | assert (x.dim() == 1) and (len(x) == 1) 84 | return x[0] 85 | 86 | return x.item() 87 | 88 | 89 | def makedirs(path): 90 | if not os.path.exists(path): 91 | os.makedirs(path, 0o777) 92 | 93 | 94 | def load_value_file(file_path): 95 | with open(file_path, 'r') as input_file: 96 | value = float(input_file.read().rstrip('\n\r')) 97 | 98 | return value 99 | 100 | 101 | def save_checkpoint(state, is_best, epoch, experiment_root, 102 | filename='checkpoint_{}epoch.pth'): 103 | filename = os.path.join(experiment_root, filename.format(epoch)) 104 | logging.info("saving model to {}...".format(filename)) 105 | torch.save(state, filename) 106 | if is_best: 107 | best_name = os.path.join(experiment_root, 'model_best.pth') 108 | shutil.copyfile(filename, best_name) 109 | logging.info("saving done.") 110 | 111 | 112 | def calculate_accuracy(outputs, targets): 113 | batch_size = targets.size(0) 114 | 115 | _, pred = outputs.topk(1, 1, True) 116 | pred = pred.t() 117 | correct = pred.eq(targets.view(1, -1)) 118 | n_correct_elems = correct.float().sum().item() 119 | 120 | return n_correct_elems / batch_size 121 | 122 | 123 | def accuracy(output, target, topk=(1,)): 124 | """Computes the precision@k for the specified values of k""" 125 | maxk = max(topk) 126 | batch_size = target.size(0) 127 | 128 | _, pred = output.topk(maxk, 1, True, True) 129 | pred = pred.contiguous().t() 130 | correct = pred.eq(target.view(1, -1).expand_as(pred)).contiguous() 131 | 132 | res = [] 133 | for k in topk: 134 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 135 | res.append(correct_k.mul_(100.0 / batch_size)) 136 | return res 137 | 138 | 139 | def save_checkpoint( 140 | state, 141 | is_best, 142 | epoch, 143 | experiment_root, 144 | args = None, 145 | name='name', 146 | filename='checkpoint_{}epoch.pth'): 147 | checkdir = os.path.join(experiment_root, args.model_name) 148 | if not os.path.exists(checkdir): 149 | os.makedirs(checkdir) 150 | file = args.model_name + '_' + args.dataset + '_t_length_' + str(args.t_length) + '_t_stride_' + str( 151 | args.t_stride) + '_batch_' + str( 152 | args.batch_size) + '_lr_' + str(args.lr) + time.strftime("%d_%b_%Y_%H:%M:%S", time.localtime()) 153 | file_dir = os.path.join(checkdir, file) 154 | if not os.path.exists(file_dir): 155 | os.makedirs(file_dir, exist_ok=True) 156 | 157 | file = os.path.join(file_dir, filename.format(epoch)) 158 | 159 | torch.save(state, file) 160 | if is_best: 161 | best_name = os.path.join( 162 | file_dir, 163 | 'model_best_' + name + '.pth') 164 | shutil.copyfile(file, best_name) 165 | 166 | 167 | def adjust_learning_rate1(optimizer, base_lr, epoch): 168 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 169 | 170 | alpha = (epoch + 2000) / 2000 171 | warm = (1. / 10) * (1 - alpha) + alpha 172 | lr = base_lr * warm 173 | for param_group in optimizer.param_groups: 174 | param_group['lr'] = lr 175 | 176 | 177 | def adjust_learning_rate(optimizer, base_lr, epoch, lr_steps): 178 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 179 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 180 | lr = base_lr * decay 181 | for param_group in optimizer.param_groups: 182 | param_group['lr'] = lr 183 | 184 | 185 | def convert_arr2matrix(actions): 186 | node = [] 187 | output = [] 188 | intermit = [] 189 | operator = [] 190 | graph = np.zeros((5, 5)) 191 | node.append([actions[0]]) 192 | 193 | operator.append([actions[1]]) 194 | idx = 2 195 | for i in range(1, 4): 196 | node.append([actions[idx * i + j] for j in range(i + 1)]) 197 | 198 | operator.append([actions[idx * (i + 1) + j] for j in range(i + 1)]) 199 | idx += 1 200 | 201 | for j in range(len(node)): 202 | ip = True 203 | for p in node[j + 1:]: 204 | if j + 2 in p: 205 | ip = False 206 | 207 | if ip: 208 | output.append(j + 2) 209 | else: 210 | intermit.append(j + 2) 211 | idy = 1 212 | for i in range(len(node)): 213 | for j in range(len(node[i])): 214 | 215 | if node[i][j] != 0: 216 | 217 | if i + 2 in intermit: 218 | graph[node[i][j] - 1][idy] = node[i][j] + 1 219 | graph[idy][node[i][j] - 1] = node[i][j] + 1 220 | else: 221 | graph[node[i][j] - 1][idy] = node[i][j] 222 | graph[idy][node[i][j] - 1] = node[i][j] 223 | idy += 1 224 | return graph 225 | 226 | 227 | class keydefaultdict(defaultdict): 228 | def __missing__(self, key): 229 | if self.default_factory is None: 230 | raise KeyError(key) 231 | else: 232 | ret = self[key] = self.default_factory(key) 233 | return ret 234 | 235 | 236 | def get_variable(inputs, cuda=False, **kwargs): 237 | if type(inputs) in [list, np.ndarray]: 238 | inputs = torch.Tensor(inputs) 239 | if cuda: 240 | out = Variable(inputs.cuda(), **kwargs) 241 | else: 242 | out = Variable(inputs, **kwargs) 243 | return out 244 | 245 | 246 | def part_state_dict(state_dict, model_dict): 247 | pretrained_dict = {} 248 | for k, v in state_dict.items(): 249 | if k in model_dict: 250 | pretrained_dict[k] = v 251 | else: 252 | print(k) 253 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict) 254 | # model_dict.update(pretrained_dict) 255 | return pretrained_dict 256 | 257 | 258 | def adjust_learning_rate2(optimizer, base_lr): 259 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 260 | 261 | lr = base_lr 262 | for param_group in optimizer.param_groups: 263 | param_group['lr'] = lr 264 | 265 | 266 | def inflate_state_dict(pretrained_dict, model_dict): 267 | for k in pretrained_dict.keys(): 268 | if k in model_dict.keys() and 'fc' not in k: 269 | if pretrained_dict[k].size() != model_dict[k].size(): 270 | assert( 271 | pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), "To inflate, channel number should match." 272 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), "To inflate, spatial kernel size should match." 273 | #print("Layer {} needs inflation.".format(k)) 274 | shape = list(pretrained_dict[k].shape) 275 | shape.insert(2, 1) 276 | t_length = model_dict[k].shape[2] 277 | pretrained_dict[k] = pretrained_dict[k].reshape(shape) 278 | if t_length != 1: 279 | pretrained_dict[k] = pretrained_dict[k].expand_as( 280 | model_dict[k]) / t_length 281 | assert(pretrained_dict[k].size() == model_dict[k].size()), \ 282 | "After inflation, model shape should match." 283 | 284 | return pretrained_dict 285 | 286 | 287 | def print_model_parm_flops(model, frame=8): 288 | 289 | prods = {} 290 | 291 | def save_hook(name): 292 | def hook_per(self, input, output): 293 | # print 'flops:{}'.format(self.__class__.__name__) 294 | # print 'input:{}'.format(input) 295 | # print '_dim:{}'.format(input[0].dim()) 296 | # print 'input_shape:{}'.format(np.prod(input[0].shape)) 297 | # prods.append(np.prod(input[0].shape)) 298 | prods[name] = np.prod(input[0].shape) 299 | # prods.append(np.prod(input[0].shape)) 300 | 301 | return hook_per 302 | 303 | list_1 = [] 304 | 305 | def simple_hook(self, input, output): 306 | list_1.append(np.prod(input[0].shape)) 307 | 308 | list_2 = {} 309 | 310 | def simple_hook2(self, input, output): 311 | list_2['names'] = np.prod(input[0].shape) 312 | 313 | multiply_adds = False 314 | list_conv = [] 315 | 316 | def conv_hook(self, input, output): 317 | batch_size, input_channels, time_stride, input_height, input_width = input[0].size( 318 | ) 319 | output_channels, out_time, output_height, output_width = output[0].size( 320 | ) 321 | 322 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] * ( 323 | self.in_channels / self.groups) * (2 if multiply_adds else 1) 324 | bias_ops = 1 if self.bias is not None else 0 325 | 326 | params = output_channels * (kernel_ops + bias_ops) 327 | flops = batch_size * params * output_height * output_width * out_time 328 | 329 | list_conv.append(flops) 330 | 331 | list_linear = [] 332 | 333 | def linear_hook(self, input, output): 334 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 335 | 336 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 337 | bias_ops = self.bias.nelement() 338 | 339 | flops = batch_size * (weight_ops + bias_ops) 340 | list_linear.append(flops) 341 | 342 | list_bn = [] 343 | 344 | def bn_hook(self, input, output): 345 | list_bn.append(input[0].nelement()) 346 | 347 | list_relu = [] 348 | 349 | def relu_hook(self, input, output): 350 | list_relu.append(input[0].nelement()) 351 | 352 | list_pooling = [] 353 | 354 | def pooling_hook(self, input, output): 355 | batch_size, input_channels, time_stride, input_height, input_width = input[0].size( 356 | ) 357 | output_channels, out_time, output_height, output_width = output[0].size( 358 | ) 359 | 360 | kernel_ops = self.kernel_size[0] * \ 361 | self.kernel_size[1] * self.kernel_size[2] 362 | bias_ops = 0 363 | params = output_channels * (kernel_ops + bias_ops) 364 | flops = batch_size * params * output_height * output_width * out_time 365 | 366 | list_pooling.append(flops) 367 | 368 | def foo(net): 369 | childrens = list(net.children()) 370 | if not childrens: 371 | if isinstance(net, torch.nn.Conv3d): 372 | # net.register_forward_hook(save_hook(net.__class__.__name__)) 373 | # net.register_forward_hook(simple_hook) 374 | # net.register_forward_hook(simple_hook2) 375 | net.register_forward_hook(conv_hook) 376 | if isinstance(net, torch.nn.Linear): 377 | net.register_forward_hook(linear_hook) 378 | if isinstance(net, torch.nn.BatchNorm3d): 379 | net.register_forward_hook(bn_hook) 380 | if isinstance(net, torch.nn.ReLU): 381 | net.register_forward_hook(relu_hook) 382 | if isinstance( 383 | net, torch.nn.MaxPool3d) or isinstance( 384 | net, torch.nn.AvgPool3d): 385 | net.register_forward_hook(pooling_hook) 386 | return 387 | for c in childrens: 388 | foo(c) 389 | 390 | criterion = nn.CrossEntropyLoss().cuda() 391 | 392 | model = model 393 | 394 | foo(model) 395 | input = Variable(torch.rand(1, 3, frame, 224, 224), requires_grad=True) 396 | out = model(input) 397 | 398 | total_flops = ( 399 | sum(list_conv) + 400 | sum(list_linear) + 401 | sum(list_bn) + 402 | sum(list_relu) + 403 | sum(list_pooling)) 404 | 405 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 406 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import torch 4 | 5 | from tools.tools import AverageMeter, accuracy, is_main_process 6 | 7 | 8 | def train( 9 | train_loader, 10 | model, 11 | criterion, 12 | optimizer, 13 | epoch, 14 | print_freq, 15 | writer, 16 | args=None,): 17 | batch_time = AverageMeter() 18 | data_time = AverageMeter() 19 | losses = AverageMeter() 20 | top1 = AverageMeter() 21 | top5 = AverageMeter() 22 | 23 | # switch to train mode 24 | model.train() 25 | end = time.time() 26 | if args.distribute: 27 | local_rank = torch.distributed.get_rank() 28 | torch.cuda.set_device(args.local_rank) 29 | device = torch.device("cuda", local_rank) 30 | else: 31 | device = torch.device("cuda") 32 | 33 | for i, (input, target) in enumerate(train_loader): 34 | 35 | 36 | data_time.update(time.time() - end) 37 | target = target.cuda(non_blocking=True, device=device) 38 | 39 | if args.half: 40 | with torch.cuda.amp.autocast(): 41 | output = model(input.cuda(device)) 42 | loss = criterion(output, target) 43 | else: 44 | output = model(input.cuda(device)) 45 | loss = criterion(output, target) 46 | 47 | optimizer.zero_grad() 48 | loss.backward() 49 | optimizer.step() 50 | 51 | batch_time.update(time.time() - end) 52 | end = time.time() 53 | 54 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 55 | top5.update(prec5.item(), input.size(0)) 56 | top1.update(prec1.item(), input.size(0)) 57 | losses.update(loss.item(), input.size(0)) 58 | if i % print_freq == 0 and is_main_process(): 59 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.8f}\t' 60 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 61 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 62 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 63 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(epoch, 64 | int(i), 65 | int(len(train_loader)), 66 | batch_time=batch_time, 67 | loss=losses, 68 | top1=top1, 69 | top5=top5, 70 | lr=optimizer.param_groups[-1]['lr']))) 71 | if args.distribute: 72 | losses.synchronize_between_processes() 73 | top1.synchronize_between_processes() 74 | top5.synchronize_between_processes() 75 | if is_main_process(): 76 | writer.add_scalar('Train/loss', losses.avg, epoch) 77 | writer.add_scalar('Train/top1', top1.avg, epoch) 78 | writer.add_scalar('Train/lr', optimizer.param_groups[-1]['lr'], epoch) 79 | logging.info( 80 | ('Epoch {epoch} Training Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 81 | .format(epoch=epoch, top1=top1, top5=top5, loss=losses))) 82 | 83 | 84 | 85 | def validate(val_loader, model, criterion, print_freq, epoch, writer, args=None): 86 | batch_time = AverageMeter() 87 | losses = AverageMeter() 88 | top1 = AverageMeter() 89 | top5 = AverageMeter() 90 | model.eval() 91 | if args.distribute: 92 | local_rank = torch.distributed.get_rank() 93 | torch.cuda.set_device(args.local_rank) 94 | device = torch.device("cuda", local_rank) 95 | else: 96 | device = torch.device("cuda") 97 | with torch.no_grad(): 98 | end = time.time() 99 | for i, (input, target) in enumerate(val_loader): 100 | target = target.cuda(non_blocking=True, device=device) 101 | 102 | # compute output 103 | output = model(input) 104 | loss = criterion(output, target) 105 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 106 | top5.update(prec5.item(), input.size(0)) 107 | losses.update(loss.item(), input.size(0)) 108 | top1.update(prec1.item(), input.size(0)) 109 | 110 | # measure elapsed time 111 | batch_time.update(time.time() - end) 112 | end = time.time() 113 | 114 | if i % print_freq == 0 and is_main_process(): 115 | logging.info( 116 | ('Test: [{0}/{1}]\t' 117 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 118 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 119 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 120 | i, 121 | len(val_loader), 122 | batch_time=batch_time, 123 | top1=top1, 124 | top5=top5))) 125 | 126 | if args.distribute: 127 | losses.synchronize_between_processes() 128 | top1.synchronize_between_processes() 129 | top5.synchronize_between_processes() 130 | if is_main_process(): 131 | writer.add_scalar('Test/loss', losses.avg, epoch) 132 | writer.add_scalar('Test/top1', top1.avg, epoch) 133 | logging.info( 134 | ('Epoch {epoch} Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 135 | .format(epoch=epoch, top1=top1, top5=top5, loss=losses))) 136 | 137 | return (top1.avg + top5.avg) / 2 138 | --------------------------------------------------------------------------------