├── .gitignore
├── .gitmodules
├── README.md
├── clip_test.py
├── compute_flops.py
├── data
├── hmdb51
│ ├── hmdb51_train_split1_list.txt
│ ├── hmdb51_train_split2_list.txt
│ ├── hmdb51_train_split3_list.txt
│ ├── hmdb51_val_split1_list.txt
│ ├── hmdb51_val_split2_list.txt
│ └── hmdb51_val_split3_list.txt
├── kinetics200
│ ├── 400_200_label_mapping.txt
│ ├── create_kinetics200_list.py
│ ├── kinetics200_train_list.txt
│ ├── kinetics200_train_list_org.txt
│ ├── kinetics200_val_list.txt
│ ├── kinetics200_val_list_org.txt
│ ├── kinetics_train_list.txt
│ └── kinetics_val_list.txt
├── kinetics400
│ ├── count.py
│ ├── create_xlw_list.py
│ ├── kinetics_train_list.txt
│ ├── kinetics_train_list_xlw
│ ├── kinetics_val_list.txt
│ └── kinetics_val_list_xlw
├── sthsth_v1
│ ├── create_sthsth_v1_list.py
│ ├── something-something-v1-labels.csv
│ ├── something-something-v1-test.csv
│ ├── something-something-v1-train.csv
│ ├── something-something-v1-validation.csv
│ ├── sthv1_train_list.txt
│ └── sthv1_val_list.txt
└── ucf101
│ ├── ucf101_train_split1_list.txt
│ ├── ucf101_train_split2_list.txt
│ ├── ucf101_train_split3_list.txt
│ ├── ucf101_val_split1_list.txt
│ ├── ucf101_val_split2_list.txt
│ └── ucf101_val_split3_list.txt
├── finetune_bn_frozen.py
├── finetune_fc.py
├── lib
├── dataset.py
├── models.py
├── modules
│ ├── __init__.py
│ ├── pooling.py
│ └── scale.py
├── networks
│ ├── __init__.py
│ ├── mnet2.py
│ ├── mnet2_3d.py
│ ├── part_inflate_resnet_3d.py
│ ├── resnet.py
│ ├── resnet_3d.py
│ └── resnet_3d_nodown.py
├── opts.py
├── transforms.py
└── utils
│ ├── deprefix.py
│ ├── tools.py
│ ├── vis_comb.py
│ └── visualization.py
├── main.py
├── main_20bn.py
├── main_imagenet.py
├── scripts
├── imagenet_2d_res26.sh
└── kinetics400_3d_res50_slowonly_im_pre.sh
├── test_10crop.py
├── test_kaiming.py
└── train_val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | access/
3 | output/
4 | models/
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 | .idea
105 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "data/kinetics200/Mini-Kinetics-200"]
2 | path = data/kinetics200/Mini-Kinetics-200
3 | url = https://github.com/BannyStone/Mini-Kinetics-200
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Video-Classification-Pytorch
2 |
3 | ***This is an archived repo. Stronly recommend PySlowFast or mmaction for video understanding***.
4 |
5 | This is a repository containing 3D models and 2D models for video classification. The code is based on PyTorch 1.0.
6 | Until now, it supports the following datasets:
7 | Kinetics-400, Mini-Kinetics-200, UCF101, HMDB51
8 |
9 | ## Results
10 |
11 | ### Kinetics-400
12 |
13 | We report the baselines with ResNet-50 backbone on Kinetics-400 validation set as below (all models are trained on training set).
14 | All the models are trained in one single server with 8 GTX 1080 Ti GPUs.
15 |
16 | | network | pretrain data | spatial resolution | input frames | sampling stride | backbone | top1 | top5 |
17 | | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
18 | | ResNet50-SlowOnly | ImageNet-1K | 224x224 | 8 | 8 | ResNet50 | 73.77 | 91.17 |
19 |
20 |
21 | ## Get the Code
22 | ```Shell
23 | git clone --recursive https://github.com/BannyStone/Video_Classification_PyTorch.git
24 | ```
25 |
26 | ## Preparing Dataset
27 | ### Kinetics-400
28 | ```Shell
29 | cd data/kinetics400
30 | mkdir access && cd access
31 | ln -s $YOUR_KINETICS400_DATASET_TRAIN_DIR$ RGB_train
32 | ln -s $YOUR_KINETICS400_DATASET_VAL_DIR$ RGB_val
33 | ```
34 | Note that:
35 | - The reported models are trained with the Kinetics data provided by Xiaolong Wang.https://github.com/facebookresearch/video-nonlocal-net/blob/master/DATASET.md
36 | - In train and validation lists for all datasets, each line represents one video where the first element is the video frame directory, the second element is the number of frames and the third element is the index of class. Please prepare your own list accordingly because different video parsing method may lead to different frame numbers. We show part of Kinetics-400 train list as an example:
37 | ```shell
38 | RGB_train/D32_1gwq35E 300 66
39 | RGB_train/-G-5CJ0JkKY 250 254
40 | RGB_train/4uZ27ivBl00 300 341
41 | RGB_train/pZP-dHUuGiA 240 369
42 | ```
43 | - This code can read the image files in each video frame folder according to the image template argument *image_tmpl*, such as *image_{:06d}.jpg*.
44 |
45 | ## Training
46 | Execute training script:
47 | ```Shell
48 | ./scripts/kinetics400_3d_res50_slowonly_im_pre.sh
49 | ```
50 |
51 | We show script *kinetics400_3d_res50_slowonly_im_pre.sh* here:
52 | ```Shell
53 | python main.py \
54 | kinetics400 \
55 | data/kinetics400/kinetics_train_list_xlw \
56 | data/kinetics400/kinetics_val_list_xlw \
57 | --arch resnet50_3d_slowonly \
58 | --dro 0.5 \
59 | --mode 3D \
60 | --t_length 8 \
61 | --t_stride 8 \
62 | --pretrained \
63 | --epochs 110 \
64 | --batch-size 96 \
65 | --lr 0.02 \
66 | --wd 0.0001 \
67 | --lr_steps 50 80 100 \
68 | --workers 16 \
69 | ```
70 |
71 | ## Testing
72 | ```Shell
73 | python ./test_kaiming.py \
74 | kinetics400 \
75 | data/kinetics400/kinetics_val_list_xlw \
76 | output/kinetics400_resnet50_3d_slowonly_3D_length8_stride8_dropout0.5/model_best.pth \
77 | --arch resnet50_3d_slowonly \
78 | --mode TSN+3D \
79 | --batch_size 1 \
80 | --num_segments 10 \
81 | --input_size 256 \
82 | --t_length 8 \
83 | --t_stride 8 \
84 | --dropout 0.5 \
85 | --workers 12 \
86 | --image_tmpl image_{:06d}.jpg \
87 |
88 | ```
89 |
--------------------------------------------------------------------------------
/clip_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import logging
6 |
7 | import torch
8 | import torchvision
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | from lib.dataset import VideoDataSet
14 | from lib.models import VideoModule
15 | from lib.transforms import *
16 | from lib.utils.tools import *
17 | from lib.opts import args
18 |
19 | from train_val import train, validate
20 |
21 | def main():
22 | global args, best_metric
23 |
24 | # specify dataset
25 | if args.dataset == 'ucf101':
26 | num_class = 101
27 | elif args.dataset == 'hmdb51':
28 | num_class = 51
29 | elif args.dataset == 'kinetics400':
30 | num_class = 400
31 | elif args.dataset == 'kinetics200':
32 | num_class = 200
33 | else:
34 | raise ValueError('Unknown dataset '+args.dataset)
35 |
36 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
37 | "data/{}/access".format(args.dataset))
38 |
39 | # create model
40 | org_model = VideoModule(num_class=num_class,
41 | base_model_name=args.arch,
42 | dropout=args.dropout,
43 | pretrained=args.pretrained,
44 | pretrained_model=args.pretrained_model)
45 | num_params = 0
46 | for param in org_model.parameters():
47 | num_params += param.reshape((-1, 1)).shape[0]
48 | print("Model Size is {:.3f}M".format(num_params/1000000))
49 |
50 | model = torch.nn.DataParallel(org_model).cuda()
51 |
52 | criterion = torch.nn.CrossEntropyLoss().cuda()
53 |
54 | optimizer = torch.optim.SGD(model.parameters(),
55 | args.lr,
56 | momentum=args.momentum,
57 | weight_decay=args.weight_decay)
58 |
59 | # optionally resume from a checkpoint
60 | if args.resume:
61 | if os.path.isfile(args.resume):
62 | print(("=> loading checkpoint '{}'".format(args.resume)))
63 | checkpoint = torch.load(args.resume)
64 | args.start_epoch = checkpoint['epoch']
65 | best_metric = checkpoint['best_metric']
66 | model.load_state_dict(checkpoint['state_dict'])
67 | optimizer.load_state_dict(checkpoint['optimizer'])
68 | print(("=> loaded checkpoint '{}' (epoch {})"
69 | .format(args.resume, checkpoint['epoch'])))
70 | else:
71 | print(("=> no checkpoint found at '{}'".format(args.resume)))
72 |
73 | ## val data
74 | val_transform = torchvision.transforms.Compose([
75 | GroupScale(args.new_size),
76 | GroupCenterCrop(args.crop_size),
77 | Stack(mode=args.mode),
78 | ToTorchFormatTensor(),
79 | GroupNormalize(),
80 | ])
81 | val_dataset = VideoDataSet(root_path=data_root,
82 | list_file=args.val_list,
83 | t_length=args.t_length,
84 | t_stride=args.t_stride,
85 | num_segments=args.num_segments,
86 | image_tmpl=args.image_tmpl,
87 | transform=val_transform,
88 | phase="Val")
89 | val_loader = torch.utils.data.DataLoader(
90 | val_dataset,
91 | batch_size=args.batch_size, shuffle=False,
92 | num_workers=args.workers, pin_memory=True)
93 |
94 | if args.mode != "3D":
95 | cudnn.benchmark = True
96 |
97 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
98 |
99 |
100 | if __name__ == '__main__':
101 | main()
102 |
--------------------------------------------------------------------------------
/compute_flops.py:
--------------------------------------------------------------------------------
1 | from lib.networks.part_inflate_resnet_3d import *
2 | from lib.modules import *
3 | import torch
4 | from lib.networks.km_resnet_3d_beta import TKMConv, compute_tkmconv, km_resnet26_3d_v2_sample, km_resnet50_3d_v2_sample
5 | def count_GloAvgPool3d(m, x, y):
6 | m.total_ops = torch.Tensor([int(0)])
7 |
8 | from thop import profile
9 | model = km_resnet50_3d_v2_sample()
10 | model.fc = torch.nn.Linear(2048, 400)
11 | flops, params = profile(model, input_size=(1, 3, 8, 224,224), custom_ops={GloAvgPool3d: count_GloAvgPool3d, TKMConv: compute_tkmconv})
12 | print("params: {}".format(params/1000000))
13 | print("flops: {}".format(flops/1000000000))
14 |
--------------------------------------------------------------------------------
/data/kinetics200/400_200_label_mapping.txt:
--------------------------------------------------------------------------------
1 | 0 0
2 | 1 1
3 | 5 2
4 | 6 3
5 | 11 4
6 | 12 5
7 | 14 6
8 | 16 7
9 | 18 8
10 | 19 9
11 | 22 10
12 | 24 11
13 | 27 12
14 | 29 13
15 | 31 14
16 | 32 15
17 | 34 16
18 | 36 17
19 | 37 18
20 | 40 19
21 | 41 20
22 | 42 21
23 | 43 22
24 | 48 23
25 | 49 24
26 | 50 25
27 | 55 26
28 | 56 27
29 | 59 28
30 | 60 29
31 | 68 30
32 | 69 31
33 | 70 32
34 | 75 33
35 | 77 34
36 | 78 35
37 | 79 36
38 | 80 37
39 | 83 38
40 | 84 39
41 | 86 40
42 | 87 41
43 | 88 42
44 | 93 43
45 | 97 44
46 | 99 45
47 | 103 46
48 | 104 47
49 | 107 48
50 | 108 49
51 | 109 50
52 | 115 51
53 | 116 52
54 | 123 53
55 | 124 54
56 | 125 55
57 | 126 56
58 | 127 57
59 | 130 58
60 | 132 59
61 | 133 60
62 | 134 61
63 | 140 62
64 | 142 63
65 | 143 64
66 | 147 65
67 | 148 66
68 | 149 67
69 | 151 68
70 | 152 69
71 | 153 70
72 | 159 71
73 | 161 72
74 | 162 73
75 | 164 74
76 | 166 75
77 | 167 76
78 | 169 77
79 | 172 78
80 | 174 79
81 | 177 80
82 | 180 81
83 | 182 82
84 | 183 83
85 | 188 84
86 | 189 85
87 | 192 86
88 | 193 87
89 | 197 88
90 | 199 89
91 | 201 90
92 | 204 91
93 | 205 92
94 | 206 93
95 | 208 94
96 | 209 95
97 | 212 96
98 | 214 97
99 | 217 98
100 | 218 99
101 | 219 100
102 | 220 101
103 | 221 102
104 | 223 103
105 | 224 104
106 | 225 105
107 | 227 106
108 | 229 107
109 | 230 108
110 | 232 109
111 | 233 110
112 | 234 111
113 | 235 112
114 | 240 113
115 | 242 114
116 | 243 115
117 | 244 116
118 | 245 117
119 | 246 118
120 | 247 119
121 | 248 120
122 | 249 121
123 | 250 122
124 | 251 123
125 | 252 124
126 | 253 125
127 | 254 126
128 | 255 127
129 | 256 128
130 | 258 129
131 | 261 130
132 | 262 131
133 | 264 132
134 | 269 133
135 | 273 134
136 | 275 135
137 | 277 136
138 | 278 137
139 | 280 138
140 | 282 139
141 | 283 140
142 | 285 141
143 | 286 142
144 | 289 143
145 | 291 144
146 | 292 145
147 | 294 146
148 | 298 147
149 | 299 148
150 | 301 149
151 | 302 150
152 | 304 151
153 | 305 152
154 | 306 153
155 | 307 154
156 | 308 155
157 | 313 156
158 | 315 157
159 | 316 158
160 | 317 159
161 | 318 160
162 | 321 161
163 | 322 162
164 | 323 163
165 | 325 164
166 | 326 165
167 | 327 166
168 | 330 167
169 | 331 168
170 | 334 169
171 | 336 170
172 | 337 171
173 | 339 172
174 | 340 173
175 | 346 174
176 | 348 175
177 | 349 176
178 | 350 177
179 | 356 178
180 | 358 179
181 | 360 180
182 | 364 181
183 | 365 182
184 | 367 183
185 | 369 184
186 | 371 185
187 | 373 186
188 | 378 187
189 | 379 188
190 | 380 189
191 | 382 190
192 | 383 191
193 | 387 192
194 | 389 193
195 | 390 194
196 | 391 195
197 | 393 196
198 | 394 197
199 | 398 198
200 | 399 199
201 |
--------------------------------------------------------------------------------
/data/kinetics200/create_kinetics200_list.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | # extract target 200-class videos from the original videos
4 | with open('kinetics_train_list.txt') as tr400:
5 | with open('Mini-Kinetics-200/train_ytid_list.txt') as miniTr:
6 | with open('kinetics200_train_list_org.txt', 'w') as tr200:
7 | # build indices for original 400-class train list
8 | lines = tr400.readlines()
9 | ytid_line_dict = dict()
10 | for line in lines:
11 | ytid = line.strip().split()[0].split('/')[1]
12 | ytid_line_dict[ytid] = line
13 | # extract target lines and write them into tr200 file
14 | lines = miniTr.readlines()
15 | for line in lines:
16 | ytid = line.strip()
17 | if ytid in ytid_line_dict:
18 | target_line = ytid_line_dict[ytid]
19 | tr200.write(target_line)
20 | else:
21 | print("{} is not in original video list".format(ytid))
22 |
23 | with open('kinetics_val_list.txt') as va400:
24 | with open('Mini-Kinetics-200/val_ytid_list.txt') as miniVa:
25 | with open('kinetics200_val_list_org.txt', 'w') as va200:
26 | # build indices for original 400-class val list
27 | lines = va400.readlines()
28 | ytid_line_dict = dict()
29 | for line in lines:
30 | ytid = line.strip().split()[0].split('/')[1]
31 | ytid_line_dict[ytid] = line
32 | # extract target lines and write them into va200 file
33 | lines = miniVa.readlines()
34 | for line in lines:
35 | ytid = line.strip()
36 | if ytid in ytid_line_dict:
37 | target_line = ytid_line_dict[ytid]
38 | va200.write(target_line)
39 | else:
40 | print("{} is not in original video list".format(ytid))
41 |
42 | # summarize all the 200 categories of Mini-Kinetics
43 | # Train and val
44 | cats_tr = set()
45 | cats_va = set()
46 |
47 | with open("kinetics200_train_list_org.txt") as f:
48 | lines = f.readlines()
49 | for line in lines:
50 | label_id = int(line.strip().split()[-1])
51 | cats_tr.add(label_id)
52 |
53 | with open("kinetics200_val_list_org.txt") as f:
54 | lines = f.readlines()
55 | for line in lines:
56 | label_id = int(line.strip().split()[-1])
57 | cats_va.add(label_id)
58 |
59 | assert(cats_tr == cats_va)
60 |
61 | # build 400-class 200-class dictionary
62 | _400_200_dict = dict()
63 | for i, cat in enumerate(cats_tr):
64 | _400_200_dict[cat] = i
65 |
66 | with open('400_200_label_mapping.txt', 'w') as f:
67 | for key, value in _400_200_dict.items():
68 | f.write("{} {}\n".format(key, value))
69 |
70 | with open('kinetics200_train_list_org.txt') as f_src:
71 | with open('kinetics200_train_list.txt', 'w') as f_dst:
72 | lines = f_src.readlines()
73 | for line in lines:
74 | items = line.strip().split()
75 | items[-1] = str(_400_200_dict[int(items[-1])])
76 | new_line = ' '.join(items)
77 | f_dst.write(new_line + '\n')
78 |
79 | with open('kinetics200_val_list_org.txt') as f_src:
80 | with open('kinetics200_val_list.txt', 'w') as f_dst:
81 | lines = f_src.readlines()
82 | for line in lines:
83 | items = line.strip().split()
84 | items[-1] = str(_400_200_dict[int(items[-1])])
85 | new_line = ' '.join(items)
86 | f_dst.write(new_line + '\n')
87 |
88 | # pdb.set_trace()
--------------------------------------------------------------------------------
/data/kinetics400/count.py:
--------------------------------------------------------------------------------
1 | frames = []
2 | with open("kinetics_val_list.txt") as f:
3 | lines = f.readlines()
4 | for line in lines:
5 | items = line.strip().split()
6 | frames.append(int(items[1]))
7 |
8 | total = len(frames)
9 | count60 = 0
10 | count120 = 0
11 | count240 = 0
12 |
13 | for fr in frames:
14 | if fr > 60:
15 | count60 += 1
16 | if fr > 120:
17 | count120 += 1
18 | if fr > 240:
19 | count240 += 1
20 |
21 | print("60: ", count60, total, count60/total)
22 | print("120: ", count120, total, count120/total)
23 | print("240: ", count240, total, count240/total)
--------------------------------------------------------------------------------
/data/kinetics400/create_xlw_list.py:
--------------------------------------------------------------------------------
1 | access = "access/"
2 | import os
3 | from tqdm import tqdm
4 |
5 | # with open("kinetics_val_list.txt") as f_old:
6 | # with open("kinetics_val_list_xlw", 'w') as f_new:
7 | # old_lines = f_old.readlines()
8 | # for line in old_lines:
9 | # vid_path, num_fr, label = line.strip().split()
10 | # if os.path.exists(access+vid_path):
11 | # new_num_fr = len(os.listdir(access+vid_path))
12 | # f_new.write(" ".join([vid_path, str(new_num_fr), label]) + '\n')
13 |
14 | # with open("kinetics_train_list.txt") as f_old:
15 | # with open("kinetics_train_list_xlw", 'w') as f_new:
16 | # old_lines = f_old.readlines()
17 | # for line in old_lines:
18 | # vid_path, num_fr, label = line.strip().split()
19 | # if os.path.exists(access+vid_path):
20 | # new_num_fr = len(os.listdir(access+vid_path))
21 | # f_new.write(" ".join([vid_path, str(new_num_fr), label]) + '\n')
22 |
23 | with open("kinetics_train_list_xlw") as f:
24 | lines = f.readlines()
25 | for line in tqdm(lines):
26 | vid_path, num_fr, label = line.strip().split()
27 | images = os.listdir(access+vid_path)
28 | images.sort()
29 | last_image = images[-1]
30 | # import pdb
31 | # pdb.set_trace()
32 | if int(last_image[6:-4]) != int(num_fr):
33 | print(vid_path)
--------------------------------------------------------------------------------
/data/sthsth_v1/create_sthsth_v1_list.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import collections
4 | from collections import OrderedDict
5 |
6 | frame_root = "/media/SSD/zhoulei/20bn-something-something-v1"
7 |
8 | f_tr = open("sthv1_train_list.txt", 'w')
9 | f_va = open("sthv1_val_list.txt", 'w')
10 |
11 | name_id = OrderedDict()
12 | with open('something-something-v1-labels.csv', newline='') as csvfile:
13 | reader = csv.reader(csvfile, delimiter=';')
14 | for i, row in enumerate(reader):
15 | assert(len(row) == 1), "the length of row must be one"
16 | name_id[row[0]] = i
17 |
18 | with open('something-something-v1-train.csv', newline='') as csvfile:
19 | reader = csv.reader(csvfile, delimiter=';')
20 | for row in reader:
21 | dir_name = row[0]
22 | class_name = row[1]
23 | class_id = name_id[class_name]
24 |
25 | vid_dir = os.path.join(frame_root, dir_name)
26 | frame_num = len(os.listdir(vid_dir))
27 |
28 | line = ' '.join(("RGB/"+dir_name, str(frame_num), str(class_id)+'\n'))
29 | f_tr.write(line)
30 |
31 | with open('something-something-v1-validation.csv', newline='') as csvfile:
32 | reader = csv.reader(csvfile, delimiter=';')
33 | for row in reader:
34 | dir_name = row[0]
35 | class_name = row[1]
36 | class_id = name_id[class_name]
37 |
38 | vid_dir = os.path.join(frame_root, dir_name)
39 | frame_num = len(os.listdir(vid_dir))
40 |
41 | line = ' '.join(("RGB/"+dir_name, str(frame_num), str(class_id)+'\n'))
42 | f_va.write(line)
--------------------------------------------------------------------------------
/data/sthsth_v1/something-something-v1-labels.csv:
--------------------------------------------------------------------------------
1 | Holding something
2 | Turning something upside down
3 | Turning the camera left while filming something
4 | Stacking number of something
5 | Turning the camera right while filming something
6 | Opening something
7 | Approaching something with your camera
8 | Picking something up
9 | Pushing something so that it almost falls off but doesn't
10 | Folding something
11 | Moving something away from the camera
12 | Closing something
13 | Moving away from something with your camera
14 | Turning the camera downwards while filming something
15 | Pushing something so that it slightly moves
16 | Turning the camera upwards while filming something
17 | Pretending to pick something up
18 | Showing something to the camera
19 | Moving something up
20 | Plugging something into something
21 | Unfolding something
22 | Putting something onto something
23 | Showing that something is empty
24 | Pretending to put something on a surface
25 | Taking something from somewhere
26 | Putting something next to something
27 | Moving something towards the camera
28 | Showing a photo of something to the camera
29 | Pushing something with something
30 | Throwing something
31 | Pushing something from left to right
32 | Something falling like a feather or paper
33 | Throwing something in the air and letting it fall
34 | Throwing something against something
35 | Lifting something with something on it
36 | Taking one of many similar things on the table
37 | Showing something behind something
38 | Putting something into something
39 | Tearing something just a little bit
40 | Moving something away from something
41 | Tearing something into two pieces
42 | Pushing something from right to left
43 | Holding something next to something
44 | Putting something, something and something on the table
45 | Pretending to take something from somewhere
46 | Moving something closer to something
47 | Pretending to put something next to something
48 | Uncovering something
49 | Something falling like a rock
50 | Putting something and something on the table
51 | Pouring something into something
52 | Moving something down
53 | Pulling something from right to left
54 | Throwing something in the air and catching it
55 | Tilting something with something on it until it falls off
56 | Putting something in front of something
57 | Pretending to turn something upside down
58 | Putting something on a surface
59 | Pretending to throw something
60 | Showing something on top of something
61 | Covering something with something
62 | Squeezing something
63 | Putting something similar to other things that are already on the table
64 | Lifting up one end of something, then letting it drop down
65 | Taking something out of something
66 | Moving part of something
67 | Pulling something from left to right
68 | Lifting something up completely without letting it drop down
69 | Attaching something to something
70 | Putting something behind something
71 | Moving something and something closer to each other
72 | Holding something in front of something
73 | Pushing something so that it falls off the table
74 | Holding something over something
75 | Pretending to open something without actually opening it
76 | Removing something, revealing something behind
77 | Hitting something with something
78 | Moving something and something away from each other
79 | Touching (without moving) part of something
80 | Pretending to put something into something
81 | Showing that something is inside something
82 | Lifting something up completely, then letting it drop down
83 | Pretending to take something out of something
84 | Holding something behind something
85 | Laying something on the table on its side, not upright
86 | Poking something so it slightly moves
87 | Pretending to close something without actually closing it
88 | Putting something upright on the table
89 | Dropping something in front of something
90 | Dropping something behind something
91 | Lifting up one end of something without letting it drop down
92 | Rolling something on a flat surface
93 | Throwing something onto a surface
94 | Showing something next to something
95 | Dropping something onto something
96 | Stuffing something into something
97 | Dropping something into something
98 | Piling something up
99 | Letting something roll along a flat surface
100 | Twisting something
101 | Spinning something that quickly stops spinning
102 | Putting number of something onto something
103 | Putting something underneath something
104 | Moving something across a surface without it falling down
105 | Plugging something into something but pulling it right out as you remove your hand
106 | Dropping something next to something
107 | Poking something so that it falls over
108 | Spinning something so it continues spinning
109 | Poking something so lightly that it doesn't or almost doesn't move
110 | Wiping something off of something
111 | Moving something across a surface until it falls down
112 | Pretending to poke something
113 | Putting something that cannot actually stand upright upright on the table, so it falls on its side
114 | Pulling something out of something
115 | Scooping something up with something
116 | Pretending to be tearing something that is not tearable
117 | Burying something in something
118 | Tipping something over
119 | Tilting something with something on it slightly so it doesn't fall down
120 | Pretending to put something onto something
121 | Bending something until it breaks
122 | Letting something roll down a slanted surface
123 | Trying to bend something unbendable so nothing happens
124 | Bending something so that it deforms
125 | Digging something out of something
126 | Pretending to put something underneath something
127 | Putting something on a flat surface without letting it roll
128 | Putting something on the edge of something so it is not supported and falls down
129 | Spreading something onto something
130 | Pretending to put something behind something
131 | Sprinkling something onto something
132 | Something colliding with something and both come to a halt
133 | Pushing something off of something
134 | Putting something that can't roll onto a slanted surface, so it stays where it is
135 | Lifting a surface with something on it until it starts sliding down
136 | Pretending or failing to wipe something off of something
137 | Trying but failing to attach something to something because it doesn't stick
138 | Pulling something from behind of something
139 | Pushing something so it spins
140 | Pouring something onto something
141 | Pulling two ends of something but nothing happens
142 | Moving something and something so they pass each other
143 | Pretending to sprinkle air onto something
144 | Putting something that can't roll onto a slanted surface, so it slides down
145 | Something colliding with something and both are being deflected
146 | Pretending to squeeze something
147 | Pulling something onto something
148 | Putting something onto something else that cannot support it so it falls down
149 | Lifting a surface with something on it but not enough for it to slide down
150 | Pouring something out of something
151 | Moving something and something so they collide with each other
152 | Tipping something with something in it over, so something in it falls out
153 | Letting something roll up a slanted surface, so it rolls back down
154 | Pretending to scoop something up with something
155 | Pretending to pour something out of something, but something is empty
156 | Pulling two ends of something so that it gets stretched
157 | Failing to put something into something because something does not fit
158 | Pretending or trying and failing to twist something
159 | Trying to pour something into something, but missing so it spills next to it
160 | Something being deflected from something
161 | Poking a stack of something so the stack collapses
162 | Spilling something onto something
163 | Pulling two ends of something so that it separates into two pieces
164 | Pouring something into something until it overflows
165 | Pretending to spread air onto something
166 | Twisting (wringing) something wet until water comes out
167 | Poking a hole into something soft
168 | Spilling something next to something
169 | Poking a stack of something without the stack collapsing
170 | Putting something onto a slanted surface but it doesn't glide down
171 | Pushing something onto something
172 | Poking something so that it spins around
173 | Spilling something behind something
174 | Poking a hole into some substance
175 |
--------------------------------------------------------------------------------
/finetune_bn_frozen.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import logging
6 |
7 | import torch
8 | import torchvision
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | from lib.dataset import VideoDataSet
14 | from lib.models import VideoModule
15 | from lib.transforms import *
16 | from lib.utils.tools import *
17 | from lib.opts import args
18 | from lib.modules import *
19 |
20 | from train_val import train, validate,finetune_bn_frozen
21 |
22 | best_metric = 0
23 |
24 | def main():
25 | global args, best_metric
26 |
27 | if 'ucf101' in args.dataset:
28 | num_class = 101
29 | elif 'hmdb51' in args.dataset:
30 | num_class = 51
31 | elif args.dataset == 'kinetics400':
32 | num_class = 400
33 | elif args.dataset == 'kinetics200':
34 | num_class = 200
35 | else:
36 | raise ValueError('Unknown dataset '+args.dataset)
37 |
38 | # data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
39 | # "data/{}/access".format(args.dataset))
40 |
41 | if "ucf101" in args.dataset or "hmdb51" in args.dataset:
42 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
43 | "data/{}/access".format(args.dataset[:-3]))
44 | else:
45 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
46 | "data/{}/access".format(args.dataset))
47 |
48 | # create model
49 | org_model = VideoModule(num_class=num_class,
50 | base_model_name=args.arch,
51 | dropout=args.dropout,
52 | pretrained=args.pretrained,
53 | pretrained_model=args.pretrained_model)
54 | num_params = 0
55 | for param in org_model.parameters():
56 | num_params += param.reshape((-1, 1)).shape[0]
57 | print("Model Size is {:.3f}M".format(num_params/1000000))
58 |
59 | model = torch.nn.DataParallel(org_model).cuda()
60 | # model = org_model
61 |
62 | # define loss function (criterion) and optimizer
63 | criterion = torch.nn.CrossEntropyLoss().cuda()
64 |
65 | # optim_params = [param[1] for param in model.named_parameters() if "classifier" in param[0]]
66 | # import pdb
67 | # pdb.set_trace()
68 | optimizer = torch.optim.SGD(model.parameters(),
69 | args.lr,
70 | momentum=args.momentum,
71 | weight_decay=args.weight_decay)
72 |
73 | # optionally resume from a checkpoint
74 | if args.resume:
75 | if os.path.isfile(args.resume):
76 | print(("=> loading checkpoint '{}'".format(args.resume)))
77 | checkpoint = torch.load(args.resume)
78 | args.start_epoch = checkpoint['epoch']
79 | best_metric = checkpoint['best_metric']
80 | model.load_state_dict(checkpoint['state_dict'])
81 | optimizer.load_state_dict(checkpoint['optimizer'])
82 | print(("=> loaded checkpoint '{}' (epoch {})"
83 | .format(args.resume, checkpoint['epoch'])))
84 | else:
85 | print(("=> no checkpoint found at '{}'".format(args.resume)))
86 |
87 | # Data loading code
88 | ## train data
89 | train_transform = torchvision.transforms.Compose([
90 | GroupScale(args.new_size),
91 | GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]),
92 | GroupRandomHorizontalFlip(),
93 | Stack(mode=args.mode),
94 | ToTorchFormatTensor(),
95 | GroupNormalize(),
96 | ])
97 | train_dataset = VideoDataSet(root_path=data_root,
98 | list_file=args.train_list,
99 | t_length=args.t_length,
100 | t_stride=args.t_stride,
101 | num_segments=args.num_segments,
102 | image_tmpl=args.image_tmpl,
103 | transform=train_transform,
104 | phase="Train")
105 | train_loader = torch.utils.data.DataLoader(
106 | train_dataset,
107 | batch_size=args.batch_size, shuffle=True, drop_last=True,
108 | num_workers=args.workers, pin_memory=True)
109 |
110 | ## val data
111 | val_transform = torchvision.transforms.Compose([
112 | GroupScale(args.new_size),
113 | GroupCenterCrop(args.crop_size),
114 | Stack(mode=args.mode),
115 | ToTorchFormatTensor(),
116 | GroupNormalize(),
117 | ])
118 | val_dataset = VideoDataSet(root_path=data_root,
119 | list_file=args.val_list,
120 | t_length=args.t_length,
121 | t_stride=args.t_stride,
122 | num_segments=args.num_segments,
123 | image_tmpl=args.image_tmpl,
124 | transform=val_transform,
125 | phase="Val")
126 | val_loader = torch.utils.data.DataLoader(
127 | val_dataset,
128 | batch_size=args.batch_size, shuffle=False,
129 | num_workers=args.workers, pin_memory=True)
130 |
131 | if args.mode != "3D":
132 | cudnn.benchmark = True
133 |
134 | if args.resume:
135 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
136 | torch.cuda.empty_cache()
137 |
138 | for epoch in range(args.start_epoch, args.epochs):
139 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps)
140 |
141 | # train for one epoch
142 | finetune_bn_frozen(train_loader, model, criterion, optimizer, epoch, args.print_freq)
143 |
144 | # evaluate on validation set
145 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
146 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1)
147 | torch.cuda.empty_cache()
148 |
149 | # remember best prec@1 and save checkpoint
150 | is_best = metric > best_metric
151 | best_metric = max(metric, best_metric)
152 | save_checkpoint({
153 | 'epoch': epoch + 1,
154 | 'arch': args.arch,
155 | 'state_dict': model.state_dict(),
156 | 'best_metric': best_metric,
157 | 'optimizer': optimizer.state_dict(),
158 | }, is_best, epoch + 1, args.experiment_root)
159 |
160 | if __name__ == '__main__':
161 | main()
162 |
--------------------------------------------------------------------------------
/finetune_fc.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import logging
6 |
7 | import torch
8 | import torchvision
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | from lib.dataset import VideoDataSet
14 | from lib.models import VideoModule
15 | from lib.transforms import *
16 | from lib.utils.tools import *
17 | from lib.opts import args
18 |
19 | from train_val import train, validate, finetune_fc
20 |
21 | best_metric = 0
22 |
23 | def main():
24 | global args, best_metric
25 |
26 | # specify dataset
27 | if 'ucf101' in args.dataset:
28 | num_class = 101
29 | elif 'hmdb51' in args.dataset:
30 | num_class = 51
31 | elif args.dataset == 'kinetics400':
32 | num_class = 400
33 | elif args.dataset == 'kinetics200':
34 | num_class = 200
35 | else:
36 | raise ValueError('Unknown dataset '+args.dataset)
37 |
38 | if "ucf101" in args.dataset or "hmdb51" in args.dataset:
39 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
40 | "data/{}/access".format(args.dataset[:-3]))
41 | else:
42 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
43 | "data/{}/access".format(args.dataset))
44 |
45 | # create model
46 | org_model = VideoModule(num_class=num_class,
47 | base_model_name=args.arch,
48 | dropout=args.dropout,
49 | pretrained=args.pretrained,
50 | pretrained_model=args.pretrained_model)
51 | num_params = 0
52 | for param in org_model.parameters():
53 | num_params += param.reshape((-1, 1)).shape[0]
54 | print("Model Size is {:.3f}M".format(num_params/1000000))
55 |
56 | model = torch.nn.DataParallel(org_model).cuda()
57 | # model = org_model
58 |
59 | # define loss function (criterion) and optimizer
60 | criterion = torch.nn.CrossEntropyLoss().cuda()
61 |
62 | optim_params = [param[1] for param in model.named_parameters() if "classifier" in param[0]]
63 | # import pdb
64 | # pdb.set_trace()
65 | optimizer = torch.optim.SGD(optim_params,
66 | args.lr,
67 | momentum=args.momentum,
68 | weight_decay=args.weight_decay)
69 |
70 | # optionally resume from a checkpoint
71 | if args.resume:
72 | if os.path.isfile(args.resume):
73 | print(("=> loading checkpoint '{}'".format(args.resume)))
74 | checkpoint = torch.load(args.resume)
75 | args.start_epoch = checkpoint['epoch']
76 | best_metric = checkpoint['best_metric']
77 | model.load_state_dict(checkpoint['state_dict'])
78 | optimizer.load_state_dict(checkpoint['optimizer'])
79 | print(("=> loaded checkpoint '{}' (epoch {})"
80 | .format(args.resume, checkpoint['epoch'])))
81 | else:
82 | print(("=> no checkpoint found at '{}'".format(args.resume)))
83 |
84 | # Data loading code
85 | ## train data
86 | train_transform = torchvision.transforms.Compose([
87 | GroupScale(args.new_size),
88 | GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]),
89 | GroupRandomHorizontalFlip(),
90 | Stack(mode=args.mode),
91 | ToTorchFormatTensor(),
92 | GroupNormalize(),
93 | ])
94 | train_dataset = VideoDataSet(root_path=data_root,
95 | list_file=args.train_list,
96 | t_length=args.t_length,
97 | t_stride=args.t_stride,
98 | num_segments=args.num_segments,
99 | image_tmpl=args.image_tmpl,
100 | transform=train_transform,
101 | phase="Train")
102 | train_loader = torch.utils.data.DataLoader(
103 | train_dataset,
104 | batch_size=args.batch_size, shuffle=True, drop_last=True,
105 | num_workers=args.workers, pin_memory=True)
106 |
107 | ## val data
108 | val_transform = torchvision.transforms.Compose([
109 | GroupScale(args.new_size),
110 | GroupCenterCrop(args.crop_size),
111 | Stack(mode=args.mode),
112 | ToTorchFormatTensor(),
113 | GroupNormalize(),
114 | ])
115 | val_dataset = VideoDataSet(root_path=data_root,
116 | list_file=args.val_list,
117 | t_length=args.t_length,
118 | t_stride=args.t_stride,
119 | num_segments=args.num_segments,
120 | image_tmpl=args.image_tmpl,
121 | transform=val_transform,
122 | phase="Val")
123 | val_loader = torch.utils.data.DataLoader(
124 | val_dataset,
125 | batch_size=args.batch_size, shuffle=False,
126 | num_workers=args.workers, pin_memory=True)
127 |
128 | if args.mode != "3D":
129 | cudnn.benchmark = True
130 |
131 | # validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
132 |
133 | for epoch in range(args.start_epoch, args.epochs):
134 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps)
135 |
136 | # train for one epoch
137 | finetune_fc(train_loader, model, criterion, optimizer, epoch, args.print_freq)
138 |
139 | # evaluate on validation set
140 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
141 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1)
142 |
143 | # remember best prec@1 and save checkpoint
144 | is_best = metric > best_metric
145 | best_metric = max(metric, best_metric)
146 | save_checkpoint({
147 | 'epoch': epoch + 1,
148 | 'arch': args.arch,
149 | 'state_dict': model.state_dict(),
150 | 'best_metric': best_metric,
151 | 'optimizer': optimizer.state_dict(),
152 | }, is_best, epoch + 1, args.experiment_root)
153 |
154 | if __name__ == '__main__':
155 | main()
156 |
--------------------------------------------------------------------------------
/lib/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | from numpy.random import randint
8 |
9 | import torch
10 |
11 | class VideoRecord(object):
12 | def __init__(self, row, root_path):
13 | self._data = row
14 | self._root_path = root_path
15 |
16 | @property
17 | def path(self):
18 | return os.path.join(self._root_path, self._data[0])
19 |
20 | @property
21 | def num_frames(self):
22 | return int(self._data[1])
23 |
24 | @property
25 | def label(self):
26 | return int(self._data[2])
27 |
28 | class VideoDebugDataSet(data.Dataset):
29 | """
30 | """
31 | def __len__(self):
32 | return 100
33 |
34 | def __getitem__(self, index):
35 | np.random.seed(12345)
36 | input_tensor = (np.random.random_sample((3,18,224,224)) - 0.5) * 2
37 | return torch.from_numpy(input_tensor).to(torch.float), 0
38 |
39 | class VideoDataSet(data.Dataset):
40 | def __init__(self, root_path, list_file,
41 | t_length=32, t_stride=2, num_segments=1,
42 | image_tmpl='img_{:05d}.jpg',
43 | transform=None, style="Dense",
44 | phase="Train"):
45 | """
46 | :style: Dense, for 2D and 3D model, and Sparse for TSN model
47 | :phase: Train, Val, Test
48 | """
49 |
50 | self.root_path = root_path
51 | self.list_file = list_file
52 | self.t_length = t_length
53 | self.t_stride = t_stride
54 | self.num_segments = num_segments
55 | self.image_tmpl = image_tmpl
56 | self.transform = transform
57 | assert(style in ("Dense", "UnevenDense")), "Only support Dense and UnevenDense"
58 | self.style = style
59 | self.phase = phase
60 | assert(t_length > 0), "Length of time must be bigger than zero."
61 | assert(t_stride > 0), "Stride of time must be bigger than zero."
62 |
63 | self._parse_list()
64 |
65 | def _load_image(self, directory, idx):
66 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')]
67 |
68 | def _parse_list(self):
69 | self.video_list = [VideoRecord(x.strip().split(' '), self.root_path) for x in open(self.list_file)]
70 | # self.video_list = [VideoRecord(x.strip().split(' '), self.root_path) for x in open(self.list_file) if VideoRecord(x.strip().split(' '), self.root_path).num_frames > 240]
71 | # print(len(self.video_list))
72 |
73 | @staticmethod
74 | def dense_sampler(num_frames, length, stride=1):
75 | t_length = length
76 | t_stride = stride
77 | # compute offsets
78 | offset = 0
79 | average_duration = num_frames - (t_length - 1) * t_stride - 1
80 | if average_duration >= 0:
81 | offset = randint(average_duration + 1)
82 | elif num_frames > t_length:
83 | while(t_stride - 1 > 0):
84 | t_stride -= 1
85 | average_duration = num_frames - (t_length - 1) * t_stride - 1
86 | if average_duration >= 0:
87 | offset = randint(average_duration + 1)
88 | break
89 | assert(t_stride >= 1), "temporal stride must be bigger than zero."
90 | else:
91 | t_stride = 1
92 | # sampling
93 | samples = []
94 | for i in range(t_length):
95 | samples.append(offset + i * t_stride + 1)
96 | return samples
97 |
98 | def _sample_indices(self, record):
99 | """
100 | :param record: VideoRecord
101 | :return: list
102 | """
103 | if self.style == "Dense":
104 | frames = []
105 | average_duration = record.num_frames / self.num_segments
106 | offsets = [average_duration * i for i in range(self.num_segments)]
107 | for i in range(self.num_segments):
108 | samples = self.dense_sampler(average_duration, self.t_length, self.t_stride)
109 | samples = [sample + offsets[i] for sample in samples]
110 | frames.extend(samples)
111 | return {"dense": frames}
112 | elif self.style == "UnevenDense":
113 | sparse_frames = []
114 | average_duration = record.num_frames / self.num_segments
115 | offsets = [average_duration * i for i in range(self.num_segments)]
116 | dense_frames = self.dense_sampler(record.num_frames, self.t_length, self.t_stride)
117 | dense_seg = -1
118 | for i in range(self.num_segments):
119 | if dense_frames[self.t_length//2] >= offsets[self.num_segments - i - 1]:
120 | dense_seg = self.num_segments - i - 1
121 | break
122 | else:
123 | continue
124 | assert(dense_seg != -1)
125 | # dense_seg = randint(self.num_segments)
126 | for i in range(self.num_segments):
127 | # if i == dense_seg:
128 | # samples = self.dense_sampler(average_duration, self.t_length, self.t_stride)
129 | # samples = [sample + offsets[i] for sample in samples]
130 | # dense_frames.extend(samples)
131 | # dense_seg = -1 # set dense seg to -1 and check after sampling.
132 | if i != dense_seg:
133 | samples = self.dense_sampler(average_duration, 1)
134 | samples = [sample + offsets[i] for sample in samples]
135 | sparse_frames.extend(samples)
136 | return {"dense":dense_frames, "sparse":sparse_frames}
137 | else:
138 | return
139 |
140 | def _get_val_indices(self, record):
141 | """
142 | get indices in val phase
143 | """
144 | # valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1
145 | valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1
146 | offset = int(valid_offset_range / 2.0)
147 | if offset < 0:
148 | offset = 0
149 | samples = []
150 | for i in range(self.t_length):
151 | samples.append(offset + i * self.t_stride + 1)
152 | return {"dense": samples}
153 |
154 | def _get_test_indices(self, record):
155 | """
156 | get indices in test phase
157 | """
158 | valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1
159 | interval = valid_offset_range / (self.num_segments - 1)
160 | offsets = []
161 | for i in range(self.num_segments):
162 | offset = int(i * interval)
163 | if offset > valid_offset_range:
164 | offset = valid_offset_range
165 | if offset < 0:
166 | offset = 0
167 | offsets.append(offset + 1)
168 | frames = []
169 | for i in range(self.num_segments):
170 | for j in range(self.t_length):
171 | frames.append(offsets[i] + j*self.t_stride)
172 | # frames.append(offsets[i]+j)
173 | return {"dense": frames}
174 |
175 | def __getitem__(self, index):
176 | record = self.video_list[index]
177 |
178 | if self.phase == "Train":
179 | indices = self._sample_indices(record)
180 | return self.get(record, indices, self.phase)
181 | elif self.phase == "Val":
182 | indices = self._get_val_indices(record)
183 | return self.get(record, indices, self.phase)
184 | elif self.phase == "Test":
185 | indices = self._get_test_indices(record)
186 | return self.get(record, indices, self.phase)
187 | else:
188 | raise TypeError("Unsuported phase {}".format(self.phase))
189 |
190 | def get(self, record, indices, phase):
191 | # dense process data
192 | def dense_process_data():
193 | images = list()
194 | for ind in indices['dense']:
195 | ptr = int(ind)
196 | if ptr <= record.num_frames:
197 | imgs = self._load_image(record.path, ptr)
198 | else:
199 | imgs = self._load_image(record.path, record.num_frames)
200 | images.extend(imgs)
201 | return self.transform(images)
202 | # unevendense process data
203 | def unevendense_process_data():
204 | dense_images = list()
205 | sparse_images = list()
206 | for ind in indices['dense']:
207 | ptr = int(ind)
208 | if ptr <= record.num_frames:
209 | imgs = self._load_image(record.path, ptr)
210 | else:
211 | imgs = self._load_image(record.path, record.num_frames)
212 | dense_images.extend(imgs)
213 | for ind in indices['sparse']:
214 | ptr = int(ind)
215 | if ptr <= record.num_frames:
216 | imgs = self._load_image(record.path, ptr)
217 | else:
218 | imgs = self._load_image(record.path, record.num_frames)
219 | sparse_images.extend(imgs)
220 |
221 | images = dense_images + sparse_images
222 | return self.transform(images)
223 | if phase == "Train":
224 | if self.style == "Dense":
225 | process_data = dense_process_data()
226 | elif self.style == "UnevenDense":
227 | process_data = unevendense_process_data()
228 | elif phase in ("Val", "Test"):
229 | process_data = dense_process_data()
230 | return process_data, record.label
231 |
232 | def __len__(self):
233 | return len(self.video_list)
234 |
235 | class ShortVideoDataSet(VideoDataSet):
236 | def __init__(self, root_path, list_file,
237 | t_length=32, t_stride=2, num_segments=1,
238 | image_tmpl='img_{:05d}.jpg',
239 | transform=None, style="Dense",
240 | phase="Train"):
241 | """
242 | :style: Dense, for 2D and 3D model, and Sparse for TSN model
243 | :phase: Train, Val, Test
244 | """
245 |
246 | super(ShortVideoDataSet, self).__init__(root_path,
247 | list_file, t_length, t_stride, num_segments,
248 | image_tmpl, transform, style, phase)
249 |
250 |
251 | def _get_val_indices(self, record):
252 | """
253 | get indices in val phase
254 | """
255 | # valid_offset_range = record.num_frames - (self.t_length - 1) * self.t_stride - 1
256 | t_stride = self.t_stride
257 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1
258 | offset = int(valid_offset_range / 2.0)
259 |
260 | if record.num_frames > self.t_length:
261 | while(offset < 0 and t_stride > 1):
262 | t_stride -= 1
263 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1
264 | offset = int(valid_offset_range / 2.0)
265 | else:
266 | t_stride = 1
267 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1
268 | offset = int(valid_offset_range / 2.0)
269 |
270 | if offset < 0:
271 | offset = 0
272 | samples = []
273 | for i in range(self.t_length):
274 | samples.append(offset + i * t_stride + 1)
275 | return {"dense": samples}
276 |
277 | def _get_test_indices(self, record):
278 | """
279 | get indices in test phase
280 | """
281 | t_stride = self.t_stride
282 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1
283 | while(valid_offset_range < (self.num_segments - 1) and t_stride > 1):
284 | t_stride -= 1
285 | valid_offset_range = record.num_frames - (self.t_length - 1) * t_stride - 1
286 | if valid_offset_range < 0:
287 | valid_offset_range = 0
288 | interval = valid_offset_range / (self.num_segments - 1)
289 | offsets = []
290 | for i in range(self.num_segments):
291 | offset = int(i * interval)
292 | if offset > valid_offset_range+1:
293 | offset = valid_offset_range+1
294 | if offset < 0:
295 | offset = 0
296 | offsets.append(offset + 1)
297 | frames = []
298 | for i in range(self.num_segments):
299 | for j in range(self.t_length):
300 | frames.append(offsets[i] + j * t_stride)
301 | # frames.append(offsets[i]+j)
302 | return {"dense": frames}
303 |
304 |
305 | if __name__ == "__main__":
306 | td = VideoDataSet(root_path="../data/kinetics400/access/kinetics_train_rgb_img_256_340/",
307 | list_file="../data/kinetics400/kinetics_train_list.txt",
308 | t_length=16,
309 | t_stride=4,
310 | num_segments=3,
311 | image_tmpl="image_{:06d}.jpg",
312 | style="UnevenDense",
313 | phase="Train")
314 | # sample0 = td[0]
315 | import pdb
316 | pdb.set_trace()
317 |
--------------------------------------------------------------------------------
/lib/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch import nn
3 | from torch.nn.parameter import Parameter
4 | from .networks import *
5 |
6 | from .transforms import *
7 |
8 | class VideoModule(nn.Module):
9 | def __init__(self, num_class, base_model_name='resnet50',
10 | before_softmax=True, dropout=0.8, pretrained=True, pretrained_model=None):
11 | super(VideoModule, self).__init__()
12 | self.num_class = num_class
13 | self.base_model_name = base_model_name
14 | self.before_softmax = before_softmax
15 | self.dropout = dropout
16 | self.pretrained = pretrained
17 | self.pretrained_model = pretrained_model
18 | # self.finetune = finetune
19 |
20 | self._prepare_base_model(base_model_name)
21 |
22 | if not self.before_softmax:
23 | self.softmax = nn.Softmax()
24 |
25 | def _prepare_base_model(self, base_model_name):
26 | """
27 | base_model+(dropout)+classifier
28 | """
29 | base_model_dict = None
30 | classifier_dict = None
31 | if self.pretrained and self.pretrained_model:
32 | model_dict = torch.load(self.pretrained_model)
33 | base_model_dict = {k: v for k, v in model_dict.items() if "classifier" not in k}
34 | classifier_dict = {'.'.join(k.split('.')[1:]): v for k, v in model_dict.items() if "classifier" in k}
35 | # base model
36 | if "resnet" in base_model_name:
37 | self.base_model = eval(base_model_name)(pretrained=self.pretrained, \
38 | feat=True, pretrained_model=base_model_dict)
39 | elif base_model_name == "mnet2":
40 | model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
41 | "../models/mobilenet_v2.pth.tar")
42 | self.base_model = mnet2(pretrained=model_path, feat=True)
43 | elif base_model_name == "mnet2_3d":
44 | model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
45 | "../models/mobilenet_v2.pth.tar")
46 | self.base_model = mnet2_3d(pretrained=model_path, feat=True)
47 | elif "fst" in base_model_name or "msv" in base_model_name or "gsv" in base_model_name:
48 | self.base_model = eval(base_model_name)(pretrained=self.pretrained,
49 | feat=True, pretrained_model=base_model_dict)
50 | else:
51 | raise ValueError('Unknown base model: {}'.format(base_model))
52 |
53 | # classifier: (dropout) + fc
54 | if self.dropout == 0:
55 | self.classifier = nn.Linear(self.base_model.feat_dim, self.num_class)
56 | elif self.dropout > 0:
57 | self.classifier = nn.Sequential(nn.Dropout(self.dropout), nn.Linear(self.base_model.feat_dim, self.num_class))
58 |
59 | # init classifier
60 | for m in self.classifier.modules():
61 | if isinstance(m, nn.Linear):
62 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear')
63 | nn.init.constant_(m.bias, 0)
64 |
65 | # if self.pretrained and self.pretrained_model:
66 | # self.classifier.load_state_dict(classifier_dict)
67 |
68 | # if self.finetune:
69 | # print("Finetune")
70 | # for param in self.base_model.parameters():
71 | # param.requires_grad = False
72 | # for m in self.base_model.modules():
73 | # if isinstance(m, nn.BatchNorm3d):
74 | # m.eval()
75 |
76 | # import pdb
77 | # pdb.set_trace()
78 |
79 | def forward(self, input):
80 | out = self.base_model(input)
81 | out = self.classifier(out)
82 |
83 | if not self.before_softmax:
84 | out = self.softmax(out)
85 |
86 | return out
87 |
88 | def get_augmentation(self):
89 | return torchvision.transforms.Compose([GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]),
90 | GroupRandomHorizontalFlip()])
91 |
92 | class TSN(nn.Module):
93 | """Temporal Segment Network
94 |
95 | """
96 | def __init__(self, batch_size, video_module, num_segments=1, t_length=1,
97 | crop_fusion_type='max', mode="3D"):
98 | super(TSN, self).__init__()
99 | self.t_length = t_length
100 | self.batch_size = batch_size
101 | self.num_segments = num_segments
102 | self.video_module = video_module
103 | self.crop_fusion_type = crop_fusion_type
104 | self.mode = mode
105 |
106 | def forward(self, input):
107 | # reshape input first
108 | shape = input.shape
109 | if "3D" in self.mode:
110 | assert(len(shape)) == 5, "In 3D mode, input must have 5 dims."
111 | shape = (shape[0], shape[1], shape[2]//self.t_length, self.t_length) + shape[3:]
112 | input = input.view(shape).permute((0, 2, 1, 3, 4, 5)).contiguous()
113 | shape = (input.shape[0] * input.shape[1], ) + input.shape[2:]
114 | input = input.view(shape)
115 | elif "2D" in self.mode:
116 | assert(len(shape)) == 4, "In 2D mode, input must have 4 dims."
117 | shape = (shape[0]*shape[1]//3, 3,) + shape[2:]
118 | input = input.view(shape)
119 | else:
120 | raise Exception("Unsupported mode.")
121 |
122 | # base network forward
123 | output = self.video_module(input)
124 | # fuse output
125 | output = output.view((self.batch_size,
126 | output.shape[0] // (self.batch_size * self.num_segments),
127 | self.num_segments, output.shape[1]))
128 |
129 | output_max = output.max(1)[0].squeeze(1)
130 | pred_max = output_max.mean(1).squeeze(1)
131 | output_ave = output.mean(1).squeeze(1)
132 | pred_ave = output_ave.mean(1).squeeze(1)
133 | # if self.crop_fusion_type == 'max':
134 | # # pdb.set_trace()
135 | # output = output.max(1)[0].squeeze(1)
136 | # elif self.crop_fusion_type == 'avg':
137 | # output = output.mean(1).squeeze(1)
138 | # pred = output.mean(1).squeeze(1)
139 | return (output_max, pred_max, output_ave, pred_ave)
140 |
--------------------------------------------------------------------------------
/lib/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .scale import *
2 | from .pooling import *
--------------------------------------------------------------------------------
/lib/modules/pooling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class GloAvgPool3d(nn.Module):
6 | def __init__(self):
7 | super(GloAvgPool3d, self).__init__()
8 | self.stride = 1
9 | self.padding = 0
10 | self.ceil_mode = False
11 | self.count_include_pad = True
12 |
13 | def forward(self, input):
14 | input_shape = input.shape
15 | kernel_size = input_shape[2:]
16 | return F.avg_pool3d(input, kernel_size, self.stride,
17 | self.padding, self.ceil_mode, self.count_include_pad)
18 |
19 | class GloSptMaxPool3d(nn.Module):
20 | def __init__(self):
21 | super(GloSptMaxPool3d, self).__init__()
22 | self.stride = 1
23 | self.padding = 0
24 | self.ceil_mode = False
25 | self.count_include_pad = True
26 |
27 | def forward(self, input):
28 | input_shape = input.shape
29 | kernel_size = (1,) + input_shape[3:]
30 | return F.max_pool3d(input, kernel_size=kernel_size, stride=self.stride,
31 | padding=self.padding, ceil_mode=self.ceil_mode)
32 |
33 | class GloSptAvgPool3d(nn.Module):
34 | def __init__(self):
35 | super(GloSptAvgPool3d, self).__init__()
36 | self.stride = 1
37 | self.padding = 0
38 | self.ceil_mode = False
39 | self.count_include_pad = True
40 |
41 | def forward(self, input):
42 | input_shape = input.shape
43 | kernel_size = (1, ) + input_shape[3:]
44 | return F.avg_pool3d(input, kernel_size, self.stride,
45 | self.padding, self.ceil_mode, self.count_include_pad)
--------------------------------------------------------------------------------
/lib/modules/scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 | class Scale2d(nn.Module):
6 | def __init__(self, out_channels):
7 | super(Scale2d, self).__init__()
8 | self.scale = Parameter(torch.Tensor(1, out_channels, 1, 1))
9 |
10 | def forward(self, input):
11 | return input * self.scale
12 |
13 | class Scale3d(nn.Module):
14 | def __init__(self, out_channels):
15 | super(Scale3d, self).__init__()
16 | self.scale = Parameter(torch.Tensor(1, out_channels, 1, 1, 1))
17 |
18 | def forward(self, input):
19 | return input * self.scale
--------------------------------------------------------------------------------
/lib/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .mnet2 import *
2 | from .mnet2_3d import *
3 | from .resnet import *
4 | from .resnet_3d import *
5 | from .part_inflate_resnet_3d import *
6 | from .resnet_3d_nodown import *
7 |
--------------------------------------------------------------------------------
/lib/networks/mnet2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import os
5 |
6 | __all__ = ['mnet2']
7 |
8 | def conv_bn(inp, oup, stride):
9 | return nn.Sequential(
10 | nn.Conv2d(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False),
11 | nn.BatchNorm2d(oup),
12 | nn.ReLU6(inplace=True)
13 | )
14 |
15 |
16 | def conv_1x1_bn(inp, oup):
17 | return nn.Sequential(
18 | nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False),
19 | nn.BatchNorm2d(oup),
20 | nn.ReLU6(inplace=True)
21 | )
22 |
23 |
24 | class InvertedResidual(nn.Module):
25 | def __init__(self, inp, oup, stride, expand_ratio):
26 | super(InvertedResidual, self).__init__()
27 | self.stride = stride
28 | assert stride in [1, 2]
29 |
30 | hidden_dim = round(inp * expand_ratio)
31 | self.use_res_connect = self.stride == 1 and inp == oup
32 |
33 | if expand_ratio == 1:
34 | self.conv = nn.Sequential(
35 | # dw
36 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False),
37 | nn.BatchNorm2d(hidden_dim),
38 | nn.ReLU6(inplace=True),
39 | # pw-linear
40 | nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
41 | nn.BatchNorm2d(oup),
42 | )
43 | else:
44 | self.conv = nn.Sequential(
45 | # pw
46 | nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
47 | nn.BatchNorm2d(hidden_dim),
48 | nn.ReLU6(inplace=True),
49 | # dw
50 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False),
51 | nn.BatchNorm2d(hidden_dim),
52 | nn.ReLU6(inplace=True),
53 | # pw-linear
54 | nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
55 | nn.BatchNorm2d(oup),
56 | )
57 |
58 | def forward(self, x):
59 | if self.use_res_connect:
60 | return x + self.conv(x)
61 | else:
62 | return self.conv(x)
63 |
64 |
65 | class MobileNetV2(nn.Module):
66 | def __init__(self, n_class=1000, input_size=224, width_mult=1., feat=False):
67 | super(MobileNetV2, self).__init__()
68 | self.feat = feat
69 | block = InvertedResidual
70 | input_channel = 32
71 | last_channel = 1280
72 | interverted_residual_setting = [
73 | # t, c, n, s
74 | [1, 16, 1, 1],
75 | [6, 24, 2, 2],
76 | [6, 32, 3, 2],
77 | [6, 64, 4, 2],
78 | [6, 96, 3, 1],
79 | [6, 160, 3, 2],
80 | [6, 320, 1, 1],
81 | ]
82 |
83 | # building first layer
84 | assert input_size % 32 == 0
85 | input_channel = int(input_channel * width_mult)
86 | self.feat_dim = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
87 | self.features = [conv_bn(3, input_channel, 2)]
88 | # building inverted residual blocks
89 | for t, c, n, s in interverted_residual_setting:
90 | output_channel = int(c * width_mult)
91 | for i in range(n):
92 | if i == 0:
93 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
94 | else:
95 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
96 | input_channel = output_channel
97 | # building last several layers
98 | self.features.append(conv_1x1_bn(input_channel, self.feat_dim))
99 | # make it nn.Sequential
100 | self.features = nn.Sequential(*self.features)
101 | self.avgpool = nn.AvgPool2d(7, stride=1)
102 |
103 | # building classifier
104 | if not self.feat:
105 | self.classifier = nn.Sequential(
106 | nn.Dropout(0.2),
107 | nn.Linear(self.feat_dim, n_class),
108 | )
109 |
110 | self._initialize_weights()
111 |
112 | def forward(self, x):
113 | x = self.features(x)
114 | x = self.avgpool(x)
115 | x = x.view(x.size(0), -1)
116 | if not self.feat:
117 | x = self.classifier(x)
118 | return x
119 |
120 | def _initialize_weights(self):
121 | for m in self.modules():
122 | if isinstance(m, nn.Conv2d):
123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
124 | m.weight.data.normal_(0, math.sqrt(2. / n))
125 | if m.bias is not None:
126 | m.bias.data.zero_()
127 | elif isinstance(m, nn.BatchNorm2d):
128 | m.weight.data.fill_(1)
129 | m.bias.data.zero_()
130 | elif isinstance(m, nn.Linear):
131 | n = m.weight.size(1)
132 | m.weight.data.normal_(0, 0.01)
133 | m.bias.data.zero_()
134 |
135 | def part_state_dict(state_dict, model_dict):
136 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
137 | model_dict.update(pretrained_dict)
138 | return model_dict
139 |
140 | def mnet2(pretrained=None, feat=False):
141 | if feat:
142 | assert(pretrained != None and os.path.exists(pretrained)), "pretrained model must be ready when using feat."
143 | model = MobileNetV2(feat=feat)
144 | if feat:
145 | state_dict = part_state_dict(torch.load(pretrained, map_location=lambda storage, loc: storage),
146 | model.state_dict())
147 | model.load_state_dict(state_dict)
148 | return model
149 |
--------------------------------------------------------------------------------
/lib/networks/mnet2_3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import os
5 |
6 | __all__ = ["mnet2_3d"]
7 |
8 | def conv_bn(inp, oup, stride, t_stride=1):
9 | return nn.Sequential(
10 | nn.Conv3d(inp, oup, kernel_size=(1, 3, 3),
11 | stride=(t_stride, stride, stride), padding=(0, 1, 1), bias=False),
12 | nn.BatchNorm3d(oup),
13 | nn.ReLU6(inplace=True)
14 | )
15 |
16 |
17 | def conv_1x1x1_bn(inp, oup):
18 | return nn.Sequential(
19 | nn.Conv3d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False),
20 | nn.BatchNorm3d(oup),
21 | nn.ReLU6(inplace=True)
22 | )
23 |
24 |
25 | class InvertedResidual(nn.Module):
26 | def __init__(self, inp, oup, stride, t_stride, expand_ratio, t_radius=1):
27 | super(InvertedResidual, self).__init__()
28 | self.stride = stride
29 | self.t_stride = t_stride
30 | self.t_radius = t_radius
31 | assert stride in [1, 2] and t_stride in [1, 2]
32 |
33 | hidden_dim = round(inp * expand_ratio)
34 | self.use_res_connect = self.stride == 1 and inp == oup
35 |
36 | if expand_ratio == 1:
37 | assert(t_stride == 1), "Temporal stride must be one when expand ratio is one."
38 | self.conv = nn.Sequential(
39 | # dw
40 | nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(1, 3, 3), stride=(t_stride, stride, stride),
41 | padding=(0, 1, 1), groups=hidden_dim, bias=False),
42 | nn.BatchNorm3d(hidden_dim),
43 | nn.ReLU6(inplace=True),
44 | # pw-linear
45 | nn.Conv3d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
46 | nn.BatchNorm3d(oup),
47 | )
48 | else:
49 | self.conv = nn.Sequential(
50 | # pw
51 | nn.Conv3d(inp, hidden_dim, kernel_size=(t_radius * 2 + 1, 1, 1),
52 | stride=(t_stride, 1, 1), padding=(t_radius, 0, 0), bias=False),
53 | nn.BatchNorm3d(hidden_dim),
54 | nn.ReLU6(inplace=True),
55 | # dw
56 | nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(1, 3, 3), stride=(1, stride, stride),
57 | padding=(0, 1, 1), groups=hidden_dim, bias=False),
58 | nn.BatchNorm3d(hidden_dim),
59 | nn.ReLU6(inplace=True),
60 | # pw-linear
61 | nn.Conv3d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
62 | nn.BatchNorm3d(oup),
63 | )
64 |
65 | def forward(self, x):
66 | if self.use_res_connect:
67 | return x + self.conv(x)
68 | else:
69 | return self.conv(x)
70 |
71 |
72 | class MobileNetV2_3D(nn.Module):
73 | def __init__(self, n_class=1000, input_size=224, width_mult=1., feat=False):
74 | super(MobileNetV2_3D, self).__init__()
75 | self.feat = feat
76 | block = InvertedResidual
77 | input_channel = 32
78 | last_channel = 1280
79 | interverted_residual_setting = [
80 | # t, c, n, s, ts, r
81 | [1, 16, 1, 1, 1, 0],
82 | [6, 24, 2, 2, 1, 0],
83 | [6, 32, 3, 2, 1, 0],
84 | [6, 64, 4, 2, 1, 1],
85 | [6, 96, 3, 1, 2, 1],
86 | [6, 160, 3, 2, 2, 1],
87 | [6, 320, 1, 1, 1, 1],
88 | ]
89 |
90 | # building first layer
91 | assert input_size % 32 == 0
92 | input_channel = int(input_channel * width_mult)
93 | self.feat_dim = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
94 | self.features = [conv_bn(3, input_channel, 2)]
95 | # building inverted residual blocks
96 | for t, c, n, s, ts, r in interverted_residual_setting:
97 | output_channel = int(c * width_mult)
98 | for i in range(n):
99 | if i == 0:
100 | self.features.append(block(input_channel, output_channel, s, ts, expand_ratio=t, t_radius=r))
101 | else:
102 | self.features.append(block(input_channel, output_channel, 1, 1, expand_ratio=t, t_radius=r))
103 | input_channel = output_channel
104 | # building last several layers
105 | self.features.append(conv_1x1x1_bn(input_channel, self.feat_dim))
106 | # make it nn.Sequential
107 | self.features = nn.Sequential(*self.features)
108 | self.avgpool = nn.AvgPool3d(kernel_size=(4, 7, 7), stride=1)
109 |
110 | # building classifier
111 | if not self.feat:
112 | self.classifier = nn.Sequential(
113 | nn.Dropout(0.2),
114 | nn.Linear(self.feat_dim, n_class),
115 | )
116 |
117 | self._initialize_weights()
118 |
119 | def forward(self, x):
120 | x = self.features(x)
121 | x = self.avgpool(x)
122 | x = x.view(x.size(0), -1)
123 | if not self.feat:
124 | x = self.classifier(x)
125 | return x
126 |
127 | def _initialize_weights(self):
128 | for m in self.modules():
129 | if isinstance(m, nn.Conv3d):
130 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
131 | m.weight.data.normal_(0, math.sqrt(2. / n))
132 | if m.bias is not None:
133 | m.bias.data.zero_()
134 | elif isinstance(m, nn.BatchNorm3d):
135 | m.weight.data.fill_(1)
136 | m.bias.data.zero_()
137 | elif isinstance(m, nn.Linear):
138 | n = m.weight.size(1)
139 | m.weight.data.normal_(0, 0.01)
140 | m.bias.data.zero_()
141 |
142 | def part_state_dict(state_dict, model_dict):
143 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
144 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict)
145 | model_dict.update(pretrained_dict)
146 | return model_dict
147 |
148 |
149 | def inflate_state_dict(pretrained_dict, model_dict):
150 | for k in pretrained_dict.keys():
151 | if pretrained_dict[k].size() != model_dict[k].size():
152 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \
153 | "To inflate, channel number should match."
154 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \
155 | "To inflate, spatial kernel size should match."
156 | print("Layer {} needs inflation.".format(k))
157 | shape = list(pretrained_dict[k].shape)
158 | shape.insert(2, 1)
159 | t_length = model_dict[k].shape[2]
160 | pretrained_dict[k] = pretrained_dict[k].reshape(shape)
161 | if t_length != 1:
162 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length
163 | assert(pretrained_dict[k].size() == model_dict[k].size()), \
164 | "After inflation, model shape should match."
165 | return pretrained_dict
166 |
167 | def mnet2_3d(pretrained=None, feat=False):
168 | if pretrained != None:
169 | assert(os.path.exists(pretrained)), "pretrained model does not exist."
170 | model = MobileNetV2_3D(feat=feat)
171 | if pretrained:
172 | state_dict = torch.load(pretrained, map_location=lambda storage, loc: storage)
173 | state_dict = part_state_dict(state_dict, model.state_dict())
174 | model.load_state_dict(state_dict)
175 | return model
176 |
--------------------------------------------------------------------------------
/lib/networks/part_inflate_resnet_3d.py:
--------------------------------------------------------------------------------
1 | """
2 | Modify the original file to make the class support feature extraction
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parameter import Parameter
7 | import torch.nn.functional as F
8 | import math
9 | import torch.utils.model_zoo as model_zoo
10 | from ..modules import *
11 |
12 |
13 | __all__ = ["pib_resnet26_3d_v1", "pib_resnet50_3d_slow", "pib_resnet26_3d_v1_1", "pib_resnet26_3d_full", "pib_resnet26_2d_full"]
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | }
22 |
23 | class Bottleneck3D_000(nn.Module):
24 | expansion = 4
25 |
26 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
27 | super(Bottleneck3D_000, self).__init__()
28 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1,
29 | stride=[t_stride, 1, 1], bias=False)
30 | self.bn1 = nn.BatchNorm3d(planes)
31 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
32 | stride=[1, stride, stride], padding=(0, 1, 1), bias=False)
33 | self.bn2 = nn.BatchNorm3d(planes)
34 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
35 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv3(out)
52 | out = self.bn3(out)
53 |
54 | if self.downsample is not None:
55 | residual = self.downsample(x)
56 |
57 | out += residual
58 | out = self.relu(out)
59 |
60 | return out
61 |
62 | class PIBottleneck3D(nn.Module):
63 | expansion = 4
64 |
65 | def __init__(self, inplanes, planes, ratio=0.5, stride=1, t_stride=1, downsample=None):
66 | super(PIBottleneck3D, self).__init__()
67 | self.ratio = ratio
68 | if ratio == 1:
69 | self.conv1_t = nn.Conv3d(inplanes, planes,
70 | kernel_size=(3, 1, 1),
71 | stride=(t_stride, 1, 1),
72 | padding=(1, 0, 0),
73 | bias=False)
74 | elif ratio == 0:
75 | self.conv1_p = nn.Conv3d(inplanes, planes,
76 | kernel_size=(1, 1, 1),
77 | stride=(t_stride, 1, 1),
78 | padding=(0, 0, 0),
79 | bias=False)
80 | else:
81 | self.conv1_t = nn.Conv3d(inplanes, int(planes * ratio),
82 | kernel_size=(3, 1, 1),
83 | stride=(t_stride, 1, 1),
84 | padding=(1, 0, 0),
85 | bias=False)
86 | self.conv1_p = nn.Conv3d(inplanes, int(planes*(1-ratio)),
87 | kernel_size=(1, 1, 1),
88 | stride=(t_stride, 1, 1),
89 | padding=(0, 0, 0),
90 | bias=False)
91 | self.bn1 = nn.BatchNorm3d(planes)
92 | self.conv2 = nn.Conv3d(planes, planes,
93 | kernel_size=(1, 3, 3),
94 | stride=(1, stride, stride),
95 | padding=(0, 1, 1),
96 | bias=False)
97 | self.bn2 = nn.BatchNorm3d(planes)
98 | self.conv3 = nn.Conv3d(planes, planes * self.expansion,
99 | kernel_size=1,
100 | bias=False)
101 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
102 | self.relu = nn.ReLU(inplace=True)
103 | self.downsample = downsample
104 | self.stride = stride
105 |
106 | def forward(self, x):
107 | residual = x
108 |
109 | if self.ratio == 1:
110 | out = self.conv1_t(x)
111 | elif self.ratio == 0:
112 | out = self.conv1_p(x)
113 | else:
114 | out_t = self.conv1_t(x)
115 | out_p = self.conv1_p(x)
116 | out = torch.cat((out_t, out_p), dim=1)
117 | out = self.bn1(out)
118 | out = self.relu(out)
119 |
120 | out = self.conv2(out)
121 | out = self.bn2(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv3(out)
125 | out = self.bn3(out)
126 |
127 | if self.downsample is not None:
128 | residual = self.downsample(x)
129 |
130 | out += residual
131 | out = self.relu(out)
132 |
133 | return out
134 |
135 | class PIBResNet3D_8fr(nn.Module):
136 |
137 | def __init__(self, block, layers, ratios, num_classes=1000, feat=False, **kwargs):
138 | if not isinstance(block, list):
139 | block = [block] * 4
140 | else:
141 | assert(len(block)) == 4, "Block number must be 4 for ResNet-Stype networks."
142 | self.inplanes = 64
143 | super(PIBResNet3D_8fr, self).__init__()
144 | self.feat = feat
145 | self.conv1 = nn.Conv3d(3, 64,
146 | kernel_size=(1, 7, 7),
147 | stride=(1, 2, 2),
148 | padding=(0, 3, 3),
149 | bias=False)
150 | self.bn1 = nn.BatchNorm3d(64)
151 | self.relu = nn.ReLU(inplace=True)
152 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3),
153 | stride=(1, 2, 2),
154 | padding=(0, 1, 1))
155 | self.layer1 = self._make_layer(block[0], 64, layers[0], inf_ratio=ratios[0])
156 | self.layer2 = self._make_layer(block[1], 128, layers[1], inf_ratio=ratios[1], stride=2)
157 | self.layer3 = self._make_layer(block[2], 256, layers[2], inf_ratio=ratios[2], stride=2, t_stride=2)
158 | self.layer4 = self._make_layer(block[3], 512, layers[3], inf_ratio=ratios[3], stride=2, t_stride=2)
159 | self.avgpool = GloAvgPool3d()
160 | self.feat_dim = 512 * block[0].expansion
161 | if not feat:
162 | self.fc = nn.Linear(512 * block[0].expansion, num_classes)
163 |
164 | for n, m in self.named_modules():
165 | if isinstance(m, nn.Conv3d):
166 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
167 | elif isinstance(m, nn.BatchNorm3d) and "conv_t" not in n:
168 | nn.init.constant_(m.weight, 1)
169 | nn.init.constant_(m.bias, 0)
170 | elif isinstance(m, Scale3d):
171 | nn.init.constant_(m.scale, 0)
172 |
173 |
174 | def _make_layer(self, block, planes, blocks, inf_ratio, stride=1, t_stride=1):
175 | downsample = None
176 | if stride != 1 or self.inplanes != planes * block.expansion:
177 | downsample = nn.Sequential(
178 | nn.Conv3d(self.inplanes, planes * block.expansion,
179 | kernel_size=1, stride=(t_stride, stride, stride), bias=False),
180 | nn.BatchNorm3d(planes * block.expansion),
181 | )
182 |
183 | layers = []
184 | layers.append(block(self.inplanes, planes, inf_ratio, stride=stride, t_stride=t_stride, downsample=downsample))
185 | self.inplanes = planes * block.expansion
186 | for i in range(1, blocks):
187 | layers.append(block(self.inplanes, planes, inf_ratio))
188 |
189 | return nn.Sequential(*layers)
190 |
191 | def forward(self, x):
192 | x = self.conv1(x)
193 | x = self.bn1(x)
194 | x = self.relu(x)
195 | x = self.maxpool(x)
196 |
197 | x = self.layer1(x)
198 | x = self.layer2(x)
199 | x = self.layer3(x)
200 | x = self.layer4(x)
201 |
202 |
203 | x = self.avgpool(x)
204 | x = x.view(x.size(0), -1)
205 | if not self.feat:
206 | x = self.fc(x)
207 |
208 | return x
209 |
210 |
211 | def part_state_dict(state_dict, model_dict, ratios):
212 | assert(len(ratios) == 4), "Length of ratios must equal to stage number"
213 | added_dict = {}
214 | for k, v in state_dict.items():
215 | # import pdb
216 | # pdb.set_trace()
217 | if ".conv1.weight" in k and "layer" in k:
218 | # import pdb
219 | # pdb.set_trace()
220 | ratio = ratios[int(k[k.index("layer")+5])-1]
221 | out_channels = v.shape[0]
222 | slice_index = int(out_channels*ratio)
223 | if ratio == 1:
224 | new_k = k[:k.index(".conv1.weight")]+'.conv1_t.weight'
225 | added_dict.update({new_k: v[:slice_index,...]})
226 | elif ratio == 0:
227 | new_k = k[:k.index(".conv1.weight")]+'.conv1_p.weight'
228 | added_dict.update({new_k: v[slice_index:,...]})
229 | else:
230 | new_k = k[:k.index(".conv1.weight")]+'.conv1_t.weight'
231 | added_dict.update({new_k: v[:slice_index,...]})
232 | new_k = k[:k.index(".conv1.weight")]+'.conv1_p.weight'
233 | added_dict.update({new_k: v[slice_index:,...]})
234 |
235 | state_dict.update(added_dict)
236 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
237 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict)
238 | model_dict.update(pretrained_dict)
239 | return model_dict
240 |
241 |
242 | def inflate_state_dict(pretrained_dict, model_dict):
243 | for k in pretrained_dict.keys():
244 | if pretrained_dict[k].size() != model_dict[k].size():
245 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \
246 | "To inflate, channel number should match."
247 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \
248 | "To inflate, spatial kernel size should match."
249 | print("Layer {} needs inflation.".format(k))
250 | shape = list(pretrained_dict[k].shape)
251 | shape.insert(2, 1)
252 | t_length = model_dict[k].shape[2]
253 | pretrained_dict[k] = pretrained_dict[k].reshape(shape)
254 | if t_length != 1:
255 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length
256 | assert(pretrained_dict[k].size() == model_dict[k].size()), \
257 | "After inflation, model shape should match."
258 |
259 | return pretrained_dict
260 |
261 | def pib_resnet26_3d_v1(pretrained=False, feat=False, **kwargs):
262 | """Constructs a ResNet-50 model.
263 | Args:
264 | pretrained (bool): If True, returns a model pre-trained on ImageNet
265 | """
266 | ratios = (1/8, 1/4, 1/2, 1)
267 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D],
268 | [2, 2, 2, 2], ratios, feat=feat, **kwargs)
269 | if pretrained:
270 | if kwargs['pretrained_model'] is None:
271 | pass
272 | # state_dict = model_zoo.load_url(model_urls['resnet50'])
273 | else:
274 | print("Using specified pretrain model")
275 | state_dict = kwargs['pretrained_model']
276 | if feat:
277 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios)
278 | model.load_state_dict(new_state_dict)
279 | return model
280 |
281 | def pib_resnet26_3d_full(pretrained=False, feat=False, **kwargs):
282 | """Constructs a ResNet-50 model.
283 | Args:
284 | pretrained (bool): If True, returns a model pre-trained on ImageNet
285 | """
286 | ratios = (1, 1, 1, 1)
287 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D],
288 | [2, 2, 2, 2], ratios, feat=feat, **kwargs)
289 | if pretrained:
290 | if kwargs['pretrained_model'] is None:
291 | pass
292 | # state_dict = model_zoo.load_url(model_urls['resnet50'])
293 | else:
294 | print("Using specified pretrain model")
295 | state_dict = kwargs['pretrained_model']
296 | if feat:
297 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios)
298 | model.load_state_dict(new_state_dict)
299 | return model
300 |
301 | def pib_resnet26_2d_full(pretrained=False, feat=False, **kwargs):
302 | """Constructs a ResNet-50 model.
303 | Args:
304 | pretrained (bool): If True, returns a model pre-trained on ImageNet
305 | """
306 | ratios = (0, 0, 0, 0)
307 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D],
308 | [2, 2, 2, 2], ratios, feat=feat, **kwargs)
309 | if pretrained:
310 | if kwargs['pretrained_model'] is None:
311 | pass
312 | # state_dict = model_zoo.load_url(model_urls['resnet50'])
313 | else:
314 | print("Using specified pretrain model")
315 | state_dict = kwargs['pretrained_model']
316 | if feat:
317 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios)
318 | model.load_state_dict(new_state_dict)
319 | return model
320 |
321 | def pib_resnet26_3d_v1_1(pretrained=False, feat=False, **kwargs):
322 | """Constructs a ResNet-50 model.
323 | Args:
324 | pretrained (bool): If True, returns a model pre-trained on ImageNet
325 | """
326 | ratios = (1/2, 1/2, 1/2, 1/2)
327 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D],
328 | [2, 2, 2, 2], ratios, feat=feat, **kwargs)
329 | if pretrained:
330 | if kwargs['pretrained_model'] is None:
331 | pass
332 | # state_dict = model_zoo.load_url(model_urls['resnet50'])
333 | else:
334 | print("Using specified pretrain model")
335 | state_dict = kwargs['pretrained_model']
336 | if feat:
337 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios)
338 | model.load_state_dict(new_state_dict)
339 | return model
340 |
341 | def pib_resnet50_3d_slow(pretrained=False, feat=False, **kwargs):
342 | """Constructs a ResNet-50 model.
343 | Args:
344 | pretrained (bool): If True, returns a model pre-trained on ImageNet
345 | """
346 | ratios = (0, 0, 1, 1)
347 | model = PIBResNet3D_8fr([PIBottleneck3D, PIBottleneck3D, PIBottleneck3D, PIBottleneck3D],
348 | [3, 4, 6, 3], ratios, feat=feat, **kwargs)
349 | if pretrained:
350 | if kwargs['pretrained_model'] is None:
351 | state_dict = model_zoo.load_url(model_urls['resnet50'])
352 | else:
353 | print("Using specified pretrain model")
354 | state_dict = kwargs['pretrained_model']
355 | if feat:
356 | new_state_dict = part_state_dict(state_dict, model.state_dict(), ratios)
357 | model.load_state_dict(new_state_dict)
358 | return model
--------------------------------------------------------------------------------
/lib/networks/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Modify the original file to make the class support feature extraction
3 | """
4 |
5 | import torch.nn as nn
6 | import math
7 | import torch.utils.model_zoo as model_zoo
8 | import torch
9 | from torch.nn.parameter import Parameter
10 | from ..modules import *
11 |
12 |
13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet26', 'resnet26_point', 'resnet50', 'resnet101',
14 | 'resnet152']
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=1, bias=False)
30 |
31 | # class Scale2d(nn.Module):
32 | # def __init__(self, out_channels):
33 | # super(Scale2d, self).__init__()
34 | # self.scale = Parameter(torch.Tensor(1, out_channels, 1, 1))
35 |
36 | # def forward(self, input):
37 | # return input * self.scale
38 |
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None):
43 | super(BasicBlock, self).__init__()
44 | self.conv1 = conv3x3(inplanes, planes, stride)
45 | self.bn1 = nn.BatchNorm2d(planes)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.conv2 = conv3x3(planes, planes)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.downsample = downsample
50 | self.stride = stride
51 |
52 | def forward(self, x):
53 | residual = x
54 |
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | out = self.bn2(out)
61 |
62 | if self.downsample is not None:
63 | residual = self.downsample(x)
64 |
65 | out += residual
66 | out = self.relu(out)
67 |
68 | return out
69 |
70 |
71 | class Bottleneck(nn.Module):
72 | expansion = 4
73 |
74 | def __init__(self, inplanes, planes, stride=1, downsample=None):
75 | super(Bottleneck, self).__init__()
76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
77 | self.bn1 = nn.BatchNorm2d(planes)
78 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
79 | padding=1, bias=False)
80 | self.bn2 = nn.BatchNorm2d(planes)
81 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
82 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
83 | self.relu = nn.ReLU(inplace=True)
84 | self.downsample = downsample
85 | self.stride = stride
86 |
87 | def forward(self, x):
88 | residual = x
89 |
90 | out = self.conv1(x)
91 | out = self.bn1(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv2(out)
95 | out = self.bn2(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv3(out)
99 | out = self.bn3(out)
100 |
101 | if self.downsample is not None:
102 | residual = self.downsample(x)
103 |
104 | out += residual
105 | out = self.relu(out)
106 |
107 | return out
108 |
109 | class PointBottleneck(nn.Module):
110 | expansion = 4
111 |
112 | def __init__(self, inplanes, planes, stride=1, downsample=None):
113 | super(PointBottleneck, self).__init__()
114 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
115 | self.bn1 = nn.BatchNorm2d(planes)
116 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
117 | padding=1, bias=False)
118 | self.bn2 = nn.BatchNorm2d(planes)
119 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
120 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
121 | self.relu = nn.ReLU(inplace=True)
122 | self.downsample = downsample
123 | if self.downsample is None:
124 | self.conv_p = nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, bias=False)
125 | self.stride = stride
126 |
127 | def forward(self, x):
128 | residual = x
129 |
130 | out = self.conv1(x)
131 | out = self.bn1(out)
132 | out = self.relu(out)
133 |
134 | out = self.conv2(out)
135 | out = self.bn2(out)
136 | out = self.relu(out)
137 |
138 | out = self.conv3(out)
139 | out = self.bn3(out)
140 |
141 | if self.downsample is not None:
142 | residual = self.downsample(x)
143 | else:
144 | residual = self.conv_p(x)
145 |
146 | out += residual
147 | out = self.relu(out)
148 |
149 | return out
150 |
151 | class SCBottleneck(nn.Module):
152 | expansion = 4
153 |
154 | def __init__(self, inplanes, planes, stride=1, downsample=None):
155 | super(SCBottleneck, self).__init__()
156 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
157 | self.bn1 = nn.BatchNorm2d(planes)
158 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
159 | padding=1, bias=False)
160 | self.bn2 = nn.BatchNorm2d(planes)
161 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
162 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
163 | self.relu = nn.ReLU(inplace=True)
164 | self.sc = Scale2d(out_channels=inplanes)
165 | self.downsample = downsample
166 | self.stride = stride
167 |
168 | def forward(self, x):
169 | # residual = x
170 |
171 | out = self.conv1(x)
172 | out = self.bn1(out)
173 | out = self.relu(out)
174 |
175 | out = self.conv2(out)
176 | out = self.bn2(out)
177 | out = self.relu(out)
178 |
179 | out = self.conv3(out)
180 | out = self.bn3(out)
181 |
182 | # if residual.device == torch.device('cuda:0'):
183 | # print(self.sc.scale.view(-1)[:20].data)
184 | residual = self.sc(x)
185 | if self.downsample is not None:
186 | residual = self.downsample(residual)
187 |
188 | out += residual
189 | out = self.relu(out)
190 |
191 | return out
192 |
193 |
194 | class ResNet(nn.Module):
195 |
196 | def __init__(self, block, layers, num_classes=1000, feat=False):
197 | self.inplanes = 64
198 | super(ResNet, self).__init__()
199 | self.feat = feat
200 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
201 | bias=False)
202 | self.bn1 = nn.BatchNorm2d(64)
203 | self.relu = nn.ReLU(inplace=True)
204 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
205 | self.layer1 = self._make_layer(block, 64, layers[0])
206 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
207 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
208 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
209 | self.avgpool = nn.AvgPool2d(7, stride=1)
210 | self.feat_dim = 512 * block.expansion
211 | if not feat:
212 | self.fc = nn.Linear(512 * block.expansion, num_classes)
213 |
214 | for m in self.modules():
215 | if isinstance(m, nn.Conv2d):
216 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
217 | elif isinstance(m, nn.BatchNorm2d):
218 | nn.init.constant_(m.weight, 1)
219 | nn.init.constant_(m.bias, 0)
220 | elif isinstance(m, Scale2d):
221 | nn.init.constant_(m.scale, 1)
222 |
223 | def _make_layer(self, block, planes, blocks, stride=1):
224 | downsample = None
225 | if stride != 1 or self.inplanes != planes * block.expansion:
226 | downsample = nn.Sequential(
227 | nn.Conv2d(self.inplanes, planes * block.expansion,
228 | kernel_size=1, stride=stride, bias=False),
229 | nn.BatchNorm2d(planes * block.expansion),
230 | )
231 |
232 | layers = []
233 | layers.append(block(self.inplanes, planes, stride, downsample))
234 | self.inplanes = planes * block.expansion
235 | for i in range(1, blocks):
236 | layers.append(block(self.inplanes, planes))
237 |
238 | return nn.Sequential(*layers)
239 |
240 | def forward(self, x):
241 | x = self.conv1(x)
242 | x = self.bn1(x)
243 | x = self.relu(x)
244 | x = self.maxpool(x)
245 |
246 | x = self.layer1(x)
247 | x = self.layer2(x)
248 | x = self.layer3(x)
249 | x = self.layer4(x)
250 |
251 | x = self.avgpool(x)
252 | x = x.view(x.size(0), -1)
253 | if not self.feat:
254 | x = self.fc(x)
255 |
256 | return x
257 |
258 |
259 | def part_state_dict(state_dict, model_dict):
260 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
261 | model_dict.update(pretrained_dict)
262 | return model_dict
263 |
264 |
265 | def resnet18(pretrained=False, feat=False, **kwargs):
266 | """Constructs a ResNet-18 model.
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | """
270 | model = ResNet(BasicBlock, [2, 2, 2, 2], feat=feat, **kwargs)
271 | if feat:
272 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet18']), model.state_dict())
273 | if pretrained:
274 | model.load_state_dict(state_dict)
275 | return model
276 |
277 |
278 | def resnet34(pretrained=False, feat=False, **kwargs):
279 | """Constructs a ResNet-34 model.
280 | Args:
281 | pretrained (bool): If True, returns a model pre-trained on ImageNet
282 | """
283 | model = ResNet(BasicBlock, [3, 4, 6, 3], feat=feat, **kwargs)
284 | if feat:
285 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet34']), model.state_dict())
286 | if pretrained:
287 | model.load_state_dict(state_dict)
288 | return model
289 |
290 | def resnet26(pretrained=False, feat=False, **kwargs):
291 | """Constructs a ResNet-50 model.
292 | Args:
293 | pretrained (bool): If True, returns a model pre-trained on ImageNet
294 | """
295 | model = ResNet(Bottleneck, [2, 2, 2, 2], feat=feat, **kwargs)
296 | return model
297 |
298 | def resnet26_sc(pretrained=False, feat=False, **kwargs):
299 | """Constructs a ResNet-50 model.
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | """
303 | model = ResNet(SCBottleneck, [2, 2, 2, 2], feat=feat, **kwargs)
304 | return model
305 |
306 | def resnet26_point(pretrained=False, feat=False, **kwargs):
307 | """Constructs a ResNet-50 model.
308 | Args:
309 | pretrained (bool): If True, returns a model pre-trained on ImageNet
310 | """
311 | model = ResNet(PointBottleneck, [2, 2, 2, 2], feat=feat, **kwargs)
312 | return model
313 |
314 | def resnet50(pretrained=False, feat=False, **kwargs):
315 | """Constructs a ResNet-50 model.
316 | Args:
317 | pretrained (bool): If True, returns a model pre-trained on ImageNet
318 | """
319 | model = ResNet(Bottleneck, [3, 4, 6, 3], feat=feat, **kwargs)
320 | if feat:
321 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet50']), model.state_dict())
322 | if pretrained:
323 | model.load_state_dict(state_dict)
324 | return model
325 |
326 |
327 | def resnet101(pretrained=False, feat=False, **kwargs):
328 | """Constructs a ResNet-101 model.
329 | Args:
330 | pretrained (bool): If True, returns a model pre-trained on ImageNet
331 | """
332 | model = ResNet(Bottleneck, [3, 4, 23, 3], feat=feat, **kwargs)
333 | if feat:
334 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet101']), model.state_dict())
335 | if pretrained:
336 | model.load_state_dict(state_dict)
337 | return model
338 |
339 |
340 | def resnet152(pretrained=False, feat=False, **kwargs):
341 | """Constructs a ResNet-152 model.
342 | Args:
343 | pretrained (bool): If True, returns a model pre-trained on ImageNet
344 | """
345 | model = ResNet(Bottleneck, [3, 8, 36, 3], feat=feat, **kwargs)
346 | if feat:
347 | state_dict = part_state_dict(model_zoo.load_url(model_urls['resnet152']), model.state_dict())
348 | if pretrained:
349 | model.load_state_dict(state_dict)
350 | return model
351 |
--------------------------------------------------------------------------------
/lib/networks/resnet_3d.py:
--------------------------------------------------------------------------------
1 | """
2 | Modify the original file to make the class support feature extraction
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import math
8 | import torch.utils.model_zoo as model_zoo
9 |
10 | __all__ = ['resnet50_3d_v3','resnet26_3d_v3','resnet101_3d_v1']
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 | class GloAvgPool3d(nn.Module):
21 | def __init__(self):
22 | super(GloAvgPool3d, self).__init__()
23 | self.stride = 1
24 | self.padding = 0
25 | self.ceil_mode = False
26 | self.count_include_pad = True
27 |
28 | def forward(self, input):
29 | input_shape = input.shape
30 | kernel_size = input_shape[2:]
31 | return F.avg_pool3d(input, kernel_size, self.stride,
32 | self.padding, self.ceil_mode, self.count_include_pad)
33 |
34 | class Bottleneck3D_100(nn.Module):
35 | expansion = 4
36 |
37 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
38 | super(Bottleneck3D_100, self).__init__()
39 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1),
40 | stride=(t_stride, 1, 1),
41 | padding=(1, 0, 0), bias=False)
42 | self.bn1 = nn.BatchNorm3d(planes)
43 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
44 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False)
45 | self.bn2 = nn.BatchNorm3d(planes)
46 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.downsample = downsample
50 | self.stride = stride
51 |
52 | def forward(self, x):
53 | residual = x
54 |
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | out = self.bn2(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv3(out)
64 | out = self.bn3(out)
65 |
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 |
69 | out += residual
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 | class Bottleneck3D_101(nn.Module):
75 | expansion = 4
76 |
77 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
78 | super(Bottleneck3D_101, self).__init__()
79 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1),
80 | stride=(t_stride, 1, 1),
81 | padding=(1, 0, 0), bias=False)
82 | self.bn1 = nn.BatchNorm3d(planes)
83 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
84 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False)
85 | self.bn2 = nn.BatchNorm3d(planes)
86 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=(3, 1, 1),
87 | stride=1,
88 | padding=(1, 0, 0),
89 | bias=False)
90 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | residual = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | residual = self.downsample(x)
111 |
112 | out += residual
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 | class Bottleneck3D_000(nn.Module):
118 | expansion = 4
119 |
120 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
121 | super(Bottleneck3D_000, self).__init__()
122 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1,
123 | stride=[t_stride, 1, 1], bias=False)
124 | self.bn1 = nn.BatchNorm3d(planes)
125 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
126 | stride=[1, stride, stride], padding=(0, 1, 1), bias=False)
127 | self.bn2 = nn.BatchNorm3d(planes)
128 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
129 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
130 | self.relu = nn.ReLU(inplace=True)
131 | self.downsample = downsample
132 | self.stride = stride
133 |
134 | def forward(self, x):
135 | residual = x
136 |
137 | out = self.conv1(x)
138 | out = self.bn1(out)
139 | out = self.relu(out)
140 |
141 | out = self.conv2(out)
142 | out = self.bn2(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv3(out)
146 | out = self.bn3(out)
147 |
148 | if self.downsample is not None:
149 | residual = self.downsample(x)
150 |
151 | out += residual
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 |
157 | class ResNet3D(nn.Module):
158 |
159 | def __init__(self, block, layers, num_classes=1000, feat=False, **kwargs):
160 | if not isinstance(block, list):
161 | block = [block] * 4
162 | else:
163 | assert(len(block)) == 4, "Block number must be 4 for ResNet-Stype networks."
164 | self.inplanes = 64
165 | super(ResNet3D, self).__init__()
166 | self.feat = feat
167 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7),
168 | stride=(1, 2, 2), padding=(0, 3, 3),
169 | bias=False)
170 | self.bn1 = nn.BatchNorm3d(64)
171 | self.relu = nn.ReLU(inplace=True)
172 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
173 | self.layer1 = self._make_layer(block[0], 64, layers[0])
174 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2, t_stride=2)
175 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2, t_stride=2)
176 | self.layer4 = self._make_layer(block[3], 512, layers[3], stride=2, t_stride=2)
177 | self.avgpool = GloAvgPool3d()
178 | self.feat_dim = 512 * block[0].expansion
179 | if not feat:
180 | self.fc = nn.Linear(512 * block[0].expansion, num_classes)
181 |
182 | for n, m in self.named_modules():
183 | if isinstance(m, nn.Conv3d):
184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
185 | elif isinstance(m, nn.BatchNorm3d):
186 | nn.init.constant_(m.weight, 1)
187 | nn.init.constant_(m.bias, 0)
188 |
189 | def _make_layer(self, block, planes, blocks, stride=1, t_stride=1):
190 | downsample = None
191 | if stride != 1 or self.inplanes != planes * block.expansion:
192 | downsample = nn.Sequential(
193 | nn.Conv3d(self.inplanes, planes * block.expansion,
194 | kernel_size=1, stride=(t_stride, stride, stride), bias=False),
195 | nn.BatchNorm3d(planes * block.expansion),
196 | )
197 |
198 | layers = []
199 | layers.append(block(self.inplanes, planes, stride=stride, t_stride=t_stride, downsample=downsample))
200 | self.inplanes = planes * block.expansion
201 | for i in range(1, blocks):
202 | layers.append(block(self.inplanes, planes))
203 |
204 | return nn.Sequential(*layers)
205 |
206 | def forward(self, x):
207 | x = self.conv1(x)
208 | x = self.bn1(x)
209 | x = self.relu(x)
210 | x = self.maxpool(x)
211 |
212 | x = self.layer1(x)
213 | x = self.layer2(x)
214 | x = self.layer3(x)
215 | x = self.layer4(x)
216 |
217 |
218 | x = self.avgpool(x)
219 | x = x.view(x.size(0), -1)
220 | if not self.feat:
221 | print("WARNING!!!!!!!")
222 | x = self.fc(x)
223 |
224 | return x
225 |
226 |
227 | def part_state_dict(state_dict, model_dict):
228 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
229 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict)
230 | model_dict.update(pretrained_dict)
231 | return model_dict
232 |
233 |
234 | def inflate_state_dict(pretrained_dict, model_dict):
235 | for k in pretrained_dict.keys():
236 | if pretrained_dict[k].size() != model_dict[k].size():
237 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \
238 | "To inflate, channel number should match."
239 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \
240 | "To inflate, spatial kernel size should match."
241 | print("Layer {} needs inflation.".format(k))
242 | shape = list(pretrained_dict[k].shape)
243 | shape.insert(2, 1)
244 | t_length = model_dict[k].shape[2]
245 | pretrained_dict[k] = pretrained_dict[k].reshape(shape)
246 | if t_length != 1:
247 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length
248 | assert(pretrained_dict[k].size() == model_dict[k].size()), \
249 | "After inflation, model shape should match."
250 |
251 | return pretrained_dict
252 |
253 | def resnet50_3d_v1(pretrained=False, feat=False, **kwargs):
254 | """Constructs a ResNet-50 model.
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | """
258 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_101, Bottleneck3D_101],
259 | [3, 4, 6, 3], feat=feat, **kwargs)
260 | # import pdb
261 | # pdb.set_trace()
262 | if pretrained:
263 | if kwargs['pretrained_model'] is None:
264 | state_dict = model_zoo.load_url(model_urls['resnet50'])
265 | else:
266 | print("Using specified pretrain model")
267 | state_dict = kwargs['pretrained_model']
268 | if feat:
269 | new_state_dict = part_state_dict(state_dict, model.state_dict())
270 | model.load_state_dict(new_state_dict)
271 | return model
272 |
273 | def resnet50_3d_v2(pretrained=False, feat=False, **kwargs):
274 | """Constructs a ResNet-50 model.
275 | Args:
276 | pretrained (bool): If True, returns a model pre-trained on ImageNet
277 | """
278 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100],
279 | [3, 4, 6, 3], feat=feat, **kwargs)
280 | # import pdb
281 | # pdb.set_trace()
282 | if pretrained:
283 | if kwargs['pretrained_model'] is None:
284 | state_dict = model_zoo.load_url(model_urls['resnet50'])
285 | else:
286 | print("Using specified pretrain model")
287 | state_dict = kwargs['pretrained_model']
288 | if feat:
289 | new_state_dict = part_state_dict(state_dict, model.state_dict())
290 | model.load_state_dict(new_state_dict)
291 | return model
292 |
293 | def resnet50_3d_v3(pretrained=False, feat=False, **kwargs):
294 | """Constructs a ResNet-50 model.
295 | Args:
296 | pretrained (bool): If True, returns a model pre-trained on ImageNet
297 | """
298 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100],
299 | [3, 4, 6, 3], feat=feat, **kwargs)
300 | # import pdb
301 | # pdb.set_trace()
302 | if pretrained:
303 | if kwargs['pretrained_model'] is None:
304 | state_dict = model_zoo.load_url(model_urls['resnet50'])
305 | else:
306 | print("Using specified pretrain model")
307 | state_dict = kwargs['pretrained_model']
308 | if feat:
309 | new_state_dict = part_state_dict(state_dict, model.state_dict())
310 | model.load_state_dict(new_state_dict)
311 | return model
312 |
313 | def resnet26_3d_v1(pretrained=False, feat=False, **kwargs):
314 | """Constructs a ResNet-50 model.
315 | Args:
316 | pretrained (bool): If True, returns a model pre-trained on ImageNet
317 | """
318 | model = ResNet3D([Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100],
319 | [2, 2, 2, 2], feat=feat, **kwargs)
320 | # import pdb
321 | # pdb.set_trace()
322 | if pretrained:
323 | if kwargs['pretrained_model'] is None:
324 | raise ValueError("pretrained model must be specified")
325 | else:
326 | print("Using specified pretrain model")
327 | state_dict = kwargs['pretrained_model']
328 | if feat:
329 | new_state_dict = part_state_dict(state_dict, model.state_dict())
330 | model.load_state_dict(new_state_dict)
331 | return model
332 |
333 | def resnet26_3d_v3(pretrained=False, feat=False, **kwargs):
334 | """Constructs a ResNet-50 model.
335 | Args:
336 | pretrained (bool): If True, returns a model pre-trained on ImageNet
337 | """
338 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100, Bottleneck3D_100],
339 | [2, 2, 2, 2], feat=feat, **kwargs)
340 | # import pdb
341 | # pdb.set_trace()
342 | if pretrained:
343 | if kwargs['pretrained_model'] is None:
344 | raise ValueError("pretrained model must be specified")
345 | else:
346 | print("Using specified pretrain model")
347 | state_dict = kwargs['pretrained_model']
348 | if feat:
349 | new_state_dict = part_state_dict(state_dict, model.state_dict())
350 | model.load_state_dict(new_state_dict)
351 | return model
352 |
353 | def resnet101_3d_v1(pretrained=False, feat=False, **kwargs):
354 | """Constructs a ResNet-50 model.
355 | Args:
356 | pretrained (bool): If True, returns a model pre-trained on ImageNet
357 | """
358 | model = ResNet3D([Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_101, Bottleneck3D_101],
359 | [3, 4, 23, 3], feat=feat, **kwargs)
360 | # import pdb
361 | # pdb.set_trace()
362 | if pretrained:
363 | if kwargs['pretrained_model'] is None:
364 | state_dict = model_zoo.load_url(model_urls['resnet101'])
365 | else:
366 | print("Using specified pretrain model")
367 | state_dict = kwargs['pretrained_model']
368 | if feat:
369 | new_state_dict = part_state_dict(state_dict, model.state_dict())
370 | model.load_state_dict(new_state_dict)
371 | return model
--------------------------------------------------------------------------------
/lib/networks/resnet_3d_nodown.py:
--------------------------------------------------------------------------------
1 | """
2 | Modify the original file to make the class support feature extraction
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import math
8 | import torch.utils.model_zoo as model_zoo
9 |
10 | __all__ = ["resnet50_3d_slowonly"]
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | }
19 |
20 | class GloAvgPool3d(nn.Module):
21 | def __init__(self):
22 | super(GloAvgPool3d, self).__init__()
23 | self.stride = 1
24 | self.padding = 0
25 | self.ceil_mode = False
26 | self.count_include_pad = True
27 |
28 | def forward(self, input):
29 | input_shape = input.shape
30 | kernel_size = input_shape[2:]
31 | return F.avg_pool3d(input, kernel_size, self.stride,
32 | self.padding, self.ceil_mode, self.count_include_pad)
33 |
34 | class Bottleneck3D_100(nn.Module):
35 | expansion = 4
36 |
37 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
38 | super(Bottleneck3D_100, self).__init__()
39 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1),
40 | stride=(t_stride, 1, 1),
41 | padding=(1, 0, 0), bias=False)
42 | self.bn1 = nn.BatchNorm3d(planes)
43 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
44 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False)
45 | self.bn2 = nn.BatchNorm3d(planes)
46 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.downsample = downsample
50 | self.stride = stride
51 |
52 | def forward(self, x):
53 | residual = x
54 |
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | out = self.bn2(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv3(out)
64 | out = self.bn3(out)
65 |
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 |
69 | out += residual
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 | class Bottleneck3D_101(nn.Module):
75 | expansion = 4
76 |
77 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
78 | super(Bottleneck3D_101, self).__init__()
79 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1),
80 | stride=(t_stride, 1, 1),
81 | padding=(1, 0, 0), bias=False)
82 | self.bn1 = nn.BatchNorm3d(planes)
83 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
84 | stride=(1, stride, stride), padding=(0, 1, 1), bias=False)
85 | self.bn2 = nn.BatchNorm3d(planes)
86 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=(3, 1, 1),
87 | stride=1,
88 | padding=(1, 0, 0),
89 | bias=False)
90 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | residual = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | residual = self.downsample(x)
111 |
112 | out += residual
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 | class Bottleneck3D_000(nn.Module):
118 | expansion = 4
119 |
120 | def __init__(self, inplanes, planes, stride=1, t_stride=1, downsample=None):
121 | super(Bottleneck3D_000, self).__init__()
122 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1,
123 | stride=[t_stride, 1, 1], bias=False)
124 | self.bn1 = nn.BatchNorm3d(planes)
125 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3),
126 | stride=[1, stride, stride], padding=(0, 1, 1), bias=False)
127 | self.bn2 = nn.BatchNorm3d(planes)
128 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
129 | self.bn3 = nn.BatchNorm3d(planes * self.expansion)
130 | self.relu = nn.ReLU(inplace=True)
131 | self.downsample = downsample
132 | self.stride = stride
133 |
134 | def forward(self, x):
135 | residual = x
136 |
137 | out = self.conv1(x)
138 | out = self.bn1(out)
139 | out = self.relu(out)
140 |
141 | out = self.conv2(out)
142 | out = self.bn2(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv3(out)
146 | out = self.bn3(out)
147 |
148 | if self.downsample is not None:
149 | residual = self.downsample(x)
150 |
151 | out += residual
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 | class ResNet3D_nodown(nn.Module):
157 |
158 | def __init__(self, block, layers, num_classes=1000, feat=False, **kwargs):
159 | if not isinstance(block, list):
160 | block = [block] * 4
161 | else:
162 | assert(len(block)) == 4, "Block number must be 4 for ResNet-Stype networks."
163 | self.inplanes = 64
164 | super(ResNet3D_nodown, self).__init__()
165 | self.feat = feat
166 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7),
167 | stride=(1, 2, 2), padding=(0, 3, 3),
168 | bias=False)
169 | self.bn1 = nn.BatchNorm3d(64)
170 | self.relu = nn.ReLU(inplace=True)
171 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
172 | self.layer1 = self._make_layer(block[0], 64, layers[0])
173 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2)
174 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2)
175 | self.layer4 = self._make_layer(block[3], 512, layers[3], stride=2)
176 | self.avgpool = GloAvgPool3d()
177 | self.feat_dim = 512 * block[0].expansion
178 | if not feat:
179 | self.fc = nn.Linear(512 * block[0].expansion, num_classes)
180 |
181 | for n, m in self.named_modules():
182 | if isinstance(m, nn.Conv3d):
183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
184 | elif isinstance(m, nn.BatchNorm3d):
185 | nn.init.constant_(m.weight, 1)
186 | nn.init.constant_(m.bias, 0)
187 |
188 | def _make_layer(self, block, planes, blocks, stride=1, t_stride=1):
189 | downsample = None
190 | if stride != 1 or self.inplanes != planes * block.expansion:
191 | downsample = nn.Sequential(
192 | nn.Conv3d(self.inplanes, planes * block.expansion,
193 | kernel_size=1, stride=(t_stride, stride, stride), bias=False),
194 | nn.BatchNorm3d(planes * block.expansion),
195 | )
196 |
197 | layers = []
198 | layers.append(block(self.inplanes, planes, stride=stride, t_stride=t_stride, downsample=downsample))
199 | self.inplanes = planes * block.expansion
200 | for i in range(1, blocks):
201 | layers.append(block(self.inplanes, planes))
202 |
203 | return nn.Sequential(*layers)
204 |
205 | def forward(self, x):
206 | x = self.conv1(x)
207 | x = self.bn1(x)
208 | x = self.relu(x)
209 | x = self.maxpool(x)
210 |
211 | x = self.layer1(x)
212 | x = self.layer2(x)
213 | x = self.layer3(x)
214 | x = self.layer4(x)
215 |
216 | x = self.avgpool(x)
217 | # print(x.shape)
218 | x = x.view(x.size(0), -1)
219 | if not self.feat:
220 | print("WARNING!!!!!!!")
221 | x = self.fc(x)
222 |
223 | return x
224 |
225 | def part_state_dict(state_dict, model_dict):
226 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
227 | pretrained_dict = inflate_state_dict(pretrained_dict, model_dict)
228 | model_dict.update(pretrained_dict)
229 | return model_dict
230 |
231 |
232 | def inflate_state_dict(pretrained_dict, model_dict):
233 | for k in pretrained_dict.keys():
234 | if pretrained_dict[k].size() != model_dict[k].size():
235 | assert(pretrained_dict[k].size()[:2] == model_dict[k].size()[:2]), \
236 | "To inflate, channel number should match."
237 | assert(pretrained_dict[k].size()[-2:] == model_dict[k].size()[-2:]), \
238 | "To inflate, spatial kernel size should match."
239 | print("Layer {} needs inflation.".format(k))
240 | shape = list(pretrained_dict[k].shape)
241 | shape.insert(2, 1)
242 | t_length = model_dict[k].shape[2]
243 | pretrained_dict[k] = pretrained_dict[k].reshape(shape)
244 | if t_length != 1:
245 | pretrained_dict[k] = pretrained_dict[k].expand_as(model_dict[k]) / t_length
246 | assert(pretrained_dict[k].size() == model_dict[k].size()), \
247 | "After inflation, model shape should match."
248 |
249 | return pretrained_dict
250 |
251 | def resnet50_3d_slowonly(pretrained=False, feat=False, **kwargs):
252 | """Constructs a ResNet-50 model.
253 | Args:
254 | pretrained (bool): If True, returns a model pre-trained on ImageNet
255 | """
256 | model = ResNet3D_nodown([Bottleneck3D_000, Bottleneck3D_000, Bottleneck3D_100, Bottleneck3D_100],
257 | [3, 4, 6, 3], feat=feat, **kwargs)
258 | # import pdb
259 | # pdb.set_trace()
260 | if pretrained:
261 | if kwargs['pretrained_model'] is None:
262 | state_dict = model_zoo.load_url(model_urls['resnet50'])
263 | else:
264 | print("Using specified pretrain model")
265 | state_dict = kwargs['pretrained_model']
266 | if feat:
267 | new_state_dict = part_state_dict(state_dict, model.state_dict())
268 | model.load_state_dict(new_state_dict)
269 | return model
270 |
--------------------------------------------------------------------------------
/lib/opts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import argparse
4 |
5 | def set_logger(debug_mode=False):
6 | import time
7 | from time import gmtime, strftime
8 | logdir = os.path.join(args.experiment_root, 'log')
9 | if not os.path.exists(logdir):
10 | os.makedirs(logdir)
11 | log_file = "logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", time.localtime())
12 | log_file = os.path.join(logdir, log_file)
13 | handlers = [logging.FileHandler(log_file), logging.StreamHandler()]
14 |
15 | """ add '%(filename)s:%(lineno)d %(levelname)s:' to format show source file """
16 | logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO,
17 | format='%(asctime)s: %(message)s',
18 | datefmt='%Y-%m-%d %H:%M:%S',
19 | handlers = handlers)
20 |
21 | parser = argparse.ArgumentParser(description="PyTorch implementation of Video Classification")
22 | parser.add_argument('dataset', type=str)
23 | parser.add_argument('train_list', type=str)
24 | parser.add_argument('val_list', type=str)
25 |
26 | # ========================= Model Configs ==========================
27 | parser.add_argument('--arch', '-a', type=str, default="resnet18")
28 | parser.add_argument('--shadow', action='store_true')
29 | parser.add_argument('--dropout', '--do', default=0.2, type=float,
30 | metavar='DO', help='dropout ratio (default: 0.2)')
31 | parser.add_argument('--mode', type=str, default='3D', choices=['3D', 'TSN', '2D'])
32 | parser.add_argument('--new_size', type=int, default=256)
33 | parser.add_argument('--crop_size', type=int, default=224)
34 | parser.add_argument('--t_length', type=int, default=32, help="time length")
35 | parser.add_argument('--t_stride', type=int, default=2, help="time stride between frames")
36 | parser.add_argument('--num_segments', type=int, default=1)
37 | parser.add_argument('--pretrained', action='store_true')
38 | parser.add_argument('--pretrained_model', type=str, default=None)
39 |
40 | # ========================= Learning Configs ==========================
41 | parser.add_argument('--epochs', default=60, type=int, metavar='N',
42 | help='number of total epochs to run')
43 | parser.add_argument('-b', '--batch-size', default=256, type=int,
44 | metavar='N', help='mini-batch size (default: 256)')
45 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
46 | metavar='LR', help='initial learning rate')
47 | parser.add_argument('--lr_steps', default=[40, 70, 70], type=float, nargs="+",
48 | metavar='LRSteps', help='epochs to decay learning rate by 10')
49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
50 | help='momentum')
51 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
52 | metavar='W', help='weight decay (default: 5e-4)')
53 |
54 | # ========================= Monitor Configs ==========================
55 | parser.add_argument('--print-freq', '-p', default=20, type=int,
56 | metavar='N', help='print frequency (default: 20)')
57 | parser.add_argument('--eval-freq', '-ef', default=2, type=int,
58 | metavar='N', help='evaluation frequency (default: 2)')
59 |
60 | # ========================= Runtime Configs ==========================
61 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
62 | help='number of data loading workers (default: 4)')
63 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
64 | help='path to latest checkpoint (default: none)')
65 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
66 | help='evaluate model on validation set')
67 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
68 | help='manual epoch number (useful on restarts)')
69 | parser.add_argument('--output_root', type=str, default="./output")
70 | parser.add_argument('--image_tmpl', type=str, default="image_{:06d}.jpg")
71 |
72 | args = parser.parse_args()
73 | if args.mode == "2D":
74 | args.t_length = 1
75 |
76 | experiment_id = '_'.join(map(str, [args.dataset, args.arch, args.mode,
77 | 'length'+str(args.t_length), 'stride'+str(args.t_stride),
78 | 'dropout'+str(args.dropout)]))
79 |
80 | if args.pretrained and args.pretrained_model:
81 | if "2d" in args.pretrained_model:
82 | experiment_id += '_2dpretrained'
83 |
84 | if args.shadow:
85 | experiment_id += '_shadow'
86 |
87 | args.experiment_root = os.path.join(args.output_root, experiment_id)
88 | # init logger
89 | set_logger()
90 | logging.info(args)
91 | if not os.path.exists(args.experiment_root):
92 | os.makedirs(args.experiment_root)
93 |
--------------------------------------------------------------------------------
/lib/transforms.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 | import math
7 | import torch
8 |
9 | class GroupRandomCrop(object):
10 | def __init__(self, size):
11 | if isinstance(size, numbers.Number):
12 | self.size = (int(size), int(size))
13 | else:
14 | self.size = size
15 |
16 | def __call__(self, img_group):
17 |
18 | w, h = img_group[0].size
19 | th, tw = self.size
20 |
21 | out_images = list()
22 |
23 | x1 = random.randint(0, w - tw)
24 | y1 = random.randint(0, h - th)
25 |
26 | for img in img_group:
27 | assert(img.size[0] == w and img.size[1] == h)
28 | if w == tw and h == th:
29 | out_images.append(img)
30 | else:
31 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
32 |
33 | return out_images
34 |
35 | class GroupCenterCrop(object):
36 | def __init__(self, size):
37 | self.worker = torchvision.transforms.CenterCrop(size)
38 |
39 | def __call__(self, img_group):
40 | return [self.worker(img) for img in img_group]
41 |
42 |
43 | class GroupRandomHorizontalFlip(object):
44 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
45 | There is no need to define an init function.
46 | """
47 | def __call__(self, img_group):
48 | v = random.random()
49 | if v < 0.5:
50 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
51 | return ret
52 | else:
53 | return img_group
54 |
55 | class GroupNormalize(object):
56 | def __init__(self,
57 | mean=[0.485, 0.456, 0.406],
58 | std=[0.229, 0.224, 0.225]):
59 | self.mean = mean
60 | self.std = std
61 |
62 | def __call__(self, tensor):
63 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
64 | rep_std = self.std * (tensor.size()[0]//len(self.std))
65 |
66 | # TODO: make efficient
67 | for t, m, s in zip(tensor, rep_mean, rep_std):
68 | t.sub_(m).div_(s)
69 |
70 | return tensor
71 |
72 |
73 | class GroupScale(object):
74 | """ Rescales the input PIL.Image to the given 'size'.
75 | 'size' will be the size of the smaller edge.
76 | For example, if height > width, then image will be
77 | rescaled to (size * height / width, size)
78 | size: size of the smaller edge
79 | interpolation: Default: PIL.Image.BILINEAR
80 | """
81 |
82 | def __init__(self, size, interpolation=Image.BILINEAR):
83 | self.worker = torchvision.transforms.Resize(size, interpolation)
84 |
85 | def __call__(self, img_group):
86 | return [self.worker(img) for img in img_group]
87 |
88 | class GroupRandomScale(object):
89 | """ Rescales the input PIL.Image to the given 'size'.
90 | 'size' will be the size of the smaller edge.
91 | For example, if height > width, then image will be
92 | rescaled to (size * height / width, size)
93 | size: size of the smaller edge
94 | interpolation: Default: PIL.Image.BILINEAR
95 | """
96 |
97 | def __init__(self, smallest_size=256, largest_size=320, interpolation=Image.BILINEAR):
98 | self.smallest_size = smallest_size
99 | self.largest_size = largest_size
100 | self.interpolation = interpolation
101 |
102 | def __call__(self, img_group):
103 | size = random.randint(self.smallest_size, self.largest_size)
104 | # print(size)
105 | self.worker = torchvision.transforms.Resize(size, self.interpolation)
106 | return [self.worker(img) for img in img_group]
107 |
108 | class GroupOverSample(object):
109 | def __init__(self, crop_size, scale_size=None):
110 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
111 |
112 | if scale_size is not None:
113 | self.scale_worker = GroupScale(scale_size)
114 | else:
115 | self.scale_worker = None
116 |
117 | def __call__(self, img_group):
118 |
119 | if self.scale_worker is not None:
120 | img_group = self.scale_worker(img_group)
121 |
122 | image_w, image_h = img_group[0].size
123 | crop_w, crop_h = self.crop_size
124 |
125 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
126 | oversample_group = list()
127 | for o_w, o_h in offsets:
128 | normal_group = list()
129 | flip_group = list()
130 | for i, img in enumerate(img_group):
131 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
132 | normal_group.append(crop)
133 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
134 | flip_group.append(flip_crop)
135 |
136 | oversample_group.extend(normal_group)
137 | oversample_group.extend(flip_group)
138 | return oversample_group
139 |
140 | class GroupOverSampleKaiming(object):
141 | def __init__(self, crop_size, scale_size=None):
142 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
143 |
144 | if scale_size is not None:
145 | self.scale_worker = GroupScale(scale_size)
146 | else:
147 | self.scale_worker = None
148 |
149 | def __call__(self, img_group):
150 |
151 | if self.scale_worker is not None:
152 | img_group = self.scale_worker(img_group)
153 |
154 | image_w, image_h = img_group[0].size
155 | crop_w, crop_h = self.crop_size
156 |
157 | offsets = self.fill_fix_offset(image_w, image_h, crop_w, crop_h)
158 | oversample_group = list()
159 | for o_w, o_h in offsets:
160 | normal_group = list()
161 | # flip_group = list()
162 | for i, img in enumerate(img_group):
163 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
164 | normal_group.append(crop)
165 | # flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
166 | # flip_group.append(flip_crop)
167 |
168 | oversample_group.extend(normal_group)
169 | # oversample_group.extend(flip_group)
170 | return oversample_group
171 |
172 | def fill_fix_offset(self, image_w, image_h, crop_w, crop_h):
173 | # assert(crop_h == image_h), "In Kaiming mode, crop_h should equal to image_h"
174 | ret = list()
175 | if image_w == 256:
176 | h_step = (image_h - crop_h) // 4
177 | ret.append((0, 0)) # upper
178 | ret.append((0, 4 * h_step)) # down
179 | ret.append((0, 2 * h_step)) # center
180 | elif image_h == 256:
181 | w_step = (image_w - crop_w) // 4
182 | ret.append((0, 0)) # left
183 | ret.append((4 * w_step, 0)) # right
184 | ret.append((2 * w_step, 0)) # center
185 | else:
186 | raise ValueError("Either image_w or image_h should be equal to 256")
187 |
188 | return ret
189 |
190 |
191 | class GroupMultiScaleCrop(object):
192 |
193 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
194 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
195 | self.scales = scales if scales is not None else [1, .875, .75, .66]
196 | self.max_distort = max_distort
197 | self.fix_crop = fix_crop
198 | self.more_fix_crop = more_fix_crop
199 | self.interpolation = Image.BILINEAR
200 |
201 | def __call__(self, img_group):
202 |
203 | im_size = img_group[0].size
204 |
205 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
206 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
207 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
208 | for img in crop_img_group]
209 | return ret_img_group
210 |
211 | def _sample_crop_size(self, im_size):
212 | image_w, image_h = im_size[0], im_size[1]
213 |
214 | # find a crop size
215 | base_size = min(image_w, image_h)
216 | crop_sizes = [int(base_size * x) for x in self.scales]
217 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
218 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
219 |
220 | pairs = []
221 | for i, h in enumerate(crop_h):
222 | for j, w in enumerate(crop_w):
223 | if abs(i - j) <= self.max_distort:
224 | pairs.append((w, h))
225 |
226 | crop_pair = random.choice(pairs)
227 | if not self.fix_crop:
228 | w_offset = random.randint(0, image_w - crop_pair[0])
229 | h_offset = random.randint(0, image_h - crop_pair[1])
230 | else:
231 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
232 |
233 | return crop_pair[0], crop_pair[1], w_offset, h_offset
234 |
235 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
236 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
237 | return random.choice(offsets)
238 |
239 | @staticmethod
240 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
241 | w_step = (image_w - crop_w) // 4
242 | h_step = (image_h - crop_h) // 4
243 |
244 | ret = list()
245 | ret.append((0, 0)) # upper left
246 | ret.append((4 * w_step, 0)) # upper right
247 | ret.append((0, 4 * h_step)) # lower left
248 | ret.append((4 * w_step, 4 * h_step)) # lower right
249 | ret.append((2 * w_step, 2 * h_step)) # center
250 |
251 | if more_fix_crop:
252 | ret.append((0, 2 * h_step)) # center left
253 | ret.append((4 * w_step, 2 * h_step)) # center right
254 | ret.append((2 * w_step, 4 * h_step)) # lower center
255 | ret.append((2 * w_step, 0 * h_step)) # upper center
256 |
257 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
258 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
259 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
260 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
261 |
262 | return ret
263 |
264 |
265 | class GroupRandomSizedCrop(object):
266 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
267 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
268 | This is popularly used to train the Inception networks
269 | size: size of the smaller edge
270 | interpolation: Default: PIL.Image.BILINEAR
271 | """
272 | def __init__(self, size, interpolation=Image.BILINEAR):
273 | self.size = size
274 | self.interpolation = interpolation
275 |
276 | def __call__(self, img_group):
277 | for attempt in range(10):
278 | area = img_group[0].size[0] * img_group[0].size[1]
279 | target_area = random.uniform(0.08, 1.0) * area
280 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
281 |
282 | w = int(round(math.sqrt(target_area * aspect_ratio)))
283 | h = int(round(math.sqrt(target_area / aspect_ratio)))
284 |
285 | if random.random() < 0.5:
286 | w, h = h, w
287 |
288 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
289 | x1 = random.randint(0, img_group[0].size[0] - w)
290 | y1 = random.randint(0, img_group[0].size[1] - h)
291 | found = True
292 | break
293 | else:
294 | found = False
295 | x1 = 0
296 | y1 = 0
297 |
298 | if found:
299 | out_group = list()
300 | for img in img_group:
301 | img = img.crop((x1, y1, x1 + w, y1 + h))
302 | assert(img.size == (w, h))
303 | out_group.append(img.resize((self.size, self.size), self.interpolation))
304 | return out_group
305 | else:
306 | # Fallback
307 | scale = GroupScale(self.size, interpolation=self.interpolation)
308 | crop = GroupRandomCrop(self.size)
309 | return crop(scale(img_group))
310 |
311 |
312 | class Stack(object):
313 |
314 | def __init__(self, mode="3D"):
315 | """Support modes: ["3D", "TSN", "2D", "TSN+3D"]
316 | """
317 | assert(mode in ["3D", "TSN+2D", "2D", "TSN+3D"]), "Unsupported mode: {}".format()
318 | self.mode = mode
319 |
320 | def __call__(self, img_group):
321 | """Only support RGB mode now
322 | img_group: list([h, w, c])
323 | """
324 | assert(img_group[0].mode == 'RGB'), "Must read images in RGB mode."
325 | if "3D" in self.mode:
326 | imgs = np.concatenate([np.array(img)[np.newaxis, ...] for img in img_group], axis=0)
327 | imgs = torch.from_numpy(imgs).permute(3, 0, 1, 2).contiguous()
328 | elif "2D" in self.mode:
329 | imgs = np.concatenate([np.array(img) for img in img_group], axis=2)
330 | imgs = torch.from_numpy(imgs).permute(2, 0, 1).contiguous()
331 | else:
332 | raise Exception("Unsupported mode.")
333 | return imgs
334 |
335 |
336 | class ToTorchFormatTensor(object):
337 | """ Converts a torch.Tensor in the range [0, 255]
338 | to a torch.FloatTensor in the range [0.0, 1.0] """
339 | def __init__(self, div=True):
340 | self.div = div
341 |
342 | def __call__(self, imgs):
343 | assert(isinstance(imgs, torch.Tensor)), "pic must be torch.Tensor."
344 | return imgs.float().div(255) if self.div else img.float()
345 |
346 |
347 | class IdentityTransform(object):
348 |
349 | def __call__(self, data):
350 | return data
351 |
352 |
353 | if __name__ == "__main__":
354 | trans = torchvision.transforms.Compose([
355 | GroupMultiScaleCrop(input_size=224, scales=[1, .875, .75, .66]),
356 | Stack(mode="2D"),
357 | ToTorchFormatTensor(),
358 | GroupNormalize()]
359 | )
360 |
361 | im = Image.open('/home/leizhou/CVPR2019/vid_cls/lena.png')
362 |
363 | color_group = [im]
364 | rst = trans(color_group)
--------------------------------------------------------------------------------
/lib/utils/deprefix.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import pdb
4 |
5 | parser = argparse.ArgumentParser(description="Remove PyTorch Model Prefix")
6 | parser.add_argument('src_model', type=str)
7 | parser.add_argument('dst_model', type=str)
8 |
9 | args = parser.parse_args()
10 | state_dict = torch.load(args.src_model, map_location=lambda storage, loc: storage)
11 | state_dict = state_dict['state_dict']
12 | pdb.set_trace()
13 | state_dict = {('.'.join(k.split('.')[1:]) if "module" in k else k): v for k, v in state_dict.items()}
14 | state_dict = {('.'.join(k.split('.')[1:]) if "base_model" in k else k): v for k, v in state_dict.items()}
15 | torch.save(state_dict, args.dst_model)
16 |
--------------------------------------------------------------------------------
/lib/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import logging
4 | import torch
5 | import shutil
6 |
7 | __all__ = ['AverageMeter', 'save_checkpoint', 'adjust_learning_rate', 'accuracy']
8 |
9 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
10 | def __init__(
11 | self,
12 | optimizer,
13 | milestones,
14 | gamma=0.1,
15 | warmup_factor=1.0 / 3,
16 | warmup_iters=500,
17 | warmup_method="linear",
18 | last_epoch=-1,
19 | ):
20 | if not list(milestones) == sorted(milestones):
21 | raise ValueError(
22 | "Milestones should be a list of" " increasing integers. Got {}",
23 | milestones,
24 | )
25 |
26 | if warmup_method not in ("constant", "linear"):
27 | raise ValueError(
28 | "Only 'constant' or 'linear' warmup_method accepted"
29 | "got {}".format(warmup_method)
30 | )
31 | self.milestones = milestones
32 | self.gamma = gamma
33 | self.warmup_factor = warmup_factor
34 | self.warmup_iters = warmup_iters
35 | self.warmup_method = warmup_method
36 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
37 |
38 | def get_lr(self):
39 | warmup_factor = 1
40 | if self.last_epoch < self.warmup_iters:
41 | if self.warmup_method == "constant":
42 | warmup_factor = self.warmup_factor
43 | elif self.warmup_method == "linear":
44 | alpha = float(self.last_epoch) / self.warmup_iters
45 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
46 | return [
47 | base_lr
48 | * warmup_factor
49 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
50 | for base_lr in self.base_lrs
51 | ]
52 |
53 | class AverageMeter(object):
54 | """Computes and stores the average and current value"""
55 | def __init__(self):
56 | self.reset()
57 |
58 | def reset(self):
59 | self.val = 0
60 | self.avg = 0
61 | self.sum = 0
62 | self.count = 0
63 |
64 | def update(self, val, n=1):
65 | self.val = val
66 | self.sum += val * n
67 | self.count += n
68 | self.avg = self.sum / self.count
69 |
70 | def save_checkpoint(state, is_best, epoch, experiment_root, filename='checkpoint_{}epoch.pth'):
71 | filename = os.path.join(experiment_root, filename.format(epoch))
72 | logging.info("saving model to {}...".format(filename))
73 | torch.save(state, filename)
74 | if is_best:
75 | best_name = os.path.join(experiment_root, 'model_best.pth')
76 | shutil.copyfile(filename, best_name)
77 | logging.info("saving done.")
78 |
79 | def adjust_learning_rate(optimizer, base_lr, epoch, lr_steps):
80 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
81 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
82 | lr = base_lr * decay
83 | for param_group in optimizer.param_groups:
84 | param_group['lr'] = lr
85 |
86 | def accuracy(output, target, topk=(1,)):
87 | """Computes the precision@k for the specified values of k"""
88 | maxk = max(topk)
89 | batch_size = target.size(0)
90 |
91 | _, pred = output.topk(maxk, 1, True, True)
92 | pred = pred.t()
93 | correct = pred.eq(target.view(1, -1).expand_as(pred))
94 |
95 | res = []
96 | for k in topk:
97 | correct_k = correct[:k].view(-1).float().sum(0)
98 | res.append(correct_k.mul_(100.0 / batch_size))
99 | return res
--------------------------------------------------------------------------------
/lib/utils/vis_comb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 |
4 | # files = os.listdir("log")
5 | # for ind, file in enumerate(files):
6 | # if not file.startswith('logfile'):
7 | # files.pop(ind)
8 | # print(files)
9 | # files = ["log/"+file for file in files]
10 | # files.sort()
11 |
12 | class log_parser():
13 | def __init__(self, landmark, log_file, key_words=[],
14 | base_key_words=['Loss', 'Prec@1', 'Prec@5']):
15 | super(log_parser, self).__init__()
16 | with open(log_file) as f:
17 | self.lines = f.readlines()
18 | self.log_info = dict()
19 | self.landmark = landmark
20 | self.key_words = key_words + base_key_words
21 | for word in self.key_words:
22 | self.log_info[word] = []
23 |
24 | def __add__(self, other):
25 | """Add two log parsers of same type.
26 | """
27 | assert hasattr(self, "hist") and hasattr(other, "hist"), "Parse before adding."
28 | for key in self.hist.keys():
29 | assert key in other.hist, "Mush share key when adding."
30 | self.hist[key].update(other.hist[key])
31 | return self
32 |
33 | def parse(self):
34 | # parse info into list
35 | for line in self.lines:
36 | items = line.strip().split()
37 | if self.landmark not in items:
38 | continue
39 | for word in self.key_words:
40 | assert(word in items), "Key word should be in target line."
41 | for word in self.key_words:
42 | ind = items.index(word) + 1
43 | if word == "Epoch:":
44 | self.log_info[word].append(items[ind])
45 | else:
46 | self.log_info[word].append(float(items[ind]))
47 |
48 | # convert epoch string
49 | self.convert_epoch_string()
50 | # find the key for the later dict
51 | if "Epoch:" in self.key_words:
52 | key = "Epoch:"
53 | else:
54 | key = "Epoch"
55 |
56 | # build hist
57 | self.hist = {}
58 | for word in self.key_words:
59 | if "Epoch" not in word:
60 | self.hist[word] = {}
61 | for k, v in zip(self.log_info[key], self.log_info[word]):
62 | self.hist[word].update({k: v})
63 |
64 | def convert_epoch_string(self):
65 | if "Epoch:" in self.log_info:
66 | epochs = self.log_info['Epoch:']
67 | for idx, epoch_str in enumerate(epochs):
68 | epoch_num, fraction = epoch_str[1:-2].split("][")
69 | epoch = float(epoch_num) + eval(fraction)
70 | epochs[idx] = epoch
71 |
72 | def plot(dir, tr_landmark="lr:", ts_landmark="Testing"):
73 |
74 | files = os.listdir(dir)
75 | for ind, file in enumerate(files):
76 | if not file.startswith('logfile'):
77 | files.pop(ind)
78 | # print(files)
79 | files = [os.path.join(dir, file) for file in files]
80 | files.sort()
81 |
82 | file = files[0]
83 | tr_parser_base = log_parser(tr_landmark, files[0], key_words=['Epoch:'])
84 | tr_parser_base.parse()
85 | ts_parser_base = log_parser(ts_landmark, files[0], key_words=['Epoch'])
86 | ts_parser_base.parse()
87 | if len(files) > 1:
88 | for file in files[1:]:
89 | tr_parser = log_parser(tr_landmark, file, key_words=['Epoch:'])
90 | tr_parser.parse()
91 | ts_parser = log_parser(ts_landmark, file, key_words=['Epoch'])
92 | ts_parser.parse()
93 | tr_parser_base += tr_parser
94 | ts_parser_base += ts_parser
95 |
96 | return ts_parser_base
97 | # fig, ax = plt.subplots()
98 | # ax.plot(tr_parser_base.hist['Loss'].keys(), tr_parser_base.hist['Loss'].values(), label='Train Loss')
99 | # ax.plot(ts_parser_base.hist['Loss'].keys(), ts_parser_base.hist['Loss'].values(), label='Val Loss')
100 | # ax.set(xlabel="Epoch", ylabel='Loss', title='Loss')
101 | # ax.grid()
102 | # ax.legend(loc='upper right', shadow=False, fontsize='x-large')
103 | # plt.show()
104 |
105 | # fig, ax = plt.subplots()
106 | # ax.plot(ts_parser_base.hist['Prec@1'].keys(), ts_parser_base.hist['Prec@1'].values(), label='Prec@1')
107 | # ax.plot(ts_parser_base.hist['Prec@5'].keys(), ts_parser_base.hist['Prec@5'].values(), 'g--', label='Prec@5')
108 | # ax.set(xlabel="Epoch", ylabel='Prec', title='Test Acc')
109 | # ax.grid()
110 | # ax.legend(loc='lower right', shadow=False, fontsize='x-large')
111 | # plt.show()
112 |
113 | def designated_plot(baseline_parser, sd2_st1_parser, sd2_st4_parser, sd4_st1_parser, sd4_st4_parser, sf5_st1_parser):
114 | fig, ax = plt.subplots()
115 | ax.plot(baseline_parser.hist['Prec@1'].keys(), baseline_parser.hist['Prec@1'].values(), label='FST')
116 | ax.plot(sd2_st1_parser.hist['Prec@1'].keys(), sd2_st1_parser.hist['Prec@1'].values(), label='dilation2_stage1')
117 | ax.plot(sd2_st4_parser.hist['Prec@1'].keys(), sd2_st4_parser.hist['Prec@1'].values(), label='dilation2_stage4')
118 | ax.plot(sd4_st1_parser.hist['Prec@1'].keys(), sd4_st1_parser.hist['Prec@1'].values(), label='dilation4_stage1')
119 | ax.plot(sd4_st4_parser.hist['Prec@1'].keys(), sd4_st4_parser.hist['Prec@1'].values(), label='dilation4_stage4')
120 | ax.plot(sf5_st1_parser.hist['Prec@1'].keys(), sf5_st1_parser.hist['Prec@1'].values(), label='s_kernel5_stage1')
121 | ax.set(xlabel="Epoch", ylabel='Prec', title='Test Acc')
122 | ax.grid()
123 | ax.legend(loc='lower right', shadow=False, fontsize='x-large')
124 | plt.show()
125 |
126 | if __name__ == "__main__":
127 | baseline_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_x4_3D_length16_stride4_dropout0.2/log')
128 | sd2_st1_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd2_st1_x4_3D_length16_stride4_dropout0.2/log')
129 | sd2_st4_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd2_st4_x4_3D_length16_stride4_dropout0.2/log')
130 | sd4_st1_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd4_st1_x4_3D_length16_stride4_dropout0.2/log')
131 | sd4_st4_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sd4_st4_x4_3D_length16_stride4_dropout0.2/log')
132 | sf5_st1_parser = plot('/home/leizhou/Research/vid_cls/output/kinetics200_fst_resnet18_sf5_st1_x4_3D_length16_stride4_dropout0.2/log')
133 | designated_plot(baseline_parser, sd2_st1_parser, sd2_st4_parser, sd4_st1_parser, sd4_st4_parser, sf5_st1_parser)
--------------------------------------------------------------------------------
/lib/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 |
4 | files = os.listdir("log")
5 | for ind, file in enumerate(files):
6 | if not file.startswith('logfile'):
7 | files.pop(ind)
8 | print(files)
9 | files = ["log/"+file for file in files]
10 | files.sort()
11 |
12 | class log_parser():
13 | def __init__(self, landmark, log_file, key_words=[],
14 | base_key_words=['Loss', 'Prec@1', 'Prec@5']):
15 | super(log_parser, self).__init__()
16 | with open(log_file) as f:
17 | self.lines = f.readlines()
18 | self.log_info = dict()
19 | self.landmark = landmark
20 | self.key_words = key_words + base_key_words
21 | for word in self.key_words:
22 | self.log_info[word] = []
23 |
24 | def __add__(self, other):
25 | """Add two log parsers of same type.
26 | """
27 | assert hasattr(self, "hist") and hasattr(other, "hist"), "Parse before adding."
28 | for key in self.hist.keys():
29 | assert key in other.hist, "Mush share key when adding."
30 | self.hist[key].update(other.hist[key])
31 | return self
32 |
33 | def parse(self):
34 | # parse info into list
35 | for line in self.lines:
36 | items = line.strip().split()
37 | if self.landmark not in items:
38 | continue
39 | for word in self.key_words:
40 | assert(word in items), "Key word should be in target line."
41 | for word in self.key_words:
42 | ind = items.index(word) + 1
43 | if word == "Epoch:":
44 | self.log_info[word].append(items[ind])
45 | else:
46 | self.log_info[word].append(float(items[ind]))
47 |
48 | # convert epoch string
49 | self.convert_epoch_string()
50 | # find the key for the later dict
51 | if "Epoch:" in self.key_words:
52 | key = "Epoch:"
53 | else:
54 | key = "Epoch"
55 |
56 | # build hist
57 | self.hist = {}
58 | for word in self.key_words:
59 | if "Epoch" not in word:
60 | self.hist[word] = {}
61 | for k, v in zip(self.log_info[key], self.log_info[word]):
62 | self.hist[word].update({k: v})
63 |
64 | def convert_epoch_string(self):
65 | if "Epoch:" in self.log_info:
66 | epochs = self.log_info['Epoch:']
67 | for idx, epoch_str in enumerate(epochs):
68 | epoch_num, fraction = epoch_str[1:-2].split("][")
69 | epoch = float(epoch_num) + eval(fraction)
70 | epochs[idx] = epoch
71 |
72 | def plot(files, tr_landmark="lr:", ts_landmark="Testing"):
73 | if not isinstance(files, list):
74 | files = [files]
75 |
76 | file = files[0]
77 | tr_parser_base = log_parser(tr_landmark, files[0], key_words=['Epoch:'])
78 | tr_parser_base.parse()
79 | ts_parser_base = log_parser(ts_landmark, files[0], key_words=['Epoch'])
80 | ts_parser_base.parse()
81 | if len(files) > 1:
82 | for file in files[1:]:
83 | tr_parser = log_parser(tr_landmark, file, key_words=['Epoch:'])
84 | tr_parser.parse()
85 | ts_parser = log_parser(ts_landmark, file, key_words=['Epoch'])
86 | ts_parser.parse()
87 | tr_parser_base += tr_parser
88 | ts_parser_base += ts_parser
89 |
90 | fig, ax = plt.subplots()
91 | ax.plot(tr_parser_base.hist['Loss'].keys(), tr_parser_base.hist['Loss'].values(), label='Train Loss')
92 | ax.plot(ts_parser_base.hist['Loss'].keys(), ts_parser_base.hist['Loss'].values(), label='Val Loss')
93 | ax.set(xlabel="Epoch", ylabel='Loss', title='Loss')
94 | ax.grid()
95 | ax.legend(loc='upper right', shadow=False, fontsize='x-large')
96 | plt.show()
97 |
98 | fig, ax = plt.subplots()
99 | ax.plot(ts_parser_base.hist['Prec@1'].keys(), ts_parser_base.hist['Prec@1'].values(), label='Prec@1')
100 | ax.plot(ts_parser_base.hist['Prec@5'].keys(), ts_parser_base.hist['Prec@5'].values(), 'g--', label='Prec@5')
101 | ax.set(xlabel="Epoch", ylabel='Prec', title='Test Acc')
102 | ax.grid()
103 | ax.legend(loc='lower right', shadow=False, fontsize='x-large')
104 | plt.show()
105 |
106 | if __name__ == "__main__":
107 | plot(files)
108 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import logging
6 |
7 | import torch
8 | import torchvision
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | from lib.dataset import VideoDataSet
14 | from lib.models import VideoModule
15 | from lib.transforms import *
16 | from lib.utils.tools import *
17 | from lib.opts import args
18 | from lib.modules import *
19 |
20 | from train_val import train, validate
21 |
22 | best_metric = 0
23 |
24 | def main():
25 | global args, best_metric
26 |
27 | # specify dataset
28 | if 'ucf101' in args.dataset:
29 | num_class = 101
30 | elif 'hmdb51' in args.dataset:
31 | num_class = 51
32 | elif args.dataset == 'kinetics400':
33 | num_class = 400
34 | elif args.dataset == 'kinetics200':
35 | num_class = 200
36 | else:
37 | raise ValueError('Unknown dataset '+args.dataset)
38 |
39 | # data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
40 | # "data/{}/access".format(args.dataset))
41 |
42 | if "ucf101" in args.dataset or "hmdb51" in args.dataset:
43 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
44 | "data/{}/access".format(args.dataset[:-3]))
45 | else:
46 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
47 | "data/{}/access".format(args.dataset))
48 |
49 | # create model
50 | org_model = VideoModule(num_class=num_class,
51 | base_model_name=args.arch,
52 | dropout=args.dropout,
53 | pretrained=args.pretrained,
54 | pretrained_model=args.pretrained_model)
55 | num_params = 0
56 | for param in org_model.parameters():
57 | num_params += param.reshape((-1, 1)).shape[0]
58 | logging.info("Model Size is {:.3f}M".format(num_params/1000000))
59 |
60 | model = torch.nn.DataParallel(org_model).cuda()
61 | # model = org_model
62 |
63 | # define loss function (criterion) and optimizer
64 | criterion = torch.nn.CrossEntropyLoss().cuda()
65 |
66 | optimizer = torch.optim.SGD(model.parameters(),
67 | args.lr,
68 | momentum=args.momentum,
69 | weight_decay=args.weight_decay)
70 |
71 | # optionally resume from a checkpoint
72 | if args.resume:
73 | if os.path.isfile(args.resume):
74 | print(("=> loading checkpoint '{}'".format(args.resume)))
75 | checkpoint = torch.load(args.resume)
76 | args.start_epoch = checkpoint['epoch']
77 | best_metric = checkpoint['best_metric']
78 | model.load_state_dict(checkpoint['state_dict'])
79 | optimizer.load_state_dict(checkpoint['optimizer'])
80 | print(("=> loaded checkpoint '{}' (epoch {})"
81 | .format(args.resume, checkpoint['epoch'])))
82 | else:
83 | print(("=> no checkpoint found at '{}'".format(args.resume)))
84 |
85 | # Data loading code
86 | ## train data
87 | # train_transform = torchvision.transforms.Compose([
88 | # GroupScale(args.new_size),
89 | # GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]),
90 | # GroupRandomHorizontalFlip(),
91 | # Stack(mode=args.mode),
92 | # ToTorchFormatTensor(),
93 | # GroupNormalize(),
94 | # ])
95 | train_transform = torchvision.transforms.Compose([
96 | GroupRandomScale(),
97 | GroupRandomCrop(size=args.crop_size),
98 | GroupRandomHorizontalFlip(),
99 | Stack(mode=args.mode),
100 | ToTorchFormatTensor(),
101 | GroupNormalize(),
102 | ])
103 | train_dataset = VideoDataSet(root_path=data_root,
104 | list_file=args.train_list,
105 | t_length=args.t_length,
106 | t_stride=args.t_stride,
107 | num_segments=args.num_segments,
108 | image_tmpl=args.image_tmpl,
109 | transform=train_transform,
110 | phase="Train")
111 | train_loader = torch.utils.data.DataLoader(
112 | train_dataset,
113 | batch_size=args.batch_size, shuffle=True, drop_last=True,
114 | num_workers=args.workers, pin_memory=True)
115 |
116 | ## val data
117 | val_transform = torchvision.transforms.Compose([
118 | GroupScale(args.new_size),
119 | GroupCenterCrop(args.crop_size),
120 | Stack(mode=args.mode),
121 | ToTorchFormatTensor(),
122 | GroupNormalize(),
123 | ])
124 | val_dataset = VideoDataSet(root_path=data_root,
125 | list_file=args.val_list,
126 | t_length=args.t_length,
127 | t_stride=args.t_stride,
128 | num_segments=args.num_segments,
129 | image_tmpl=args.image_tmpl,
130 | transform=val_transform,
131 | phase="Val")
132 | val_loader = torch.utils.data.DataLoader(
133 | val_dataset,
134 | batch_size=args.batch_size, shuffle=False,
135 | num_workers=args.workers, pin_memory=True)
136 |
137 | if args.mode != "3D":
138 | cudnn.benchmark = True
139 |
140 | # validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
141 | # torch.cuda.empty_cache()
142 | if args.resume:
143 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
144 | torch.cuda.empty_cache()
145 |
146 | for epoch in range(args.start_epoch, args.epochs):
147 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps)
148 |
149 | # train for one epoch
150 | train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
151 |
152 | # evaluate on validation set
153 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
154 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1)
155 | torch.cuda.empty_cache()
156 |
157 | # remember best prec@1 and save checkpoint
158 | is_best = metric > best_metric
159 | best_metric = max(metric, best_metric)
160 | save_checkpoint({
161 | 'epoch': epoch + 1,
162 | 'arch': args.arch,
163 | 'state_dict': model.state_dict(),
164 | 'best_metric': best_metric,
165 | 'optimizer': optimizer.state_dict(),
166 | }, is_best, epoch + 1, args.experiment_root)
167 |
168 | if __name__ == '__main__':
169 | main()
170 |
--------------------------------------------------------------------------------
/main_20bn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import logging
6 |
7 | import torch
8 | import torchvision
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | from lib.dataset import VideoDataSet, ShortVideoDataSet
14 | from lib.models import VideoModule
15 | from lib.transforms import *
16 | from lib.utils.tools import *
17 | from lib.opts import args
18 | from lib.modules import *
19 |
20 | from train_val import train, validate
21 |
22 | best_metric = 0
23 |
24 | def main():
25 | global args, best_metric
26 |
27 | # specify dataset
28 | if 'sthsth_v1' in args.dataset:
29 | num_class = 174
30 | elif 'sthsth_v2' in args.dataset:
31 | num_class = 174
32 | else:
33 | raise ValueError('Unknown dataset '+args.dataset)
34 |
35 | # data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
36 | # "data/{}/access".format(args.dataset))
37 |
38 | if "ucf101" in args.dataset or "hmdb51" in args.dataset:
39 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
40 | "data/{}/access".format(args.dataset[:-3]))
41 | else:
42 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
43 | "data/{}/access".format(args.dataset))
44 |
45 | # create model
46 | org_model = VideoModule(num_class=num_class,
47 | base_model_name=args.arch,
48 | dropout=args.dropout,
49 | pretrained=args.pretrained,
50 | pretrained_model=args.pretrained_model)
51 | num_params = 0
52 | for param in org_model.parameters():
53 | num_params += param.reshape((-1, 1)).shape[0]
54 | logging.info("Model Size is {:.3f}M".format(num_params/1000000))
55 |
56 | model = torch.nn.DataParallel(org_model).cuda()
57 | # model = org_model
58 |
59 | # define loss function (criterion) and optimizer
60 | criterion = torch.nn.CrossEntropyLoss().cuda()
61 |
62 | optimizer = torch.optim.SGD(model.parameters(),
63 | args.lr,
64 | momentum=args.momentum,
65 | weight_decay=args.weight_decay)
66 |
67 | # optionally resume from a checkpoint
68 | if args.resume:
69 | if os.path.isfile(args.resume):
70 | print(("=> loading checkpoint '{}'".format(args.resume)))
71 | checkpoint = torch.load(args.resume)
72 | args.start_epoch = checkpoint['epoch']
73 | best_metric = checkpoint['best_metric']
74 | model.load_state_dict(checkpoint['state_dict'])
75 | optimizer.load_state_dict(checkpoint['optimizer'])
76 | print(("=> loaded checkpoint '{}' (epoch {})"
77 | .format(args.resume, checkpoint['epoch'])))
78 | else:
79 | print(("=> no checkpoint found at '{}'".format(args.resume)))
80 |
81 | # Data loading code
82 | ## train data
83 | train_transform = torchvision.transforms.Compose([
84 | GroupScale(args.new_size),
85 | GroupMultiScaleCrop(input_size=args.crop_size, scales=[1, .875, .75, .66]),
86 | # GroupRandomHorizontalFlip(),
87 | Stack(mode=args.mode),
88 | ToTorchFormatTensor(),
89 | GroupNormalize(),
90 | ])
91 | train_dataset = VideoDataSet(root_path=data_root,
92 | list_file=args.train_list,
93 | t_length=args.t_length,
94 | t_stride=args.t_stride,
95 | num_segments=args.num_segments,
96 | image_tmpl=args.image_tmpl,
97 | transform=train_transform,
98 | phase="Train")
99 | train_loader = torch.utils.data.DataLoader(
100 | train_dataset,
101 | batch_size=args.batch_size, shuffle=True, drop_last=True,
102 | num_workers=args.workers, pin_memory=True)
103 |
104 | ## val data
105 | val_transform = torchvision.transforms.Compose([
106 | GroupScale(args.new_size),
107 | GroupCenterCrop(args.crop_size),
108 | Stack(mode=args.mode),
109 | ToTorchFormatTensor(),
110 | GroupNormalize(),
111 | ])
112 | val_dataset = ShortVideoDataSet(root_path=data_root,
113 | list_file=args.val_list,
114 | t_length=args.t_length,
115 | t_stride=args.t_stride,
116 | num_segments=args.num_segments,
117 | image_tmpl=args.image_tmpl,
118 | transform=val_transform,
119 | phase="Val")
120 | val_loader = torch.utils.data.DataLoader(
121 | val_dataset,
122 | batch_size=args.batch_size, shuffle=False,
123 | num_workers=args.workers, pin_memory=True)
124 |
125 | if args.mode != "3D":
126 | cudnn.benchmark = True
127 |
128 | if args.resume:
129 | validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
130 | torch.cuda.empty_cache()
131 |
132 | for epoch in range(args.start_epoch, args.epochs):
133 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps)
134 |
135 | # train for one epoch
136 | train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
137 |
138 | # evaluate on validation set
139 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
140 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1)
141 | torch.cuda.empty_cache()
142 |
143 | # remember best prec@1 and save checkpoint
144 | is_best = metric > best_metric
145 | best_metric = max(metric, best_metric)
146 | save_checkpoint({
147 | 'epoch': epoch + 1,
148 | 'arch': args.arch,
149 | 'state_dict': model.state_dict(),
150 | 'best_metric': best_metric,
151 | 'optimizer': optimizer.state_dict(),
152 | }, is_best, epoch + 1, args.experiment_root)
153 |
154 | if __name__ == '__main__':
155 | main()
156 |
--------------------------------------------------------------------------------
/main_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import shutil
5 | import logging
6 |
7 | import torch
8 | import torchvision
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 |
13 | import torchvision.transforms as transforms
14 | import torchvision.datasets as datasets
15 | from torchvision.models.resnet import *
16 | from torchvision.models.vgg import *
17 | from lib.networks.resnet import resnet26, resnet26_sc, resnet26_point
18 | from lib.networks.gsv_resnet_2d_v3 import gsv_resnet50_2d_v3
19 | from lib.modules import *
20 | from lib.utils.tools import *
21 | from lib.opts import args
22 |
23 | from train_val import train, validate
24 |
25 | best_metric = 0
26 |
27 | def main():
28 | global args, best_metric
29 |
30 | # specify dataset
31 | if args.dataset == 'imagenet':
32 | num_class = 1000
33 | else:
34 | raise ValueError('Unknown dataset '+args.dataset)
35 |
36 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
37 | "data/{}/access".format(args.dataset))
38 |
39 | # create model
40 | org_model = eval(args.arch)(pretrained=args.pretrained)#, feat=False, num_classes=num_class)
41 | num_params = 0
42 | for param in org_model.parameters():
43 | num_params += param.reshape((-1, 1)).shape[0]
44 | print("Model Size is {:.3f}M".format(num_params/1000000))
45 |
46 | model = torch.nn.DataParallel(org_model).cuda()
47 | # model = org_model
48 |
49 | # define loss function (criterion) and optimizer
50 | criterion = torch.nn.CrossEntropyLoss().cuda()
51 |
52 | # scale params
53 | scale_parameters = []
54 | other_parameters = []
55 | for m in model.modules():
56 | if isinstance(m, Scale2d):
57 | scale_parameters.append(m.scale)
58 | elif isinstance(m, nn.Conv2d):
59 | other_parameters.append(m.weight)
60 | if m.bias is not None:
61 | other_parameters.append(m.bias)
62 | elif isinstance(m, nn.BatchNorm2d):
63 | other_parameters.append(m.weight)
64 | other_parameters.append(m.bias)
65 | elif isinstance(m, nn.Linear):
66 | other_parameters.append(m.weight)
67 | other_parameters.append(m.bias)
68 |
69 | optimizer = torch.optim.SGD([{"params": other_parameters},
70 | {"params": scale_parameters, "weight_decay": 0}],
71 | args.lr,
72 | momentum=args.momentum,
73 | weight_decay=args.weight_decay)
74 |
75 | # optionally resume from a checkpoint
76 | if args.resume:
77 | if os.path.isfile(args.resume):
78 | print(("=> loading checkpoint '{}'".format(args.resume)))
79 | checkpoint = torch.load(args.resume)
80 | args.start_epoch = checkpoint['epoch']
81 | best_metric = checkpoint['best_metric']
82 | model.load_state_dict(checkpoint['state_dict'])
83 | optimizer.load_state_dict(checkpoint['optimizer'])
84 | print(("=> loaded checkpoint '{}' (epoch {})"
85 | .format(args.resume, checkpoint['epoch'])))
86 | else:
87 | print(("=> no checkpoint found at '{}'".format(args.resume)))
88 |
89 | train_transform = transforms.Compose([
90 | transforms.RandomResizedCrop(224),
91 | transforms.RandomHorizontalFlip(),
92 | transforms.ToTensor(),
93 | transforms.Normalize(
94 | mean=[0.485, 0.456, 0.406],
95 | std=[0.229, 0.224, 0.225]),
96 | ])
97 | train_dataset = datasets.ImageFolder(
98 | os.path.join(data_root, 'train'),
99 | train_transform)
100 | train_loader = torch.utils.data.DataLoader(
101 | train_dataset, batch_size=args.batch_size, shuffle=True,
102 | num_workers=args.workers, pin_memory=True)
103 |
104 | val_transform = transforms.Compose([
105 | transforms.Resize(256),
106 | transforms.CenterCrop(224),
107 | transforms.ToTensor(),
108 | transforms.Normalize(
109 | mean=[0.485, 0.456, 0.406],
110 | std=[0.229, 0.224, 0.225]),
111 | ])
112 | val_dataset = datasets.ImageFolder(
113 | os.path.join(data_root, 'val'),
114 | val_transform)
115 | val_loader = torch.utils.data.DataLoader(
116 | val_dataset,
117 | batch_size=args.batch_size, shuffle=False,
118 | num_workers=args.workers, pin_memory=True)
119 |
120 | cudnn.benchmark = True
121 | # validate(val_loader, model, criterion, args.print_freq, args.start_epoch)
122 |
123 | for epoch in range(args.start_epoch, args.epochs):
124 | adjust_learning_rate(optimizer, args.lr, epoch, args.lr_steps)
125 |
126 | # train for one epoch
127 | train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
128 |
129 | # evaluate on validation set
130 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
131 | metric = validate(val_loader, model, criterion, args.print_freq, epoch + 1)
132 |
133 | # remember best prec@1 and save checkpoint
134 | is_best = metric > best_metric
135 | best_metric = max(metric, best_metric)
136 | save_checkpoint({
137 | 'epoch': epoch + 1,
138 | 'arch': args.arch,
139 | 'state_dict': model.state_dict(),
140 | 'best_metric': best_metric,
141 | 'optimizer': optimizer.state_dict(),
142 | }, is_best, epoch + 1, args.experiment_root)
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
147 |
148 |
149 |
150 |
151 |
152 | # import argparse
153 | # import os
154 | # import random
155 | # import shutil
156 | # import time
157 | # import warnings
158 |
159 | # import torch
160 | # import torch.nn as nn
161 | # import torch.nn.parallel
162 | # import torch.backends.cudnn as cudnn
163 | # import torch.distributed as dist
164 | # import torch.optim
165 | # import torch.utils.data
166 | # import torch.utils.data.distributed
167 | # import torchvision.transforms as transforms
168 | # import torchvision.datasets as datasets
169 | # import torchvision.models as models
170 | # from lib.networks.mnet2 import *
171 |
172 | # model_names = sorted(name for name in models.__dict__
173 | # if name.islower() and not name.startswith("__")
174 | # and callable(models.__dict__[name]))
175 |
176 | # parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
177 | # parser.add_argument('data', metavar='DIR',
178 | # help='path to dataset')
179 | # # parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
180 | # # choices=model_names,
181 | # # help='model architecture: ' +
182 | # # ' | '.join(model_names) +
183 | # # ' (default: resnet18)')
184 | # parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
185 | # help='model architecture')
186 | # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
187 | # help='number of data loading workers (default: 4)')
188 | # parser.add_argument('--epochs', default=90, type=int, metavar='N',
189 | # help='number of total epochs to run')
190 | # parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
191 | # help='manual epoch number (useful on restarts)')
192 | # parser.add_argument('-b', '--batch-size', default=256, type=int,
193 | # metavar='N', help='mini-batch size (default: 256)')
194 | # parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
195 | # metavar='LR', help='initial learning rate')
196 | # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
197 | # help='momentum')
198 | # parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
199 | # metavar='W', help='weight decay (default: 1e-4)')
200 | # parser.add_argument('--print-freq', '-p', default=10, type=int,
201 | # metavar='N', help='print frequency (default: 10)')
202 | # parser.add_argument('--resume', default='', type=str, metavar='PATH',
203 | # help='path to latest checkpoint (default: none)')
204 | # parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
205 | # help='evaluate model on validation set')
206 | # parser.add_argument('--pretrained', dest='pretrained', action='store_true',
207 | # help='use pre-trained model')
208 | # parser.add_argument('--world-size', default=1, type=int,
209 | # help='number of distributed processes')
210 | # parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
211 | # help='url used to set up distributed training')
212 | # parser.add_argument('--dist-backend', default='gloo', type=str,
213 | # help='distributed backend')
214 | # parser.add_argument('--seed', default=None, type=int,
215 | # help='seed for initializing training. ')
216 | # parser.add_argument('--gpu', default=None, type=int,
217 | # help='GPU id to use.')
218 |
219 | # best_prec1 = 0
220 |
221 |
222 | # def main():
223 | # global args, best_prec1
224 | # args = parser.parse_args()
225 |
226 | # if args.seed is not None:
227 | # random.seed(args.seed)
228 | # torch.manual_seed(args.seed)
229 | # cudnn.deterministic = True
230 | # warnings.warn('You have chosen to seed training. '
231 | # 'This will turn on the CUDNN deterministic setting, '
232 | # 'which can slow down your training considerably! '
233 | # 'You may see unexpected behavior when restarting '
234 | # 'from checkpoints.')
235 |
236 | # if args.gpu is not None:
237 | # warnings.warn('You have chosen a specific GPU. This will completely '
238 | # 'disable data parallelism.')
239 |
240 | # args.distributed = args.world_size > 1
241 |
242 | # if args.distributed:
243 | # dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
244 | # world_size=args.world_size)
245 |
246 | # # create model
247 | # if args.pretrained:
248 | # print("=> using pre-trained model '{}'".format(args.arch))
249 | # if args.arch in models.__dict__:
250 | # model = models.__dict__[args.arch](pretrained=True)
251 | # elif args.arch == "mnet2":
252 | # model = mnet2("/home/leizhou/CVPR2019/vid_cls/models/mobilenet_v2.pth.tar")
253 | # else:
254 | # print("=> creating model '{}'".format(args.arch))
255 | # if args.arch in models.__dict__:
256 | # model = models.__dict__[args.arch]()
257 | # else:
258 | # model = eval(args.arch)()
259 |
260 | # if args.gpu is not None:
261 | # model = model.cuda(args.gpu)
262 | # elif args.distributed:
263 | # model.cuda()
264 | # model = torch.nn.parallel.DistributedDataParallel(model)
265 | # else:
266 | # if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
267 | # model.features = torch.nn.DataParallel(model.features)
268 | # model.cuda()
269 | # else:
270 | # model = torch.nn.DataParallel(model).cuda()
271 |
272 | # # define loss function (criterion) and optimizer
273 | # criterion = nn.CrossEntropyLoss().cuda(args.gpu)
274 |
275 | # optimizer = torch.optim.SGD(model.parameters(), args.lr,
276 | # momentum=args.momentum,
277 | # weight_decay=args.weight_decay)
278 |
279 | # # optionally resume from a checkpoint
280 | # if args.resume:
281 | # if os.path.isfile(args.resume):
282 | # print("=> loading checkpoint '{}'".format(args.resume))
283 | # checkpoint = torch.load(args.resume)
284 | # args.start_epoch = checkpoint['epoch']
285 | # best_prec1 = checkpoint['best_prec1']
286 | # model.load_state_dict(checkpoint['state_dict'])
287 | # optimizer.load_state_dict(checkpoint['optimizer'])
288 | # print("=> loaded checkpoint '{}' (epoch {})"
289 | # .format(args.resume, checkpoint['epoch']))
290 | # else:
291 | # print("=> no checkpoint found at '{}'".format(args.resume))
292 |
293 | # cudnn.benchmark = True
294 |
295 | # # Data loading code
296 | # traindir = os.path.join(args.data, 'train')
297 | # valdir = os.path.join(args.data, 'val')
298 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
299 | # std=[0.229, 0.224, 0.225])
300 |
301 | # train_dataset = datasets.ImageFolder(
302 | # traindir,
303 | # transforms.Compose([
304 | # transforms.RandomResizedCrop(224),
305 | # transforms.RandomHorizontalFlip(),
306 | # transforms.ToTensor(),
307 | # normalize,
308 | # ]))
309 |
310 | # if args.distributed:
311 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
312 | # else:
313 | # train_sampler = None
314 |
315 | # train_loader = torch.utils.data.DataLoader(
316 | # train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
317 | # num_workers=args.workers, pin_memory=True, sampler=train_sampler)
318 |
319 | # val_loader = torch.utils.data.DataLoader(
320 | # datasets.ImageFolder(valdir, transforms.Compose([
321 | # transforms.Resize(256),
322 | # transforms.CenterCrop(224),
323 | # transforms.ToTensor(),
324 | # normalize,
325 | # ])),
326 | # batch_size=args.batch_size, shuffle=False,
327 | # num_workers=args.workers, pin_memory=True)
328 |
329 | # if args.evaluate:
330 | # validate(val_loader, model, criterion)
331 | # return
332 |
333 | # for epoch in range(args.start_epoch, args.epochs):
334 | # if args.distributed:
335 | # train_sampler.set_epoch(epoch)
336 | # adjust_learning_rate(optimizer, epoch)
337 |
338 | # # train for one epoch
339 | # train(train_loader, model, criterion, optimizer, epoch)
340 |
341 | # # evaluate on validation set
342 | # prec1 = validate(val_loader, model, criterion)
343 |
344 | # # remember best prec@1 and save checkpoint
345 | # is_best = prec1 > best_prec1
346 | # best_prec1 = max(prec1, best_prec1)
347 | # save_checkpoint({
348 | # 'epoch': epoch + 1,
349 | # 'arch': args.arch,
350 | # 'state_dict': model.state_dict(),
351 | # 'best_prec1': best_prec1,
352 | # 'optimizer' : optimizer.state_dict(),
353 | # }, is_best)
354 |
355 |
356 | # def train(train_loader, model, criterion, optimizer, epoch):
357 | # batch_time = AverageMeter()
358 | # data_time = AverageMeter()
359 | # losses = AverageMeter()
360 | # top1 = AverageMeter()
361 | # top5 = AverageMeter()
362 |
363 | # # switch to train mode
364 | # model.train()
365 |
366 | # end = time.time()
367 | # for i, (input, target) in enumerate(train_loader):
368 | # # measure data loading time
369 | # data_time.update(time.time() - end)
370 |
371 | # if args.gpu is not None:
372 | # input = input.cuda(args.gpu, non_blocking=True)
373 | # target = target.cuda(args.gpu, non_blocking=True)
374 |
375 | # # compute output
376 | # output = model(input)
377 | # loss = criterion(output, target)
378 |
379 | # # measure accuracy and record loss
380 | # prec1, prec5 = accuracy(output, target, topk=(1, 5))
381 | # losses.update(loss.item(), input.size(0))
382 | # top1.update(prec1[0], input.size(0))
383 | # top5.update(prec5[0], input.size(0))
384 |
385 | # # compute gradient and do SGD step
386 | # optimizer.zero_grad()
387 | # loss.backward()
388 | # optimizer.step()
389 |
390 | # # measure elapsed time
391 | # batch_time.update(time.time() - end)
392 | # end = time.time()
393 |
394 | # if i % args.print_freq == 0:
395 | # print('Epoch: [{0}][{1}/{2}]\t'
396 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
397 | # 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
398 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
399 | # 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
400 | # 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
401 | # epoch, i, len(train_loader), batch_time=batch_time,
402 | # data_time=data_time, loss=losses, top1=top1, top5=top5))
403 |
404 |
405 | # def validate(val_loader, model, criterion):
406 | # batch_time = AverageMeter()
407 | # losses = AverageMeter()
408 | # top1 = AverageMeter()
409 | # top5 = AverageMeter()
410 |
411 | # # switch to evaluate mode
412 | # model.eval()
413 |
414 | # with torch.no_grad():
415 | # end = time.time()
416 | # for i, (input, target) in enumerate(val_loader):
417 | # if args.gpu is not None:
418 | # input = input.cuda(args.gpu, non_blocking=True)
419 | # target = target.cuda(args.gpu, non_blocking=True)
420 |
421 | # # compute output
422 | # output = model(input)
423 | # loss = criterion(output, target)
424 |
425 | # # measure accuracy and record loss
426 | # prec1, prec5 = accuracy(output, target, topk=(1, 5))
427 | # losses.update(loss.item(), input.size(0))
428 | # top1.update(prec1[0], input.size(0))
429 | # top5.update(prec5[0], input.size(0))
430 |
431 | # # measure elapsed time
432 | # batch_time.update(time.time() - end)
433 | # end = time.time()
434 |
435 | # if i % args.print_freq == 0:
436 | # print('Test: [{0}/{1}]\t'
437 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
438 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
439 | # 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
440 | # 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
441 | # i, len(val_loader), batch_time=batch_time, loss=losses,
442 | # top1=top1, top5=top5))
443 |
444 | # print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
445 | # .format(top1=top1, top5=top5))
446 |
447 | # return top1.avg
448 |
449 |
450 | # def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
451 | # torch.save(state, filename)
452 | # if is_best:
453 | # shutil.copyfile(filename, 'model_best.pth.tar')
454 |
455 |
456 | # class AverageMeter(object):
457 | # """Computes and stores the average and current value"""
458 | # def __init__(self):
459 | # self.reset()
460 |
461 | # def reset(self):
462 | # self.val = 0
463 | # self.avg = 0
464 | # self.sum = 0
465 | # self.count = 0
466 |
467 | # def update(self, val, n=1):
468 | # self.val = val
469 | # self.sum += val * n
470 | # self.count += n
471 | # self.avg = self.sum / self.count
472 |
473 |
474 | # def adjust_learning_rate(optimizer, epoch):
475 | # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
476 | # lr = args.lr * (0.1 ** (epoch // 30))
477 | # for param_group in optimizer.param_groups:
478 | # param_group['lr'] = lr
479 |
480 |
481 | # def accuracy(output, target, topk=(1,)):
482 | # """Computes the precision@k for the specified values of k"""
483 | # with torch.no_grad():
484 | # maxk = max(topk)
485 | # batch_size = target.size(0)
486 |
487 | # _, pred = output.topk(maxk, 1, True, True)
488 | # pred = pred.t()
489 | # correct = pred.eq(target.view(1, -1).expand_as(pred))
490 |
491 | # res = []
492 | # for k in topk:
493 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
494 | # res.append(correct_k.mul_(100.0 / batch_size))
495 | # return res
496 |
497 |
498 | # if __name__ == '__main__':
499 | # main()
--------------------------------------------------------------------------------
/scripts/imagenet_2d_res26.sh:
--------------------------------------------------------------------------------
1 | # CUDA_LAUNCH_BLOCKING=1 \
2 | CUDA_VISIBLE_DEVICES=4,5,6,7 \
3 | python main_imagenet.py \
4 | imagenet \
5 | placeholder \
6 | placeholder \
7 | --arch resnet26 \
8 | --epochs 100 \
9 | --batch-size 512 \
10 | --lr 0.1 \
11 | --lr_steps 30 50 70 90 \
12 | --workers 20 \
13 | --weight-decay 0.0001 \
14 | --eval-freq 1 \
15 |
--------------------------------------------------------------------------------
/scripts/kinetics400_3d_res50_slowonly_im_pre.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
2 | python main.py \
3 | kinetics400 \
4 | data/kinetics400/kinetics_train_list_xlw \
5 | data/kinetics400/kinetics_val_list_xlw \
6 | --arch resnet50_3d_slowonly \
7 | --dro 0.5 \
8 | --mode 3D \
9 | --t_length 8 \
10 | --t_stride 8 \
11 | --pretrained \
12 | --epochs 110 \
13 | --batch-size 96 \
14 | --lr 0.02 \
15 | --wd 0.0001 \
16 | --lr_steps 50 80 100 \
17 | --workers 16 \
18 |
19 | python ./test_kaiming.py \
20 | kinetics400 \
21 | data/kinetics400/kinetics_val_list_xlw \
22 | output/kinetics400_resnet50_3d_slowonly_3D_length8_stride8_dropout0.5/model_best.pth \
23 | --arch resnet50_3d_slowonly \
24 | --mode TSN+3D \
25 | --batch_size 1 \
26 | --num_segments 10 \
27 | --input_size 256 \
28 | --t_length 8 \
29 | --t_stride 8 \
30 | --dropout 0.5 \
31 | --workers 12 \
32 | --image_tmpl image_{:06d}.jpg \
--------------------------------------------------------------------------------
/test_10crop.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import numpy as np
5 | import torch.nn.parallel
6 | import torch.optim
7 | # from sklearn.metrics import confusion_matrix
8 |
9 | from lib.dataset import VideoDataSet
10 | from lib.models import VideoModule, TSN
11 | from lib.transforms import *
12 | from lib.utils.tools import AverageMeter, accuracy
13 |
14 | import pdb
15 |
16 | # options
17 | parser = argparse.ArgumentParser(
18 | description="Standard video-level testing")
19 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics400', 'kinetics200'])
20 | parser.add_argument('test_list', type=str)
21 | parser.add_argument('weights', type=str)
22 | parser.add_argument('--arch', type=str, default="resnet50_3d_v1")
23 | parser.add_argument('--mode', type=str, default="TSN+3D")
24 | parser.add_argument('--save_scores', type=str, default=None)
25 | parser.add_argument('--batch_size', type=int, default=2)
26 | parser.add_argument('--num_segments', type=int, default=10)
27 | parser.add_argument('--input_size', type=int, default=224)
28 | parser.add_argument('--resize', type=int, default=256)
29 | parser.add_argument('--t_length', type=int, default=16)
30 | parser.add_argument('--t_stride', type=int, default=4)
31 | parser.add_argument('--crop_fusion_type', type=str, default='avg',
32 | choices=['avg', 'max', 'topk'])
33 | parser.add_argument('--image_tmpl', type=str)
34 | parser.add_argument('--dropout', type=float, default=0.2)
35 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
36 | help='number of data loading workers (default: 4)')
37 |
38 | args = parser.parse_args()
39 |
40 | def main():
41 | if args.dataset == 'ucf101':
42 | num_class = 101
43 | elif args.dataset == 'hmdb51':
44 | num_class = 51
45 | elif args.dataset == 'kinetics400':
46 | num_class = 400
47 | elif args.dataset == 'kinetics200':
48 | num_class = 200
49 | else:
50 | raise ValueError('Unknown dataset '+args.dataset)
51 |
52 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
53 | "data/{}/access".format(args.dataset))
54 |
55 | net = VideoModule(num_class=num_class,
56 | base_model_name=args.arch,
57 | dropout=args.dropout,
58 | pretrained=False)
59 |
60 | # compute params number of a model
61 | num_params = 0
62 | for param in net.parameters():
63 | num_params += param.reshape((-1, 1)).shape[0]
64 | print("Model Size is {:.3f}M".format(num_params / 1000000))
65 |
66 | net = torch.nn.DataParallel(net).cuda()
67 | net.eval()
68 |
69 | # load weights
70 | model_state = torch.load(args.weights)
71 | state_dict = model_state['state_dict']
72 | test_epoch = model_state['epoch']
73 | arch = model_state['arch']
74 | assert arch == args.arch
75 | net.load_state_dict(state_dict)
76 | tsn = TSN(args.batch_size, net,
77 | args.num_segments, args.t_length,
78 | crop_fusion_type=args.crop_fusion_type,
79 | mode=args.mode).cuda()
80 |
81 | ## test data
82 | test_transform = torchvision.transforms.Compose([
83 | GroupOverSample(args.input_size, args.resize),
84 | Stack(mode=args.mode),
85 | ToTorchFormatTensor(),
86 | GroupNormalize(),
87 | ])
88 | test_dataset = VideoDataSet(
89 | root_path=data_root,
90 | list_file=args.test_list,
91 | t_length=args.t_length,
92 | t_stride=args.t_stride,
93 | num_segments=args.num_segments,
94 | image_tmpl=args.image_tmpl,
95 | transform=test_transform,
96 | phase="Test")
97 | test_loader = torch.utils.data.DataLoader(
98 | test_dataset,
99 | batch_size=args.batch_size, shuffle=False,
100 | num_workers=args.workers, pin_memory=True)
101 |
102 | # Test
103 | batch_timer = AverageMeter()
104 | top1 = AverageMeter()
105 | top5 = AverageMeter()
106 | results = None
107 |
108 | # set eval mode
109 | tsn.eval()
110 |
111 | end = time.time()
112 | for ind, (data, label) in enumerate(test_loader):
113 | label = label.cuda(non_blocking=True)
114 |
115 | with torch.no_grad():
116 | output, pred, _, _ = tsn(data)
117 | prec1, prec5 = accuracy(pred, label, topk=(1, 5))
118 | top1.update(prec1.item(), data.shape[0])
119 | top5.update(prec5.item(), data.shape[0])
120 |
121 | # pdb.set_trace()
122 | batch_timer.update(time.time() - end)
123 | end = time.time()
124 | if results is not None:
125 | np.concatenate((results, output.cpu().numpy()), axis=0)
126 | else:
127 | results = output.cpu().numpy()
128 | print("{0}/{1} done, Batch: {batch_timer.val:.3f}({batch_timer.avg:.3f}), \
129 | Top1: {top1.val:>6.3f}({top1.avg:>6.3f}), \
130 | Top5: {top5.val:>6.3f}({top5.avg:>6.3f})".
131 | format(ind + 1, len(test_loader),
132 | batch_timer=batch_timer,
133 | top1=top1, top5=top5))
134 | target_file = os.path.join(args.save_scores, "arch_{0}-epoch_{1}-top1_{2}-top5_{3}.npz".format(arch, test_epoch, top1.avg, top5.avg))
135 | print("saving {}".format(target_file))
136 | np.savez(target_file, results)
137 | if __name__ == "__main__":
138 | main()
139 |
--------------------------------------------------------------------------------
/test_kaiming.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import numpy as np
5 | import torch.nn.parallel
6 | import torch.optim
7 | # from sklearn.metrics import confusion_matrix
8 |
9 | from lib.dataset import VideoDataSet, ShortVideoDataSet
10 | from lib.models import VideoModule, TSN
11 | from lib.transforms import *
12 | from lib.utils.tools import AverageMeter, accuracy
13 |
14 | import pdb
15 | import logging
16 |
17 | def set_logger(debug_mode=False):
18 | import time
19 | from time import gmtime, strftime
20 | logdir = os.path.join(args.experiment_root, 'log')
21 | if not os.path.exists(logdir):
22 | os.makedirs(logdir)
23 | log_file = "logfile_" + time.strftime("%d_%b_%Y_%H:%M:%S", time.localtime())
24 | log_file = os.path.join(logdir, log_file)
25 | handlers = [logging.FileHandler(log_file), logging.StreamHandler()]
26 |
27 | """ add '%(filename)s:%(lineno)d %(levelname)s:' to format show source file """
28 | logging.basicConfig(level=logging.DEBUG if debug_mode else logging.INFO,
29 | format='%(asctime)s: %(message)s',
30 | datefmt='%Y-%m-%d %H:%M:%S',
31 | handlers = handlers)
32 |
33 | # options
34 | parser = argparse.ArgumentParser(
35 | description="Standard video-level testing")
36 | parser.add_argument('dataset', type=str)
37 | parser.add_argument('test_list', type=str)
38 | parser.add_argument('weights', type=str)
39 | parser.add_argument('--arch', type=str, default="resnet50_3d_v1")
40 | parser.add_argument('--mode', type=str, default="TSN+3D")
41 | # parser.add_argument('--save_scores', type=str, default=None)
42 | parser.add_argument('--batch_size', type=int, default=2)
43 | parser.add_argument('--num_segments', type=int, default=10)
44 | parser.add_argument('--input_size', type=int, default=224)
45 | parser.add_argument('--resize', type=int, default=256)
46 | parser.add_argument('--t_length', type=int, default=16)
47 | parser.add_argument('--t_stride', type=int, default=4)
48 | # parser.add_argument('--crop_fusion_type', type=str, default='avg',
49 | # choices=['avg', 'max', 'topk'])
50 | parser.add_argument('--image_tmpl', type=str)
51 | parser.add_argument('--dropout', type=float, default=0.2)
52 | parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
53 | help='number of data loading workers (default: 4)')
54 |
55 | args = parser.parse_args()
56 |
57 | experiment_id = '_'.join(map(str, ['test', args.dataset, args.arch, args.mode,
58 | 'length'+str(args.t_length), 'stride'+str(args.t_stride),
59 | 'seg'+str(args.num_segments)]))
60 |
61 | args.experiment_root = os.path.join('./output', experiment_id)
62 |
63 | set_logger()
64 | logging.info(args)
65 | if not os.path.exists(args.experiment_root):
66 | os.makedirs(args.experiment_root)
67 |
68 | def main():
69 | if args.dataset == 'ucf101':
70 | num_class = 101
71 | elif args.dataset == 'hmdb51':
72 | num_class = 51
73 | elif args.dataset == 'kinetics400':
74 | num_class = 400
75 | elif args.dataset == 'kinetics200':
76 | num_class = 200
77 | elif args.dataset == 'sthsth_v1':
78 | num_class = 174
79 | else:
80 | raise ValueError('Unknown dataset '+args.dataset)
81 |
82 | data_root = os.path.join(os.path.dirname(os.path.abspath(__file__)),
83 | "data/{}/access".format(args.dataset))
84 |
85 | net = VideoModule(num_class=num_class,
86 | base_model_name=args.arch,
87 | dropout=args.dropout,
88 | pretrained=False)
89 |
90 | # compute params number of a model
91 | num_params = 0
92 | for param in net.parameters():
93 | num_params += param.reshape((-1, 1)).shape[0]
94 | logging.info("Model Size is {:.3f}M".format(num_params / 1000000))
95 |
96 | net = torch.nn.DataParallel(net).cuda()
97 | net.eval()
98 |
99 | # load weights
100 | model_state = torch.load(args.weights)
101 | state_dict = model_state['state_dict']
102 | test_epoch = model_state['epoch']
103 | best_metric = model_state['best_metric']
104 | arch = model_state['arch']
105 | logging.info("Model Epoch: {}; Best_Top1: {}".format(test_epoch, best_metric))
106 | assert arch == args.arch
107 | net.load_state_dict(state_dict)
108 | tsn = TSN(args.batch_size, net,
109 | args.num_segments, args.t_length,
110 | mode=args.mode).cuda()
111 |
112 | ## test data
113 | test_transform = torchvision.transforms.Compose([
114 | GroupScale(256),
115 | GroupOverSampleKaiming(args.input_size),
116 | Stack(mode=args.mode),
117 | ToTorchFormatTensor(),
118 | GroupNormalize(),
119 | ])
120 | test_dataset = VideoDataSet(
121 | root_path=data_root,
122 | list_file=args.test_list,
123 | t_length=args.t_length,
124 | t_stride=args.t_stride,
125 | num_segments=args.num_segments,
126 | image_tmpl=args.image_tmpl,
127 | transform=test_transform,
128 | phase="Test")
129 | test_loader = torch.utils.data.DataLoader(
130 | test_dataset,
131 | batch_size=args.batch_size, shuffle=False,
132 | num_workers=args.workers, pin_memory=True)
133 |
134 | # Test
135 | batch_timer = AverageMeter()
136 | top1_m = AverageMeter()
137 | top5_m = AverageMeter()
138 | top1_a = AverageMeter()
139 | top5_a = AverageMeter()
140 | results_m = None
141 | results_a = None
142 |
143 | # set eval mode
144 | tsn.eval()
145 |
146 | end = time.time()
147 | for ind, (data, label) in enumerate(test_loader):
148 | label = label.cuda(non_blocking=True)
149 |
150 | with torch.no_grad():
151 | output_m, pred_m, output_a, pred_a = tsn(data)
152 | prec1_m, prec5_m = accuracy(pred_m, label, topk=(1, 5))
153 | prec1_a, prec5_a = accuracy(pred_a, label, topk=(1, 5))
154 | top1_m.update(prec1_m.item(), data.shape[0])
155 | top5_m.update(prec5_m.item(), data.shape[0])
156 | top1_a.update(prec1_a.item(), data.shape[0])
157 | top5_a.update(prec5_a.item(), data.shape[0])
158 |
159 | # pdb.set_trace()
160 | batch_timer.update(time.time() - end)
161 | end = time.time()
162 | if results_m is not None:
163 | np.concatenate((results_m, output_m.cpu().numpy()), axis=0)
164 | else:
165 | results_m = output_m.cpu().numpy()
166 |
167 | if results_a is not None:
168 | np.concatenate((results_a, output_a.cpu().numpy()), axis=0)
169 | else:
170 | results_a = output_a.cpu().numpy()
171 | logging.info("{0}/{1} done, Batch: {batch_timer.val:.3f}({batch_timer.avg:.3f}), maxTop1: {top1_m.val:>6.3f}({top1_m.avg:>6.3f}), maxTop5: {top5_m.val:>6.3f}({top5_m.avg:>6.3f}), avgTop1: {top1_a.val:>6.3f}({top1_a.avg:>6.3f}), avgTop5: {top5_a.val:>6.3f}({top5_a.avg:>6.3f})".
172 | format(ind + 1, len(test_loader),
173 | batch_timer=batch_timer,
174 | top1_m=top1_m, top5_m=top5_m, top1_a=top1_a, top5_a=top5_a))
175 | max_target_file = os.path.join(args.experiment_root, "arch_{0}-epoch_{1}-top1_{2}-top5_{3}_max.npz".format(arch, test_epoch, top1_m.avg, top5_m.avg))
176 | avg_target_file = os.path.join(args.experiment_root, "arch_{0}-epoch_{1}-top1_{2}-top5_{3}_avg.npz".format(arch, test_epoch, top1_a.avg, top5_a.avg))
177 | print("saving {}".format(max_target_file))
178 | np.savez(max_target_file, results_m)
179 | print("saving {}".format(avg_target_file))
180 | np.savez(avg_target_file, results_a)
181 | if __name__ == "__main__":
182 | main()
183 |
--------------------------------------------------------------------------------
/train_val.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn.utils import clip_grad_norm_
8 |
9 | from lib.utils.tools import *
10 |
11 | def set_bn_eval(m):
12 | classname = m.__class__.__name__
13 | if classname.find('BatchNorm') != -1:
14 | m.eval()
15 |
16 | def train(train_loader, model, criterion, optimizer, epoch, print_freq):
17 | batch_time = AverageMeter()
18 | data_time = AverageMeter()
19 | losses = AverageMeter()
20 | top1 = AverageMeter()
21 | top5 = AverageMeter()
22 |
23 | # switch to train mode
24 | model.train()
25 |
26 | end = time.time()
27 | for i, (input, target) in enumerate(train_loader):
28 | # measure data loading time
29 | data_time.update(time.time() - end)
30 |
31 | # input = input.cuda()
32 | target = target.cuda(non_blocking=True)
33 |
34 | # compute output
35 | output = model(input)
36 | loss = criterion(output, target)
37 |
38 | # measure accuracy and record loss
39 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
40 | losses.update(loss.item(), input.size(0))
41 | top1.update(prec1.item(), input.size(0))
42 | top5.update(prec5.item(), input.size(0))
43 |
44 | # compute gradient and do SGD step
45 | optimizer.zero_grad()
46 | loss.backward()
47 |
48 | # clip gradients
49 | # total_norm = clip_grad_norm_(model.parameters(), 20)
50 | # if total_norm > 20:
51 | # print("clipping gradient: {} with coef {}".format(total_norm, 20 / total_norm))
52 |
53 | optimizer.step()
54 |
55 | # measure elapsed time
56 | batch_time.update(time.time() - end)
57 | end = time.time()
58 |
59 | if i % print_freq == 0:
60 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
61 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
62 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
63 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
64 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
65 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
66 | epoch, i, len(train_loader), batch_time=batch_time,
67 | data_time=data_time, loss=losses, top1=top1,
68 | top5=top5, lr=optimizer.param_groups[-1]['lr'])))
69 |
70 | def finetune_fc(train_loader, model, criterion, optimizer, epoch, print_freq):
71 | batch_time = AverageMeter()
72 | data_time = AverageMeter()
73 | losses = AverageMeter()
74 | top1 = AverageMeter()
75 | top5 = AverageMeter()
76 |
77 | model.train()
78 |
79 | # switch mode
80 | for m in model.modules():
81 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
82 | m.eval()
83 | if isinstance(m, nn.Dropout):
84 | m.eval()
85 |
86 | # block gradients to base model
87 | for param in model.named_parameters():
88 | if "base_model" in param[0]:
89 | param[1].requires_grad = False
90 |
91 | end = time.time()
92 | for i, (input, target) in enumerate(train_loader):
93 | # measure data loading time
94 | data_time.update(time.time() - end)
95 |
96 | # input = input.cuda()
97 | target = target.cuda(non_blocking=True)
98 |
99 | # compute output
100 | output = model(input)
101 | loss = criterion(output, target)
102 |
103 | # measure accuracy and record loss
104 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
105 | losses.update(loss.item(), input.size(0))
106 | top1.update(prec1.item(), input.size(0))
107 | top5.update(prec5.item(), input.size(0))
108 |
109 | # compute gradient and do SGD step
110 | optimizer.zero_grad()
111 | loss.backward()
112 | optimizer.step()
113 |
114 | # measure elapsed time
115 | batch_time.update(time.time() - end)
116 | end = time.time()
117 |
118 | if i % print_freq == 0:
119 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
120 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
121 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
122 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
123 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
124 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
125 | epoch, i, len(train_loader), batch_time=batch_time,
126 | data_time=data_time, loss=losses, top1=top1,
127 | top5=top5, lr=optimizer.param_groups[-1]['lr'])))
128 |
129 | def finetune_bn_frozen(train_loader, model, criterion, optimizer, epoch, print_freq):
130 | batch_time = AverageMeter()
131 | data_time = AverageMeter()
132 | losses = AverageMeter()
133 | top1 = AverageMeter()
134 | top5 = AverageMeter()
135 |
136 | model.train()
137 |
138 | # model.apply(set_bn_eval)
139 |
140 | # switch mode
141 | for n, m in model.named_modules():
142 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
143 | # m.eval()
144 | if "base_model.bn1" in n:
145 | print(n)
146 | pass
147 | else:
148 | for p in m.parameters():
149 | p.requires_grad = False
150 | m.eval()
151 |
152 | # for n, m in model.named_modules():
153 | # if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
154 | # m.eval()
155 | # if isinstance(m, nn.Dropout):
156 | # m.eval()
157 |
158 | # block gradients to base model
159 | # for param in model.named_parameters():
160 | # if "bn" in param[0]:
161 | # print(param[1].requires_grad)
162 | # if "base_model" in param[0]:
163 | # param[1].requires_grad = False
164 |
165 | end = time.time()
166 | for i, (input, target) in enumerate(train_loader):
167 | # print(model.module.base_model.bn1.weight.view(-1)[:3])
168 | # print(model.module.base_model.bn1.running_mean.view(-1)[:3])
169 | # import pdb
170 | # pdb.set_trace()
171 | # print("conv1", model.state_dict()['module.base_model.conv1.weight'].view(-1)[0:3])
172 | # print("fc", model.state_dict()['module.classifier.1.weight'].view(-1)[0:3])
173 | # print(model.state_dict().view(-1)[0:3])
174 | # measure data loading time
175 | data_time.update(time.time() - end)
176 |
177 | # input = input.cuda()
178 | target = target.cuda(non_blocking=True)
179 |
180 | # compute output
181 | output = model(input)
182 | loss = criterion(output, target)
183 |
184 | # measure accuracy and record loss
185 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
186 | losses.update(loss.item(), input.size(0))
187 | top1.update(prec1.item(), input.size(0))
188 | top5.update(prec5.item(), input.size(0))
189 |
190 | # compute gradient and do SGD step
191 | optimizer.zero_grad()
192 | loss.backward()
193 | # for param in model.parameters():
194 | # param.grad.data.clamp_(-1, 1)
195 | total_norm = clip_grad_norm(model.parameters(), 20)
196 | if total_norm > 20:
197 | print("clipping gradient: {} with coef {}".format(total_norm, 20 / total_norm))
198 |
199 | optimizer.step()
200 |
201 | # measure elapsed time
202 | batch_time.update(time.time() - end)
203 | end = time.time()
204 |
205 | if i % print_freq == 0:
206 | logging.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
207 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
208 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
209 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
210 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
211 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
212 | epoch, i, len(train_loader), batch_time=batch_time,
213 | data_time=data_time, loss=losses, top1=top1,
214 | top5=top5, lr=optimizer.param_groups[-1]['lr'])))
215 |
216 | def validate(val_loader, model, criterion, print_freq, epoch, logger=None):
217 | batch_time = AverageMeter()
218 | losses = AverageMeter()
219 | top1 = AverageMeter()
220 | top5 = AverageMeter()
221 |
222 | # switch to evaluate mode
223 | model.eval()
224 |
225 | with torch.no_grad():
226 | end = time.time()
227 | for i, (input, target) in enumerate(val_loader):
228 | target = target.cuda(non_blocking=True)
229 |
230 | # compute output
231 | output = model(input)
232 | loss = criterion(output, target)
233 |
234 | # measure accuracy and record loss
235 | prec1, prec5 = accuracy(output, target, topk=(1, 5))
236 | losses.update(loss.item(), input.size(0))
237 | top1.update(prec1.item(), input.size(0))
238 | top5.update(prec5.item(), input.size(0))
239 |
240 | # measure elapsed time
241 | batch_time.update(time.time() - end)
242 | end = time.time()
243 |
244 | if i % print_freq == 0:
245 | logging.info(('Test: [{0}/{1}]\t'
246 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
247 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
248 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
249 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
250 | i, len(val_loader), batch_time=batch_time, loss=losses,
251 | top1=top1, top5=top5)))
252 |
253 | logging.info(('Epoch {epoch} Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
254 | .format(epoch=epoch, top1=top1, top5=top5, loss=losses)))
255 |
256 | # return (top1.avg + top5.avg) / 2
257 | return top1.avg
--------------------------------------------------------------------------------