├── README.md ├── TCM.py ├── imgs └── TCM.jpg ├── main_something.py ├── models.py ├── ops ├── __init__.py ├── basic_ops.py ├── dataset.py ├── resnet.py ├── transforms.py ├── utils.py └── video_dataset.py ├── opts.py ├── requirements.txt ├── resnet_TSM.py └── tsm_util.py /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Correlation Module 2 | ## Overview 3 | We release the PyTorch code of the Temporal Correlation Module. 4 | ![tcm](imgs/TCM.jpg) 5 | 6 | ## Requirements 7 | * Python >= 3.6 8 | * Pytorch >= 1.2 with CUDA 9 | * [Pytorch-Correlation-extension](https://github.com/ClementPinard/Pytorch-Correlation-extension) 10 | * pip install -r requirements.txt 11 | 12 | ## Data Preparation 13 | The data preparation is the same as [TSM](https://github.com/mit-han-lab/temporal-shift-module), please refer to [TSM](https://github.com/mit-han-lab/temporal-shift-module) repo for the detailed guide of data pre-processing. 14 | 15 | ## Training 16 | 17 | For training TCM with TSM_RseNet50 on something-something V1 with 8-segments input (4 NVIDIA 2080TI GPUS used here): 18 | ``` 19 | export CUDA_VISIBLE_DEVICES=0,1,2,3 20 | python -u main_something.py something RGB \ 21 | "/path/to/somethingV1/train_videofolder.txt" "/path/to/somethingV1/val_videofolder_latest.txt" \ 22 | --arch "TCM_resnet50" --num_segments 8 --mode 1 --gd 200 --lr 0.01 --lr_steps 30 40 45 --epochs 50 -b 32 -i 1 -j 4 --dropout 0.5 \ 23 | --snapshot_pref logs --consensus_type avg --eval-freq 1 --rgb_prefix "" \ 24 | --no_partialbn --val_output_folder logs -p 20 --nesterov "True" \ 25 | 2>&1 | tee -a logs/log.txt 26 | ``` 27 | 28 | **Note the learning rate should match the batch size. For example, if you increase the batch size from 32 to 64, the corresponding learning rate should be 0.01\*(64/32)=0.02** 29 | 30 | ## Testing 31 | For testing on something-something V1 with 8-segments input: 32 | ``` 33 | python -u main_something.py something RGB \ 34 | "/path/to/somethingV1/train_videofolder.txt" "/path/to/somethingV1/val_videofolder_latest.txt" \ 35 | --arch "TCM_resnet50" --num_segments 8 --mode 1 --gd 200 --lr 0.01 --lr_steps 30 40 45 --epochs 50 -b 32 -i 1 -j 4 --dropout 0.5 \ 36 | --snapshot_pref logs --consensus_type avg --eval-freq 1 --rgb_prefix "" \ 37 | --no_partialbn --val_output_folder logs -p 20 --nesterov "True" \ 38 | --evaluate --resume "/PATH/TO/SAVED/WEIGHT/FILE" \ 39 | 2>&1 | tee -a logs/log.txt 40 | ``` 41 | 42 | 43 | ## Pre-trained Models 44 | Since we have reorganized the structure of the codes and renamed the modules of TCM for public release, the old models cannot be loaded with new names. We plan to retrain the models with the new codes and release the models for evaluation. 45 | 46 | Currently available pre-trained models: 47 | 48 | | Model | Backbone | Pretrained | Input | Dataset | Top-1val | pth | 49 | | ------- | ------------ | ---------- | ----- | ------------ | ------------- | ------------------------------------------------------------ | 50 | | TCM-R50 | TSM-ResNet50 | ImageNet | 8 | Something V1 | 52.2 | [sthv1_8f_best.pth](https://github.com/zphyix/TCM/releases/download/v1/sthv1_8f_best.pth) | 51 | 52 | -------------------------------------------------------------------------------- /TCM.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from spatial_correlation_sampler import SpatialCorrelationSampler 5 | import ipdb 6 | 7 | 8 | class TCM(nn.Module): # # Multi-scale Temporal Dynamics Module 9 | 10 | def __init__(self, num_segments, expansion = 1, pos=2): 11 | super(TCM, self).__init__() 12 | self.num_segments = num_segments 13 | self.mtdm = MTDM(num_segments, expansion=expansion, pos=pos) 14 | self.tam = TAM(num_segments=num_segments, expansion=expansion, pos=pos) 15 | 16 | def forward(self,x): 17 | out = self.mtdm(x) 18 | out = self.tam(out) 19 | return x + out 20 | 21 | 22 | class TAM(nn.Module): # Temporal Attention Module 23 | def __init__(self, num_segments, expansion = 1, pos=2): 24 | super(TAM, self).__init__() 25 | self.num_segments = num_segments 26 | self.expansion = expansion 27 | self.pos = pos 28 | self.out_channel = 64*(2**(self.pos-1))*self.expansion 29 | self.c1 = 16 30 | self.c2 = 32 31 | self.c3 = 64 32 | 33 | self.conv1 = nn.Sequential( 34 | nn.Conv2d(6, 6, kernel_size=3, stride=1, padding=1, groups=3, bias=False), 35 | nn.BatchNorm2d(6), 36 | nn.ReLU(), 37 | nn.Conv2d(6, 6, kernel_size=3, stride=1, padding=1, groups=3, bias=False), 38 | nn.BatchNorm2d(6), 39 | nn.ReLU(), 40 | nn.Conv2d(6, 6, kernel_size=3, stride=1, padding=1, groups=3, bias=False), 41 | nn.BatchNorm2d(6), 42 | nn.ReLU(), 43 | nn.Conv2d(6, self.c1, kernel_size=1, stride=1, padding=0, bias=False), 44 | nn.BatchNorm2d(self.c1), 45 | nn.ReLU() 46 | ) 47 | 48 | self.conv2 = nn.Sequential( 49 | nn.Conv2d(self.c1, self.c1, kernel_size=3, stride=1, padding=1, groups=self.c1, bias=False), 50 | nn.BatchNorm2d(self.c1), 51 | nn.ReLU(), 52 | nn.Conv2d(self.c1, self.c2, kernel_size=1, stride=1, padding=0, bias=False), 53 | nn.BatchNorm2d(self.c2), 54 | nn.ReLU() 55 | ) 56 | self.conv3 = nn.Sequential( 57 | nn.Conv2d(self.c2, self.c2, kernel_size=3, stride=1, padding=1, groups=self.c2, bias=False), 58 | nn.BatchNorm2d(self.c2), 59 | nn.ReLU(), 60 | nn.Conv2d(self.c2, self.c3, kernel_size=1, stride=1, padding=0, bias=False), 61 | nn.BatchNorm2d(self.c3), 62 | nn.ReLU() 63 | ) 64 | self.conv4 = nn.Sequential( 65 | nn.Conv2d(self.c3, self.c3, kernel_size=3, stride=1, padding=1, groups=self.c3, bias=False), 66 | nn.BatchNorm2d(self.c3), 67 | nn.ReLU(), 68 | nn.Conv2d(self.c3, self.out_channel, kernel_size=1, stride=1, padding=0, bias=False), 69 | nn.BatchNorm2d(self.out_channel), 70 | nn.ReLU() 71 | ) 72 | 73 | k_size = int(math.log(num_segments, 2)) 74 | if ( k_size & 1) == 0: # is odd number 75 | k_size = k_size + 1 76 | 77 | self.ETA = eca_layer(self.num_segments, k_size=k_size) # efficient temporal attention 78 | 79 | 80 | def forward(self, x): 81 | x = self.conv1(x) 82 | x = self.conv2(x) 83 | x = self.conv3(x) 84 | x = self.conv4(x) 85 | 86 | # temporal efficient channle attention 87 | x = x.view((-1, self.num_segments) + x.size()[1:]) # N T C H W 88 | N,T,C,H,W = x.size() 89 | x = x.permute(0,2,1,3,4).contiguous() # N C T H W 90 | x = x.view(-1, T, H, W) # NC,T,H,W 91 | x = self.ETA(x) 92 | x = x.view(N,C,T,H,W).permute(0,2,1,3,4).contiguous() # N,T,C,H,W 93 | x = x.view(-1, C, H, W) 94 | return x 95 | 96 | 97 | class MTDM(nn.Module): # # Multi-scale Temporal Dynamics Module 98 | 99 | def __init__(self, num_segments, expansion = 1, pos=2): 100 | super(MTDM, self).__init__() 101 | patchs = [15, 15, 7, 3] 102 | self.patch = patchs[pos-1] 103 | self.patch_dilation = 1 104 | self.soft_argmax = nn.Softmax(dim=1) 105 | self.expansion = expansion 106 | self.num_segments = num_segments 107 | 108 | self.chnl_reduction = nn.Sequential( 109 | nn.Conv2d(128*self.expansion, 64, kernel_size=1, stride=1, padding=0, bias=False), 110 | nn.BatchNorm2d(64), 111 | nn.ReLU(inplace=True) 112 | ) 113 | 114 | self.matching_layer = Matching_layer(ks=1, patch=self.patch, stride=1, pad=0, patch_dilation=self.patch_dilation) 115 | 116 | def L2normalize(self, x, d=1): 117 | eps = 1e-6 118 | norm = x ** 2 119 | norm = norm.sum(dim=d, keepdim=True) + eps 120 | norm = norm ** (0.5) 121 | return (x / norm) 122 | 123 | def apply_binary_kernel(self, match, h, w, region): 124 | # binary kernel 125 | x_line = torch.arange(w, dtype=torch.float).to('cuda').detach() 126 | y_line = torch.arange(h, dtype=torch.float).to('cuda').detach() 127 | x_kernel_1 = x_line.view(1,1,1,1,w).expand(1,1,w,h,w).to('cuda').detach() 128 | y_kernel_1 = y_line.view(1,1,1,h,1).expand(1,h,1,h,w).to('cuda').detach() 129 | x_kernel_2 = x_line.view(1,1,w,1,1).expand(1,1,w,h,w).to('cuda').detach() 130 | y_kernel_2 = y_line.view(1,h,1,1,1).expand(1,h,1,h,w).to('cuda').detach() 131 | 132 | ones = torch.ones(1).to('cuda').detach() 133 | zeros = torch.zeros(1).to('cuda').detach() 134 | 135 | eps = 1e-6 136 | kx = torch.where(torch.abs(x_kernel_1 - x_kernel_2)<=region, ones, zeros).to('cuda').detach() 137 | ky = torch.where(torch.abs(y_kernel_1 - y_kernel_2)<=region, ones, zeros).to('cuda').detach() 138 | kernel = kx * ky + eps 139 | kernel = kernel.view(1,h*w,h*w).to('cuda').detach() 140 | return match* kernel 141 | 142 | 143 | def apply_gaussian_kernel(self, corr, h,w,p, sigma=5): 144 | b, c, s = corr.size() 145 | 146 | x = torch.arange(p, dtype=torch.float).to('cuda').detach() 147 | y = torch.arange(p, dtype=torch.float).to('cuda').detach() 148 | 149 | idx = corr.max(dim=1)[1] # b x hw get maximum value along channel 150 | idx_y = (idx // p).view(b, 1, 1, h, w).float() 151 | idx_x = (idx % p).view(b, 1, 1, h, w).float() 152 | 153 | x = x.view(1,1,p,1,1).expand(1, 1, p, h, w).to('cuda').detach() 154 | y = y.view(1,p,1,1,1).expand(1, p, 1, h, w).to('cuda').detach() 155 | 156 | gauss_kernel = torch.exp(-((x-idx_x)**2 + (y-idx_y)**2) / (2 * sigma**2)) 157 | gauss_kernel = gauss_kernel.view(b, p*p, h*w)#.permute(0,2,1).contiguous() 158 | 159 | return gauss_kernel * corr 160 | 161 | def match_to_flow_soft(self, match, k, h,w, temperature=1, mode='softmax'): 162 | b, c , s = match.size() 163 | idx = torch.arange(h*w, dtype=torch.float32).to('cuda') 164 | idx_x = idx % w 165 | idx_x = idx_x.repeat(b,k,1).to('cuda') 166 | idx_y = torch.floor(idx / w) 167 | idx_y = idx_y.repeat(b,k,1).to('cuda') 168 | 169 | soft_idx_x = idx_x[:,:1] 170 | soft_idx_y = idx_y[:,:1] 171 | displacement = (self.patch-1)/2 172 | 173 | topk_value, topk_idx = torch.topk(match, k, dim=1) # (B*T-1, k, H*W) 174 | topk_value = topk_value.view(-1,k,h,w) 175 | 176 | match = self.apply_gaussian_kernel(match, h, w, self.patch, sigma=5) 177 | match = match*temperature 178 | match_pre = self.soft_argmax(match) 179 | smax = match_pre 180 | smax = smax.view(b,self.patch,self.patch,h,w) 181 | x_kernel = torch.arange(-displacement*self.patch_dilation, displacement*self.patch_dilation+1, step=self.patch_dilation, dtype=torch.float).to('cuda') 182 | y_kernel = torch.arange(-displacement*self.patch_dilation, displacement*self.patch_dilation+1, step=self.patch_dilation, dtype=torch.float).to('cuda') 183 | x_mult = x_kernel.expand(b,self.patch).view(b,self.patch,1,1) 184 | y_mult = y_kernel.expand(b,self.patch).view(b,self.patch,1,1) 185 | 186 | smax_x = smax.sum(dim=1, keepdim=False) #(b,w=k,h,w) 187 | smax_y = smax.sum(dim=2, keepdim=False) #(b,h=k,h,w) 188 | flow_x = (smax_x*x_mult).sum(dim=1, keepdim=True).view(-1,1,h*w) # (b,1,h,w) 189 | flow_y = (smax_y*y_mult).sum(dim=1, keepdim=True).view(-1,1,h*w) # (b,1,h,w) 190 | 191 | flow_x = (flow_x / (self.patch_dilation * displacement)) 192 | flow_y = (flow_y / (self.patch_dilation * displacement)) 193 | 194 | return flow_x, flow_y, topk_value 195 | 196 | def flow_computation(self, x, pos=0, temperature=100): 197 | 198 | size = x.size() 199 | x = x.view((-1, self.num_segments) + size[1:]) # N T C H W 200 | x = x.permute(0,2,1,3,4).contiguous() # B C T H W 201 | 202 | # match to flow 203 | k = 1 204 | b,c,t,h,w = x.size() 205 | t = t-1 206 | 207 | if pos == 0: 208 | x_pre = x[:,:,0,:].unsqueeze(dim=2).expand((b,c,t,h,w)).permute(0,2,1,3,4).contiguous().view(-1,c,h,w) 209 | else: 210 | x_pre = x[:,:,:-1].permute(0,2,1,3,4).contiguous().view(-1,c,h,w) 211 | 212 | #x_pre = x[:,:,0,:].unsqueeze(dim=2).expand((b,c,t-1,h,w)) 213 | x_post = x[:,:,1:].permute(0,2,1,3,4).contiguous().view(-1,c,h,w) 214 | 215 | match = self.matching_layer(x_pre, x_post) # (B*T-1*group, H*W, H*W) 216 | u, v, confidence = self.match_to_flow_soft(match, k, h, w, temperature) 217 | flow = torch.cat([u,v], dim=1).view(-1, 2*k, h, w) # (b, 2, h, w) 218 | 219 | return flow, confidence 220 | 221 | def forward(self,x): 222 | # multi-scale temporal action feature 223 | x_redu = self.chnl_reduction(x) 224 | flow_1, match_v1 = self.flow_computation(x_redu, pos=1) 225 | flow_2, match_v2 = self.flow_computation(x_redu, pos=0) 226 | 227 | x1 = torch.cat([flow_1, match_v1], dim=1) 228 | x2 = torch.cat([flow_2, match_v2], dim=1) 229 | 230 | _, c, h, w = x1.size() 231 | x1 = x1.view(-1,self.num_segments-1,c,h,w) 232 | x2 = x2.view(-1,self.num_segments-1,c,h,w) 233 | 234 | x1 = torch.cat([x1,x1[:,-1:,:,:,:]], dim=1) ## (b,t,3,h,w) 235 | x2 = torch.cat([x2,x2[:,-1:,:,:,:]], dim=1) ## (b,t,3,h,w) 236 | 237 | out = torch.cat([x1,x2], dim=2) 238 | out = out.view(-1,2*c,h,w) 239 | return out 240 | 241 | class Matching_layer(nn.Module): 242 | def __init__(self, ks, patch, stride, pad, patch_dilation): 243 | super(Matching_layer, self).__init__() 244 | self.relu = nn.ReLU() 245 | self.patch = patch 246 | self.correlation_sampler = SpatialCorrelationSampler(ks, patch, stride, pad, patch_dilation) 247 | 248 | def L2normalize(self, x, d=1): 249 | eps = 1e-6 250 | norm = x ** 2 251 | norm = norm.sum(dim=d, keepdim=True) + eps 252 | norm = norm ** (0.5) 253 | return (x / norm) 254 | 255 | def forward(self, feature1, feature2): 256 | feature1 = self.L2normalize(feature1) 257 | feature2 = self.L2normalize(feature2) 258 | b, c, h1, w1 = feature1.size() 259 | b, c, h2, w2 = feature2.size() 260 | corr = self.correlation_sampler(feature1, feature2) 261 | corr = corr.view(b, self.patch * self.patch, h1* w1) # Channel : target // Spatial grid : source 262 | corr = self.relu(corr) 263 | return corr 264 | 265 | 266 | class eca_layer(nn.Module): 267 | """Constructs a ECA module. 268 | 269 | Args: 270 | channel: Number of channels of the input feature map 271 | k_size: Adaptive selection of kernel size 272 | """ 273 | def __init__(self, channel, k_size=3): 274 | super(eca_layer, self).__init__() 275 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 276 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 277 | self.sigmoid = nn.Sigmoid() 278 | 279 | def forward(self, x): 280 | # x: input features with shape [b, c, h, w] 281 | b, c, h, w = x.size() 282 | 283 | # feature descriptor on the global spatial information 284 | y = self.avg_pool(x) 285 | 286 | # Two different branches of ECA module 287 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 288 | 289 | # Multi-scale information fusion 290 | y = self.sigmoid(y) 291 | 292 | return x * y.expand_as(x) -------------------------------------------------------------------------------- /imgs/TCM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzfly/TCM/34685f7c86915228b1849fd1d20c0a704aa66ab2/imgs/TCM.jpg -------------------------------------------------------------------------------- /main_something.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import shutil 5 | import torch 6 | import torchvision 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torch.nn.utils import clip_grad_norm_ 11 | 12 | from ops.dataset import TSNDataSet 13 | from ops.transforms import * 14 | 15 | from models import TSN 16 | from opts import parser 17 | import sys 18 | import torch.utils.model_zoo as model_zoo 19 | from torch.nn.init import constant_, xavier_uniform_ 20 | 21 | import ipdb 22 | 23 | #os.environ["CUDA_VISIBLE_DEVICES"]='0,1,2,3' 24 | best_prec1 = 0 25 | 26 | 27 | def main(): 28 | global args, best_prec1 29 | args = parser.parse_args() 30 | 31 | print("------------------------------------") 32 | print("Environment Versions:") 33 | print("- Python: {}".format(sys.version)) 34 | print("- PyTorch: {}".format(torch.__version__)) 35 | print("- TorchVison: {}".format(torchvision.__version__)) 36 | 37 | args_dict = args.__dict__ 38 | print("------------------------------------") 39 | print(args.arch+" Configurations:") 40 | for key in args_dict.keys(): 41 | print("- {}: {}".format(key, args_dict[key])) 42 | print("------------------------------------") 43 | print (args.mode) 44 | if args.dataset == 'ucf101': 45 | num_class = 101 46 | rgb_read_format = "{:05d}.jpg" 47 | elif args.dataset == 'hmdb51': 48 | num_class = 51 49 | rgb_read_format = "{:05d}.jpg" 50 | elif args.dataset == 'kinetics': 51 | num_class = 400 52 | rgb_read_format = "{:05d}.jpg" 53 | elif args.dataset == 'something': 54 | num_class = 174 55 | rgb_read_format = "{:05d}.jpg" 56 | elif args.dataset == 'somethingv2': 57 | num_class = 174 58 | rgb_read_format = "img_{:05d}.jpg" 59 | elif args.dataset == 'NTU_RGBD': 60 | num_class = 120 61 | rgb_read_format = "{:05d}.jpg" 62 | elif args.dataset == 'tinykinetics': 63 | num_class = 150 64 | rgb_read_format = "{:05d}.jpg" 65 | else: 66 | raise ValueError('Unknown dataset '+args.dataset) 67 | 68 | model = TSN(num_class, args.num_segments, args.modality, 69 | base_model=args.arch, 70 | consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn, non_local=args.non_local) 71 | 72 | crop_size = model.crop_size 73 | scale_size = model.scale_size 74 | input_mean = model.input_mean 75 | input_std = model.input_std 76 | # Optimizer s also support specifying per-parameter options. 77 | # To do this, pass in an iterable of dict s. 78 | # Each of them will define a separate parameter group, 79 | # and should contain a params key, containing a list of parameters belonging to it. 80 | # Other keys should match the keyword arguments accepted by the optimizers, 81 | # and will be used as optimization options for this group. 82 | policies = model.get_optim_policies(args.dataset) 83 | 84 | train_augmentation = model.get_augmentation() 85 | 86 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 87 | 88 | model_dict = model.state_dict() 89 | 90 | if args.arch == "resnet50": 91 | new_state_dict = {} #model_dict 92 | div = False 93 | roll = True 94 | elif args.arch == "resnet34": 95 | pretrained_dict={} 96 | new_state_dict = {} #model_dict 97 | for k, v in model_dict.items(): 98 | if ('fc' not in k): 99 | new_state_dict.update({k:v}) 100 | div = False 101 | roll = True 102 | elif (args.arch[:3] == "TCM" ): 103 | pretrained_dict={} 104 | new_state_dict = {} #model_dict 105 | for k, v in model_dict.items(): 106 | if ('fc' not in k): 107 | new_state_dict.update({k:v}) 108 | div = True 109 | roll = False 110 | 111 | 112 | if args.resume: 113 | if os.path.isfile(args.resume): 114 | print(("=> loading checkpoint '{}'".format(args.resume))) 115 | checkpoint = torch.load(args.resume) 116 | args.start_epoch = checkpoint['epoch'] 117 | best_prec1 = checkpoint['best_prec1'] 118 | model.load_state_dict(checkpoint['state_dict']) 119 | print(("=> loaded checkpoint '{}' (epoch {})" 120 | .format(args.resume, checkpoint['epoch']))) 121 | else: 122 | print(("=> no checkpoint found at '{}'".format(args.resume))) 123 | 124 | cudnn.benchmark = True 125 | 126 | # Data loading code 127 | if args.modality != 'RGBDiff': 128 | normalize = GroupNormalize(input_mean, input_std) 129 | else: 130 | normalize = IdentityTransform() 131 | 132 | if args.modality == 'RGB': 133 | data_length = 1 134 | elif args.modality in ['Flow', 'RGBDiff']: 135 | data_length = 1 136 | 137 | train_loader = torch.utils.data.DataLoader( 138 | TSNDataSet("", args.train_list, num_segments=args.num_segments, 139 | new_length=data_length, 140 | modality=args.modality, 141 | mode = args.mode, 142 | image_tmpl=args.rgb_prefix+rgb_read_format if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+rgb_read_format, 143 | img_start_idx=args.img_start_idx, 144 | transform=torchvision.transforms.Compose([ 145 | GroupScale((240,320)), 146 | # GroupScale(int(scale_size)), 147 | train_augmentation, 148 | Stack(roll=roll), 149 | ToTorchFormatTensor(div=div), 150 | normalize, 151 | ])), 152 | batch_size=args.batch_size, shuffle=True, 153 | num_workers=args.workers, pin_memory=True) 154 | 155 | val_loader = torch.utils.data.DataLoader( 156 | TSNDataSet("", args.val_list, num_segments=args.num_segments, 157 | new_length=data_length, 158 | modality=args.modality, 159 | mode =args.mode, 160 | image_tmpl=args.rgb_prefix+rgb_read_format if args.modality in ["RGB", "RGBDiff"] else args.flow_prefix+rgb_read_format, 161 | img_start_idx=args.img_start_idx, 162 | random_shift=False, 163 | transform=torchvision.transforms.Compose([ 164 | GroupScale((240,320)), 165 | # GroupScale((224)), 166 | # GroupScale(int(scale_size)), 167 | GroupCenterCrop(crop_size), 168 | Stack(roll=roll), 169 | ToTorchFormatTensor(div=div), 170 | normalize, 171 | ])), 172 | batch_size=args.batch_size, shuffle=False, 173 | num_workers=args.workers, pin_memory=True) 174 | 175 | # define loss function (criterion) and optimizer 176 | if args.loss_type == 'nll': 177 | criterion = torch.nn.CrossEntropyLoss().cuda() 178 | 179 | else: 180 | raise ValueError("Unknown loss type") 181 | 182 | for group in policies: 183 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 184 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 185 | 186 | optimizer = torch.optim.SGD(policies, 187 | args.lr, 188 | momentum=args.momentum, 189 | weight_decay=args.weight_decay,nesterov=args.nesterov) 190 | 191 | output_list = [] 192 | if args.evaluate: 193 | prec1, score_tensor = validate(val_loader,model,criterion,temperature=100) 194 | output_list.append(score_tensor) 195 | save_validation_score(output_list, filename='score.pt') 196 | print("validation score saved in {}".format('/'.join((args.val_output_folder, 'score_inf5.pt')))) 197 | return 198 | 199 | for epoch in range(args.start_epoch, args.epochs): 200 | adjust_learning_rate(optimizer, epoch, args.lr_steps) 201 | # train for one epoch 202 | temperature = train(train_loader, model, criterion, optimizer, epoch) 203 | 204 | # evaluate on validation set 205 | if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: 206 | prec1, score_tensor = validate(val_loader, model, criterion, temperature=temperature) 207 | 208 | output_list.append(score_tensor) 209 | 210 | # remember best prec@1 and save checkpoint 211 | is_best = prec1 > best_prec1 212 | best_prec1 = max(prec1, best_prec1) 213 | 214 | output_best = 'Best Prec@1: %.3f\n' % (best_prec1) 215 | print(output_best) 216 | 217 | save_checkpoint({ 218 | 'epoch': epoch + 1, 219 | 'arch': args.arch, 220 | 'state_dict': model.state_dict(), 221 | 'best_prec1': best_prec1, 222 | }, is_best) 223 | 224 | # save validation score 225 | save_validation_score(output_list) 226 | print("validation score saved in {}".format('/'.join((args.val_output_folder, 'score.pt')))) 227 | 228 | 229 | def train(train_loader, model, criterion, optimizer, epoch): 230 | batch_time = AverageMeter() 231 | data_time = AverageMeter() 232 | losses = AverageMeter() 233 | top1 = AverageMeter() 234 | top5 = AverageMeter() 235 | 236 | 237 | # temperature 238 | increase = pow(1.05, epoch) 239 | temperature = 100 # * increase 240 | print (temperature) 241 | 242 | 243 | # In PyTorch 0.4, "volatile=True" is deprecated. 244 | torch.set_grad_enabled(True) 245 | 246 | if args.no_partialbn: 247 | model.module.partialBN(False) 248 | else: 249 | model.module.partialBN(True) 250 | 251 | # switch to train mode 252 | model.train() 253 | 254 | end = time.time() 255 | for i, (input, target) in enumerate(train_loader): 256 | # discard final batch 257 | if i == len(train_loader)-1: 258 | break 259 | # measure data loading time 260 | data_time.update(time.time() - end) 261 | 262 | # target size: [batch_size] 263 | target = target.cuda() 264 | input_var = input 265 | target_var = target 266 | output = model(input_var, temperature) 267 | loss = criterion(output, target_var) 268 | 269 | # measure accuracy and record loss 270 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 271 | losses.update(loss.item(), input.size(0)) 272 | top1.update(prec1.item(), input.size(0)) 273 | top5.update(prec5.item(), input.size(0)) 274 | 275 | # compute gradient and do SGD step 276 | loss.backward() 277 | 278 | if i % args.iter_size == 0: 279 | # scale down gradients when iter size is functioning 280 | if args.iter_size != 1: 281 | for g in optimizer.param_groups: 282 | for p in g['params']: 283 | p.grad /= args.iter_size 284 | 285 | if args.clip_gradient is not None: 286 | total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient) 287 | if total_norm > args.clip_gradient: 288 | print("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm)) 289 | else: 290 | total_norm = 0 291 | 292 | optimizer.step() 293 | optimizer.zero_grad() 294 | 295 | 296 | # measure elapsed time 297 | batch_time.update(time.time() - end) 298 | end = time.time() 299 | 300 | if i % args.print_freq == 0: 301 | print(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 302 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 303 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 304 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 305 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 306 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 307 | epoch, i, len(train_loader), batch_time=batch_time, 308 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-2]['lr']))) 309 | # print(('Flow_Con_Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=flow_con_losses))) 310 | return temperature 311 | 312 | 313 | def validate(val_loader, model, criterion, temperature, logger=None): 314 | batch_time = AverageMeter() 315 | losses = AverageMeter() 316 | top1 = AverageMeter() 317 | top5 = AverageMeter() 318 | 319 | # another losses 320 | flow_con_losses = AverageMeter() 321 | 322 | # In PyTorch 0.4, "volatile=True" is deprecated. 323 | torch.set_grad_enabled(False) 324 | # torch.no_grad() 325 | # switch to evaluate mode 326 | model.eval() 327 | # model.train() 328 | 329 | output_list = [] 330 | pred_arr = [] 331 | target_arr = [] 332 | end = time.time() 333 | for i, (input, target) in enumerate(val_loader): 334 | # discard final batch 335 | if i == len(val_loader)-1: 336 | break 337 | target = target.cuda() 338 | 339 | input_var = input 340 | target_var = target 341 | 342 | output= model(input_var, temperature) 343 | loss = criterion(output, target_var) 344 | 345 | # class acc 346 | pred = torch.argmax(output.data, dim=1) 347 | pred_arr.extend(pred) 348 | target_arr.extend(target) 349 | 350 | # measure accuracy and record loss 351 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 352 | 353 | losses.update(loss.item(), input.size(0)) 354 | top1.update(prec1.item(), input.size(0)) 355 | top5.update(prec5.item(), input.size(0)) 356 | 357 | # measure elapsed time 358 | batch_time.update(time.time() - end) 359 | end = time.time() 360 | 361 | output_list.append(output) 362 | 363 | if i % args.print_freq == 0: 364 | print(('Test: [{0}/{1}]\t' 365 | 'Time {batch_time.val:.3f} ({batch_time.avg:.4f})\t' 366 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 367 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 368 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 369 | i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5))) 370 | # print(('Flow_Con_Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=flow_con_losses))) 371 | output_tensor = torch.cat(output_list, dim=0) 372 | 373 | print(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f} Time {batch_time.avg:.4f}' 374 | .format(top1=top1, top5=top5, loss=losses, batch_time=batch_time))) 375 | return top1.avg, output_tensor 376 | 377 | def save_checkpoint(state, is_best, filename='latest.pth.tar'): 378 | filename = os.path.join(args.snapshot_pref, filename) 379 | torch.save(state, filename) 380 | if is_best: 381 | best_name = '_'.join((args.modality.lower(), 'model_best.pth.tar')) 382 | best_name = os.path.join(args.snapshot_pref, best_name) 383 | shutil.copyfile(filename, best_name) 384 | 385 | def save_validation_score(score, filename='score.pt'): 386 | filename = '/'.join((args.val_output_folder, filename)) 387 | torch.save(score, filename) 388 | 389 | class AverageMeter(object): 390 | """Computes and stores the average and current value""" 391 | def __init__(self): 392 | self.reset() 393 | 394 | def reset(self): 395 | self.val = 0 396 | self.avg = 0 397 | self.sum = 0 398 | self.count = 0 399 | 400 | def update(self, val, n=1): 401 | self.val = val 402 | self.sum += val * n 403 | self.count += n 404 | self.avg = self.sum / self.count 405 | 406 | 407 | def adjust_learning_rate(optimizer, epoch, lr_steps): 408 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 409 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 410 | lr = args.lr * decay 411 | decay = args.weight_decay 412 | for param_group in optimizer.param_groups: 413 | param_group['lr'] = lr * param_group['lr_mult'] 414 | param_group['weight_decay'] = decay * param_group['decay_mult'] 415 | 416 | 417 | def accuracy(output, target, topk=(1,)): 418 | """Computes the precision@k for the specified values of k""" 419 | maxk = max(topk) 420 | batch_size = target.size(0) 421 | 422 | _, pred = output.topk(maxk, 1, True, True) 423 | pred = pred.t() 424 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 425 | res = [] 426 | for k in topk: 427 | correct_k = correct[:k].reshape(-1).float().sum(0) 428 | res.append(correct_k.mul_(100.0 / batch_size)) 429 | return res 430 | 431 | 432 | if __name__ == '__main__': 433 | main() 434 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ops.basic_ops import ConsensusModule, Identity 4 | from ops.transforms import * 5 | from torch.nn.init import xavier_uniform_, constant_ 6 | 7 | 8 | class TSN(nn.Module): 9 | def __init__(self, num_class, num_segments, modality, 10 | base_model='resnet101', dataset='something', new_length=None, 11 | consensus_type='avg', before_softmax=True, 12 | dropout=0.8,fc_lr5=True, 13 | crop_num=1, partial_bn=True, non_local=False): 14 | super(TSN, self).__init__() 15 | self.modality = modality 16 | self.num_segments = num_segments 17 | self.reshape = True 18 | self.before_softmax = before_softmax 19 | self.dropout = dropout 20 | self.crop_num = crop_num 21 | self.consensus_type = consensus_type 22 | self.base_model_name = base_model 23 | self.dataset = dataset 24 | self.fc_lr5 = fc_lr5 25 | self.non_local = non_local 26 | if not before_softmax and consensus_type != 'avg': 27 | raise ValueError("Only avg consensus can be used after Softmax") 28 | 29 | if new_length is None: 30 | self.new_length = 1 if modality == "RGB" else 1 31 | else: 32 | self.new_length = new_length 33 | 34 | print((""" 35 | Initializing TSN with base model: {}. 36 | TSN Configurations: 37 | input_modality: {} 38 | num_segments: {} 39 | new_length: {} 40 | consensus_module: {} 41 | dropout_ratio: {} 42 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout))) 43 | 44 | if (base_model[:3] == 'TCM'): 45 | if 'resnet50' in base_model: 46 | from resnet_TSM import resnet50 as resnet 47 | self.base_model = resnet(True, shift='TSM', num_segments = self.num_segments, enable_TCM = 1) 48 | print("Backbone: resnet50") 49 | elif 'resnet101' in base_model: 50 | from resnet_TSM import resnet101 as resnet 51 | self.base_model = resnet(True, shift='TSM', num_segments = self.num_segments, enable_TCM = 1) 52 | print("Backbone: resnet101") 53 | else: 54 | raise ValueError('Unknown base model: {}'.format(base_model)) 55 | 56 | if self.non_local: 57 | print('Adding non-local module...') 58 | from non_local import make_non_local_resnet50_layer4 as make_non_local 59 | make_non_local(self.base_model, self.num_segments) 60 | 61 | self.base_model.last_layer_name = 'fc1' 62 | self.input_size = 224 63 | self.input_mean = [0.485, 0.456, 0.406] 64 | self.input_std = [0.229, 0.224, 0.225] 65 | feature_dim = self._prepare_tsn(num_class) 66 | else: 67 | self._prepare_base_model(base_model) 68 | feature_dim = self._prepare_tsn(num_class) 69 | 70 | if self.modality == 'Flow': 71 | print("Converting the ImageNet model to a flow init model") 72 | self.base_model = self._construct_flow_model(self.base_model) 73 | print("Done. Flow model ready...") 74 | elif self.modality == 'RGBDiff': 75 | print("Converting the ImageNet model to RGB+Diff init model") 76 | self.base_model = self._construct_diff_model(self.base_model) 77 | print("Done. RGBDiff model ready.") 78 | 79 | self.consensus = ConsensusModule(consensus_type) 80 | 81 | if not self.before_softmax: 82 | self.softmax = nn.Softmax() 83 | 84 | self._enable_pbn = partial_bn 85 | if partial_bn: 86 | self.partialBN(True) 87 | 88 | def _prepare_tsn(self, num_class): 89 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_channels 90 | if self.dropout == 0: 91 | setattr(self.base_model, self.base_model.last_layer_name, nn.Conv1d(feature_dim, num_class, kernel_size=1, stride=1, padding=0,bias=True)) 92 | self.new_fc = None 93 | else: 94 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 95 | self.new_fc = nn.Conv1d(feature_dim, num_class, kernel_size=1, stride=1, padding=0,bias=True) 96 | 97 | std = 0.001 98 | 99 | if self.new_fc is None: 100 | xavier_uniform_(getattr(self.base_model, self.base_model.last_layer_name).weight) 101 | constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 102 | else: 103 | xavier_uniform_(self.new_fc.weight) 104 | constant_(self.new_fc.bias, 0) 105 | 106 | return feature_dim 107 | 108 | def _prepare_base_model(self, base_model): 109 | 110 | if 'resnet' in base_model: 111 | self.base_model = getattr(torchvision.models, base_model)(True) 112 | self.base_model.last_layer_name = 'fc' 113 | self.input_size = 224 114 | self.input_mean = [0.485, 0.456, 0.406] 115 | self.input_std = [0.229, 0.224, 0.225] 116 | 117 | if self.modality == 'Flow': 118 | self.input_mean = [0.5] 119 | self.input_std = [np.mean(self.input_std)] 120 | elif self.modality == 'RGBDiff': 121 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length 122 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length 123 | else: 124 | raise ValueError('Unknown base model: {}'.format(base_model)) 125 | 126 | def train(self, mode=True): 127 | """ 128 | Override the default train() to freeze the BN parameters 129 | :return: 130 | """ 131 | super(TSN, self).train(mode) 132 | count = 0 133 | if self._enable_pbn: 134 | print("Freezing BatchNorm2D except the first one.") 135 | for m in self.base_model.modules(): 136 | if isinstance(m, nn.BatchNorm2d): 137 | count += 1 138 | if count >= (2 if self._enable_pbn else 1): 139 | m.eval() 140 | 141 | # shutdown update in frozen mode 142 | m.weight.requires_grad = False 143 | m.bias.requires_grad = False 144 | else: 145 | print("No BN layer Freezing.") 146 | 147 | def partialBN(self, enable): 148 | self._enable_pbn = enable 149 | 150 | def get_optim_policies(self, dataset): 151 | first_conv_weight = [] 152 | first_conv_bias = [] 153 | normal_weight = [] 154 | normal_bias = [] 155 | lr5_weight = [] 156 | lr10_bias = [] 157 | bn = [] 158 | custom_ops = [] 159 | 160 | conv_cnt = 0 161 | bn_cnt = 0 162 | for m in self.modules(): 163 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 164 | ps = list(m.parameters()) 165 | conv_cnt += 1 166 | if conv_cnt == 1: 167 | first_conv_weight.append(ps[0]) 168 | if len(ps) == 2: 169 | first_conv_bias.append(ps[1]) 170 | else: 171 | normal_weight.append(ps[0]) 172 | if len(ps) == 2: 173 | normal_bias.append(ps[1]) 174 | elif isinstance(m, torch.nn.Linear): 175 | ps = list(m.parameters()) 176 | if self.fc_lr5: 177 | lr5_weight.append(ps[0]) 178 | else: 179 | normal_weight.append(ps[0]) 180 | if len(ps) == 2: 181 | if self.fc_lr5: 182 | lr10_bias.append(ps[1]) 183 | else: 184 | normal_bias.append(ps[1]) 185 | 186 | elif isinstance(m, torch.nn.BatchNorm2d): 187 | bn_cnt += 1 188 | # later BN's are frozen 189 | if not self._enable_pbn or bn_cnt == 1: 190 | bn.extend(list(m.parameters())) 191 | elif isinstance(m, torch.nn.BatchNorm3d): 192 | bn_cnt += 1 193 | # later BN's are frozen 194 | if not self._enable_pbn or bn_cnt == 1: 195 | bn.extend(list(m.parameters())) 196 | elif len(m._modules) == 0: 197 | if len(list(m.parameters())) > 0: 198 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 199 | 200 | return [ 201 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 202 | 'name': "first_conv_weight"}, 203 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 204 | 'name': "first_conv_bias"}, 205 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 206 | 'name': "normal_weight"}, 207 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 208 | 'name': "normal_bias"}, 209 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 210 | 'name': "BN scale/shift"}, 211 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, 212 | 'name': "custom_ops"}, 213 | # for fc 214 | {'params': lr5_weight, 'lr_mult': 5 if self.dataset == 'kinetics' else 1, 'decay_mult': 1, 215 | 'name': "lr5_weight"}, 216 | {'params': lr10_bias, 'lr_mult': 10 if self.dataset == 'kinetics' else 2, 'decay_mult': 0, 217 | 'name': "lr10_bias"}, 218 | ] 219 | 220 | 221 | def get_optim_policies_BN2to1D(self): 222 | first_conv_weight = [] 223 | first_conv_bias = [] 224 | normal_weight = [] 225 | normal_bias = [] 226 | bn = [] 227 | last_conv_weight = [] 228 | last_conv_bias = [] 229 | 230 | conv_cnt = 0 231 | bn_cnt = 0 232 | for m in self.modules(): 233 | # (conv1d or conv2d) 1st layer's params will be append to list: first_conv_weight & first_conv_bias, total num 1 respectively(1 conv2d) 234 | # (conv1d or conv2d or Linear) from 2nd layers' params will be append to list: normal_weight & normal_bias, total num 69 respectively(68 Conv2d + 1 Linear) 235 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d): 236 | ps = list(m.parameters()) 237 | conv_cnt += 1 238 | if conv_cnt == 1: 239 | first_conv_weight.append(ps[0]) 240 | if len(ps) == 2: 241 | first_conv_bias.append(ps[1]) 242 | else: 243 | normal_weight.append(ps[0]) 244 | if len(ps) == 2: 245 | normal_bias.append(ps[1]) 246 | elif isinstance(m, torch.nn.Conv3d): 247 | ps = list(m.parameters()) 248 | last_conv_weight.append(ps[0]) 249 | if len(ps) == 2: 250 | last_conv_bias.append(ps[1]) 251 | elif isinstance(m, torch.nn.Linear): 252 | ps = list(m.parameters()) 253 | normal_weight.append(ps[0]) 254 | if len(ps) == 2: 255 | normal_bias.append(ps[1]) 256 | # (BatchNorm1d or BatchNorm2d) params will be append to list: bn, total num 2 (enabled pbn, so only: 1st BN layer's weight + 1st BN layer's bias) 257 | elif isinstance(m, torch.nn.BatchNorm1d): 258 | bn.extend(list(m.parameters())) 259 | elif isinstance(m, torch.nn.BatchNorm2d): 260 | bn_cnt += 1 261 | # later BN's are frozen 262 | if not self._enable_pbn or bn_cnt == 1: 263 | bn.extend(list(m.parameters())) 264 | elif isinstance(m, torch.nn.BatchNorm3d): 265 | bn_cnt += 1 266 | # 4 267 | # later BN's are frozen 268 | if not self._enable_pbn or bn_cnt == 1: 269 | bn.extend(list(m.parameters())) 270 | elif len(m._modules) == 0: 271 | if len(list(m.parameters())) > 0: 272 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 273 | return [ 274 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 275 | 'name': "first_conv_weight"}, 276 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 277 | 'name': "first_conv_bias"}, 278 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 279 | 'name': "normal_weight"}, 280 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 281 | 'name': "normal_bias"}, 282 | {'params': last_conv_weight, 'lr_mult': 5, 'decay_mult': 1, 283 | 'name': "last_conv_weight"}, 284 | {'params': last_conv_bias, 'lr_mult': 10, 'decay_mult': 0, 285 | 'name': "last_conv_bias"}, 286 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 287 | 'name': "BN scale/shift"}, 288 | ] 289 | 290 | def forward(self, input, temperature): 291 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 292 | 293 | if self.modality == 'RGBDiff': 294 | sample_len = 3 * self.new_length 295 | input = self._get_diff(input) 296 | 297 | # input.size(): [32, 9, 224, 224] 298 | # after view() func: [96, 3, 224, 224] 299 | if (self.base_model_name == "C3DRes18") : 300 | before_permute = input.view((-1, sample_len) + input.size()[-2:]) 301 | input_var = torch.transpose(before_permute.view((-1, self.num_segments) + before_permute.size()[1:]), 1, 2) 302 | elif ("Res3D" in self.base_model_name): 303 | before_permute = input.view((-1, sample_len) + input.size()[-2:]) 304 | input_var = torch.transpose(before_permute.view((-1, self.num_segments) + before_permute.size()[1:]), 1, 2) 305 | elif (self.base_model_name in ["I3D", "I3D_flow"]): # [B, C, T, W, H] 306 | before_permute = input.view((-1, sample_len) + input.size()[-2:]) 307 | input_var = torch.transpose(before_permute.view((-1, self.num_segments) + before_permute.size()[1:]), 1, 2) 308 | else: 309 | input_var = input.view((-1, sample_len) + input.size()[-2:]) 310 | 311 | base_out = self.base_model(input_var, temperature) 312 | # zc comments 313 | if self.dropout > 0: 314 | #import ipdb; ipdb.set_trace() 315 | base_out = self.new_fc(base_out) 316 | 317 | if not self.before_softmax: 318 | base_out = self.softmax(base_out) 319 | # zc comments end 320 | 321 | if self.reshape: 322 | if "flow" in self.base_model_name: 323 | base_out = base_out.view((-1, (self.num_segments)) + base_out.size()[1:]) 324 | else: 325 | base_out = base_out.view((-1, (self.num_segments)) + base_out.size()[1:]) 326 | 327 | output = self.consensus(base_out) 328 | output = output.squeeze(3).squeeze(1) 329 | 330 | return output 331 | 332 | 333 | def _get_diff(self, input, keep_rgb=False): 334 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 335 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 336 | if keep_rgb: 337 | new_data = input_view.clone() 338 | else: 339 | new_data = input_view[:, :, 1:, :, :, :].clone() 340 | 341 | for x in reversed(list(range(1, self.new_length + 1))): 342 | if keep_rgb: 343 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 344 | else: 345 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] 346 | 347 | return new_data 348 | 349 | 350 | def _construct_flow_model(self, base_model): 351 | # modify the convolution layers 352 | # Torch models are usually defined in a hierarchical way. 353 | # nn.modules.children() return all sub modules in a DFS manner 354 | modules = list(self.base_model.modules()) 355 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 356 | conv_layer = modules[first_conv_idx] 357 | container = modules[first_conv_idx - 1] 358 | 359 | # modify parameters, assume the first blob contains the convolution kernels 360 | params = [x.clone() for x in conv_layer.parameters()] 361 | kernel_size = params[0].size() 362 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] 363 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 364 | 365 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 366 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 367 | bias=True if len(params) == 2 else False) 368 | new_conv.weight.data = new_kernels 369 | if len(params) == 2: 370 | new_conv.bias.data = params[1].data # add bias if neccessary 371 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 372 | 373 | # replace the first convlution layer 374 | setattr(container, layer_name, new_conv) 375 | return base_model 376 | 377 | def _construct_diff_model(self, base_model, keep_rgb=False): 378 | # modify the convolution layers 379 | # Torch models are usually defined in a hierarchical way. 380 | # nn.modules.children() return all sub modules in a DFS manner 381 | modules = list(self.base_model.modules()) 382 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0] 383 | conv_layer = modules[first_conv_idx] 384 | container = modules[first_conv_idx - 1] 385 | 386 | # modify parameters, assume the first blob contains the convolution kernels 387 | params = [x.clone() for x in conv_layer.parameters()] 388 | kernel_size = params[0].size() 389 | if not keep_rgb: 390 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 391 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 392 | else: 393 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] 394 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 395 | 1) 396 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] 397 | 398 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 399 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 400 | bias=True if len(params) == 2 else False) 401 | new_conv.weight.data = new_kernels 402 | if len(params) == 2: 403 | new_conv.bias.data = params[1].data # add bias if neccessary 404 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 405 | 406 | # replace the first convolution layer 407 | setattr(container, layer_name, new_conv) 408 | return base_model 409 | 410 | @property 411 | def crop_size(self): 412 | return self.input_size 413 | 414 | @property 415 | def scale_size(self): 416 | return self.input_size * 256 // 224 417 | 418 | def get_augmentation(self): 419 | if self.modality == 'RGB': 420 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),GroupRandomHorizontalFlip(selective_flip=True, is_flow=False)]) 421 | # return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875]), 422 | # GroupRandomHorizontalFlip(is_flow=False)]) 423 | elif self.modality == 'Flow': 424 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 425 | GroupRandomHorizontalFlip(is_flow=True)]) 426 | elif self.modality == 'RGBDiff': 427 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 428 | GroupRandomHorizontalFlip(is_flow=False)]) -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # phoenixyli 李岩 @2020-04-02 17:10:52 3 | 4 | import torch 5 | import ipdb 6 | 7 | 8 | class Identity(torch.nn.Module): 9 | """Identity module 10 | 11 | x = x 12 | """ 13 | 14 | def forward(self, input): 15 | return input 16 | 17 | 18 | class SegmentConsensus(torch.autograd.Function): 19 | 20 | @staticmethod 21 | def forward(ctx, input_tensor, consensus_type, dim=1): 22 | ctx.consensus_type = consensus_type 23 | ctx.dim = dim 24 | ctx.shape = input_tensor.size() 25 | 26 | if ctx.consensus_type == 'avg': 27 | output = input_tensor.mean(dim=ctx.dim, keepdim=True) 28 | elif ctx.consensus_type == 'identity': 29 | output = input_tensor 30 | else: 31 | output = None 32 | 33 | return output 34 | 35 | @staticmethod 36 | def backward(ctx, grad_output): 37 | #ipdb.set_trace() 38 | if ctx.consensus_type == 'avg': 39 | grad_in = grad_output.expand(ctx.shape) / float(ctx.shape[ctx.dim]) 40 | elif ctx.consensus_type == 'identity': 41 | grad_in = grad_output 42 | else: 43 | grad_in = None 44 | 45 | return grad_in, None, None 46 | 47 | 48 | class ConsensusModule(torch.nn.Module): 49 | """ 50 | """ 51 | 52 | def __init__(self, consensus_type, dim=1): 53 | super(ConsensusModule, self).__init__() 54 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 55 | self.dim = dim 56 | 57 | def forward(self, x): 58 | return SegmentConsensus.apply(x, self.consensus_type, self.dim) 59 | -------------------------------------------------------------------------------- /ops/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from numpy.random import randint 8 | import ipdb 9 | 10 | class VideoRecord(object): 11 | def __init__(self, row): 12 | self._data = row 13 | 14 | @property 15 | def path(self): 16 | return self._data[0] 17 | 18 | @property 19 | def num_frames(self): 20 | return int(self._data[1]) 21 | 22 | @property 23 | def label(self): 24 | return int(self._data[2]) 25 | 26 | 27 | class TSNDataSet(data.Dataset): 28 | def __init__(self, root_path, list_file, 29 | num_segments=3, new_length=1, interval=2, modality='RGB', mode=1, stride=8, 30 | image_tmpl='img_{:05d}.jpg',img_start_idx=1, transform=None, 31 | force_grayscale=False, random_shift=True, test_mode=False): 32 | 33 | self.root_path = root_path 34 | self.list_file = list_file 35 | self.num_segments = num_segments 36 | self.new_length = new_length 37 | 38 | self.modality = modality 39 | self.image_tmpl = image_tmpl 40 | self.transform = transform 41 | self.random_shift = random_shift 42 | self.test_mode = test_mode 43 | 44 | self.interval= interval 45 | self.mode = mode 46 | self.stride = stride 47 | self.start_idx = img_start_idx 48 | 49 | if self.modality == 'RGBDiff': 50 | self.new_length += 1# Diff needs one more image to calculate diff 51 | 52 | self._parse_list() 53 | #ipdb.set_trace() 54 | 55 | def _load_image(self, directory, idx): 56 | 57 | ####################################################### 58 | #if os.path.exists(os.path.join(self.root_path, directory, self.image_tmpl.format(0))): 59 | # idx = idx - 1 # frame number start from 0 60 | ####################################################### 61 | 62 | if self.start_idx == 0 : 63 | idx = idx - 1 64 | 65 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 66 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 67 | elif self.modality == 'Flow': 68 | ''' 69 | img = Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB') 70 | flow_x, flow_y, _ = img.split() 71 | x_img = flow_x.convert('L') 72 | y_img = flow_y.convert('L') 73 | ''' 74 | x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L') 75 | y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L') 76 | 77 | return [x_img, y_img] 78 | 79 | def _parse_list(self): 80 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)] 81 | 82 | def _sample_indices(self, record): 83 | """ 84 | :param record: VideoRecord 85 | :return: list 86 | """ 87 | if self.mode==0: # i3d dense sample 88 | sample_pos = max(1, 1 + record.num_frames - 64) 89 | t_stride = 64 // self.num_segments 90 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 91 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 92 | # print (offsets) 93 | return np.array(offsets) + 1 94 | elif self.mode: # normal sample 95 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 96 | if average_duration > 0: 97 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 98 | size=self.num_segments) 99 | elif record.num_frames > self.num_segments: 100 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 101 | else: 102 | offsets = np.zeros((self.num_segments,)) 103 | # print (offsets) 104 | return offsets + 1 105 | 106 | def _get_val_indices(self, record): 107 | if (self.mode==0): # i3d dense sample 108 | sample_pos = max(1, 1 + record.num_frames - 64) 109 | t_stride = 64 // self.num_segments 110 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 111 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 112 | return np.array(offsets) + 1 113 | else: 114 | if record.num_frames > self.num_segments + self.new_length - 1: 115 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 116 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 117 | else: 118 | offsets = np.zeros((self.num_segments,)) 119 | return offsets + 1 120 | 121 | def _get_test_indices(self, record): 122 | if (self.mode==0): 123 | sample_pos = max(1, 1 + record.num_frames - 64) 124 | t_stride = 64 // self.num_segments 125 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) 126 | offsets = [] 127 | for start_idx in start_list.tolist(): 128 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 129 | return np.array(offsets) + 1 130 | 131 | elif self.mode==2: # tsm twice sample 132 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 133 | 134 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + 135 | [int(tick * x) for x in range(self.num_segments)]) 136 | 137 | return offsets + 1 138 | 139 | else: 140 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 141 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 142 | return offsets + 1 143 | 144 | 145 | def __getitem__(self, index): 146 | record = self.video_list[index] 147 | 148 | if not self.test_mode: 149 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 150 | # print (segment_indices) 151 | else: 152 | segment_indices = self._get_test_indices(record) 153 | # print (segment_indices) 154 | 155 | return self.get(record, segment_indices) 156 | 157 | def get(self, record, indices): 158 | 159 | images = list() 160 | for seg_ind in indices: 161 | p = int(seg_ind) 162 | for i in range(self.new_length): 163 | seg_imgs = self._load_image(record.path, p) 164 | images.extend(seg_imgs) 165 | if p < record.num_frames: 166 | p += 1 167 | 168 | process_data, _ = self.transform((images,record.label)) 169 | return process_data, record.label 170 | 171 | def __len__(self): 172 | return len(self.video_list) 173 | -------------------------------------------------------------------------------- /ops/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from .utils import load_state_dict_from_url 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion: int = 1 39 | 40 | def __init__( 41 | self, 42 | inplanes: int, 43 | planes: int, 44 | stride: int = 1, 45 | downsample: Optional[nn.Module] = None, 46 | groups: int = 1, 47 | base_width: int = 64, 48 | dilation: int = 1, 49 | norm_layer: Optional[Callable[..., nn.Module]] = None 50 | ) -> None: 51 | super(BasicBlock, self).__init__() 52 | if norm_layer is None: 53 | norm_layer = nn.BatchNorm2d 54 | if groups != 1 or base_width != 64: 55 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 56 | if dilation > 1: 57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 59 | self.conv1 = conv3x3(inplanes, planes, stride) 60 | self.bn1 = norm_layer(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(planes, planes) 63 | self.bn2 = norm_layer(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 90 | # This variant is also known as ResNet V1.5 and improves accuracy according to 91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 92 | 93 | expansion: int = 4 94 | 95 | def __init__( 96 | self, 97 | inplanes: int, 98 | planes: int, 99 | stride: int = 1, 100 | downsample: Optional[nn.Module] = None, 101 | groups: int = 1, 102 | base_width: int = 64, 103 | dilation: int = 1, 104 | norm_layer: Optional[Callable[..., nn.Module]] = None 105 | ) -> None: 106 | super(Bottleneck, self).__init__() 107 | if norm_layer is None: 108 | norm_layer = nn.BatchNorm2d 109 | width = int(planes * (base_width / 64.)) * groups 110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 111 | self.conv1 = conv1x1(inplanes, width) 112 | self.bn1 = norm_layer(width) 113 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 114 | self.bn2 = norm_layer(width) 115 | self.conv3 = conv1x1(width, planes * self.expansion) 116 | self.bn3 = norm_layer(planes * self.expansion) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x: Tensor) -> Tensor: 122 | identity = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | identity = self.downsample(x) 137 | 138 | out += identity 139 | out = self.relu(out) 140 | 141 | return out 142 | 143 | 144 | class ResNet(nn.Module): 145 | 146 | def __init__( 147 | self, 148 | block: Type[Union[BasicBlock, Bottleneck]], 149 | layers: List[int], 150 | num_classes: int = 1000, 151 | zero_init_residual: bool = False, 152 | groups: int = 1, 153 | width_per_group: int = 64, 154 | replace_stride_with_dilation: Optional[List[bool]] = None, 155 | norm_layer: Optional[Callable[..., nn.Module]] = None 156 | ) -> None: 157 | super(ResNet, self).__init__() 158 | if norm_layer is None: 159 | norm_layer = nn.BatchNorm2d 160 | self._norm_layer = norm_layer 161 | 162 | self.inplanes = 64 163 | self.dilation = 1 164 | if replace_stride_with_dilation is None: 165 | # each element in the tuple indicates if we should replace 166 | # the 2x2 stride with a dilated convolution instead 167 | replace_stride_with_dilation = [False, False, False] 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError("replace_stride_with_dilation should be None " 170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 174 | bias=False) 175 | self.bn1 = norm_layer(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 178 | self.layer1 = self._make_layer(block, 64, layers[0]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 180 | dilate=replace_stride_with_dilation[0]) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 182 | dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 184 | dilate=replace_stride_with_dilation[2]) 185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 186 | self.fc = nn.Linear(512 * block.expansion, num_classes) 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 204 | 205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 206 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 207 | norm_layer = self._norm_layer 208 | downsample = None 209 | previous_dilation = self.dilation 210 | if dilate: 211 | self.dilation *= stride 212 | stride = 1 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 221 | self.base_width, previous_dilation, norm_layer)) 222 | self.inplanes = planes * block.expansion 223 | for _ in range(1, blocks): 224 | layers.append(block(self.inplanes, planes, groups=self.groups, 225 | base_width=self.base_width, dilation=self.dilation, 226 | norm_layer=norm_layer)) 227 | 228 | return nn.Sequential(*layers) 229 | 230 | def _forward_impl(self, x: Tensor) -> Tensor: 231 | # See note [TorchScript super()] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.avgpool(x) 243 | x = torch.flatten(x, 1) 244 | x = self.fc(x) 245 | 246 | return x 247 | 248 | def forward(self, x: Tensor) -> Tensor: 249 | return self._forward_impl(x) 250 | 251 | 252 | def _resnet( 253 | arch: str, 254 | block: Type[Union[BasicBlock, Bottleneck]], 255 | layers: List[int], 256 | pretrained: bool, 257 | progress: bool, 258 | **kwargs: Any 259 | ) -> ResNet: 260 | model = ResNet(block, layers, **kwargs) 261 | if pretrained: 262 | state_dict = load_state_dict_from_url(model_urls[arch], 263 | progress=progress) 264 | model.load_state_dict(state_dict) 265 | return model 266 | 267 | 268 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 269 | r"""ResNet-18 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 281 | r"""ResNet-34 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | 292 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 293 | r"""ResNet-50 model from 294 | `"Deep Residual Learning for Image Recognition" `_ 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | progress (bool): If True, displays a progress bar of the download to stderr 299 | """ 300 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 301 | **kwargs) 302 | 303 | 304 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 305 | r"""ResNet-101 model from 306 | `"Deep Residual Learning for Image Recognition" `_ 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 313 | **kwargs) 314 | 315 | 316 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 317 | r"""ResNet-152 model from 318 | `"Deep Residual Learning for Image Recognition" `_ 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 325 | **kwargs) 326 | 327 | 328 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 329 | r"""ResNeXt-50 32x4d model from 330 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['groups'] = 32 337 | kwargs['width_per_group'] = 4 338 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 339 | pretrained, progress, **kwargs) 340 | 341 | 342 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 343 | r"""ResNeXt-101 32x8d model from 344 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 345 | 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | kwargs['groups'] = 32 351 | kwargs['width_per_group'] = 8 352 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 353 | pretrained, progress, **kwargs) 354 | 355 | 356 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 357 | r"""Wide ResNet-50-2 model from 358 | `"Wide Residual Networks" `_ 359 | 360 | The model is the same as ResNet except for the bottleneck number of channels 361 | which is twice larger in every block. The number of channels in outer 1x1 362 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 363 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 364 | 365 | Args: 366 | pretrained (bool): If True, returns a model pre-trained on ImageNet 367 | progress (bool): If True, displays a progress bar of the download to stderr 368 | """ 369 | kwargs['width_per_group'] = 64 * 2 370 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 371 | pretrained, progress, **kwargs) 372 | 373 | 374 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 375 | r"""Wide ResNet-101-2 model from 376 | `"Wide Residual Networks" `_ 377 | 378 | The model is the same as ResNet except for the bottleneck number of channels 379 | which is twice larger in every block. The number of channels in outer 1x1 380 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 381 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 382 | 383 | Args: 384 | pretrained (bool): If True, returns a model pre-trained on ImageNet 385 | progress (bool): If True, displays a progress bar of the download to stderr 386 | """ 387 | kwargs['width_per_group'] = 64 * 2 388 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 389 | pretrained, progress, **kwargs) 390 | -------------------------------------------------------------------------------- /ops/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_tuple): 18 | img_group, label = img_tuple 19 | 20 | w, h = img_group[0].size 21 | th, tw = self.size 22 | 23 | out_images = list() 24 | 25 | x1 = random.randint(0, w - tw) 26 | y1 = random.randint(0, h - th) 27 | 28 | for img in img_group: 29 | assert(img.size[0] == w and img.size[1] == h) 30 | if w == tw and h == th: 31 | out_images.append(img) 32 | else: 33 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 34 | 35 | return (out_images, label) 36 | 37 | 38 | class GroupCenterCrop(object): 39 | def __init__(self, size): 40 | self.worker = torchvision.transforms.CenterCrop(size) 41 | 42 | def __call__(self, img_tuple): 43 | img_group, label = img_tuple 44 | return ([self.worker(img) for img in img_group], label) 45 | 46 | 47 | class GroupRandomHorizontalFlip(object): 48 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 49 | """ 50 | def __init__(self, selective_flip=True, is_flow=False): 51 | self.is_flow = is_flow 52 | self.class_LeftRight = [86,87,93,94,166,167] if selective_flip else [] 53 | 54 | def __call__(self, img_tuple, is_flow=False): 55 | img_group, label = img_tuple 56 | v = random.random() 57 | if (label not in self.class_LeftRight) and v < 0.5: 58 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 59 | if self.is_flow: 60 | for i in range(0, len(ret), 2): 61 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 62 | return (ret, label) 63 | else: 64 | return img_tuple 65 | 66 | class GroupNormalize(object): 67 | def __init__(self, mean, std): 68 | self.mean = mean 69 | self.std = std 70 | 71 | def __call__(self, tensor_tuple): 72 | tensor, label = tensor_tuple 73 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 74 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 75 | 76 | # TODO: make efficient 77 | for t, m, s in zip(tensor, rep_mean, rep_std): 78 | t.sub_(m).div_(s) 79 | 80 | return (tensor,label) 81 | 82 | 83 | class GroupGrayScale(object): 84 | def __init__(self, size): 85 | self.worker = torchvision.transforms.Grayscale(size) 86 | 87 | def __call__(self, img_tuple): 88 | img_group, label = img_tuple 89 | return ([self.worker(img) for img in img_group], label) 90 | 91 | 92 | class GroupScale(object): 93 | """ Rescales the input PIL.Image to the given 'size'. 94 | 'size' will be the size of the smaller edge. 95 | For example, if height > width, then image will be 96 | rescaled to (size * height / width, size) 97 | size: size of the smaller edge 98 | interpolation: Default: PIL.Image.BILINEAR 99 | """ 100 | 101 | def __init__(self, size, interpolation=Image.BILINEAR): 102 | self.worker = torchvision.transforms.Resize(size, interpolation) 103 | 104 | def __call__(self, img_tuple): 105 | img_group, label = img_tuple 106 | return ([self.worker(img) for img in img_group], label) 107 | 108 | 109 | class GroupOverSample(object): 110 | def __init__(self, crop_size, scale_size=None): 111 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 112 | 113 | if scale_size is not None: 114 | self.scale_worker = GroupScale(scale_size) 115 | else: 116 | self.scale_worker = None 117 | 118 | def __call__(self, img_tuple): 119 | if self.scale_worker is not None: 120 | img_tuple = self.scale_worker(img_tuple) 121 | 122 | img_group, label = img_tuple 123 | 124 | image_w, image_h = img_group[0].size 125 | crop_w, crop_h = self.crop_size 126 | 127 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 128 | oversample_group = list() 129 | for o_w, o_h in offsets: 130 | normal_group = list() 131 | flip_group = list() 132 | for i, img in enumerate(img_group): 133 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 134 | normal_group.append(crop) 135 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 136 | 137 | if img.mode == 'L' and i % 2 == 0: 138 | flip_group.append(ImageOps.invert(flip_crop)) 139 | else: 140 | flip_group.append(flip_crop) 141 | 142 | oversample_group.extend(normal_group) 143 | oversample_group.extend(flip_group) 144 | return (oversample_group, label) 145 | 146 | class GroupFullResSample(object): 147 | def __init__(self, crop_size, scale_size=None, flip=True): 148 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 149 | 150 | if scale_size is not None: 151 | self.scale_worker = GroupScale(scale_size) 152 | else: 153 | self.scale_worker = None 154 | self.flip = flip 155 | 156 | def __call__(self, img_tuple): 157 | 158 | if self.scale_worker is not None: 159 | img_tuple = self.scale_worker(img_tuple) 160 | 161 | img_group, label = img_tuple 162 | image_w, image_h = img_group[0].size 163 | crop_w, crop_h = self.crop_size 164 | 165 | w_step = (image_w - crop_w) // 4 166 | h_step = (image_h - crop_h) // 4 167 | 168 | offsets = list() 169 | offsets.append((0 * w_step, 2 * h_step)) # left 170 | offsets.append((4 * w_step, 2 * h_step)) # right 171 | offsets.append((2 * w_step, 2 * h_step)) # center 172 | 173 | oversample_group = list() 174 | for o_w, o_h in offsets: 175 | normal_group = list() 176 | flip_group = list() 177 | for i, img in enumerate(img_group): 178 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 179 | normal_group.append(crop) 180 | if self.flip: 181 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 182 | 183 | if img.mode == 'L' and i % 2 == 0: 184 | flip_group.append(ImageOps.invert(flip_crop)) 185 | else: 186 | flip_group.append(flip_crop) 187 | 188 | oversample_group.extend(normal_group) 189 | oversample_group.extend(flip_group) 190 | return (oversample_group, label) 191 | 192 | class GroupMultiScaleCrop(object): 193 | 194 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 195 | self.scales = scales if scales is not None else [1, 875, .75, .66] 196 | self.max_distort = max_distort 197 | self.fix_crop = fix_crop 198 | self.more_fix_crop = more_fix_crop 199 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 200 | self.interpolation = Image.BILINEAR 201 | 202 | def __call__(self, img_tuple): 203 | img_group, label = img_tuple 204 | 205 | im_size = img_group[0].size 206 | 207 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 208 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 209 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 210 | return (ret_img_group, label) 211 | 212 | def _sample_crop_size(self, im_size): 213 | image_w, image_h = im_size[0], im_size[1] 214 | 215 | # find a crop size 216 | base_size = min(image_w, image_h) 217 | crop_sizes = [int(base_size * x) for x in self.scales] 218 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 219 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 220 | 221 | pairs = [] 222 | for i, h in enumerate(crop_h): 223 | for j, w in enumerate(crop_w): 224 | if abs(i - j) <= self.max_distort: 225 | pairs.append((w, h)) 226 | 227 | crop_pair = random.choice(pairs) 228 | if not self.fix_crop: 229 | w_offset = random.randint(0, image_w - crop_pair[0]) 230 | h_offset = random.randint(0, image_h - crop_pair[1]) 231 | else: 232 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 233 | 234 | return crop_pair[0], crop_pair[1], w_offset, h_offset 235 | 236 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 237 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 238 | return random.choice(offsets) 239 | 240 | @staticmethod 241 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 242 | w_step = (image_w - crop_w) // 4 243 | h_step = (image_h - crop_h) // 4 244 | 245 | ret = list() 246 | ret.append((0, 0)) # upper left 247 | ret.append((4 * w_step, 0)) # upper right 248 | ret.append((0, 4 * h_step)) # lower left 249 | ret.append((4 * w_step, 4 * h_step)) # lower right 250 | ret.append((2 * w_step, 2 * h_step)) # center 251 | 252 | if more_fix_crop: 253 | ret.append((0, 2 * h_step)) # center left 254 | ret.append((4 * w_step, 2 * h_step)) # center right 255 | ret.append((2 * w_step, 4 * h_step)) # lower center 256 | ret.append((2 * w_step, 0 * h_step)) # upper center 257 | 258 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 259 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 260 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 261 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 262 | 263 | return ret 264 | 265 | 266 | class GroupRandomSizedCrop(object): 267 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 268 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 269 | This is popularly used to train the Inception networks 270 | size: size of the smaller edge 271 | interpolation: Default: PIL.Image.BILINEAR 272 | """ 273 | def __init__(self, size, interpolation=Image.BILINEAR): 274 | self.size = size 275 | self.interpolation = interpolation 276 | 277 | def __call__(self, img_tuple): 278 | img_group, label = img_tuple 279 | 280 | for attempt in range(10): 281 | area = img_group[0].size[0] * img_group[0].size[1] 282 | target_area = random.uniform(0.08, 1.0) * area 283 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 284 | 285 | w = int(round(math.sqrt(target_area * aspect_ratio))) 286 | h = int(round(math.sqrt(target_area / aspect_ratio))) 287 | 288 | if random.random() < 0.5: 289 | w, h = h, w 290 | 291 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 292 | x1 = random.randint(0, img_group[0].size[0] - w) 293 | y1 = random.randint(0, img_group[0].size[1] - h) 294 | found = True 295 | break 296 | else: 297 | found = False 298 | x1 = 0 299 | y1 = 0 300 | 301 | if found: 302 | out_group = list() 303 | for img in img_group: 304 | img = img.crop((x1, y1, x1 + w, y1 + h)) 305 | assert(img.size == (w, h)) 306 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 307 | return out_group 308 | else: 309 | # Fallback 310 | scale = GroupScale(self.size, interpolation=self.interpolation) 311 | crop = GroupRandomCrop(self.size) 312 | return crop(scale(img_group)) 313 | 314 | 315 | class Stack(object): 316 | 317 | def __init__(self, roll=False): 318 | self.roll = roll 319 | 320 | def __call__(self, img_tuple): 321 | img_group, label = img_tuple 322 | 323 | if img_group[0].mode == 'L': 324 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 325 | elif img_group[0].mode == 'RGB': 326 | if self.roll: 327 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 328 | else: 329 | return (np.concatenate(img_group, axis=2), label) 330 | 331 | 332 | class ToTorchFormatTensor(object): 333 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 334 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 335 | def __init__(self, div=True): 336 | self.div = div 337 | 338 | def __call__(self, pic_tuple): 339 | pic, label = pic_tuple 340 | 341 | if isinstance(pic, np.ndarray): 342 | # handle numpy array 343 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 344 | else: 345 | # handle PIL Image 346 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 347 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 348 | # put it from HWC to CHW format 349 | # yikes, this transpose takes 80% of the loading time/CPU 350 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 351 | return (img.float().div(255.) if self.div else img.float(), label) 352 | 353 | 354 | class IdentityTransform(object): 355 | 356 | def __call__(self, data): 357 | return data 358 | 359 | 360 | if __name__ == "__main__": 361 | trans = torchvision.transforms.Compose([ 362 | GroupScale(256), 363 | GroupRandomCrop(224), 364 | Stack(), 365 | ToTorchFormatTensor(), 366 | GroupNormalize( 367 | mean=[.485, .456, .406], 368 | std=[.229, .224, .225] 369 | )] 370 | ) 371 | 372 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 373 | 374 | color_group = [im] * 3 375 | rst = trans(color_group) 376 | 377 | gray_group = [im.convert('L')] * 9 378 | gray_rst = trans(gray_group) 379 | 380 | trans2 = torchvision.transforms.Compose([ 381 | GroupRandomSizedCrop(256), 382 | Stack(), 383 | ToTorchFormatTensor(), 384 | GroupNormalize( 385 | mean=[.485, .456, .406], 386 | std=[.229, .224, .225]) 387 | ]) 388 | print(trans2(color_group)) -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def get_grad_hook(name): 6 | def hook(m, grad_in, grad_out): 7 | print((name, grad_out[0].data.abs().mean(), grad_in[0].data.abs().mean())) 8 | print((grad_out[0].size())) 9 | print((grad_in[0].size())) 10 | 11 | print((grad_out[0])) 12 | print((grad_in[0])) 13 | 14 | return hook 15 | 16 | 17 | def softmax(scores): 18 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 19 | return es / es.sum(axis=-1)[..., None] 20 | 21 | 22 | def log_add(log_a, log_b): 23 | return log_a + np.log(1 + np.exp(log_b - log_a)) 24 | 25 | 26 | def class_accuracy(prediction, label): 27 | cf = confusion_matrix(prediction, label) 28 | cls_cnt = cf.sum(axis=1) 29 | cls_hit = np.diag(cf) 30 | 31 | cls_acc = cls_hit / cls_cnt.astype(float) 32 | 33 | mean_cls_acc = cls_acc.mean() 34 | 35 | return cls_acc, mean_cls_acc -------------------------------------------------------------------------------- /ops/video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from numpy.random import randint 8 | import decord 9 | 10 | import ipdb 11 | 12 | 13 | class VideoRecord(object): 14 | def __init__(self, row): 15 | self._data = row 16 | 17 | @property 18 | def path(self): 19 | return self._data[0] 20 | 21 | @property 22 | def num_frames(self): 23 | return int(self._data[1]) 24 | 25 | @property 26 | def label(self): 27 | return int(self._data[2]) 28 | 29 | 30 | class TSNDataSet(data.Dataset): 31 | def __init__(self, root_path, list_file, 32 | num_segments=3, new_length=1, interval=2, modality='RGB', mode=1, stride=8, 33 | image_tmpl='img_{:05d}.jpg',img_start_idx=1, transform=None, 34 | force_grayscale=False, random_shift=True, test_mode=False): 35 | 36 | self.root_path = root_path 37 | self.list_file = list_file 38 | self.num_segments = num_segments 39 | self.new_length = new_length 40 | 41 | self.modality = modality 42 | self.image_tmpl = image_tmpl 43 | self.transform = transform 44 | self.random_shift = random_shift 45 | self.test_mode = test_mode 46 | 47 | self.interval= interval 48 | self.mode = mode 49 | self.stride = stride 50 | self.start_idx = img_start_idx 51 | 52 | if self.modality == 'RGBDiff': 53 | self.new_length += 1# Diff needs one more image to calculate diff 54 | 55 | self._parse_list() 56 | #ipdb.set_trace() 57 | 58 | def _load_image(self, directory, idx): 59 | 60 | ####################################################### 61 | #if os.path.exists(os.path.join(self.root_path, directory, self.image_tmpl.format(0))): 62 | # idx = idx - 1 # frame number start from 0 63 | ####################################################### 64 | 65 | if self.start_idx == 0 : 66 | idx = idx - 1 67 | 68 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 69 | return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 70 | elif self.modality == 'Flow': 71 | ''' 72 | img = Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB') 73 | flow_x, flow_y, _ = img.split() 74 | x_img = flow_x.convert('L') 75 | y_img = flow_y.convert('L') 76 | ''' 77 | x_img = Image.open(os.path.join(directory, self.image_tmpl.format('x', idx))).convert('L') 78 | y_img = Image.open(os.path.join(directory, self.image_tmpl.format('y', idx))).convert('L') 79 | 80 | return [x_img, y_img] 81 | 82 | def _parse_list(self): 83 | self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)] 84 | 85 | def _sample_indices(self, record): 86 | """ 87 | :param record: VideoRecord 88 | :return: list 89 | """ 90 | if self.mode==0: # i3d dense sample 91 | sample_pos = max(1, 1 + record.num_frames - 64) 92 | t_stride = 64 // self.num_segments 93 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 94 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 95 | # print (offsets) 96 | return np.array(offsets) + 1 97 | elif self.mode: # normal sample 98 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 99 | if average_duration > 0: 100 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 101 | size=self.num_segments) 102 | elif record.num_frames > self.num_segments: 103 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 104 | else: 105 | offsets = np.zeros((self.num_segments,)) 106 | # print (offsets) 107 | return offsets + 1 108 | 109 | def _get_val_indices(self, record): 110 | if (self.mode==0): # i3d dense sample 111 | sample_pos = max(1, 1 + record.num_frames - 64) 112 | t_stride = 64 // self.num_segments 113 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 114 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 115 | return np.array(offsets) + 1 116 | else: 117 | if record.num_frames > self.num_segments + self.new_length - 1: 118 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 119 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 120 | else: 121 | offsets = np.zeros((self.num_segments,)) 122 | return offsets + 1 123 | 124 | def _get_test_indices(self, record): 125 | if (self.mode==0): 126 | sample_pos = max(1, 1 + record.num_frames - 64) 127 | t_stride = 64 // self.num_segments 128 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) 129 | offsets = [] 130 | for start_idx in start_list.tolist(): 131 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 132 | return np.array(offsets) + 1 133 | 134 | elif self.mode==2: # tsm twice sample 135 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 136 | 137 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + 138 | [int(tick * x) for x in range(self.num_segments)]) 139 | 140 | return offsets + 1 141 | 142 | else: 143 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 144 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 145 | return offsets + 1 146 | 147 | 148 | def __getitem__(self, index): 149 | record = self.video_list[index] 150 | 151 | if not self.test_mode: 152 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 153 | # print (segment_indices) 154 | else: 155 | segment_indices = self._get_test_indices(record) 156 | # print (segment_indices) 157 | 158 | return self.get(record, segment_indices) 159 | 160 | def get(self, record, indices): 161 | 162 | images = list() 163 | for seg_ind in indices: 164 | p = int(seg_ind) 165 | for i in range(self.new_length): 166 | seg_imgs = self._load_image(record.path, p) 167 | images.extend(seg_imgs) 168 | if p < record.num_frames: 169 | p += 1 170 | 171 | process_data, _ = self.transform((images,record.label)) 172 | return process_data, record.label 173 | 174 | def __len__(self): 175 | return len(self.video_list) 176 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks") 3 | parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics', 'something','somethingv2','jhmdb', 'tinykinetics']) 4 | parser.add_argument('modality', type=str, choices=['gray', 'RGB', 'Flow', 'RGBDiff']) 5 | parser.add_argument('train_list', type=str) 6 | parser.add_argument('val_list', type=str) 7 | parser.add_argument('--img_start_idx', type=int, default=1) 8 | 9 | # ========================= Model Configs ========================== 10 | parser.add_argument('--arch', type=str, default="resnet101") 11 | parser.add_argument('--num_segments', type=int, default=3) 12 | parser.add_argument('--mode', type=int, default=1) 13 | parser.add_argument('--consensus_type', type=str, default='avg', 14 | choices=['avg', 'max', 'topk', 'identity', 'rnn', 'cnn']) 15 | parser.add_argument('--k', type=int, default=3) 16 | 17 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 18 | metavar='DO', help='dropout ratio (default: 0.5)') 19 | parser.add_argument('--loss_type', type=str, default="nll", 20 | choices=['nll']) 21 | parser.add_argument('--rep_flow', default=False, action='store_true') 22 | 23 | # ========================= Learning Configs ========================== 24 | parser.add_argument('--epochs', default=45, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('-b', '--batch-size', default=256, type=int, 27 | metavar='N', help='mini-batch size (default: 256)') 28 | parser.add_argument('-i', '--iter-size', default=1, type=int, 29 | metavar='N', help='number of iterations before on update') 30 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 31 | metavar='LR', help='initial learning rate') 32 | parser.add_argument('--lr_steps', default=[20, 40], type=float, nargs="+", 33 | metavar='LRSteps', help='epochs to decay learning rate by 10') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 37 | metavar='W', help='weight decay (default: 5e-4)') 38 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 39 | metavar='W', help='gradient norm clipping (default: disabled)') 40 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 41 | parser.add_argument('--nesterov', default=False) 42 | 43 | # ========================= Monitor Configs ========================== 44 | parser.add_argument('--print-freq', '-p', default=20, type=int, 45 | metavar='N', help='print frequency (default: 10)') 46 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 47 | metavar='N', help='evaluation frequency (default: 5)') 48 | 49 | 50 | # ========================= Runtime Configs ========================== 51 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 52 | help='number of data loading workers (default: 4)') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 56 | help='evaluate model on validation set') 57 | parser.add_argument('--snapshot_pref', type=str, default="") 58 | parser.add_argument('--val_output_folder', type=str, default="", help="folder location to store validation scores") 59 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 60 | help='manual epoch number (useful on restarts)') 61 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 62 | parser.add_argument('--flow_prefix', default="flow_{}_", type=str) 63 | parser.add_argument('--rgb_prefix', default="img_", type=str) 64 | 65 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block') 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | backcall==0.2.0 3 | blessings==1.7 4 | cachetools==4.2.1 5 | certifi==2020.12.5 6 | chardet==4.0.0 7 | cycler==0.10.0 8 | Cython==0.29.22 9 | decorator==5.0.7 10 | google-auth==1.28.1 11 | google-auth-oauthlib==0.4.4 12 | gpustat==0.6.0 13 | grpcio==1.37.0 14 | idna==2.10 15 | ipdb==0.13.7 16 | ipython==7.22.0 17 | ipython-genutils==0.2.0 18 | jedi==0.18.0 19 | kiwisolver==1.3.1 20 | Markdown==3.3.4 21 | matplotlib==3.4.1 22 | mkl-fft==1.3.0 23 | mkl-random==1.2.0 24 | mkl-service==2.3.0 25 | nvidia-ml-py3==7.352.0 26 | oauthlib==3.1.0 27 | opencv-python==4.5.1.48 28 | pandas==1.2.3 29 | parso==0.8.2 30 | pexpect==4.8.0 31 | pickleshare==0.7.5 32 | prompt-toolkit==3.0.18 33 | protobuf==3.15.8 34 | psutil==5.8.0 35 | ptyprocess==0.7.0 36 | pyasn1==0.4.8 37 | pyasn1-modules==0.2.8 38 | pycocotools==2.0.2 39 | Pygments==2.8.1 40 | pyparsing==2.4.7 41 | python-dateutil==2.8.1 42 | pytz==2021.1 43 | PyYAML==5.4.1 44 | requests==2.25.1 45 | requests-oauthlib==1.3.0 46 | rsa==4.7.2 47 | scipy==1.6.2 48 | seaborn==0.11.1 49 | tensorboard==2.4.1 50 | tensorboard-plugin-wit==1.8.0 51 | thop==0.0.31.post2005241907 52 | toml==0.10.2 53 | tqdm==4.60.0 54 | traitlets==5.0.5 55 | wcwidth==0.2.5 56 | Werkzeug==1.0.1 57 | -------------------------------------------------------------------------------- /resnet_TSM.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example combining `Temporal Shift Module` with `ResNet`. This implementation 3 | is based on `Temporal Segment Networks`, which merges temporal dimension into 4 | batch, i.e. inputs [N*T, C, H, W]. Here we show the case with residual connections 5 | and zero padding with 8 frames as input. 6 | """ 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from tsm_util import tsm 10 | import torch.utils.model_zoo as model_zoo 11 | from TCM import TCM 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 15 | 'resnet152'] 16 | 17 | 18 | model_urls = { 19 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 20 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 21 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 22 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | def conv1x1x1(in_planes, out_planes, stride=1): 38 | """1x1x1 convolution""" 39 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True) 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, num_segments, stride=1, downsample=None, remainder=0): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | self.remainder= remainder 54 | self.num_segments = num_segments 55 | 56 | def forward(self, x): 57 | identity = x 58 | out = tsm(x, self.num_segments, 'zero') 59 | out = self.conv1(out) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | class Bottleneck(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, num_segments, stride=1, downsample=None, remainder=0): 78 | super(Bottleneck, self).__init__() 79 | self.conv1 = conv1x1(inplanes, planes) 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.conv2 = conv3x3(planes, planes, stride) 82 | self.bn2 = nn.BatchNorm2d(planes) 83 | self.conv3 = conv1x1(planes, planes * self.expansion) 84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.downsample = downsample 87 | self.stride = stride 88 | self.remainder= remainder 89 | self.num_segments = num_segments 90 | 91 | def forward(self, x): 92 | identity = x 93 | out = tsm(x, self.num_segments, 'zero') 94 | out = self.conv1(out) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | identity = self.downsample(x) 107 | 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | class ResNet(nn.Module): 114 | 115 | def __init__(self, block, layers, num_segments, enable_TCM, num_classes=1000, zero_init_residual=False): 116 | super(ResNet, self).__init__() 117 | self.num_segments = num_segments 118 | self.enable_TCM = enable_TCM 119 | 120 | self.inplanes = 64 121 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 122 | bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | 125 | self.relu = nn.ReLU(inplace=True) 126 | self.sigmoid = nn.Sigmoid() 127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 128 | 129 | 130 | if enable_TCM: 131 | self.tcm_layer = TCM(self.num_segments, expansion = block.expansion, pos=2) 132 | 133 | self.layer1 = self._make_layer(block, 64, layers[0], num_segments=num_segments) 134 | self.layer2 = self._make_layer(block, 128, layers[1], num_segments=num_segments, stride=2) 135 | self.layer3 = self._make_layer(block, 256, layers[2], num_segments=num_segments, stride=2) 136 | self.layer4 = self._make_layer(block, 512, layers[3], num_segments=num_segments, stride=2) 137 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 138 | # self.fc1 = nn.Linear(512 * block.expansion, num_classes) 139 | self.fc1 = nn.Conv1d(512*block.expansion, num_classes, kernel_size=1, stride=1, padding=0,bias=True) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | elif isinstance(m, nn.BatchNorm2d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | elif isinstance(m, nn.BatchNorm3d): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | # Zero-initialize the last BN in each residual branch, 152 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 153 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | 162 | def _make_layer(self, block, planes, blocks, num_segments, stride=1): 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | conv1x1(self.inplanes, planes * block.expansion, stride), 167 | nn.BatchNorm2d(planes * block.expansion), 168 | ) 169 | 170 | layers = [] 171 | layers.append(block(self.inplanes, planes, num_segments, stride, downsample)) 172 | self.inplanes = planes * block.expansion 173 | for i in range(1, blocks): 174 | remainder =int( i % 3) 175 | layers.append(block(self.inplanes, planes, num_segments, remainder=remainder)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | 180 | def forward(self, x, temperature): 181 | input =x 182 | 183 | x = self.conv1(x) 184 | x = self.bn1(x) 185 | x = self.relu(x) 186 | x = self.maxpool(x) 187 | 188 | x = self.layer1(x) 189 | x = self.layer2(x) 190 | 191 | # Flow 192 | if (self.enable_TCM == 1): 193 | x = self.tcm_layer(x) 194 | 195 | x = self.layer3(x) 196 | x = self.layer4(x) 197 | x = self.avgpool(x) 198 | x = x.view(x.size(0), -1,1) 199 | 200 | x = self.fc1(x) 201 | return x 202 | 203 | 204 | def resnet18(pretrained=False, shift='TSM',num_segments = 8, enable_TCM=0, **kwargs): 205 | """Constructs a ResNet-18 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | if (shift =='TSM'): 211 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_segments=num_segments , enable_TCM=enable_TCM, **kwargs) 212 | if pretrained: 213 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 214 | new_state_dict = model.state_dict() 215 | for k, v in pretrained_dict.items(): 216 | if (k in new_state_dict): 217 | new_state_dict.update({k:v}) 218 | # print ("%s layer has pretrained weights" % k) 219 | model.load_state_dict(new_state_dict) 220 | return model 221 | 222 | 223 | def resnet34(pretrained=False, shift='TSM',num_segments = 8, enable_TCM=0,**kwargs): 224 | """Constructs a ResNet-34 model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | if (shift =='TSM'): 230 | model = ResNet(BasicBlock, [3, 4, 6, 3],num_segments=num_segments , enable_TCM=enable_TCM, **kwargs) 231 | if pretrained: 232 | pretrained_dict = model_zoo.load_url(model_urls['resnet34']) 233 | new_state_dict = model.state_dict() 234 | for k, v in pretrained_dict.items(): 235 | if (k in new_state_dict): 236 | new_state_dict.update({k:v}) 237 | # print ("%s layer has pretrained weights" % k) 238 | model.load_state_dict(new_state_dict) 239 | return model 240 | 241 | 242 | def resnet50(pretrained=False, shift='TSM', num_segments = 8, enable_TCM=0, **kwargs): 243 | """Constructs a ResNet-50 model. 244 | 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | """ 248 | if (shift =='TSM'): 249 | model = ResNet(Bottleneck, [3, 4, 6, 3],num_segments=num_segments , enable_TCM=enable_TCM, **kwargs) 250 | if pretrained: 251 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 252 | new_state_dict = model.state_dict() 253 | for k, v in pretrained_dict.items(): 254 | if (k in new_state_dict): 255 | new_state_dict.update({k:v}) 256 | # print ("%s layer has pretrained weights" % k) 257 | model.load_state_dict(new_state_dict) 258 | return model 259 | 260 | 261 | def resnet101(pretrained=False, shift='TSM',num_segments = 8, enable_TCM=0, **kwargs): 262 | """Constructs a ResNet-101 model. 263 | 264 | Args: 265 | pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | """ 267 | if (shift =='TSM'): 268 | model = ResNet(Bottleneck, [3, 4, 23, 3],num_segments=num_segments , enable_TCM=enable_TCM, **kwargs) 269 | if pretrained: 270 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 271 | new_state_dict = model.state_dict() 272 | for k, v in pretrained_dict.items(): 273 | if (k in new_state_dict): 274 | new_state_dict.update({k:v}) 275 | # print ("%s layer has pretrained weights" % k) 276 | model.load_state_dict(new_state_dict) 277 | return model 278 | -------------------------------------------------------------------------------- /tsm_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def tsm(tensor, duration, version='zero'): 5 | # tensor [N*T, C, H, W] 6 | size = tensor.size() 7 | tensor = tensor.view((-1, duration) + size[1:]) 8 | # tensor [N, T, C, H, W] 9 | pre_tensor, post_tensor, peri_tensor = tensor.split([size[1] // 8, 10 | size[1] // 8, 11 | 3*size[1] // 4], dim=2) 12 | if version == 'zero': 13 | pre_tensor = F.pad(pre_tensor, (0, 0, 0, 0, 0, 0, 0, 1))[:, 1: , ...] #F.pad(pre_tensor, (0, 0, 0, 0, 0, 0, 1, 0))[:, :-1, ...] 14 | post_tensor = F.pad(post_tensor, (0, 0, 0, 0, 0, 0, 1, 0))[:, :-1, ...] #F.pad(post_tensor, (0, 0, 0, 0, 0, 0, 0, 1))[:, 1: , ...] 15 | elif version == 'circulant': 16 | pre_tensor = torch.cat((pre_tensor [:, -1: , ...], 17 | pre_tensor [:, :-1, ...]), dim=1) 18 | post_tensor = torch.cat((post_tensor[:, 1: , ...], 19 | post_tensor[:, :1 , ...]), dim=1) 20 | else: 21 | raise ValueError('Unknown TSM version: {}'.format(version)) 22 | return torch.cat((pre_tensor, post_tensor, peri_tensor), dim=2).view(size) --------------------------------------------------------------------------------