├── .idea
├── NAS_spatiotemporal.iml
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── other.xml
├── remote-mappings.xml
└── workspace.xml
├── NAS_utils
├── __init__.py
└── ops.py
├── README.md
├── Spatiotemporal Fusion in 3D CNNs A Probabilistic View.pdf
├── Supplementary Materials.pdf.pdf
├── action.zip
├── args.py
├── dataset
├── IO.py
├── __init__.py
├── augment.py
└── config.py
├── main.py
├── models
├── __init__.py
├── densenet_3d.py
├── densenet_3d_forstat.py
├── mobilenet_v2_3d.py
└── resnet_3d.py
├── philly_distributed_utils
├── __init__.py
├── distributed.py
└── env.py
├── tools
├── __init__.py
├── ckpt_checker.py
├── generate_label_sthsthv1.py
├── generate_label_ucf101.py
├── statistics.py
├── to_hdf5.py
└── visualize.py
└── utils.py
/.idea/NAS_spatiotemporal.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/NAS_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/NAS_utils/__init__.py
--------------------------------------------------------------------------------
/NAS_utils/ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch as t
4 | import torch.nn.functional as F
5 |
6 | from torch.nn import Module, Conv3d, Linear
7 |
8 | from args import parser
9 | args = parser.parse_args()
10 |
11 | _NASAS = args.enable_nasas
12 | #TRAINING_SIZE = 86017
13 |
14 | class Conv3d_with_CD(Conv3d):
15 | def __init__(self, in_channels, out_channels, kernel_size,
16 | stride=1, padding=0, dilation=1, groups=1, bias=True,
17 | weight_reg=10.0, drop_reg=1.0, p_init=1e-1, deterministic=False, debug=False, split=1, split_pattern=None, training_size=0, deact_nasas=False):
18 | super(Conv3d_with_CD, self).__init__(in_channels, out_channels, kernel_size, stride=stride,
19 | padding=padding, dilation=dilation, groups=groups, bias=bias)
20 |
21 | self._weight_reg = weight_reg / training_size
22 | self._drop_reg = drop_reg / training_size
23 | self._det_mode = deterministic
24 | self._debug_mode = debug
25 | self._deterministic = deterministic
26 | self._deact_nasas = deact_nasas
27 |
28 | self.split = split
29 | self.split_pattern = split_pattern
30 | if self.split_pattern:
31 | assert len(self.split_pattern) == self.split
32 | assert sum(split_pattern) == self.in_channels
33 |
34 | #self._noise_shape = (self.in_channels, 1, 1, 1)
35 | self._noise_shape = (1, 1, 1, 1) #if not self._deterministic else (self.in_channels, 1, 1, 1)
36 | self._eps = 1e-8
37 | self._temp = 1. / 5.
38 | self._p_init = p_init
39 |
40 | self.p_logit = t.nn.Parameter(t.Tensor([np.log(self._p_init) - np.log(1. - self._p_init)]*self.split)) if _NASAS and not self._deact_nasas else None
41 |
42 | if self._deterministic:
43 | print('Using determinist drop.')
44 | self.unif_noise_var = t.zeros(size=[1]+list(self._noise_shape)).uniform_(0,1)
45 | self.unif_noise_variable = t.nn.Parameter(self.unif_noise_var, requires_grad=False)
46 | if self._debug_mode:
47 | if self.in_channels == 64:
48 | self.p_logit.register_hook(print)
49 |
50 | def _concrete_dropout(self, input):
51 | if self.split_pattern:
52 | _p = self.p_logit[0].expand(self.split_pattern[0])
53 | if self.split > 1:
54 | _p = t.cat(
55 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)),
56 | dim=0
57 | )
58 | else:
59 | assert self.split == 1
60 | #_p = self.p_logit[0].expand(self.in_channels)
61 | _p = self.p_logit[0] #if not self._deterministic else self.p_logit[0].expand(self.in_channels)
62 | _p = _p.sigmoid().view([1]+list(self._noise_shape))
63 |
64 | if self._deterministic:
65 | drop_tensor = t.floor(self.unif_noise_variable.cuda() + _p)
66 | random_tensor = 1. - drop_tensor
67 | else:
68 | unif_noise_1 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda()
69 | unif_noise_2 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda()
70 |
71 | drop_prob = (
72 | t.log(_p + self._eps)
73 | - t.log(1. - _p + self._eps)
74 | + t.log(-t.log(unif_noise_1 + self._eps) + self._eps)
75 | - t.log(-t.log(unif_noise_2 + self._eps) + self._eps)
76 | )
77 |
78 | drop_prob = t.sigmoid(drop_prob/self._temp)
79 | random_tensor = 1. - drop_prob
80 |
81 | return input * random_tensor
82 |
83 |
84 | def forward(self, input):
85 | input = self._concrete_dropout(input) if _NASAS and not self._deact_nasas else input
86 | return F.conv3d(input, self.weight, self.bias, self.stride,
87 | self.padding, self.dilation, self.groups)
88 |
89 | @property
90 | def KLreg(self):
91 | if self._deact_nasas:
92 | return 0.0
93 |
94 | if self.split_pattern:
95 | _p = self.p_logit[0].expand(self.split_pattern[0])
96 | if self.split > 1:
97 | _p = t.cat(
98 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)),
99 | dim=0
100 | )
101 | else:
102 | assert self.split == 1
103 | #_p = self.p_logit[0].expand(self.in_channels)
104 | _p = self.p_logit[0]
105 | _p = _p.sigmoid()
106 | # deprecated by split version
107 | weight_regularizer = self._weight_reg * t.sum(self.weight**2) * (1. - _p)
108 | #weight_regularizer = self._weight_reg * t.sum((self.weight**2) * (1. - _p.view([1, self.in_channels, 1, 1, 1])))
109 | dropout_regularizer = _p * t.log(_p)
110 | dropout_regularizer += (1. - _p) * t.log(1. - _p)
111 | # deprecated by split version
112 | #dropout_regularizer *= self._drop_reg * self.in_channels
113 | dropout_regularizer *= self._drop_reg
114 | return weight_regularizer + t.sum(dropout_regularizer)
115 |
116 | @property
117 | def p(self):
118 | if self._deact_nasas:
119 | return None
120 | return self.p_logit.sigmoid()
121 |
122 |
123 | class Linear_with_CD(Linear):
124 | def __init__(self, in_features, out_features, bias=True,
125 | weight_reg=10.0, drop_reg=1.0, p_init=1e-1, deterministic=False, debug=False, split=1, split_pattern=None, training_size=0, deact_nasas = False):
126 | super(Linear_with_CD, self).__init__(in_features, out_features, bias)
127 |
128 | self._weight_reg = weight_reg / training_size
129 | self._drop_reg = drop_reg / training_size
130 | self._det_mode = deterministic
131 | self._debug_mode = debug
132 | self._deterministic = deterministic
133 | self._deact_nasas = deact_nasas
134 |
135 | self.split = split
136 | self.split_pattern = split_pattern
137 | if self.split_pattern:
138 | assert len(self.split_pattern) == self.split
139 | assert sum(split_pattern) == self.in_features
140 |
141 | self._noise_shape = (self.in_features,)
142 | self._eps = 1e-8
143 | self._temp = 1. / 5.
144 | self._p_init = p_init
145 |
146 | self.p_logit = t.nn.Parameter(t.Tensor([np.log(self._p_init) - np.log(1. - self._p_init)] * self.split)) if _NASAS and not self._deact_nasas else None
147 |
148 | if self._deterministic:
149 | print('Using determinist drop.')
150 | self.unif_noise_var = t.zeros(size=[1] + list(self._noise_shape)).uniform_(0,1)
151 | self.unif_noise_variable = t.nn.Parameter(self.unif_noise_var, requires_grad=False)
152 | if self._debug_mode:
153 | self.p_logit.register_hook(print)
154 |
155 | def _concrete_dropout(self, input):
156 | if self.split_pattern:
157 | _p = self.p_logit[0].expand(self.split_pattern[0])
158 | if self.split > 1:
159 | _p = t.cat(
160 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)),
161 | dim=0
162 | )
163 | else:
164 | assert self.split == 1
165 | _p = self.p_logit[0].expand(self.in_features)
166 | _p = _p.sigmoid().view([1]+list(self._noise_shape))
167 |
168 | if self._deterministic:
169 | drop_tensor = t.floor(self.unif_noise_variable.cuda() + _p)
170 | random_tensor = 1. - drop_tensor
171 | else:
172 | unif_noise_1 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda()
173 | unif_noise_2 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda()
174 |
175 | drop_prob = (
176 | t.log(_p + self._eps)
177 | - t.log(1. - _p + self._eps)
178 | + t.log(-t.log(unif_noise_1 + self._eps) + self._eps)
179 | - t.log(-t.log(unif_noise_2 + self._eps) + self._eps)
180 | )
181 |
182 | drop_prob = t.sigmoid(drop_prob/self._temp)
183 | random_tensor = 1. - drop_prob
184 |
185 | return input * random_tensor
186 |
187 |
188 | def forward(self, input):
189 | input = self._concrete_dropout(input) if _NASAS and not self._deact_nasas else input
190 | return F.linear(input, self.weight, self.bias)
191 |
192 | @property
193 | def KLreg(self):
194 | if self._deact_nasas:
195 | return 0.0
196 |
197 | if self.split_pattern:
198 | _p = self.p_logit[0].expand(self.split_pattern[0])
199 | if self.split > 1:
200 | _p = t.cat(
201 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)),
202 | dim=0
203 | )
204 | else:
205 | assert self.split == 1
206 | _p = self.p_logit[0].expand(self.in_features)
207 | _p = _p.sigmoid()
208 | # deprecated by split version
209 | #weight_regularizer = self._weight_reg * t.sum(self.weight**2) * (1. - _p)
210 | weight_regularizer = self._weight_reg * t.sum((self.weight**2) * (1. - _p.view([1, self.in_features])))
211 | dropout_regularizer = _p * t.log(_p)
212 | dropout_regularizer += (1. - _p) * t.log(1. - _p)
213 | # deprecated by split version
214 | #dropout_regularizer *= self._drop_reg * self.in_channels
215 | dropout_regularizer *= self._drop_reg
216 | return weight_regularizer + t.sum(dropout_regularizer)
217 |
218 | @property
219 | def p(self):
220 | if self._deact_nasas:
221 | return None
222 | return self.p_logit.sigmoid()
223 |
224 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spatiotemporal Fusion in 3D CNNs: A Probabilistic View
2 |
3 | Experimental codes for the CVPR 2020 Oral Paper "Spatiotemporal Fusion in 3D CNNs: A Probabilistic View".
4 |
5 | The official code (Re-organized) is still under reviewed and to be appeared in the Microsoft official Repo.
6 |
7 |
8 | # Reference
9 |
10 | [1] Yizhou Zhou, Xiaoyan Sun, Chong Luo, Zheng-Jun Zha and Wengjun Zeng. Spatiotemporal fusion in 3D CNNs: A probabilistic view. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 9829-9838).
11 |
--------------------------------------------------------------------------------
/Spatiotemporal Fusion in 3D CNNs A Probabilistic View.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/Spatiotemporal Fusion in 3D CNNs A Probabilistic View.pdf
--------------------------------------------------------------------------------
/Supplementary Materials.pdf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/Supplementary Materials.pdf.pdf
--------------------------------------------------------------------------------
/action.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/action.zip
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | from philly_distributed_utils.env import get_master_ip
2 | from philly_distributed_utils.distributed import ompi_size
3 |
4 | import argparse
5 | parser = argparse.ArgumentParser(description="PyTorch implementation of NAS_spatiotemporal")
6 | parser.add_argument('--dataset', type=str, default="something")
7 | parser.add_argument('--modality', type=str, default='RGB', choices=['RGB', 'Flow'])
8 | parser.add_argument('--train_list', type=str, default="")
9 | parser.add_argument('--val_list', type=str, default="")
10 | parser.add_argument('--root_path', type=str, default="/mnt/data/")
11 | parser.add_argument('--store_name', type=str, default="")
12 | # ========================= Model Configs ==========================
13 | parser.add_argument('--arch', type=str, default="Dense3D121")
14 | parser.add_argument('--num_segments', type=int, default=1)
15 | parser.add_argument('--consensus_type', type=str, default='avg')
16 | parser.add_argument('--k', type=int, default=3)
17 |
18 | parser.add_argument('--dropout', '--do', default=0.5, type=float,
19 | metavar='DO', help='dropout ratio (default: 0.5)')
20 | parser.add_argument('--loss_type', type=str, default="nll",
21 | choices=['nll'])
22 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
23 | parser.add_argument('--suffix', type=str, default=None)
24 | parser.add_argument('--pretrain', type=str, default='imagenet')
25 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint')
26 |
27 | parser.add_argument('--enable_nasas', default=False, action="store_true",
28 | help='enable NASAS for architecture search')
29 | parser.add_argument('--temporal_nasas_only', default=False, action="store_true",
30 | help='only enable NASAS on temporal axis for architecture search')
31 |
32 | parser.add_argument('--cross_warmup', default=False, action="store_true",
33 | help='cross warmup for NASAS')
34 |
35 | parser.add_argument('--weight_reg', type=float, default=10.0,
36 | help='weight regularization used for nasas')
37 | parser.add_argument('--p_init', type=float, default=0.1,
38 | help='initial p used for nasas')
39 | parser.add_argument('--selection_mode', default=False, action="store_true",
40 | help='use selection mode in nasas')
41 | parser.add_argument('--test_mode', default=False, action="store_true",
42 | help='use test mode in nasas')
43 | parser.add_argument('--finetune_mode', default=False, action="store_true",
44 | help='use finetune mode in nasas')
45 | parser.add_argument('--training_size', default=86017, type=int,
46 | help='number of training samples')
47 |
48 |
49 | parser.add_argument('--net_version', default='pure_fused', type=str,
50 | help='densenet 3d version')
51 |
52 | # ========================= Learning Configs ==========================
53 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
54 | help='number of total epochs to run')
55 | parser.add_argument('-b', '--batch-size', default=32, type=int,
56 | metavar='N', help='mini-batch size (default: 256)')
57 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
58 | metavar='LR', help='initial learning rate')
59 | parser.add_argument('--lr_type', default='step', type=str,
60 | metavar='LRtype', help='learning rate type')
61 | parser.add_argument('--lr_steps', default=[30, 60, 80], type=float, nargs="+",
62 | metavar='LRSteps', help='epochs to decay learning rate by 10')
63 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
64 | help='momentum')
65 | parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
66 | metavar='W', help='weight decay (default: 5e-4)')
67 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
68 | metavar='W', help='gradient norm clipping (default: disabled)')
69 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
70 |
71 |
72 | # ========================= Monitor Configs ==========================
73 | parser.add_argument('--print-freq', '-p', default=20, type=int,
74 | metavar='N', help='print frequency (default: 10)')
75 | parser.add_argument('--eval-freq', '-ef', default=5, type=int,
76 | metavar='N', help='evaluation frequency (default: 5)')
77 | parser.add_argument('--test_split', type=int, default=0,
78 | help='The index of test file')
79 |
80 |
81 | # ========================= Runtime Configs ==========================
82 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
83 | help='number of data loading workers (default: 8)')
84 | parser.add_argument('--resume',
85 | default='/mnt/log/NAS_spatiotemporal/checkpoint/warmup/NAS_sptp_something_RGB_Dense3D121_avg_segment1_e50_droprate0.5_num_dense_sample32_dense/ckpt.best.pth.tar',
86 | type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
87 | parser.add_argument('--break_resume', default=False, action="store_true",
88 | help='if do break restore')
89 | parser.add_argument('--warmup', default=False, action="store_true",
90 | help='if do warmup initialization')
91 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
92 | help='evaluate model on validation set')
93 | parser.add_argument('--snapshot_pref', type=str, default="")
94 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
95 | help='manual epoch number (useful on restarts)')
96 | parser.add_argument('--gpus', nargs='+', type=int, default=None)
97 | parser.add_argument('--flow_prefix', default="", type=str)
98 | parser.add_argument('--root_log',type=str, default='/mnt/log/NAS_spatiotemporal/log')
99 | parser.add_argument('--root_model', type=str, default='/mnt/log/NAS_spatiotemporal/checkpoint')
100 |
101 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
102 | parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
103 | parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
104 |
105 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
106 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
107 |
108 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')
109 | parser.add_argument('--dense_sample_stride', default=1, type=int, help='dense sample stride for dense sample')
110 | parser.add_argument('--num_dense_sample', default=32, type=int, help='dense sample number for dense sample')
111 | parser.add_argument('--random_dense_sample_stride', default=False, action="store_true", help='use random dense sample stride for video dataset')
112 |
113 | parser.add_argument('--syncbn', default=False, action="store_true", help='Synchronized batch normalization')
114 | parser.add_argument('--use_zip', default=False, action="store_true", help='Use ZIP file for data I/O')
115 | parser.add_argument('--freeze_bn', default=False, action="store_true", help='Freeze batch normalization')
116 | # ========================= Distributed Configs ==========================
117 | parser.add_argument('--local_rank', type=int)
118 | parser.add_argument('--node_rank', type=int, default=-1)
119 | parser.add_argument('--dist-url',
120 | default='', #'tcp://' + get_master_ip() + ':23456',
121 | type=str,
122 | help='url used to set up distributed training')
123 | parser.add_argument('--world-size', default=0,#ompi_size(),
124 | type=int, help='number of distributed processes')
125 | parser.add_argument('--dist-backend', default='nccl', type=str,
126 | help='distributed backend')
127 | parser.add_argument('--philly-mpi-multi-node', default=False,action="store_true",
128 | help='nccl multiple node distributed')
129 | parser.add_argument('--philly-nccl-multi-node', default=False,action="store_true",
130 | help='nccl multiple node distributed')
131 |
--------------------------------------------------------------------------------
/dataset/IO.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from zipfile import ZipFile
4 | from PIL import Image
5 | import os
6 | import numpy as np
7 | from numpy.random import randint
8 |
9 |
10 | class VideoRecord(object):
11 | def __init__(self, row):
12 | self._data = row
13 |
14 | @property
15 | def path(self):
16 | return self._data[0]
17 |
18 | @property
19 | def num_frames(self):
20 | return int(self._data[1])
21 |
22 | @property
23 | def label(self):
24 | return int(self._data[2])
25 |
26 |
27 | class TSNDataSet(data.Dataset):
28 | def __init__(self, root_path, list_file,
29 | num_segments=3, new_length=1, modality='RGB',
30 | image_tmpl='img_{:05d}.jpg', transform=None,
31 | random_shift=True, test_mode=False,
32 | remove_missing=False, dense_sample=False, num_dense_sample=32, dense_sample_stride=1, random_dense_sample_stride=False, is_zip=False):
33 |
34 | self.root_path = root_path
35 | self.list_file = list_file
36 | self.num_segments = num_segments
37 | self.new_length = new_length
38 | self.modality = modality
39 | self.image_tmpl = image_tmpl
40 | self.transform = transform
41 | self.random_shift = random_shift
42 | self.test_mode = test_mode
43 | self.remove_missing = remove_missing
44 | self.dense_sample = dense_sample # using dense sample as I3D
45 | self.num_dense_sample = num_dense_sample
46 | self.dense_sample_stride = dense_sample_stride
47 | self.random_dense_sample_stride = random_dense_sample_stride
48 | self.is_zip = is_zip
49 | if self.dense_sample:
50 | print('=> Using dense sample for the dataset...')
51 |
52 | if self.modality == 'RGBDiff':
53 | self.new_length += 1 # Diff needs one more image to calculate diff
54 |
55 | self._parse_list()
56 |
57 | def _load_image(self, directory, idx, zip_f=None):
58 | if self.modality == 'RGB' or self.modality == 'RGBDiff':
59 | try:
60 | if self.is_zip:
61 | return [Image.open(zip_f.open(self.image_tmpl.format(idx))).convert('RGB')]
62 | else:
63 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
64 | except Exception:
65 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
66 | if self.is_zip:
67 | return [Image.open(zip_f.open(self.image_tmpl.format(1))).convert('RGB')]
68 | else:
69 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
70 | elif self.modality == 'Flow':
71 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf
72 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert(
73 | 'L')
74 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert(
75 | 'L')
76 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow
77 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
78 | format(int(directory), 'x', idx))).convert('L')
79 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
80 | format(int(directory), 'y', idx))).convert('L')
81 | else:
82 | try:
83 | # idx_skip = 1 + (idx-1)*5
84 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert(
85 | 'RGB')
86 | except Exception:
87 | print('error loading flow file:',
88 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
89 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')
90 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel
91 | flow_x, flow_y, _ = flow.split()
92 | x_img = flow_x.convert('L')
93 | y_img = flow_y.convert('L')
94 |
95 | return [x_img, y_img]
96 |
97 | def _parse_list(self):
98 | # check the frame number is large >3:
99 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
100 | tmp = [[' '.join(x[:-2]), x[-2], x[-1]] for x in tmp]
101 | if not self.test_mode or self.remove_missing:
102 | if self.test_mode and 'kinetics' in self.root_path:
103 | tmp = [item for item in tmp if int(item[1]) >= 32]
104 | print('####################### Heavy remove #######################')
105 | else:
106 | tmp = [item for item in tmp if int(item[1]) >= 3]
107 | self.video_list = [VideoRecord(item) for item in tmp]
108 |
109 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
110 | for v in self.video_list:
111 | v._data[1] = int(v._data[1]) / 2
112 | print('video number:%d' % (len(self.video_list)))
113 |
114 | def _sample_indices(self, record):
115 | """
116 |
117 | :param record: VideoRecord
118 | :return: list
119 | """
120 | if self.dense_sample: # i3d dense sample
121 | sample_pos = max(1, 1 + record.num_frames - self.num_dense_sample * self.dense_sample_stride)
122 | t_stride = self.num_dense_sample * self.dense_sample_stride // self.num_segments
123 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
124 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
125 | return np.array(offsets) + 1
126 | else: # normal sample
127 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
128 | if average_duration > 0:
129 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
130 | size=self.num_segments)
131 | elif record.num_frames > self.num_segments:
132 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
133 | else:
134 | offsets = np.zeros((self.num_segments,))
135 | return offsets + 1
136 |
137 | def _get_val_indices(self, record):
138 | if self.dense_sample: # i3d dense sample
139 | sample_pos = max(1, 1 + record.num_frames - self.num_dense_sample * self.dense_sample_stride)
140 | t_stride = self.num_dense_sample * self.dense_sample_stride // self.num_segments
141 | #start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
142 | start_idx = 0 if sample_pos == 1 else sample_pos//2
143 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
144 | return np.array(offsets) + 1
145 | else:
146 | if record.num_frames > self.num_segments + self.new_length - 1:
147 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
148 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
149 | else:
150 | offsets = np.zeros((self.num_segments,))
151 | return offsets + 1
152 |
153 | def _get_test_indices(self, record):
154 | if self.dense_sample:
155 | sample_pos = max(1, 1 + record.num_frames - self.num_dense_sample * self.dense_sample_stride)
156 | t_stride = self.num_dense_sample * self.dense_sample_stride // self.num_segments
157 | start_list = np.linspace(0, sample_pos - 1, num=2, dtype=int)
158 | offsets = []
159 | for start_idx in start_list.tolist():
160 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
161 | return np.array(offsets) + 1
162 | else:
163 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
164 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
165 | return offsets + 1
166 |
167 | def __getitem__(self, index):
168 | record = self.video_list[index]
169 | # check this is a legit video folder
170 |
171 | if self.image_tmpl == 'flow_{}_{:05d}.jpg':
172 | file_name = self.image_tmpl.format('x', 1)
173 | full_path = os.path.join(self.root_path, record.path, file_name)
174 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
175 | file_name = self.image_tmpl.format(int(record.path), 'x', 1)
176 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
177 | else:
178 | file_name = self.image_tmpl.format(1)
179 | full_path = os.path.join(self.root_path, record.path, file_name)
180 |
181 | if not self.is_zip:
182 | while not os.path.exists(full_path):
183 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
184 | index = np.random.randint(len(self.video_list))
185 | record = self.video_list[index]
186 | if self.image_tmpl == 'flow_{}_{:05d}.jpg':
187 | file_name = self.image_tmpl.format('x', 1)
188 | full_path = os.path.join(self.root_path, record.path, file_name)
189 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
190 | file_name = self.image_tmpl.format(int(record.path), 'x', 1)
191 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
192 | else:
193 | file_name = self.image_tmpl.format(1)
194 | full_path = os.path.join(self.root_path, record.path, file_name)
195 |
196 | if not self.test_mode:
197 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
198 | else:
199 | segment_indices = self._get_test_indices(record)
200 | return self.get(record, segment_indices)
201 |
202 | def get(self, record, indices):
203 |
204 | images = list()
205 | zip_f = ZipFile(os.path.join(self.root_path, record.path, 'RGB_frames.zip'), mode='r') if self.is_zip else None
206 | if self.dense_sample:
207 | assert self.num_segments == 1, "dense sample needs segment number to be 1."
208 | for seg_ind in indices:
209 | p = int(seg_ind)
210 | for i in range(self.num_dense_sample):
211 | seg_imgs = self._load_image(record.path, p, zip_f)
212 | images.extend(seg_imgs)
213 | if p < record.num_frames - self.dense_sample_stride:
214 | if self.random_dense_sample_stride and self.random_shift:
215 | p += randint(1, self.dense_sample_stride+1)
216 | else:
217 | p += self.dense_sample_stride
218 | else:
219 | for seg_ind in indices:
220 | p = int(seg_ind)
221 | for i in range(self.new_length):
222 | seg_imgs = self._load_image(record.path, p)
223 | images.extend(seg_imgs)
224 | if p < record.num_frames:
225 | p += 1
226 |
227 | process_data = self.transform(images)
228 | if zip_f:
229 | zip_f.close()
230 | return process_data, record.label
231 |
232 | def __len__(self):
233 | return len(self.video_list)
234 |
235 | if __name__ == '__main__':
236 | TSNDataSet('', '/data/home/v-yizzh/workspace/code/NAS_spatiotemporal/dataset/val_videofolder.txt')
237 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/augment.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 | import math
7 | import torch
8 |
9 |
10 | class GroupRandomCrop(object):
11 | def __init__(self, size, repeat=0):
12 |
13 | self.repeat = repeat
14 | if isinstance(size, numbers.Number):
15 | self.size = (int(size), int(size))
16 | else:
17 | self.size = size
18 |
19 | def __call__(self, img_group):
20 |
21 | h, w, _ = img_group[0].shape
22 | th, tw = self.size
23 |
24 | x1 = random.randint(0, w - tw)
25 | y1 = random.randint(0, h - th)
26 | cropped_img_group = img_group[:, y1 : y1+th, x1 : x1+tw, :]
27 |
28 | for i in range(self.repeat):
29 | x1 = random.randint(0, w - tw)
30 | y1 = random.randint(0, h - th)
31 | cropped_img_group = np.concatenate((cropped_img_group, img_group[:, y1 : y1+th, x1 : x1+tw, :]), axis=0)
32 | '''
33 | out_images = list()
34 |
35 | x1 = random.randint(0, w - tw)
36 | y1 = random.randint(0, h - th)
37 |
38 | for img in img_group:
39 | assert(img.size[0] == w and img.size[1] == h)
40 | if w == tw and h == th:
41 | out_images.append(img)
42 | else:
43 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
44 |
45 |
46 | return out_images
47 | '''
48 | return cropped_img_group
49 |
50 |
51 | class GroupCenterCrop(object):
52 | '''
53 | def __init__(self, size):
54 | self.worker = torchvision.transforms.CenterCrop(size)
55 |
56 | def __call__(self, img_group):
57 | return [self.worker(img) for img in img_group]
58 | '''
59 |
60 | def __init__(self, size):
61 | if isinstance(size, numbers.Number):
62 | self.size = (int(size), int(size))
63 | else:
64 | self.size = size
65 |
66 | def __call__(self, img_group):
67 |
68 | h, w, _ = img_group[0].shape
69 | th, tw = self.size
70 |
71 | assert th <= h and tw <= w, "target size must be smaller than original size."
72 |
73 | x1 = (w - tw) // 2
74 | y1 = (h - th) // 2
75 |
76 | return img_group[:, y1: y1 + th, x1: x1 + tw, :]
77 |
78 |
79 | class GroupRandomHorizontalFlip(object):
80 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5
81 | """
82 | def __init__(self, is_flow=False):
83 | self.is_flow = is_flow
84 | '''
85 | def __call__(self, img_group, is_flow=False):
86 | v = random.random()
87 | if v < 0.5:
88 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
89 | if self.is_flow:
90 | for i in range(0, len(ret), 2):
91 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
92 | return ret
93 | else:
94 | return img_group
95 | '''
96 | def __call__(self, img_group, is_flow=False):
97 | assert is_flow is False, "Currently only RGB flip is supported."
98 | v = random.random()
99 | if v < 0.5:
100 | return img_group[:, :, ::-1, :].copy()
101 | else:
102 | return img_group
103 |
104 |
105 | class GroupNormalize(object):
106 | def __init__(self, mean, std):
107 | self.mean = torch.from_numpy(np.array(mean, dtype=np.float32)).view(-1,1,1,1)
108 | self.std = torch.from_numpy(np.array(std, dtype=np.float32)).view(-1,1,1,1)
109 |
110 | def __call__(self, tensor):
111 | '''
112 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
113 | rep_std = self.std * (tensor.size()[0]//len(self.std))
114 |
115 | # TODO: make efficient
116 | for t, m, s in zip(tensor, rep_mean, rep_std):
117 | t.sub_(m).div_(s)
118 | '''
119 | tensor.sub_(self.mean).div_(self.std)
120 | return tensor
121 |
122 |
123 | class GroupScale(object):
124 | """ Rescales the input PIL.Image to the given 'size'.
125 | 'size' will be the size of the smaller edge.
126 | For example, if height > width, then image will be
127 | rescaled to (size * height / width, size)
128 | size: size of the smaller edge
129 | interpolation: Default: PIL.Image.BILINEAR
130 | """
131 | def __init__(self, size, interpolation=Image.BILINEAR):
132 | if isinstance(size, int):
133 | size = [size]
134 | else:
135 | assert isinstance(size, list), "Size is list or int."
136 | self.worker = [torchvision.transforms.Resize(this_size, interpolation) for this_size in size]
137 |
138 | def __call__(self, img_group):
139 | this_worker = self.worker[np.random.randint(len(self.worker))]
140 | return [this_worker(img) for img in img_group]
141 |
142 |
143 | class GroupOverSample(object):
144 | def __init__(self, crop_size, scale_size=None, flip=True):
145 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
146 |
147 | if scale_size is not None:
148 | self.scale_worker = GroupScale(scale_size)
149 | else:
150 | self.scale_worker = None
151 | self.flip = flip
152 |
153 | def __call__(self, img_group):
154 |
155 | if self.scale_worker is not None:
156 | img_group = self.scale_worker(img_group)
157 |
158 | image_w, image_h = img_group[0].size
159 | crop_w, crop_h = self.crop_size
160 |
161 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
162 | oversample_group = list()
163 | for o_w, o_h in offsets:
164 | normal_group = list()
165 | flip_group = list()
166 | for i, img in enumerate(img_group):
167 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
168 | normal_group.append(crop)
169 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
170 |
171 | if img.mode == 'L' and i % 2 == 0:
172 | flip_group.append(ImageOps.invert(flip_crop))
173 | else:
174 | flip_group.append(flip_crop)
175 |
176 | oversample_group.extend(normal_group)
177 | if self.flip:
178 | oversample_group.extend(flip_group)
179 | return oversample_group
180 |
181 |
182 | class GroupFullResSample(object):
183 | def __init__(self, crop_size, scale_size=None, flip=True):
184 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
185 |
186 | if scale_size is not None:
187 | self.scale_worker = GroupScale(scale_size)
188 | else:
189 | self.scale_worker = None
190 | self.flip = flip
191 |
192 | def __call__(self, img_group):
193 |
194 | if self.scale_worker is not None:
195 | img_group = self.scale_worker(img_group)
196 |
197 | image_w, image_h = img_group[0].size
198 | crop_w, crop_h = self.crop_size
199 |
200 | w_step = (image_w - crop_w) // 4
201 | h_step = (image_h - crop_h) // 4
202 |
203 | offsets = list()
204 | offsets.append((0 * w_step, 2 * h_step)) # left
205 | offsets.append((4 * w_step, 2 * h_step)) # right
206 | offsets.append((2 * w_step, 2 * h_step)) # center
207 |
208 | oversample_group = list()
209 | for o_w, o_h in offsets:
210 | normal_group = list()
211 | flip_group = list()
212 | for i, img in enumerate(img_group):
213 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
214 | normal_group.append(crop)
215 | if self.flip:
216 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
217 |
218 | if img.mode == 'L' and i % 2 == 0:
219 | flip_group.append(ImageOps.invert(flip_crop))
220 | else:
221 | flip_group.append(flip_crop)
222 |
223 | oversample_group.extend(normal_group)
224 | oversample_group.extend(flip_group)
225 | return oversample_group
226 |
227 |
228 | class GroupMultiScaleCrop(object):
229 |
230 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
231 | self.scales = scales if scales is not None else [1, .875, .75, .66]
232 | self.max_distort = max_distort
233 | self.fix_crop = fix_crop
234 | self.more_fix_crop = more_fix_crop
235 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
236 | self.interpolation = Image.BILINEAR
237 |
238 | def __call__(self, img_group):
239 |
240 | im_size = img_group[0].size
241 |
242 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
243 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
244 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
245 | for img in crop_img_group]
246 | return ret_img_group
247 |
248 | def _sample_crop_size(self, im_size):
249 | image_w, image_h = im_size[0], im_size[1]
250 |
251 | # find a crop size
252 | base_size = min(image_w, image_h)
253 | crop_sizes = [int(base_size * x) for x in self.scales]
254 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
255 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
256 |
257 | pairs = []
258 | for i, h in enumerate(crop_h):
259 | for j, w in enumerate(crop_w):
260 | if abs(i - j) <= self.max_distort:
261 | pairs.append((w, h))
262 |
263 | crop_pair = random.choice(pairs)
264 | if not self.fix_crop:
265 | w_offset = random.randint(0, image_w - crop_pair[0])
266 | h_offset = random.randint(0, image_h - crop_pair[1])
267 | else:
268 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
269 |
270 | return crop_pair[0], crop_pair[1], w_offset, h_offset
271 |
272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
274 | return random.choice(offsets)
275 |
276 | @staticmethod
277 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
278 | w_step = (image_w - crop_w) // 4
279 | h_step = (image_h - crop_h) // 4
280 |
281 | ret = list()
282 | ret.append((0, 0)) # upper left
283 | ret.append((4 * w_step, 0)) # upper right
284 | ret.append((0, 4 * h_step)) # lower left
285 | ret.append((4 * w_step, 4 * h_step)) # lower right
286 | ret.append((2 * w_step, 2 * h_step)) # center
287 |
288 | if more_fix_crop:
289 | ret.append((0, 2 * h_step)) # center left
290 | ret.append((4 * w_step, 2 * h_step)) # center right
291 | ret.append((2 * w_step, 4 * h_step)) # lower center
292 | ret.append((2 * w_step, 0 * h_step)) # upper center
293 |
294 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
295 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
296 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
297 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
298 |
299 | return ret
300 |
301 |
302 | class GroupRandomSizedCrop(object):
303 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
304 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
305 | This is popularly used to train the Inception networks
306 | size: size of the smaller edge
307 | interpolation: Default: PIL.Image.BILINEAR
308 | """
309 | def __init__(self, size, interpolation=Image.BILINEAR):
310 | self.size = size
311 | self.interpolation = interpolation
312 |
313 | def __call__(self, img_group):
314 | for attempt in range(10):
315 | area = img_group[0].size[0] * img_group[0].size[1]
316 | target_area = random.uniform(0.08, 1.0) * area
317 | aspect_ratio = random.uniform(3. / 4, 4. / 3)
318 |
319 | w = int(round(math.sqrt(target_area * aspect_ratio)))
320 | h = int(round(math.sqrt(target_area / aspect_ratio)))
321 |
322 | if random.random() < 0.5:
323 | w, h = h, w
324 |
325 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
326 | x1 = random.randint(0, img_group[0].size[0] - w)
327 | y1 = random.randint(0, img_group[0].size[1] - h)
328 | found = True
329 | break
330 | else:
331 | found = False
332 | x1 = 0
333 | y1 = 0
334 |
335 | if found:
336 | out_group = list()
337 | for img in img_group:
338 | img = img.crop((x1, y1, x1 + w, y1 + h))
339 | assert(img.size == (w, h))
340 | out_group.append(img.resize((self.size, self.size), self.interpolation))
341 | return out_group
342 | else:
343 | # Fallback
344 | scale = GroupScale(self.size, interpolation=self.interpolation)
345 | crop = GroupRandomCrop(self.size)
346 | return crop(scale(img_group))
347 |
348 |
349 | class Stack(object):
350 |
351 | def __init__(self, roll=False):
352 | self.roll = roll
353 |
354 | def __call__(self, img_group):
355 | if img_group[0].mode == 'L':
356 | return np.concatenate([np.expand_dims(np.expand_dims(x, 2), 0) for x in img_group], axis=0)
357 | elif img_group[0].mode == 'RGB':
358 | if self.roll:
359 | return np.concatenate([np.expand_dims(np.array(x)[:, :, ::-1], axis=0) for x in img_group], axis=0)
360 | else:
361 | return np.concatenate([np.expand_dims(np.array(x), axis=0) for x in img_group], axis=0)
362 |
363 |
364 | class ToTorchFormatTensor(object):
365 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
366 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
367 | def __init__(self, div=True):
368 | self.div = div
369 |
370 | def __call__(self, pic):
371 | assert isinstance(pic, np.ndarray), "Require numpy array input."
372 | if isinstance(pic, np.ndarray):
373 | # handle numpy array
374 | img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous()
375 | else:
376 | # handle PIL Image
377 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
378 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
379 | # put it from HWC to CHW format
380 | # yikes, this transpose takes 80% of the loading time/CPU
381 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
382 | return img.float().div(255) if self.div else img.float()
383 |
384 |
385 | class IdentityTransform(object):
386 |
387 | def __call__(self, data):
388 | return data
389 |
390 |
391 | def get_train_augmentation(modality='RGB', flip=True, div=True, roll=False):
392 | assert modality == 'RGB', "Currently only RGB augmentation is supported."
393 | if modality == 'RGB':
394 | if flip:
395 | #return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),
396 | # GroupRandomHorizontalFlip(is_flow=False)])
397 | return torchvision.transforms.Compose([GroupScale([256, 288, 320]),
398 | Stack(roll=roll),
399 | GroupRandomCrop(224),
400 | GroupRandomHorizontalFlip(is_flow=False),
401 | ToTorchFormatTensor(div=div)]
402 | )
403 | else:
404 | print('#' * 20, 'NO FLIP!!!')
405 | #return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])])
406 | return torchvision.transforms.Compose([GroupScale(256),
407 | Stack(roll=roll),
408 | GroupRandomCrop(224),
409 | ToTorchFormatTensor(div=div)]
410 | )
411 | elif modality == 'Flow':
412 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
413 | GroupRandomHorizontalFlip(is_flow=True)])
414 | elif modality == 'RGBDiff':
415 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
416 | GroupRandomHorizontalFlip(is_flow=False)])
417 |
418 |
419 | def get_val_augmentation(modality='RGB', div=True, roll=False):
420 | assert modality == 'RGB', "Currently only RGB augmentation is supported."
421 | if modality == 'RGB':
422 | return torchvision.transforms.Compose([GroupScale(256),
423 | Stack(roll=roll),
424 | #GroupRandomCrop(224),
425 | GroupCenterCrop(256),
426 | ToTorchFormatTensor(div=div)]
427 | )
428 | elif modality == 'Flow':
429 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
430 | GroupRandomHorizontalFlip(is_flow=True)])
431 | elif modality == 'RGBDiff':
432 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
433 | GroupRandomHorizontalFlip(is_flow=False)])
434 |
435 |
436 | def get_test_augmentation(modality='RGB', div=True, roll=False):
437 | assert modality == 'RGB', "Currently only RGB augmentation is supported."
438 | if modality == 'RGB':
439 | return torchvision.transforms.Compose([GroupScale(256),
440 | Stack(roll=roll),
441 | #GroupRandomCrop(256, 2),
442 | GroupCenterCrop(256),
443 | ToTorchFormatTensor(div=div)]
444 | )
445 | elif modality == 'Flow':
446 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
447 | GroupRandomHorizontalFlip(is_flow=True)])
448 | elif modality == 'RGBDiff':
449 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
450 | GroupRandomHorizontalFlip(is_flow=False)])
451 |
452 |
453 | def get_selection_augmentation(modality='RGB', div=True, roll=False):
454 | assert modality == 'RGB', "Currently only RGB augmentation is supported."
455 | if modality == 'RGB':
456 | return torchvision.transforms.Compose([GroupScale(256),
457 | Stack(roll=roll),
458 | GroupCenterCrop(256),
459 | ToTorchFormatTensor(div=div)]
460 | )
461 | elif modality == 'Flow':
462 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
463 | GroupRandomHorizontalFlip(is_flow=True)])
464 | elif modality == 'RGBDiff':
465 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]),
466 | GroupRandomHorizontalFlip(is_flow=False)])
467 |
468 | if __name__ == "__main__":
469 | trans = torchvision.transforms.Compose([GroupScale(256),
470 | Stack(),
471 | GroupRandomCrop(224),
472 | GroupRandomHorizontalFlip(is_flow=False),
473 | ToTorchFormatTensor(),
474 | GroupNormalize(mean=[.485, .456, .406], std=[.229, .224, .225])]
475 | )
476 |
477 | im = Image.open('/mnt/data/somethingsomethingv1_raw/20bn-something-something-v1/2/00002.jpg')
478 |
479 | color_group = [im] * 6
480 | rst = trans(color_group)
481 |
482 | gray_group = [im.convert('L')] * 9
483 | gray_rst = trans(gray_group)
484 |
485 | trans2 = torchvision.transforms.Compose([
486 | GroupRandomSizedCrop(256),
487 | Stack(),
488 | ToTorchFormatTensor(),
489 | GroupNormalize(
490 | mean=[.485, .456, .406],
491 | std=[.229, .224, .225])
492 | ])
493 | print(trans2(color_group))
--------------------------------------------------------------------------------
/dataset/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ROOT_DATASET = '/mnt/data/'
4 |
5 |
6 | def return_ucf101(modality):
7 | filename_categories = 'UCF101/labels/classInd.txt'
8 | if modality == 'RGB':
9 | root_data = ROOT_DATASET + 'UCF101/jpg'
10 | filename_imglist_train = 'UCF101/file_list/ucf101_rgb_train_split_1.txt'
11 | filename_imglist_val = 'UCF101/file_list/ucf101_rgb_val_split_1.txt'
12 | prefix = 'img_{:05d}.jpg'
13 | elif modality == 'Flow':
14 | root_data = ROOT_DATASET + 'UCF101/jpg'
15 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
16 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
17 | prefix = 'flow_{}_{:05d}.jpg'
18 | else:
19 | raise NotImplementedError('no such modality:' + modality)
20 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
21 |
22 |
23 | def return_ucf101_zip(modality, root_path=''):
24 | assert modality == 'RGB', "Currently RGB only."
25 | filename_categories = 'ucf101_zip/classInd.txt'
26 | if modality == 'RGB':
27 | root_data = os.path.join(root_path, 'ucf101_zip/')
28 | filename_imglist_train = 'ucf101_zip/train_videofolder.txt'
29 | filename_imglist_val = 'ucf101_zip/val_videofolder.txt'
30 | prefix = 'image_{:05d}.jpg'
31 | elif modality == 'Flow':
32 | root_data = ROOT_DATASET + 'UCF101/jpg'
33 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
34 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
35 | prefix = 'flow_{}_{:05d}.jpg'
36 | else:
37 | raise NotImplementedError('no such modality:' + modality)
38 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
39 |
40 |
41 | def return_hmdb51(modality):
42 | filename_categories = 51
43 | if modality == 'RGB':
44 | root_data = ROOT_DATASET + 'HMDB51/images'
45 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt'
46 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt'
47 | prefix = 'img_{:05d}.jpg'
48 | elif modality == 'Flow':
49 | root_data = ROOT_DATASET + 'HMDB51/images'
50 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt'
51 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt'
52 | prefix = 'flow_{}_{:05d}.jpg'
53 | else:
54 | raise NotImplementedError('no such modality:' + modality)
55 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
56 |
57 |
58 | def return_something(modality, root_path=''):
59 | assert modality == 'RGB', "Currently RGB only."
60 | filename_categories = '20bn-something-something-v1/category.txt'
61 | if modality == 'RGB':
62 | root_data = os.path.join(root_path, '20bn-something-something-v1/')
63 | filename_imglist_train = '20bn-something-something-v1/train_videofolder.txt'
64 | filename_imglist_val = '20bn-something-something-v1/val_videofolder.txt'
65 | prefix = '{:05d}.jpg'
66 | elif modality == 'Flow':
67 | root_data = os.path.join(root_path, 'something/v1/20bn-something-something-v1-flow/')
68 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt'
69 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt'
70 | prefix = '{:06d}-{}_{:05d}.jpg'
71 | else:
72 | print('no such modality:'+modality)
73 | raise NotImplementedError
74 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
75 |
76 |
77 | def return_something_zip(modality, root_path=''):
78 | assert modality == 'RGB', "Currently RGB only."
79 | filename_categories = '20bn-something-something-v1_zip/category.txt'
80 | if modality == 'RGB':
81 | root_data = os.path.join(root_path, '20bn-something-something-v1_zip/')
82 | filename_imglist_train = '20bn-something-something-v1_zip/train_videofolder.txt'
83 | filename_imglist_val = '20bn-something-something-v1_zip/val_videofolder.txt'
84 | prefix = '{:05d}.jpg'
85 | elif modality == 'Flow':
86 | root_data = os.path.join(root_path, 'something/v1/20bn-something-something-v1-flow/')
87 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt'
88 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt'
89 | prefix = '{:06d}-{}_{:05d}.jpg'
90 | else:
91 | print('no such modality:'+modality)
92 | raise NotImplementedError
93 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
94 |
95 |
96 | def return_somethingv2(modality, root_path=''):
97 | assert modality == 'RGB', "Currently RGB only."
98 | filename_categories = '20bn-something-something-v2/category.txt'
99 | if modality == 'RGB':
100 | root_data = os.path.join(root_path, '20bn-something-something-v2/')
101 | filename_imglist_train = '20bn-something-something-v2/train_videofolder.txt'
102 | filename_imglist_val = '20bn-something-something-v2/val_videofolder.txt'
103 | prefix = '{:06d}.jpg'
104 | elif modality == 'Flow':
105 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow'
106 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt'
107 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt'
108 | prefix = '{:06d}.jpg'
109 | else:
110 | raise NotImplementedError('no such modality:'+modality)
111 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
112 |
113 |
114 | def return_somethingv2_zip(modality, root_path=''):
115 | assert modality == 'RGB', "Currently RGB only."
116 | filename_categories = '20bn-something-something-v2_zip/category.txt'
117 | if modality == 'RGB':
118 | root_data = os.path.join(root_path, '20bn-something-something-v2_zip/')
119 | filename_imglist_train = '20bn-something-something-v2_zip/train_videofolder.txt'
120 | filename_imglist_val = '20bn-something-something-v2_zip/val_videofolder.txt'
121 | prefix = '{:06d}.jpg'
122 | elif modality == 'Flow':
123 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow'
124 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt'
125 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt'
126 | prefix = '{:06d}.jpg'
127 | else:
128 | raise NotImplementedError('no such modality:'+modality)
129 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
130 |
131 |
132 | def return_jester(modality):
133 | filename_categories = 'jester/category.txt'
134 | if modality == 'RGB':
135 | prefix = '{:05d}.jpg'
136 | root_data = ROOT_DATASET + 'jester/20bn-jester-v1'
137 | filename_imglist_train = 'jester/train_videofolder.txt'
138 | filename_imglist_val = 'jester/val_videofolder.txt'
139 | else:
140 | raise NotImplementedError('no such modality:'+modality)
141 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
142 |
143 |
144 | def return_kinetics(modality, root_path=''):
145 | filename_categories = 400
146 | if modality == 'RGB':
147 | root_data = os.path.join(root_path, 'kinetics400_frame/')
148 | filename_imglist_train = 'kinetics400_frame/train_videofolder.txt'
149 | filename_imglist_val = 'kinetics400_frame/val_videofolder.txt'
150 | prefix = 'img_{:05d}.jpg'
151 | else:
152 | raise NotImplementedError('no such modality:' + modality)
153 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
154 |
155 |
156 | def return_kinetics_zip(modality, root_path=''):
157 | filename_categories = 400
158 | if modality == 'RGB':
159 | root_data = os.path.join(root_path, 'kinetics400_frame_zip/')
160 | filename_imglist_train = 'kinetics400_frame_zip/train_videofolder.txt'
161 | filename_imglist_val = 'kinetics400_frame_zip/val_videofolder.txt'
162 | prefix = 'img_{:05d}.jpg'
163 | else:
164 | raise NotImplementedError('no such modality:' + modality)
165 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
166 |
167 |
168 | def return_dataset(dataset, modality, root_path=''):
169 | dict_single = {'jester': return_jester, 'something': return_something, 'something_zip': return_something_zip,
170 | 'somethingv2': return_somethingv2, 'somethingv2_zip': return_somethingv2_zip,
171 | 'ucf101': return_ucf101, 'ucf101_zip': return_ucf101_zip, 'hmdb51': return_hmdb51,
172 | 'kinetics': return_kinetics, 'kinetics_zip': return_kinetics_zip}
173 | if dataset in dict_single:
174 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality, root_path)
175 | else:
176 | raise ValueError('Unknown dataset '+dataset)
177 |
178 | file_imglist_train = os.path.join(root_path, file_imglist_train)
179 | file_imglist_val = os.path.join(root_path, file_imglist_val)
180 | if isinstance(file_categories, str):
181 | file_categories = os.path.join(root_path, file_categories)
182 | with open(file_categories) as f:
183 | lines = f.readlines()
184 | categories = [item.rstrip() for item in lines]
185 | else: # number of categories
186 | categories = [None] * file_categories
187 | n_class = len(categories)
188 | print('{}: {} classes'.format(dataset, n_class))
189 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/models/__init__.py
--------------------------------------------------------------------------------
/models/densenet_3d.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | import types
7 |
8 | import os
9 |
10 | from collections import OrderedDict
11 | from functools import partial
12 |
13 | from args import parser
14 | from NAS_utils.ops import Conv3d_with_CD, Linear_with_CD
15 |
16 | args = parser.parse_args()
17 |
18 | #Conv2d = Conv2d
19 | Conv3d = partial(Conv3d_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init)
20 | Linear = partial(Linear_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init)
21 | nnConv2d = nn.Conv2d
22 | BatchNorm3d = partial(nn.BatchNorm3d, track_running_stats=not args.freeze_bn)
23 | _TEMPORAL_NASAS_ONLY = args.temporal_nasas_only
24 | _TEMPORAL_NODOWNSAMPLE = 'v1d3' in args.net_version or ('pure' in args.net_version and ('something' in args.dataset or 'ucf' in args.dataset))
25 |
26 | __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
27 |
28 |
29 | model_urls = {
30 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
31 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
32 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
33 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
34 | }
35 |
36 |
37 | input_sizes = {}
38 | means = {}
39 | stds = {}
40 |
41 |
42 | for model_name in __all__:
43 | input_sizes[model_name] = [3, 224, 224]
44 | means[model_name] = [0.485, 0.456, 0.406]
45 | stds[model_name] = [0.229, 0.224, 0.225]
46 |
47 |
48 | pretrained_settings = {}
49 |
50 |
51 | for model_name in __all__:
52 | pretrained_settings[model_name] = {
53 | 'imagenet': {
54 | 'url': model_urls[model_name],
55 | 'input_space': 'RGB',
56 | 'input_size': input_sizes[model_name],
57 | 'crop_size': input_sizes[model_name][-1] * 256 // 224,
58 | 'input_range': [0, 1],
59 | 'mean': means[model_name],
60 | 'std': stds[model_name]
61 | #'num_classes': 174
62 | }
63 | }
64 |
65 |
66 | def load_pretrained(model, num_classes, settings):
67 | #assert num_classes == settings['num_classes'], \
68 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
69 | try:
70 | state_dict = torch.load('/log/checkpoint/Densenet121_2D_ImagenetPretrained/densenet121-a639ec97.pth')
71 | except:
72 | state_dict = model_zoo.load_url(settings['url'])
73 | state_dict = update_state_dict(state_dict)
74 | mk, uk = model.load_state_dict(state_dict, strict=False)
75 | model.input_space = settings['input_space']
76 | model.input_size = settings['input_size']
77 | model.input_range = settings['input_range']
78 | model.mean = settings['mean']
79 | model.std = settings['std']
80 | return model
81 |
82 |
83 | def update_state_dict(state_dict):
84 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
85 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
86 | # They are also in the checkpoints in model_urls. This pattern is used
87 | # to find such keys.
88 | pattern = re.compile(
89 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
90 | for key in list(state_dict.keys()):
91 | res = pattern.match(key)
92 | if res:
93 | new_key = res.group(1) + res.group(2)
94 | state_dict[new_key] = state_dict[key]
95 | del state_dict[key]
96 |
97 | # Inflate to 3d densenet
98 | pattern = re.compile(
99 | r'^(.*)((?:conv|norm)(?:[012]?)\.(?:weight|bias|running_mean|running_var))$')
100 | for key in list(state_dict.keys()):
101 | res = pattern.match(key)
102 | if res:
103 | v = state_dict[key]
104 | if 'conv' in key:
105 | v = torch.unsqueeze(v, dim=2)
106 | if 'conv0' in key:
107 | v = v.repeat([1, 1, 5, 1, 1])
108 | v /= 5.0
109 | state_dict[key] = v
110 | elif 'conv1' in key:
111 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2)
112 | state_dict[new_key_btnk] = v
113 | if 'v1' in args.net_version or 'pure_temporal' in args.net_version:
114 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
115 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0
116 | del state_dict[key]
117 | elif 'conv2' in key:
118 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2)
119 | state_dict[new_key_sptl] = v
120 | del state_dict[key]
121 | else:
122 | if 'transition' in key:
123 | new_key_btnk = res.group(1) + 'original.' + res.group(2)
124 | state_dict[new_key_btnk] = v
125 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3', 'pure_temporal']:
126 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
127 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0
128 | del state_dict[key]
129 | else:
130 | state_dict[key] = v
131 | else:
132 | if 'norm1' in key:
133 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2)
134 | state_dict[new_key_btnk] = v
135 | if args.net_version in ['v1d2', 'v1nt', 'v1d3', 'pure_temporal']:
136 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
137 | state_dict[new_key_tmpr] = v
138 | del state_dict[key]
139 | elif 'norm2' in key:
140 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2)
141 | state_dict[new_key_sptl] = v
142 | del state_dict[key]
143 | else:
144 | if 'transition' in key:
145 | new_key_btnk = res.group(1) + 'original.' + res.group(2)
146 | state_dict[new_key_btnk] = v
147 | if args.net_version in ['v1d2', 'vt', 'v1d3', 'pure_temporal']:
148 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
149 | state_dict[new_key_tmpr] = v
150 | del state_dict[key]
151 | else:
152 | state_dict[key] = v
153 |
154 | if 'classifier' in key:
155 | del state_dict[key]
156 | return state_dict
157 |
158 |
159 | class _DenseLayer(nn.Sequential):
160 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, split=1, split_pattern=None):
161 | super(_DenseLayer, self).__init__()
162 | if 'pure_temporal' not in args.net_version:
163 | self.bottleneck = nn.Sequential(OrderedDict([
164 | ('norm1', BatchNorm3d(num_input_features)),
165 | ('relu1', nn.ReLU(inplace=True)),
166 | ('conv1', Conv3d(num_input_features, bn_size *
167 | growth_rate, kernel_size=1, stride=1, bias=False, split=split,
168 | split_pattern=split_pattern))
169 | ]))
170 | self.spatial = nn.Sequential(OrderedDict([
171 | ('norm2', BatchNorm3d(bn_size * growth_rate)),
172 | ('relu2', nn.ReLU(inplace=True)),
173 | ('conv2', Conv3d(bn_size * growth_rate, growth_rate,
174 | kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False, split=1, deact_nasas=_TEMPORAL_NASAS_ONLY))
175 | ]))
176 |
177 | if 'v1' in args.net_version:
178 | self.temporal = nn.Sequential(OrderedDict([
179 | ('norm1', BatchNorm3d(num_input_features)),
180 | ('relu1', nn.ReLU(inplace=True)),
181 | ('conv1', Conv3d(num_input_features, bn_size *
182 | growth_rate, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False,
183 | split=split,
184 | split_pattern=split_pattern))
185 | ]))
186 | elif 'v2' in args.net_version:
187 | self.temporal = nn.Sequential(OrderedDict([
188 | ('norm1', BatchNorm3d(bn_size * growth_rate)),
189 | ('relu1', nn.ReLU(inplace=True)),
190 | ('conv1', Conv3d(bn_size * growth_rate, bn_size * growth_rate,
191 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=1))
192 | ]))
193 | elif 'v3' in args.net_version:
194 | self.temporal = nn.Sequential(OrderedDict([
195 | ('norm1', BatchNorm3d(growth_rate)),
196 | ('relu1', nn.ReLU(inplace=True)),
197 | ('conv1', Conv3d(growth_rate, growth_rate,
198 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=1))
199 | ]))
200 | elif 'pure_temporal' in args.net_version:
201 | self.temporal = nn.Sequential(OrderedDict([
202 | ('norm1', BatchNorm3d(num_input_features)),
203 | ('relu1', nn.ReLU(inplace=True)),
204 | ('conv1', Conv3d(num_input_features, bn_size *
205 | growth_rate, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False,
206 | split=split,
207 | split_pattern=split_pattern))
208 | ]))
209 | else:
210 | pass
211 | self.drop_rate = drop_rate
212 |
213 | def forward(self, x):
214 | if 'v1' in args.net_version:
215 | new_features = self.temporal.forward(x) + self.bottleneck.forward(x)
216 | new_features = self.spatial.forward(new_features)
217 | elif 'v2' in args.net_version:
218 | new_features = self.bottleneck.forward(x)
219 | new_features = self.temporal.forward(new_features) + new_features
220 | new_features = self.spatial.forward(new_features)
221 | elif 'v3' in args.net_version:
222 | new_features = self.bottleneck.forward(x)
223 | new_features = self.spatial.forward(new_features)
224 | new_features = self.temporal.forward(new_features) + new_features
225 | elif 'pure_temporal' in args.net_version:
226 | new_features = self.temporal.forward(x)
227 | new_features = self.spatial.forward(new_features)
228 | else:
229 | new_features = self.bottleneck.forward(x)
230 | new_features = self.spatial.forward(new_features)
231 | #if self.drop_rate > 0:
232 | # new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
233 | return torch.cat([x, new_features], 1)
234 |
235 |
236 | class _DenseBlock(nn.Sequential):
237 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
238 | super(_DenseBlock, self).__init__()
239 | self._split_pattern = [num_input_features]
240 | for i in range(num_layers):
241 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, split=1, split_pattern=None)#split=i+1, split_pattern=self._split_pattern)
242 | self.add_module('denselayer%d' % (i + 1), layer)
243 | # DO NOT use += in-place operator here!
244 | self._split_pattern = self._split_pattern + [growth_rate]
245 |
246 |
247 | class _Transition(nn.Sequential):
248 | def __init__(self, num_input_features, num_output_features, split=1, split_pattern=None, temporal_pool_size=1):
249 | super(_Transition, self).__init__()
250 | '''
251 | self.add_module('norm', nn.BatchNorm3d(num_input_features))
252 | self.add_module('relu', nn.ReLU(inplace=True))
253 | self.add_module('conv', Conv3d(num_input_features, num_output_features,
254 | kernel_size=1, stride=1, bias=False, split=split, split_pattern=split_pattern))
255 | self.add_module('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2)))
256 | '''
257 | if 'pure_temporal' not in args.net_version:
258 | self.original = nn.Sequential(OrderedDict([
259 | ('norm', BatchNorm3d(num_input_features)),
260 | ('relu', nn.ReLU(inplace=True)),
261 | ('conv', Conv3d(num_input_features, num_output_features,
262 | kernel_size=1, stride=1, bias=False, split=split, split_pattern=split_pattern))
263 | ]))
264 | self.transition_pool = nn.Sequential(OrderedDict([
265 | ('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2)))
266 | ]))
267 |
268 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']:
269 | self.temporal = nn.Sequential(OrderedDict([
270 | ('norm', BatchNorm3d(num_input_features)),
271 | ('relu', nn.ReLU(inplace=True)),
272 | ('conv', Conv3d(num_input_features, num_output_features,
273 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=split,
274 | split_pattern=split_pattern))
275 | ]))
276 | elif args.net_version in ['v2', 'v3', 'v4']:
277 | self.temporal = nn.Sequential(OrderedDict([
278 | ('norm', BatchNorm3d(num_output_features)),
279 | ('relu', nn.ReLU(inplace=True)),
280 | ('conv', Conv3d(num_output_features, num_output_features,
281 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=1))
282 | ]))
283 | elif 'pure_temporal' in args.net_version:
284 | self.temporal = nn.Sequential(OrderedDict([
285 | ('norm', BatchNorm3d(num_input_features)),
286 | ('relu', nn.ReLU(inplace=True)),
287 | ('conv', Conv3d(num_input_features, num_output_features,
288 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=split,
289 | split_pattern=split_pattern))
290 | ]))
291 | else:
292 | pass
293 |
294 | def forward(self, input):
295 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']:
296 | new_features = self.original(input) + self.temporal(input)
297 | elif args.net_version in ['v2', 'v3', 'v4']:
298 | new_features = self.original(input)
299 | new_features = self.temporal(new_features) + new_features
300 | elif 'pure_temporal' in args.net_version:
301 | new_features = self.temporal(input)
302 | else:
303 | new_features = self.original(input)
304 | new_features = self.transition_pool(new_features)
305 | return new_features
306 |
307 |
308 | class DenseNet(nn.Module):
309 | r"""Densenet-BC model class, based on
310 | `"Densely Connected Convolutional Networks" `_
311 |
312 | Args:
313 | growth_rate (int) - how many filters to add each layer (`k` in paper)
314 | block_config (list of 4 ints) - how many layers in each pooling block
315 | num_init_features (int) - the number of filters to learn in the first convolution layer
316 | bn_size (int) - multiplicative factor for number of bottle neck layers
317 | (i.e. bn_size * k features in the bottleneck layer)
318 | drop_rate (float) - dropout rate after each dense layer
319 | num_classes (int) - number of classification classes
320 | """
321 |
322 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
323 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
324 |
325 | super(DenseNet, self).__init__()
326 | self.drop_rate = drop_rate
327 |
328 | # First convolution
329 | self.features = nn.Sequential(OrderedDict([
330 | ('conv0', nn.Conv3d(3, num_init_features, kernel_size=(5, 7, 7), stride=(1, 2, 2) if _TEMPORAL_NODOWNSAMPLE else 2, padding=(2, 3, 3), bias=False)),
331 | ('norm0', BatchNorm3d(num_init_features)),
332 | ('relu0', nn.ReLU(inplace=True)),
333 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2) if _TEMPORAL_NODOWNSAMPLE else 2, padding=1)),
334 | ]))
335 |
336 | # Each denseblock
337 | num_features = num_init_features
338 | downsample_pos = [-1] if _TEMPORAL_NODOWNSAMPLE else [0]
339 | for i, num_layers in enumerate(block_config):
340 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
341 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=self.drop_rate)
342 | self.features.add_module('denseblock%d' % (i + 1), block)
343 | num_features = num_features + num_layers * growth_rate
344 | if i != len(block_config) - 1:
345 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
346 | split=1,#split=num_layers+1,
347 | split_pattern=None, #split_pattern=[num_features - num_layers * growth_rate]+[growth_rate]*num_layers,
348 | temporal_pool_size=2 if i in downsample_pos else 1)
349 | self.features.add_module('transition%d' % (i + 1), trans)
350 | num_features = num_features // 2
351 |
352 | # Final batch norm
353 | self.features.add_module('norm5', BatchNorm3d(num_features))
354 |
355 | # Linear layer
356 | self.classifier = Linear(num_features, num_classes)
357 |
358 | # Official init from torch repo.
359 | for m in self.modules():
360 | if isinstance(m, nn.Conv3d):
361 | nn.init.kaiming_normal_(m.weight)
362 | elif isinstance(m, nn.BatchNorm3d):
363 | nn.init.constant_(m.weight, 0)
364 | nn.init.constant_(m.bias, 0)
365 | elif isinstance(m, nn.Linear):
366 | nn.init.constant_(m.bias, 0)
367 |
368 | def forward(self, x):
369 | features = self.features(x)
370 | out = F.relu(features, inplace=True)
371 | out = F.adaptive_avg_pool3d(out, (1, 1, 1)).view(features.size(0), -1)
372 | out = F.dropout(out, p=self.drop_rate, training=self.training)
373 | out = self.classifier(out)
374 | return out
375 |
376 |
377 | def _load_state_dict(model, model_url):
378 | # '.'s are no longer allowed in module names, but previous _DenseLayer
379 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
380 | # They are also in the checkpoints in model_urls. This pattern is used
381 | # to find such keys.
382 | pattern = re.compile(
383 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
384 | state_dict = model_zoo.load_url(model_url)
385 | for key in list(state_dict.keys()):
386 | res = pattern.match(key)
387 | if res:
388 | new_key = res.group(1) + res.group(2)
389 | state_dict[new_key] = state_dict[key]
390 | del state_dict[key]
391 | model.load_state_dict(state_dict)
392 |
393 |
394 | def modify_densenets(model):
395 | # Modify attributs
396 | model.last_linear = model.classifier
397 | del model.classifier
398 |
399 | def logits(self, features):
400 | x = F.relu(features, inplace=True)
401 | x = F.avg_pool2d(x, kernel_size=7, stride=1)
402 | x = x.view(x.size(0), -1)
403 | x = self.last_linear(x)
404 | return x
405 |
406 | def forward(self, input):
407 | x = self.features(input)
408 | x = self.logits(x)
409 | return x
410 |
411 | # Modify methods
412 | model.logits = types.MethodType(logits, model)
413 | model.forward = types.MethodType(forward, model)
414 | return model
415 |
416 |
417 | def _densenet121(num_classes, **kwargs):
418 | r"""Densenet-121 model from
419 | `"Densely Connected Convolutional Networks" `_
420 |
421 | Args:
422 | pretrained (bool): If True, returns a model pre-trained on ImageNet
423 | """
424 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes,
425 | **kwargs)
426 | return model
427 |
428 | def _densenet169(pretrained=False, **kwargs):
429 | r"""Densenet-121 model from
430 | `"Densely Connected Convolutional Networks" `_
431 |
432 | Args:
433 | pretrained (bool): If True, returns a model pre-trained on ImageNet
434 | """
435 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
436 | **kwargs)
437 | if pretrained:
438 | _load_state_dict(model, model_urls['densenet169'])
439 | return model
440 |
441 | def densenet121(num_classes=1000, pretrained='imagenet', drop_rate=0.0):
442 | r"""Densenet-121 model from
443 | `"Densely Connected Convolutional Networks" `
444 | """
445 | model = _densenet121(num_classes=num_classes, drop_rate=drop_rate)
446 | if pretrained is not None:
447 | settings = pretrained_settings['densenet121'][pretrained]
448 | model = load_pretrained(model, num_classes, settings)
449 | return model
450 |
451 |
452 | def densenet169(num_classes=1000, pretrained='imagenet'):
453 | r"""Densenet-121 model from
454 | `"Densely Connected Convolutional Networks" `
455 | """
456 | model = _densenet169(pretrained=False)
457 | if pretrained is not None:
458 | settings = pretrained_settings['densenet169'][pretrained]
459 | model = load_pretrained(model, num_classes, settings)
460 | #model = modify_densenets(model)
461 | return model
462 |
463 |
464 | def densenet201(pretrained=False, **kwargs):
465 | r"""Densenet-201 model from
466 | `"Densely Connected Convolutional Networks" `_
467 |
468 | Args:
469 | pretrained (bool): If True, returns a model pre-trained on ImageNet
470 | """
471 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
472 | **kwargs)
473 | if pretrained:
474 | _load_state_dict(model, model_urls['densenet201'])
475 | return model
476 |
477 |
478 | def densenet161(pretrained=False, **kwargs):
479 | r"""Densenet-161 model from
480 | `"Densely Connected Convolutional Networks" `_
481 |
482 | Args:
483 | pretrained (bool): If True, returns a model pre-trained on ImageNet
484 | """
485 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
486 | **kwargs)
487 | if pretrained:
488 | _load_state_dict(model, model_urls['densenet161'])
489 | return model
--------------------------------------------------------------------------------
/models/densenet_3d_forstat.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | import types
7 |
8 | import os
9 |
10 | from collections import OrderedDict
11 | from functools import partial
12 |
13 | from args import parser
14 | from NAS_utils.ops import Conv3d_with_CD, Linear_with_CD
15 |
16 | args = parser.parse_args()
17 |
18 | #Conv2d = Conv2d
19 | Conv3d = nn.Conv3d#partial(Conv3d_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init)
20 | Linear = nn.Linear #partial(Linear_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init)
21 | nnConv2d = nn.Conv2d
22 | BatchNorm3d = partial(nn.BatchNorm3d, track_running_stats=not args.freeze_bn)
23 | _TEMPORAL_NASAS_ONLY = args.temporal_nasas_only
24 |
25 | __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
26 |
27 |
28 | model_urls = {
29 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
30 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
31 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
32 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
33 | }
34 |
35 |
36 | input_sizes = {}
37 | means = {}
38 | stds = {}
39 |
40 |
41 | for model_name in __all__:
42 | input_sizes[model_name] = [3, 224, 224]
43 | means[model_name] = [0.485, 0.456, 0.406]
44 | stds[model_name] = [0.229, 0.224, 0.225]
45 |
46 |
47 | pretrained_settings = {}
48 |
49 |
50 | for model_name in __all__:
51 | pretrained_settings[model_name] = {
52 | 'imagenet': {
53 | 'url': model_urls[model_name],
54 | 'input_space': 'RGB',
55 | 'input_size': input_sizes[model_name],
56 | 'crop_size': input_sizes[model_name][-1] * 256 // 224,
57 | 'input_range': [0, 1],
58 | 'mean': means[model_name],
59 | 'std': stds[model_name]
60 | #'num_classes': 174
61 | }
62 | }
63 |
64 |
65 | def load_pretrained(model, num_classes, settings):
66 | #assert num_classes == settings['num_classes'], \
67 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
68 | try:
69 | state_dict = torch.load('/log/checkpoint/Densenet121_2D_ImagenetPretrained/densenet121-a639ec97.pth')
70 | except:
71 | state_dict = model_zoo.load_url(settings['url'])
72 | state_dict = update_state_dict(state_dict)
73 | mk, uk = model.load_state_dict(state_dict, strict=False)
74 | model.input_space = settings['input_space']
75 | model.input_size = settings['input_size']
76 | model.input_range = settings['input_range']
77 | model.mean = settings['mean']
78 | model.std = settings['std']
79 | return model
80 |
81 |
82 | def update_state_dict(state_dict):
83 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
84 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
85 | # They are also in the checkpoints in model_urls. This pattern is used
86 | # to find such keys.
87 | pattern = re.compile(
88 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
89 | for key in list(state_dict.keys()):
90 | res = pattern.match(key)
91 | if res:
92 | new_key = res.group(1) + res.group(2)
93 | state_dict[new_key] = state_dict[key]
94 | del state_dict[key]
95 |
96 | # Inflate to 3d densenet
97 | pattern = re.compile(
98 | r'^(.*)((?:conv|norm)(?:[012]?)\.(?:weight|bias|running_mean|running_var))$')
99 | for key in list(state_dict.keys()):
100 | res = pattern.match(key)
101 | if res:
102 | v = state_dict[key]
103 | if 'conv' in key:
104 | v = torch.unsqueeze(v, dim=2)
105 | if 'conv0' in key:
106 | v = v.repeat([1, 1, 5, 1, 1])
107 | v /= 5.0
108 | state_dict[key] = v
109 | elif 'conv1' in key:
110 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2)
111 | state_dict[new_key_btnk] = v
112 | if 'v1' in args.net_version:
113 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
114 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0
115 | del state_dict[key]
116 | elif 'conv2' in key:
117 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2)
118 | state_dict[new_key_sptl] = v
119 | del state_dict[key]
120 | else:
121 | if 'transition' in key:
122 | new_key_btnk = res.group(1) + 'original.' + res.group(2)
123 | state_dict[new_key_btnk] = v
124 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']:
125 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
126 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0
127 | del state_dict[key]
128 | else:
129 | state_dict[key] = v
130 | else:
131 | if 'norm1' in key:
132 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2)
133 | state_dict[new_key_btnk] = v
134 | if args.net_version in ['v1d2', 'v1nt', 'v1d3']:
135 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
136 | state_dict[new_key_tmpr] = v
137 | del state_dict[key]
138 | elif 'norm2' in key:
139 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2)
140 | state_dict[new_key_sptl] = v
141 | del state_dict[key]
142 | else:
143 | if 'transition' in key:
144 | new_key_btnk = res.group(1) + 'original.' + res.group(2)
145 | state_dict[new_key_btnk] = v
146 | if args.net_version in ['v1d2', 'vt', 'v1d3']:
147 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
148 | state_dict[new_key_tmpr] = v
149 | del state_dict[key]
150 | else:
151 | state_dict[key] = v
152 |
153 | if 'classifier' in key:
154 | del state_dict[key]
155 | return state_dict
156 |
157 |
158 | class _DenseLayer(nn.Sequential):
159 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, split=1, split_pattern=None):
160 | super(_DenseLayer, self).__init__()
161 | self.bottleneck = nn.Sequential(OrderedDict([
162 | ('norm1', BatchNorm3d(num_input_features)),
163 | ('relu1', nn.ReLU(inplace=True)),
164 | ('conv1', Conv3d(num_input_features, bn_size *
165 | growth_rate, kernel_size=1, stride=1, bias=False))
166 | ]))
167 | self.spatial = nn.Sequential(OrderedDict([
168 | ('norm2', BatchNorm3d(bn_size * growth_rate)),
169 | ('relu2', nn.ReLU(inplace=True)),
170 | ('conv2', Conv3d(bn_size * growth_rate, growth_rate,
171 | kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False))
172 | ]))
173 |
174 | if 'v1' in args.net_version:
175 | self.temporal = nn.Sequential(OrderedDict([
176 | ('norm1', BatchNorm3d(num_input_features)),
177 | ('relu1', nn.ReLU(inplace=True)),
178 | ('conv1', Conv3d(num_input_features, bn_size *
179 | growth_rate, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
180 | ]))
181 | elif 'v2' in args.net_version:
182 | self.temporal = nn.Sequential(OrderedDict([
183 | ('norm1', BatchNorm3d(bn_size * growth_rate)),
184 | ('relu1', nn.ReLU(inplace=True)),
185 | ('conv1', Conv3d(bn_size * growth_rate, bn_size * growth_rate,
186 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
187 | ]))
188 | elif 'v3' in args.net_version:
189 | self.temporal = nn.Sequential(OrderedDict([
190 | ('norm1', BatchNorm3d(growth_rate)),
191 | ('relu1', nn.ReLU(inplace=True)),
192 | ('conv1', Conv3d(growth_rate, growth_rate,
193 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
194 | ]))
195 | else:
196 | pass
197 | self.drop_rate = drop_rate
198 |
199 | def forward(self, x):
200 | if 'v1' in args.net_version:
201 | new_features = self.temporal.forward(x) + self.bottleneck.forward(x)
202 | new_features = self.spatial.forward(new_features)
203 | elif 'v2' in args.net_version:
204 | new_features = self.bottleneck.forward(x)
205 | new_features = self.temporal.forward(new_features) + new_features
206 | new_features = self.spatial.forward(new_features)
207 | elif 'v3' in args.net_version:
208 | new_features = self.bottleneck.forward(x)
209 | new_features = self.spatial.forward(new_features)
210 | new_features = self.temporal.forward(new_features) + new_features
211 | else:
212 | new_features = self.bottleneck.forward(x)
213 | new_features = self.spatial.forward(new_features)
214 | #if self.drop_rate > 0:
215 | # new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
216 | return torch.cat([x, new_features], 1)
217 |
218 |
219 | class _DenseBlock(nn.Sequential):
220 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
221 | super(_DenseBlock, self).__init__()
222 | self._split_pattern = [num_input_features]
223 | for i in range(num_layers):
224 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, split=1, split_pattern=None)#split=i+1, split_pattern=self._split_pattern)
225 | self.add_module('denselayer%d' % (i + 1), layer)
226 | # DO NOT use += in-place operator here!
227 | self._split_pattern = self._split_pattern + [growth_rate]
228 |
229 |
230 | class _Transition(nn.Sequential):
231 | def __init__(self, num_input_features, num_output_features, split=1, split_pattern=None, temporal_pool_size=1):
232 | super(_Transition, self).__init__()
233 | '''
234 | self.add_module('norm', nn.BatchNorm3d(num_input_features))
235 | self.add_module('relu', nn.ReLU(inplace=True))
236 | self.add_module('conv', Conv3d(num_input_features, num_output_features,
237 | kernel_size=1, stride=1, bias=False, split=split, split_pattern=split_pattern))
238 | self.add_module('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2)))
239 | '''
240 | self.original = nn.Sequential(OrderedDict([
241 | ('norm', BatchNorm3d(num_input_features)),
242 | ('relu', nn.ReLU(inplace=True)),
243 | ('conv', Conv3d(num_input_features, num_output_features,
244 | kernel_size=1, stride=1, bias=False))
245 | ]))
246 | self.transition_pool = nn.Sequential(OrderedDict([
247 | ('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2)))
248 | ]))
249 |
250 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']:
251 | self.temporal = nn.Sequential(OrderedDict([
252 | ('norm', BatchNorm3d(num_input_features)),
253 | ('relu', nn.ReLU(inplace=True)),
254 | ('conv', Conv3d(num_input_features, num_output_features,
255 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
256 | ]))
257 | elif args.net_version in ['v2', 'v3', 'v4']:
258 | self.temporal = nn.Sequential(OrderedDict([
259 | ('norm', BatchNorm3d(num_output_features)),
260 | ('relu', nn.ReLU(inplace=True)),
261 | ('conv', Conv3d(num_output_features, num_output_features,
262 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
263 | ]))
264 | else:
265 | pass
266 |
267 | def forward(self, input):
268 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']:
269 | new_features = self.original(input) + self.temporal(input)
270 | elif args.net_version in ['v2', 'v3', 'v4']:
271 | new_features = self.original(input)
272 | new_features = self.temporal(new_features) + new_features
273 | else:
274 | new_features = self.original(input)
275 | new_features = self.transition_pool(new_features)
276 | return new_features
277 |
278 |
279 | class DenseNet(nn.Module):
280 | r"""Densenet-BC model class, based on
281 | `"Densely Connected Convolutional Networks" `_
282 |
283 | Args:
284 | growth_rate (int) - how many filters to add each layer (`k` in paper)
285 | block_config (list of 4 ints) - how many layers in each pooling block
286 | num_init_features (int) - the number of filters to learn in the first convolution layer
287 | bn_size (int) - multiplicative factor for number of bottle neck layers
288 | (i.e. bn_size * k features in the bottleneck layer)
289 | drop_rate (float) - dropout rate after each dense layer
290 | num_classes (int) - number of classification classes
291 | """
292 |
293 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
294 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
295 |
296 | super(DenseNet, self).__init__()
297 | self.drop_rate = drop_rate
298 |
299 | # First convolution
300 | self.features = nn.Sequential(OrderedDict([
301 | ('conv0', nn.Conv3d(3, num_init_features, kernel_size=(5, 7, 7), stride=(1, 2, 2) if args.net_version=='v1d3' else 2, padding=(2, 3, 3), bias=False)),
302 | ('norm0', BatchNorm3d(num_init_features)),
303 | ('relu0', nn.ReLU(inplace=True)),
304 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2) if args.net_version=='v1d3' or args.random_dense_sample_stride else 2, padding=1)),
305 | ]))
306 |
307 | # Each denseblock
308 | num_features = num_init_features
309 | downsample_pos = [-1] if args.net_version=='v1d3' else [0]
310 | for i, num_layers in enumerate(block_config):
311 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
312 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=self.drop_rate)
313 | self.features.add_module('denseblock%d' % (i + 1), block)
314 | num_features = num_features + num_layers * growth_rate
315 | if i != len(block_config) - 1:
316 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
317 | split=1,#split=num_layers+1,
318 | split_pattern=None, #split_pattern=[num_features - num_layers * growth_rate]+[growth_rate]*num_layers,
319 | temporal_pool_size=2 if i in downsample_pos else 1)
320 | self.features.add_module('transition%d' % (i + 1), trans)
321 | num_features = num_features // 2
322 |
323 | # Final batch norm
324 | self.features.add_module('norm5', BatchNorm3d(num_features))
325 |
326 | # Linear layer
327 | self.classifier = Linear(num_features, num_classes)
328 |
329 | # Official init from torch repo.
330 | for m in self.modules():
331 | if isinstance(m, nn.Conv3d):
332 | nn.init.kaiming_normal_(m.weight)
333 | elif isinstance(m, nn.BatchNorm3d):
334 | nn.init.constant_(m.weight, 0)
335 | nn.init.constant_(m.bias, 0)
336 | elif isinstance(m, nn.Linear):
337 | nn.init.constant_(m.bias, 0)
338 |
339 | def forward(self, x):
340 | features = self.features(x)
341 | out = F.relu(features, inplace=True)
342 | out = F.adaptive_avg_pool3d(out, (1, 1, 1)).view(features.size(0), -1)
343 | out = F.dropout(out, p=self.drop_rate, training=self.training)
344 | out = self.classifier(out)
345 | return out
346 |
347 |
348 | def _load_state_dict(model, model_url):
349 | # '.'s are no longer allowed in module names, but previous _DenseLayer
350 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
351 | # They are also in the checkpoints in model_urls. This pattern is used
352 | # to find such keys.
353 | pattern = re.compile(
354 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
355 | state_dict = model_zoo.load_url(model_url)
356 | for key in list(state_dict.keys()):
357 | res = pattern.match(key)
358 | if res:
359 | new_key = res.group(1) + res.group(2)
360 | state_dict[new_key] = state_dict[key]
361 | del state_dict[key]
362 | model.load_state_dict(state_dict)
363 |
364 |
365 | def modify_densenets(model):
366 | # Modify attributs
367 | model.last_linear = model.classifier
368 | del model.classifier
369 |
370 | def logits(self, features):
371 | x = F.relu(features, inplace=True)
372 | x = F.avg_pool2d(x, kernel_size=7, stride=1)
373 | x = x.view(x.size(0), -1)
374 | x = self.last_linear(x)
375 | return x
376 |
377 | def forward(self, input):
378 | x = self.features(input)
379 | x = self.logits(x)
380 | return x
381 |
382 | # Modify methods
383 | model.logits = types.MethodType(logits, model)
384 | model.forward = types.MethodType(forward, model)
385 | return model
386 |
387 |
388 | def _densenet121(num_classes, **kwargs):
389 | r"""Densenet-121 model from
390 | `"Densely Connected Convolutional Networks" `_
391 |
392 | Args:
393 | pretrained (bool): If True, returns a model pre-trained on ImageNet
394 | """
395 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes,
396 | **kwargs)
397 | return model
398 |
399 | def _densenet169(pretrained=False, **kwargs):
400 | r"""Densenet-121 model from
401 | `"Densely Connected Convolutional Networks" `_
402 |
403 | Args:
404 | pretrained (bool): If True, returns a model pre-trained on ImageNet
405 | """
406 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
407 | **kwargs)
408 | if pretrained:
409 | _load_state_dict(model, model_urls['densenet169'])
410 | return model
411 |
412 | def densenet121(num_classes=1000, pretrained='imagenet', drop_rate=0.0):
413 | r"""Densenet-121 model from
414 | `"Densely Connected Convolutional Networks" `
415 | """
416 | model = _densenet121(num_classes=num_classes, drop_rate=drop_rate)
417 | if pretrained is not None:
418 | settings = pretrained_settings['densenet121'][pretrained]
419 | model = load_pretrained(model, num_classes, settings)
420 | return model
421 |
422 |
423 | def densenet169(num_classes=1000, pretrained='imagenet'):
424 | r"""Densenet-121 model from
425 | `"Densely Connected Convolutional Networks" `
426 | """
427 | model = _densenet169(pretrained=False)
428 | if pretrained is not None:
429 | settings = pretrained_settings['densenet169'][pretrained]
430 | model = load_pretrained(model, num_classes, settings)
431 | #model = modify_densenets(model)
432 | return model
433 |
434 |
435 | def densenet201(pretrained=False, **kwargs):
436 | r"""Densenet-201 model from
437 | `"Densely Connected Convolutional Networks" `_
438 |
439 | Args:
440 | pretrained (bool): If True, returns a model pre-trained on ImageNet
441 | """
442 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
443 | **kwargs)
444 | if pretrained:
445 | _load_state_dict(model, model_urls['densenet201'])
446 | return model
447 |
448 |
449 | def densenet161(pretrained=False, **kwargs):
450 | r"""Densenet-161 model from
451 | `"Densely Connected Convolutional Networks" `_
452 |
453 | Args:
454 | pretrained (bool): If True, returns a model pre-trained on ImageNet
455 | """
456 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
457 | **kwargs)
458 | if pretrained:
459 | _load_state_dict(model, model_urls['densenet161'])
460 | return model
--------------------------------------------------------------------------------
/models/mobilenet_v2_3d.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.functional import F
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | from args import parser
6 | args = parser.parse_args()
7 |
8 | import torch
9 | import re
10 |
11 |
12 | __all__ = ['mobilenet_v2']
13 |
14 | model_urls = {
15 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
16 | }
17 |
18 | input_sizes = {}
19 | means = {}
20 | stds = {}
21 |
22 | for model_name in __all__:
23 | input_sizes[model_name] = [3, 224, 224]
24 | means[model_name] = [0.485, 0.456, 0.406]
25 | stds[model_name] = [0.229, 0.224, 0.225]
26 |
27 | pretrained_settings = {}
28 |
29 |
30 | for model_name in __all__:
31 | pretrained_settings[model_name] = {
32 | 'imagenet': {
33 | 'url': model_urls[model_name],
34 | 'input_space': 'RGB',
35 | 'input_size': input_sizes[model_name],
36 | 'crop_size': input_sizes[model_name][-1] * 256 // 224,
37 | 'input_range': [0, 1],
38 | 'mean': means[model_name],
39 | 'std': stds[model_name]
40 | #'num_classes': 174
41 | }
42 | }
43 |
44 |
45 | def update_state_dict(state_dict):
46 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
47 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
48 | # They are also in the checkpoints in model_urls. This pattern is used
49 | # to find such keys.
50 | """
51 | pattern = re.compile(
52 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
53 | for key in list(state_dict.keys()):
54 | res = pattern.match(key)
55 | if res:
56 | new_key = res.group(1) + res.group(2)
57 | state_dict[new_key] = state_dict[key]
58 | del state_dict[key]
59 | """
60 | # Inflate to 3d densenet
61 | pattern = re.compile(
62 | r'^(.*)((?:conv|bn)(?:[0123]?)\.(?:weight|bias|running_mean|running_var))$')
63 | for key in list(state_dict.keys()):
64 | if True:
65 | v = state_dict[key]
66 | if 'features.0.' in key:
67 | if 'features.0.0.weight' in key:
68 | v = torch.unsqueeze(v, dim=2)
69 | v = v.repeat([1, 1, 5, 1, 1])
70 | v /= 5.0
71 | state_dict[key] = v
72 | else:
73 | pass
74 | elif 'features.1.' in key:
75 | if 'conv.0' in key:
76 | if 'conv.0.0' in key:
77 | v = torch.unsqueeze(v, dim=2)
78 | new_key_btnk = key.replace('conv.0', 'depth_wise')
79 | state_dict[new_key_btnk] = v
80 | del state_dict[key]
81 | elif 'conv.1' in key:
82 | v = torch.unsqueeze(v, dim=2)
83 | new_key_btnk = key.replace('conv.1', 'point_wise')
84 | state_dict[new_key_btnk] = v
85 | del state_dict[key]
86 | else:
87 | assert 'conv.2' in key
88 | new_key_btnk = key.replace('conv.2', 'bn')
89 | state_dict[new_key_btnk] = v
90 | del state_dict[key]
91 | elif 'features.18.' in key:
92 | if 'features.18.0.weight' in key:
93 | v = torch.unsqueeze(v, dim=2)
94 | state_dict[key] = v
95 | else:
96 | pass
97 | elif 'classifier' in key:
98 | pass
99 | else:
100 | if 'conv.0.' in key:
101 | if 'conv.0.0.' in key:
102 | v = torch.unsqueeze(v, dim=2)
103 | new_key_btnk = key.replace('conv.0', 'bottleneck')
104 | state_dict[new_key_btnk] = v
105 | new_key_btnk = key.replace('conv.0', 'temporal')
106 | state_dict[new_key_btnk] = v.repeat([1, 1, 3, 1, 1]) / 3.0
107 | else:
108 | new_key_btnk = key.replace('conv.0', 'bottleneck')
109 | state_dict[new_key_btnk] = v
110 | new_key_btnk = key.replace('conv.0', 'temporal')
111 | state_dict[new_key_btnk] = v
112 | del state_dict[key]
113 | elif 'conv.1.' in key:
114 | if 'conv.1.0.' in key:
115 | v = torch.unsqueeze(v, dim=2)
116 | new_key_btnk = key.replace('conv.1', 'depth_wise')
117 | state_dict[new_key_btnk] = v
118 | del state_dict[key]
119 | elif 'conv.2.' in key:
120 | v = torch.unsqueeze(v, dim=2)
121 | new_key_btnk = key.replace('conv.2', 'point_wise')
122 | state_dict[new_key_btnk] = v
123 | del state_dict[key]
124 | else:
125 | assert 'conv.3.' in key
126 | new_key_btnk = key.replace('conv.3', 'bn')
127 | state_dict[new_key_btnk] = v
128 | del state_dict[key]
129 |
130 | return state_dict
131 |
132 |
133 | def load_pretrained(model, num_classes, settings):
134 | #assert num_classes == settings['num_classes'], \
135 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
136 | state_dict = model_zoo.load_url(settings['url'])
137 | state_dict = update_state_dict(state_dict)
138 | mk, uk = model.load_state_dict(state_dict, strict=False)
139 | print('mk: {}'.format(mk))
140 | print('uk: {}'.format(uk))
141 | model.input_space = settings['input_space']
142 | model.input_size = settings['input_size']
143 | model.input_range = settings['input_range']
144 | model.mean = settings['mean']
145 | model.std = settings['std']
146 | return model
147 |
148 |
149 | def _make_divisible(v, divisor, min_value=None):
150 | """
151 | This function is taken from the original tf repo.
152 | It ensures that all layers have a channel number that is divisible by 8
153 | It can be seen here:
154 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
155 | :param v:
156 | :param divisor:
157 | :param min_value:
158 | :return:
159 | """
160 | if min_value is None:
161 | min_value = divisor
162 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
163 | # Make sure that round down does not go down by more than 10%.
164 | if new_v < 0.9 * v:
165 | new_v += divisor
166 | return new_v
167 |
168 |
169 | class ConvBNReLU(nn.Sequential):
170 | def __init__(self, in_planes, out_planes, kernel_size=(1, 3, 3), stride=(1, 1, 1), groups=1):
171 |
172 | padding = tuple([(k - 1) // 2 for k in kernel_size])
173 | super(ConvBNReLU, self).__init__(
174 | nn.Conv3d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
175 | nn.BatchNorm3d(out_planes),
176 | nn.ReLU6(inplace=True)
177 | )
178 |
179 |
180 | class InvertedResidual(nn.Module):
181 | def __init__(self, inp, oup, stride, expand_ratio, modality=None):
182 | super(InvertedResidual, self).__init__()
183 | self.stride = stride
184 | self.modality = modality
185 | self.expand_ratio = expand_ratio
186 | assert stride in [1, 2]
187 |
188 | hidden_dim = int(round(inp * expand_ratio))
189 | self.use_res_connect = self.stride == 1 and inp == oup
190 |
191 | if expand_ratio != 1:
192 | if args.net_version == 'pure_spatial':
193 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1))
194 | elif args.net_version == 'pure_temporal':
195 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1))
196 | elif args.net_version == 'pure_fused':
197 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1))
198 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1))
199 | elif args.net_version == 'pure_adaptive':
200 | assert self.modality is not None
201 | if self.modality == 'fused':
202 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1))
203 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1))
204 | elif self.modality == 'spatial':
205 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1))
206 | else:
207 | assert self.modality == 'temporal'
208 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1))
209 | else:
210 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1))
211 | # pw
212 | self.depth_wise = ConvBNReLU(hidden_dim, hidden_dim, stride=(1, stride, stride), groups=hidden_dim)
213 | self.point_wise = nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False)
214 | self.bn = nn.BatchNorm3d(oup)
215 |
216 | def forward(self, x):
217 | if self.expand_ratio != 1:
218 | if args.net_version == 'pure_spatial':
219 | new_features = self.bottleneck(x)
220 | elif args.net_version == 'pure_temporal':
221 | new_features = self.temporal(x)
222 | elif args.net_version == 'pure_fused':
223 | new_features = self.bottleneck(x) + self.temporal(x)
224 | elif args.net_version == 'pure_adaptive':
225 | assert self.modality is not None
226 | if self.modality == 'fused':
227 | new_features = self.bottleneck(x) + self.temporal(x)
228 | elif self.modality == 'spatial':
229 | new_features = self.bottleneck(x)
230 | else:
231 | assert self.modality == 'temporal'
232 | new_features = self.temporal(x)
233 | else:
234 | new_features = self.bottleneck(x)
235 | else:
236 | new_features = x
237 | new_features = self.bn(self.point_wise(self.depth_wise(new_features)))
238 | if self.use_res_connect:
239 | return x + new_features
240 | else:
241 | return new_features
242 |
243 |
244 | class MobileNetV2(nn.Module):
245 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, drop_rate=0.0):
246 | """
247 | MobileNet V2 main class
248 |
249 | Args:
250 | num_classes (int): Number of classes
251 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
252 | inverted_residual_setting: Network structure
253 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
254 | Set to 1 to turn off rounding
255 | """
256 | super(MobileNetV2, self).__init__()
257 | block = InvertedResidual
258 | input_channel = 32
259 | last_channel = 1280
260 |
261 | if inverted_residual_setting is None:
262 | if 'something' in args.dataset:
263 | inverted_residual_setting = [
264 | # t, c, n, s, m
265 | [1, 16, 1, 1, 'fused'],
266 | [6, 24, 2, 2, 'fused'],
267 | [6, 32, 3, 2, 'fused'],
268 | [6, 64, 4, 2, 'temporal'],
269 | [6, 96, 3, 1, 'temporal'],
270 | [6, 160, 3, 2, 'fused'],
271 | [6, 320, 1, 1, 'fused'],
272 | ]
273 | else:
274 | inverted_residual_setting = [
275 | # t, c, n, s, m
276 | [1, 16, 1, 1, 'fused'],
277 | [6, 24, 2, 2, 'fused'],
278 | [6, 32, 3, 2, 'spatial'],
279 | [6, 64, 4, 2, 'fused'],
280 | [6, 96, 3, 1, 'spatial'],
281 | [6, 160, 3, 2, 'temporal'],
282 | [6, 320, 1, 1, 'fused'],
283 | ]
284 |
285 | # only check the first element, assuming user knows t,c,n,s are required
286 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 5:
287 | raise ValueError("inverted_residual_setting should be non-empty "
288 | "or a 5-element list, got {}".format(inverted_residual_setting))
289 |
290 | # building first layer
291 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
292 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
293 | features = [ConvBNReLU(3, input_channel, stride=2, kernel_size=(5, 3, 3))]
294 | # building inverted residual blocks
295 | for t, c, n, s, m in inverted_residual_setting:
296 | output_channel = _make_divisible(c * width_mult, round_nearest)
297 | for i in range(n):
298 | stride = s if i == 0 else 1
299 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, modality=m))
300 | input_channel = output_channel
301 | # building last several layers
302 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=(1, 1, 1)))
303 | # make it nn.Sequential
304 | self.features = nn.Sequential(*features)
305 |
306 | # building classifier
307 | self.new_classifier = nn.Sequential(
308 | nn.Dropout(drop_rate),
309 | nn.Linear(self.last_channel, num_classes),
310 | )
311 |
312 | # weight initialization
313 | for m in self.modules():
314 | if isinstance(m, nn.Conv2d):
315 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
316 | if m.bias is not None:
317 | nn.init.zeros_(m.bias)
318 | elif isinstance(m, nn.BatchNorm2d):
319 | nn.init.ones_(m.weight)
320 | nn.init.zeros_(m.bias)
321 | elif isinstance(m, nn.Linear):
322 | nn.init.normal_(m.weight, 0, 0.01)
323 | nn.init.zeros_(m.bias)
324 |
325 | def forward(self, x):
326 | x = self.features(x)
327 | x = F.adaptive_avg_pool3d(x, (1, 1, 1)).view(x.size(0), -1)
328 | x = self.new_classifier(x)
329 | return x
330 |
331 |
332 | def mobilenet_v2(pretrained='imagenet', progress=True, **kwargs):
333 | """
334 | Constructs a MobileNetV2 architecture from
335 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
336 |
337 | Args:
338 | pretrained (bool): If True, returns a model pre-trained on ImageNet
339 | progress (bool): If True, displays a progress bar of the download to stderr
340 | """
341 | model = MobileNetV2(**kwargs)
342 | if pretrained:
343 | settings = pretrained_settings['mobilenet_v2'][pretrained]
344 | model = load_pretrained(model, kwargs['num_classes'], settings)
345 | return model
346 |
347 |
348 | if __name__ == '__main__':
349 | mobilenet_v2(pretrained='imagenet', num_classes=1000)
--------------------------------------------------------------------------------
/models/resnet_3d.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | import torch.nn.functional as F
7 |
8 | from collections import OrderedDict
9 | #from functools import partial
10 |
11 | from args import parser
12 | #from NAS_utils.ops import Conv3d_with_CD, Linear_with_CD
13 |
14 | args = parser.parse_args()
15 |
16 | #Conv2d = Conv2d
17 | Conv3d = nn.Conv3d #partial(Conv3d_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode else False, training_size=args.training_size)
18 | Linear = nn.Linear #partial(Linear_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode else False, training_size=args.training_size)
19 | nnConv2d = nn.Conv2d
20 |
21 | _TEMPORAL_NASAS_ONLY = args.temporal_nasas_only
22 |
23 |
24 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101',
25 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
26 | 'wide_resnet50_2', 'wide_resnet101_2']
27 |
28 |
29 | model_urls = {
30 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
31 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
32 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
33 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
34 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
35 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
36 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
37 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
38 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
39 | }
40 |
41 | input_sizes = {}
42 | means = {}
43 | stds = {}
44 |
45 | for model_name in __all__:
46 | input_sizes[model_name] = [3, 224, 224]
47 | means[model_name] = [0.485, 0.456, 0.406]
48 | stds[model_name] = [0.229, 0.224, 0.225]
49 |
50 | pretrained_settings = {}
51 |
52 | for model_name in __all__:
53 | pretrained_settings[model_name] = {
54 | 'imagenet': {
55 | 'url': model_urls[model_name],
56 | 'input_space': 'RGB',
57 | 'input_size': input_sizes[model_name],
58 | 'crop_size': input_sizes[model_name][-1] * 256 // 224,
59 | 'input_range': [0, 1],
60 | 'mean': means[model_name],
61 | 'std': stds[model_name]
62 | #'num_classes': 174
63 | }
64 | }
65 |
66 |
67 | def update_state_dict(state_dict):
68 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
69 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
70 | # They are also in the checkpoints in model_urls. This pattern is used
71 | # to find such keys.
72 | """
73 | pattern = re.compile(
74 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
75 | for key in list(state_dict.keys()):
76 | res = pattern.match(key)
77 | if res:
78 | new_key = res.group(1) + res.group(2)
79 | state_dict[new_key] = state_dict[key]
80 | del state_dict[key]
81 | """
82 | # Inflate to 3d densenet
83 | pattern = re.compile(
84 | r'^(.*)((?:conv|bn)(?:[0123]?)\.(?:weight|bias|running_mean|running_var))$')
85 | for key in list(state_dict.keys()):
86 | res = pattern.match(key)
87 | if res:
88 | v = state_dict[key]
89 | if 'conv' in key:
90 | v = torch.unsqueeze(v, dim=2)
91 | if 'layer' not in key:
92 | v = v.repeat([1, 1, 5, 1, 1])
93 | v /= 5.0
94 | state_dict[key] = v
95 | elif 'conv1' in key:
96 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2)
97 | state_dict[new_key_btnk] = v
98 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2)
99 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0
100 | del state_dict[key]
101 | elif 'conv2' in key:
102 | state_dict[key] = v
103 | elif 'conv3' in key:
104 | state_dict[key] = v
105 | else:
106 | if 'bn1' in key:
107 | pass
108 | elif 'bn2' in key:
109 | pass
110 | else:
111 | pass
112 | if 'downsample' in key:
113 | v = state_dict[key]
114 | if 'downsample.0' in key:
115 | v = torch.unsqueeze(v, dim=2)
116 | state_dict[key] = v
117 | if 'fc' in key:
118 | del state_dict[key]
119 | return state_dict
120 |
121 |
122 | def load_pretrained(model, num_classes, settings):
123 | #assert num_classes == settings['num_classes'], \
124 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
125 | state_dict = model_zoo.load_url(settings['url'])
126 | state_dict = update_state_dict(state_dict)
127 | mk, uk = model.load_state_dict(state_dict, strict=False)
128 | model.input_space = settings['input_space']
129 | model.input_size = settings['input_size']
130 | model.input_range = settings['input_range']
131 | model.mean = settings['mean']
132 | model.std = settings['std']
133 | return model
134 |
135 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
136 | """3x3 convolution with padding"""
137 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
138 | padding=dilation, groups=groups, bias=False, dilation=dilation)
139 |
140 |
141 | def conv1x1(in_planes, out_planes, stride=1):
142 | """1x1 convolution"""
143 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
144 |
145 |
146 | class BasicBlock(nn.Module):
147 | expansion = 1
148 |
149 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
150 | base_width=64, dilation=1, norm_layer=None):
151 | super(BasicBlock, self).__init__()
152 | if norm_layer is None:
153 | norm_layer = nn.BatchNorm2d
154 | if groups != 1 or base_width != 64:
155 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
156 | if dilation > 1:
157 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
158 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
159 | self.conv1 = conv3x3(inplanes, planes, stride)
160 | self.bn1 = norm_layer(planes)
161 | self.relu = nn.ReLU(inplace=True)
162 | self.conv2 = conv3x3(planes, planes)
163 | self.bn2 = norm_layer(planes)
164 | self.downsample = downsample
165 | self.stride = stride
166 |
167 | def forward(self, x):
168 | identity = x
169 |
170 | out = self.conv1(x)
171 | out = self.bn1(out)
172 | out = self.relu(out)
173 |
174 | out = self.conv2(out)
175 | out = self.bn2(out)
176 |
177 | if self.downsample is not None:
178 | identity = self.downsample(x)
179 |
180 | out += identity
181 | out = self.relu(out)
182 |
183 | return out
184 |
185 |
186 | class Bottleneck(nn.Module):
187 | expansion = 4
188 |
189 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
190 | base_width=64, dilation=1, norm_layer=None, temporal_stride=1, enable_fuse=False, modality='temporal'):
191 | super(Bottleneck, self).__init__()
192 | self.enable_fuse=enable_fuse
193 | self.modality = modality
194 |
195 | if norm_layer is None:
196 | norm_layer = nn.BatchNorm3d
197 | width = int(planes * (base_width / 64.)) * groups
198 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
199 | if args.net_version == 'pure_fused':
200 | self.bottleneck = nn.Sequential(OrderedDict([
201 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False))
202 | ]))
203 | self.temporal = nn.Sequential(OrderedDict([
204 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
205 | ]))
206 | elif args.net_version == 'pure_spatial':
207 | self.bottleneck = nn.Sequential(OrderedDict([
208 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False))
209 | ]))
210 | elif args.net_version == 'pure_temporal':
211 | self.temporal = nn.Sequential(OrderedDict([
212 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
213 | ]))
214 | else:
215 | assert args.net_version == 'pure_adaptive', 'Unknown network version: {}'.format(args.net_version)
216 | if self.enable_fuse:
217 | self.bottleneck = nn.Sequential(OrderedDict([
218 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False))
219 | ]))
220 | self.temporal = nn.Sequential(OrderedDict([
221 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
222 | ]))
223 | else:
224 | if self.modality == 'temporal':
225 | self.temporal = nn.Sequential(OrderedDict([
226 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False))
227 | ]))
228 | else:
229 | self.bottleneck = nn.Sequential(OrderedDict([
230 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False))
231 | ]))
232 |
233 |
234 | self.bn1 = norm_layer(width)
235 |
236 | if temporal_stride != 1:
237 | self.temporal_pool = nn.AvgPool3d(kernel_size=(temporal_stride, 1, 1), stride=(temporal_stride, 1, 1))
238 |
239 | self.conv2 = Conv3d(width, width, groups=groups, dilation=dilation, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, 1, 1), bias=False)
240 | self.bn2 = norm_layer(width)
241 | self.conv3 = Conv3d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
242 | self.bn3 = norm_layer(planes * self.expansion)
243 | self.relu = nn.ReLU(inplace=True)
244 | self.downsample = downsample
245 | self.stride = stride
246 | self.temporal_stride = temporal_stride
247 |
248 | def forward(self, x):
249 | identity = x
250 |
251 | if self.temporal_stride != 1:
252 | out = self.temporal_pool(x)
253 | else:
254 | out = x
255 | if args.net_version == 'pure_fused':
256 | out = self.temporal(out) + self.bottleneck(out)
257 | elif args.net_version == 'pure_spatial':
258 | out = self.bottleneck(out)
259 | elif args.net_version == 'pure_temporal':
260 | out = self.temporal(out)
261 | else:
262 | assert args.net_version == 'pure_adaptive', 'Unknown network version: {}'.format(args.net_version)
263 | if self.enable_fuse:
264 | out = self.temporal(out) + self.bottleneck(out)
265 | else:
266 | if self.modality == 'temporal':
267 | out = self.temporal(out)
268 | else:
269 | out = self.bottleneck(out)
270 |
271 | out = self.bn1(out)
272 | out = self.relu(out)
273 |
274 | out = self.conv2(out)
275 | out = self.bn2(out)
276 | out = self.relu(out)
277 |
278 | out = self.conv3(out)
279 | out = self.bn3(out)
280 |
281 | if self.downsample is not None:
282 | identity = self.downsample(x)
283 |
284 | out += identity
285 | out = self.relu(out)
286 |
287 | return out
288 |
289 |
290 | class ResNet(nn.Module):
291 |
292 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
293 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
294 | norm_layer=None, drop_rate=0.0):
295 | super(ResNet, self).__init__()
296 | self.drop_rate = drop_rate
297 |
298 | if norm_layer is None:
299 | norm_layer = nn.BatchNorm3d
300 | self._norm_layer = norm_layer
301 |
302 | self.inplanes = 64
303 | self.dilation = 1
304 | if replace_stride_with_dilation is None:
305 | # each element in the tuple indicates if we should replace
306 | # the 2x2 stride with a dilated convolution instead
307 | replace_stride_with_dilation = [False, False, False]
308 | if len(replace_stride_with_dilation) != 3:
309 | raise ValueError("replace_stride_with_dilation should be None "
310 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
311 | self.groups = groups
312 | self.base_width = width_per_group
313 | self.conv1 = nn.Conv3d(3, self.inplanes, kernel_size=(5, 7, 7), stride=(2, 2, 2) if 'kinetics' in args.dataset or 'ucf' in args.dataset else (1, 2, 2), padding=(2, 3, 3),
314 | bias=False)
315 | self.bn1 = norm_layer(self.inplanes)
316 | self.relu = nn.ReLU(inplace=True)
317 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1))
318 | self.layer1 = self._make_layer(block, 64, layers[0], temporal_stride=1, enable_fuse=True)
319 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
320 | dilate=replace_stride_with_dilation[0], temporal_stride=1, enable_fuse=True)
321 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
322 | dilate=replace_stride_with_dilation[1], temporal_stride=1, enable_fuse=False)
323 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
324 | dilate=replace_stride_with_dilation[2], temporal_stride=1, enable_fuse=True)
325 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
326 | self.new_fc = Linear(512 * block.expansion, num_classes)
327 |
328 | for m in self.modules():
329 | if isinstance(m, nn.Conv3d):
330 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
331 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
332 | nn.init.constant_(m.weight, 1)
333 | nn.init.constant_(m.bias, 0)
334 | elif isinstance(m, nn.Linear):
335 | nn.init.constant_(m.bias, 0)
336 |
337 | # Zero-initialize the last BN in each residual branch,
338 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
339 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
340 | if zero_init_residual:
341 | for m in self.modules():
342 | if isinstance(m, Bottleneck):
343 | nn.init.constant_(m.bn3.weight, 0)
344 | elif isinstance(m, BasicBlock):
345 | nn.init.constant_(m.bn2.weight, 0)
346 |
347 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, temporal_stride=1, enable_fuse=False):
348 | norm_layer = self._norm_layer
349 | downsample = None
350 | previous_dilation = self.dilation
351 | if dilate:
352 | self.dilation *= stride
353 | stride = 1
354 | if stride != 1:
355 | downsample = nn.Sequential(OrderedDict([
356 | ('avepool', nn.AvgPool3d(kernel_size=(temporal_stride, 3, 3), stride=(temporal_stride, stride, stride), padding=(0, 1, 1))),
357 | ('0', Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=(1, 1, 1), bias=False)),
358 | ('1', norm_layer(planes * block.expansion))
359 | ]))
360 | elif self.inplanes != planes * block.expansion:
361 | assert temporal_stride==1, 'temporal stride != 1'
362 | downsample = nn.Sequential(OrderedDict([
363 | ('0', Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=(1, 1, 1), bias=False)),
364 | ('1', norm_layer(planes * block.expansion))
365 | ]))
366 | layers = []
367 | if blocks == 23:
368 | if 'kinetics' in args.dataset:
369 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
370 | self.base_width, previous_dilation, norm_layer, temporal_stride=temporal_stride,
371 | enable_fuse=False, modality='spatial'))
372 | self.inplanes = planes * block.expansion
373 | for _ in range(1, 5):
374 | layers.append(block(self.inplanes, planes, groups=self.groups,
375 | base_width=self.base_width, dilation=self.dilation,
376 | norm_layer=norm_layer, enable_fuse=False, modality='spatial'))
377 | for _ in range(5, 15):
378 | layers.append(block(self.inplanes, planes, groups=self.groups,
379 | base_width=self.base_width, dilation=self.dilation,
380 | norm_layer=norm_layer, enable_fuse=True))
381 | for _ in range(15, 19):
382 | layers.append(block(self.inplanes, planes, groups=self.groups,
383 | base_width=self.base_width, dilation=self.dilation,
384 | norm_layer=norm_layer, enable_fuse=False, modality='spatial'))
385 | for _ in range(19, blocks):
386 | layers.append(block(self.inplanes, planes, groups=self.groups,
387 | base_width=self.base_width, dilation=self.dilation,
388 | norm_layer=norm_layer, enable_fuse=False, modality='temporal'))
389 | else:
390 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
391 | self.base_width, previous_dilation, norm_layer, temporal_stride=temporal_stride,
392 | enable_fuse=False))
393 | self.inplanes = planes * block.expansion
394 | for _ in range(1, 10):
395 | layers.append(block(self.inplanes, planes, groups=self.groups,
396 | base_width=self.base_width, dilation=self.dilation,
397 | norm_layer=norm_layer, enable_fuse=True))
398 | for _ in range(10, blocks):
399 | layers.append(block(self.inplanes, planes, groups=self.groups,
400 | base_width=self.base_width, dilation=self.dilation,
401 | norm_layer=norm_layer, enable_fuse=False))
402 | else:
403 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
404 | self.base_width, previous_dilation, norm_layer, temporal_stride=temporal_stride,
405 | enable_fuse=enable_fuse))
406 | self.inplanes = planes * block.expansion
407 | for _ in range(1, blocks):
408 | layers.append(block(self.inplanes, planes, groups=self.groups,
409 | base_width=self.base_width, dilation=self.dilation,
410 | norm_layer=norm_layer, enable_fuse=enable_fuse))
411 |
412 | return nn.Sequential(*layers)
413 |
414 | def forward(self, x):
415 | x = self.conv1(x)
416 | x = self.bn1(x)
417 | x = self.relu(x)
418 | x = self.maxpool(x)
419 |
420 | x = self.layer1(x)
421 | x = self.layer2(x)
422 | x = self.layer3(x)
423 | x = self.layer4(x)
424 |
425 | x = F.adaptive_avg_pool3d(x, (1, 1, 1)).view(x.size(0), -1)
426 | x = F.dropout(x, p=self.drop_rate, training=self.training)
427 | x = self.new_fc(x)
428 |
429 | return x
430 |
431 |
432 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
433 | model = ResNet(block, layers, **kwargs)
434 | if pretrained is not None:
435 | settings = pretrained_settings[arch][pretrained]
436 | model = load_pretrained(model, kwargs['num_classes'], settings)
437 | return model
438 |
439 |
440 | def resnet18(pretrained=False, progress=True, **kwargs):
441 | r"""ResNet-18 model from
442 | `"Deep Residual Learning for Image Recognition" `_
443 |
444 | Args:
445 | pretrained (bool): If True, returns a model pre-trained on ImageNet
446 | progress (bool): If True, displays a progress bar of the download to stderr
447 | """
448 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
449 | **kwargs)
450 |
451 |
452 |
453 | def resnet34(pretrained=False, progress=True, **kwargs):
454 | r"""ResNet-34 model from
455 | `"Deep Residual Learning for Image Recognition" `_
456 |
457 | Args:
458 | pretrained (bool): If True, returns a model pre-trained on ImageNet
459 | progress (bool): If True, displays a progress bar of the download to stderr
460 | """
461 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
462 | **kwargs)
463 |
464 |
465 |
466 | def resnet50(pretrained='imagenet', progress=True, **kwargs):
467 | r"""ResNet-50 model from
468 | `"Deep Residual Learning for Image Recognition" `_
469 |
470 | Args:
471 | pretrained (str): pre-trained model
472 | progress (bool): If True, displays a progress bar of the download to stderr
473 | """
474 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
475 | **kwargs)
476 |
477 |
478 |
479 | def resnet101(pretrained=False, progress=True, **kwargs):
480 | r"""ResNet-101 model from
481 | `"Deep Residual Learning for Image Recognition" `_
482 |
483 | Args:
484 | pretrained (bool): If True, returns a model pre-trained on ImageNet
485 | progress (bool): If True, displays a progress bar of the download to stderr
486 | """
487 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
488 | **kwargs)
489 |
490 |
491 |
492 | def resnet152(pretrained=False, progress=True, **kwargs):
493 | r"""ResNet-152 model from
494 | `"Deep Residual Learning for Image Recognition" `_
495 |
496 | Args:
497 | pretrained (bool): If True, returns a model pre-trained on ImageNet
498 | progress (bool): If True, displays a progress bar of the download to stderr
499 | """
500 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
501 | **kwargs)
502 |
503 |
504 |
505 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
506 | r"""ResNeXt-50 32x4d model from
507 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
508 |
509 | Args:
510 | pretrained (bool): If True, returns a model pre-trained on ImageNet
511 | progress (bool): If True, displays a progress bar of the download to stderr
512 | """
513 | kwargs['groups'] = 32
514 | kwargs['width_per_group'] = 4
515 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
516 | pretrained, progress, **kwargs)
517 |
518 |
519 |
520 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
521 | r"""ResNeXt-101 32x8d model from
522 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
523 |
524 | Args:
525 | pretrained (bool): If True, returns a model pre-trained on ImageNet
526 | progress (bool): If True, displays a progress bar of the download to stderr
527 | """
528 | kwargs['groups'] = 32
529 | kwargs['width_per_group'] = 8
530 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
531 | pretrained, progress, **kwargs)
532 |
533 |
534 |
535 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
536 | r"""Wide ResNet-50-2 model from
537 | `"Wide Residual Networks" `_
538 |
539 | The model is the same as ResNet except for the bottleneck number of channels
540 | which is twice larger in every block. The number of channels in outer 1x1
541 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
542 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
543 |
544 | Args:
545 | pretrained (bool): If True, returns a model pre-trained on ImageNet
546 | progress (bool): If True, displays a progress bar of the download to stderr
547 | """
548 | kwargs['width_per_group'] = 64 * 2
549 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
550 | pretrained, progress, **kwargs)
551 |
552 |
553 |
554 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
555 | r"""Wide ResNet-101-2 model from
556 | `"Wide Residual Networks" `_
557 |
558 | The model is the same as ResNet except for the bottleneck number of channels
559 | which is twice larger in every block. The number of channels in outer 1x1
560 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
561 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
562 |
563 | Args:
564 | pretrained (bool): If True, returns a model pre-trained on ImageNet
565 | progress (bool): If True, displays a progress bar of the download to stderr
566 | """
567 | kwargs['width_per_group'] = 64 * 2
568 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
569 | pretrained, progress, **kwargs)
570 |
571 |
572 | if __name__ == '__main__':
573 | resnet50(pretrained='imagenet', num_classes=1000)
--------------------------------------------------------------------------------
/philly_distributed_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/philly_distributed_utils/__init__.py
--------------------------------------------------------------------------------
/philly_distributed_utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as op
3 | import numpy as np
4 | import subprocess
5 | from contextlib import contextmanager
6 | import logging
7 |
8 | def ompi_rank():
9 | """Find OMPI world rank without calling mpi functions
10 | :rtype: int
11 | """
12 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
13 |
14 |
15 | def ompi_size():
16 | """Find OMPI world size without calling mpi functions
17 | :rtype: int
18 | """
19 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
20 |
21 |
22 | def ompi_local_rank():
23 | """Find OMPI local rank without calling mpi functions
24 | :rtype: int
25 | """
26 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
27 |
28 |
29 | def ompi_local_size():
30 | """Find OMPI local size without calling mpi functions
31 | :rtype: int
32 | """
33 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE') or 1)
34 |
35 |
36 | @contextmanager
37 | def run_and_terminate_process(*args, **kwargs):
38 | """Run a process and terminate it at the end
39 | """
40 | p = None
41 | try:
42 | p = subprocess.Popen(*args, **kwargs)
43 | yield p
44 | finally:
45 | if not p:
46 | return
47 | try:
48 | p.terminate() # send sigterm
49 | except OSError:
50 | pass
51 | try:
52 | p.kill() # send sigkill
53 | except OSError:
54 | pass
55 |
56 |
57 | def get_gpus_nocache():
58 | """List of NVIDIA GPUs
59 | """
60 | cmds = 'nvidia-smi --query-gpu=name --format=csv,noheader'.split(' ')
61 | with run_and_terminate_process(cmds,
62 | stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
63 | bufsize=1) as process:
64 | return [line.strip() for line in iter(process.stdout.readline, "")]
65 |
66 |
67 | def get_gpus():
68 | """List of NVIDIA GPUs
69 | """
70 | return get_gpus_nocache()
71 |
72 |
73 | def gpu_indices(divisible=True):
74 | """Get the GPU device indices for this process/rank
75 | :param divisible: if GPU count of all ranks must be the same
76 | :rtype: list[int]
77 | """
78 | local_size = ompi_local_size()
79 | local_rank = ompi_local_rank()
80 | assert 0 <= local_rank < local_size, "Invalid local_rank: {} local_size: {}".format(local_rank, local_size)
81 | gpu_count = len(get_gpus())
82 | assert gpu_count >= local_size > 0, "GPU count: {} must be >= LOCAL_SIZE: {} > 0".format(gpu_count, local_size)
83 | if divisible:
84 | ngpu = gpu_count / local_size
85 | gpus = np.arange(local_rank * ngpu, (local_rank + 1) * ngpu)
86 | if gpu_count % local_size != 0:
87 | logging.warning("gpu_count: {} not divisible by local_size: {}; some GPUs may be unused".format(
88 | gpu_count, local_size
89 | ))
90 | else:
91 | gpus = np.array_split(range(gpu_count), local_size)[local_rank]
92 | return gpus
93 |
--------------------------------------------------------------------------------
/philly_distributed_utils/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as op
3 | import re
4 |
5 | def _vc_home():
6 | """Find philly's VC home in scratch space
7 | :rtype: str
8 | """
9 | home = os.environ.get('PHILLY_VC_NFS_DIRECTORY', os.environ.get('PHILLY_VC_DIRECTORY'))
10 | if not home:
11 | home = op.expanduser('~')
12 | home = '/'.join(home.split('/')[:5])
13 | return home
14 |
15 |
16 | _VC_HOME = _vc_home()
17 |
18 |
19 | def vc_name():
20 | """Find philly's VC name
21 | :rtype: str
22 | """
23 | name = os.environ.get('PHILLY_VC')
24 | if name:
25 | return name
26 | name = op.basename(_VC_HOME)
27 | if name:
28 | return name
29 | return op.basename(op.dirname(_VC_HOME))
30 |
31 |
32 | _VC_NAME = vc_name()
33 |
34 |
35 | def _vc_hdfs_base():
36 | base = os.environ.get("PHILLY_DATA_DIRECTORY") or os.environ.get("PHILLY_HDFS_PREFIX")
37 | if base:
38 | return base
39 | for base in ["/hdfs", "/home"]:
40 | if op.isdir(base):
41 | return base
42 | return _VC_HOME
43 |
44 |
45 | def vc_hdfs_root():
46 | """Find the HDFS root of the VC
47 | :rtype: str
48 | """
49 | path = os.environ.get('PHILLY_VC_HDFS_DIRECTORY')
50 | if path:
51 | return path
52 | path = op.join(os.environ.get('PHILLY_HDFS_PREFIX', _vc_hdfs_base()), _VC_NAME)
53 | return path
54 |
55 |
56 | _VC_HDFS_ROOT = vc_hdfs_root()
57 |
58 |
59 | def expand_vc_user(path):
60 | """Expand ~ to VC's home
61 | :param path: the path to expand VC user
62 | :type path: str
63 | :return:/var/storage/shared/$VC_NAME
64 | :rtype: str
65 | """
66 | if path.startswith('~'):
67 | path = op.abspath(op.join(_VC_HOME, '.' + path[1:]))
68 |
69 | return path
70 |
71 | def abspath(path, roots=None):
72 | """Expand ~ to VC's home and resolve relative paths to absolute paths
73 | :param path: the path to resolve
74 | :type path: str
75 | :param roots: CWD roots to resolve relative paths to them
76 | :type roots: list
77 | """
78 | path = expand_vc_user(path)
79 | if op.isabs(path):
80 | return path
81 | if not roots:
82 | roots = ["~"]
83 | roots = [expand_vc_user(root) for root in roots]
84 | for root in roots:
85 | resolved = op.abspath(op.join(root, path))
86 | if op.isfile(resolved) or op.isdir(resolved):
87 | return resolved
88 | # return assuming the first root (even though it does not exist)
89 | return op.abspath(op.join(roots[0], path))
90 |
91 |
92 | def job_id(path=None):
93 | """Get the philly job ID (from a path)
94 | :param path:Path to seach for app id
95 | :rtype: str
96 | """
97 | if path is None:
98 | return os.environ.get('PHILLY_JOB_ID') or job_id(op.expanduser('~'))
99 | m = re.search('/(?Papplication_[\d_]+)[/\w]*$', path)
100 | if m:
101 | return m.group('app_id')
102 | return ''
103 |
104 |
105 | def get_model_path(path=None):
106 | """Find the default location to output/models
107 | """
108 | return abspath(op.join('sys', 'jobs', job_id(path), 'models'), roots=[vc_hdfs_root()])
109 |
110 |
111 | def get_master_machine():
112 | mpi_host_file = op.expanduser('~/mpi-hosts')
113 | with open(mpi_host_file, 'r') as f:
114 | master_name = f.readline().strip()
115 | return master_name
116 |
117 |
118 | def get_master_ip(master_name=None):
119 | if master_name is None:
120 | master_name = get_master_machine()
121 | etc_host_file = '/etc/hosts'
122 | with open(etc_host_file, 'r') as f:
123 | name_ip_pairs = f.readlines()
124 | name2ip = {}
125 | for name_ip_pair in name_ip_pairs:
126 | pair_list = name_ip_pair.split(' ')
127 | key = pair_list[1].strip()
128 | value = pair_list[0]
129 | name2ip[key] = value
130 | return name2ip[master_name]
131 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/tools/__init__.py
--------------------------------------------------------------------------------
/tools/ckpt_checker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser(description="PyTorch implementation of NAS_spatiotemporal")
3 | parser.add_argument('--ckptpth', type=str, default='/mnt/log/NAS_spatiotemporal/checkpoint/NAS_sptp_nasas__ls250.0_something_RGB_Dense3D121_avg_segment1_e90_droprate0.0_num_dense_sample32_dense_sample_stride1_dense_Netv1d3Bz2by16Lr0.005SbnTpbtShare40-60-80/NAS_sptp_nasas_selection_ls250.0_something_RGB_Dense3D121_avg_segment1_e50_droprate0.0_num_dense_sample32_dense_Netv1d3Spbtsharels250temptSelect/ckpt.best.1.pth.tar')
4 |
5 | import os
6 | import numpy as np
7 | import csv
8 |
9 | def alpha_checker(state_dict, path=''):
10 | with open(os.path.join(path, 'p_log.csv'), mode='w') as csv_file:
11 | fields = ['index', 'S', 'T']
12 | csv_writer = csv.DictWriter(csv_file, fieldnames=fields)
13 | csv_writer.writeheader()
14 |
15 | records = {}
16 | for name, value in state_dict.items():
17 | if 'p_logit' in name and 'classifier' not in name:
18 | print('{} {}'.format(name, value.sigmoid().item()))
19 | name = name.replace('module.features.', '')
20 | name = name.replace('.conv1.p_logit', '')
21 | name = name.replace('.conv.p_logit', '')
22 | name = name.replace('bottleneck', 'S')
23 | name = name.replace('temporal', 'T')
24 | name = name.replace('denseblock', 'B')
25 | name = name.replace('denselayer', 'L')
26 | name = name.replace('original', 'S')
27 | if '.S' in name:
28 | if name.replace('.S', '') in records.keys():
29 | records[name.replace('.S', '')]['S'] = value.sigmoid().item()
30 | else:
31 | records[name.replace('.S', '')] = {}
32 | records[name.replace('.S', '')]['S'] = value.sigmoid().item()
33 | #csv_writer.writerow({'index': name, 'S': value.sigmoid().item(), 'T': state_dict[name.replace('.S', '.T')].sigmoid().item()})
34 | elif '.T' in name:
35 | if name.replace('.T', '') in records.keys():
36 | records[name.replace('.T', '')]['T'] = value.sigmoid().item()
37 | else:
38 | records[name.replace('.T', '')] = {}
39 | records[name.replace('.T', '')]['T'] = value.sigmoid().item()
40 | else:
41 | pass
42 |
43 | for name, value in records.items():
44 | csv_writer.writerow({'index': name, 'S': value['S'], 'T': value['T']})
45 |
46 |
47 | def sptp_checker(state_dict):
48 | t_count = 0
49 | s_count = 0
50 | for name, value in state_dict.items():
51 | if 'p_logit' in name and 'classifier' not in name:
52 | #print('{}: {}'.format(name, value.sigmoid()))
53 | stensor = 1- np.floor(state_dict[name.replace('p_logit', 'unif_noise_variable')].cpu().item() + value.sigmoid().cpu().item())
54 | if int(stensor) == 0:
55 | if 'temporal' in name:
56 | t_count += 1
57 | elif 'bottleneck' or 'original' in name:
58 | s_count += 1
59 | print('{}: {}'.format(name, stensor))
60 | #if 'norm' in name:
61 | # print('{}: {}'.format(name, value))
62 | print('{}: {}'.format('t_count', t_count))
63 | print('{}: {}'.format('s_count', s_count))
64 |
65 |
66 | def main():
67 | args = parser.parse_args()
68 | import torch
69 | checkpoint = torch.load(os.path.join(args.ckptpth, 'ckpt.best.1.pth.tar'), map_location='cpu')
70 | state_dict = checkpoint['state_dict']
71 | #alpha_checker(state_dict, path=args.ckptpth)
72 | sptp_checker(state_dict)
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
--------------------------------------------------------------------------------
/tools/generate_label_sthsthv1.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 | # ------------------------------------------------------
6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py
7 | # processing the raw data of the video Something-Something-V1
8 |
9 | import os
10 |
11 | root_path = '/mnt/data/somethingsomethingv1/'
12 |
13 | if __name__ == '__main__':
14 | os.chdir(root_path)
15 | dataset_name = 'something-something-v1' # 'jester-v1'
16 | with open('%s-labels.csv' % dataset_name) as f:
17 | lines = f.readlines()
18 | categories = []
19 | for line in lines:
20 | line = line.rstrip()
21 | categories.append(line)
22 | categories = sorted(categories)
23 | with open('category.txt', 'w') as f:
24 | f.write('\n'.join(categories))
25 |
26 | dict_categories = {}
27 | for i, category in enumerate(categories):
28 | dict_categories[category] = i
29 |
30 | files_input = ['%s-validation.csv' % dataset_name, '%s-train.csv' % dataset_name]
31 | files_output = ['val_videofolder_azure.txt', 'train_videofolder_azure.txt']
32 | for (filename_input, filename_output) in zip(files_input, files_output):
33 | with open(filename_input) as f:
34 | lines = f.readlines()
35 | folders = []
36 | idx_categories = []
37 | for line in lines:
38 | line = line.rstrip()
39 | items = line.split(';')
40 | folders.append(items[0])
41 | idx_categories.append(dict_categories[items[1]])
42 | output = []
43 | for i in range(len(folders)):
44 | curFolder = folders[i]
45 | curIDX = idx_categories[i]
46 | # counting the number of frames in each video folders
47 | dir_files = os.listdir(os.path.join('./img', curFolder))
48 | output.append('%s %d %d' % (os.path.join('', curFolder), len(dir_files), curIDX))
49 | print('%d/%d' % (i, len(folders)))
50 | with open(filename_output, 'w') as f:
51 | f.write('\n'.join(output))
--------------------------------------------------------------------------------
/tools/generate_label_ucf101.py:
--------------------------------------------------------------------------------
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
2 | # arXiv:1811.08383
3 | # Ji Lin*, Chuang Gan, Song Han
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
5 | # ------------------------------------------------------
6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py
7 | # processing the raw data of the video Something-Something-V1
8 |
9 | import os
10 |
11 | root_path = '/home/sda/data-writable/ucf101/'
12 |
13 | if __name__ == '__main__':
14 | os.chdir(root_path)
15 | dataset_name = 'ucf101'
16 | with open('%s-labels.txt' % dataset_name) as f:
17 | lines = f.readlines()
18 | dict_categories = {}
19 | for line in lines:
20 | line = line.rstrip()
21 | line = line.split( )
22 | dict_categories[line[1]] = int(line[0])-1
23 |
24 |
25 | files_input = ['testlist01.txt', 'trainlist01.txt']
26 | files_output = ['val_videofolder.txt', 'train_videofolder.txt']
27 | for (filename_input, filename_output) in zip(files_input, files_output):
28 | with open(filename_input) as f:
29 | lines = f.readlines()
30 | folders = []
31 | idx_categories = []
32 | for line in lines:
33 | line = line.rstrip()
34 | items = line.split('.')
35 | folders.append(items[0])
36 | idx_categories.append(dict_categories[items[0].split('/')[0]])
37 | output = []
38 | for i in range(len(folders)):
39 | curFolder = folders[i]
40 | curIDX = idx_categories[i]
41 | # counting the number of frames in each video folders
42 | dir_files = os.listdir(os.path.join('./UCF-101_image', curFolder, 'i'))
43 | output.append('%s %d %d' % (os.path.join('', curFolder), len(dir_files), curIDX))
44 | print('%d/%d' % (i, len(folders)))
45 | with open(filename_output, 'w') as f:
46 | f.write('\n'.join(output))
--------------------------------------------------------------------------------
/tools/statistics.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | from thop import profile
5 | from models.densenet_3d_forstat import densenet121
6 | from args import parser
7 |
8 | import torch
9 |
10 | args = parser.parse_args()
11 | model = densenet121(num_classes=174, pretrained=None, drop_rate=0.5)
12 | input = torch.randn(1, 3, 128, 256, 256)
13 | flops, params = profile(model, inputs=(input, ))
14 |
15 | print(flops)
16 | print(params)
--------------------------------------------------------------------------------
/tools/to_hdf5.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | import h5py
5 | from zipfile import ZipFile
6 | import numpy as np
7 |
8 | from PIL import Image
9 |
10 |
11 | class VideoRecord(object):
12 | def __init__(self, row):
13 | self._data = row
14 |
15 | @property
16 | def path(self):
17 | return self._data[0]
18 |
19 | @property
20 | def num_frames(self):
21 | return int(self._data[1])
22 |
23 | @property
24 | def label(self):
25 | return int(self._data[2])
26 |
27 |
28 | class Converter(object):
29 | def __init__(self, root_path, list_file, target_path, image_tmpl):
30 | self._list_file = list_file
31 | self._root_path = root_path
32 | self._target_path = target_path
33 | self._image_tmpl = image_tmpl
34 |
35 | self._parse_list()
36 |
37 | def _parse_list(self):
38 | tmp = [x.strip().split(' ') for x in open(self._list_file)]
39 | tmp = [[' '.join(x[:-2]), x[-2], x[-1]] for x in tmp]
40 | self._video_list = [VideoRecord(item) for item in tmp]
41 |
42 | if self._image_tmpl == '{:06d}-{}_{:05d}.jpg':
43 | for v in self._video_list:
44 | v._data[1] = int(v._data[1]) / 2
45 | print('video number:%d' % (len(self._video_list)))
46 |
47 | def _full_path(self, directory, idx):
48 | return os.path.join(self._root_path, directory, self._image_tmpl.format(idx))
49 |
50 | def _load_image(self, directory, idx):
51 | try:
52 | return Image.open(os.path.join(self._root_path, directory, self._image_tmpl.format(idx))).convert('RGB')
53 | except Exception:
54 | print('error loading image:', os.path.join(self._root_path, directory, self._image_tmpl.format(idx)))
55 | return Image.open(os.path.join(self._root_path, directory, self._image_tmpl.format(1))).convert('RGB')
56 |
57 | def convert(self):
58 | raise NotImplementedError()
59 |
60 |
61 | class HDF5Converter(Converter):
62 | def __init__(self, root_path, list_file, target_path, image_tmpl):
63 | super(HDF5Converter, self).__init__(root_path, list_file, target_path, image_tmpl)
64 |
65 | def convert(self):
66 | for record in self._video_list:
67 | if not os.path.exists(os.path.join(self._target_path, record.path)):
68 | os.makedirs(os.path.join(self._target_path, record.path))
69 | assert not os.path.exists(os.path.join(self._target_path, record.path, 'RGB_frames')), "{} already exist".format(os.path.join(self._target_path, record.path, 'RGB_frames'))
70 | with h5py.File(os.path.join(self._target_path, record.path, 'RGB_frames'), 'w') as hdf:
71 | for idx in range(record.num_frames):
72 | img = np.asarray((self._load_image(record.path, idx+1)), dtype="uint8")
73 | hdf.create_dataset(record.path+"/"+self._image_tmpl.format(idx+1), data=img, dtype="uint8")
74 | print("{} Done".format(record.path))
75 |
76 |
77 | class ZIPConverter(Converter):
78 | def __init__(self, root_path, list_file, target_path, image_tmpl):
79 | super(ZIPConverter, self).__init__(root_path, list_file, target_path, image_tmpl)
80 |
81 | def convert(self):
82 | _video_num = len(self._video_list)
83 | for i, record in enumerate(self._video_list):
84 | if not os.path.exists(os.path.join(self._target_path, record.path)):
85 | os.makedirs(os.path.join(self._target_path, record.path))
86 | assert not os.path.exists(os.path.join(self._target_path, record.path, 'RGB_frames.zip')), "{} already exist".format(os.path.join(self._target_path, record.path, 'RGB_frames.zip'))
87 | with ZipFile(os.path.join(self._target_path, record.path, 'RGB_frames.zip'), 'w') as zipf:
88 | for idx in range(record.num_frames):
89 | #img = np.asarray((self._load_image(record.path, idx+1)), dtype="uint8")
90 | zipf.write(self._full_path(record.path, idx+1), arcname=self._image_tmpl.format(idx+1))
91 | print("{} of {} ({}) Done".format(str(i), str(_video_num), record.path))
92 |
93 |
94 | def main(list_file):
95 | root_path = "/home/sda/data-writable/something-something/"
96 | target_path = "/home/sdb/writable/20bn-something-something-v1_zip/"
97 | #list_file = "/home/sda/data-writable/kinetics400_frame/val_videofolder.txt"
98 | list_file = os.path.join(root_path, list_file)
99 |
100 | cvt = ZIPConverter(os.path.join(root_path, '20bn-something-something-v1'), list_file, target_path, '{:05d}.jpg')
101 | cvt.convert()
102 |
103 |
104 | if __name__ == '__main__':
105 | main(sys.argv[1])
106 |
107 |
--------------------------------------------------------------------------------
/tools/visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 | import re
6 | import os
7 |
8 | def visualization(args):
9 | log_path = os.path.join(args.root_log, args.store_name)
10 | checkpoint_path = os.path.join(args.root_model, args.store_name)
11 | current_checkpoint_path = os.path.join(checkpoint_path, "ckpt.pth.tar")
12 | best_checkpoint_path = os.path.join(checkpoint_path, "ckpt.best.pth.tar")
13 |
14 | p_record_path = os.path.join(log_path, "p_record.txt")
15 | if os.path.isfile(p_record_path):
16 | # 不是第一次
17 | try:
18 | with open(p_record_path, 'r') as f:
19 | p_record = int(f.read())
20 | assert 0 <= p_record <= 100000, "p_record out of range"
21 | print("history record p: ", p_record)
22 | except (IOError, ValueError) as E:
23 | print(E)
24 | print("the p_record.txt file is suddenly missing")
25 | p_record = 0
26 | else:
27 | # 第一次
28 | p_record = 0
29 | print("init record p: ", p_record)
30 |
31 |
32 | # get data
33 | try:
34 | files = os.listdir(log_path)
35 | log_pattern = re.compile('(log)(-*\d+)(\.csv)')
36 | log_number = re.compile('-*\d+')
37 | log_numbers = []
38 | for f in files:
39 | if log_pattern.match(f):
40 | this_log_number = int(log_number.findall(f)[0])
41 | log_numbers.append(this_log_number)
42 | min_log_number = min(log_numbers)
43 | min_log = 'log' + str(min_log_number) + '.csv'
44 | data = pd.read_csv(os.path.join(log_path, min_log), delimiter='\n', engine='python', header=None,
45 | error_bad_lines=False)
46 | data = [str(x) for x in data.values]
47 | test = [re.findall(r'-?\d+\.?\d*e?-?\d*?', x) for x in data if "Testing" in x]
48 | train = [re.findall(r'-?\d+\.?\d*e?-?\d*?', x) for x in data if "Worker" in x]
49 | y_val = [float(x[1]) for x in test]
50 | y_train_batch = [float(x[-5]) for x in train]
51 | y_train_avg_epoch = []
52 |
53 | loss_val = [float(x[-1]) for x in test]
54 | loss_train_batch = [float(x[9]) for x in train]
55 | CE_loss_train_batch = [float(x[11]) for x in train]
56 | KL_loss_train_batch = [float(x[13]) for x in train]
57 | loss_train_avg_epoch = []
58 | CE_loss_train_avg_epoch = []
59 | KL_loss_train_avg_epoch = []
60 |
61 | for i, tmp in enumerate(train):
62 | if float(tmp[-5]) == float(tmp[-4]) and i > 0:
63 | y_train_avg_epoch.append(float(train[i - 1][-4]))
64 | loss_train_avg_epoch.append(float(train[i - 1][10]))
65 | CE_loss_train_avg_epoch.append(float(train[i - 1][12]))
66 | KL_loss_train_avg_epoch.append(float(train[i - 1][14]))
67 |
68 | num_epochs = len(y_train_avg_epoch)
69 | num_batchs = len(y_train_batch)
70 | num_vals = len(y_val)
71 | x_train_avg_epoch = np.array(range(num_epochs))
72 | x_train_batch = np.array(range(num_batchs)) * (num_epochs / num_batchs)
73 | x_val = np.array(range(num_vals)) * (num_epochs / num_vals)
74 |
75 | best_val_y = max(y_val)
76 | best_val_x = x_val[y_val.index(best_val_y)]
77 |
78 | if True:
79 | current_checkpoint = torch.load(current_checkpoint_path, map_location=lambda storage, loc: storage)
80 | best_checkpoint = torch.load(best_checkpoint_path, map_location=lambda storage, loc: storage)
81 | current_state_dict = current_checkpoint['state_dict']
82 | best_state_dict = best_checkpoint['state_dict']
83 | assert current_state_dict.keys() == best_state_dict.keys()
84 | current_p_logit = [(x, torch.sigmoid(torch.tensor(float(current_state_dict[x])))) for x in
85 | current_state_dict.keys() if
86 | "p_logit" in x]
87 | best_p_logit = [(x, torch.sigmoid(torch.tensor(float(best_state_dict[x])))) for x in
88 | best_state_dict.keys() if
89 | "p_logit" in x]
90 | X_p = list(range(len(current_p_logit)))
91 | X_p_ticks = [loc for loc, value in current_p_logit]
92 | Y_p_current = np.array([value.item() for loc, value in current_p_logit])
93 | Y_p_best = np.array([value.item() for loc, value in best_p_logit])
94 | except Exception as e:
95 | print("visual exception: ", e)
96 | return
97 |
98 | else:
99 | print("log and checkpoint data load success, generating the result picture")
100 | # result
101 | plt.figure(figsize=(20, 10))
102 |
103 | plt.subplot(121)
104 | plt.title("prec1@{}".format("something"))
105 | plt.plot(x_train_batch, y_train_batch, label="train batchs")
106 | plt.plot(x_train_avg_epoch, y_train_avg_epoch, marker='*', label="train epochs average")
107 | plt.plot(x_val, y_val, marker='o', label="test per {} epochs".format(round(num_epochs / num_vals)))
108 | plt.annotate('best: {}'.format(best_val_y),
109 | xy=(best_val_x, best_val_y),
110 | xycoords='data',
111 | xytext=(50, 50),
112 | textcoords='offset points',
113 | fontsize=16,
114 | arrowprops=dict(arrowstyle='->', connectionstyle="arc3, rad=.2"))
115 |
116 | plt.xlabel("epochs")
117 | plt.ylabel("top1%")
118 | plt.legend(loc="best")
119 |
120 | plt.subplot(122)
121 | plt.title("loss@{}".format("something"))
122 | plt.plot(x_train_batch, loss_train_batch, label="train batchs' loss")
123 | plt.plot(x_train_batch, CE_loss_train_batch, label="train batchs' CrossEntropy loss")
124 | plt.plot(x_train_batch, KL_loss_train_batch, label="train batchs' KL loss")
125 | plt.plot(x_train_avg_epoch, loss_train_avg_epoch, marker="*", label="train epoch avg's loss")
126 | plt.plot(x_train_avg_epoch, CE_loss_train_avg_epoch, marker="*", label="train epoch avg's CE loss")
127 | plt.plot(x_train_avg_epoch, KL_loss_train_avg_epoch, marker="*", label="train epoch avg's KL loss")
128 | plt.plot(x_val, loss_val, marker="*", label="test per {} epochs' loss".format(round(num_epochs / num_vals)))
129 | plt.xlabel("epochs")
130 | plt.ylabel("loss")
131 |
132 | plt.legend(loc='best')
133 |
134 | plt.savefig(os.path.join(log_path, "result.png"), bbox_inches='tight')
135 |
136 | if True:
137 |
138 | # current p
139 | p_record += 1
140 | with open(p_record_path, "w") as f:
141 | f.write(str(p_record))
142 |
143 | plt.figure(figsize=(20, 10))
144 | plt.bar(X_p, Y_p_current, facecolor='#ff9800')
145 | plt.plot(X_p, Y_p_current, marker='^', markersize=15, linewidth=3)
146 | plt.title("p to drop connection")
147 | # 显示数据
148 | for x, y in zip(X_p, Y_p_current):
149 | plt.text(x, y + 0.03,
150 | '%.2f' % y, ha='center', va='bottom',
151 | fontdict={'color': '#0091ea',
152 | 'size': 16})
153 | plt.ylim(0., 1.)
154 | plt.xticks(X_p, X_p_ticks, size="small", rotation=85)
155 | plt.yticks([])
156 | plt.savefig(os.path.join(log_path, "current{}_p.png".format(p_record)), bbox_inches='tight')
157 |
158 | # best p
159 | plt.figure(figsize=(20, 10))
160 | plt.bar(X_p, Y_p_best, facecolor='#0000FF')
161 | plt.plot(X_p, Y_p_best, marker='^', markersize=15, linewidth=3)
162 | plt.title("p to drop connection")
163 | # 显示数据
164 | for x, y in zip(X_p, Y_p_best):
165 | plt.text(x, y + 0.03,
166 | '%.2f' % y, ha='center', va='bottom',
167 | fontdict={'color': '#0091ea',
168 | 'size': 16})
169 | plt.ylim(0., 1.)
170 | plt.xticks(X_p, X_p_ticks, size="small", rotation=85)
171 | plt.yticks([])
172 | plt.savefig(os.path.join(log_path, "best_p.png"), bbox_inches='tight')
173 |
174 | plt.close('all')
175 |
176 | if __name__ == '__main__':
177 | import argparse
178 | parser = argparse.ArgumentParser(description="PyTorch implementation of NAS_spatiotemporal")
179 | parser.add_argument('--root_log', type=str, default='/mnt/log/NAS_spatiotemporal/log')
180 | parser.add_argument('--root_model', type=str, default='/mnt/log/NAS_spatiotemporal/checkpoint')
181 | parser.add_argument('--store_name', type=str, default="")
182 |
183 | args = parser.parse_args()
184 | visualization(args)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def softmax(scores):
5 | es = np.exp(scores - scores.max(axis=-1)[..., None])
6 | return es / es.sum(axis=-1)[..., None]
7 |
8 |
9 | class AverageMeter(object):
10 | """Computes and stores the average and current value"""
11 |
12 | def __init__(self):
13 | self.reset()
14 |
15 | def reset(self):
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 |
28 | def accuracy(output, target, topk=(1,)):
29 | """Computes the precision@k for the specified values of k"""
30 | maxk = max(topk)
31 | batch_size = target.size(0)
32 |
33 | _, pred = output.topk(maxk, 1, True, True)
34 | pred = pred.t()
35 | correct = pred.eq(target.view(1, -1).expand_as(pred))
36 |
37 | res = []
38 | for k in topk:
39 | correct_k = correct[:k].view(-1).float().sum(0)
40 | res.append(correct_k.mul_(100.0 / batch_size))
41 | return res
--------------------------------------------------------------------------------