├── .gitignore ├── LICENSE ├── README.md ├── asset └── arch.png ├── backbone ├── convrnn.py ├── resnet_2d3d.py └── select_backbone.py ├── dpc ├── dataset_3d.py ├── main.py └── model_3d.py ├── eval ├── dataset_3d_lc.py ├── model_3d_lc.py └── test.py ├── process_data ├── readme.md └── src │ ├── extract_frame.py │ └── write_csv.py └── utils ├── augmentation.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/ 2 | */*/__pycache__/ 3 | *.pyc 4 | *.pth.tar 5 | *.pth 6 | *.cluster.local 7 | *tfevent* 8 | *.png 9 | *.pkl 10 | */tmp/ 11 | process_data/data/ 12 | !asset/*.png 13 | #!asset/*/*.png 14 | #!asset/*/*/*.png 15 | *.nfs* 16 | */model/* 17 | *.tsv 18 | *.pbtxt 19 | *.svg 20 | */test_log.md 21 | *test_log.md 22 | *notes.md 23 | *.pdf 24 | *.csv 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tengda Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Video Representation Learning by Dense Predictive Coding 2 | 3 | This repository contains the implementation of Dense Predictive Coding (DPC). 4 | 5 | Links: [[Arxiv](https://arxiv.org/abs/1909.04656)] [[Video](https://youtu.be/43KIHUvHjB0)] [[Project page](http://www.robots.ox.ac.uk/~vgg/research/DPC/dpc.html)] 6 | 7 | ![arch](asset/arch.png) 8 | 9 | ### DPC Results 10 | 11 | Original result from [our paper](https://arxiv.org/abs/1909.04656): 12 | 13 | | Pretrain Dataset| Resolution | Backbone | Finetune Acc@1 (UCF101) | Finetune Acc@1 (HMDB51) | 14 | |----|----|----|----|----| 15 | |UCF101|128x128|2d3d-R18|60.6|-| 16 | |Kinetics400|128x128|2d3d-R18|68.2|34.5| 17 | |Kinetics400|224x224|2d3d-R34|75.7|35.7| 18 | 19 | Also re-implemented by other researchers: 20 | | Pretrain Dataset| Resolution | Backbone | Finetune Acc@1 (UCF101) | Finetune Acc@1 (HMDB51) | 21 | |----|----|----|----|----| 22 | |UCF101|128x128|2d3d-R18|61.35 [@kayush95](https://github.com/kayush95) |45.31 [@kayush95](https://github.com/kayush95) | 23 | 24 | ### News 25 | * 2020/10/09: Upload [3D-ResNet18-UCF101-128x128](http://www.robots.ox.ac.uk/~htd/dpc/ucf101-rgb-128_resnet18_dpc.pth.tar) pretrained weights. 26 | * 2020/06/05: Update the link for [3D-ResNet34-Kinetics400-224x224-runningStats](https://drive.google.com/file/d/1-WpsKzPNmSWuzoF2_qVfvfLOE1fwWD4x/view?usp=sharing), the [old one](https://drive.google.com/file/d/1d2XhuUwGTgEBg2cKkQbfJG8omHaSlELZ/view?usp=sharing) didn't save BN running statistics, thus couldn't be used to evaluate linear probe. Now it saves (without changing any weights). 27 | 28 | ### Installation 29 | 30 | The implementation should work with python >= 3.6, pytorch >= 0.4, torchvision >= 0.2.2. 31 | 32 | The repo also requires cv2 (`conda install -c menpo opencv`), tensorboardX >= 1.7 (`pip install tensorboardX`), joblib, tqdm, ipdb. 33 | 34 | ### Prepare data 35 | 36 | Follow the instructions [here](process_data/). 37 | 38 | ### Self-supervised training (DPC) 39 | 40 | Change directory `cd DPC/dpc/` 41 | 42 | * example: train DPC-RNN using 2 GPUs, with 3D-ResNet18 backbone, on Kinetics400 dataset with 128x128 resolution, for 300 epochs 43 | ``` 44 | python main.py --gpu 0,1 --net resnet18 --dataset k400 --batch_size 128 --img_dim 128 --epochs 300 45 | ``` 46 | 47 | * example: train DPC-RNN using 4 GPUs, with 3D-ResNet34 backbone, on Kinetics400 dataset with 224x224 resolution, for 150 epochs 48 | ``` 49 | python main.py --gpu 0,1,2,3 --net resnet34 --dataset k400 --batch_size 44 --img_dim 224 --epochs 150 50 | ``` 51 | 52 | ### Evaluation: supervised action classification 53 | 54 | Change directory `cd DPC/eval/` 55 | 56 | * example: finetune pretrained DPC weights (replace `{model.pth.tar}` with pretrained DPC model) 57 | ``` 58 | python test.py --gpu 0,1 --net resnet18 --dataset ucf101 --batch_size 128 --img_dim 128 --pretrain {model.pth.tar} --train_what ft --epochs 300 59 | ``` 60 | 61 | * example (continued): test the finetuned model (replace `{finetune_model.pth.tar}` with finetuned classifier model) 62 | ``` 63 | python test.py --gpu 0,1 --net resnet18 --dataset ucf101 --batch_size 128 --img_dim 128 --test {finetune_model.pth.tar} 64 | ``` 65 | 66 | ### DPC-pretrained weights 67 | 68 | It took us **more than 1 week** to train the 3D-ResNet18 DPC model on Kinetics-400 with 128x128 resolution, and it tooks about **6 weeks** to train the 3D-ResNet34 DPC model on Kinetics-400 with 224x224 resolution (with 4 Nvidia P40 GPUs). 69 | 70 | Download link: 71 | * Kinetics400 pretrain: 72 | - [3D-ResNet18-Kinetics400-128x128](https://drive.google.com/file/d/1jbMg2EAX8armIQA6_0YwfATh_h7rQz4u/view?usp=sharing), 73 | - [3D-ResNet34-Kinetics400-224x224](https://drive.google.com/file/d/1d2XhuUwGTgEBg2cKkQbfJG8omHaSlELZ/view?usp=sharing), 74 | - [3D-ResNet34-Kinetics400-224x224-runningStats](https://drive.google.com/file/d/1-WpsKzPNmSWuzoF2_qVfvfLOE1fwWD4x/view?usp=sharing) 75 | * UCF101 pretrain: 76 | - [3D-ResNet18-UCF101-128x128](http://www.robots.ox.ac.uk/~htd/dpc/ucf101-rgb-128_resnet18_dpc.pth.tar) 77 | 78 | * example: finetune `3D-ResNet34-Kinetics400-224x224` 79 | ``` 80 | python test.py --gpu 0,1 --net resnet34 --dataset ucf101 --batch_size 44 --img_dim 224 --pretrain {model.pth.tar} --train_what ft --epochs 300 81 | ``` 82 | 83 | ### Citation 84 | 85 | If you find the repo useful for your research, please consider citing our paper: 86 | ``` 87 | @InProceedings{Han19dpc, 88 | author = "Tengda Han and Weidi Xie and Andrew Zisserman", 89 | title = "Video Representation Learning by Dense Predictive Coding", 90 | booktitle = "Workshop on Large Scale Holistic Video Understanding, ICCV", 91 | year = "2019", 92 | } 93 | ``` 94 | For any questions, welcome to create an issue or contact Tengda Han ([htd@robots.ox.ac.uk](mailto:htd@robots.ox.ac.uk)). 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /asset/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TengdaHan/DPC/1592e5f5443c1a121d61c5b3a27b3be43f0e5fb5/asset/arch.png -------------------------------------------------------------------------------- /backbone/convrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvGRUCell(nn.Module): 5 | ''' Initialize ConvGRU cell ''' 6 | def __init__(self, input_size, hidden_size, kernel_size): 7 | super(ConvGRUCell, self).__init__() 8 | self.input_size = input_size 9 | self.hidden_size = hidden_size 10 | self.kernel_size = kernel_size 11 | padding = kernel_size // 2 12 | 13 | self.reset_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 14 | self.update_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 15 | self.out_gate = nn.Conv2d(input_size+hidden_size, hidden_size, kernel_size, padding=padding) 16 | 17 | nn.init.orthogonal_(self.reset_gate.weight) 18 | nn.init.orthogonal_(self.update_gate.weight) 19 | nn.init.orthogonal_(self.out_gate.weight) 20 | nn.init.constant_(self.reset_gate.bias, 0.) 21 | nn.init.constant_(self.update_gate.bias, 0.) 22 | nn.init.constant_(self.out_gate.bias, 0.) 23 | 24 | def forward(self, input_tensor, hidden_state): 25 | if hidden_state is None: 26 | B, C, *spatial_dim = input_tensor.size() 27 | hidden_state = torch.zeros([B,self.hidden_size,*spatial_dim]).cuda() 28 | # [B, C, H, W] 29 | combined = torch.cat([input_tensor, hidden_state], dim=1) #concat in C 30 | update = torch.sigmoid(self.update_gate(combined)) 31 | reset = torch.sigmoid(self.reset_gate(combined)) 32 | out = torch.tanh(self.out_gate(torch.cat([input_tensor, hidden_state * reset], dim=1))) 33 | new_state = hidden_state * (1 - update) + out * update 34 | return new_state 35 | 36 | 37 | class ConvGRU(nn.Module): 38 | ''' Initialize a multi-layer Conv GRU ''' 39 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, dropout=0.1): 40 | super(ConvGRU, self).__init__() 41 | self.input_size = input_size 42 | self.hidden_size = hidden_size 43 | self.kernel_size = kernel_size 44 | self.num_layers = num_layers 45 | 46 | cell_list = [] 47 | for i in range(self.num_layers): 48 | if i == 0: 49 | input_dim = self.input_size 50 | else: 51 | input_dim = self.hidden_size 52 | cell = ConvGRUCell(input_dim, self.hidden_size, self.kernel_size) 53 | name = 'ConvGRUCell_' + str(i).zfill(2) 54 | 55 | setattr(self, name, cell) 56 | cell_list.append(getattr(self, name)) 57 | 58 | self.cell_list = nn.ModuleList(cell_list) 59 | self.dropout_layer = nn.Dropout(p=dropout) 60 | 61 | 62 | def forward(self, x, hidden_state=None): 63 | [B, seq_len, *_] = x.size() 64 | 65 | if hidden_state is None: 66 | hidden_state = [None] * self.num_layers 67 | # input: image sequences [B, T, C, H, W] 68 | current_layer_input = x 69 | del x 70 | 71 | last_state_list = [] 72 | 73 | for idx in range(self.num_layers): 74 | cell_hidden = hidden_state[idx] 75 | output_inner = [] 76 | for t in range(seq_len): 77 | cell_hidden = self.cell_list[idx](current_layer_input[:,t,:], cell_hidden) 78 | cell_hidden = self.dropout_layer(cell_hidden) # dropout in each time step 79 | output_inner.append(cell_hidden) 80 | 81 | layer_output = torch.stack(output_inner, dim=1) 82 | current_layer_input = layer_output 83 | 84 | last_state_list.append(cell_hidden) 85 | 86 | last_state_list = torch.stack(last_state_list, dim=1) 87 | 88 | return layer_output, last_state_list 89 | 90 | 91 | if __name__ == '__main__': 92 | crnn = ConvGRU(input_size=10, hidden_size=20, kernel_size=3, num_layers=2) 93 | data = torch.randn(4, 5, 10, 6, 6) # [B, seq_len, C, H, W], temporal axis=1 94 | output, hn = crnn(data) 95 | import ipdb; ipdb.set_trace() 96 | -------------------------------------------------------------------------------- /backbone/resnet_2d3d.py: -------------------------------------------------------------------------------- 1 | ## modified from https://github.com/kenshohara/3D-ResNets-PyTorch 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import math 7 | 8 | __all__ = [ 9 | 'ResNet2d3d_full', 'resnet18_2d3d_full', 'resnet34_2d3d_full', 'resnet50_2d3d_full', 'resnet101_2d3d_full', 10 | 'resnet152_2d3d_full', 'resnet200_2d3d_full', 11 | ] 12 | 13 | def conv3x3x3(in_planes, out_planes, stride=1, bias=False): 14 | # 3x3x3 convolution with padding 15 | return nn.Conv3d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=bias) 22 | 23 | def conv1x3x3(in_planes, out_planes, stride=1, bias=False): 24 | # 1x3x3 convolution with padding 25 | return nn.Conv3d( 26 | in_planes, 27 | out_planes, 28 | kernel_size=(1,3,3), 29 | stride=(1,stride,stride), 30 | padding=(0,1,1), 31 | bias=bias) 32 | 33 | 34 | def downsample_basic_block(x, planes, stride): 35 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 36 | zero_pads = torch.Tensor( 37 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 38 | out.size(4)).zero_() 39 | if isinstance(out.data, torch.cuda.FloatTensor): 40 | zero_pads = zero_pads.cuda() 41 | 42 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 43 | 44 | return out 45 | 46 | 47 | class BasicBlock3d(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 51 | super(BasicBlock3d, self).__init__() 52 | bias = False 53 | self.use_final_relu = use_final_relu 54 | self.conv1 = conv3x3x3(inplanes, planes, stride, bias=bias) 55 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 56 | 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3x3(planes, planes, bias=bias) 59 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 60 | 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | out += residual 78 | if self.use_final_relu: out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class BasicBlock2d(nn.Module): 84 | expansion = 1 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 87 | super(BasicBlock2d, self).__init__() 88 | bias = False 89 | self.use_final_relu = use_final_relu 90 | self.conv1 = conv1x3x3(inplanes, planes, stride, bias=bias) 91 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 92 | 93 | self.relu = nn.ReLU(inplace=True) 94 | self.conv2 = conv1x3x3(planes, planes, bias=bias) 95 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 96 | 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | residual = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | if self.use_final_relu: out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class Bottleneck3d(nn.Module): 120 | expansion = 4 121 | 122 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 123 | super(Bottleneck3d, self).__init__() 124 | bias = False 125 | self.use_final_relu = use_final_relu 126 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 127 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 128 | 129 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias) 130 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 131 | 132 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 133 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 134 | 135 | self.relu = nn.ReLU(inplace=True) 136 | self.downsample = downsample 137 | self.stride = stride 138 | 139 | def forward(self, x): 140 | residual = x 141 | 142 | out = self.conv1(x) 143 | out = self.bn1(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv2(out) 147 | out = self.bn2(out) 148 | out = self.relu(out) 149 | 150 | out = self.conv3(out) 151 | out = self.bn3(out) 152 | 153 | if self.downsample is not None: 154 | residual = self.downsample(x) 155 | 156 | out += residual 157 | if self.use_final_relu: out = self.relu(out) 158 | 159 | return out 160 | 161 | 162 | class Bottleneck2d(nn.Module): 163 | expansion = 4 164 | 165 | def __init__(self, inplanes, planes, stride=1, downsample=None, track_running_stats=True, use_final_relu=True): 166 | super(Bottleneck2d, self).__init__() 167 | bias = False 168 | self.use_final_relu = use_final_relu 169 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 170 | self.bn1 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 171 | 172 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,stride,stride), padding=(0,1,1), bias=bias) 173 | self.bn2 = nn.BatchNorm3d(planes, track_running_stats=track_running_stats) 174 | 175 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 176 | self.bn3 = nn.BatchNorm3d(planes * 4, track_running_stats=track_running_stats) 177 | 178 | self.relu = nn.ReLU(inplace=True) 179 | self.downsample = downsample 180 | self.stride = stride 181 | 182 | def forward(self, x): 183 | residual = x 184 | 185 | out = self.conv1(x) 186 | out = self.bn1(out) 187 | out = self.relu(out) 188 | 189 | out = self.conv2(out) 190 | out = self.bn2(out) 191 | out = self.relu(out) 192 | 193 | out = self.conv3(out) 194 | out = self.bn3(out) 195 | 196 | if self.downsample is not None: 197 | residual = self.downsample(x) 198 | 199 | out += residual 200 | if self.use_final_relu: out = self.relu(out) 201 | 202 | return out 203 | 204 | 205 | class ResNet2d3d_full(nn.Module): 206 | def __init__(self, block, layers, track_running_stats=True): 207 | super(ResNet2d3d_full, self).__init__() 208 | self.inplanes = 64 209 | self.track_running_stats = track_running_stats 210 | bias = False 211 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1,7,7), stride=(1, 2, 2), padding=(0, 3, 3), bias=bias) 212 | self.bn1 = nn.BatchNorm3d(64, track_running_stats=track_running_stats) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 215 | 216 | if not isinstance(block, list): 217 | block = [block] * 4 218 | 219 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 220 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=2) 221 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=2) 222 | self.layer4 = self._make_layer(block[3], 256, layers[3], stride=2, is_final=True) 223 | # modify layer4 from exp=512 to exp=256 224 | for m in self.modules(): 225 | if isinstance(m, nn.Conv3d): 226 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 227 | if m.bias is not None: m.bias.data.zero_() 228 | elif isinstance(m, nn.BatchNorm3d): 229 | m.weight.data.fill_(1) 230 | m.bias.data.zero_() 231 | 232 | def _make_layer(self, block, planes, blocks, stride=1, is_final=False): 233 | downsample = None 234 | if stride != 1 or self.inplanes != planes * block.expansion: 235 | # customized_stride to deal with 2d or 3d residual blocks 236 | if (block == Bottleneck2d) or (block == BasicBlock2d): 237 | customized_stride = (1, stride, stride) 238 | else: 239 | customized_stride = stride 240 | 241 | downsample = nn.Sequential( 242 | nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=customized_stride, bias=False), 243 | nn.BatchNorm3d(planes * block.expansion, track_running_stats=self.track_running_stats) 244 | ) 245 | 246 | layers = [] 247 | layers.append(block(self.inplanes, planes, stride, downsample, track_running_stats=self.track_running_stats)) 248 | self.inplanes = planes * block.expansion 249 | if is_final: # if is final block, no ReLU in the final output 250 | for i in range(1, blocks-1): 251 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 252 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats, use_final_relu=False)) 253 | else: 254 | for i in range(1, blocks): 255 | layers.append(block(self.inplanes, planes, track_running_stats=self.track_running_stats)) 256 | 257 | return nn.Sequential(*layers) 258 | 259 | def forward(self, x): 260 | x = self.conv1(x) 261 | x = self.bn1(x) 262 | x = self.relu(x) 263 | x = self.maxpool(x) 264 | 265 | x = self.layer1(x) 266 | x = self.layer2(x) 267 | x = self.layer3(x) 268 | x = self.layer4(x) 269 | 270 | return x 271 | 272 | 273 | ## full resnet 274 | def resnet18_2d3d_full(**kwargs): 275 | '''Constructs a ResNet-18 model. ''' 276 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 277 | [2, 2, 2, 2], **kwargs) 278 | return model 279 | 280 | def resnet34_2d3d_full(**kwargs): 281 | '''Constructs a ResNet-34 model. ''' 282 | model = ResNet2d3d_full([BasicBlock2d, BasicBlock2d, BasicBlock3d, BasicBlock3d], 283 | [3, 4, 6, 3], **kwargs) 284 | return model 285 | 286 | def resnet50_2d3d_full(**kwargs): 287 | '''Constructs a ResNet-50 model. ''' 288 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 289 | [3, 4, 6, 3], **kwargs) 290 | return model 291 | 292 | def resnet101_2d3d_full(**kwargs): 293 | '''Constructs a ResNet-101 model. ''' 294 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 295 | [3, 4, 23, 3], **kwargs) 296 | return model 297 | 298 | def resnet152_2d3d_full(**kwargs): 299 | '''Constructs a ResNet-101 model. ''' 300 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 301 | [3, 8, 36, 3], **kwargs) 302 | return model 303 | 304 | def resnet200_2d3d_full(**kwargs): 305 | '''Constructs a ResNet-101 model. ''' 306 | model = ResNet2d3d_full([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 307 | [3, 24, 36, 3], **kwargs) 308 | return model 309 | 310 | def neq_load_customized(model, pretrained_dict): 311 | ''' load pre-trained model in a not-equal way, 312 | when new model has been partially modified ''' 313 | model_dict = model.state_dict() 314 | tmp = {} 315 | print('\n=======Check Weights Loading======') 316 | print('Weights not used from pretrained file:') 317 | for k, v in pretrained_dict.items(): 318 | if k in model_dict: 319 | tmp[k] = v 320 | else: 321 | print(k) 322 | print('---------------------------') 323 | print('Weights not loaded into new model:') 324 | for k, v in model_dict.items(): 325 | if k not in pretrained_dict: 326 | print(k) 327 | print('===================================\n') 328 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 329 | del pretrained_dict 330 | model_dict.update(tmp) 331 | del tmp 332 | model.load_state_dict(model_dict) 333 | return model 334 | 335 | 336 | if __name__ == '__main__': 337 | mymodel = resnet18_2d3d_full() 338 | mydata = torch.FloatTensor(4, 3, 16, 128, 128) 339 | nn.init.normal_(mydata) 340 | import ipdb; ipdb.set_trace() 341 | mymodel(mydata) 342 | -------------------------------------------------------------------------------- /backbone/select_backbone.py: -------------------------------------------------------------------------------- 1 | from resnet_2d3d import * 2 | 3 | def select_resnet(network, track_running_stats=True,): 4 | param = {'feature_size': 1024} 5 | if network == 'resnet18': 6 | model = resnet18_2d3d_full(track_running_stats=track_running_stats) 7 | param['feature_size'] = 256 8 | elif network == 'resnet34': 9 | model = resnet34_2d3d_full(track_running_stats=track_running_stats) 10 | param['feature_size'] = 256 11 | elif network == 'resnet50': 12 | model = resnet50_2d3d_full(track_running_stats=track_running_stats) 13 | elif network == 'resnet101': 14 | model = resnet101_2d3d_full(track_running_stats=track_running_stats) 15 | elif network == 'resnet152': 16 | model = resnet152_2d3d_full(track_running_stats=track_running_stats) 17 | elif network == 'resnet200': 18 | model = resnet200_2d3d_full(track_running_stats=track_running_stats) 19 | else: raise IOError('model type is wrong') 20 | 21 | return model, param -------------------------------------------------------------------------------- /dpc/dataset_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import os 5 | import sys 6 | import time 7 | import pickle 8 | import glob 9 | import csv 10 | import pandas as pd 11 | import numpy as np 12 | import cv2 13 | sys.path.append('../utils') 14 | from augmentation import * 15 | from tqdm import tqdm 16 | from joblib import Parallel, delayed 17 | 18 | def pil_loader(path): 19 | with open(path, 'rb') as f: 20 | with Image.open(f) as img: 21 | return img.convert('RGB') 22 | 23 | 24 | class Kinetics400_full_3d(data.Dataset): 25 | def __init__(self, 26 | mode='train', 27 | transform=None, 28 | seq_len=10, 29 | num_seq=5, 30 | downsample=3, 31 | epsilon=5, 32 | unit_test=False, 33 | big=False, 34 | return_label=False): 35 | self.mode = mode 36 | self.transform = transform 37 | self.seq_len = seq_len 38 | self.num_seq = num_seq 39 | self.downsample = downsample 40 | self.epsilon = epsilon 41 | self.unit_test = unit_test 42 | self.return_label = return_label 43 | 44 | if big: print('Using Kinetics400 full data (256x256)') 45 | else: print('Using Kinetics400 full data (150x150)') 46 | 47 | # get action list 48 | self.action_dict_encode = {} 49 | self.action_dict_decode = {} 50 | action_file = os.path.join('../process_data/data/kinetics400', 'classInd.txt') 51 | action_df = pd.read_csv(action_file, sep=',', header=None) 52 | for _, row in action_df.iterrows(): 53 | act_id, act_name = row 54 | act_id = int(act_id) - 1 # let id start from 0 55 | self.action_dict_decode[act_id] = act_name 56 | self.action_dict_encode[act_name] = act_id 57 | 58 | # splits 59 | if big: 60 | if mode == 'train': 61 | split = '../process_data/data/kinetics400_256/train_split.csv' 62 | video_info = pd.read_csv(split, header=None) 63 | elif (mode == 'val') or (mode == 'test'): 64 | split = '../process_data/data/kinetics400_256/val_split.csv' 65 | video_info = pd.read_csv(split, header=None) 66 | else: raise ValueError('wrong mode') 67 | else: # small 68 | if mode == 'train': 69 | split = '../process_data/data/kinetics400/train_split.csv' 70 | video_info = pd.read_csv(split, header=None) 71 | elif (mode == 'val') or (mode == 'test'): 72 | split = '../process_data/data/kinetics400/val_split.csv' 73 | video_info = pd.read_csv(split, header=None) 74 | else: raise ValueError('wrong mode') 75 | 76 | drop_idx = [] 77 | print('filter out too short videos ...') 78 | for idx, row in tqdm(video_info.iterrows(), total=len(video_info)): 79 | vpath, vlen = row 80 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 81 | drop_idx.append(idx) 82 | self.video_info = video_info.drop(drop_idx, axis=0) 83 | 84 | if mode == 'val': self.video_info = self.video_info.sample(frac=0.3, random_state=666) 85 | if self.unit_test: self.video_info = self.video_info.sample(32, random_state=666) 86 | # shuffle not necessary because use RandomSampler 87 | 88 | def idx_sampler(self, vlen, vpath): 89 | '''sample index from a video''' 90 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] 91 | n = 1 92 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) 93 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx 94 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample 95 | return [seq_idx_block, vpath] 96 | 97 | def __getitem__(self, index): 98 | vpath, vlen = self.video_info.iloc[index] 99 | items = self.idx_sampler(vlen, vpath) 100 | if items is None: print(vpath) 101 | 102 | idx_block, vpath = items 103 | assert idx_block.shape == (self.num_seq, self.seq_len) 104 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 105 | 106 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] 107 | t_seq = self.transform(seq) # apply same transform 108 | 109 | (C, H, W) = t_seq[0].size() 110 | t_seq = torch.stack(t_seq, 0) 111 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 112 | 113 | if self.return_label: 114 | try: 115 | vname = vpath.split('/')[-3] 116 | vid = self.encode_action(vname) 117 | except: 118 | vname = vpath.split('/')[-2] 119 | vid = self.encode_action(vname) 120 | 121 | label = torch.LongTensor([vid]) 122 | return t_seq, label 123 | 124 | return t_seq 125 | 126 | def __len__(self): 127 | return len(self.video_info) 128 | 129 | def encode_action(self, action_name): 130 | '''give action name, return category''' 131 | return self.action_dict_encode[action_name] 132 | 133 | def decode_action(self, action_code): 134 | '''give action code, return action name''' 135 | return self.action_dict_decode[action_code] 136 | 137 | 138 | class UCF101_3d(data.Dataset): 139 | def __init__(self, 140 | mode='train', 141 | transform=None, 142 | seq_len=10, 143 | num_seq = 5, 144 | downsample=3, 145 | epsilon=5, 146 | which_split=1, 147 | return_label=False): 148 | self.mode = mode 149 | self.transform = transform 150 | self.seq_len = seq_len 151 | self.num_seq = num_seq 152 | self.downsample = downsample 153 | self.epsilon = epsilon 154 | self.which_split = which_split 155 | self.return_label = return_label 156 | 157 | # splits 158 | if mode == 'train': 159 | split = '../process_data/data/ucf101/train_split%02d.csv' % self.which_split 160 | video_info = pd.read_csv(split, header=None) 161 | elif (mode == 'val') or (mode == 'test'): # use val for test 162 | split = '../process_data/data/ucf101/test_split%02d.csv' % self.which_split 163 | video_info = pd.read_csv(split, header=None) 164 | else: raise ValueError('wrong mode') 165 | 166 | # get action list 167 | self.action_dict_encode = {} 168 | self.action_dict_decode = {} 169 | action_file = os.path.join('../process_data/data/ucf101', 'classInd.txt') 170 | action_df = pd.read_csv(action_file, sep=' ', header=None) 171 | for _, row in action_df.iterrows(): 172 | act_id, act_name = row 173 | self.action_dict_decode[act_id] = act_name 174 | self.action_dict_encode[act_name] = act_id 175 | 176 | # filter out too short videos: 177 | drop_idx = [] 178 | for idx, row in video_info.iterrows(): 179 | vpath, vlen = row 180 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 181 | drop_idx.append(idx) 182 | self.video_info = video_info.drop(drop_idx, axis=0) 183 | 184 | if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 185 | # shuffle not required due to external sampler 186 | 187 | def idx_sampler(self, vlen, vpath): 188 | '''sample index from a video''' 189 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] 190 | n = 1 191 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) 192 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx 193 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample 194 | return [seq_idx_block, vpath] 195 | 196 | 197 | def __getitem__(self, index): 198 | vpath, vlen = self.video_info.iloc[index] 199 | items = self.idx_sampler(vlen, vpath) 200 | if items is None: print(vpath) 201 | 202 | idx_block, vpath = items 203 | assert idx_block.shape == (self.num_seq, self.seq_len) 204 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 205 | 206 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] 207 | t_seq = self.transform(seq) # apply same transform 208 | 209 | (C, H, W) = t_seq[0].size() 210 | t_seq = torch.stack(t_seq, 0) 211 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 212 | 213 | if self.return_label: 214 | try: 215 | vname = vpath.split('/')[-3] 216 | vid = self.encode_action(vname) 217 | except: 218 | vname = vpath.split('/')[-2] 219 | vid = self.encode_action(vname) 220 | label = torch.LongTensor([vid]) 221 | return t_seq, label 222 | 223 | return t_seq 224 | 225 | def __len__(self): 226 | return len(self.video_info) 227 | 228 | def encode_action(self, action_name): 229 | '''give action name, return action code''' 230 | return self.action_dict_encode[action_name] 231 | 232 | def decode_action(self, action_code): 233 | '''give action code, return action name''' 234 | return self.action_dict_decode[action_code] 235 | 236 | -------------------------------------------------------------------------------- /dpc/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import re 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | from tensorboardX import SummaryWriter 9 | import matplotlib.pyplot as plt 10 | plt.switch_backend('agg') 11 | 12 | sys.path.append('../utils') 13 | from dataset_3d import * 14 | from model_3d import * 15 | from resnet_2d3d import neq_load_customized 16 | from augmentation import * 17 | from utils import AverageMeter, save_checkpoint, denorm, calc_topk_accuracy 18 | 19 | import torch 20 | import torch.optim as optim 21 | from torch.utils import data 22 | from torchvision import datasets, models, transforms 23 | import torchvision.utils as vutils 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--net', default='resnet18', type=str) 29 | parser.add_argument('--model', default='dpc-rnn', type=str) 30 | parser.add_argument('--dataset', default='ucf101', type=str) 31 | parser.add_argument('--seq_len', default=5, type=int, help='number of frames in each video block') 32 | parser.add_argument('--num_seq', default=8, type=int, help='number of video blocks') 33 | parser.add_argument('--pred_step', default=3, type=int) 34 | parser.add_argument('--ds', default=3, type=int, help='frame downsampling rate') 35 | parser.add_argument('--batch_size', default=4, type=int) 36 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 37 | parser.add_argument('--wd', default=1e-5, type=float, help='weight decay') 38 | parser.add_argument('--resume', default='', type=str, help='path of model to resume') 39 | parser.add_argument('--pretrain', default='', type=str, help='path of pretrained model') 40 | parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') 41 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 42 | parser.add_argument('--gpu', default='0,1', type=str) 43 | parser.add_argument('--print_freq', default=5, type=int, help='frequency of printing output during training') 44 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 45 | parser.add_argument('--prefix', default='tmp', type=str, help='prefix of checkpoint filename') 46 | parser.add_argument('--train_what', default='all', type=str) 47 | parser.add_argument('--img_dim', default=128, type=int) 48 | 49 | def main(): 50 | torch.manual_seed(0) 51 | np.random.seed(0) 52 | global args; args = parser.parse_args() 53 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 54 | global cuda; cuda = torch.device('cuda') 55 | 56 | ### dpc model ### 57 | if args.model == 'dpc-rnn': 58 | model = DPC_RNN(sample_size=args.img_dim, 59 | num_seq=args.num_seq, 60 | seq_len=args.seq_len, 61 | network=args.net, 62 | pred_step=args.pred_step) 63 | else: raise ValueError('wrong model!') 64 | 65 | model = nn.DataParallel(model) 66 | model = model.to(cuda) 67 | global criterion; criterion = nn.CrossEntropyLoss() 68 | 69 | ### optimizer ### 70 | if args.train_what == 'last': 71 | for name, param in model.module.resnet.named_parameters(): 72 | param.requires_grad = False 73 | else: pass # train all layers 74 | 75 | print('\n===========Check Grad============') 76 | for name, param in model.named_parameters(): 77 | print(name, param.requires_grad) 78 | print('=================================\n') 79 | 80 | params = model.parameters() 81 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 82 | args.old_lr = None 83 | 84 | best_acc = 0 85 | global iteration; iteration = 0 86 | 87 | ### restart training ### 88 | if args.resume: 89 | if os.path.isfile(args.resume): 90 | args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1)) 91 | print("=> loading resumed checkpoint '{}'".format(args.resume)) 92 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 93 | args.start_epoch = checkpoint['epoch'] 94 | iteration = checkpoint['iteration'] 95 | best_acc = checkpoint['best_acc'] 96 | model.load_state_dict(checkpoint['state_dict']) 97 | if not args.reset_lr: # if didn't reset lr, load old optimizer 98 | optimizer.load_state_dict(checkpoint['optimizer']) 99 | else: print('==== Change lr from %f to %f ====' % (args.old_lr, args.lr)) 100 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 101 | else: 102 | print("[Warning] no checkpoint found at '{}'".format(args.resume)) 103 | 104 | if args.pretrain: 105 | if os.path.isfile(args.pretrain): 106 | print("=> loading pretrained checkpoint '{}'".format(args.pretrain)) 107 | checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu')) 108 | model = neq_load_customized(model, checkpoint['state_dict']) 109 | print("=> loaded pretrained checkpoint '{}' (epoch {})" 110 | .format(args.pretrain, checkpoint['epoch'])) 111 | else: 112 | print("=> no checkpoint found at '{}'".format(args.pretrain)) 113 | 114 | ### load data ### 115 | if args.dataset == 'ucf101': # designed for ucf101, short size=256, rand crop to 224x224 then scale to 128x128 116 | transform = transforms.Compose([ 117 | RandomHorizontalFlip(consistent=True), 118 | RandomCrop(size=224, consistent=True), 119 | Scale(size=(args.img_dim,args.img_dim)), 120 | RandomGray(consistent=False, p=0.5), 121 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), 122 | ToTensor(), 123 | Normalize() 124 | ]) 125 | elif args.dataset == 'k400': # designed for kinetics400, short size=150, rand crop to 128x128 126 | transform = transforms.Compose([ 127 | RandomSizedCrop(size=args.img_dim, consistent=True, p=1.0), 128 | RandomHorizontalFlip(consistent=True), 129 | RandomGray(consistent=False, p=0.5), 130 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=1.0), 131 | ToTensor(), 132 | Normalize() 133 | ]) 134 | 135 | train_loader = get_data(transform, 'train') 136 | val_loader = get_data(transform, 'val') 137 | 138 | # setup tools 139 | global de_normalize; de_normalize = denorm() 140 | global img_path; img_path, model_path = set_path(args) 141 | global writer_train 142 | try: # old version 143 | writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) 144 | writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) 145 | except: # v1.7 146 | writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) 147 | writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) 148 | 149 | ### main loop ### 150 | for epoch in range(args.start_epoch, args.epochs): 151 | train_loss, train_acc, train_accuracy_list = train(train_loader, model, optimizer, epoch) 152 | val_loss, val_acc, val_accuracy_list = validate(val_loader, model, epoch) 153 | 154 | # save curve 155 | writer_train.add_scalar('global/loss', train_loss, epoch) 156 | writer_train.add_scalar('global/accuracy', train_acc, epoch) 157 | writer_val.add_scalar('global/loss', val_loss, epoch) 158 | writer_val.add_scalar('global/accuracy', val_acc, epoch) 159 | writer_train.add_scalar('accuracy/top1', train_accuracy_list[0], epoch) 160 | writer_train.add_scalar('accuracy/top3', train_accuracy_list[1], epoch) 161 | writer_train.add_scalar('accuracy/top5', train_accuracy_list[2], epoch) 162 | writer_val.add_scalar('accuracy/top1', val_accuracy_list[0], epoch) 163 | writer_val.add_scalar('accuracy/top3', val_accuracy_list[1], epoch) 164 | writer_val.add_scalar('accuracy/top5', val_accuracy_list[2], epoch) 165 | 166 | # save check_point 167 | is_best = val_acc > best_acc; best_acc = max(val_acc, best_acc) 168 | save_checkpoint({'epoch': epoch+1, 169 | 'net': args.net, 170 | 'state_dict': model.state_dict(), 171 | 'best_acc': best_acc, 172 | 'optimizer': optimizer.state_dict(), 173 | 'iteration': iteration}, 174 | is_best, filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch+1)), keep_all=False) 175 | 176 | print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) 177 | 178 | def process_output(mask): 179 | '''task mask as input, compute the target for contrastive loss''' 180 | # dot product is computed in parallel gpus, so get less easy neg, bounded by batch size in each gpu''' 181 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 182 | (B, NP, SQ, B2, NS, _) = mask.size() # [B, P, SQ, B, N, SQ] 183 | target = mask == 1 184 | target.requires_grad = False 185 | return target, (B, B2, NS, NP, SQ) 186 | 187 | def train(data_loader, model, optimizer, epoch): 188 | losses = AverageMeter() 189 | accuracy = AverageMeter() 190 | accuracy_list = [AverageMeter(), AverageMeter(), AverageMeter()] 191 | model.train() 192 | global iteration 193 | 194 | for idx, input_seq in enumerate(data_loader): 195 | tic = time.time() 196 | input_seq = input_seq.to(cuda) 197 | B = input_seq.size(0) 198 | [score_, mask_] = model(input_seq) 199 | # visualize 200 | if (iteration == 0) or (iteration == args.print_freq): 201 | if B > 2: input_seq = input_seq[0:2,:] 202 | writer_train.add_image('input_seq', 203 | de_normalize(vutils.make_grid( 204 | input_seq.transpose(2,3).contiguous().view(-1,3,args.img_dim,args.img_dim), 205 | nrow=args.num_seq*args.seq_len)), 206 | iteration) 207 | del input_seq 208 | 209 | if idx == 0: target_, (_, B2, NS, NP, SQ) = process_output(mask_) 210 | 211 | # score is a 6d tensor: [B, P, SQ, B2, N, SQ] 212 | # similarity matrix is computed inside each gpu, thus here B == num_gpu * B2 213 | score_flattened = score_.view(B*NP*SQ, B2*NS*SQ) 214 | target_flattened = target_.view(B*NP*SQ, B2*NS*SQ).to(cuda) 215 | target_flattened = target_flattened.to(int).argmax(dim=1) 216 | 217 | loss = criterion(score_flattened, target_flattened) 218 | top1, top3, top5 = calc_topk_accuracy(score_flattened, target_flattened, (1,3,5)) 219 | 220 | accuracy_list[0].update(top1.item(), B) 221 | accuracy_list[1].update(top3.item(), B) 222 | accuracy_list[2].update(top5.item(), B) 223 | 224 | losses.update(loss.item(), B) 225 | accuracy.update(top1.item(), B) 226 | 227 | del score_ 228 | 229 | optimizer.zero_grad() 230 | loss.backward() 231 | optimizer.step() 232 | 233 | del loss 234 | 235 | if idx % args.print_freq == 0: 236 | print('Epoch: [{0}][{1}/{2}]\t' 237 | 'Loss {loss.val:.6f} ({loss.local_avg:.4f})\t' 238 | 'Acc: top1 {3:.4f}; top3 {4:.4f}; top5 {5:.4f} T:{6:.2f}\t'.format( 239 | epoch, idx, len(data_loader), top1, top3, top5, time.time()-tic, loss=losses)) 240 | 241 | writer_train.add_scalar('local/loss', losses.val, iteration) 242 | writer_train.add_scalar('local/accuracy', accuracy.val, iteration) 243 | 244 | iteration += 1 245 | 246 | return losses.local_avg, accuracy.local_avg, [i.local_avg for i in accuracy_list] 247 | 248 | 249 | def validate(data_loader, model, epoch): 250 | losses = AverageMeter() 251 | accuracy = AverageMeter() 252 | accuracy_list = [AverageMeter(), AverageMeter(), AverageMeter()] 253 | model.eval() 254 | 255 | with torch.no_grad(): 256 | for idx, input_seq in tqdm(enumerate(data_loader), total=len(data_loader)): 257 | input_seq = input_seq.to(cuda) 258 | B = input_seq.size(0) 259 | [score_, mask_] = model(input_seq) 260 | del input_seq 261 | 262 | if idx == 0: target_, (_, B2, NS, NP, SQ) = process_output(mask_) 263 | 264 | # [B, P, SQ, B, N, SQ] 265 | score_flattened = score_.view(B*NP*SQ, B2*NS*SQ) 266 | target_flattened = target_.view(B*NP*SQ, B2*NS*SQ).to(cuda) 267 | target_flattened = target_flattened.to(int).argmax(dim=1) 268 | 269 | loss = criterion(score_flattened, target_flattened) 270 | top1, top3, top5 = calc_topk_accuracy(score_flattened, target_flattened, (1,3,5)) 271 | 272 | losses.update(loss.item(), B) 273 | accuracy.update(top1.item(), B) 274 | 275 | accuracy_list[0].update(top1.item(), B) 276 | accuracy_list[1].update(top3.item(), B) 277 | accuracy_list[2].update(top5.item(), B) 278 | 279 | print('[{0}/{1}] Loss {loss.local_avg:.4f}\t' 280 | 'Acc: top1 {2:.4f}; top3 {3:.4f}; top5 {4:.4f} \t'.format( 281 | epoch, args.epochs, *[i.avg for i in accuracy_list], loss=losses)) 282 | return losses.local_avg, accuracy.local_avg, [i.local_avg for i in accuracy_list] 283 | 284 | 285 | def get_data(transform, mode='train'): 286 | print('Loading data for "%s" ...' % mode) 287 | if args.dataset == 'k400': 288 | use_big_K400 = args.img_dim > 140 289 | dataset = Kinetics400_full_3d(mode=mode, 290 | transform=transform, 291 | seq_len=args.seq_len, 292 | num_seq=args.num_seq, 293 | downsample=5, 294 | big=use_big_K400) 295 | elif args.dataset == 'ucf101': 296 | dataset = UCF101_3d(mode=mode, 297 | transform=transform, 298 | seq_len=args.seq_len, 299 | num_seq=args.num_seq, 300 | downsample=args.ds) 301 | else: 302 | raise ValueError('dataset not supported') 303 | 304 | sampler = data.RandomSampler(dataset) 305 | 306 | if mode == 'train': 307 | data_loader = data.DataLoader(dataset, 308 | batch_size=args.batch_size, 309 | sampler=sampler, 310 | shuffle=False, 311 | num_workers=32, 312 | pin_memory=True, 313 | drop_last=True) 314 | elif mode == 'val': 315 | data_loader = data.DataLoader(dataset, 316 | batch_size=args.batch_size, 317 | sampler=sampler, 318 | shuffle=False, 319 | num_workers=32, 320 | pin_memory=True, 321 | drop_last=True) 322 | print('"%s" dataset size: %d' % (mode, len(dataset))) 323 | return data_loader 324 | 325 | def set_path(args): 326 | if args.resume: exp_path = os.path.dirname(os.path.dirname(args.resume)) 327 | else: 328 | exp_path = 'log_{args.prefix}/{args.dataset}-{args.img_dim}_{0}_{args.model}_\ 329 | bs{args.batch_size}_lr{1}_seq{args.num_seq}_pred{args.pred_step}_len{args.seq_len}_ds{args.ds}_\ 330 | train-{args.train_what}{2}'.format( 331 | 'r%s' % args.net[6::], \ 332 | args.old_lr if args.old_lr is not None else args.lr, \ 333 | '_pt=%s' % args.pretrain.replace('/','-') if args.pretrain else '', \ 334 | args=args) 335 | img_path = os.path.join(exp_path, 'img') 336 | model_path = os.path.join(exp_path, 'model') 337 | if not os.path.exists(img_path): os.makedirs(img_path) 338 | if not os.path.exists(model_path): os.makedirs(model_path) 339 | return img_path, model_path 340 | 341 | if __name__ == '__main__': 342 | main() 343 | -------------------------------------------------------------------------------- /dpc/model_3d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import math 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | sys.path.append('../backbone') 10 | from select_backbone import select_resnet 11 | from convrnn import ConvGRU 12 | 13 | 14 | class DPC_RNN(nn.Module): 15 | '''DPC with RNN''' 16 | def __init__(self, sample_size, num_seq=8, seq_len=5, pred_step=3, network='resnet50'): 17 | super(DPC_RNN, self).__init__() 18 | torch.cuda.manual_seed(233) 19 | print('Using DPC-RNN model') 20 | self.sample_size = sample_size 21 | self.num_seq = num_seq 22 | self.seq_len = seq_len 23 | self.pred_step = pred_step 24 | self.last_duration = int(math.ceil(seq_len / 4)) 25 | self.last_size = int(math.ceil(sample_size / 32)) 26 | print('final feature map has size %dx%d' % (self.last_size, self.last_size)) 27 | 28 | self.backbone, self.param = select_resnet(network, track_running_stats=False) 29 | self.param['num_layers'] = 1 # param for GRU 30 | self.param['hidden_size'] = self.param['feature_size'] # param for GRU 31 | 32 | self.agg = ConvGRU(input_size=self.param['feature_size'], 33 | hidden_size=self.param['hidden_size'], 34 | kernel_size=1, 35 | num_layers=self.param['num_layers']) 36 | self.network_pred = nn.Sequential( 37 | nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0), 38 | nn.ReLU(inplace=True), 39 | nn.Conv2d(self.param['feature_size'], self.param['feature_size'], kernel_size=1, padding=0) 40 | ) 41 | self.mask = None 42 | self.relu = nn.ReLU(inplace=False) 43 | self._initialize_weights(self.agg) 44 | self._initialize_weights(self.network_pred) 45 | 46 | def forward(self, block): 47 | # block: [B, N, C, SL, W, H] 48 | ### extract feature ### 49 | (B, N, C, SL, H, W) = block.shape 50 | block = block.view(B*N, C, SL, H, W) 51 | feature = self.backbone(block) 52 | del block 53 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=(1, 1, 1)) 54 | 55 | feature_inf_all = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # before ReLU, (-inf, +inf) 56 | feature = self.relu(feature) # [0, +inf) 57 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # [B,N,D,6,6], [0, +inf) 58 | feature_inf = feature_inf_all[:, N-self.pred_step::, :].contiguous() 59 | del feature_inf_all 60 | 61 | ### aggregate, predict future ### 62 | _, hidden = self.agg(feature[:, 0:N-self.pred_step, :].contiguous()) 63 | hidden = hidden[:,-1,:] # after tanh, (-1,1). get the hidden state of last layer, last time step 64 | 65 | pred = [] 66 | for i in range(self.pred_step): 67 | # sequentially pred future 68 | p_tmp = self.network_pred(hidden) 69 | pred.append(p_tmp) 70 | _, hidden = self.agg(self.relu(p_tmp).unsqueeze(1), hidden.unsqueeze(0)) 71 | hidden = hidden[:,-1,:] 72 | pred = torch.stack(pred, 1) # B, pred_step, xxx 73 | del hidden 74 | 75 | 76 | ### Get similarity score ### 77 | # pred: [B, pred_step, D, last_size, last_size] 78 | # GT: [B, N, D, last_size, last_size] 79 | N = self.pred_step 80 | # dot product D dimension in pred-GT pair, get a 6d tensor. First 3 dims are from pred, last 3 dims are from GT. 81 | pred = pred.permute(0,1,3,4,2).contiguous().view(B*self.pred_step*self.last_size**2, self.param['feature_size']) 82 | feature_inf = feature_inf.permute(0,1,3,4,2).contiguous().view(B*N*self.last_size**2, self.param['feature_size']).transpose(0,1) 83 | score = torch.matmul(pred, feature_inf).view(B, self.pred_step, self.last_size**2, B, N, self.last_size**2) 84 | del feature_inf, pred 85 | 86 | if self.mask is None: # only compute mask once 87 | # mask meaning: -2: omit, -1: temporal neg (hard), 0: easy neg, 1: pos, -3: spatial neg 88 | mask = torch.zeros((B, self.pred_step, self.last_size**2, B, N, self.last_size**2), dtype=torch.int8, requires_grad=False).detach().cuda() 89 | mask[torch.arange(B), :, :, torch.arange(B), :, :] = -3 # spatial neg 90 | for k in range(B): 91 | mask[k, :, torch.arange(self.last_size**2), k, :, torch.arange(self.last_size**2)] = -1 # temporal neg 92 | tmp = mask.permute(0, 2, 1, 3, 5, 4).contiguous().view(B*self.last_size**2, self.pred_step, B*self.last_size**2, N) 93 | for j in range(B*self.last_size**2): 94 | tmp[j, torch.arange(self.pred_step), j, torch.arange(N-self.pred_step, N)] = 1 # pos 95 | mask = tmp.view(B, self.last_size**2, self.pred_step, B, self.last_size**2, N).permute(0,2,1,3,5,4) 96 | self.mask = mask 97 | 98 | return [score, self.mask] 99 | 100 | def _initialize_weights(self, module): 101 | for name, param in module.named_parameters(): 102 | if 'bias' in name: 103 | nn.init.constant_(param, 0.0) 104 | elif 'weight' in name: 105 | nn.init.orthogonal_(param, 1) 106 | # other resnet weights have been initialized in resnet itself 107 | 108 | def reset_mask(self): 109 | self.mask = None 110 | 111 | -------------------------------------------------------------------------------- /eval/dataset_3d_lc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import os 5 | import sys 6 | import time 7 | import pickle 8 | import csv 9 | import glob 10 | import pandas as pd 11 | import numpy as np 12 | import cv2 13 | sys.path.append('../utils') 14 | from augmentation import * 15 | from tqdm import tqdm 16 | from joblib import Parallel, delayed 17 | 18 | def pil_loader(path): 19 | with open(path, 'rb') as f: 20 | with Image.open(f) as img: 21 | return img.convert('RGB') 22 | 23 | class UCF101_3d(data.Dataset): 24 | def __init__(self, 25 | mode='train', 26 | transform=None, 27 | seq_len=10, 28 | num_seq =1, 29 | downsample=3, 30 | epsilon=5, 31 | which_split=1): 32 | self.mode = mode 33 | self.transform = transform 34 | self.seq_len = seq_len 35 | self.num_seq = num_seq 36 | self.downsample = downsample 37 | self.epsilon = epsilon 38 | self.which_split = which_split 39 | 40 | # splits 41 | if mode == 'train': 42 | split = '../process_data/data/ucf101/train_split%02d.csv' % self.which_split 43 | video_info = pd.read_csv(split, header=None) 44 | elif (mode == 'val') or (mode == 'test'): 45 | split = '../process_data/data/ucf101/test_split%02d.csv' % self.which_split # use test for val, temporary 46 | video_info = pd.read_csv(split, header=None) 47 | else: raise ValueError('wrong mode') 48 | 49 | # get action list 50 | self.action_dict_encode = {} 51 | self.action_dict_decode = {} 52 | 53 | action_file = os.path.join('../process_data/data/ucf101', 'classInd.txt') 54 | action_df = pd.read_csv(action_file, sep=' ', header=None) 55 | for _, row in action_df.iterrows(): 56 | act_id, act_name = row 57 | act_id = int(act_id) - 1 # let id start from 0 58 | self.action_dict_decode[act_id] = act_name 59 | self.action_dict_encode[act_name] = act_id 60 | 61 | # filter out too short videos: 62 | drop_idx = [] 63 | for idx, row in video_info.iterrows(): 64 | vpath, vlen = row 65 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 66 | drop_idx.append(idx) 67 | self.video_info = video_info.drop(drop_idx, axis=0) 68 | 69 | if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 70 | # shuffle not required 71 | 72 | def idx_sampler(self, vlen, vpath): 73 | '''sample index from a video''' 74 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] 75 | n = 1 76 | if self.mode == 'test': 77 | seq_idx_block = np.arange(0, vlen, self.downsample) # all possible frames with downsampling 78 | return [seq_idx_block, vpath] 79 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) 80 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx 81 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample 82 | return [seq_idx_block, vpath] 83 | 84 | 85 | def __getitem__(self, index): 86 | vpath, vlen = self.video_info.iloc[index] 87 | items = self.idx_sampler(vlen, vpath) 88 | if items is None: print(vpath) 89 | 90 | idx_block, vpath = items 91 | if self.mode != 'test': 92 | assert idx_block.shape == (self.num_seq, self.seq_len) 93 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 94 | 95 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] 96 | t_seq = self.transform(seq) # apply same transform 97 | 98 | num_crop = None 99 | try: 100 | (C, H, W) = t_seq[0].size() 101 | t_seq = torch.stack(t_seq, 0) 102 | except: 103 | (C, H, W) = t_seq[0][0].size() 104 | tmp = [torch.stack(i, 0) for i in t_seq] 105 | assert len(tmp) == 5 106 | num_crop = 5 107 | t_seq = torch.stack(tmp, 1) 108 | 109 | if self.mode == 'test': 110 | # return all available clips, but cut into length = num_seq 111 | SL = t_seq.size(0) 112 | clips = []; i = 0 113 | while i+self.seq_len <= SL: 114 | clips.append(t_seq[i:i+self.seq_len, :]) 115 | # i += self.seq_len//2 116 | i += self.seq_len 117 | if num_crop: 118 | # half overlap: 119 | clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 120 | NC = len(clips) 121 | t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) 122 | else: 123 | # half overlap: 124 | clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 125 | t_seq = torch.stack(clips, 0) 126 | else: 127 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 128 | 129 | try: 130 | vname = vpath.split('/')[-3] 131 | vid = self.encode_action(vname) 132 | except: 133 | vname = vpath.split('/')[-2] 134 | vid = self.encode_action(vname) 135 | 136 | label = torch.LongTensor([vid]) 137 | 138 | return t_seq, label 139 | 140 | def __len__(self): 141 | return len(self.video_info) 142 | 143 | def encode_action(self, action_name): 144 | '''give action name, return category''' 145 | return self.action_dict_encode[action_name] 146 | 147 | def decode_action(self, action_code): 148 | '''give action code, return action name''' 149 | return self.action_dict_decode[action_code] 150 | 151 | 152 | class HMDB51_3d(data.Dataset): 153 | def __init__(self, 154 | mode='train', 155 | transform=None, 156 | seq_len=10, 157 | num_seq=1, 158 | downsample=1, 159 | epsilon=5, 160 | which_split=1): 161 | self.mode = mode 162 | self.transform = transform 163 | self.seq_len = seq_len 164 | self.num_seq = num_seq 165 | self.downsample = downsample 166 | self.epsilon = epsilon 167 | self.which_split = which_split 168 | 169 | # splits 170 | if mode == 'train': 171 | split = '../process_data/data/hmdb51/train_split%02d.csv' % self.which_split 172 | video_info = pd.read_csv(split, header=None) 173 | elif (mode == 'val') or (mode == 'test'): 174 | split = '../process_data/data/hmdb51/test_split%02d.csv' % self.which_split # use test for val, temporary 175 | video_info = pd.read_csv(split, header=None) 176 | else: raise ValueError('wrong mode') 177 | 178 | # get action list 179 | self.action_dict_encode = {} 180 | self.action_dict_decode = {} 181 | 182 | action_file = os.path.join('../process_data/data/hmdb51', 'classInd.txt') 183 | action_df = pd.read_csv(action_file, sep=' ', header=None) 184 | for _, row in action_df.iterrows(): 185 | act_id, act_name = row 186 | act_id = int(act_id) - 1 # let id start from 0 187 | self.action_dict_decode[act_id] = act_name 188 | self.action_dict_encode[act_name] = act_id 189 | 190 | # filter out too short videos: 191 | drop_idx = [] 192 | for idx, row in video_info.iterrows(): 193 | vpath, vlen = row 194 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: 195 | drop_idx.append(idx) 196 | self.video_info = video_info.drop(drop_idx, axis=0) 197 | 198 | if mode == 'val': self.video_info = self.video_info.sample(frac=0.3) 199 | # shuffle not required 200 | 201 | def idx_sampler(self, vlen, vpath): 202 | '''sample index from a video''' 203 | if vlen-self.num_seq*self.seq_len*self.downsample <= 0: return [None] 204 | n = 1 205 | if self.mode == 'test': 206 | seq_idx_block = np.arange(0, vlen, self.downsample) # all possible frames with downsampling 207 | return [seq_idx_block, vpath] 208 | start_idx = np.random.choice(range(vlen-self.num_seq*self.seq_len*self.downsample), n) 209 | seq_idx = np.expand_dims(np.arange(self.num_seq), -1)*self.downsample*self.seq_len + start_idx 210 | seq_idx_block = seq_idx + np.expand_dims(np.arange(self.seq_len),0)*self.downsample 211 | return [seq_idx_block, vpath] 212 | 213 | 214 | def __getitem__(self, index): 215 | vpath, vlen = self.video_info.iloc[index] 216 | items = self.idx_sampler(vlen, vpath) 217 | if items is None: print(vpath) 218 | 219 | idx_block, vpath = items 220 | if self.mode != 'test': 221 | assert idx_block.shape == (self.num_seq, self.seq_len) 222 | idx_block = idx_block.reshape(self.num_seq*self.seq_len) 223 | 224 | seq = [pil_loader(os.path.join(vpath, 'image_%05d.jpg' % (i+1))) for i in idx_block] 225 | t_seq = self.transform(seq) # apply same transform 226 | 227 | num_crop = None 228 | try: 229 | (C, H, W) = t_seq[0].size() 230 | t_seq = torch.stack(t_seq, 0) 231 | except: 232 | (C, H, W) = t_seq[0][0].size() 233 | tmp = [torch.stack(i, 0) for i in t_seq] 234 | assert len(tmp) == 5 235 | num_crop = 5 236 | t_seq = torch.stack(tmp, 1) 237 | # print(t_seq.size()) 238 | # import ipdb; ipdb.set_trace() 239 | if self.mode == 'test': 240 | # return all available clips, but cut into length = num_seq 241 | SL = t_seq.size(0) 242 | clips = []; i = 0 243 | while i+self.seq_len <= SL: 244 | clips.append(t_seq[i:i+self.seq_len, :]) 245 | # i += self.seq_len//2 246 | i += self.seq_len 247 | if num_crop: 248 | # half overlap: 249 | clips = [torch.stack(clips[i:i+self.num_seq], 0).permute(2,0,3,1,4,5) for i in range(0,len(clips)+1-self.num_seq,self.num_seq//2)] 250 | NC = len(clips) 251 | t_seq = torch.stack(clips, 0).view(NC*num_crop, self.num_seq, C, self.seq_len, H, W) 252 | else: 253 | # half overlap: 254 | clips = [torch.stack(clips[i:i+self.num_seq], 0).transpose(1,2) for i in range(0,len(clips)+1-self.num_seq,3*self.num_seq//4)] 255 | t_seq = torch.stack(clips, 0) 256 | else: 257 | t_seq = t_seq.view(self.num_seq, self.seq_len, C, H, W).transpose(1,2) 258 | 259 | try: 260 | vname = vpath.split('/')[-3] 261 | vid = self.encode_action(vname) 262 | except: 263 | vname = vpath.split('/')[-2] 264 | vid = self.encode_action(vname) 265 | 266 | label = torch.LongTensor([vid]) 267 | 268 | return t_seq, label 269 | 270 | def __len__(self): 271 | return len(self.video_info) 272 | 273 | def encode_action(self, action_name): 274 | '''give action name, return category''' 275 | return self.action_dict_encode[action_name] 276 | 277 | def decode_action(self, action_code): 278 | '''give action code, return action name''' 279 | return self.action_dict_decode[action_code] 280 | 281 | -------------------------------------------------------------------------------- /eval/model_3d_lc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | sys.path.append('../backbone') 5 | from select_backbone import select_resnet 6 | from convrnn import ConvGRU 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class LC(nn.Module): 13 | def __init__(self, sample_size, num_seq, seq_len, 14 | network='resnet18', dropout=0.5, num_class=101): 15 | super(LC, self).__init__() 16 | torch.cuda.manual_seed(666) 17 | self.sample_size = sample_size 18 | self.num_seq = num_seq 19 | self.seq_len = seq_len 20 | self.num_class = num_class 21 | print('=> Using RNN + FC model ') 22 | 23 | print('=> Use 2D-3D %s!' % network) 24 | self.last_duration = int(math.ceil(seq_len / 4)) 25 | self.last_size = int(math.ceil(sample_size / 32)) 26 | track_running_stats = True 27 | 28 | self.backbone, self.param = select_resnet(network, track_running_stats=track_running_stats) 29 | self.param['num_layers'] = 1 30 | self.param['hidden_size'] = self.param['feature_size'] 31 | 32 | print('=> using ConvRNN, kernel_size = 1') 33 | self.agg = ConvGRU(input_size=self.param['feature_size'], 34 | hidden_size=self.param['hidden_size'], 35 | kernel_size=1, 36 | num_layers=self.param['num_layers']) 37 | self._initialize_weights(self.agg) 38 | 39 | self.final_bn = nn.BatchNorm1d(self.param['feature_size']) 40 | self.final_bn.weight.data.fill_(1) 41 | self.final_bn.bias.data.zero_() 42 | 43 | self.final_fc = nn.Sequential(nn.Dropout(dropout), 44 | nn.Linear(self.param['feature_size'], self.num_class)) 45 | self._initialize_weights(self.final_fc) 46 | 47 | def forward(self, block): 48 | # seq1: [B, N, C, SL, W, H] 49 | (B, N, C, SL, H, W) = block.shape 50 | block = block.view(B*N, C, SL, H, W) 51 | feature = self.backbone(block) 52 | del block 53 | feature = F.relu(feature) 54 | 55 | feature = F.avg_pool3d(feature, (self.last_duration, 1, 1), stride=1) 56 | feature = feature.view(B, N, self.param['feature_size'], self.last_size, self.last_size) # [B*N,D,last_size,last_size] 57 | context, _ = self.agg(feature) 58 | context = context[:,-1,:].unsqueeze(1) 59 | context = F.avg_pool3d(context, (1, self.last_size, self.last_size), stride=1).squeeze(-1).squeeze(-1) 60 | del feature 61 | 62 | context = self.final_bn(context.transpose(-1,-2)).transpose(-1,-2) # [B,N,C] -> [B,C,N] -> BN() -> [B,N,C], because BN operates on id=1 channel. 63 | output = self.final_fc(context).view(B, -1, self.num_class) 64 | 65 | return output, context 66 | 67 | def _initialize_weights(self, module): 68 | for name, param in module.named_parameters(): 69 | if 'bias' in name: 70 | nn.init.constant_(param, 0.0) 71 | elif 'weight' in name: 72 | nn.init.orthogonal_(param, 1) 73 | # other resnet weights have been initialized in resnet_3d.py 74 | 75 | 76 | -------------------------------------------------------------------------------- /eval/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import re 6 | import numpy as np 7 | from tqdm import tqdm 8 | from tensorboardX import SummaryWriter 9 | 10 | sys.path.append('../utils') 11 | sys.path.append('../backbone') 12 | from dataset_3d_lc import UCF101_3d, HMDB51_3d 13 | from model_3d_lc import * 14 | from resnet_2d3d import neq_load_customized 15 | from augmentation import * 16 | from utils import AverageMeter, ConfusionMeter, save_checkpoint, write_log, calc_topk_accuracy, denorm, calc_accuracy 17 | 18 | import torch 19 | import torch.optim as optim 20 | from torch.utils import data 21 | import torch.nn as nn 22 | from torchvision import datasets, models, transforms 23 | import torchvision.utils as vutils 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--net', default='resnet18', type=str) 27 | parser.add_argument('--model', default='lc', type=str) 28 | parser.add_argument('--dataset', default='ucf101', type=str) 29 | parser.add_argument('--split', default=1, type=int) 30 | parser.add_argument('--seq_len', default=5, type=int) 31 | parser.add_argument('--num_seq', default=8, type=int) 32 | parser.add_argument('--num_class', default=101, type=int) 33 | parser.add_argument('--dropout', default=0.5, type=float) 34 | parser.add_argument('--ds', default=3, type=int) 35 | parser.add_argument('--batch_size', default=4, type=int) 36 | parser.add_argument('--lr', default=1e-3, type=float) 37 | parser.add_argument('--wd', default=1e-3, type=float, help='weight decay') 38 | parser.add_argument('--resume', default='', type=str) 39 | parser.add_argument('--pretrain', default='random', type=str) 40 | parser.add_argument('--test', default='', type=str) 41 | parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)') 43 | parser.add_argument('--gpu', default='0,1', type=str) 44 | parser.add_argument('--print_freq', default=5, type=int) 45 | parser.add_argument('--reset_lr', action='store_true', help='Reset learning rate when resume training?') 46 | parser.add_argument('--train_what', default='last', type=str, help='Train what parameters?') 47 | parser.add_argument('--prefix', default='tmp', type=str) 48 | parser.add_argument('--img_dim', default=128, type=int) 49 | 50 | 51 | def main(): 52 | global args; args = parser.parse_args() 53 | os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) 54 | global cuda; cuda = torch.device('cuda') 55 | 56 | if args.dataset == 'ucf101': args.num_class = 101 57 | elif args.dataset == 'hmdb51': args.num_class = 51 58 | 59 | ### classifier model ### 60 | if args.model == 'lc': 61 | model = LC(sample_size=args.img_dim, 62 | num_seq=args.num_seq, 63 | seq_len=args.seq_len, 64 | network=args.net, 65 | num_class=args.num_class, 66 | dropout=args.dropout) 67 | else: 68 | raise ValueError('wrong model!') 69 | 70 | model = nn.DataParallel(model) 71 | model = model.to(cuda) 72 | global criterion; criterion = nn.CrossEntropyLoss() 73 | 74 | ### optimizer ### 75 | params = None 76 | if args.train_what == 'ft': 77 | print('=> finetune backbone with smaller lr') 78 | params = [] 79 | for name, param in model.module.named_parameters(): 80 | if ('resnet' in name) or ('rnn' in name): 81 | params.append({'params': param, 'lr': args.lr/10}) 82 | else: 83 | params.append({'params': param}) 84 | else: pass # train all layers 85 | 86 | print('\n===========Check Grad============') 87 | for name, param in model.named_parameters(): 88 | print(name, param.requires_grad) 89 | print('=================================\n') 90 | 91 | if params is None: params = model.parameters() 92 | 93 | optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd) 94 | if args.dataset == 'hmdb51': 95 | lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[150,250,300], repeat=1) 96 | elif args.dataset == 'ucf101': 97 | if args.img_dim == 224: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[300,400,500], repeat=1) 98 | else: lr_lambda = lambda ep: MultiStepLR_Restart_Multiplier(ep, gamma=0.1, step=[60, 80, 100], repeat=1) 99 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) 100 | 101 | args.old_lr = None 102 | best_acc = 0 103 | global iteration; iteration = 0 104 | 105 | ### restart training ### 106 | if args.test: 107 | if os.path.isfile(args.test): 108 | print("=> loading testing checkpoint '{}'".format(args.test)) 109 | checkpoint = torch.load(args.test) 110 | try: model.load_state_dict(checkpoint['state_dict']) 111 | except: 112 | print('=> [Warning]: weight structure is not equal to test model; Use non-equal load ==') 113 | model = neq_load_customized(model, checkpoint['state_dict']) 114 | print("=> loaded testing checkpoint '{}' (epoch {})".format(args.test, checkpoint['epoch'])) 115 | global num_epoch; num_epoch = checkpoint['epoch'] 116 | elif args.test == 'random': 117 | print("=> [Warning] loaded random weights") 118 | else: 119 | raise ValueError() 120 | 121 | transform = transforms.Compose([ 122 | RandomSizedCrop(consistent=True, size=224, p=0.0), 123 | Scale(size=(args.img_dim,args.img_dim)), 124 | ToTensor(), 125 | Normalize() 126 | ]) 127 | test_loader = get_data(transform, 'test') 128 | test_loss, test_acc = test(test_loader, model) 129 | sys.exit() 130 | else: # not test 131 | torch.backends.cudnn.benchmark = True 132 | 133 | if args.resume: 134 | if os.path.isfile(args.resume): 135 | args.old_lr = float(re.search('_lr(.+?)_', args.resume).group(1)) 136 | print("=> loading resumed checkpoint '{}'".format(args.resume)) 137 | checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) 138 | args.start_epoch = checkpoint['epoch'] 139 | best_acc = checkpoint['best_acc'] 140 | model.load_state_dict(checkpoint['state_dict']) 141 | if not args.reset_lr: # if didn't reset lr, load old optimizer 142 | optimizer.load_state_dict(checkpoint['optimizer']) 143 | else: print('==== Change lr from %f to %f ====' % (args.old_lr, args.lr)) 144 | iteration = checkpoint['iteration'] 145 | print("=> loaded resumed checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 146 | else: 147 | print("=> no checkpoint found at '{}'".format(args.resume)) 148 | 149 | if (not args.resume) and args.pretrain: 150 | if args.pretrain == 'random': 151 | print('=> using random weights') 152 | elif os.path.isfile(args.pretrain): 153 | print("=> loading pretrained checkpoint '{}'".format(args.pretrain)) 154 | checkpoint = torch.load(args.pretrain, map_location=torch.device('cpu')) 155 | model = neq_load_customized(model, checkpoint['state_dict']) 156 | print("=> loaded pretrained checkpoint '{}' (epoch {})".format(args.pretrain, checkpoint['epoch'])) 157 | else: 158 | print("=> no checkpoint found at '{}'".format(args.pretrain)) 159 | 160 | ### load data ### 161 | transform = transforms.Compose([ 162 | RandomSizedCrop(consistent=True, size=224, p=1.0), 163 | Scale(size=(args.img_dim,args.img_dim)), 164 | RandomHorizontalFlip(consistent=True), 165 | ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.25, p=0.3, consistent=True), 166 | ToTensor(), 167 | Normalize() 168 | ]) 169 | val_transform = transforms.Compose([ 170 | RandomSizedCrop(consistent=True, size=224, p=0.3), 171 | Scale(size=(args.img_dim,args.img_dim)), 172 | RandomHorizontalFlip(consistent=True), 173 | ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.3, consistent=True), 174 | ToTensor(), 175 | Normalize() 176 | ]) 177 | 178 | train_loader = get_data(transform, 'train') 179 | val_loader = get_data(val_transform, 'val') 180 | 181 | # setup tools 182 | global de_normalize; de_normalize = denorm() 183 | global img_path; img_path, model_path = set_path(args) 184 | global writer_train 185 | try: # old version 186 | writer_val = SummaryWriter(log_dir=os.path.join(img_path, 'val')) 187 | writer_train = SummaryWriter(log_dir=os.path.join(img_path, 'train')) 188 | except: # v1.7 189 | writer_val = SummaryWriter(logdir=os.path.join(img_path, 'val')) 190 | writer_train = SummaryWriter(logdir=os.path.join(img_path, 'train')) 191 | 192 | ### main loop ### 193 | for epoch in range(args.start_epoch, args.epochs): 194 | train_loss, train_acc = train(train_loader, model, optimizer, epoch) 195 | val_loss, val_acc = validate(val_loader, model) 196 | scheduler.step(epoch) 197 | 198 | writer_train.add_scalar('global/loss', train_loss, epoch) 199 | writer_train.add_scalar('global/accuracy', train_acc, epoch) 200 | writer_val.add_scalar('global/loss', val_loss, epoch) 201 | writer_val.add_scalar('global/accuracy', val_acc, epoch) 202 | 203 | # save check_point 204 | is_best = val_acc > best_acc 205 | best_acc = max(val_acc, best_acc) 206 | save_checkpoint({ 207 | 'epoch': epoch+1, 208 | 'net': args.net, 209 | 'state_dict': model.state_dict(), 210 | 'best_acc': best_acc, 211 | 'optimizer': optimizer.state_dict(), 212 | 'iteration': iteration 213 | }, is_best, filename=os.path.join(model_path, 'epoch%s.pth.tar' % str(epoch+1)), keep_all=False) 214 | 215 | print('Training from ep %d to ep %d finished' % (args.start_epoch, args.epochs)) 216 | 217 | 218 | def train(data_loader, model, optimizer, epoch): 219 | losses = AverageMeter() 220 | accuracy = AverageMeter() 221 | model.train() 222 | global iteration 223 | 224 | for idx, (input_seq, target) in enumerate(data_loader): 225 | tic = time.time() 226 | input_seq = input_seq.to(cuda) 227 | target = target.to(cuda) 228 | B = input_seq.size(0) 229 | output, _ = model(input_seq) 230 | 231 | # visualize 232 | if (iteration == 0) or (iteration == args.print_freq): 233 | if B > 2: input_seq = input_seq[0:2,:] 234 | writer_train.add_image('input_seq', 235 | de_normalize(vutils.make_grid( 236 | input_seq.transpose(2,3).contiguous().view(-1,3,args.img_dim,args.img_dim), 237 | nrow=args.num_seq*args.seq_len)), 238 | iteration) 239 | del input_seq 240 | 241 | [_, N, D] = output.size() 242 | output = output.view(B*N, D) 243 | target = target.repeat(1, N).view(-1) 244 | 245 | loss = criterion(output, target) 246 | acc = calc_accuracy(output, target) 247 | 248 | del target 249 | 250 | losses.update(loss.item(), B) 251 | accuracy.update(acc.item(), B) 252 | 253 | optimizer.zero_grad() 254 | loss.backward() 255 | optimizer.step() 256 | 257 | if idx % args.print_freq == 0: 258 | print('Epoch: [{0}][{1}/{2}]\t' 259 | 'Loss {loss.val:.4f} ({loss.local_avg:.4f})\t' 260 | 'Acc: {acc.val:.4f} ({acc.local_avg:.4f}) T:{3:.2f}\t'.format( 261 | epoch, idx, len(data_loader), time.time()-tic, 262 | loss=losses, acc=accuracy)) 263 | 264 | total_weight = 0.0 265 | decay_weight = 0.0 266 | for m in model.parameters(): 267 | if m.requires_grad: decay_weight += m.norm(2).data 268 | total_weight += m.norm(2).data 269 | print('Decay weight / Total weight: %.3f/%.3f' % (decay_weight, total_weight)) 270 | 271 | writer_train.add_scalar('local/loss', losses.val, iteration) 272 | writer_train.add_scalar('local/accuracy', accuracy.val, iteration) 273 | 274 | iteration += 1 275 | 276 | return losses.local_avg, accuracy.local_avg 277 | 278 | def validate(data_loader, model): 279 | losses = AverageMeter() 280 | accuracy = AverageMeter() 281 | model.eval() 282 | with torch.no_grad(): 283 | for idx, (input_seq, target) in tqdm(enumerate(data_loader), total=len(data_loader)): 284 | input_seq = input_seq.to(cuda) 285 | target = target.to(cuda) 286 | B = input_seq.size(0) 287 | output, _ = model(input_seq) 288 | 289 | [_, N, D] = output.size() 290 | output = output.view(B*N, D) 291 | target = target.repeat(1, N).view(-1) 292 | 293 | loss = criterion(output, target) 294 | acc = calc_accuracy(output, target) 295 | 296 | losses.update(loss.item(), B) 297 | accuracy.update(acc.item(), B) 298 | 299 | print('Loss {loss.avg:.4f}\t' 300 | 'Acc: {acc.avg:.4f} \t'.format(loss=losses, acc=accuracy)) 301 | return losses.avg, accuracy.avg 302 | 303 | def test(data_loader, model): 304 | losses = AverageMeter() 305 | acc_top1 = AverageMeter() 306 | acc_top5 = AverageMeter() 307 | confusion_mat = ConfusionMeter(args.num_class) 308 | model.eval() 309 | with torch.no_grad(): 310 | for idx, (input_seq, target) in tqdm(enumerate(data_loader), total=len(data_loader)): 311 | input_seq = input_seq.to(cuda) 312 | target = target.to(cuda) 313 | B = input_seq.size(0) 314 | input_seq = input_seq.squeeze(0) # squeeze the '1' batch dim 315 | output, _ = model(input_seq) 316 | del input_seq 317 | top1, top5 = calc_topk_accuracy(torch.mean( 318 | torch.mean( 319 | nn.functional.softmax(output,2), 320 | 0),0, keepdim=True), 321 | target, (1,5)) 322 | acc_top1.update(top1.item(), B) 323 | acc_top5.update(top5.item(), B) 324 | del top1, top5 325 | 326 | output = torch.mean(torch.mean(output, 0), 0, keepdim=True) 327 | loss = criterion(output, target.squeeze(-1)) 328 | 329 | losses.update(loss.item(), B) 330 | del loss 331 | 332 | 333 | _, pred = torch.max(output, 1) 334 | confusion_mat.update(pred, target.view(-1).byte()) 335 | 336 | print('Loss {loss.avg:.4f}\t' 337 | 'Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(loss=losses, top1=acc_top1, top5=acc_top5)) 338 | confusion_mat.plot_mat(args.test+'.svg') 339 | write_log(content='Loss {loss.avg:.4f}\t Acc top1: {top1.avg:.4f} Acc top5: {top5.avg:.4f} \t'.format(loss=losses, top1=acc_top1, top5=acc_top5, args=args), 340 | epoch=num_epoch, 341 | filename=os.path.join(os.path.dirname(args.test), 'test_log.md')) 342 | import ipdb; ipdb.set_trace() 343 | return losses.avg, [acc_top1.avg, acc_top5.avg] 344 | 345 | def get_data(transform, mode='train'): 346 | print('Loading data for "%s" ...' % mode) 347 | global dataset 348 | if args.dataset == 'ucf101': 349 | dataset = UCF101_3d(mode=mode, 350 | transform=transform, 351 | seq_len=args.seq_len, 352 | num_seq=args.num_seq, 353 | downsample=args.ds, 354 | which_split=args.split) 355 | elif args.dataset == 'hmdb51': 356 | dataset = HMDB51_3d(mode=mode, 357 | transform=transform, 358 | seq_len=args.seq_len, 359 | num_seq=args.num_seq, 360 | downsample=args.ds, 361 | which_split=args.split) 362 | else: 363 | raise ValueError('dataset not supported') 364 | my_sampler = data.RandomSampler(dataset) 365 | if mode == 'train': 366 | data_loader = data.DataLoader(dataset, 367 | batch_size=args.batch_size, 368 | sampler=my_sampler, 369 | shuffle=False, 370 | num_workers=16, 371 | pin_memory=True, 372 | drop_last=True) 373 | elif mode == 'val': 374 | data_loader = data.DataLoader(dataset, 375 | batch_size=args.batch_size, 376 | sampler=my_sampler, 377 | shuffle=False, 378 | num_workers=16, 379 | pin_memory=True, 380 | drop_last=True) 381 | elif mode == 'test': 382 | data_loader = data.DataLoader(dataset, 383 | batch_size=1, 384 | sampler=my_sampler, 385 | shuffle=False, 386 | num_workers=16, 387 | pin_memory=True) 388 | print('"%s" dataset size: %d' % (mode, len(dataset))) 389 | return data_loader 390 | 391 | def set_path(args): 392 | if args.resume: exp_path = os.path.dirname(os.path.dirname(args.resume)) 393 | else: 394 | exp_path = 'log_{args.prefix}/{args.dataset}-{args.img_dim}-\ 395 | sp{args.split}_{0}_{args.model}_bs{args.batch_size}_\ 396 | lr{1}_wd{args.wd}_ds{args.ds}_seq{args.num_seq}_len{args.seq_len}_\ 397 | dp{args.dropout}_train-{args.train_what}{2}'.format( 398 | 'r%s' % args.net[6::], \ 399 | args.old_lr if args.old_lr is not None else args.lr, \ 400 | '_pt='+args.pretrain.replace('/','-') if args.pretrain else '', \ 401 | args=args) 402 | img_path = os.path.join(exp_path, 'img') 403 | model_path = os.path.join(exp_path, 'model') 404 | if not os.path.exists(img_path): os.makedirs(img_path) 405 | if not os.path.exists(model_path): os.makedirs(model_path) 406 | return img_path, model_path 407 | 408 | def MultiStepLR_Restart_Multiplier(epoch, gamma=0.1, step=[10,15,20], repeat=3): 409 | '''return the multipier for LambdaLR, 410 | 0 <= ep < 10: gamma^0 411 | 10 <= ep < 15: gamma^1 412 | 15 <= ep < 20: gamma^2 413 | 20 <= ep < 30: gamma^0 ... repeat 3 cycles and then keep gamma^2''' 414 | max_step = max(step) 415 | effective_epoch = epoch % max_step 416 | if epoch // max_step >= repeat: 417 | exp = len(step) - 1 418 | else: 419 | exp = len([i for i in step if effective_epoch>=i]) 420 | return gamma ** exp 421 | 422 | if __name__ == '__main__': 423 | main() 424 | -------------------------------------------------------------------------------- /process_data/readme.md: -------------------------------------------------------------------------------- 1 | ## Process data 2 | 3 | This folder has some tools to process UCF101, HMDB51 and Kinetics400 datasets. 4 | 5 | ### 1. Download 6 | 7 | Download the videos from source: 8 | [UCF101 source](https://www.crcv.ucf.edu/data/UCF101.php), 9 | [HMDB51 source](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads), 10 | [Kinetics400 source](https://deepmind.com/research/publications/kinetics-human-action-video-dataset). 11 | 12 | Make sure datasets are stored as follows: 13 | 14 | * UCF101 15 | ``` 16 | {your_path}/UCF101/videos/{action class}/{video name}.avi 17 | {your_path}/UCF101/splits_classification/trainlist{01/02/03}.txt 18 | {your_path}/UCF101/splits_classification/testlist{01/02/03}}.txt 19 | ``` 20 | 21 | * HMDB51 22 | ``` 23 | {your_path}/HMDB51/videos/{action class}/{video name}.avi 24 | {your_path}/HMDB51/split/testTrainMulti_7030_splits/{action class}_test_split{1/2/3}.txt 25 | ``` 26 | 27 | * Kinetics400 28 | ``` 29 | {your_path}/Kinetics400/videos/train_split/{action class}/{video name}.mp4 30 | {your_path}/Kinetics400/videos/val_split/{action class}/{video name}.mp4 31 | ``` 32 | Also keep the downloaded csv files, make sure you have: 33 | ``` 34 | {your_path}/Kinetics/kinetics_train/kinetics_train.csv 35 | {your_path}/Kinetics/kinetics_val/kinetics_val.csv 36 | {your_path}/Kinetics/kinetics_test/kinetics_test.csv 37 | ``` 38 | 39 | ### 2. Extract frames 40 | 41 | Edit path arguments in `main_*()` functions, and `python extract_frame.py`. Video frames will be extracted. 42 | 43 | ### 3. Collect all paths into csv 44 | 45 | Edit path arguments in `main_*()` functions, and `python write_csv.py`. csv files will be stored in `data/` directory. 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /process_data/src/extract_frame.py: -------------------------------------------------------------------------------- 1 | from joblib import delayed, Parallel 2 | import os 3 | import sys 4 | import glob 5 | from tqdm import tqdm 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | plt.switch_backend('agg') 9 | 10 | def extract_video_opencv(v_path, f_root, dim=240): 11 | '''v_path: single video path; 12 | f_root: root to store frames''' 13 | v_class = v_path.split('/')[-2] 14 | v_name = os.path.basename(v_path)[0:-4] 15 | out_dir = os.path.join(f_root, v_class, v_name) 16 | if not os.path.exists(out_dir): 17 | os.makedirs(out_dir) 18 | 19 | vidcap = cv2.VideoCapture(v_path) 20 | nb_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 21 | width = vidcap.get(cv2.CAP_PROP_FRAME_WIDTH) # float 22 | height = vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 23 | if (width == 0) or (height==0): 24 | print(v_path, 'not successfully loaded, drop ..'); return 25 | new_dim = resize_dim(width, height, dim) 26 | 27 | success, image = vidcap.read() 28 | count = 1 29 | while success: 30 | image = cv2.resize(image, new_dim, interpolation = cv2.INTER_LINEAR) 31 | cv2.imwrite(os.path.join(out_dir, 'image_%05d.jpg' % count), image, 32 | [cv2.IMWRITE_JPEG_QUALITY, 80])# quality from 0-100, 95 is default, high is good 33 | success, image = vidcap.read() 34 | count += 1 35 | if nb_frames > count: 36 | print('/'.join(out_dir.split('/')[-2::]), 'NOT extracted successfully: %df/%df' % (count, nb_frames)) 37 | vidcap.release() 38 | 39 | def resize_dim(w, h, target): 40 | '''resize (w, h), such that the smaller side is target, keep the aspect ratio''' 41 | if w >= h: 42 | return (int(target * w / h), int(target)) 43 | else: 44 | return (int(target), int(target * h / w)) 45 | 46 | def main_UCF101(v_root, f_root): 47 | print('extracting UCF101 ... ') 48 | print('extracting videos from %s' % v_root) 49 | print('frame save to %s' % f_root) 50 | 51 | if not os.path.exists(f_root): os.makedirs(f_root) 52 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 53 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 54 | v_paths = glob.glob(os.path.join(j, '*.avi')) 55 | v_paths = sorted(v_paths) 56 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 57 | 58 | def main_HMDB51(v_root, f_root): 59 | print('extracting HMDB51 ... ') 60 | print('extracting videos from %s' % v_root) 61 | print('frame save to %s' % f_root) 62 | 63 | if not os.path.exists(f_root): os.makedirs(f_root) 64 | v_act_root = glob.glob(os.path.join(v_root, '*/')) 65 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 66 | v_paths = glob.glob(os.path.join(j, '*.avi')) 67 | v_paths = sorted(v_paths) 68 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root) for p in tqdm(v_paths, total=len(v_paths))) 69 | 70 | def main_kinetics400(v_root, f_root, dim=150): 71 | print('extracting Kinetics400 ... ') 72 | for basename in ['train_split', 'val_split']: 73 | v_root_real = v_root + '/' + basename 74 | if not os.path.exists(v_root_real): 75 | print('Wrong v_root'); sys.exit() 76 | f_root_real = '/scratch/local/ssd/htd/kinetics400/frame_full' + '/' + basename 77 | print('Extract to: \nframe: %s' % f_root_real) 78 | if not os.path.exists(f_root_real): os.makedirs(f_root_real) 79 | v_act_root = glob.glob(os.path.join(v_root_real, '*/')) 80 | v_act_root = sorted(v_act_root) 81 | 82 | # if resume, remember to delete the last video folder 83 | for i, j in tqdm(enumerate(v_act_root), total=len(v_act_root)): 84 | v_paths = glob.glob(os.path.join(j, '*.mp4')) 85 | v_paths = sorted(v_paths) 86 | # for resume: 87 | v_class = j.split('/')[-2] 88 | out_dir = os.path.join(f_root_real, v_class) 89 | if os.path.exists(out_dir): print(out_dir, 'exists!'); continue 90 | print('extracting: %s' % v_class) 91 | # dim = 150 (crop to 128 later) or 256 (crop to 224 later) 92 | Parallel(n_jobs=32)(delayed(extract_video_opencv)(p, f_root_real, dim=dim) for p in tqdm(v_paths, total=len(v_paths))) 93 | 94 | 95 | if __name__ == '__main__': 96 | # v_root is the video source path, f_root is where to store frames 97 | # edit 'your_path' here: 98 | 99 | main_UCF101(v_root='your_path/UCF101/videos', 100 | f_root='your_path/UCF101/frame') 101 | 102 | # main_HMDB51(v_root='your_path/HMDB51/videos', 103 | # f_root='your_path/HMDB51/frame') 104 | 105 | # main_kinetics400(v_root='your_path/Kinetics400/videos', 106 | # f_root='your_path/Kinetics400/frame', dim=150) 107 | 108 | # main_kinetics400(v_root='your_path/Kinetics400_256/videos', 109 | # f_root='your_path/Kinetics400_256/frame', dim=256) 110 | -------------------------------------------------------------------------------- /process_data/src/write_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | 5 | def write_list(data_list, path, ): 6 | with open(path, 'w') as f: 7 | writer = csv.writer(f, delimiter=',') 8 | for row in data_list: 9 | if row: writer.writerow(row) 10 | print('split saved to %s' % path) 11 | 12 | def main_UCF101(f_root, splits_root, csv_root='../data/ucf101/'): 13 | '''generate training/testing split, count number of available frames, save in csv''' 14 | if not os.path.exists(csv_root): os.makedirs(csv_root) 15 | for which_split in [1,2,3]: 16 | train_set = [] 17 | test_set = [] 18 | train_split_file = os.path.join(splits_root, 'trainlist%02d.txt' % which_split) 19 | with open(train_split_file, 'r') as f: 20 | for line in f: 21 | vpath = os.path.join(f_root, line.split(' ')[0][0:-4]) + '/' 22 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 23 | 24 | test_split_file = os.path.join(splits_root, 'testlist%02d.txt' % which_split) 25 | with open(test_split_file, 'r') as f: 26 | for line in f: 27 | vpath = os.path.join(f_root, line.rstrip()[0:-4]) + '/' 28 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 29 | 30 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 31 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 32 | 33 | 34 | def main_HMDB51(f_root, splits_root, csv_root='../data/hmdb51/'): 35 | '''generate training/testing split, count number of available frames, save in csv''' 36 | if not os.path.exists(csv_root): os.makedirs(csv_root) 37 | for which_split in [1,2,3]: 38 | train_set = [] 39 | test_set = [] 40 | split_files = sorted(glob.glob(os.path.join(splits_root, '*_test_split%d.txt' % which_split))) 41 | assert len(split_files) == 51 42 | for split_file in split_files: 43 | action_name = os.path.basename(split_file)[0:-16] 44 | with open(split_file, 'r') as f: 45 | for line in f: 46 | video_name = line.split(' ')[0] 47 | _type = line.split(' ')[1] 48 | vpath = os.path.join(f_root, action_name, video_name[0:-4]) + '/' 49 | if _type == '1': 50 | train_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 51 | elif _type == '2': 52 | test_set.append([vpath, len(glob.glob(os.path.join(vpath, '*.jpg')))]) 53 | 54 | write_list(train_set, os.path.join(csv_root, 'train_split%02d.csv' % which_split)) 55 | write_list(test_set, os.path.join(csv_root, 'test_split%02d.csv' % which_split)) 56 | 57 | ### For Kinetics ### 58 | def get_split(root, split_path, mode): 59 | print('processing %s split ...' % mode) 60 | print('checking %s' % root) 61 | split_list = [] 62 | split_content = pd.read_csv(split_path).iloc[:,0:4] 63 | split_list = Parallel(n_jobs=64)\ 64 | (delayed(check_exists)(row, root) \ 65 | for i, row in tqdm(split_content.iterrows(), total=len(split_content))) 66 | return split_list 67 | 68 | def check_exists(row, root): 69 | dirname = '_'.join([row['youtube_id'], '%06d' % row['time_start'], '%06d' % row['time_end']]) 70 | full_dirname = os.path.join(root, row['label'], dirname) 71 | if os.path.exists(full_dirname): 72 | n_frames = len(glob.glob(os.path.join(full_dirname, '*.jpg'))) 73 | return [full_dirname, n_frames] 74 | else: 75 | return None 76 | 77 | def main_Kinetics400(mode, k400_path, f_root, csv_root='../data/kinetics400'): 78 | train_split_path = os.path.join(k400_path, 'kinetics_train/kinetics_train.csv') 79 | val_split_path = os.path.join(k400_path, 'kinetics_val/kinetics_val.csv') 80 | test_split_path = os.path.join(k400_path, 'kinetics_test/kinetics_test.csv') 81 | if not os.path.exists(csv_root): os.makedirs(csv_root) 82 | if mode == 'train': 83 | train_split = get_split(os.path.join(f_root, 'train_split'), train_split_path, 'train') 84 | write_list(train_split, os.path.join(csv_root, 'train_split.csv')) 85 | elif mode == 'val': 86 | val_split = get_split(os.path.join(f_root, 'val_split'), val_split_path, 'val') 87 | write_list(val_split, os.path.join(csv_root, 'val_split.csv')) 88 | elif mode == 'test': 89 | test_split = get_split(f_root, test_split_path, 'test') 90 | write_list(test_split, os.path.join(csv_root, 'test_split.csv')) 91 | else: 92 | raise IOError('wrong mode') 93 | 94 | if __name__ == '__main__': 95 | # f_root is the frame path 96 | # edit 'your_path' here: 97 | 98 | main_UCF101(f_root='your_path/UCF101/frame', 99 | splits_root='your_path/UCF101/splits_classification') 100 | 101 | # main_HMDB51(f_root='your_path/HMDB51/frame', 102 | # splits_root='your_path/HMDB51/split/testTrainMulti_7030_splits') 103 | 104 | # main_Kinetics400(mode='train', # train or val or test 105 | # k400_path='your_path/Kinetics', 106 | # f_root='your_path/Kinetics400/frame') 107 | 108 | # main_Kinetics400(mode='train', # train or val or test 109 | # k400_path='your_path/Kinetics', 110 | # f_root='your_path/Kinetics400_256/frame', 111 | # csv_root='../data/kinetics400_256') -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import math 4 | import collections 5 | import numpy as np 6 | from PIL import ImageOps, Image 7 | from joblib import Parallel, delayed 8 | 9 | import torchvision 10 | from torchvision import transforms 11 | import torchvision.transforms.functional as F 12 | 13 | class Padding: 14 | def __init__(self, pad): 15 | self.pad = pad 16 | 17 | def __call__(self, img): 18 | return ImageOps.expand(img, border=self.pad, fill=0) 19 | 20 | class Scale: 21 | def __init__(self, size, interpolation=Image.NEAREST): 22 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 23 | self.size = size 24 | self.interpolation = interpolation 25 | 26 | def __call__(self, imgmap): 27 | # assert len(imgmap) > 1 # list of images 28 | img1 = imgmap[0] 29 | if isinstance(self.size, int): 30 | w, h = img1.size 31 | if (w <= h and w == self.size) or (h <= w and h == self.size): 32 | return imgmap 33 | if w < h: 34 | ow = self.size 35 | oh = int(self.size * h / w) 36 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 37 | else: 38 | oh = self.size 39 | ow = int(self.size * w / h) 40 | return [i.resize((ow, oh), self.interpolation) for i in imgmap] 41 | else: 42 | return [i.resize(self.size, self.interpolation) for i in imgmap] 43 | 44 | 45 | class CenterCrop: 46 | def __init__(self, size, consistent=True): 47 | if isinstance(size, numbers.Number): 48 | self.size = (int(size), int(size)) 49 | else: 50 | self.size = size 51 | 52 | def __call__(self, imgmap): 53 | img1 = imgmap[0] 54 | w, h = img1.size 55 | th, tw = self.size 56 | x1 = int(round((w - tw) / 2.)) 57 | y1 = int(round((h - th) / 2.)) 58 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 59 | 60 | 61 | class RandomCropWithProb: 62 | def __init__(self, size, p=0.8, consistent=True): 63 | if isinstance(size, numbers.Number): 64 | self.size = (int(size), int(size)) 65 | else: 66 | self.size = size 67 | self.consistent = consistent 68 | self.threshold = p 69 | 70 | def __call__(self, imgmap): 71 | img1 = imgmap[0] 72 | w, h = img1.size 73 | if self.size is not None: 74 | th, tw = self.size 75 | if w == tw and h == th: 76 | return imgmap 77 | if self.consistent: 78 | if random.random() < self.threshold: 79 | x1 = random.randint(0, w - tw) 80 | y1 = random.randint(0, h - th) 81 | else: 82 | x1 = int(round((w - tw) / 2.)) 83 | y1 = int(round((h - th) / 2.)) 84 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 85 | else: 86 | result = [] 87 | for i in imgmap: 88 | if random.random() < self.threshold: 89 | x1 = random.randint(0, w - tw) 90 | y1 = random.randint(0, h - th) 91 | else: 92 | x1 = int(round((w - tw) / 2.)) 93 | y1 = int(round((h - th) / 2.)) 94 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 95 | return result 96 | else: 97 | return imgmap 98 | 99 | class RandomCrop: 100 | def __init__(self, size, consistent=True): 101 | if isinstance(size, numbers.Number): 102 | self.size = (int(size), int(size)) 103 | else: 104 | self.size = size 105 | self.consistent = consistent 106 | 107 | def __call__(self, imgmap, flowmap=None): 108 | img1 = imgmap[0] 109 | w, h = img1.size 110 | if self.size is not None: 111 | th, tw = self.size 112 | if w == tw and h == th: 113 | return imgmap 114 | if not flowmap: 115 | if self.consistent: 116 | x1 = random.randint(0, w - tw) 117 | y1 = random.randint(0, h - th) 118 | return [i.crop((x1, y1, x1 + tw, y1 + th)) for i in imgmap] 119 | else: 120 | result = [] 121 | for i in imgmap: 122 | x1 = random.randint(0, w - tw) 123 | y1 = random.randint(0, h - th) 124 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 125 | return result 126 | elif flowmap is not None: 127 | assert (not self.consistent) 128 | result = [] 129 | for idx, i in enumerate(imgmap): 130 | proposal = [] 131 | for j in range(3): # number of proposal: use the one with largest optical flow 132 | x = random.randint(0, w - tw) 133 | y = random.randint(0, h - th) 134 | proposal.append([x, y, abs(np.mean(flowmap[idx,y:y+th,x:x+tw,:]))]) 135 | [x1, y1, _] = max(proposal, key=lambda x: x[-1]) 136 | result.append(i.crop((x1, y1, x1 + tw, y1 + th))) 137 | return result 138 | else: 139 | raise ValueError('wrong case') 140 | else: 141 | return imgmap 142 | 143 | 144 | class RandomSizedCrop: 145 | def __init__(self, size, interpolation=Image.BILINEAR, consistent=True, p=1.0): 146 | self.size = size 147 | self.interpolation = interpolation 148 | self.consistent = consistent 149 | self.threshold = p 150 | 151 | def __call__(self, imgmap): 152 | img1 = imgmap[0] 153 | if random.random() < self.threshold: # do RandomSizedCrop 154 | for attempt in range(10): 155 | area = img1.size[0] * img1.size[1] 156 | target_area = random.uniform(0.5, 1) * area 157 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 158 | 159 | w = int(round(math.sqrt(target_area * aspect_ratio))) 160 | h = int(round(math.sqrt(target_area / aspect_ratio))) 161 | 162 | if self.consistent: 163 | if random.random() < 0.5: 164 | w, h = h, w 165 | if w <= img1.size[0] and h <= img1.size[1]: 166 | x1 = random.randint(0, img1.size[0] - w) 167 | y1 = random.randint(0, img1.size[1] - h) 168 | 169 | imgmap = [i.crop((x1, y1, x1 + w, y1 + h)) for i in imgmap] 170 | for i in imgmap: assert(i.size == (w, h)) 171 | 172 | return [i.resize((self.size, self.size), self.interpolation) for i in imgmap] 173 | else: 174 | result = [] 175 | for i in imgmap: 176 | if random.random() < 0.5: 177 | w, h = h, w 178 | if w <= img1.size[0] and h <= img1.size[1]: 179 | x1 = random.randint(0, img1.size[0] - w) 180 | y1 = random.randint(0, img1.size[1] - h) 181 | result.append(i.crop((x1, y1, x1 + w, y1 + h))) 182 | assert(result[-1].size == (w, h)) 183 | else: 184 | result.append(i) 185 | 186 | assert len(result) == len(imgmap) 187 | return [i.resize((self.size, self.size), self.interpolation) for i in result] 188 | 189 | # Fallback 190 | scale = Scale(self.size, interpolation=self.interpolation) 191 | crop = CenterCrop(self.size) 192 | return crop(scale(imgmap)) 193 | else: # don't do RandomSizedCrop, do CenterCrop 194 | crop = CenterCrop(self.size) 195 | return crop(imgmap) 196 | 197 | 198 | class RandomHorizontalFlip: 199 | def __init__(self, consistent=True, command=None): 200 | self.consistent = consistent 201 | if command == 'left': 202 | self.threshold = 0 203 | elif command == 'right': 204 | self.threshold = 1 205 | else: 206 | self.threshold = 0.5 207 | def __call__(self, imgmap): 208 | if self.consistent: 209 | if random.random() < self.threshold: 210 | return [i.transpose(Image.FLIP_LEFT_RIGHT) for i in imgmap] 211 | else: 212 | return imgmap 213 | else: 214 | result = [] 215 | for i in imgmap: 216 | if random.random() < self.threshold: 217 | result.append(i.transpose(Image.FLIP_LEFT_RIGHT)) 218 | else: 219 | result.append(i) 220 | assert len(result) == len(imgmap) 221 | return result 222 | 223 | 224 | class RandomGray: 225 | '''Actually it is a channel splitting, not strictly grayscale images''' 226 | def __init__(self, consistent=True, p=0.5): 227 | self.consistent = consistent 228 | self.p = p # probability to apply grayscale 229 | def __call__(self, imgmap): 230 | if self.consistent: 231 | if random.random() < self.p: 232 | return [self.grayscale(i) for i in imgmap] 233 | else: 234 | return imgmap 235 | else: 236 | result = [] 237 | for i in imgmap: 238 | if random.random() < self.p: 239 | result.append(self.grayscale(i)) 240 | else: 241 | result.append(i) 242 | assert len(result) == len(imgmap) 243 | return result 244 | 245 | def grayscale(self, img): 246 | channel = np.random.choice(3) 247 | np_img = np.array(img)[:,:,channel] 248 | np_img = np.dstack([np_img, np_img, np_img]) 249 | img = Image.fromarray(np_img, 'RGB') 250 | return img 251 | 252 | 253 | class ColorJitter(object): 254 | """Randomly change the brightness, contrast and saturation of an image. --modified from pytorch source code 255 | Args: 256 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 257 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 258 | or the given [min, max]. Should be non negative numbers. 259 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 260 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 261 | or the given [min, max]. Should be non negative numbers. 262 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 263 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 264 | or the given [min, max]. Should be non negative numbers. 265 | hue (float or tuple of float (min, max)): How much to jitter hue. 266 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 267 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 268 | """ 269 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, consistent=False, p=1.0): 270 | self.brightness = self._check_input(brightness, 'brightness') 271 | self.contrast = self._check_input(contrast, 'contrast') 272 | self.saturation = self._check_input(saturation, 'saturation') 273 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 274 | clip_first_on_zero=False) 275 | self.consistent = consistent 276 | self.threshold = p 277 | 278 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 279 | if isinstance(value, numbers.Number): 280 | if value < 0: 281 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 282 | value = [center - value, center + value] 283 | if clip_first_on_zero: 284 | value[0] = max(value[0], 0) 285 | elif isinstance(value, (tuple, list)) and len(value) == 2: 286 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 287 | raise ValueError("{} values should be between {}".format(name, bound)) 288 | else: 289 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 290 | 291 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 292 | # or (0., 0.) for hue, do nothing 293 | if value[0] == value[1] == center: 294 | value = None 295 | return value 296 | 297 | @staticmethod 298 | def get_params(brightness, contrast, saturation, hue): 299 | """Get a randomized transform to be applied on image. 300 | Arguments are same as that of __init__. 301 | Returns: 302 | Transform which randomly adjusts brightness, contrast and 303 | saturation in a random order. 304 | """ 305 | transforms = [] 306 | 307 | if brightness is not None: 308 | brightness_factor = random.uniform(brightness[0], brightness[1]) 309 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 310 | 311 | if contrast is not None: 312 | contrast_factor = random.uniform(contrast[0], contrast[1]) 313 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 314 | 315 | if saturation is not None: 316 | saturation_factor = random.uniform(saturation[0], saturation[1]) 317 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 318 | 319 | if hue is not None: 320 | hue_factor = random.uniform(hue[0], hue[1]) 321 | transforms.append(torchvision.transforms.Lambda(lambda img: F.adjust_hue(img, hue_factor))) 322 | 323 | random.shuffle(transforms) 324 | transform = torchvision.transforms.Compose(transforms) 325 | 326 | return transform 327 | 328 | def __call__(self, imgmap): 329 | if random.random() < self.threshold: # do ColorJitter 330 | if self.consistent: 331 | transform = self.get_params(self.brightness, self.contrast, 332 | self.saturation, self.hue) 333 | return [transform(i) for i in imgmap] 334 | else: 335 | result = [] 336 | for img in imgmap: 337 | transform = self.get_params(self.brightness, self.contrast, 338 | self.saturation, self.hue) 339 | result.append(transform(img)) 340 | return result 341 | else: # don't do ColorJitter, do nothing 342 | return imgmap 343 | 344 | def __repr__(self): 345 | format_string = self.__class__.__name__ + '(' 346 | format_string += 'brightness={0}'.format(self.brightness) 347 | format_string += ', contrast={0}'.format(self.contrast) 348 | format_string += ', saturation={0}'.format(self.saturation) 349 | format_string += ', hue={0})'.format(self.hue) 350 | return format_string 351 | 352 | 353 | class RandomRotation: 354 | def __init__(self, consistent=True, degree=15, p=1.0): 355 | self.consistent = consistent 356 | self.degree = degree 357 | self.threshold = p 358 | def __call__(self, imgmap): 359 | if random.random() < self.threshold: # do RandomRotation 360 | if self.consistent: 361 | deg = np.random.randint(-self.degree, self.degree, 1)[0] 362 | return [i.rotate(deg, expand=True) for i in imgmap] 363 | else: 364 | return [i.rotate(np.random.randint(-self.degree, self.degree, 1)[0], expand=True) for i in imgmap] 365 | else: # don't do RandomRotation, do nothing 366 | return imgmap 367 | 368 | class ToTensor: 369 | def __call__(self, imgmap): 370 | totensor = transforms.ToTensor() 371 | return [totensor(i) for i in imgmap] 372 | 373 | class Normalize: 374 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 375 | self.mean = mean 376 | self.std = std 377 | def __call__(self, imgmap): 378 | normalize = transforms.Normalize(mean=self.mean, std=self.std) 379 | return [normalize(i) for i in imgmap] 380 | 381 | 382 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import os 5 | from datetime import datetime 6 | import glob 7 | import re 8 | import matplotlib.pyplot as plt 9 | plt.switch_backend('agg') 10 | from collections import deque 11 | from tqdm import tqdm 12 | from torchvision import transforms 13 | 14 | def save_checkpoint(state, is_best=0, gap=1, filename='models/checkpoint.pth.tar', keep_all=False): 15 | torch.save(state, filename) 16 | last_epoch_path = os.path.join(os.path.dirname(filename), 17 | 'epoch%s.pth.tar' % str(state['epoch']-gap)) 18 | if not keep_all: 19 | try: os.remove(last_epoch_path) 20 | except: pass 21 | if is_best: 22 | past_best = glob.glob(os.path.join(os.path.dirname(filename), 'model_best_*.pth.tar')) 23 | for i in past_best: 24 | try: os.remove(i) 25 | except: pass 26 | torch.save(state, os.path.join(os.path.dirname(filename), 'model_best_epoch%s.pth.tar' % str(state['epoch']))) 27 | 28 | def write_log(content, epoch, filename): 29 | if not os.path.exists(filename): 30 | log_file = open(filename, 'w') 31 | else: 32 | log_file = open(filename, 'a') 33 | log_file.write('## Epoch %d:\n' % epoch) 34 | log_file.write('time: %s\n' % str(datetime.now())) 35 | log_file.write(content + '\n\n') 36 | log_file.close() 37 | 38 | def calc_topk_accuracy(output, target, topk=(1,)): 39 | ''' 40 | Modified from: https://gist.github.com/agermanidis/275b23ad7a10ee89adccf021536bb97e 41 | Given predicted and ground truth labels, 42 | calculate top-k accuracies. 43 | ''' 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 54 | res.append(correct_k.mul_(1 / batch_size)) 55 | return res 56 | 57 | def calc_accuracy(output, target): 58 | '''output: (B, N); target: (B)''' 59 | target = target.squeeze() 60 | _, pred = torch.max(output, 1) 61 | return torch.mean((pred == target).float()) 62 | 63 | def calc_accuracy_binary(output, target): 64 | '''output, target: (B, N), output is logits, before sigmoid ''' 65 | pred = output > 0 66 | acc = torch.mean((pred == target.byte()).float()) 67 | del pred, output, target 68 | return acc 69 | 70 | def denorm(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 71 | assert len(mean)==len(std)==3 72 | inv_mean = [-mean[i]/std[i] for i in range(3)] 73 | inv_std = [1/i for i in std] 74 | return transforms.Normalize(mean=inv_mean, std=inv_std) 75 | 76 | 77 | class AverageMeter(object): 78 | """Computes and stores the average and current value""" 79 | def __init__(self): 80 | self.reset() 81 | 82 | def reset(self): 83 | self.val = 0 84 | self.avg = 0 85 | self.sum = 0 86 | self.count = 0 87 | self.local_history = deque([]) 88 | self.local_avg = 0 89 | self.history = [] 90 | self.dict = {} # save all data values here 91 | self.save_dict = {} # save mean and std here, for summary table 92 | 93 | def update(self, val, n=1, history=0, step=5): 94 | self.val = val 95 | self.sum += val * n 96 | self.count += n 97 | self.avg = self.sum / self.count 98 | if history: 99 | self.history.append(val) 100 | if step > 0: 101 | self.local_history.append(val) 102 | if len(self.local_history) > step: 103 | self.local_history.popleft() 104 | self.local_avg = np.average(self.local_history) 105 | 106 | def dict_update(self, val, key): 107 | if key in self.dict.keys(): 108 | self.dict[key].append(val) 109 | else: 110 | self.dict[key] = [val] 111 | 112 | def __len__(self): 113 | return self.count 114 | 115 | 116 | class AccuracyTable(object): 117 | '''compute accuracy for each class''' 118 | def __init__(self): 119 | self.dict = {} 120 | 121 | def update(self, pred, tar): 122 | pred = torch.squeeze(pred) 123 | tar = torch.squeeze(tar) 124 | for i, j in zip(pred, tar): 125 | i = int(i) 126 | j = int(j) 127 | if j not in self.dict.keys(): 128 | self.dict[j] = {'count':0,'correct':0} 129 | self.dict[j]['count'] += 1 130 | if i == j: 131 | self.dict[j]['correct'] += 1 132 | 133 | def print_table(self, label): 134 | for key in self.dict.keys(): 135 | acc = self.dict[key]['correct'] / self.dict[key]['count'] 136 | print('%s: %2d, accuracy: %3d/%3d = %0.6f' \ 137 | % (label, key, self.dict[key]['correct'], self.dict[key]['count'], acc)) 138 | 139 | 140 | class ConfusionMeter(object): 141 | '''compute and show confusion matrix''' 142 | def __init__(self, num_class): 143 | self.num_class = num_class 144 | self.mat = np.zeros((num_class, num_class)) 145 | self.precision = [] 146 | self.recall = [] 147 | 148 | def update(self, pred, tar): 149 | pred, tar = pred.cpu().numpy(), tar.cpu().numpy() 150 | pred = np.squeeze(pred) 151 | tar = np.squeeze(tar) 152 | for p,t in zip(pred.flat, tar.flat): 153 | self.mat[p][t] += 1 154 | 155 | def print_mat(self): 156 | print('Confusion Matrix: (target in columns)') 157 | print(self.mat) 158 | 159 | def plot_mat(self, path, dictionary=None, annotate=False): 160 | plt.figure(dpi=600) 161 | plt.imshow(self.mat, 162 | cmap=plt.cm.jet, 163 | interpolation=None, 164 | extent=(0.5, np.shape(self.mat)[0]+0.5, np.shape(self.mat)[1]+0.5, 0.5)) 165 | width, height = self.mat.shape 166 | if annotate: 167 | for x in range(width): 168 | for y in range(height): 169 | plt.annotate(str(int(self.mat[x][y])), xy=(y+1, x+1), 170 | horizontalalignment='center', 171 | verticalalignment='center', 172 | fontsize=8) 173 | 174 | if dictionary is not None: 175 | plt.xticks([i+1 for i in range(width)], 176 | [dictionary[i] for i in range(width)], 177 | rotation='vertical') 178 | plt.yticks([i+1 for i in range(height)], 179 | [dictionary[i] for i in range(height)]) 180 | plt.xlabel('Ground Truth') 181 | plt.ylabel('Prediction') 182 | plt.colorbar() 183 | plt.tight_layout() 184 | plt.savefig(path, format='svg') 185 | plt.clf() 186 | 187 | # for i in range(width): 188 | # if np.sum(self.mat[i,:]) != 0: 189 | # self.precision.append(self.mat[i,i] / np.sum(self.mat[i,:])) 190 | # if np.sum(self.mat[:,i]) != 0: 191 | # self.recall.append(self.mat[i,i] / np.sum(self.mat[:,i])) 192 | # print('Average Precision: %0.4f' % np.mean(self.precision)) 193 | # print('Average Recall: %0.4f' % np.mean(self.recall)) 194 | 195 | 196 | 197 | 198 | --------------------------------------------------------------------------------