├── .gitignore ├── .gitmodules ├── README.md ├── clip_test.py ├── compute_flops.py ├── data ├── hmdb51 │ ├── hmdb51_train_split1_list.txt │ ├── hmdb51_train_split2_list.txt │ ├── hmdb51_train_split3_list.txt │ ├── hmdb51_val_split1_list.txt │ ├── hmdb51_val_split2_list.txt │ └── hmdb51_val_split3_list.txt ├── kinetics200 │ ├── 400_200_label_mapping.txt │ ├── create_kinetics200_list.py │ ├── kinetics200_train_list.txt │ ├── kinetics200_train_list_org.txt │ ├── kinetics200_val_list.txt │ ├── kinetics200_val_list_org.txt │ ├── kinetics_train_list.txt │ └── kinetics_val_list.txt ├── kinetics400 │ ├── count.py │ ├── create_xlw_list.py │ ├── kinetics_train_list.txt │ ├── kinetics_train_list_xlw │ ├── kinetics_val_list.txt │ └── kinetics_val_list_xlw ├── sthsth_v1 │ ├── create_sthsth_v1_list.py │ ├── something-something-v1-labels.csv │ ├── something-something-v1-test.csv │ ├── something-something-v1-train.csv │ ├── something-something-v1-validation.csv │ ├── sthv1_train_list.txt │ └── sthv1_val_list.txt └── ucf101 │ ├── ucf101_train_split1_list.txt │ ├── ucf101_train_split2_list.txt │ ├── ucf101_train_split3_list.txt │ ├── ucf101_val_split1_list.txt │ ├── ucf101_val_split2_list.txt │ └── ucf101_val_split3_list.txt ├── finetune_bn_frozen.py ├── finetune_fc.py ├── lib ├── dataset.py ├── models.py ├── modules │ ├── __init__.py │ ├── pooling.py │ └── scale.py ├── networks │ ├── __init__.py │ ├── mnet2.py │ ├── mnet2_3d.py │ ├── part_inflate_resnet_3d.py │ ├── resnet.py │ ├── resnet_3d.py │ └── resnet_3d_nodown.py ├── opts.py ├── transforms.py └── utils │ ├── deprefix.py │ ├── tools.py │ ├── vis_comb.py │ └── visualization.py ├── main.py ├── main_20bn.py ├── main_imagenet.py ├── scripts ├── imagenet_2d_res26.sh └── kinetics400_3d_res50_slowonly_im_pre.sh ├── test_10crop.py ├── test_kaiming.py └── train_val.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | access/ 3 | output/ 4 | models/ 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | .idea 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data/kinetics200/Mini-Kinetics-200"] 2 | path = data/kinetics200/Mini-Kinetics-200 3 | url = https://github.com/BannyStone/Mini-Kinetics-200 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video-Classification-Pytorch 2 | 3 | ***This is an archived repo. Stronly recommend PySlowFast or mmaction for video understanding***. 4 | 5 | This is a repository containing 3D models and 2D models for video classification. The code is based on PyTorch 1.0. 6 | Until now, it supports the following datasets: 7 | Kinetics-400, Mini-Kinetics-200, UCF101, HMDB51 8 | 9 | ## Results 10 | 11 | ### Kinetics-400 12 | 13 | We report the baselines with ResNet-50 backbone on Kinetics-400 validation set as below (all models are trained on training set). 14 | All the models are trained in one single server with 8 GTX 1080 Ti GPUs. 15 | 16 | | network | pretrain data | spatial resolution | input frames | sampling stride | backbone | top1 | top5 | 17 | | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | 18 | | ResNet50-SlowOnly | ImageNet-1K | 224x224 | 8 | 8 | ResNet50 | 73.77 | 91.17 | 19 | 20 | 21 | ## Get the Code 22 | ```Shell 23 | git clone --recursive https://github.com/BannyStone/Video_Classification_PyTorch.git 24 | ``` 25 | 26 | ## Preparing Dataset 27 | ### Kinetics-400 28 | ```Shell 29 | cd data/kinetics400 30 | mkdir access && cd access 31 | ln -s $YOUR_KINETICS400_DATASET_TRAIN_DIR$ RGB_train 32 | ln -s $YOUR_KINETICS400_DATASET_VAL_DIR$ RGB_val 33 | ``` 34 | Note that: 35 | - The reported models are trained with the Kinetics data provided by Xiaolong Wang.https://github.com/facebookresearch/video-nonlocal-net/blob/master/DATASET.md 36 | - In train and validation lists for all datasets, each line represents one video where the first element is the video frame directory, the second element is the number of frames and the third element is the index of class. Please prepare your own list accordingly because different video parsing method may lead to different frame numbers. We show part of Kinetics-400 train list as an example: 37 | ```shell 38 | RGB_train/D32_1gwq35E 300 66 39 | RGB_train/-G-5CJ0JkKY 250 254 40 | RGB_train/4uZ27ivBl00 300 341 41 | RGB_train/pZP-dHUuGiA 240 369 42 | ``` 43 | - This code can read the image files in each video frame folder according to the image template argument *image_tmpl*, such as *image_{:06d}.jpg*. 44 | 45 | ## Training 46 | Execute training script: 47 | ```Shell 48 | ./scripts/kinetics400_3d_res50_slowonly_im_pre.sh 49 | ``` 50 | 51 | We show script *kinetics400_3d_res50_slowonly_im_pre.sh* here: 52 | ```Shell 53 | python main.py \ 54 | kinetics400 \ 55 | data/kinetics400/kinetics_train_list_xlw \ 56 | data/kinetics400/kinetics_val_list_xlw \ 57 | --arch resnet50_3d_slowonly \ 58 | --dro 0.5 \ 59 | --mode 3D \ 60 | --t_length 8 \ 61 | --t_stride 8 \ 62 | --pretrained \ 63 | --epochs 110 \ 64 | --batch-size 96 \ 65 | --lr 0.02 \ 66 | --wd 0.0001 \ 67 | --lr_steps 50 80 100 \ 68 | --workers 16 \ 69 | ``` 70 | 71 | ## Testing 72 | ```Shell 73 | python ./test_kaiming.py \ 74 | kinetics400 \ 75 | data/kinetics400/kinetics_val_list_xlw \ 76 | output/kinetics400_resnet50_3d_slowonly_3D_length8_stride8_dropout0.5/model_best.pth \ 77 | --arch resnet50_3d_slowonly \ 78 | --mode TSN+3D \ 79 | --batch_size 1 \ 80 | --num_segments 10 \ 81 | --input_size 256 \ 82 | --t_length 8 \ 83 | --t_stride 8 \ 84 | --dropout 0.5 \ 85 | --workers 12 \ 86 | --image_tmpl image_{:06d}.jpg \ 87 | 88 | ``` 89 | -------------------------------------------------------------------------------- /clip_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | import torchvision 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | from lib.dataset import VideoDataSet 14 | from lib.models import VideoModule 15 | from lib.transforms import * 16 | from lib.utils.tools import * 17 | from lib.opts import args 18 | 19 | from train_val import train, validate 20 | 21 | def main(): 22 | global args, best_metric 23 | 24 | # specify dataset 25 | if args.dataset == 'ucf101': 26 | num_class = 101 27 | elif args.dataset == 'hmdb51': 28 | num_class = 51 29 | elif args.dataset == 'kinetics400': 30 | num_class = 400 31 | elif args.dataset == 'kinetics200': 32 | num_class = 200 33 | else: 34 | raise ValueError('Unknown dataset '+args.dataset) 35 | 36 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 37 | "data/{}/access".format(args.dataset)) 38 | 39 | # create model 40 | org_model = VideoModule(num_class=num_class, 41 | base_model_name=args.arch, 42 | dropout=args.dropout, 43 | pretrained=args.pretrained, 44 | pretrained_model=args.pretrained_model) 45 | num_params = 0 46 | for param in org_model.parameters(): 47 | num_params += param.reshape((-1, 1)).shape[0] 48 | print("Model Size is {:.3f}M".format(num_params/1000000)) 49 | 50 | model = torch.nn.DataParallel(org_model).cuda() 51 | 52 | criterion = torch.nn.CrossEntropyLoss().cuda() 53 | 54 | optimizer = torch.optim.SGD(model.parameters(), 55 | args.lr, 56 | momentum=args.momentum, 57 | weight_decay=args.weight_decay) 58 | 59 | # optionally resume from a checkpoint 60 | if args.resume: 61 | if os.path.isfile(args.resume): 62 | print(("=> loading checkpoint '{}'".format(args.resume))) 63 | checkpoint = torch.load(args.resume) 64 | args.start_epoch = checkpoint['epoch'] 65 | best_metric = checkpoint['best_metric'] 66 | model.load_state_dict(checkpoint['state_dict']) 67 | optimizer.load_state_dict(checkpoint['optimizer']) 68 | print(("=> loaded checkpoint '{}' (epoch {})" 69 | .format(args.resume, checkpoint['epoch']))) 70 | else: 71 | print(("=> no checkpoint found at '{}'".format(args.resume))) 72 | 73 | ## val data 74 | val_transform = torchvision.transforms.Compose([ 75 | GroupScale(args.new_size), 76 | GroupCenterCrop(args.crop_size), 77 | Stack(mode=args.mode), 78 | ToTorchFormatTensor(), 79 | GroupNormalize(), 80 | ]) 81 | val_dataset = VideoDataSet(root_path=data_root, 82 | list_file=args.val_list, 83 | t_length=args.t_length, 84 | t_stride=args.t_stride, 85 | num_segments=args.num_segments, 86 | image_tmpl=args.image_tmpl, 87 | transform=val_transform, 88 | phase="Val") 89 | val_loader = torch.utils.data.DataLoader( 90 | val_dataset, 91 | batch_size=args.batch_size, shuffle=False, 92 | num_workers=args.workers, pin_memory=True) 93 | 94 | if args.mode != "3D": 95 | cudnn.benchmark = True 96 | 97 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /compute_flops.py: -------------------------------------------------------------------------------- 1 | from lib.networks.part_inflate_resnet_3d import * 2 | from lib.modules import * 3 | import torch 4 | from lib.networks.km_resnet_3d_beta import TKMConv, compute_tkmconv, km_resnet26_3d_v2_sample, km_resnet50_3d_v2_sample 5 | def count_GloAvgPool3d(m, x, y): 6 | m.total_ops = torch.Tensor([int(0)]) 7 | 8 | from thop import profile 9 | model = km_resnet50_3d_v2_sample() 10 | model.fc = torch.nn.Linear(2048, 400) 11 | flops, params = profile(model, input_size=(1, 3, 8, 224,224), custom_ops={GloAvgPool3d: count_GloAvgPool3d, TKMConv: compute_tkmconv}) 12 | print("params: {}".format(params/1000000)) 13 | print("flops: {}".format(flops/1000000000)) 14 | -------------------------------------------------------------------------------- /data/kinetics200/400_200_label_mapping.txt: -------------------------------------------------------------------------------- 1 | 0 0 2 | 1 1 3 | 5 2 4 | 6 3 5 | 11 4 6 | 12 5 7 | 14 6 8 | 16 7 9 | 18 8 10 | 19 9 11 | 22 10 12 | 24 11 13 | 27 12 14 | 29 13 15 | 31 14 16 | 32 15 17 | 34 16 18 | 36 17 19 | 37 18 20 | 40 19 21 | 41 20 22 | 42 21 23 | 43 22 24 | 48 23 25 | 49 24 26 | 50 25 27 | 55 26 28 | 56 27 29 | 59 28 30 | 60 29 31 | 68 30 32 | 69 31 33 | 70 32 34 | 75 33 35 | 77 34 36 | 78 35 37 | 79 36 38 | 80 37 39 | 83 38 40 | 84 39 41 | 86 40 42 | 87 41 43 | 88 42 44 | 93 43 45 | 97 44 46 | 99 45 47 | 103 46 48 | 104 47 49 | 107 48 50 | 108 49 51 | 109 50 52 | 115 51 53 | 116 52 54 | 123 53 55 | 124 54 56 | 125 55 57 | 126 56 58 | 127 57 59 | 130 58 60 | 132 59 61 | 133 60 62 | 134 61 63 | 140 62 64 | 142 63 65 | 143 64 66 | 147 65 67 | 148 66 68 | 149 67 69 | 151 68 70 | 152 69 71 | 153 70 72 | 159 71 73 | 161 72 74 | 162 73 75 | 164 74 76 | 166 75 77 | 167 76 78 | 169 77 79 | 172 78 80 | 174 79 81 | 177 80 82 | 180 81 83 | 182 82 84 | 183 83 85 | 188 84 86 | 189 85 87 | 192 86 88 | 193 87 89 | 197 88 90 | 199 89 91 | 201 90 92 | 204 91 93 | 205 92 94 | 206 93 95 | 208 94 96 | 209 95 97 | 212 96 98 | 214 97 99 | 217 98 100 | 218 99 101 | 219 100 102 | 220 101 103 | 221 102 104 | 223 103 105 | 224 104 106 | 225 105 107 | 227 106 108 | 229 107 109 | 230 108 110 | 232 109 111 | 233 110 112 | 234 111 113 | 235 112 114 | 240 113 115 | 242 114 116 | 243 115 117 | 244 116 118 | 245 117 119 | 246 118 120 | 247 119 121 | 248 120 122 | 249 121 123 | 250 122 124 | 251 123 125 | 252 124 126 | 253 125 127 | 254 126 128 | 255 127 129 | 256 128 130 | 258 129 131 | 261 130 132 | 262 131 133 | 264 132 134 | 269 133 135 | 273 134 136 | 275 135 137 | 277 136 138 | 278 137 139 | 280 138 140 | 282 139 141 | 283 140 142 | 285 141 143 | 286 142 144 | 289 143 145 | 291 144 146 | 292 145 147 | 294 146 148 | 298 147 149 | 299 148 150 | 301 149 151 | 302 150 152 | 304 151 153 | 305 152 154 | 306 153 155 | 307 154 156 | 308 155 157 | 313 156 158 | 315 157 159 | 316 158 160 | 317 159 161 | 318 160 162 | 321 161 163 | 322 162 164 | 323 163 165 | 325 164 166 | 326 165 167 | 327 166 168 | 330 167 169 | 331 168 170 | 334 169 171 | 336 170 172 | 337 171 173 | 339 172 174 | 340 173 175 | 346 174 176 | 348 175 177 | 349 176 178 | 350 177 179 | 356 178 180 | 358 179 181 | 360 180 182 | 364 181 183 | 365 182 184 | 367 183 185 | 369 184 186 | 371 185 187 | 373 186 188 | 378 187 189 | 379 188 190 | 380 189 191 | 382 190 192 | 383 191 193 | 387 192 194 | 389 193 195 | 390 194 196 | 391 195 197 | 393 196 198 | 394 197 199 | 398 198 200 | 399 199 201 | -------------------------------------------------------------------------------- /data/kinetics200/create_kinetics200_list.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | # extract target 200-class videos from the original videos 4 | with open('kinetics_train_list.txt') as tr400: 5 | with open('Mini-Kinetics-200/train_ytid_list.txt') as miniTr: 6 | with open('kinetics200_train_list_org.txt', 'w') as tr200: 7 | # build indices for original 400-class train list 8 | lines = tr400.readlines() 9 | ytid_line_dict = dict() 10 | for line in lines: 11 | ytid = line.strip().split()[0].split('/')[1] 12 | ytid_line_dict[ytid] = line 13 | # extract target lines and write them into tr200 file 14 | lines = miniTr.readlines() 15 | for line in lines: 16 | ytid = line.strip() 17 | if ytid in ytid_line_dict: 18 | target_line = ytid_line_dict[ytid] 19 | tr200.write(target_line) 20 | else: 21 | print("{} is not in original video list".format(ytid)) 22 | 23 | with open('kinetics_val_list.txt') as va400: 24 | with open('Mini-Kinetics-200/val_ytid_list.txt') as miniVa: 25 | with open('kinetics200_val_list_org.txt', 'w') as va200: 26 | # build indices for original 400-class val list 27 | lines = va400.readlines() 28 | ytid_line_dict = dict() 29 | for line in lines: 30 | ytid = line.strip().split()[0].split('/')[1] 31 | ytid_line_dict[ytid] = line 32 | # extract target lines and write them into va200 file 33 | lines = miniVa.readlines() 34 | for line in lines: 35 | ytid = line.strip() 36 | if ytid in ytid_line_dict: 37 | target_line = ytid_line_dict[ytid] 38 | va200.write(target_line) 39 | else: 40 | print("{} is not in original video list".format(ytid)) 41 | 42 | # summarize all the 200 categories of Mini-Kinetics 43 | # Train and val 44 | cats_tr = set() 45 | cats_va = set() 46 | 47 | with open("kinetics200_train_list_org.txt") as f: 48 | lines = f.readlines() 49 | for line in lines: 50 | label_id = int(line.strip().split()[-1]) 51 | cats_tr.add(label_id) 52 | 53 | with open("kinetics200_val_list_org.txt") as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | label_id = int(line.strip().split()[-1]) 57 | cats_va.add(label_id) 58 | 59 | assert(cats_tr == cats_va) 60 | 61 | # build 400-class 200-class dictionary 62 | _400_200_dict = dict() 63 | for i, cat in enumerate(cats_tr): 64 | _400_200_dict[cat] = i 65 | 66 | with open('400_200_label_mapping.txt', 'w') as f: 67 | for key, value in _400_200_dict.items(): 68 | f.write("{} {}\n".format(key, value)) 69 | 70 | with open('kinetics200_train_list_org.txt') as f_src: 71 | with open('kinetics200_train_list.txt', 'w') as f_dst: 72 | lines = f_src.readlines() 73 | for line in lines: 74 | items = line.strip().split() 75 | items[-1] = str(_400_200_dict[int(items[-1])]) 76 | new_line = ' '.join(items) 77 | f_dst.write(new_line + '\n') 78 | 79 | with open('kinetics200_val_list_org.txt') as f_src: 80 | with open('kinetics200_val_list.txt', 'w') as f_dst: 81 | lines = f_src.readlines() 82 | for line in lines: 83 | items = line.strip().split() 84 | items[-1] = str(_400_200_dict[int(items[-1])]) 85 | new_line = ' '.join(items) 86 | f_dst.write(new_line + '\n') 87 | 88 | # pdb.set_trace() -------------------------------------------------------------------------------- /data/kinetics400/count.py: -------------------------------------------------------------------------------- 1 | frames = [] 2 | with open("kinetics_val_list.txt") as f: 3 | lines = f.readlines() 4 | for line in lines: 5 | items = line.strip().split() 6 | frames.append(int(items[1])) 7 | 8 | total = len(frames) 9 | count60 = 0 10 | count120 = 0 11 | count240 = 0 12 | 13 | for fr in frames: 14 | if fr > 60: 15 | count60 += 1 16 | if fr > 120: 17 | count120 += 1 18 | if fr > 240: 19 | count240 += 1 20 | 21 | print("60: ", count60, total, count60/total) 22 | print("120: ", count120, total, count120/total) 23 | print("240: ", count240, total, count240/total) -------------------------------------------------------------------------------- /data/kinetics400/create_xlw_list.py: -------------------------------------------------------------------------------- 1 | access = "access/" 2 | import os 3 | from tqdm import tqdm 4 | 5 | # with open("kinetics_val_list.txt") as f_old: 6 | # with open("kinetics_val_list_xlw", 'w') as f_new: 7 | # old_lines = f_old.readlines() 8 | # for line in old_lines: 9 | # vid_path, num_fr, label = line.strip().split() 10 | # if os.path.exists(access+vid_path): 11 | # new_num_fr = len(os.listdir(access+vid_path)) 12 | # f_new.write(" ".join([vid_path, str(new_num_fr), label]) + '\n') 13 | 14 | # with open("kinetics_train_list.txt") as f_old: 15 | # with open("kinetics_train_list_xlw", 'w') as f_new: 16 | # old_lines = f_old.readlines() 17 | # for line in old_lines: 18 | # vid_path, num_fr, label = line.strip().split() 19 | # if os.path.exists(access+vid_path): 20 | # new_num_fr = len(os.listdir(access+vid_path)) 21 | # f_new.write(" ".join([vid_path, str(new_num_fr), label]) + '\n') 22 | 23 | with open("kinetics_train_list_xlw") as f: 24 | lines = f.readlines() 25 | for line in tqdm(lines): 26 | vid_path, num_fr, label = line.strip().split() 27 | images = os.listdir(access+vid_path) 28 | images.sort() 29 | last_image = images[-1] 30 | # import pdb 31 | # pdb.set_trace() 32 | if int(last_image[6:-4]) != int(num_fr): 33 | print(vid_path) -------------------------------------------------------------------------------- /data/sthsth_v1/create_sthsth_v1_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import collections 4 | from collections import OrderedDict 5 | 6 | frame_root = "/media/SSD/zhoulei/20bn-something-something-v1" 7 | 8 | f_tr = open("sthv1_train_list.txt", 'w') 9 | f_va = open("sthv1_val_list.txt", 'w') 10 | 11 | name_id = OrderedDict() 12 | with open('something-something-v1-labels.csv', newline='') as csvfile: 13 | reader = csv.reader(csvfile, delimiter=';') 14 | for i, row in enumerate(reader): 15 | assert(len(row) == 1), "the length of row must be one" 16 | name_id[row[0]] = i 17 | 18 | with open('something-something-v1-train.csv', newline='') as csvfile: 19 | reader = csv.reader(csvfile, delimiter=';') 20 | for row in reader: 21 | dir_name = row[0] 22 | class_name = row[1] 23 | class_id = name_id[class_name] 24 | 25 | vid_dir = os.path.join(frame_root, dir_name) 26 | frame_num = len(os.listdir(vid_dir)) 27 | 28 | line = ' '.join(("RGB/"+dir_name, str(frame_num), str(class_id)+'\n')) 29 | f_tr.write(line) 30 | 31 | with open('something-something-v1-validation.csv', newline='') as csvfile: 32 | reader = csv.reader(csvfile, delimiter=';') 33 | for row in reader: 34 | dir_name = row[0] 35 | class_name = row[1] 36 | class_id = name_id[class_name] 37 | 38 | vid_dir = os.path.join(frame_root, dir_name) 39 | frame_num = len(os.listdir(vid_dir)) 40 | 41 | line = ' '.join(("RGB/"+dir_name, str(frame_num), str(class_id)+'\n')) 42 | f_va.write(line) -------------------------------------------------------------------------------- /data/sthsth_v1/something-something-v1-labels.csv: -------------------------------------------------------------------------------- 1 | Holding something 2 | Turning something upside down 3 | Turning the camera left while filming something 4 | Stacking number of something 5 | Turning the camera right while filming something 6 | Opening something 7 | Approaching something with your camera 8 | Picking something up 9 | Pushing something so that it almost falls off but doesn't 10 | Folding something 11 | Moving something away from the camera 12 | Closing something 13 | Moving away from something with your camera 14 | Turning the camera downwards while filming something 15 | Pushing something so that it slightly moves 16 | Turning the camera upwards while filming something 17 | Pretending to pick something up 18 | Showing something to the camera 19 | Moving something up 20 | Plugging something into something 21 | Unfolding something 22 | Putting something onto something 23 | Showing that something is empty 24 | Pretending to put something on a surface 25 | Taking something from somewhere 26 | Putting something next to something 27 | Moving something towards the camera 28 | Showing a photo of something to the camera 29 | Pushing something with something 30 | Throwing something 31 | Pushing something from left to right 32 | Something falling like a feather or paper 33 | Throwing something in the air and letting it fall 34 | Throwing something against something 35 | Lifting something with something on it 36 | Taking one of many similar things on the table 37 | Showing something behind something 38 | Putting something into something 39 | Tearing something just a little bit 40 | Moving something away from something 41 | Tearing something into two pieces 42 | Pushing something from right to left 43 | Holding something next to something 44 | Putting something, something and something on the table 45 | Pretending to take something from somewhere 46 | Moving something closer to something 47 | Pretending to put something next to something 48 | Uncovering something 49 | Something falling like a rock 50 | Putting something and something on the table 51 | Pouring something into something 52 | Moving something down 53 | Pulling something from right to left 54 | Throwing something in the air and catching it 55 | Tilting something with something on it until it falls off 56 | Putting something in front of something 57 | Pretending to turn something upside down 58 | Putting something on a surface 59 | Pretending to throw something 60 | Showing something on top of something 61 | Covering something with something 62 | Squeezing something 63 | Putting something similar to other things that are already on the table 64 | Lifting up one end of something, then letting it drop down 65 | Taking something out of something 66 | Moving part of something 67 | Pulling something from left to right 68 | Lifting something up completely without letting it drop down 69 | Attaching something to something 70 | Putting something behind something 71 | Moving something and something closer to each other 72 | Holding something in front of something 73 | Pushing something so that it falls off the table 74 | Holding something over something 75 | Pretending to open something without actually opening it 76 | Removing something, revealing something behind 77 | Hitting something with something 78 | Moving something and something away from each other 79 | Touching (without moving) part of something 80 | Pretending to put something into something 81 | Showing that something is inside something 82 | Lifting something up completely, then letting it drop down 83 | Pretending to take something out of something 84 | Holding something behind something 85 | Laying something on the table on its side, not upright 86 | Poking something so it slightly moves 87 | Pretending to close something without actually closing it 88 | Putting something upright on the table 89 | Dropping something in front of something 90 | Dropping something behind something 91 | Lifting up one end of something without letting it drop down 92 | Rolling something on a flat surface 93 | Throwing something onto a surface 94 | Showing something next to something 95 | Dropping something onto something 96 | Stuffing something into something 97 | Dropping something into something 98 | Piling something up 99 | Letting something roll along a flat surface 100 | Twisting something 101 | Spinning something that quickly stops spinning 102 | Putting number of something onto something 103 | Putting something underneath something 104 | Moving something across a surface without it falling down 105 | Plugging something into something but pulling it right out as you remove your hand 106 | Dropping something next to something 107 | Poking something so that it falls over 108 | Spinning something so it continues spinning 109 | Poking something so lightly that it doesn't or almost doesn't move 110 | Wiping something off of something 111 | Moving something across a surface until it falls down 112 | Pretending to poke something 113 | Putting something that cannot actually stand upright upright on the table, so it falls on its side 114 | Pulling something out of something 115 | Scooping something up with something 116 | Pretending to be tearing something that is not tearable 117 | Burying something in something 118 | Tipping something over 119 | Tilting something with something on it slightly so it doesn't fall down 120 | Pretending to put something onto something 121 | Bending something until it breaks 122 | Letting something roll down a slanted surface 123 | Trying to bend something unbendable so nothing happens 124 | Bending something so that it deforms 125 | Digging something out of something 126 | Pretending to put something underneath something 127 | Putting something on a flat surface without letting it roll 128 | Putting something on the edge of something so it is not supported and falls down 129 | Spreading something onto something 130 | Pretending to put something behind something 131 | Sprinkling something onto something 132 | Something colliding with something and both come to a halt 133 | Pushing something off of something 134 | Putting something that can't roll onto a slanted surface, so it stays where it is 135 | Lifting a surface with something on it until it starts sliding down 136 | Pretending or failing to wipe something off of something 137 | Trying but failing to attach something to something because it doesn't stick 138 | Pulling something from behind of something 139 | Pushing something so it spins 140 | Pouring something onto something 141 | Pulling two ends of something but nothing happens 142 | Moving something and something so they pass each other 143 | Pretending to sprinkle air onto something 144 | Putting something that can't roll onto a slanted surface, so it slides down 145 | Something colliding with something and both are being deflected 146 | Pretending to squeeze something 147 | Pulling something onto something 148 | Putting something onto something else that cannot support it so it falls down 149 | Lifting a surface with something on it but not enough for it to slide down 150 | Pouring something out of something 151 | Moving something and something so they collide with each other 152 | Tipping something with something in it over, so something in it falls out 153 | Letting something roll up a slanted surface, so it rolls back down 154 | Pretending to scoop something up with something 155 | Pretending to pour something out of something, but something is empty 156 | Pulling two ends of something so that it gets stretched 157 | Failing to put something into something because something does not fit 158 | Pretending or trying and failing to twist something 159 | Trying to pour something into something, but missing so it spills next to it 160 | Something being deflected from something 161 | Poking a stack of something so the stack collapses 162 | Spilling something onto something 163 | Pulling two ends of something so that it separates into two pieces 164 | Pouring something into something until it overflows 165 | Pretending to spread air onto something 166 | Twisting (wringing) something wet until water comes out 167 | Poking a hole into something soft 168 | Spilling something next to something 169 | Poking a stack of something without the stack collapsing 170 | Putting something onto a slanted surface but it doesn't glide down 171 | Pushing something onto something 172 | Poking something so that it spins around 173 | Spilling something behind something 174 | Poking a hole into some substance 175 | -------------------------------------------------------------------------------- /finetune_bn_frozen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | import torchvision 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | from lib.dataset import VideoDataSet 14 | from lib.models import VideoModule 15 | from lib.transforms import * 16 | from lib.utils.tools import * 17 | from lib.opts import args 18 | from lib.modules import * 19 | 20 | from train_val import train, validate,finetune_bn_frozen 21 | 22 | best_metric = 0 23 | 24 | def main(): 25 | global args, best_metric 26 | 27 | if 'ucf101' in args.dataset: 28 | num_class = 101 29 | elif 'hmdb51' in args.dataset: 30 | num_class = 51 31 | elif args.dataset == 'kinetics400': 32 | num_class = 400 33 | elif args.dataset == 'kinetics200': 34 | num_class = 200 35 | else: 36 | raise ValueError('Unknown dataset '+args.dataset) 37 | 38 | # data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 39 | # "data/{}/access".format(args.dataset)) 40 | 41 | if "ucf101" in args.dataset or "hmdb51" in args.dataset: 42 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 43 | "data/{}/access".format(args.dataset[:-3])) 44 | else: 45 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 46 | "data/{}/access".format(args.dataset)) 47 | 48 | # create model 49 | org_model = VideoModule(num_class=num_class, 50 | base_model_name=args.arch, 51 | dropout=args.dropout, 52 | pretrained=args.pretrained, 53 | pretrained_model=args.pretrained_model) 54 | num_params = 0 55 | for param in org_model.parameters(): 56 | num_params += param.reshape((-1, 1)).shape[0] 57 | print("Model Size is {:.3f}M".format(num_params/1000000)) 58 | 59 | model = torch.nn.DataParallel(org_model).cuda() 60 | # model = org_model 61 | 62 | # define loss function (criterion) and optimizer 63 | criterion = torch.nn.CrossEntropyLoss().cuda() 64 | 65 | # optim_params = [param[1] for param in model.named_parameters() if "classifier" in param[0]] 66 | # import pdb 67 | # pdb.set_trace() 68 | optimizer = torch.optim.SGD(model.parameters(), 69 | args.lr, 70 | momentum=args.momentum, 71 | weight_decay=args.weight_decay) 72 | 73 | # optionally resume from a checkpoint 74 | if args.resume: 75 | if os.path.isfile(args.resume): 76 | print(("=> loading checkpoint '{}'".format(args.resume))) 77 | checkpoint = torch.load(args.resume) 78 | args.start_epoch = checkpoint['epoch'] 79 | best_metric = checkpoint['best_metric'] 80 | model.load_state_dict(checkpoint['state_dict']) 81 | optimizer.load_state_dict(checkpoint['optimizer']) 82 | print(("=> loaded checkpoint '{}' (epoch {})" 83 | .format(args.resume, checkpoint['epoch']))) 84 | else: 85 | print(("=> no checkpoint found at '{}'".format(args.resume))) 86 | 87 | # Data loading code 88 | ## train data 89 | train_transform = torchvision.transforms.Compose([ 90 | GroupScale(args.new_size), 91 | GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]), 92 | GroupRandomHorizontalFlip(), 93 | Stack(mode=args.mode), 94 | ToTorchFormatTensor(), 95 | GroupNormalize(), 96 | ]) 97 | train_dataset = VideoDataSet(root_path=data_root, 98 | list_file=args.train_list, 99 | t_length=args.t_length, 100 | t_stride=args.t_stride, 101 | num_segments=args.num_segments, 102 | image_tmpl=args.image_tmpl, 103 | transform=train_transform, 104 | phase="Train") 105 | train_loader = torch.utils.data.DataLoader( 106 | train_dataset, 107 | batch_size=args.batch_size, shuffle=True, drop_last=True, 108 | num_workers=args.workers, pin_memory=True) 109 | 110 | ## val data 111 | val_transform = torchvision.transforms.Compose([ 112 | GroupScale(args.new_size), 113 | GroupCenterCrop(args.crop_size), 114 | Stack(mode=args.mode), 115 | ToTorchFormatTensor(), 116 | GroupNormalize(), 117 | ]) 118 | val_dataset = VideoDataSet(root_path=data_root, 119 | list_file=args.val_list, 120 | t_length=args.t_length, 121 | t_stride=args.t_stride, 122 | num_segments=args.num_segments, 123 | image_tmpl=args.image_tmpl, 124 | transform=val_transform, 125 | phase="Val") 126 | val_loader = torch.utils.data.DataLoader( 127 | val_dataset, 128 | batch_size=args.batch_size, shuffle=False, 129 | num_workers=args.workers, pin_memory=True) 130 | 131 | if args.mode != "3D": 132 | cudnn.benchmark = True 133 | 134 | if args.resume: 135 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 136 | torch.cuda.empty_cache() 137 | 138 | for epoch in range(args.start_epoch, args.epochs): 139 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps) 140 | 141 | # train for one epoch 142 | finetune_bn_frozen(train_loader, model, criterion, optimizer, epoch, args.print_freq) 143 | 144 | # evaluate on validation set 145 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 146 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1) 147 | torch.cuda.empty_cache() 148 | 149 | # remember best prec@1 and save checkpoint 150 | is_best = metric > best_metric 151 | best_metric = max(metric, best_metric) 152 | save_checkpoint({ 153 | 'epoch': epoch + 1, 154 | 'arch': args.arch, 155 | 'state_dict': model.state_dict(), 156 | 'best_metric': best_metric, 157 | 'optimizer': optimizer.state_dict(), 158 | }, is_best, epoch + 1, args.experiment_root) 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /finetune_fc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | import torchvision 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | from lib.dataset import VideoDataSet 14 | from lib.models import VideoModule 15 | from lib.transforms import * 16 | from lib.utils.tools import * 17 | from lib.opts import args 18 | 19 | from train_val import train, validate, finetune_fc 20 | 21 | best_metric = 0 22 | 23 | def main(): 24 | global args, best_metric 25 | 26 | # specify dataset 27 | if 'ucf101' in args.dataset: 28 | num_class = 101 29 | elif 'hmdb51' in args.dataset: 30 | num_class = 51 31 | elif args.dataset == 'kinetics400': 32 | num_class = 400 33 | elif args.dataset == 'kinetics200': 34 | num_class = 200 35 | else: 36 | raise ValueError('Unknown dataset '+args.dataset) 37 | 38 | if "ucf101" in args.dataset or "hmdb51" in args.dataset: 39 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 40 | "data/{}/access".format(args.dataset[:-3])) 41 | else: 42 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 43 | "data/{}/access".format(args.dataset)) 44 | 45 | # create model 46 | org_model = VideoModule(num_class=num_class, 47 | base_model_name=args.arch, 48 | dropout=args.dropout, 49 | pretrained=args.pretrained, 50 | pretrained_model=args.pretrained_model) 51 | num_params = 0 52 | for param in org_model.parameters(): 53 | num_params += param.reshape((-1, 1)).shape[0] 54 | print("Model Size is {:.3f}M".format(num_params/1000000)) 55 | 56 | model = torch.nn.DataParallel(org_model).cuda() 57 | # model = org_model 58 | 59 | # define loss function (criterion) and optimizer 60 | criterion = torch.nn.CrossEntropyLoss().cuda() 61 | 62 | optim_params = [param[1] for param in model.named_parameters() if "classifier" in param[0]] 63 | # import pdb 64 | # pdb.set_trace() 65 | optimizer = torch.optim.SGD(optim_params, 66 | args.lr, 67 | momentum=args.momentum, 68 | weight_decay=args.weight_decay) 69 | 70 | # optionally resume from a checkpoint 71 | if args.resume: 72 | if os.path.isfile(args.resume): 73 | print(("=> loading checkpoint '{}'".format(args.resume))) 74 | checkpoint = torch.load(args.resume) 75 | args.start_epoch = checkpoint['epoch'] 76 | best_metric = checkpoint['best_metric'] 77 | model.load_state_dict(checkpoint['state_dict']) 78 | optimizer.load_state_dict(checkpoint['optimizer']) 79 | print(("=> loaded checkpoint '{}' (epoch {})" 80 | .format(args.resume, checkpoint['epoch']))) 81 | else: 82 | print(("=> no checkpoint found at '{}'".format(args.resume))) 83 | 84 | # Data loading code 85 | ## train data 86 | train_transform = torchvision.transforms.Compose([ 87 | GroupScale(args.new_size), 88 | GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]), 89 | GroupRandomHorizontalFlip(), 90 | Stack(mode=args.mode), 91 | ToTorchFormatTensor(), 92 | GroupNormalize(), 93 | ]) 94 | train_dataset = VideoDataSet(root_path=data_root, 95 | list_file=args.train_list, 96 | t_length=args.t_length, 97 | t_stride=args.t_stride, 98 | num_segments=args.num_segments, 99 | image_tmpl=args.image_tmpl, 100 | transform=train_transform, 101 | phase="Train") 102 | train_loader = torch.utils.data.DataLoader( 103 | train_dataset, 104 | batch_size=args.batch_size, shuffle=True, drop_last=True, 105 | num_workers=args.workers, pin_memory=True) 106 | 107 | ## val data 108 | val_transform = torchvision.transforms.Compose([ 109 | GroupScale(args.new_size), 110 | GroupCenterCrop(args.crop_size), 111 | Stack(mode=args.mode), 112 | ToTorchFormatTensor(), 113 | GroupNormalize(), 114 | ]) 115 | val_dataset = VideoDataSet(root_path=data_root, 116 | list_file=args.val_list, 117 | t_length=args.t_length, 118 | t_stride=args.t_stride, 119 | num_segments=args.num_segments, 120 | image_tmpl=args.image_tmpl, 121 | transform=val_transform, 122 | phase="Val") 123 | val_loader = torch.utils.data.DataLoader( 124 | val_dataset, 125 | batch_size=args.batch_size, shuffle=False, 126 | num_workers=args.workers, pin_memory=True) 127 | 128 | if args.mode != "3D": 129 | cudnn.benchmark = True 130 | 131 | # validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 132 | 133 | for epoch in range(args.start_epoch, args.epochs): 134 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps) 135 | 136 | # train for one epoch 137 | finetune_fc(train_loader, model, criterion, optimizer, epoch, args.print_freq) 138 | 139 | # evaluate on validation set 140 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 141 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1) 142 | 143 | # remember best prec@1 and save checkpoint 144 | is_best = metric > best_metric 145 | best_metric = max(metric, best_metric) 146 | save_checkpoint({ 147 | 'epoch': epoch + 1, 148 | 'arch': args.arch, 149 | 'state_dict': model.state_dict(), 150 | 'best_metric': best_metric, 151 | 'optimizer': optimizer.state_dict(), 152 | }, is_best, epoch + 1, args.experiment_root) 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from numpy.random import randint 8 | 9 | import torch 10 | 11 | class VideoRecord(object): 12 | def __init__(self, row, root_path): 13 | self._data = row 14 | self._root_path = root_path 15 | 16 | @property 17 | def path(self): 18 | return os.path.join(self._root_path, self._data[0]) 19 | 20 | @property 21 | def num_frames(self): 22 | return int(self._data[1]) 23 | 24 | @property 25 | def label(self): 26 | return int(self._data[2]) 27 | 28 | class VideoDebugDataSet(data.Dataset): 29 | """ 30 | """ 31 | def __len__(self): 32 | return 100 33 | 34 | def __getitem__(self, index): 35 | np.random.seed(12345) 36 | input_tensor = (np.random.random_sample((3,18,224,224)) - 0.5) * 2 37 | return torch.from_numpy(input_tensor).to(torch.float), 0 38 | 39 | class VideoDataSet(data.Dataset): 40 | def __init__(self, root_path, list_file, 41 | t_length=32, t_stride=2, num_segments=1, 42 | image_tmpl='img_{:05d}.jpg', 43 | transform=None, style="Dense", 44 | phase="Train"): 45 | """ 46 | :style: Dense, for 2D and 3D model, and Sparse for TSN model 47 | :phase: Train, Val, Test 48 | """ 49 | 50 | self.root_path = root_path 51 | self.list_file = list_file 52 | self.t_length = t_length 53 | self.t_stride = t_stride 54 | self.num_segments = num_segments 55 | self.image_tmpl = image_tmpl 56 | self.transform = transform 57 | assert(style in ("Dense", "UnevenDense")), "Only support Dense and UnevenDense" 58 | self.style = style 59 | self.phase = phase 60 | assert(t_length > 0), "Length of time must be bigger than zero." 61 | assert(t_stride > 0), "Stride of time must be bigger than zero." 62 | 63 | self._parse_list() 64 | 65 | def _load_image(self, directory, idx): 66 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 67 | 68 | def _parse_list(self): 69 | self.video_list = [VideoRecord(x.strip().split(' '), self.root_path) for x in open(self.list_file)] 70 | # self.video_list = [VideoRecord(x.strip().split(' '), self.root_path) for x in open(self.list_file) if VideoRecord(x.strip().split(' '), self.root_path).num_frames > 240] 71 | # print(len(self.video_list)) 72 | 73 | @staticmethod 74 | def dense_sampler(num_frames, length, stride=1): 75 | t_length = length 76 | t_stride = stride 77 | # compute offsets 78 | offset = 0 79 | average_duration = num_frames - (t_length - 1) * t_stride - 1 80 | if average_duration >= 0: 81 | offset = randint(average_duration + 1) 82 | elif num_frames > t_length: 83 | while(t_stride - 1 > 0): 84 | t_stride -= 1 85 | average_duration = num_frames - (t_length - 1) * t_stride - 1 86 | if average_duration >= 0: 87 | offset = randint(average_duration + 1) 88 | break 89 | assert(t_stride >= 1), "temporal stride must be bigger than zero." 90 | else: 91 | t_stride = 1 92 | # sampling 93 | samples = [] 94 | for i in range(t_length): 95 | samples.append(offset + i * t_stride + 1) 96 | return samples 97 | 98 | def _sample_indices(self, record): 99 | """ 100 | :param record: VideoRecord 101 | :return: list 102 | """ 103 | if self.style == "Dense": 104 | frames = [] 105 | average_duration = record.num_frames / self.num_segments 106 | offsets = [average_duration * i for i in range(self.num_segments)] 107 | for i in range(self.num_segments): 108 | samples = self.dense_sampler(average_duration, self.t_length, self.t_stride) 109 | samples = [sample + offsets[i] for sample in samples] 110 | frames.extend(samples) 111 | return {"dense": frames} 112 | elif self.style == "UnevenDense": 113 | sparse_frames = [] 114 | average_duration = record.num_frames / self.num_segments 115 | offsets = [average_duration * i for i in range(self.num_segments)] 116 | dense_frames = self.dense_sampler(record.num_frames, self.t_length, self.t_stride) 117 | dense_seg = -1 118 | for i in range(self.num_segments): 119 | if dense_frames[self.t_length//2] >= offsets[self.num_segments - i - 1]: 120 | dense_seg = self.num_segments - i - 1 121 | break 122 | else: 123 | continue 124 | assert(dense_seg != -1) 125 | # dense_seg = randint(self.num_segments) 126 | for i in range(self.num_segments): 127 | # if i == dense_seg: 128 | # samples = self.dense_sampler(average_duration, self.t_length, self.t_stride) 129 | # samples = [sample + offsets[i] for sample in samples] 130 | # dense_frames.extend(samples) 131 | # dense_seg = -1 # set dense seg to -1 and check after sampling. 132 | if i != dense_seg: 133 | samples = self.dense_sampler(average_duration, 1) 134 | samples = [sample + offsets[i] for sample in samples] 135 | sparse_frames.extend(samples) 136 | return {"dense":dense_frames, "sparse":sparse_frames} 137 | else: 138 | return 139 | 140 | def _get_val_indices(self, record): 141 | """ 142 | get indices in val phase 143 | """ 144 | # valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1 145 | valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1 146 | offset = int(valid_offset_range / 2.0) 147 | if offset < 0: 148 | offset = 0 149 | samples = [] 150 | for i in range(self.t_length): 151 | samples.append(offset + i * self.t_stride + 1) 152 | return {"dense": samples} 153 | 154 | def _get_test_indices(self, record): 155 | """ 156 | get indices in test phase 157 | """ 158 | valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1 159 | interval = valid_offset_range / (self.num_segments - 1) 160 | offsets = [] 161 | for i in range(self.num_segments): 162 | offset = int(i * interval) 163 | if offset > valid_offset_range: 164 | offset = valid_offset_range 165 | if offset < 0: 166 | offset = 0 167 | offsets.append(offset + 1) 168 | frames = [] 169 | for i in range(self.num_segments): 170 | for j in range(self.t_length): 171 | frames.append(offsets[i] + j*self.t_stride) 172 | # frames.append(offsets[i]+j) 173 | return {"dense": frames} 174 | 175 | def __getitem__(self, index): 176 | record = self.video_list[index] 177 | 178 | if self.phase == "Train": 179 | indices = self._sample_indices(record) 180 | return self.get(record, indices, self.phase) 181 | elif self.phase == "Val": 182 | indices = self._get_val_indices(record) 183 | return self.get(record, indices, self.phase) 184 | elif self.phase == "Test": 185 | indices = self._get_test_indices(record) 186 | return self.get(record, indices, self.phase) 187 | else: 188 | raise TypeError("Unsuported phase {}".format(self.phase)) 189 | 190 | def get(self, record, indices, phase): 191 | # dense process data 192 | def dense_process_data(): 193 | images = list() 194 | for ind in indices['dense']: 195 | ptr = int(ind) 196 | if ptr <= record.num_frames: 197 | imgs = self._load_image(record.path, ptr) 198 | else: 199 | imgs = self._load_image(record.path, record.num_frames) 200 | images.extend(imgs) 201 | return self.transform(images) 202 | # unevendense process data 203 | def unevendense_process_data(): 204 | dense_images = list() 205 | sparse_images = list() 206 | for ind in indices['dense']: 207 | ptr = int(ind) 208 | if ptr <= record.num_frames: 209 | imgs = self._load_image(record.path, ptr) 210 | else: 211 | imgs = self._load_image(record.path, record.num_frames) 212 | dense_images.extend(imgs) 213 | for ind in indices['sparse']: 214 | ptr = int(ind) 215 | if ptr <= record.num_frames: 216 | imgs = self._load_image(record.path, ptr) 217 | else: 218 | imgs = self._load_image(record.path, record.num_frames) 219 | sparse_images.extend(imgs) 220 | 221 | images = dense_images + sparse_images 222 | return self.transform(images) 223 | if phase == "Train": 224 | if self.style == "Dense": 225 | process_data = dense_process_data() 226 | elif self.style == "UnevenDense": 227 | process_data = unevendense_process_data() 228 | elif phase in ("Val", "Test"): 229 | process_data = dense_process_data() 230 | return process_data, record.label 231 | 232 | def __len__(self): 233 | return len(self.video_list) 234 | 235 | class ShortVideoDataSet(VideoDataSet): 236 | def __init__(self, root_path, list_file, 237 | t_length=32, t_stride=2, num_segments=1, 238 | image_tmpl='img_{:05d}.jpg', 239 | transform=None, style="Dense", 240 | phase="Train"): 241 | """ 242 | :style: Dense, for 2D and 3D model, and Sparse for TSN model 243 | :phase: Train, Val, Test 244 | """ 245 | 246 | super(ShortVideoDataSet, self).__init__(root_path, 247 | list_file, t_length, t_stride, num_segments, 248 | image_tmpl, transform, style, phase) 249 | 250 | 251 | def _get_val_indices(self, record): 252 | """ 253 | get indices in val phase 254 | """ 255 | # valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1 256 | t_stride = self.t_stride 257 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1 258 | offset = int(valid_offset_range / 2.0) 259 | 260 | if record.num_frames > self.t_length: 261 | while(offset < 0 and t_stride > 1): 262 | t_stride -= 1 263 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1 264 | offset = int(valid_offset_range / 2.0) 265 | else: 266 | t_stride = 1 267 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1 268 | offset = int(valid_offset_range / 2.0) 269 | 270 | if offset < 0: 271 | offset = 0 272 | samples = [] 273 | for i in range(self.t_length): 274 | samples.append(offset + i * t_stride + 1) 275 | return {"dense": samples} 276 | 277 | def _get_test_indices(self, record): 278 | """ 279 | get indices in test phase 280 | """ 281 | t_stride = self.t_stride 282 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1 283 | while(valid_offset_range < (self.num_segments - 1) and t_stride > 1): 284 | t_stride -= 1 285 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1 286 | if valid_offset_range < 0: 287 | valid_offset_range = 0 288 | interval = valid_offset_range / (self.num_segments - 1) 289 | offsets = [] 290 | for i in range(self.num_segments): 291 | offset = int(i * interval) 292 | if offset > valid_offset_range+1: 293 | offset = valid_offset_range+1 294 | if offset < 0: 295 | offset = 0 296 | offsets.append(offset + 1) 297 | frames = [] 298 | for i in range(self.num_segments): 299 | for j in range(self.t_length): 300 | frames.append(offsets[i] + j * t_stride) 301 | # frames.append(offsets[i]+j) 302 | return {"dense": frames} 303 | 304 | 305 | if __name__ == "__main__": 306 | td = VideoDataSet(root_path="../data/kinetics400/access/kinetics_train_rgb_img_256_340/", 307 | list_file="../data/kinetics400/kinetics_train_list.txt", 308 | t_length=16, 309 | t_stride=4, 310 | num_segments=3, 311 | image_tmpl="image_{:06d}.jpg", 312 | style="UnevenDense", 313 | phase="Train") 314 | # sample0 = td[0] 315 | import pdb 316 | pdb.set_trace() 317 | -------------------------------------------------------------------------------- /lib/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch import nn 3 | from torch.nn.parameter import Parameter 4 | from .networks import * 5 | 6 | from .transforms import * 7 | 8 | class VideoModule(nn.Module): 9 | def __init__(self, num_class, base_model_name='resnet50', 10 | before_softmax=True, dropout=0.8, pretrained=True, pretrained_model=None): 11 | super(VideoModule, self).__init__() 12 | self.num_class = num_class 13 | self.base_model_name = base_model_name 14 | self.before_softmax = before_softmax 15 | self.dropout = dropout 16 | self.pretrained = pretrained 17 | self.pretrained_model = pretrained_model 18 | # self.finetune = finetune 19 | 20 | self._prepare_base_model(base_model_name) 21 | 22 | if not self.before_softmax: 23 | self.softmax = nn.Softmax() 24 | 25 | def _prepare_base_model(self, base_model_name): 26 | """ 27 | base_model+(dropout)+classifier 28 | """ 29 | base_model_dict = None 30 | classifier_dict = None 31 | if self.pretrained and self.pretrained_model: 32 | model_dict = torch.load(self.pretrained_model) 33 | base_model_dict = {k: v for k, v in model_dict.items() if "classifier" not in k} 34 | classifier_dict = {'.'.join(k.split('.')[1:]): v for k, v in model_dict.items() if "classifier" in k} 35 | # base model 36 | if "resnet" in base_model_name: 37 | self.base_model = eval(base_model_name)(pretrained=self.pretrained, \ 38 | feat=True, pretrained_model=base_model_dict) 39 | elif base_model_name == "mnet2": 40 | model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 41 | "../models/mobilenet_v2.pth.tar") 42 | self.base_model = mnet2(pretrained=model_path, feat=True) 43 | elif base_model_name == "mnet2_3d": 44 | model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 45 | "../models/mobilenet_v2.pth.tar") 46 | self.base_model = mnet2_3d(pretrained=model_path, feat=True) 47 | elif "fst" in base_model_name or "msv" in base_model_name or "gsv" in base_model_name: 48 | self.base_model = eval(base_model_name)(pretrained=self.pretrained, 49 | feat=True, pretrained_model=base_model_dict) 50 | else: 51 | raise ValueError('Unknown base model: {}'.format(base_model)) 52 | 53 | # classifier: (dropout) + fc 54 | if self.dropout == 0: 55 | self.classifier = nn.Linear(self.base_model.feat_dim, self.num_class) 56 | elif self.dropout > 0: 57 | self.classifier = nn.Sequential(nn.Dropout(self.dropout), nn.Linear(self.base_model.feat_dim, self.num_class)) 58 | 59 | # init classifier 60 | for m in self.classifier.modules(): 61 | if isinstance(m, nn.Linear): 62 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear') 63 | nn.init.constant_(m.bias, 0) 64 | 65 | # if self.pretrained and self.pretrained_model: 66 | # self.classifier.load_state_dict(classifier_dict) 67 | 68 | # if self.finetune: 69 | # print("Finetune") 70 | # for param in self.base_model.parameters(): 71 | # param.requires_grad = False 72 | # for m in self.base_model.modules(): 73 | # if isinstance(m, nn.BatchNorm3d): 74 | # m.eval() 75 | 76 | # import pdb 77 | # pdb.set_trace() 78 | 79 | def forward(self, input): 80 | out = self.base_model(input) 81 | out = self.classifier(out) 82 | 83 | if not self.before_softmax: 84 | out = self.softmax(out) 85 | 86 | return out 87 | 88 | def get_augmentation(self): 89 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 90 | GroupRandomHorizontalFlip()]) 91 | 92 | class TSN(nn.Module): 93 | """Temporal Segment Network 94 | 95 | """ 96 | def __init__(self, batch_size, video_module, num_segments=1, t_length=1, 97 | crop_fusion_type='max', mode="3D"): 98 | super(TSN, self).__init__() 99 | self.t_length = t_length 100 | self.batch_size = batch_size 101 | self.num_segments = num_segments 102 | self.video_module = video_module 103 | self.crop_fusion_type = crop_fusion_type 104 | self.mode = mode 105 | 106 | def forward(self, input): 107 | # reshape input first 108 | shape = input.shape 109 | if "3D" in self.mode: 110 | assert(len(shape)) == 5, "In 3D mode, input must have 5 dims." 111 | shape = (shape[0], shape[1], shape[2]//self.t_length, self.t_length) + shape[3:] 112 | input = input.view(shape).permute((0, 2, 1, 3, 4, 5)).contiguous() 113 | shape = (input.shape[0] * input.shape[1], ) + input.shape[2:] 114 | input = input.view(shape) 115 | elif "2D" in self.mode: 116 | assert(len(shape)) == 4, "In 2D mode, input must have 4 dims." 117 | shape = (shape[0]*shape[1]//3, 3,) + shape[2:] 118 | input = input.view(shape) 119 | else: 120 | raise Exception("Unsupported mode.") 121 | 122 | # base network forward 123 | output = self.video_module(input) 124 | # fuse output 125 | output = output.view((self.batch_size, 126 | output.shape[0] // (self.batch_size * self.num_segments), 127 | self.num_segments, output.shape[1])) 128 | 129 | output_max = output.max(1)[0].squeeze(1) 130 | pred_max = output_max.mean(1).squeeze(1) 131 | output_ave = output.mean(1).squeeze(1) 132 | pred_ave = output_ave.mean(1).squeeze(1) 133 | # if self.crop_fusion_type == 'max': 134 | # # pdb.set_trace() 135 | # output = output.max(1)[0].squeeze(1) 136 | # elif self.crop_fusion_type == 'avg': 137 | # output = output.mean(1).squeeze(1) 138 | # pred = output.mean(1).squeeze(1) 139 | return (output_max, pred_max, output_ave, pred_ave) 140 | -------------------------------------------------------------------------------- /lib/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .scale import * 2 | from .pooling import * -------------------------------------------------------------------------------- /lib/modules/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class GloAvgPool3d(nn.Module): 6 | def __init__(self): 7 | super(GloAvgPool3d, self).__init__() 8 | self.stride = 1 9 | self.padding = 0 10 | self.ceil_mode = False 11 | self.count_include_pad = True 12 | 13 | def forward(self, input): 14 | input_shape = input.shape 15 | kernel_size = input_shape[2:] 16 | return F.avg_pool3d(input, kernel_size, self.stride, 17 | self.padding, self.ceil_mode, self.count_include_pad) 18 | 19 | class GloSptMaxPool3d(nn.Module): 20 | def __init__(self): 21 | super(GloSptMaxPool3d, self).__init__() 22 | self.stride = 1 23 | self.padding = 0 24 | self.ceil_mode = False 25 | self.count_include_pad = True 26 | 27 | def forward(self, input): 28 | input_shape = input.shape 29 | kernel_size = (1,) + input_shape[3:] 30 | return F.max_pool3d(input, kernel_size=kernel_size, stride=self.stride, 31 | padding=self.padding, ceil_mode=self.ceil_mode) 32 | 33 | class GloSptAvgPool3d(nn.Module): 34 | def __init__(self): 35 | super(GloSptAvgPool3d, self).__init__() 36 | self.stride = 1 37 | self.padding = 0 38 | self.ceil_mode = False 39 | self.count_include_pad = True 40 | 41 | def forward(self, input): 42 | input_shape = input.shape 43 | kernel_size = (1, ) + input_shape[3:] 44 | return F.avg_pool3d(input, kernel_size, self.stride, 45 | self.padding, self.ceil_mode, self.count_include_pad) -------------------------------------------------------------------------------- /lib/modules/scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | class Scale2d(nn.Module): 6 | def __init__(self, out_channels): 7 | super(Scale2d, self).__init__() 8 | self.scale = Parameter(torch.Tensor(1, out_channels, 1, 1)) 9 | 10 | def forward(self, input): 11 | return input * self.scale 12 | 13 | class Scale3d(nn.Module): 14 | def __init__(self, out_channels): 15 | super(Scale3d, self).__init__() 16 | self.scale = Parameter(torch.Tensor(1, out_channels, 1, 1, 1)) 17 | 18 | def forward(self, input): 19 | return input * self.scale -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnet2 import * 2 | from .mnet2_3d import * 3 | from .resnet import * 4 | from .resnet_3d import * 5 | from .part_inflate_resnet_3d import * 6 | from .resnet_3d_nodown import * 7 | -------------------------------------------------------------------------------- /lib/networks/mnet2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import os 5 | 6 | __all__ = ['mnet2'] 7 | 8 | def conv_bn(inp, oup, stride): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def conv_1x1_bn(inp, oup): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU6(inplace=True) 21 | ) 22 | 23 | 24 | class InvertedResidual(nn.Module): 25 | def __init__(self, inp, oup, stride, expand_ratio): 26 | super(InvertedResidual, self).__init__() 27 | self.stride = stride 28 | assert stride in [1, 2] 29 | 30 | hidden_dim = round(inp * expand_ratio) 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | 33 | if expand_ratio == 1: 34 | self.conv = nn.Sequential( 35 | # dw 36 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False), 37 | nn.BatchNorm2d(hidden_dim), 38 | nn.ReLU6(inplace=True), 39 | # pw-linear 40 | nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False), 41 | nn.BatchNorm2d(oup), 42 | ) 43 | else: 44 | self.conv = nn.Sequential( 45 | # pw 46 | nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False), 47 | nn.BatchNorm2d(hidden_dim), 48 | nn.ReLU6(inplace=True), 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False), 51 | nn.BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False), 55 | nn.BatchNorm2d(oup), 56 | ) 57 | 58 | def forward(self, x): 59 | if self.use_res_connect: 60 | return x + self.conv(x) 61 | else: 62 | return self.conv(x) 63 | 64 | 65 | class MobileNetV2(nn.Module): 66 | def __init__(self, n_class=1000, input_size=224, width_mult=1., feat=False): 67 | super(MobileNetV2, self).__init__() 68 | self.feat = feat 69 | block = InvertedResidual 70 | input_channel = 32 71 | last_channel = 1280 72 | interverted_residual_setting = [ 73 | # t, c, n, s 74 | [1, 16, 1, 1], 75 | [6, 24, 2, 2], 76 | [6, 32, 3, 2], 77 | [6, 64, 4, 2], 78 | [6, 96, 3, 1], 79 | [6, 160, 3, 2], 80 | [6, 320, 1, 1], 81 | ] 82 | 83 | # building first layer 84 | assert input_size % 32 == 0 85 | input_channel = int(input_channel * width_mult) 86 | self.feat_dim = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 87 | self.features = [conv_bn(3, input_channel, 2)] 88 | # building inverted residual blocks 89 | for t, c, n, s in interverted_residual_setting: 90 | output_channel = int(c * width_mult) 91 | for i in range(n): 92 | if i == 0: 93 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 94 | else: 95 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 96 | input_channel = output_channel 97 | # building last several layers 98 | self.features.append(conv_1x1_bn(input_channel, self.feat_dim)) 99 | # make it nn.Sequential 100 | self.features = nn.Sequential(*self.features) 101 | self.avgpool = nn.AvgPool2d(7, stride=1) 102 | 103 | # building classifier 104 | if not self.feat: 105 | self.classifier = nn.Sequential( 106 | nn.Dropout(0.2), 107 | nn.Linear(self.feat_dim, n_class), 108 | ) 109 | 110 | self._initialize_weights() 111 | 112 | def forward(self, x): 113 | x = self.features(x) 114 | x = self.avgpool(x) 115 | x = x.view(x.size(0), -1) 116 | if not self.feat: 117 | x = self.classifier(x) 118 | return x 119 | 120 | def _initialize_weights(self): 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 124 | m.weight.data.normal_(0, math.sqrt(2. / n)) 125 | if m.bias is not None: 126 | m.bias.data.zero_() 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | elif isinstance(m, nn.Linear): 131 | n = m.weight.size(1) 132 | m.weight.data.normal_(0, 0.01) 133 | m.bias.data.zero_() 134 | 135 | def part_state_dict(state_dict, model_dict): 136 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 137 | model_dict.update(pretrained_dict) 138 | return model_dict 139 | 140 | def mnet2(pretrained=None, feat=False): 141 | if feat: 142 | assert(pretrained != None and os.path.exists(pretrained)), "pretrained model must be ready when using feat." 143 | model = MobileNetV2(feat=feat) 144 | if feat: 145 | state_dict = part_state_dict(torch.load(pretrained, map_location=lambda storage, loc: storage), 146 | model.state_dict()) 147 | model.load_state_dict(state_dict) 148 | return model 149 | -------------------------------------------------------------------------------- /lib/networks/mnet2_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import os 5 | 6 | __all__ = ["mnet2_3d"] 7 | 8 | def conv_bn(inp, oup, stride, t_stride=1): 9 | return nn.Sequential( 10 | nn.Conv3d(inp, oup, kernel_size=(1, 3, 3), 11 | stride=(t_stride, stride, stride), padding=(0, 1, 1), bias=False), 12 | nn.BatchNorm3d(oup), 13 | nn.ReLU6(inplace=True) 14 | ) 15 | 16 | 17 | def conv_1x1x1_bn(inp, oup): 18 | return nn.Sequential( 19 | nn.Conv3d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False), 20 | nn.BatchNorm3d(oup), 21 | nn.ReLU6(inplace=True) 22 | ) 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, t_stride, expand_ratio, t_radius=1): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | self.t_stride = t_stride 30 | self.t_radius = t_radius 31 | assert stride in [1, 2] and t_stride in [1, 2] 32 | 33 | hidden_dim = round(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | assert(t_stride == 1), "Temporal stride must be one when expand ratio is one." 38 | self.conv = nn.Sequential( 39 | # dw 40 | nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(1, 3, 3), stride=(t_stride, stride, stride), 41 | padding=(0, 1, 1), groups=hidden_dim, bias=False), 42 | nn.BatchNorm3d(hidden_dim), 43 | nn.ReLU6(inplace=True), 44 | # pw-linear 45 | nn.Conv3d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False), 46 | nn.BatchNorm3d(oup), 47 | ) 48 | else: 49 | self.conv = nn.Sequential( 50 | # pw 51 | nn.Conv3d(inp, hidden_dim, kernel_size=(t_radius * 2 + 1, 1, 1), 52 | stride=(t_stride, 1, 1), padding=(t_radius, 0, 0), bias=False), 53 | nn.BatchNorm3d(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # dw 56 | nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(1, 3, 3), stride=(1, stride, stride), 57 | padding=(0, 1, 1), groups=hidden_dim, bias=False), 58 | nn.BatchNorm3d(hidden_dim), 59 | nn.ReLU6(inplace=True), 60 | # pw-linear 61 | nn.Conv3d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False), 62 | nn.BatchNorm3d(oup), 63 | ) 64 | 65 | def forward(self, x): 66 | if self.use_res_connect: 67 | return x + self.conv(x) 68 | else: 69 | return self.conv(x) 70 | 71 | 72 | class MobileNetV2_3D(nn.Module): 73 | def __init__(self, n_class=1000, input_size=224, width_mult=1., feat=False): 74 | super(MobileNetV2_3D, self).__init__() 75 | self.feat = feat 76 | block = InvertedResidual 77 | input_channel = 32 78 | last_channel = 1280 79 | interverted_residual_setting = [ 80 | # t, c, n, s, ts, r 81 | [1, 16, 1, 1, 1, 0], 82 | [6, 24, 2, 2, 1, 0], 83 | [6, 32, 3, 2, 1, 0], 84 | [6, 64, 4, 2, 1, 1], 85 | [6, 96, 3, 1, 2, 1], 86 | [6, 160, 3, 2, 2, 1], 87 | [6, 320, 1, 1, 1, 1], 88 | ] 89 | 90 | # building first layer 91 | assert input_size % 32 == 0 92 | input_channel = int(input_channel * width_mult) 93 | self.feat_dim = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 94 | self.features = [conv_bn(3, input_channel, 2)] 95 | # building inverted residual blocks 96 | for t, c, n, s, ts, r in interverted_residual_setting: 97 | output_channel = int(c * width_mult) 98 | for i in range(n): 99 | if i == 0: 100 | self.features.append(block(input_channel, output_channel, s, ts, expand_ratio=t, t_radius=r)) 101 | else: 102 | self.features.append(block(input_channel, output_channel, 1, 1, expand_ratio=t, t_radius=r)) 103 | input_channel = output_channel 104 | # building last several layers 105 | self.features.append(conv_1x1x1_bn(input_channel, self.feat_dim)) 106 | # make it nn.Sequential 107 | self.features = nn.Sequential(*self.features) 108 | self.avgpool = nn.AvgPool3d(kernel_size=(4, 7, 7), stride=1) 109 | 110 | # building classifier 111 | if not self.feat: 112 | self.classifier = nn.Sequential( 113 | nn.Dropout(0.2), 114 | nn.Linear(self.feat_dim, n_class), 115 | ) 116 | 117 | self._initialize_weights() 118 | 119 | def forward(self, x): 120 | x = self.features(x) 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | if not self.feat: 124 | x = self.classifier(x) 125 | return x 126 | 127 | def _initialize_weights(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv3d): 130 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 131 | m.weight.data.normal_(0, math.sqrt(2. / n)) 132 | if m.bias is not None: 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm3d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.Linear): 138 | n = m.weight.size(1) 139 | m.weight.data.normal_(0, 0.01) 140 | m.bias.data.zero_() 141 | 142 | def part_state_dict(state_dict, model_dict): 143 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 144 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict) 145 | model_dict.update(pretrained_dict) 146 | return model_dict 147 | 148 | 149 | def inflate_state_dict(pretrained_dict, model_dict): 150 | for k in pretrained_dict.keys(): 151 | if pretrained_dict[k].size() != model_dict[k].size(): 152 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \ 153 | "To inflate, channel number should match." 154 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \ 155 | "To inflate, spatial kernel size should match." 156 | print("Layer {} needs inflation.".format(k)) 157 | shape = list(pretrained_dict[k].shape) 158 | shape.insert(2, 1) 159 | t_length = model_dict[k].shape[2] 160 | pretrained_dict[k] = pretrained_dict[k].reshape(shape) 161 | if t_length != 1: 162 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length 163 | assert(pretrained_dict[k].size() == model_dict[k].size()), \ 164 | "After inflation, model shape should match." 165 | return pretrained_dict 166 | 167 | def mnet2_3d(pretrained=None, feat=False): 168 | if pretrained != None: 169 | assert(os.path.exists(pretrained)), "pretrained model does not exist." 170 | model = MobileNetV2_3D(feat=feat) 171 | if pretrained: 172 | state_dict = torch.load(pretrained, map_location=lambda storage, loc: storage) 173 | state_dict = part_state_dict(state_dict, model.state_dict()) 174 | model.load_state_dict(state_dict) 175 | return model 176 | -------------------------------------------------------------------------------- /lib/networks/part_inflate_resnet_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modify the original file to make the class support feature extraction 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parameter import Parameter 7 | import torch.nn.functional as F 8 | import math 9 | import torch.utils.model_zoo as model_zoo 10 | from ..modules import * 11 | 12 | 13 | __all__ = ["pib_resnet26_3d_v1", "pib_resnet50_3d_slow", "pib_resnet26_3d_v1_1", "pib_resnet26_3d_full", "pib_resnet26_2d_full"] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | class Bottleneck3D_000(nn.Module): 24 | expansion = 4 25 | 26 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 27 | super(Bottleneck3D_000, self).__init__() 28 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, 29 | stride=[t_stride, 1, 1], bias=False) 30 | self.bn1 = nn.BatchNorm3d(planes) 31 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 32 | stride=[1, stride, stride], padding=(0, 1, 1), bias=False) 33 | self.bn2 = nn.BatchNorm3d(planes) 34 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 35 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv3(out) 52 | out = self.bn3(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | class PIBottleneck3D(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, ratio=0.5, stride=1, t_stride=1, downsample=None): 66 | super(PIBottleneck3D, self).__init__() 67 | self.ratio = ratio 68 | if ratio == 1: 69 | self.conv1_t = nn.Conv3d(inplanes, planes, 70 | kernel_size=(3, 1, 1), 71 | stride=(t_stride, 1, 1), 72 | padding=(1, 0, 0), 73 | bias=False) 74 | elif ratio == 0: 75 | self.conv1_p = nn.Conv3d(inplanes, planes, 76 | kernel_size=(1, 1, 1), 77 | stride=(t_stride, 1, 1), 78 | padding=(0, 0, 0), 79 | bias=False) 80 | else: 81 | self.conv1_t = nn.Conv3d(inplanes, int(planes * ratio), 82 | kernel_size=(3, 1, 1), 83 | stride=(t_stride, 1, 1), 84 | padding=(1, 0, 0), 85 | bias=False) 86 | self.conv1_p = nn.Conv3d(inplanes, int(planes*(1-ratio)), 87 | kernel_size=(1, 1, 1), 88 | stride=(t_stride, 1, 1), 89 | padding=(0, 0, 0), 90 | bias=False) 91 | self.bn1 = nn.BatchNorm3d(planes) 92 | self.conv2 = nn.Conv3d(planes, planes, 93 | kernel_size=(1, 3, 3), 94 | stride=(1, stride, stride), 95 | padding=(0, 1, 1), 96 | bias=False) 97 | self.bn2 = nn.BatchNorm3d(planes) 98 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, 99 | kernel_size=1, 100 | bias=False) 101 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | if self.ratio == 1: 110 | out = self.conv1_t(x) 111 | elif self.ratio == 0: 112 | out = self.conv1_p(x) 113 | else: 114 | out_t = self.conv1_t(x) 115 | out_p = self.conv1_p(x) 116 | out = torch.cat((out_t, out_p), dim=1) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out += residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | class PIBResNet3D_8fr(nn.Module): 136 | 137 | def __init__(self, block, layers, ratios, num_classes=1000, feat=False, **kwargs): 138 | if not isinstance(block, list): 139 | block = [block] * 4 140 | else: 141 | assert(len(block)) == 4, "Block number must be 4 for ResNet-Stype networks." 142 | self.inplanes = 64 143 | super(PIBResNet3D_8fr, self).__init__() 144 | self.feat = feat 145 | self.conv1 = nn.Conv3d(3, 64, 146 | kernel_size=(1, 7, 7), 147 | stride=(1, 2, 2), 148 | padding=(0, 3, 3), 149 | bias=False) 150 | self.bn1 = nn.BatchNorm3d(64) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), 153 | stride=(1, 2, 2), 154 | padding=(0, 1, 1)) 155 | self.layer1 = self._make_layer(block[0], 64, layers[0], inf_ratio=ratios[0]) 156 | self.layer2 = self._make_layer(block[1], 128, layers[1], inf_ratio=ratios[1], stride=2) 157 | self.layer3 = self._make_layer(block[2], 256, layers[2], inf_ratio=ratios[2], stride=2, t_stride=2) 158 | self.layer4 = self._make_layer(block[3], 512, layers[3], inf_ratio=ratios[3], stride=2, t_stride=2) 159 | self.avgpool = GloAvgPool3d() 160 | self.feat_dim = 512 * block[0].expansion 161 | if not feat: 162 | self.fc = nn.Linear(512 * block[0].expansion, num_classes) 163 | 164 | for n, m in self.named_modules(): 165 | if isinstance(m, nn.Conv3d): 166 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 167 | elif isinstance(m, nn.BatchNorm3d) and "conv_t" not in n: 168 | nn.init.constant_(m.weight, 1) 169 | nn.init.constant_(m.bias, 0) 170 | elif isinstance(m, Scale3d): 171 | nn.init.constant_(m.scale, 0) 172 | 173 | 174 | def _make_layer(self, block, planes, blocks, inf_ratio, stride=1, t_stride=1): 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv3d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=(t_stride, stride, stride), bias=False), 180 | nn.BatchNorm3d(planes * block.expansion), 181 | ) 182 | 183 | layers = [] 184 | layers.append(block(self.inplanes, planes, inf_ratio, stride=stride, t_stride=t_stride, downsample=downsample)) 185 | self.inplanes = planes * block.expansion 186 | for i in range(1, blocks): 187 | layers.append(block(self.inplanes, planes, inf_ratio)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | 202 | 203 | x = self.avgpool(x) 204 | x = x.view(x.size(0), -1) 205 | if not self.feat: 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | 211 | def part_state_dict(state_dict, model_dict, ratios): 212 | assert(len(ratios) == 4), "Length of ratios must equal to stage number" 213 | added_dict = {} 214 | for k, v in state_dict.items(): 215 | # import pdb 216 | # pdb.set_trace() 217 | if ".conv1.weight" in k and "layer" in k: 218 | # import pdb 219 | # pdb.set_trace() 220 | ratio = ratios[int(k[k.index("layer")+5])-1] 221 | out_channels = v.shape[0] 222 | slice_index = int(out_channels*ratio) 223 | if ratio == 1: 224 | new_k = k[:k.index(".conv1.weight")]+'.conv1_t.weight' 225 | added_dict.update({new_k: v[:slice_index,...]}) 226 | elif ratio == 0: 227 | new_k = k[:k.index(".conv1.weight")]+'.conv1_p.weight' 228 | added_dict.update({new_k: v[slice_index:,...]}) 229 | else: 230 | new_k = k[:k.index(".conv1.weight")]+'.conv1_t.weight' 231 | added_dict.update({new_k: v[:slice_index,...]}) 232 | new_k = k[:k.index(".conv1.weight")]+'.conv1_p.weight' 233 | added_dict.update({new_k: v[slice_index:,...]}) 234 | 235 | state_dict.update(added_dict) 236 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 237 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict) 238 | model_dict.update(pretrained_dict) 239 | return model_dict 240 | 241 | 242 | def inflate_state_dict(pretrained_dict, model_dict): 243 | for k in pretrained_dict.keys(): 244 | if pretrained_dict[k].size() != model_dict[k].size(): 245 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \ 246 | "To inflate, channel number should match." 247 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \ 248 | "To inflate, spatial kernel size should match." 249 | print("Layer {} needs inflation.".format(k)) 250 | shape = list(pretrained_dict[k].shape) 251 | shape.insert(2, 1) 252 | t_length = model_dict[k].shape[2] 253 | pretrained_dict[k] = pretrained_dict[k].reshape(shape) 254 | if t_length != 1: 255 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length 256 | assert(pretrained_dict[k].size() == model_dict[k].size()), \ 257 | "After inflation, model shape should match." 258 | 259 | return pretrained_dict 260 | 261 | def pib_resnet26_3d_v1(pretrained=False, feat=False, **kwargs): 262 | """Constructs a ResNet-50 model. 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | """ 266 | ratios = (1/8, 1/4, 1/2, 1) 267 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D], 268 | [2, 2, 2, 2], ratios, feat=feat, **kwargs) 269 | if pretrained: 270 | if kwargs['pretrained_model'] is None: 271 | pass 272 | # state_dict = model_zoo.load_url(model_urls['resnet50']) 273 | else: 274 | print("Using specified pretrain model") 275 | state_dict = kwargs['pretrained_model'] 276 | if feat: 277 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios) 278 | model.load_state_dict(new_state_dict) 279 | return model 280 | 281 | def pib_resnet26_3d_full(pretrained=False, feat=False, **kwargs): 282 | """Constructs a ResNet-50 model. 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | """ 286 | ratios = (1, 1, 1, 1) 287 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D], 288 | [2, 2, 2, 2], ratios, feat=feat, **kwargs) 289 | if pretrained: 290 | if kwargs['pretrained_model'] is None: 291 | pass 292 | # state_dict = model_zoo.load_url(model_urls['resnet50']) 293 | else: 294 | print("Using specified pretrain model") 295 | state_dict = kwargs['pretrained_model'] 296 | if feat: 297 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios) 298 | model.load_state_dict(new_state_dict) 299 | return model 300 | 301 | def pib_resnet26_2d_full(pretrained=False, feat=False, **kwargs): 302 | """Constructs a ResNet-50 model. 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | """ 306 | ratios = (0, 0, 0, 0) 307 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D], 308 | [2, 2, 2, 2], ratios, feat=feat, **kwargs) 309 | if pretrained: 310 | if kwargs['pretrained_model'] is None: 311 | pass 312 | # state_dict = model_zoo.load_url(model_urls['resnet50']) 313 | else: 314 | print("Using specified pretrain model") 315 | state_dict = kwargs['pretrained_model'] 316 | if feat: 317 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios) 318 | model.load_state_dict(new_state_dict) 319 | return model 320 | 321 | def pib_resnet26_3d_v1_1(pretrained=False, feat=False, **kwargs): 322 | """Constructs a ResNet-50 model. 323 | Args: 324 | pretrained (bool): If True, returns a model pre-trained on ImageNet 325 | """ 326 | ratios = (1/2, 1/2, 1/2, 1/2) 327 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D], 328 | [2, 2, 2, 2], ratios, feat=feat, **kwargs) 329 | if pretrained: 330 | if kwargs['pretrained_model'] is None: 331 | pass 332 | # state_dict = model_zoo.load_url(model_urls['resnet50']) 333 | else: 334 | print("Using specified pretrain model") 335 | state_dict = kwargs['pretrained_model'] 336 | if feat: 337 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios) 338 | model.load_state_dict(new_state_dict) 339 | return model 340 | 341 | def pib_resnet50_3d_slow(pretrained=False, feat=False, **kwargs): 342 | """Constructs a ResNet-50 model. 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | """ 346 | ratios = (0, 0, 1, 1) 347 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D], 348 | [3, 4, 6, 3], ratios, feat=feat, **kwargs) 349 | if pretrained: 350 | if kwargs['pretrained_model'] is None: 351 | state_dict = model_zoo.load_url(model_urls['resnet50']) 352 | else: 353 | print("Using specified pretrain model") 354 | state_dict = kwargs['pretrained_model'] 355 | if feat: 356 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios) 357 | model.load_state_dict(new_state_dict) 358 | return model -------------------------------------------------------------------------------- /lib/networks/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modify the original file to make the class support feature extraction 3 | """ 4 | 5 | import torch.nn as nn 6 | import math 7 | import torch.utils.model_zoo as model_zoo 8 | import torch 9 | from torch.nn.parameter import Parameter 10 | from ..modules import * 11 | 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet26', 'resnet26_point', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | # class Scale2d(nn.Module): 32 | # def __init__(self, out_channels): 33 | # super(Scale2d, self).__init__() 34 | # self.scale = Parameter(torch.Tensor(1, out_channels, 1, 1)) 35 | 36 | # def forward(self, input): 37 | # return input * self.scale 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None): 43 | super(BasicBlock, self).__init__() 44 | self.conv1 = conv3x3(inplanes, planes, stride) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv2 = conv3x3(planes, planes) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | residual = self.downsample(x) 64 | 65 | out += residual 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None): 75 | super(Bottleneck, self).__init__() 76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 79 | padding=1, bias=False) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 82 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | residual = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | residual = self.downsample(x) 103 | 104 | out += residual 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | class PointBottleneck(nn.Module): 110 | expansion = 4 111 | 112 | def __init__(self, inplanes, planes, stride=1, downsample=None): 113 | super(PointBottleneck, self).__init__() 114 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 115 | self.bn1 = nn.BatchNorm2d(planes) 116 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 117 | padding=1, bias=False) 118 | self.bn2 = nn.BatchNorm2d(planes) 119 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 120 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | if self.downsample is None: 124 | self.conv_p = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, bias=False) 125 | self.stride = stride 126 | 127 | def forward(self, x): 128 | residual = x 129 | 130 | out = self.conv1(x) 131 | out = self.bn1(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv2(out) 135 | out = self.bn2(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv3(out) 139 | out = self.bn3(out) 140 | 141 | if self.downsample is not None: 142 | residual = self.downsample(x) 143 | else: 144 | residual = self.conv_p(x) 145 | 146 | out += residual 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | class SCBottleneck(nn.Module): 152 | expansion = 4 153 | 154 | def __init__(self, inplanes, planes, stride=1, downsample=None): 155 | super(SCBottleneck, self).__init__() 156 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 157 | self.bn1 = nn.BatchNorm2d(planes) 158 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 159 | padding=1, bias=False) 160 | self.bn2 = nn.BatchNorm2d(planes) 161 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 162 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 163 | self.relu = nn.ReLU(inplace=True) 164 | self.sc = Scale2d(out_channels=inplanes) 165 | self.downsample = downsample 166 | self.stride = stride 167 | 168 | def forward(self, x): 169 | # residual = x 170 | 171 | out = self.conv1(x) 172 | out = self.bn1(out) 173 | out = self.relu(out) 174 | 175 | out = self.conv2(out) 176 | out = self.bn2(out) 177 | out = self.relu(out) 178 | 179 | out = self.conv3(out) 180 | out = self.bn3(out) 181 | 182 | # if residual.device == torch.device('cuda:0'): 183 | # print(self.sc.scale.view(-1)[:20].data) 184 | residual = self.sc(x) 185 | if self.downsample is not None: 186 | residual = self.downsample(residual) 187 | 188 | out += residual 189 | out = self.relu(out) 190 | 191 | return out 192 | 193 | 194 | class ResNet(nn.Module): 195 | 196 | def __init__(self, block, layers, num_classes=1000, feat=False): 197 | self.inplanes = 64 198 | super(ResNet, self).__init__() 199 | self.feat = feat 200 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 201 | bias=False) 202 | self.bn1 = nn.BatchNorm2d(64) 203 | self.relu = nn.ReLU(inplace=True) 204 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 205 | self.layer1 = self._make_layer(block, 64, layers[0]) 206 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 207 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 208 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 209 | self.avgpool = nn.AvgPool2d(7, stride=1) 210 | self.feat_dim = 512 * block.expansion 211 | if not feat: 212 | self.fc = nn.Linear(512 * block.expansion, num_classes) 213 | 214 | for m in self.modules(): 215 | if isinstance(m, nn.Conv2d): 216 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 217 | elif isinstance(m, nn.BatchNorm2d): 218 | nn.init.constant_(m.weight, 1) 219 | nn.init.constant_(m.bias, 0) 220 | elif isinstance(m, Scale2d): 221 | nn.init.constant_(m.scale, 1) 222 | 223 | def _make_layer(self, block, planes, blocks, stride=1): 224 | downsample = None 225 | if stride != 1 or self.inplanes != planes * block.expansion: 226 | downsample = nn.Sequential( 227 | nn.Conv2d(self.inplanes, planes * block.expansion, 228 | kernel_size=1, stride=stride, bias=False), 229 | nn.BatchNorm2d(planes * block.expansion), 230 | ) 231 | 232 | layers = [] 233 | layers.append(block(self.inplanes, planes, stride, downsample)) 234 | self.inplanes = planes * block.expansion 235 | for i in range(1, blocks): 236 | layers.append(block(self.inplanes, planes)) 237 | 238 | return nn.Sequential(*layers) 239 | 240 | def forward(self, x): 241 | x = self.conv1(x) 242 | x = self.bn1(x) 243 | x = self.relu(x) 244 | x = self.maxpool(x) 245 | 246 | x = self.layer1(x) 247 | x = self.layer2(x) 248 | x = self.layer3(x) 249 | x = self.layer4(x) 250 | 251 | x = self.avgpool(x) 252 | x = x.view(x.size(0), -1) 253 | if not self.feat: 254 | x = self.fc(x) 255 | 256 | return x 257 | 258 | 259 | def part_state_dict(state_dict, model_dict): 260 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 261 | model_dict.update(pretrained_dict) 262 | return model_dict 263 | 264 | 265 | def resnet18(pretrained=False, feat=False, **kwargs): 266 | """Constructs a ResNet-18 model. 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | """ 270 | model = ResNet(BasicBlock, [2, 2, 2, 2], feat=feat, **kwargs) 271 | if feat: 272 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet18']), model.state_dict()) 273 | if pretrained: 274 | model.load_state_dict(state_dict) 275 | return model 276 | 277 | 278 | def resnet34(pretrained=False, feat=False, **kwargs): 279 | """Constructs a ResNet-34 model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | """ 283 | model = ResNet(BasicBlock, [3, 4, 6, 3], feat=feat, **kwargs) 284 | if feat: 285 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet34']), model.state_dict()) 286 | if pretrained: 287 | model.load_state_dict(state_dict) 288 | return model 289 | 290 | def resnet26(pretrained=False, feat=False, **kwargs): 291 | """Constructs a ResNet-50 model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | """ 295 | model = ResNet(Bottleneck, [2, 2, 2, 2], feat=feat, **kwargs) 296 | return model 297 | 298 | def resnet26_sc(pretrained=False, feat=False, **kwargs): 299 | """Constructs a ResNet-50 model. 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | """ 303 | model = ResNet(SCBottleneck, [2, 2, 2, 2], feat=feat, **kwargs) 304 | return model 305 | 306 | def resnet26_point(pretrained=False, feat=False, **kwargs): 307 | """Constructs a ResNet-50 model. 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = ResNet(PointBottleneck, [2, 2, 2, 2], feat=feat, **kwargs) 312 | return model 313 | 314 | def resnet50(pretrained=False, feat=False, **kwargs): 315 | """Constructs a ResNet-50 model. 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on ImageNet 318 | """ 319 | model = ResNet(Bottleneck, [3, 4, 6, 3], feat=feat, **kwargs) 320 | if feat: 321 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet50']), model.state_dict()) 322 | if pretrained: 323 | model.load_state_dict(state_dict) 324 | return model 325 | 326 | 327 | def resnet101(pretrained=False, feat=False, **kwargs): 328 | """Constructs a ResNet-101 model. 329 | Args: 330 | pretrained (bool): If True, returns a model pre-trained on ImageNet 331 | """ 332 | model = ResNet(Bottleneck, [3, 4, 23, 3], feat=feat, **kwargs) 333 | if feat: 334 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet101']), model.state_dict()) 335 | if pretrained: 336 | model.load_state_dict(state_dict) 337 | return model 338 | 339 | 340 | def resnet152(pretrained=False, feat=False, **kwargs): 341 | """Constructs a ResNet-152 model. 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | """ 345 | model = ResNet(Bottleneck, [3, 8, 36, 3], feat=feat, **kwargs) 346 | if feat: 347 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet152']), model.state_dict()) 348 | if pretrained: 349 | model.load_state_dict(state_dict) 350 | return model 351 | -------------------------------------------------------------------------------- /lib/networks/resnet_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modify the original file to make the class support feature extraction 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | __all__ = ['resnet50_3d_v3','resnet26_3d_v3','resnet101_3d_v1'] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | class GloAvgPool3d(nn.Module): 21 | def __init__(self): 22 | super(GloAvgPool3d, self).__init__() 23 | self.stride = 1 24 | self.padding = 0 25 | self.ceil_mode = False 26 | self.count_include_pad = True 27 | 28 | def forward(self, input): 29 | input_shape = input.shape 30 | kernel_size = input_shape[2:] 31 | return F.avg_pool3d(input, kernel_size, self.stride, 32 | self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | class Bottleneck3D_100(nn.Module): 35 | expansion = 4 36 | 37 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 38 | super(Bottleneck3D_100, self).__init__() 39 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), 40 | stride=(t_stride, 1, 1), 41 | padding=(1, 0, 0), bias=False) 42 | self.bn1 = nn.BatchNorm3d(planes) 43 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 44 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False) 45 | self.bn2 = nn.BatchNorm3d(planes) 46 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv3(out) 64 | out = self.bn3(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | class Bottleneck3D_101(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 78 | super(Bottleneck3D_101, self).__init__() 79 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), 80 | stride=(t_stride, 1, 1), 81 | padding=(1, 0, 0), bias=False) 82 | self.bn1 = nn.BatchNorm3d(planes) 83 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 84 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False) 85 | self.bn2 = nn.BatchNorm3d(planes) 86 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=(3, 1, 1), 87 | stride=1, 88 | padding=(1, 0, 0), 89 | bias=False) 90 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | class Bottleneck3D_000(nn.Module): 118 | expansion = 4 119 | 120 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 121 | super(Bottleneck3D_000, self).__init__() 122 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, 123 | stride=[t_stride, 1, 1], bias=False) 124 | self.bn1 = nn.BatchNorm3d(planes) 125 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 126 | stride=[1, stride, stride], padding=(0, 1, 1), bias=False) 127 | self.bn2 = nn.BatchNorm3d(planes) 128 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 129 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.downsample = downsample 132 | self.stride = stride 133 | 134 | def forward(self, x): 135 | residual = x 136 | 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | 157 | class ResNet3D(nn.Module): 158 | 159 | def __init__(self, block, layers, num_classes=1000, feat=False, **kwargs): 160 | if not isinstance(block, list): 161 | block = [block] * 4 162 | else: 163 | assert(len(block)) == 4, "Block number must be 4 for ResNet-Stype networks." 164 | self.inplanes = 64 165 | super(ResNet3D, self).__init__() 166 | self.feat = feat 167 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), 168 | stride=(1, 2, 2), padding=(0, 3, 3), 169 | bias=False) 170 | self.bn1 = nn.BatchNorm3d(64) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 173 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 174 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2, t_stride=2) 175 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2, t_stride=2) 176 | self.layer4 = self._make_layer(block[3], 512, layers[3], stride=2, t_stride=2) 177 | self.avgpool = GloAvgPool3d() 178 | self.feat_dim = 512 * block[0].expansion 179 | if not feat: 180 | self.fc = nn.Linear(512 * block[0].expansion, num_classes) 181 | 182 | for n, m in self.named_modules(): 183 | if isinstance(m, nn.Conv3d): 184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 185 | elif isinstance(m, nn.BatchNorm3d): 186 | nn.init.constant_(m.weight, 1) 187 | nn.init.constant_(m.bias, 0) 188 | 189 | def _make_layer(self, block, planes, blocks, stride=1, t_stride=1): 190 | downsample = None 191 | if stride != 1 or self.inplanes != planes * block.expansion: 192 | downsample = nn.Sequential( 193 | nn.Conv3d(self.inplanes, planes * block.expansion, 194 | kernel_size=1, stride=(t_stride, stride, stride), bias=False), 195 | nn.BatchNorm3d(planes * block.expansion), 196 | ) 197 | 198 | layers = [] 199 | layers.append(block(self.inplanes, planes, stride=stride, t_stride=t_stride, downsample=downsample)) 200 | self.inplanes = planes * block.expansion 201 | for i in range(1, blocks): 202 | layers.append(block(self.inplanes, planes)) 203 | 204 | return nn.Sequential(*layers) 205 | 206 | def forward(self, x): 207 | x = self.conv1(x) 208 | x = self.bn1(x) 209 | x = self.relu(x) 210 | x = self.maxpool(x) 211 | 212 | x = self.layer1(x) 213 | x = self.layer2(x) 214 | x = self.layer3(x) 215 | x = self.layer4(x) 216 | 217 | 218 | x = self.avgpool(x) 219 | x = x.view(x.size(0), -1) 220 | if not self.feat: 221 | print("WARNING!!!!!!!") 222 | x = self.fc(x) 223 | 224 | return x 225 | 226 | 227 | def part_state_dict(state_dict, model_dict): 228 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 229 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict) 230 | model_dict.update(pretrained_dict) 231 | return model_dict 232 | 233 | 234 | def inflate_state_dict(pretrained_dict, model_dict): 235 | for k in pretrained_dict.keys(): 236 | if pretrained_dict[k].size() != model_dict[k].size(): 237 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \ 238 | "To inflate, channel number should match." 239 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \ 240 | "To inflate, spatial kernel size should match." 241 | print("Layer {} needs inflation.".format(k)) 242 | shape = list(pretrained_dict[k].shape) 243 | shape.insert(2, 1) 244 | t_length = model_dict[k].shape[2] 245 | pretrained_dict[k] = pretrained_dict[k].reshape(shape) 246 | if t_length != 1: 247 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length 248 | assert(pretrained_dict[k].size() == model_dict[k].size()), \ 249 | "After inflation, model shape should match." 250 | 251 | return pretrained_dict 252 | 253 | def resnet50_3d_v1(pretrained=False, feat=False, **kwargs): 254 | """Constructs a ResNet-50 model. 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | """ 258 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_101, Bottleneck3D_101], 259 | [3, 4, 6, 3], feat=feat, **kwargs) 260 | # import pdb 261 | # pdb.set_trace() 262 | if pretrained: 263 | if kwargs['pretrained_model'] is None: 264 | state_dict = model_zoo.load_url(model_urls['resnet50']) 265 | else: 266 | print("Using specified pretrain model") 267 | state_dict = kwargs['pretrained_model'] 268 | if feat: 269 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 270 | model.load_state_dict(new_state_dict) 271 | return model 272 | 273 | def resnet50_3d_v2(pretrained=False, feat=False, **kwargs): 274 | """Constructs a ResNet-50 model. 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | """ 278 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100], 279 | [3, 4, 6, 3], feat=feat, **kwargs) 280 | # import pdb 281 | # pdb.set_trace() 282 | if pretrained: 283 | if kwargs['pretrained_model'] is None: 284 | state_dict = model_zoo.load_url(model_urls['resnet50']) 285 | else: 286 | print("Using specified pretrain model") 287 | state_dict = kwargs['pretrained_model'] 288 | if feat: 289 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 290 | model.load_state_dict(new_state_dict) 291 | return model 292 | 293 | def resnet50_3d_v3(pretrained=False, feat=False, **kwargs): 294 | """Constructs a ResNet-50 model. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | """ 298 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100], 299 | [3, 4, 6, 3], feat=feat, **kwargs) 300 | # import pdb 301 | # pdb.set_trace() 302 | if pretrained: 303 | if kwargs['pretrained_model'] is None: 304 | state_dict = model_zoo.load_url(model_urls['resnet50']) 305 | else: 306 | print("Using specified pretrain model") 307 | state_dict = kwargs['pretrained_model'] 308 | if feat: 309 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 310 | model.load_state_dict(new_state_dict) 311 | return model 312 | 313 | def resnet26_3d_v1(pretrained=False, feat=False, **kwargs): 314 | """Constructs a ResNet-50 model. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | """ 318 | model = ResNet3D([Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100], 319 | [2, 2, 2, 2], feat=feat, **kwargs) 320 | # import pdb 321 | # pdb.set_trace() 322 | if pretrained: 323 | if kwargs['pretrained_model'] is None: 324 | raise ValueError("pretrained model must be specified") 325 | else: 326 | print("Using specified pretrain model") 327 | state_dict = kwargs['pretrained_model'] 328 | if feat: 329 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 330 | model.load_state_dict(new_state_dict) 331 | return model 332 | 333 | def resnet26_3d_v3(pretrained=False, feat=False, **kwargs): 334 | """Constructs a ResNet-50 model. 335 | Args: 336 | pretrained (bool): If True, returns a model pre-trained on ImageNet 337 | """ 338 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100], 339 | [2, 2, 2, 2], feat=feat, **kwargs) 340 | # import pdb 341 | # pdb.set_trace() 342 | if pretrained: 343 | if kwargs['pretrained_model'] is None: 344 | raise ValueError("pretrained model must be specified") 345 | else: 346 | print("Using specified pretrain model") 347 | state_dict = kwargs['pretrained_model'] 348 | if feat: 349 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 350 | model.load_state_dict(new_state_dict) 351 | return model 352 | 353 | def resnet101_3d_v1(pretrained=False, feat=False, **kwargs): 354 | """Constructs a ResNet-50 model. 355 | Args: 356 | pretrained (bool): If True, returns a model pre-trained on ImageNet 357 | """ 358 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_101, Bottleneck3D_101], 359 | [3, 4, 23, 3], feat=feat, **kwargs) 360 | # import pdb 361 | # pdb.set_trace() 362 | if pretrained: 363 | if kwargs['pretrained_model'] is None: 364 | state_dict = model_zoo.load_url(model_urls['resnet101']) 365 | else: 366 | print("Using specified pretrain model") 367 | state_dict = kwargs['pretrained_model'] 368 | if feat: 369 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 370 | model.load_state_dict(new_state_dict) 371 | return model -------------------------------------------------------------------------------- /lib/networks/resnet_3d_nodown.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modify the original file to make the class support feature extraction 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | __all__ = ["resnet50_3d_slowonly"] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | class GloAvgPool3d(nn.Module): 21 | def __init__(self): 22 | super(GloAvgPool3d, self).__init__() 23 | self.stride = 1 24 | self.padding = 0 25 | self.ceil_mode = False 26 | self.count_include_pad = True 27 | 28 | def forward(self, input): 29 | input_shape = input.shape 30 | kernel_size = input_shape[2:] 31 | return F.avg_pool3d(input, kernel_size, self.stride, 32 | self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | class Bottleneck3D_100(nn.Module): 35 | expansion = 4 36 | 37 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 38 | super(Bottleneck3D_100, self).__init__() 39 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), 40 | stride=(t_stride, 1, 1), 41 | padding=(1, 0, 0), bias=False) 42 | self.bn1 = nn.BatchNorm3d(planes) 43 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 44 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False) 45 | self.bn2 = nn.BatchNorm3d(planes) 46 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv3(out) 64 | out = self.bn3(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | class Bottleneck3D_101(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 78 | super(Bottleneck3D_101, self).__init__() 79 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), 80 | stride=(t_stride, 1, 1), 81 | padding=(1, 0, 0), bias=False) 82 | self.bn1 = nn.BatchNorm3d(planes) 83 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 84 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False) 85 | self.bn2 = nn.BatchNorm3d(planes) 86 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=(3, 1, 1), 87 | stride=1, 88 | padding=(1, 0, 0), 89 | bias=False) 90 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | class Bottleneck3D_000(nn.Module): 118 | expansion = 4 119 | 120 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None): 121 | super(Bottleneck3D_000, self).__init__() 122 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, 123 | stride=[t_stride, 1, 1], bias=False) 124 | self.bn1 = nn.BatchNorm3d(planes) 125 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), 126 | stride=[1, stride, stride], padding=(0, 1, 1), bias=False) 127 | self.bn2 = nn.BatchNorm3d(planes) 128 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 129 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.downsample = downsample 132 | self.stride = stride 133 | 134 | def forward(self, x): 135 | residual = x 136 | 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample is not None: 149 | residual = self.downsample(x) 150 | 151 | out += residual 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | class ResNet3D_nodown(nn.Module): 157 | 158 | def __init__(self, block, layers, num_classes=1000, feat=False, **kwargs): 159 | if not isinstance(block, list): 160 | block = [block] * 4 161 | else: 162 | assert(len(block)) == 4, "Block number must be 4 for ResNet-Stype networks." 163 | self.inplanes = 64 164 | super(ResNet3D_nodown, self).__init__() 165 | self.feat = feat 166 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), 167 | stride=(1, 2, 2), padding=(0, 3, 3), 168 | bias=False) 169 | self.bn1 = nn.BatchNorm3d(64) 170 | self.relu = nn.ReLU(inplace=True) 171 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 172 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 173 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2) 174 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2) 175 | self.layer4 = self._make_layer(block[3], 512, layers[3], stride=2) 176 | self.avgpool = GloAvgPool3d() 177 | self.feat_dim = 512 * block[0].expansion 178 | if not feat: 179 | self.fc = nn.Linear(512 * block[0].expansion, num_classes) 180 | 181 | for n, m in self.named_modules(): 182 | if isinstance(m, nn.Conv3d): 183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 184 | elif isinstance(m, nn.BatchNorm3d): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 0) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1, t_stride=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | nn.Conv3d(self.inplanes, planes * block.expansion, 193 | kernel_size=1, stride=(t_stride, stride, stride), bias=False), 194 | nn.BatchNorm3d(planes * block.expansion), 195 | ) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride=stride, t_stride=t_stride, downsample=downsample)) 199 | self.inplanes = planes * block.expansion 200 | for i in range(1, blocks): 201 | layers.append(block(self.inplanes, planes)) 202 | 203 | return nn.Sequential(*layers) 204 | 205 | def forward(self, x): 206 | x = self.conv1(x) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | x = self.maxpool(x) 210 | 211 | x = self.layer1(x) 212 | x = self.layer2(x) 213 | x = self.layer3(x) 214 | x = self.layer4(x) 215 | 216 | x = self.avgpool(x) 217 | # print(x.shape) 218 | x = x.view(x.size(0), -1) 219 | if not self.feat: 220 | print("WARNING!!!!!!!") 221 | x = self.fc(x) 222 | 223 | return x 224 | 225 | def part_state_dict(state_dict, model_dict): 226 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 227 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict) 228 | model_dict.update(pretrained_dict) 229 | return model_dict 230 | 231 | 232 | def inflate_state_dict(pretrained_dict, model_dict): 233 | for k in pretrained_dict.keys(): 234 | if pretrained_dict[k].size() != model_dict[k].size(): 235 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \ 236 | "To inflate, channel number should match." 237 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \ 238 | "To inflate, spatial kernel size should match." 239 | print("Layer {} needs inflation.".format(k)) 240 | shape = list(pretrained_dict[k].shape) 241 | shape.insert(2, 1) 242 | t_length = model_dict[k].shape[2] 243 | pretrained_dict[k] = pretrained_dict[k].reshape(shape) 244 | if t_length != 1: 245 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length 246 | assert(pretrained_dict[k].size() == model_dict[k].size()), \ 247 | "After inflation, model shape should match." 248 | 249 | return pretrained_dict 250 | 251 | def resnet50_3d_slowonly(pretrained=False, feat=False, **kwargs): 252 | """Constructs a ResNet-50 model. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | """ 256 | model = ResNet3D_nodown([Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100], 257 | [3, 4, 6, 3], feat=feat, **kwargs) 258 | # import pdb 259 | # pdb.set_trace() 260 | if pretrained: 261 | if kwargs['pretrained_model'] is None: 262 | state_dict = model_zoo.load_url(model_urls['resnet50']) 263 | else: 264 | print("Using specified pretrain model") 265 | state_dict = kwargs['pretrained_model'] 266 | if feat: 267 | new_state_dict = part_state_dict(state_dict, model.state_dict()) 268 | model.load_state_dict(new_state_dict) 269 | return model 270 | -------------------------------------------------------------------------------- /lib/opts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | 5 | def set_logger(debug_mode=False): 6 | import time 7 | from time import gmtime, strftime 8 | logdir = os.path.join(args.experiment_root, 'log') 9 | if not os.path.exists(logdir): 10 | os.makedirs(logdir) 11 | log_file = "logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", time.localtime()) 12 | log_file = os.path.join(logdir, log_file) 13 | handlers = [logging.FileHandler(log_file), logging.StreamHandler()] 14 | 15 | """ add '%(filename)s:%(lineno)d %(levelname)s:' to format show source file """ 16 | logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO, 17 | format='%(asctime)s: %(message)s', 18 | datefmt='%Y-%m-%d %H:%M:%S', 19 | handlers = handlers) 20 | 21 | parser = argparse.ArgumentParser(description="PyTorch implementation of Video Classification") 22 | parser.add_argument('dataset', type=str) 23 | parser.add_argument('train_list', type=str) 24 | parser.add_argument('val_list', type=str) 25 | 26 | # ========================= Model Configs ========================== 27 | parser.add_argument('--arch', '-a', type=str, default="resnet18") 28 | parser.add_argument('--shadow', action='store_true') 29 | parser.add_argument('--dropout', '--do', default=0.2, type=float, 30 | metavar='DO', help='dropout ratio (default: 0.2)') 31 | parser.add_argument('--mode', type=str, default='3D', choices=['3D', 'TSN', '2D']) 32 | parser.add_argument('--new_size', type=int, default=256) 33 | parser.add_argument('--crop_size', type=int, default=224) 34 | parser.add_argument('--t_length', type=int, default=32, help="time length") 35 | parser.add_argument('--t_stride', type=int, default=2, help="time stride between frames") 36 | parser.add_argument('--num_segments', type=int, default=1) 37 | parser.add_argument('--pretrained', action='store_true') 38 | parser.add_argument('--pretrained_model', type=str, default=None) 39 | 40 | # ========================= Learning Configs ========================== 41 | parser.add_argument('--epochs', default=60, type=int, metavar='N', 42 | help='number of total epochs to run') 43 | parser.add_argument('-b', '--batch-size', default=256, type=int, 44 | metavar='N', help='mini-batch size (default: 256)') 45 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 46 | metavar='LR', help='initial learning rate') 47 | parser.add_argument('--lr_steps', default=[40, 70, 70], type=float, nargs="+", 48 | metavar='LRSteps', help='epochs to decay learning rate by 10') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 52 | metavar='W', help='weight decay (default: 5e-4)') 53 | 54 | # ========================= Monitor Configs ========================== 55 | parser.add_argument('--print-freq', '-p', default=20, type=int, 56 | metavar='N', help='print frequency (default: 20)') 57 | parser.add_argument('--eval-freq', '-ef', default=2, type=int, 58 | metavar='N', help='evaluation frequency (default: 2)') 59 | 60 | # ========================= Runtime Configs ========================== 61 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 62 | help='number of data loading workers (default: 4)') 63 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 64 | help='path to latest checkpoint (default: none)') 65 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 66 | help='evaluate model on validation set') 67 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 68 | help='manual epoch number (useful on restarts)') 69 | parser.add_argument('--output_root', type=str, default="./output") 70 | parser.add_argument('--image_tmpl', type=str, default="image_{:06d}.jpg") 71 | 72 | args = parser.parse_args() 73 | if args.mode == "2D": 74 | args.t_length = 1 75 | 76 | experiment_id = '_'.join(map(str, [args.dataset, args.arch, args.mode, 77 | 'length'+str(args.t_length), 'stride'+str(args.t_stride), 78 | 'dropout'+str(args.dropout)])) 79 | 80 | if args.pretrained and args.pretrained_model: 81 | if "2d" in args.pretrained_model: 82 | experiment_id += '_2dpretrained' 83 | 84 | if args.shadow: 85 | experiment_id += '_shadow' 86 | 87 | args.experiment_root = os.path.join(args.output_root, experiment_id) 88 | # init logger 89 | set_logger() 90 | logging.info(args) 91 | if not os.path.exists(args.experiment_root): 92 | os.makedirs(args.experiment_root) 93 | -------------------------------------------------------------------------------- /lib/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | class GroupRandomCrop(object): 10 | def __init__(self, size): 11 | if isinstance(size, numbers.Number): 12 | self.size = (int(size), int(size)) 13 | else: 14 | self.size = size 15 | 16 | def __call__(self, img_group): 17 | 18 | w, h = img_group[0].size 19 | th, tw = self.size 20 | 21 | out_images = list() 22 | 23 | x1 = random.randint(0, w - tw) 24 | y1 = random.randint(0, h - th) 25 | 26 | for img in img_group: 27 | assert(img.size[0] == w and img.size[1] == h) 28 | if w == tw and h == th: 29 | out_images.append(img) 30 | else: 31 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 32 | 33 | return out_images 34 | 35 | class GroupCenterCrop(object): 36 | def __init__(self, size): 37 | self.worker = torchvision.transforms.CenterCrop(size) 38 | 39 | def __call__(self, img_group): 40 | return [self.worker(img) for img in img_group] 41 | 42 | 43 | class GroupRandomHorizontalFlip(object): 44 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 45 | There is no need to define an init function. 46 | """ 47 | def __call__(self, img_group): 48 | v = random.random() 49 | if v < 0.5: 50 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 51 | return ret 52 | else: 53 | return img_group 54 | 55 | class GroupNormalize(object): 56 | def __init__(self, 57 | mean=[0.485, 0.456, 0.406], 58 | std=[0.229, 0.224, 0.225]): 59 | self.mean = mean 60 | self.std = std 61 | 62 | def __call__(self, tensor): 63 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 64 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 65 | 66 | # TODO: make efficient 67 | for t, m, s in zip(tensor, rep_mean, rep_std): 68 | t.sub_(m).div_(s) 69 | 70 | return tensor 71 | 72 | 73 | class GroupScale(object): 74 | """ Rescales the input PIL.Image to the given 'size'. 75 | 'size' will be the size of the smaller edge. 76 | For example, if height > width, then image will be 77 | rescaled to (size * height / width, size) 78 | size: size of the smaller edge 79 | interpolation: Default: PIL.Image.BILINEAR 80 | """ 81 | 82 | def __init__(self, size, interpolation=Image.BILINEAR): 83 | self.worker = torchvision.transforms.Resize(size, interpolation) 84 | 85 | def __call__(self, img_group): 86 | return [self.worker(img) for img in img_group] 87 | 88 | class GroupRandomScale(object): 89 | """ Rescales the input PIL.Image to the given 'size'. 90 | 'size' will be the size of the smaller edge. 91 | For example, if height > width, then image will be 92 | rescaled to (size * height / width, size) 93 | size: size of the smaller edge 94 | interpolation: Default: PIL.Image.BILINEAR 95 | """ 96 | 97 | def __init__(self, smallest_size=256, largest_size=320, interpolation=Image.BILINEAR): 98 | self.smallest_size = smallest_size 99 | self.largest_size = largest_size 100 | self.interpolation = interpolation 101 | 102 | def __call__(self, img_group): 103 | size = random.randint(self.smallest_size, self.largest_size) 104 | # print(size) 105 | self.worker = torchvision.transforms.Resize(size, self.interpolation) 106 | return [self.worker(img) for img in img_group] 107 | 108 | class GroupOverSample(object): 109 | def __init__(self, crop_size, scale_size=None): 110 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 111 | 112 | if scale_size is not None: 113 | self.scale_worker = GroupScale(scale_size) 114 | else: 115 | self.scale_worker = None 116 | 117 | def __call__(self, img_group): 118 | 119 | if self.scale_worker is not None: 120 | img_group = self.scale_worker(img_group) 121 | 122 | image_w, image_h = img_group[0].size 123 | crop_w, crop_h = self.crop_size 124 | 125 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 126 | oversample_group = list() 127 | for o_w, o_h in offsets: 128 | normal_group = list() 129 | flip_group = list() 130 | for i, img in enumerate(img_group): 131 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 132 | normal_group.append(crop) 133 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 134 | flip_group.append(flip_crop) 135 | 136 | oversample_group.extend(normal_group) 137 | oversample_group.extend(flip_group) 138 | return oversample_group 139 | 140 | class GroupOverSampleKaiming(object): 141 | def __init__(self, crop_size, scale_size=None): 142 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 143 | 144 | if scale_size is not None: 145 | self.scale_worker = GroupScale(scale_size) 146 | else: 147 | self.scale_worker = None 148 | 149 | def __call__(self, img_group): 150 | 151 | if self.scale_worker is not None: 152 | img_group = self.scale_worker(img_group) 153 | 154 | image_w, image_h = img_group[0].size 155 | crop_w, crop_h = self.crop_size 156 | 157 | offsets = self.fill_fix_offset(image_w, image_h, crop_w, crop_h) 158 | oversample_group = list() 159 | for o_w, o_h in offsets: 160 | normal_group = list() 161 | # flip_group = list() 162 | for i, img in enumerate(img_group): 163 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 164 | normal_group.append(crop) 165 | # flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 166 | # flip_group.append(flip_crop) 167 | 168 | oversample_group.extend(normal_group) 169 | # oversample_group.extend(flip_group) 170 | return oversample_group 171 | 172 | def fill_fix_offset(self, image_w, image_h, crop_w, crop_h): 173 | # assert(crop_h == image_h), "In Kaiming mode, crop_h should equal to image_h" 174 | ret = list() 175 | if image_w == 256: 176 | h_step = (image_h - crop_h) // 4 177 | ret.append((0, 0)) # upper 178 | ret.append((0, 4 * h_step)) # down 179 | ret.append((0, 2 * h_step)) # center 180 | elif image_h == 256: 181 | w_step = (image_w - crop_w) // 4 182 | ret.append((0, 0)) # left 183 | ret.append((4 * w_step, 0)) # right 184 | ret.append((2 * w_step, 0)) # center 185 | else: 186 | raise ValueError("Either image_w or image_h should be equal to 256") 187 | 188 | return ret 189 | 190 | 191 | class GroupMultiScaleCrop(object): 192 | 193 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 194 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 195 | self.scales = scales if scales is not None else [1, .875, .75, .66] 196 | self.max_distort = max_distort 197 | self.fix_crop = fix_crop 198 | self.more_fix_crop = more_fix_crop 199 | self.interpolation = Image.BILINEAR 200 | 201 | def __call__(self, img_group): 202 | 203 | im_size = img_group[0].size 204 | 205 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 206 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 207 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 208 | for img in crop_img_group] 209 | return ret_img_group 210 | 211 | def _sample_crop_size(self, im_size): 212 | image_w, image_h = im_size[0], im_size[1] 213 | 214 | # find a crop size 215 | base_size = min(image_w, image_h) 216 | crop_sizes = [int(base_size * x) for x in self.scales] 217 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 218 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 219 | 220 | pairs = [] 221 | for i, h in enumerate(crop_h): 222 | for j, w in enumerate(crop_w): 223 | if abs(i - j) <= self.max_distort: 224 | pairs.append((w, h)) 225 | 226 | crop_pair = random.choice(pairs) 227 | if not self.fix_crop: 228 | w_offset = random.randint(0, image_w - crop_pair[0]) 229 | h_offset = random.randint(0, image_h - crop_pair[1]) 230 | else: 231 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 232 | 233 | return crop_pair[0], crop_pair[1], w_offset, h_offset 234 | 235 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 236 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 237 | return random.choice(offsets) 238 | 239 | @staticmethod 240 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 241 | w_step = (image_w - crop_w) // 4 242 | h_step = (image_h - crop_h) // 4 243 | 244 | ret = list() 245 | ret.append((0, 0)) # upper left 246 | ret.append((4 * w_step, 0)) # upper right 247 | ret.append((0, 4 * h_step)) # lower left 248 | ret.append((4 * w_step, 4 * h_step)) # lower right 249 | ret.append((2 * w_step, 2 * h_step)) # center 250 | 251 | if more_fix_crop: 252 | ret.append((0, 2 * h_step)) # center left 253 | ret.append((4 * w_step, 2 * h_step)) # center right 254 | ret.append((2 * w_step, 4 * h_step)) # lower center 255 | ret.append((2 * w_step, 0 * h_step)) # upper center 256 | 257 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 258 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 259 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 260 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 261 | 262 | return ret 263 | 264 | 265 | class GroupRandomSizedCrop(object): 266 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 267 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 268 | This is popularly used to train the Inception networks 269 | size: size of the smaller edge 270 | interpolation: Default: PIL.Image.BILINEAR 271 | """ 272 | def __init__(self, size, interpolation=Image.BILINEAR): 273 | self.size = size 274 | self.interpolation = interpolation 275 | 276 | def __call__(self, img_group): 277 | for attempt in range(10): 278 | area = img_group[0].size[0] * img_group[0].size[1] 279 | target_area = random.uniform(0.08, 1.0) * area 280 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 281 | 282 | w = int(round(math.sqrt(target_area * aspect_ratio))) 283 | h = int(round(math.sqrt(target_area / aspect_ratio))) 284 | 285 | if random.random() < 0.5: 286 | w, h = h, w 287 | 288 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 289 | x1 = random.randint(0, img_group[0].size[0] - w) 290 | y1 = random.randint(0, img_group[0].size[1] - h) 291 | found = True 292 | break 293 | else: 294 | found = False 295 | x1 = 0 296 | y1 = 0 297 | 298 | if found: 299 | out_group = list() 300 | for img in img_group: 301 | img = img.crop((x1, y1, x1 + w, y1 + h)) 302 | assert(img.size == (w, h)) 303 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 304 | return out_group 305 | else: 306 | # Fallback 307 | scale = GroupScale(self.size, interpolation=self.interpolation) 308 | crop = GroupRandomCrop(self.size) 309 | return crop(scale(img_group)) 310 | 311 | 312 | class Stack(object): 313 | 314 | def __init__(self, mode="3D"): 315 | """Support modes: ["3D", "TSN", "2D", "TSN+3D"] 316 | """ 317 | assert(mode in ["3D", "TSN+2D", "2D", "TSN+3D"]), "Unsupported mode: {}".format() 318 | self.mode = mode 319 | 320 | def __call__(self, img_group): 321 | """Only support RGB mode now 322 | img_group: list([h, w, c]) 323 | """ 324 | assert(img_group[0].mode == 'RGB'), "Must read images in RGB mode." 325 | if "3D" in self.mode: 326 | imgs = np.concatenate([np.array(img)[np.newaxis, ...] for img in img_group], axis=0) 327 | imgs = torch.from_numpy(imgs).permute(3, 0, 1, 2).contiguous() 328 | elif "2D" in self.mode: 329 | imgs = np.concatenate([np.array(img) for img in img_group], axis=2) 330 | imgs = torch.from_numpy(imgs).permute(2, 0, 1).contiguous() 331 | else: 332 | raise Exception("Unsupported mode.") 333 | return imgs 334 | 335 | 336 | class ToTorchFormatTensor(object): 337 | """ Converts a torch.Tensor in the range [0, 255] 338 | to a torch.FloatTensor in the range [0.0, 1.0] """ 339 | def __init__(self, div=True): 340 | self.div = div 341 | 342 | def __call__(self, imgs): 343 | assert(isinstance(imgs, torch.Tensor)), "pic must be torch.Tensor." 344 | return imgs.float().div(255) if self.div else img.float() 345 | 346 | 347 | class IdentityTransform(object): 348 | 349 | def __call__(self, data): 350 | return data 351 | 352 | 353 | if __name__ == "__main__": 354 | trans = torchvision.transforms.Compose([ 355 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]), 356 | Stack(mode="2D"), 357 | ToTorchFormatTensor(), 358 | GroupNormalize()] 359 | ) 360 | 361 | im = Image.open('/home/leizhou/CVPR2019/vid_cls/lena.png') 362 | 363 | color_group = [im] 364 | rst = trans(color_group) -------------------------------------------------------------------------------- /lib/utils/deprefix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import pdb 4 | 5 | parser = argparse.ArgumentParser(description="Remove PyTorch Model Prefix") 6 | parser.add_argument('src_model', type=str) 7 | parser.add_argument('dst_model', type=str) 8 | 9 | args = parser.parse_args() 10 | state_dict = torch.load(args.src_model, map_location=lambda storage, loc: storage) 11 | state_dict = state_dict['state_dict'] 12 | pdb.set_trace() 13 | state_dict = {('.'.join(k.split('.')[1:]) if "module" in k else k): v for k, v in state_dict.items()} 14 | state_dict = {('.'.join(k.split('.')[1:]) if "base_model" in k else k): v for k, v in state_dict.items()} 15 | torch.save(state_dict, args.dst_model) 16 | -------------------------------------------------------------------------------- /lib/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import logging 4 | import torch 5 | import shutil 6 | 7 | __all__ = ['AverageMeter', 'save_checkpoint', 'adjust_learning_rate', 'accuracy'] 8 | 9 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 10 | def __init__( 11 | self, 12 | optimizer, 13 | milestones, 14 | gamma=0.1, 15 | warmup_factor=1.0 / 3, 16 | warmup_iters=500, 17 | warmup_method="linear", 18 | last_epoch=-1, 19 | ): 20 | if not list(milestones) == sorted(milestones): 21 | raise ValueError( 22 | "Milestones should be a list of" " increasing integers. Got {}", 23 | milestones, 24 | ) 25 | 26 | if warmup_method not in ("constant", "linear"): 27 | raise ValueError( 28 | "Only 'constant' or 'linear' warmup_method accepted" 29 | "got {}".format(warmup_method) 30 | ) 31 | self.milestones = milestones 32 | self.gamma = gamma 33 | self.warmup_factor = warmup_factor 34 | self.warmup_iters = warmup_iters 35 | self.warmup_method = warmup_method 36 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | warmup_factor = 1 40 | if self.last_epoch < self.warmup_iters: 41 | if self.warmup_method == "constant": 42 | warmup_factor = self.warmup_factor 43 | elif self.warmup_method == "linear": 44 | alpha = float(self.last_epoch) / self.warmup_iters 45 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 46 | return [ 47 | base_lr 48 | * warmup_factor 49 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 50 | for base_lr in self.base_lrs 51 | ] 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | def __init__(self): 56 | self.reset() 57 | 58 | def reset(self): 59 | self.val = 0 60 | self.avg = 0 61 | self.sum = 0 62 | self.count = 0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count 69 | 70 | def save_checkpoint(state, is_best, epoch, experiment_root, filename='checkpoint_{}epoch.pth'): 71 | filename = os.path.join(experiment_root, filename.format(epoch)) 72 | logging.info("saving model to {}...".format(filename)) 73 | torch.save(state, filename) 74 | if is_best: 75 | best_name = os.path.join(experiment_root, 'model_best.pth') 76 | shutil.copyfile(filename, best_name) 77 | logging.info("saving done.") 78 | 79 | def adjust_learning_rate(optimizer, base_lr, epoch, lr_steps): 80 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 81 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 82 | lr = base_lr * decay 83 | for param_group in optimizer.param_groups: 84 | param_group['lr'] = lr 85 | 86 | def accuracy(output, target, topk=(1,)): 87 | """Computes the precision@k for the specified values of k""" 88 | maxk = max(topk) 89 | batch_size = target.size(0) 90 | 91 | _, pred = output.topk(maxk, 1, True, True) 92 | pred = pred.t() 93 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 94 | 95 | res = [] 96 | for k in topk: 97 | correct_k = correct[:k].view(-1).float().sum(0) 98 | res.append(correct_k.mul_(100.0 / batch_size)) 99 | return res -------------------------------------------------------------------------------- /lib/utils/vis_comb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | 4 | # files = os.listdir("log") 5 | # for ind, file in enumerate(files): 6 | # if not file.startswith('logfile'): 7 | # files.pop(ind) 8 | # print(files) 9 | # files = ["log/"+file for file in files] 10 | # files.sort() 11 | 12 | class log_parser(): 13 | def __init__(self, landmark, log_file, key_words=[], 14 | base_key_words=['Loss', 'Prec@1', 'Prec@5']): 15 | super(log_parser, self).__init__() 16 | with open(log_file) as f: 17 | self.lines = f.readlines() 18 | self.log_info = dict() 19 | self.landmark = landmark 20 | self.key_words = key_words + base_key_words 21 | for word in self.key_words: 22 | self.log_info[word] = [] 23 | 24 | def __add__(self, other): 25 | """Add two log parsers of same type. 26 | """ 27 | assert hasattr(self, "hist") and hasattr(other, "hist"), "Parse before adding." 28 | for key in self.hist.keys(): 29 | assert key in other.hist, "Mush share key when adding." 30 | self.hist[key].update(other.hist[key]) 31 | return self 32 | 33 | def parse(self): 34 | # parse info into list 35 | for line in self.lines: 36 | items = line.strip().split() 37 | if self.landmark not in items: 38 | continue 39 | for word in self.key_words: 40 | assert(word in items), "Key word should be in target line." 41 | for word in self.key_words: 42 | ind = items.index(word) + 1 43 | if word == "Epoch:": 44 | self.log_info[word].append(items[ind]) 45 | else: 46 | self.log_info[word].append(float(items[ind])) 47 | 48 | # convert epoch string 49 | self.convert_epoch_string() 50 | # find the key for the later dict 51 | if "Epoch:" in self.key_words: 52 | key = "Epoch:" 53 | else: 54 | key = "Epoch" 55 | 56 | # build hist 57 | self.hist = {} 58 | for word in self.key_words: 59 | if "Epoch" not in word: 60 | self.hist[word] = {} 61 | for k, v in zip(self.log_info[key], self.log_info[word]): 62 | self.hist[word].update({k: v}) 63 | 64 | def convert_epoch_string(self): 65 | if "Epoch:" in self.log_info: 66 | epochs = self.log_info['Epoch:'] 67 | for idx, epoch_str in enumerate(epochs): 68 | epoch_num, fraction = epoch_str[1:-2].split("][") 69 | epoch = float(epoch_num) + eval(fraction) 70 | epochs[idx] = epoch 71 | 72 | def plot(dir, tr_landmark="lr:", ts_landmark="Testing"): 73 | 74 | files = os.listdir(dir) 75 | for ind, file in enumerate(files): 76 | if not file.startswith('logfile'): 77 | files.pop(ind) 78 | # print(files) 79 | files = [os.path.join(dir, file) for file in files] 80 | files.sort() 81 | 82 | file = files[0] 83 | tr_parser_base = log_parser(tr_landmark, files[0], key_words=['Epoch:']) 84 | tr_parser_base.parse() 85 | ts_parser_base = log_parser(ts_landmark, files[0], key_words=['Epoch']) 86 | ts_parser_base.parse() 87 | if len(files) > 1: 88 | for file in files[1:]: 89 | tr_parser = log_parser(tr_landmark, file, key_words=['Epoch:']) 90 | tr_parser.parse() 91 | ts_parser = log_parser(ts_landmark, file, key_words=['Epoch']) 92 | ts_parser.parse() 93 | tr_parser_base += tr_parser 94 | ts_parser_base += ts_parser 95 | 96 | return ts_parser_base 97 | # fig, ax = plt.subplots() 98 | # ax.plot(tr_parser_base.hist['Loss'].keys(), tr_parser_base.hist['Loss'].values(), label='Train Loss') 99 | # ax.plot(ts_parser_base.hist['Loss'].keys(), ts_parser_base.hist['Loss'].values(), label='Val Loss') 100 | # ax.set(xlabel="Epoch", ylabel='Loss', title='Loss') 101 | # ax.grid() 102 | # ax.legend(loc='upper right', shadow=False, fontsize='x-large') 103 | # plt.show() 104 | 105 | # fig, ax = plt.subplots() 106 | # ax.plot(ts_parser_base.hist['Prec@1'].keys(), ts_parser_base.hist['Prec@1'].values(), label='Prec@1') 107 | # ax.plot(ts_parser_base.hist['Prec@5'].keys(), ts_parser_base.hist['Prec@5'].values(), 'g--', label='Prec@5') 108 | # ax.set(xlabel="Epoch", ylabel='Prec', title='Test Acc') 109 | # ax.grid() 110 | # ax.legend(loc='lower right', shadow=False, fontsize='x-large') 111 | # plt.show() 112 | 113 | def designated_plot(baseline_parser, sd2_st1_parser, sd2_st4_parser, sd4_st1_parser, sd4_st4_parser, sf5_st1_parser): 114 | fig, ax = plt.subplots() 115 | ax.plot(baseline_parser.hist['Prec@1'].keys(), baseline_parser.hist['Prec@1'].values(), label='FST') 116 | ax.plot(sd2_st1_parser.hist['Prec@1'].keys(), sd2_st1_parser.hist['Prec@1'].values(), label='dilation2_stage1') 117 | ax.plot(sd2_st4_parser.hist['Prec@1'].keys(), sd2_st4_parser.hist['Prec@1'].values(), label='dilation2_stage4') 118 | ax.plot(sd4_st1_parser.hist['Prec@1'].keys(), sd4_st1_parser.hist['Prec@1'].values(), label='dilation4_stage1') 119 | ax.plot(sd4_st4_parser.hist['Prec@1'].keys(), sd4_st4_parser.hist['Prec@1'].values(), label='dilation4_stage4') 120 | ax.plot(sf5_st1_parser.hist['Prec@1'].keys(), sf5_st1_parser.hist['Prec@1'].values(), label='s_kernel5_stage1') 121 | ax.set(xlabel="Epoch", ylabel='Prec', title='Test Acc') 122 | ax.grid() 123 | ax.legend(loc='lower right', shadow=False, fontsize='x-large') 124 | plt.show() 125 | 126 | if __name__ == "__main__": 127 | baseline_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_x4_3D_length16_stride4_dropout0.2/log') 128 | sd2_st1_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd2_st1_x4_3D_length16_stride4_dropout0.2/log') 129 | sd2_st4_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd2_st4_x4_3D_length16_stride4_dropout0.2/log') 130 | sd4_st1_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd4_st1_x4_3D_length16_stride4_dropout0.2/log') 131 | sd4_st4_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd4_st4_x4_3D_length16_stride4_dropout0.2/log') 132 | sf5_st1_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sf5_st1_x4_3D_length16_stride4_dropout0.2/log') 133 | designated_plot(baseline_parser, sd2_st1_parser, sd2_st4_parser, sd4_st1_parser, sd4_st4_parser, sf5_st1_parser) -------------------------------------------------------------------------------- /lib/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | 4 | files = os.listdir("log") 5 | for ind, file in enumerate(files): 6 | if not file.startswith('logfile'): 7 | files.pop(ind) 8 | print(files) 9 | files = ["log/"+file for file in files] 10 | files.sort() 11 | 12 | class log_parser(): 13 | def __init__(self, landmark, log_file, key_words=[], 14 | base_key_words=['Loss', 'Prec@1', 'Prec@5']): 15 | super(log_parser, self).__init__() 16 | with open(log_file) as f: 17 | self.lines = f.readlines() 18 | self.log_info = dict() 19 | self.landmark = landmark 20 | self.key_words = key_words + base_key_words 21 | for word in self.key_words: 22 | self.log_info[word] = [] 23 | 24 | def __add__(self, other): 25 | """Add two log parsers of same type. 26 | """ 27 | assert hasattr(self, "hist") and hasattr(other, "hist"), "Parse before adding." 28 | for key in self.hist.keys(): 29 | assert key in other.hist, "Mush share key when adding." 30 | self.hist[key].update(other.hist[key]) 31 | return self 32 | 33 | def parse(self): 34 | # parse info into list 35 | for line in self.lines: 36 | items = line.strip().split() 37 | if self.landmark not in items: 38 | continue 39 | for word in self.key_words: 40 | assert(word in items), "Key word should be in target line." 41 | for word in self.key_words: 42 | ind = items.index(word) + 1 43 | if word == "Epoch:": 44 | self.log_info[word].append(items[ind]) 45 | else: 46 | self.log_info[word].append(float(items[ind])) 47 | 48 | # convert epoch string 49 | self.convert_epoch_string() 50 | # find the key for the later dict 51 | if "Epoch:" in self.key_words: 52 | key = "Epoch:" 53 | else: 54 | key = "Epoch" 55 | 56 | # build hist 57 | self.hist = {} 58 | for word in self.key_words: 59 | if "Epoch" not in word: 60 | self.hist[word] = {} 61 | for k, v in zip(self.log_info[key], self.log_info[word]): 62 | self.hist[word].update({k: v}) 63 | 64 | def convert_epoch_string(self): 65 | if "Epoch:" in self.log_info: 66 | epochs = self.log_info['Epoch:'] 67 | for idx, epoch_str in enumerate(epochs): 68 | epoch_num, fraction = epoch_str[1:-2].split("][") 69 | epoch = float(epoch_num) + eval(fraction) 70 | epochs[idx] = epoch 71 | 72 | def plot(files, tr_landmark="lr:", ts_landmark="Testing"): 73 | if not isinstance(files, list): 74 | files = [files] 75 | 76 | file = files[0] 77 | tr_parser_base = log_parser(tr_landmark, files[0], key_words=['Epoch:']) 78 | tr_parser_base.parse() 79 | ts_parser_base = log_parser(ts_landmark, files[0], key_words=['Epoch']) 80 | ts_parser_base.parse() 81 | if len(files) > 1: 82 | for file in files[1:]: 83 | tr_parser = log_parser(tr_landmark, file, key_words=['Epoch:']) 84 | tr_parser.parse() 85 | ts_parser = log_parser(ts_landmark, file, key_words=['Epoch']) 86 | ts_parser.parse() 87 | tr_parser_base += tr_parser 88 | ts_parser_base += ts_parser 89 | 90 | fig, ax = plt.subplots() 91 | ax.plot(tr_parser_base.hist['Loss'].keys(), tr_parser_base.hist['Loss'].values(), label='Train Loss') 92 | ax.plot(ts_parser_base.hist['Loss'].keys(), ts_parser_base.hist['Loss'].values(), label='Val Loss') 93 | ax.set(xlabel="Epoch", ylabel='Loss', title='Loss') 94 | ax.grid() 95 | ax.legend(loc='upper right', shadow=False, fontsize='x-large') 96 | plt.show() 97 | 98 | fig, ax = plt.subplots() 99 | ax.plot(ts_parser_base.hist['Prec@1'].keys(), ts_parser_base.hist['Prec@1'].values(), label='Prec@1') 100 | ax.plot(ts_parser_base.hist['Prec@5'].keys(), ts_parser_base.hist['Prec@5'].values(), 'g--', label='Prec@5') 101 | ax.set(xlabel="Epoch", ylabel='Prec', title='Test Acc') 102 | ax.grid() 103 | ax.legend(loc='lower right', shadow=False, fontsize='x-large') 104 | plt.show() 105 | 106 | if __name__ == "__main__": 107 | plot(files) 108 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | import torchvision 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | from lib.dataset import VideoDataSet 14 | from lib.models import VideoModule 15 | from lib.transforms import * 16 | from lib.utils.tools import * 17 | from lib.opts import args 18 | from lib.modules import * 19 | 20 | from train_val import train, validate 21 | 22 | best_metric = 0 23 | 24 | def main(): 25 | global args, best_metric 26 | 27 | # specify dataset 28 | if 'ucf101' in args.dataset: 29 | num_class = 101 30 | elif 'hmdb51' in args.dataset: 31 | num_class = 51 32 | elif args.dataset == 'kinetics400': 33 | num_class = 400 34 | elif args.dataset == 'kinetics200': 35 | num_class = 200 36 | else: 37 | raise ValueError('Unknown dataset '+args.dataset) 38 | 39 | # data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 40 | # "data/{}/access".format(args.dataset)) 41 | 42 | if "ucf101" in args.dataset or "hmdb51" in args.dataset: 43 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 44 | "data/{}/access".format(args.dataset[:-3])) 45 | else: 46 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 47 | "data/{}/access".format(args.dataset)) 48 | 49 | # create model 50 | org_model = VideoModule(num_class=num_class, 51 | base_model_name=args.arch, 52 | dropout=args.dropout, 53 | pretrained=args.pretrained, 54 | pretrained_model=args.pretrained_model) 55 | num_params = 0 56 | for param in org_model.parameters(): 57 | num_params += param.reshape((-1, 1)).shape[0] 58 | logging.info("Model Size is {:.3f}M".format(num_params/1000000)) 59 | 60 | model = torch.nn.DataParallel(org_model).cuda() 61 | # model = org_model 62 | 63 | # define loss function (criterion) and optimizer 64 | criterion = torch.nn.CrossEntropyLoss().cuda() 65 | 66 | optimizer = torch.optim.SGD(model.parameters(), 67 | args.lr, 68 | momentum=args.momentum, 69 | weight_decay=args.weight_decay) 70 | 71 | # optionally resume from a checkpoint 72 | if args.resume: 73 | if os.path.isfile(args.resume): 74 | print(("=> loading checkpoint '{}'".format(args.resume))) 75 | checkpoint = torch.load(args.resume) 76 | args.start_epoch = checkpoint['epoch'] 77 | best_metric = checkpoint['best_metric'] 78 | model.load_state_dict(checkpoint['state_dict']) 79 | optimizer.load_state_dict(checkpoint['optimizer']) 80 | print(("=> loaded checkpoint '{}' (epoch {})" 81 | .format(args.resume, checkpoint['epoch']))) 82 | else: 83 | print(("=> no checkpoint found at '{}'".format(args.resume))) 84 | 85 | # Data loading code 86 | ## train data 87 | # train_transform = torchvision.transforms.Compose([ 88 | # GroupScale(args.new_size), 89 | # GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]), 90 | # GroupRandomHorizontalFlip(), 91 | # Stack(mode=args.mode), 92 | # ToTorchFormatTensor(), 93 | # GroupNormalize(), 94 | # ]) 95 | train_transform = torchvision.transforms.Compose([ 96 | GroupRandomScale(), 97 | GroupRandomCrop(size=args.crop_size), 98 | GroupRandomHorizontalFlip(), 99 | Stack(mode=args.mode), 100 | ToTorchFormatTensor(), 101 | GroupNormalize(), 102 | ]) 103 | train_dataset = VideoDataSet(root_path=data_root, 104 | list_file=args.train_list, 105 | t_length=args.t_length, 106 | t_stride=args.t_stride, 107 | num_segments=args.num_segments, 108 | image_tmpl=args.image_tmpl, 109 | transform=train_transform, 110 | phase="Train") 111 | train_loader = torch.utils.data.DataLoader( 112 | train_dataset, 113 | batch_size=args.batch_size, shuffle=True, drop_last=True, 114 | num_workers=args.workers, pin_memory=True) 115 | 116 | ## val data 117 | val_transform = torchvision.transforms.Compose([ 118 | GroupScale(args.new_size), 119 | GroupCenterCrop(args.crop_size), 120 | Stack(mode=args.mode), 121 | ToTorchFormatTensor(), 122 | GroupNormalize(), 123 | ]) 124 | val_dataset = VideoDataSet(root_path=data_root, 125 | list_file=args.val_list, 126 | t_length=args.t_length, 127 | t_stride=args.t_stride, 128 | num_segments=args.num_segments, 129 | image_tmpl=args.image_tmpl, 130 | transform=val_transform, 131 | phase="Val") 132 | val_loader = torch.utils.data.DataLoader( 133 | val_dataset, 134 | batch_size=args.batch_size, shuffle=False, 135 | num_workers=args.workers, pin_memory=True) 136 | 137 | if args.mode != "3D": 138 | cudnn.benchmark = True 139 | 140 | # validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 141 | # torch.cuda.empty_cache() 142 | if args.resume: 143 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 144 | torch.cuda.empty_cache() 145 | 146 | for epoch in range(args.start_epoch, args.epochs): 147 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps) 148 | 149 | # train for one epoch 150 | train(train_loader, model, criterion, optimizer, epoch, args.print_freq) 151 | 152 | # evaluate on validation set 153 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 154 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1) 155 | torch.cuda.empty_cache() 156 | 157 | # remember best prec@1 and save checkpoint 158 | is_best = metric > best_metric 159 | best_metric = max(metric, best_metric) 160 | save_checkpoint({ 161 | 'epoch': epoch + 1, 162 | 'arch': args.arch, 163 | 'state_dict': model.state_dict(), 164 | 'best_metric': best_metric, 165 | 'optimizer': optimizer.state_dict(), 166 | }, is_best, epoch + 1, args.experiment_root) 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /main_20bn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | import torchvision 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | from lib.dataset import VideoDataSet, ShortVideoDataSet 14 | from lib.models import VideoModule 15 | from lib.transforms import * 16 | from lib.utils.tools import * 17 | from lib.opts import args 18 | from lib.modules import * 19 | 20 | from train_val import train, validate 21 | 22 | best_metric = 0 23 | 24 | def main(): 25 | global args, best_metric 26 | 27 | # specify dataset 28 | if 'sthsth_v1' in args.dataset: 29 | num_class = 174 30 | elif 'sthsth_v2' in args.dataset: 31 | num_class = 174 32 | else: 33 | raise ValueError('Unknown dataset '+args.dataset) 34 | 35 | # data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 36 | # "data/{}/access".format(args.dataset)) 37 | 38 | if "ucf101" in args.dataset or "hmdb51" in args.dataset: 39 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 40 | "data/{}/access".format(args.dataset[:-3])) 41 | else: 42 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 43 | "data/{}/access".format(args.dataset)) 44 | 45 | # create model 46 | org_model = VideoModule(num_class=num_class, 47 | base_model_name=args.arch, 48 | dropout=args.dropout, 49 | pretrained=args.pretrained, 50 | pretrained_model=args.pretrained_model) 51 | num_params = 0 52 | for param in org_model.parameters(): 53 | num_params += param.reshape((-1, 1)).shape[0] 54 | logging.info("Model Size is {:.3f}M".format(num_params/1000000)) 55 | 56 | model = torch.nn.DataParallel(org_model).cuda() 57 | # model = org_model 58 | 59 | # define loss function (criterion) and optimizer 60 | criterion = torch.nn.CrossEntropyLoss().cuda() 61 | 62 | optimizer = torch.optim.SGD(model.parameters(), 63 | args.lr, 64 | momentum=args.momentum, 65 | weight_decay=args.weight_decay) 66 | 67 | # optionally resume from a checkpoint 68 | if args.resume: 69 | if os.path.isfile(args.resume): 70 | print(("=> loading checkpoint '{}'".format(args.resume))) 71 | checkpoint = torch.load(args.resume) 72 | args.start_epoch = checkpoint['epoch'] 73 | best_metric = checkpoint['best_metric'] 74 | model.load_state_dict(checkpoint['state_dict']) 75 | optimizer.load_state_dict(checkpoint['optimizer']) 76 | print(("=> loaded checkpoint '{}' (epoch {})" 77 | .format(args.resume, checkpoint['epoch']))) 78 | else: 79 | print(("=> no checkpoint found at '{}'".format(args.resume))) 80 | 81 | # Data loading code 82 | ## train data 83 | train_transform = torchvision.transforms.Compose([ 84 | GroupScale(args.new_size), 85 | GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]), 86 | # GroupRandomHorizontalFlip(), 87 | Stack(mode=args.mode), 88 | ToTorchFormatTensor(), 89 | GroupNormalize(), 90 | ]) 91 | train_dataset = VideoDataSet(root_path=data_root, 92 | list_file=args.train_list, 93 | t_length=args.t_length, 94 | t_stride=args.t_stride, 95 | num_segments=args.num_segments, 96 | image_tmpl=args.image_tmpl, 97 | transform=train_transform, 98 | phase="Train") 99 | train_loader = torch.utils.data.DataLoader( 100 | train_dataset, 101 | batch_size=args.batch_size, shuffle=True, drop_last=True, 102 | num_workers=args.workers, pin_memory=True) 103 | 104 | ## val data 105 | val_transform = torchvision.transforms.Compose([ 106 | GroupScale(args.new_size), 107 | GroupCenterCrop(args.crop_size), 108 | Stack(mode=args.mode), 109 | ToTorchFormatTensor(), 110 | GroupNormalize(), 111 | ]) 112 | val_dataset = ShortVideoDataSet(root_path=data_root, 113 | list_file=args.val_list, 114 | t_length=args.t_length, 115 | t_stride=args.t_stride, 116 | num_segments=args.num_segments, 117 | image_tmpl=args.image_tmpl, 118 | transform=val_transform, 119 | phase="Val") 120 | val_loader = torch.utils.data.DataLoader( 121 | val_dataset, 122 | batch_size=args.batch_size, shuffle=False, 123 | num_workers=args.workers, pin_memory=True) 124 | 125 | if args.mode != "3D": 126 | cudnn.benchmark = True 127 | 128 | if args.resume: 129 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 130 | torch.cuda.empty_cache() 131 | 132 | for epoch in range(args.start_epoch, args.epochs): 133 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps) 134 | 135 | # train for one epoch 136 | train(train_loader, model, criterion, optimizer, epoch, args.print_freq) 137 | 138 | # evaluate on validation set 139 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 140 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1) 141 | torch.cuda.empty_cache() 142 | 143 | # remember best prec@1 and save checkpoint 144 | is_best = metric > best_metric 145 | best_metric = max(metric, best_metric) 146 | save_checkpoint({ 147 | 'epoch': epoch + 1, 148 | 'arch': args.arch, 149 | 'state_dict': model.state_dict(), 150 | 'best_metric': best_metric, 151 | 'optimizer': optimizer.state_dict(), 152 | }, is_best, epoch + 1, args.experiment_root) 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import logging 6 | 7 | import torch 8 | import torchvision 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | from torchvision.models.resnet import * 16 | from torchvision.models.vgg import * 17 | from lib.networks.resnet import resnet26, resnet26_sc, resnet26_point 18 | from lib.networks.gsv_resnet_2d_v3 import gsv_resnet50_2d_v3 19 | from lib.modules import * 20 | from lib.utils.tools import * 21 | from lib.opts import args 22 | 23 | from train_val import train, validate 24 | 25 | best_metric = 0 26 | 27 | def main(): 28 | global args, best_metric 29 | 30 | # specify dataset 31 | if args.dataset == 'imagenet': 32 | num_class = 1000 33 | else: 34 | raise ValueError('Unknown dataset '+args.dataset) 35 | 36 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 37 | "data/{}/access".format(args.dataset)) 38 | 39 | # create model 40 | org_model = eval(args.arch)(pretrained=args.pretrained)#, feat=False, num_classes=num_class) 41 | num_params = 0 42 | for param in org_model.parameters(): 43 | num_params += param.reshape((-1, 1)).shape[0] 44 | print("Model Size is {:.3f}M".format(num_params/1000000)) 45 | 46 | model = torch.nn.DataParallel(org_model).cuda() 47 | # model = org_model 48 | 49 | # define loss function (criterion) and optimizer 50 | criterion = torch.nn.CrossEntropyLoss().cuda() 51 | 52 | # scale params 53 | scale_parameters = [] 54 | other_parameters = [] 55 | for m in model.modules(): 56 | if isinstance(m, Scale2d): 57 | scale_parameters.append(m.scale) 58 | elif isinstance(m, nn.Conv2d): 59 | other_parameters.append(m.weight) 60 | if m.bias is not None: 61 | other_parameters.append(m.bias) 62 | elif isinstance(m, nn.BatchNorm2d): 63 | other_parameters.append(m.weight) 64 | other_parameters.append(m.bias) 65 | elif isinstance(m, nn.Linear): 66 | other_parameters.append(m.weight) 67 | other_parameters.append(m.bias) 68 | 69 | optimizer = torch.optim.SGD([{"params": other_parameters}, 70 | {"params": scale_parameters, "weight_decay": 0}], 71 | args.lr, 72 | momentum=args.momentum, 73 | weight_decay=args.weight_decay) 74 | 75 | # optionally resume from a checkpoint 76 | if args.resume: 77 | if os.path.isfile(args.resume): 78 | print(("=> loading checkpoint '{}'".format(args.resume))) 79 | checkpoint = torch.load(args.resume) 80 | args.start_epoch = checkpoint['epoch'] 81 | best_metric = checkpoint['best_metric'] 82 | model.load_state_dict(checkpoint['state_dict']) 83 | optimizer.load_state_dict(checkpoint['optimizer']) 84 | print(("=> loaded checkpoint '{}' (epoch {})" 85 | .format(args.resume, checkpoint['epoch']))) 86 | else: 87 | print(("=> no checkpoint found at '{}'".format(args.resume))) 88 | 89 | train_transform = transforms.Compose([ 90 | transforms.RandomResizedCrop(224), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize( 94 | mean=[0.485, 0.456, 0.406], 95 | std=[0.229, 0.224, 0.225]), 96 | ]) 97 | train_dataset = datasets.ImageFolder( 98 | os.path.join(data_root, 'train'), 99 | train_transform) 100 | train_loader = torch.utils.data.DataLoader( 101 | train_dataset, batch_size=args.batch_size, shuffle=True, 102 | num_workers=args.workers, pin_memory=True) 103 | 104 | val_transform = transforms.Compose([ 105 | transforms.Resize(256), 106 | transforms.CenterCrop(224), 107 | transforms.ToTensor(), 108 | transforms.Normalize( 109 | mean=[0.485, 0.456, 0.406], 110 | std=[0.229, 0.224, 0.225]), 111 | ]) 112 | val_dataset = datasets.ImageFolder( 113 | os.path.join(data_root, 'val'), 114 | val_transform) 115 | val_loader = torch.utils.data.DataLoader( 116 | val_dataset, 117 | batch_size=args.batch_size, shuffle=False, 118 | num_workers=args.workers, pin_memory=True) 119 | 120 | cudnn.benchmark = True 121 | # validate(val_loader, model, criterion, args.print_freq, args.start_epoch) 122 | 123 | for epoch in range(args.start_epoch, args.epochs): 124 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps) 125 | 126 | # train for one epoch 127 | train(train_loader, model, criterion, optimizer, epoch, args.print_freq) 128 | 129 | # evaluate on validation set 130 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 131 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1) 132 | 133 | # remember best prec@1 and save checkpoint 134 | is_best = metric > best_metric 135 | best_metric = max(metric, best_metric) 136 | save_checkpoint({ 137 | 'epoch': epoch + 1, 138 | 'arch': args.arch, 139 | 'state_dict': model.state_dict(), 140 | 'best_metric': best_metric, 141 | 'optimizer': optimizer.state_dict(), 142 | }, is_best, epoch + 1, args.experiment_root) 143 | 144 | if __name__ == '__main__': 145 | main() 146 | 147 | 148 | 149 | 150 | 151 | 152 | # import argparse 153 | # import os 154 | # import random 155 | # import shutil 156 | # import time 157 | # import warnings 158 | 159 | # import torch 160 | # import torch.nn as nn 161 | # import torch.nn.parallel 162 | # import torch.backends.cudnn as cudnn 163 | # import torch.distributed as dist 164 | # import torch.optim 165 | # import torch.utils.data 166 | # import torch.utils.data.distributed 167 | # import torchvision.transforms as transforms 168 | # import torchvision.datasets as datasets 169 | # import torchvision.models as models 170 | # from lib.networks.mnet2 import * 171 | 172 | # model_names = sorted(name for name in models.__dict__ 173 | # if name.islower() and not name.startswith("__") 174 | # and callable(models.__dict__[name])) 175 | 176 | # parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 177 | # parser.add_argument('data', metavar='DIR', 178 | # help='path to dataset') 179 | # # parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 180 | # # choices=model_names, 181 | # # help='model architecture: ' + 182 | # # ' | '.join(model_names) + 183 | # # ' (default: resnet18)') 184 | # parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 185 | # help='model architecture') 186 | # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 187 | # help='number of data loading workers (default: 4)') 188 | # parser.add_argument('--epochs', default=90, type=int, metavar='N', 189 | # help='number of total epochs to run') 190 | # parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 191 | # help='manual epoch number (useful on restarts)') 192 | # parser.add_argument('-b', '--batch-size', default=256, type=int, 193 | # metavar='N', help='mini-batch size (default: 256)') 194 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 195 | # metavar='LR', help='initial learning rate') 196 | # parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 197 | # help='momentum') 198 | # parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 199 | # metavar='W', help='weight decay (default: 1e-4)') 200 | # parser.add_argument('--print-freq', '-p', default=10, type=int, 201 | # metavar='N', help='print frequency (default: 10)') 202 | # parser.add_argument('--resume', default='', type=str, metavar='PATH', 203 | # help='path to latest checkpoint (default: none)') 204 | # parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 205 | # help='evaluate model on validation set') 206 | # parser.add_argument('--pretrained', dest='pretrained', action='store_true', 207 | # help='use pre-trained model') 208 | # parser.add_argument('--world-size', default=1, type=int, 209 | # help='number of distributed processes') 210 | # parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 211 | # help='url used to set up distributed training') 212 | # parser.add_argument('--dist-backend', default='gloo', type=str, 213 | # help='distributed backend') 214 | # parser.add_argument('--seed', default=None, type=int, 215 | # help='seed for initializing training. ') 216 | # parser.add_argument('--gpu', default=None, type=int, 217 | # help='GPU id to use.') 218 | 219 | # best_prec1 = 0 220 | 221 | 222 | # def main(): 223 | # global args, best_prec1 224 | # args = parser.parse_args() 225 | 226 | # if args.seed is not None: 227 | # random.seed(args.seed) 228 | # torch.manual_seed(args.seed) 229 | # cudnn.deterministic = True 230 | # warnings.warn('You have chosen to seed training. ' 231 | # 'This will turn on the CUDNN deterministic setting, ' 232 | # 'which can slow down your training considerably! ' 233 | # 'You may see unexpected behavior when restarting ' 234 | # 'from checkpoints.') 235 | 236 | # if args.gpu is not None: 237 | # warnings.warn('You have chosen a specific GPU. This will completely ' 238 | # 'disable data parallelism.') 239 | 240 | # args.distributed = args.world_size > 1 241 | 242 | # if args.distributed: 243 | # dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 244 | # world_size=args.world_size) 245 | 246 | # # create model 247 | # if args.pretrained: 248 | # print("=> using pre-trained model '{}'".format(args.arch)) 249 | # if args.arch in models.__dict__: 250 | # model = models.__dict__[args.arch](pretrained=True) 251 | # elif args.arch == "mnet2": 252 | # model = mnet2("/home/leizhou/CVPR2019/vid_cls/models/mobilenet_v2.pth.tar") 253 | # else: 254 | # print("=> creating model '{}'".format(args.arch)) 255 | # if args.arch in models.__dict__: 256 | # model = models.__dict__[args.arch]() 257 | # else: 258 | # model = eval(args.arch)() 259 | 260 | # if args.gpu is not None: 261 | # model = model.cuda(args.gpu) 262 | # elif args.distributed: 263 | # model.cuda() 264 | # model = torch.nn.parallel.DistributedDataParallel(model) 265 | # else: 266 | # if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 267 | # model.features = torch.nn.DataParallel(model.features) 268 | # model.cuda() 269 | # else: 270 | # model = torch.nn.DataParallel(model).cuda() 271 | 272 | # # define loss function (criterion) and optimizer 273 | # criterion = nn.CrossEntropyLoss().cuda(args.gpu) 274 | 275 | # optimizer = torch.optim.SGD(model.parameters(), args.lr, 276 | # momentum=args.momentum, 277 | # weight_decay=args.weight_decay) 278 | 279 | # # optionally resume from a checkpoint 280 | # if args.resume: 281 | # if os.path.isfile(args.resume): 282 | # print("=> loading checkpoint '{}'".format(args.resume)) 283 | # checkpoint = torch.load(args.resume) 284 | # args.start_epoch = checkpoint['epoch'] 285 | # best_prec1 = checkpoint['best_prec1'] 286 | # model.load_state_dict(checkpoint['state_dict']) 287 | # optimizer.load_state_dict(checkpoint['optimizer']) 288 | # print("=> loaded checkpoint '{}' (epoch {})" 289 | # .format(args.resume, checkpoint['epoch'])) 290 | # else: 291 | # print("=> no checkpoint found at '{}'".format(args.resume)) 292 | 293 | # cudnn.benchmark = True 294 | 295 | # # Data loading code 296 | # traindir = os.path.join(args.data, 'train') 297 | # valdir = os.path.join(args.data, 'val') 298 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 299 | # std=[0.229, 0.224, 0.225]) 300 | 301 | # train_dataset = datasets.ImageFolder( 302 | # traindir, 303 | # transforms.Compose([ 304 | # transforms.RandomResizedCrop(224), 305 | # transforms.RandomHorizontalFlip(), 306 | # transforms.ToTensor(), 307 | # normalize, 308 | # ])) 309 | 310 | # if args.distributed: 311 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 312 | # else: 313 | # train_sampler = None 314 | 315 | # train_loader = torch.utils.data.DataLoader( 316 | # train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 317 | # num_workers=args.workers, pin_memory=True, sampler=train_sampler) 318 | 319 | # val_loader = torch.utils.data.DataLoader( 320 | # datasets.ImageFolder(valdir, transforms.Compose([ 321 | # transforms.Resize(256), 322 | # transforms.CenterCrop(224), 323 | # transforms.ToTensor(), 324 | # normalize, 325 | # ])), 326 | # batch_size=args.batch_size, shuffle=False, 327 | # num_workers=args.workers, pin_memory=True) 328 | 329 | # if args.evaluate: 330 | # validate(val_loader, model, criterion) 331 | # return 332 | 333 | # for epoch in range(args.start_epoch, args.epochs): 334 | # if args.distributed: 335 | # train_sampler.set_epoch(epoch) 336 | # adjust_learning_rate(optimizer, epoch) 337 | 338 | # # train for one epoch 339 | # train(train_loader, model, criterion, optimizer, epoch) 340 | 341 | # # evaluate on validation set 342 | # prec1 = validate(val_loader, model, criterion) 343 | 344 | # # remember best prec@1 and save checkpoint 345 | # is_best = prec1 > best_prec1 346 | # best_prec1 = max(prec1, best_prec1) 347 | # save_checkpoint({ 348 | # 'epoch': epoch + 1, 349 | # 'arch': args.arch, 350 | # 'state_dict': model.state_dict(), 351 | # 'best_prec1': best_prec1, 352 | # 'optimizer' : optimizer.state_dict(), 353 | # }, is_best) 354 | 355 | 356 | # def train(train_loader, model, criterion, optimizer, epoch): 357 | # batch_time = AverageMeter() 358 | # data_time = AverageMeter() 359 | # losses = AverageMeter() 360 | # top1 = AverageMeter() 361 | # top5 = AverageMeter() 362 | 363 | # # switch to train mode 364 | # model.train() 365 | 366 | # end = time.time() 367 | # for i, (input, target) in enumerate(train_loader): 368 | # # measure data loading time 369 | # data_time.update(time.time() - end) 370 | 371 | # if args.gpu is not None: 372 | # input = input.cuda(args.gpu, non_blocking=True) 373 | # target = target.cuda(args.gpu, non_blocking=True) 374 | 375 | # # compute output 376 | # output = model(input) 377 | # loss = criterion(output, target) 378 | 379 | # # measure accuracy and record loss 380 | # prec1, prec5 = accuracy(output, target, topk=(1, 5)) 381 | # losses.update(loss.item(), input.size(0)) 382 | # top1.update(prec1[0], input.size(0)) 383 | # top5.update(prec5[0], input.size(0)) 384 | 385 | # # compute gradient and do SGD step 386 | # optimizer.zero_grad() 387 | # loss.backward() 388 | # optimizer.step() 389 | 390 | # # measure elapsed time 391 | # batch_time.update(time.time() - end) 392 | # end = time.time() 393 | 394 | # if i % args.print_freq == 0: 395 | # print('Epoch: [{0}][{1}/{2}]\t' 396 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 397 | # 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 398 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 399 | # 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 400 | # 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 401 | # epoch, i, len(train_loader), batch_time=batch_time, 402 | # data_time=data_time, loss=losses, top1=top1, top5=top5)) 403 | 404 | 405 | # def validate(val_loader, model, criterion): 406 | # batch_time = AverageMeter() 407 | # losses = AverageMeter() 408 | # top1 = AverageMeter() 409 | # top5 = AverageMeter() 410 | 411 | # # switch to evaluate mode 412 | # model.eval() 413 | 414 | # with torch.no_grad(): 415 | # end = time.time() 416 | # for i, (input, target) in enumerate(val_loader): 417 | # if args.gpu is not None: 418 | # input = input.cuda(args.gpu, non_blocking=True) 419 | # target = target.cuda(args.gpu, non_blocking=True) 420 | 421 | # # compute output 422 | # output = model(input) 423 | # loss = criterion(output, target) 424 | 425 | # # measure accuracy and record loss 426 | # prec1, prec5 = accuracy(output, target, topk=(1, 5)) 427 | # losses.update(loss.item(), input.size(0)) 428 | # top1.update(prec1[0], input.size(0)) 429 | # top5.update(prec5[0], input.size(0)) 430 | 431 | # # measure elapsed time 432 | # batch_time.update(time.time() - end) 433 | # end = time.time() 434 | 435 | # if i % args.print_freq == 0: 436 | # print('Test: [{0}/{1}]\t' 437 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 438 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 439 | # 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 440 | # 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 441 | # i, len(val_loader), batch_time=batch_time, loss=losses, 442 | # top1=top1, top5=top5)) 443 | 444 | # print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 445 | # .format(top1=top1, top5=top5)) 446 | 447 | # return top1.avg 448 | 449 | 450 | # def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 451 | # torch.save(state, filename) 452 | # if is_best: 453 | # shutil.copyfile(filename, 'model_best.pth.tar') 454 | 455 | 456 | # class AverageMeter(object): 457 | # """Computes and stores the average and current value""" 458 | # def __init__(self): 459 | # self.reset() 460 | 461 | # def reset(self): 462 | # self.val = 0 463 | # self.avg = 0 464 | # self.sum = 0 465 | # self.count = 0 466 | 467 | # def update(self, val, n=1): 468 | # self.val = val 469 | # self.sum += val * n 470 | # self.count += n 471 | # self.avg = self.sum / self.count 472 | 473 | 474 | # def adjust_learning_rate(optimizer, epoch): 475 | # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 476 | # lr = args.lr * (0.1 ** (epoch // 30)) 477 | # for param_group in optimizer.param_groups: 478 | # param_group['lr'] = lr 479 | 480 | 481 | # def accuracy(output, target, topk=(1,)): 482 | # """Computes the precision@k for the specified values of k""" 483 | # with torch.no_grad(): 484 | # maxk = max(topk) 485 | # batch_size = target.size(0) 486 | 487 | # _, pred = output.topk(maxk, 1, True, True) 488 | # pred = pred.t() 489 | # correct = pred.eq(target.view(1, -1).expand_as(pred)) 490 | 491 | # res = [] 492 | # for k in topk: 493 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 494 | # res.append(correct_k.mul_(100.0 / batch_size)) 495 | # return res 496 | 497 | 498 | # if __name__ == '__main__': 499 | # main() -------------------------------------------------------------------------------- /scripts/imagenet_2d_res26.sh: -------------------------------------------------------------------------------- 1 | # CUDA_LAUNCH_BLOCKING=1 \ 2 | CUDA_VISIBLE_DEVICES=4,5,6,7 \ 3 | python main_imagenet.py \ 4 | imagenet \ 5 | placeholder \ 6 | placeholder \ 7 | --arch resnet26 \ 8 | --epochs 100 \ 9 | --batch-size 512 \ 10 | --lr 0.1 \ 11 | --lr_steps 30 50 70 90 \ 12 | --workers 20 \ 13 | --weight-decay 0.0001 \ 14 | --eval-freq 1 \ 15 | -------------------------------------------------------------------------------- /scripts/kinetics400_3d_res50_slowonly_im_pre.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 2 | python main.py \ 3 | kinetics400 \ 4 | data/kinetics400/kinetics_train_list_xlw \ 5 | data/kinetics400/kinetics_val_list_xlw \ 6 | --arch resnet50_3d_slowonly \ 7 | --dro 0.5 \ 8 | --mode 3D \ 9 | --t_length 8 \ 10 | --t_stride 8 \ 11 | --pretrained \ 12 | --epochs 110 \ 13 | --batch-size 96 \ 14 | --lr 0.02 \ 15 | --wd 0.0001 \ 16 | --lr_steps 50 80 100 \ 17 | --workers 16 \ 18 | 19 | python ./test_kaiming.py \ 20 | kinetics400 \ 21 | data/kinetics400/kinetics_val_list_xlw \ 22 | output/kinetics400_resnet50_3d_slowonly_3D_length8_stride8_dropout0.5/model_best.pth \ 23 | --arch resnet50_3d_slowonly \ 24 | --mode TSN+3D \ 25 | --batch_size 1 \ 26 | --num_segments 10 \ 27 | --input_size 256 \ 28 | --t_length 8 \ 29 | --t_stride 8 \ 30 | --dropout 0.5 \ 31 | --workers 12 \ 32 | --image_tmpl image_{:06d}.jpg \ -------------------------------------------------------------------------------- /test_10crop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import numpy as np 5 | import torch.nn.parallel 6 | import torch.optim 7 | # from sklearn.metrics import confusion_matrix 8 | 9 | from lib.dataset import VideoDataSet 10 | from lib.models import VideoModule, TSN 11 | from lib.transforms import * 12 | from lib.utils.tools import AverageMeter, accuracy 13 | 14 | import pdb 15 | 16 | # options 17 | parser = argparse.ArgumentParser( 18 | description="Standard video-level testing") 19 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics400', 'kinetics200']) 20 | parser.add_argument('test_list', type=str) 21 | parser.add_argument('weights', type=str) 22 | parser.add_argument('--arch', type=str, default="resnet50_3d_v1") 23 | parser.add_argument('--mode', type=str, default="TSN+3D") 24 | parser.add_argument('--save_scores', type=str, default=None) 25 | parser.add_argument('--batch_size', type=int, default=2) 26 | parser.add_argument('--num_segments', type=int, default=10) 27 | parser.add_argument('--input_size', type=int, default=224) 28 | parser.add_argument('--resize', type=int, default=256) 29 | parser.add_argument('--t_length', type=int, default=16) 30 | parser.add_argument('--t_stride', type=int, default=4) 31 | parser.add_argument('--crop_fusion_type', type=str, default='avg', 32 | choices=['avg', 'max', 'topk']) 33 | parser.add_argument('--image_tmpl', type=str) 34 | parser.add_argument('--dropout', type=float, default=0.2) 35 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | 38 | args = parser.parse_args() 39 | 40 | def main(): 41 | if args.dataset == 'ucf101': 42 | num_class = 101 43 | elif args.dataset == 'hmdb51': 44 | num_class = 51 45 | elif args.dataset == 'kinetics400': 46 | num_class = 400 47 | elif args.dataset == 'kinetics200': 48 | num_class = 200 49 | else: 50 | raise ValueError('Unknown dataset '+args.dataset) 51 | 52 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 53 | "data/{}/access".format(args.dataset)) 54 | 55 | net = VideoModule(num_class=num_class, 56 | base_model_name=args.arch, 57 | dropout=args.dropout, 58 | pretrained=False) 59 | 60 | # compute params number of a model 61 | num_params = 0 62 | for param in net.parameters(): 63 | num_params += param.reshape((-1, 1)).shape[0] 64 | print("Model Size is {:.3f}M".format(num_params / 1000000)) 65 | 66 | net = torch.nn.DataParallel(net).cuda() 67 | net.eval() 68 | 69 | # load weights 70 | model_state = torch.load(args.weights) 71 | state_dict = model_state['state_dict'] 72 | test_epoch = model_state['epoch'] 73 | arch = model_state['arch'] 74 | assert arch == args.arch 75 | net.load_state_dict(state_dict) 76 | tsn = TSN(args.batch_size, net, 77 | args.num_segments, args.t_length, 78 | crop_fusion_type=args.crop_fusion_type, 79 | mode=args.mode).cuda() 80 | 81 | ## test data 82 | test_transform = torchvision.transforms.Compose([ 83 | GroupOverSample(args.input_size, args.resize), 84 | Stack(mode=args.mode), 85 | ToTorchFormatTensor(), 86 | GroupNormalize(), 87 | ]) 88 | test_dataset = VideoDataSet( 89 | root_path=data_root, 90 | list_file=args.test_list, 91 | t_length=args.t_length, 92 | t_stride=args.t_stride, 93 | num_segments=args.num_segments, 94 | image_tmpl=args.image_tmpl, 95 | transform=test_transform, 96 | phase="Test") 97 | test_loader = torch.utils.data.DataLoader( 98 | test_dataset, 99 | batch_size=args.batch_size, shuffle=False, 100 | num_workers=args.workers, pin_memory=True) 101 | 102 | # Test 103 | batch_timer = AverageMeter() 104 | top1 = AverageMeter() 105 | top5 = AverageMeter() 106 | results = None 107 | 108 | # set eval mode 109 | tsn.eval() 110 | 111 | end = time.time() 112 | for ind, (data, label) in enumerate(test_loader): 113 | label = label.cuda(non_blocking=True) 114 | 115 | with torch.no_grad(): 116 | output, pred, _, _ = tsn(data) 117 | prec1, prec5 = accuracy(pred, label, topk=(1, 5)) 118 | top1.update(prec1.item(), data.shape[0]) 119 | top5.update(prec5.item(), data.shape[0]) 120 | 121 | # pdb.set_trace() 122 | batch_timer.update(time.time() - end) 123 | end = time.time() 124 | if results is not None: 125 | np.concatenate((results, output.cpu().numpy()), axis=0) 126 | else: 127 | results = output.cpu().numpy() 128 | print("{0}/{1} done, Batch: {batch_timer.val:.3f}({batch_timer.avg:.3f}), \ 129 | Top1: {top1.val:>6.3f}({top1.avg:>6.3f}), \ 130 | Top5: {top5.val:>6.3f}({top5.avg:>6.3f})". 131 | format(ind + 1, len(test_loader), 132 | batch_timer=batch_timer, 133 | top1=top1, top5=top5)) 134 | target_file = os.path.join(args.save_scores, "arch_{0}-epoch_{1}-top1_{2}-top5_{3}.npz".format(arch, test_epoch, top1.avg, top5.avg)) 135 | print("saving {}".format(target_file)) 136 | np.savez(target_file, results) 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /test_kaiming.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import numpy as np 5 | import torch.nn.parallel 6 | import torch.optim 7 | # from sklearn.metrics import confusion_matrix 8 | 9 | from lib.dataset import VideoDataSet, ShortVideoDataSet 10 | from lib.models import VideoModule, TSN 11 | from lib.transforms import * 12 | from lib.utils.tools import AverageMeter, accuracy 13 | 14 | import pdb 15 | import logging 16 | 17 | def set_logger(debug_mode=False): 18 | import time 19 | from time import gmtime, strftime 20 | logdir = os.path.join(args.experiment_root, 'log') 21 | if not os.path.exists(logdir): 22 | os.makedirs(logdir) 23 | log_file = "logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", time.localtime()) 24 | log_file = os.path.join(logdir, log_file) 25 | handlers = [logging.FileHandler(log_file), logging.StreamHandler()] 26 | 27 | """ add '%(filename)s:%(lineno)d %(levelname)s:' to format show source file """ 28 | logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO, 29 | format='%(asctime)s: %(message)s', 30 | datefmt='%Y-%m-%d %H:%M:%S', 31 | handlers = handlers) 32 | 33 | # options 34 | parser = argparse.ArgumentParser( 35 | description="Standard video-level testing") 36 | parser.add_argument('dataset', type=str) 37 | parser.add_argument('test_list', type=str) 38 | parser.add_argument('weights', type=str) 39 | parser.add_argument('--arch', type=str, default="resnet50_3d_v1") 40 | parser.add_argument('--mode', type=str, default="TSN+3D") 41 | # parser.add_argument('--save_scores', type=str, default=None) 42 | parser.add_argument('--batch_size', type=int, default=2) 43 | parser.add_argument('--num_segments', type=int, default=10) 44 | parser.add_argument('--input_size', type=int, default=224) 45 | parser.add_argument('--resize', type=int, default=256) 46 | parser.add_argument('--t_length', type=int, default=16) 47 | parser.add_argument('--t_stride', type=int, default=4) 48 | # parser.add_argument('--crop_fusion_type', type=str, default='avg', 49 | # choices=['avg', 'max', 'topk']) 50 | parser.add_argument('--image_tmpl', type=str) 51 | parser.add_argument('--dropout', type=float, default=0.2) 52 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', 53 | help='number of data loading workers (default: 4)') 54 | 55 | args = parser.parse_args() 56 | 57 | experiment_id = '_'.join(map(str, ['test', args.dataset, args.arch, args.mode, 58 | 'length'+str(args.t_length), 'stride'+str(args.t_stride), 59 | 'seg'+str(args.num_segments)])) 60 | 61 | args.experiment_root = os.path.join('./output', experiment_id) 62 | 63 | set_logger() 64 | logging.info(args) 65 | if not os.path.exists(args.experiment_root): 66 | os.makedirs(args.experiment_root) 67 | 68 | def main(): 69 | if args.dataset == 'ucf101': 70 | num_class = 101 71 | elif args.dataset == 'hmdb51': 72 | num_class = 51 73 | elif args.dataset == 'kinetics400': 74 | num_class = 400 75 | elif args.dataset == 'kinetics200': 76 | num_class = 200 77 | elif args.dataset == 'sthsth_v1': 78 | num_class = 174 79 | else: 80 | raise ValueError('Unknown dataset '+args.dataset) 81 | 82 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 83 | "data/{}/access".format(args.dataset)) 84 | 85 | net = VideoModule(num_class=num_class, 86 | base_model_name=args.arch, 87 | dropout=args.dropout, 88 | pretrained=False) 89 | 90 | # compute params number of a model 91 | num_params = 0 92 | for param in net.parameters(): 93 | num_params += param.reshape((-1, 1)).shape[0] 94 | logging.info("Model Size is {:.3f}M".format(num_params / 1000000)) 95 | 96 | net = torch.nn.DataParallel(net).cuda() 97 | net.eval() 98 | 99 | # load weights 100 | model_state = torch.load(args.weights) 101 | state_dict = model_state['state_dict'] 102 | test_epoch = model_state['epoch'] 103 | best_metric = model_state['best_metric'] 104 | arch = model_state['arch'] 105 | logging.info("Model Epoch: {}; Best_Top1: {}".format(test_epoch, best_metric)) 106 | assert arch == args.arch 107 | net.load_state_dict(state_dict) 108 | tsn = TSN(args.batch_size, net, 109 | args.num_segments, args.t_length, 110 | mode=args.mode).cuda() 111 | 112 | ## test data 113 | test_transform = torchvision.transforms.Compose([ 114 | GroupScale(256), 115 | GroupOverSampleKaiming(args.input_size), 116 | Stack(mode=args.mode), 117 | ToTorchFormatTensor(), 118 | GroupNormalize(), 119 | ]) 120 | test_dataset = VideoDataSet( 121 | root_path=data_root, 122 | list_file=args.test_list, 123 | t_length=args.t_length, 124 | t_stride=args.t_stride, 125 | num_segments=args.num_segments, 126 | image_tmpl=args.image_tmpl, 127 | transform=test_transform, 128 | phase="Test") 129 | test_loader = torch.utils.data.DataLoader( 130 | test_dataset, 131 | batch_size=args.batch_size, shuffle=False, 132 | num_workers=args.workers, pin_memory=True) 133 | 134 | # Test 135 | batch_timer = AverageMeter() 136 | top1_m = AverageMeter() 137 | top5_m = AverageMeter() 138 | top1_a = AverageMeter() 139 | top5_a = AverageMeter() 140 | results_m = None 141 | results_a = None 142 | 143 | # set eval mode 144 | tsn.eval() 145 | 146 | end = time.time() 147 | for ind, (data, label) in enumerate(test_loader): 148 | label = label.cuda(non_blocking=True) 149 | 150 | with torch.no_grad(): 151 | output_m, pred_m, output_a, pred_a = tsn(data) 152 | prec1_m, prec5_m = accuracy(pred_m, label, topk=(1, 5)) 153 | prec1_a, prec5_a = accuracy(pred_a, label, topk=(1, 5)) 154 | top1_m.update(prec1_m.item(), data.shape[0]) 155 | top5_m.update(prec5_m.item(), data.shape[0]) 156 | top1_a.update(prec1_a.item(), data.shape[0]) 157 | top5_a.update(prec5_a.item(), data.shape[0]) 158 | 159 | # pdb.set_trace() 160 | batch_timer.update(time.time() - end) 161 | end = time.time() 162 | if results_m is not None: 163 | np.concatenate((results_m, output_m.cpu().numpy()), axis=0) 164 | else: 165 | results_m = output_m.cpu().numpy() 166 | 167 | if results_a is not None: 168 | np.concatenate((results_a, output_a.cpu().numpy()), axis=0) 169 | else: 170 | results_a = output_a.cpu().numpy() 171 | logging.info("{0}/{1} done, Batch: {batch_timer.val:.3f}({batch_timer.avg:.3f}), maxTop1: {top1_m.val:>6.3f}({top1_m.avg:>6.3f}), maxTop5: {top5_m.val:>6.3f}({top5_m.avg:>6.3f}), avgTop1: {top1_a.val:>6.3f}({top1_a.avg:>6.3f}), avgTop5: {top5_a.val:>6.3f}({top5_a.avg:>6.3f})". 172 | format(ind + 1, len(test_loader), 173 | batch_timer=batch_timer, 174 | top1_m=top1_m, top5_m=top5_m, top1_a=top1_a, top5_a=top5_a)) 175 | max_target_file = os.path.join(args.experiment_root, "arch_{0}-epoch_{1}-top1_{2}-top5_{3}_max.npz".format(arch, test_epoch, top1_m.avg, top5_m.avg)) 176 | avg_target_file = os.path.join(args.experiment_root, "arch_{0}-epoch_{1}-top1_{2}-top5_{3}_avg.npz".format(arch, test_epoch, top1_a.avg, top5_a.avg)) 177 | print("saving {}".format(max_target_file)) 178 | np.savez(max_target_file, results_m) 179 | print("saving {}".format(avg_target_file)) 180 | np.savez(avg_target_file, results_a) 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | from lib.utils.tools import * 10 | 11 | def set_bn_eval(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('BatchNorm') != -1: 14 | m.eval() 15 | 16 | def train(train_loader, model, criterion, optimizer, epoch, print_freq): 17 | batch_time = AverageMeter() 18 | data_time = AverageMeter() 19 | losses = AverageMeter() 20 | top1 = AverageMeter() 21 | top5 = AverageMeter() 22 | 23 | # switch to train mode 24 | model.train() 25 | 26 | end = time.time() 27 | for i, (input, target) in enumerate(train_loader): 28 | # measure data loading time 29 | data_time.update(time.time() - end) 30 | 31 | # input = input.cuda() 32 | target = target.cuda(non_blocking=True) 33 | 34 | # compute output 35 | output = model(input) 36 | loss = criterion(output, target) 37 | 38 | # measure accuracy and record loss 39 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 40 | losses.update(loss.item(), input.size(0)) 41 | top1.update(prec1.item(), input.size(0)) 42 | top5.update(prec5.item(), input.size(0)) 43 | 44 | # compute gradient and do SGD step 45 | optimizer.zero_grad() 46 | loss.backward() 47 | 48 | # clip gradients 49 | # total_norm = clip_grad_norm_(model.parameters(), 20) 50 | # if total_norm > 20: 51 | # print("clipping gradient: {} with coef {}".format(total_norm, 20 / total_norm)) 52 | 53 | optimizer.step() 54 | 55 | # measure elapsed time 56 | batch_time.update(time.time() - end) 57 | end = time.time() 58 | 59 | if i % print_freq == 0: 60 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 61 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 62 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 63 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 64 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 65 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( 66 | epoch, i, len(train_loader), batch_time=batch_time, 67 | data_time=data_time, loss=losses, top1=top1, 68 | top5=top5, lr=optimizer.param_groups[-1]['lr']))) 69 | 70 | def finetune_fc(train_loader, model, criterion, optimizer, epoch, print_freq): 71 | batch_time = AverageMeter() 72 | data_time = AverageMeter() 73 | losses = AverageMeter() 74 | top1 = AverageMeter() 75 | top5 = AverageMeter() 76 | 77 | model.train() 78 | 79 | # switch mode 80 | for m in model.modules(): 81 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 82 | m.eval() 83 | if isinstance(m, nn.Dropout): 84 | m.eval() 85 | 86 | # block gradients to base model 87 | for param in model.named_parameters(): 88 | if "base_model" in param[0]: 89 | param[1].requires_grad = False 90 | 91 | end = time.time() 92 | for i, (input, target) in enumerate(train_loader): 93 | # measure data loading time 94 | data_time.update(time.time() - end) 95 | 96 | # input = input.cuda() 97 | target = target.cuda(non_blocking=True) 98 | 99 | # compute output 100 | output = model(input) 101 | loss = criterion(output, target) 102 | 103 | # measure accuracy and record loss 104 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 105 | losses.update(loss.item(), input.size(0)) 106 | top1.update(prec1.item(), input.size(0)) 107 | top5.update(prec5.item(), input.size(0)) 108 | 109 | # compute gradient and do SGD step 110 | optimizer.zero_grad() 111 | loss.backward() 112 | optimizer.step() 113 | 114 | # measure elapsed time 115 | batch_time.update(time.time() - end) 116 | end = time.time() 117 | 118 | if i % print_freq == 0: 119 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 120 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 121 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 122 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 123 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 124 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( 125 | epoch, i, len(train_loader), batch_time=batch_time, 126 | data_time=data_time, loss=losses, top1=top1, 127 | top5=top5, lr=optimizer.param_groups[-1]['lr']))) 128 | 129 | def finetune_bn_frozen(train_loader, model, criterion, optimizer, epoch, print_freq): 130 | batch_time = AverageMeter() 131 | data_time = AverageMeter() 132 | losses = AverageMeter() 133 | top1 = AverageMeter() 134 | top5 = AverageMeter() 135 | 136 | model.train() 137 | 138 | # model.apply(set_bn_eval) 139 | 140 | # switch mode 141 | for n, m in model.named_modules(): 142 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 143 | # m.eval() 144 | if "base_model.bn1" in n: 145 | print(n) 146 | pass 147 | else: 148 | for p in m.parameters(): 149 | p.requires_grad = False 150 | m.eval() 151 | 152 | # for n, m in model.named_modules(): 153 | # if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 154 | # m.eval() 155 | # if isinstance(m, nn.Dropout): 156 | # m.eval() 157 | 158 | # block gradients to base model 159 | # for param in model.named_parameters(): 160 | # if "bn" in param[0]: 161 | # print(param[1].requires_grad) 162 | # if "base_model" in param[0]: 163 | # param[1].requires_grad = False 164 | 165 | end = time.time() 166 | for i, (input, target) in enumerate(train_loader): 167 | # print(model.module.base_model.bn1.weight.view(-1)[:3]) 168 | # print(model.module.base_model.bn1.running_mean.view(-1)[:3]) 169 | # import pdb 170 | # pdb.set_trace() 171 | # print("conv1", model.state_dict()['module.base_model.conv1.weight'].view(-1)[0:3]) 172 | # print("fc", model.state_dict()['module.classifier.1.weight'].view(-1)[0:3]) 173 | # print(model.state_dict().view(-1)[0:3]) 174 | # measure data loading time 175 | data_time.update(time.time() - end) 176 | 177 | # input = input.cuda() 178 | target = target.cuda(non_blocking=True) 179 | 180 | # compute output 181 | output = model(input) 182 | loss = criterion(output, target) 183 | 184 | # measure accuracy and record loss 185 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 186 | losses.update(loss.item(), input.size(0)) 187 | top1.update(prec1.item(), input.size(0)) 188 | top5.update(prec5.item(), input.size(0)) 189 | 190 | # compute gradient and do SGD step 191 | optimizer.zero_grad() 192 | loss.backward() 193 | # for param in model.parameters(): 194 | # param.grad.data.clamp_(-1, 1) 195 | total_norm = clip_grad_norm(model.parameters(), 20) 196 | if total_norm > 20: 197 | print("clipping gradient: {} with coef {}".format(total_norm, 20 / total_norm)) 198 | 199 | optimizer.step() 200 | 201 | # measure elapsed time 202 | batch_time.update(time.time() - end) 203 | end = time.time() 204 | 205 | if i % print_freq == 0: 206 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 207 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 208 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 209 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 210 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 211 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( 212 | epoch, i, len(train_loader), batch_time=batch_time, 213 | data_time=data_time, loss=losses, top1=top1, 214 | top5=top5, lr=optimizer.param_groups[-1]['lr']))) 215 | 216 | def validate(val_loader, model, criterion, print_freq, epoch, logger=None): 217 | batch_time = AverageMeter() 218 | losses = AverageMeter() 219 | top1 = AverageMeter() 220 | top5 = AverageMeter() 221 | 222 | # switch to evaluate mode 223 | model.eval() 224 | 225 | with torch.no_grad(): 226 | end = time.time() 227 | for i, (input, target) in enumerate(val_loader): 228 | target = target.cuda(non_blocking=True) 229 | 230 | # compute output 231 | output = model(input) 232 | loss = criterion(output, target) 233 | 234 | # measure accuracy and record loss 235 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 236 | losses.update(loss.item(), input.size(0)) 237 | top1.update(prec1.item(), input.size(0)) 238 | top5.update(prec5.item(), input.size(0)) 239 | 240 | # measure elapsed time 241 | batch_time.update(time.time() - end) 242 | end = time.time() 243 | 244 | if i % print_freq == 0: 245 | logging.info(('Test: [{0}/{1}]\t' 246 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 247 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 248 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 249 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 250 | i, len(val_loader), batch_time=batch_time, loss=losses, 251 | top1=top1, top5=top5))) 252 | 253 | logging.info(('Epoch {epoch} Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 254 | .format(epoch=epoch, top1=top1, top5=top5, loss=losses))) 255 | 256 | # return (top1.avg + top5.avg) / 2 257 | return top1.avg --------------------------------------------------------------------------------