├── .gitignore ├── README.md ├── archs ├── __init__.py ├── resnet.py └── utils.py ├── datasets └── Mini-Something-V2 │ └── data │ ├── category_mini.txt │ ├── train_videofolder_mini.txt │ └── val_videofolder_mini.txt ├── docs ├── index.html └── resources │ ├── .gitkeep │ ├── Mini-Something_results.png │ ├── Semi-Supervised Action Recognition With Temporal Contrastive Learning.pdf │ ├── TCL_bib.txt │ ├── arch.png │ ├── arch_small.png │ └── comp.png ├── main.py ├── ops ├── __init__.py ├── basic_ops.py ├── dataset.py ├── dataset_config.py ├── models.py ├── non_local.py ├── temporal_shift.py ├── transforms.py └── utils.py ├── opts.py ├── requirements.txt ├── root_dataset.yaml └── tools ├── extract_videos_kinetics.py ├── extract_videos_moments.py └── extract_videos_st2st_v2.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/log 2 | **/checkpoint 3 | **/__pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Semi-Supervised Action Recognition with Temporal Contrastive Learning [[Paper]](https://arxiv.org/pdf/2102.02751.pdf) [[Website]](https://cvir.github.io/TCL/) 2 | 3 | This repository contains the implementation details of our Temporal Contrastive Learning (TCL) approach for action recognition in videos. 4 | 5 | Ankit Singh*, Omprakash Chakraborty*, Ashutosh Varshney, Rameswar Panda, Rogerio Feris, Kate Saenko and Abir Das, "Semi-Supervised Action Recognition with Temporal Contrastive Learning"\ 6 | *: Equal contributions 7 | 8 | If you use the codes and models from this repo, please cite our work. Thanks! 9 | 10 | ``` 11 | @InProceedings{Singh_2021_CVPR, 12 | author = {Singh, Ankit and Chakraborty, Omprakash and Varshney, Ashutosh and Panda, Rameswar and Feris, Rogerio and Saenko, Kate and Das, Abir}, 13 | title = {Semi-Supervised Action Recognition With Temporal Contrastive Learning}, 14 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 15 | month = {June}, 16 | year = {2021}, 17 | pages = {10389-10399} 18 | } 19 | ``` 20 | 21 | ## Requirements 22 | The code is written for python `3.6.10`, but should work for other version with some modifications. 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | ## Data Preparation 27 | 28 | The dataloader (ops/dataset.py) can load videos (image sequences) stored in the following format: 29 | ``` 30 | -- dataset_dir 31 | ---- data 32 | ------category.txt 33 | ------train.txt 34 | ------val.txt 35 | ---- Frames 36 | ------ video_0_folder 37 | -------- 00001.jpg 38 | -------- 00002.jpg 39 | -------- ... 40 | ------ video_1_folder 41 | ------ ... 42 | ``` 43 | For each dataset, `root_dataset.yaml` should contain the `dataset_dir` where each dataset is stored 44 | 45 | 46 | Each line in `train.txt` and `val.txt` includes 3 elements and separated by space. 47 | Four elements (in order) include (1)relative paths to `video_x_folder` from `dataset_dir`, (2) total number of frames, (3) label id (a numeric number). 48 | 49 | E.g., a `video_x` has `300` frames and belong to label `1`. 50 | ``` 51 | path/to/video_x_folder 300 1 52 | ``` 53 | 54 | After that, in the ops/dataset_config.py, the location paths of `category.txt`, `Frames`, `train.txt` and `val.txt` should be included accordingly. 55 | 56 | Samples for some datasets are already mentioned in the respective files. 57 | 58 | We provided three sample scripts in the `tools` folder to help convert some datasets but the details in the scripts must be set accordingly. E.g., the path to videos. 59 | 60 | ## Mini-datasets 61 | We provide the [`category.txt`](datasets/Mini-Something-V2/data/category.txt), [`train.txt`](datasets/Mini-Something-V2/data/train.txt) and [`val.txt`](datasets/Mini-Something-V2/data/val.txt) for the Mini-Something-Something V2 dataset. 62 | 63 | ## Python script overview 64 | 65 | `main.py` - It contains the code for Temporal Contrastive Learning(TCL) with the 2 pathway model. 66 | 67 | `opts.py` - It contains the file with default value for different parameter used in 2 pathway model. 68 | 69 | `ops/dataset_config.py` - It contains the code for different config for different dataset and their location e.g Kinetics, Jester, SomethingV2 70 | 71 | `ops/dataset.py` - It contains the code for how frames are sampled from video 72 | 73 | ### Key Parameters: 74 | `use_group_contrastive`: to use group contrastive loss \ 75 | `use_finetuning` : option to use finetuning at the last \ 76 | `finetune_start_epoch`: from which epoch to start finetuning \ 77 | `finetune_lr`: if want to use different lr other than normal one\ 78 | `gamma_finetune`: weight for pl_loss in finetuning step \ 79 | `finetune_stage_eval_freq`: printing freq for finetuning stage\ 80 | `threshold`: used in fine tuning step for selection of labels \ 81 | `sup_thresh`: till which epoch supervised only to be run \ 82 | `percentage`: percentage of unlabeled data e.g 0.99 ,0.95 \ 83 | `gamma`: weight of instance contrastive loss \ 84 | `lr`: starting learning rate \ 85 | `mu`: ratio of unlabeled to labeled data \ 86 | `flip`: whether to use horizontal flip in transforms or not 87 | 88 | 89 | ### Training TCL 90 | - For running `x%` labeled data scenario, it expects to have a folder named `Run_x` where all the labeled and unlabeled data will be split as per the input seed. 91 | - All the models and logs will be stored inside a sub folder of checkpoints directory. A different subfolder will be created on each execution. 92 | 93 | ### Sample Code to train TCL 94 | 95 | `python main.py somethingv2 RGB --seed 123 --strategy classwise 96 | --arch resnet18 --num_segments 8 --second_segments 4 --threshold 0.8 --gd 20 --lr 0.02 --wd 1e-4 97 | --epochs 400 --percentage 0.95 --batch-size 8 -j 16 --dropout 0.5 --consensus_type=avg --eval-freq=1 --print-freq 50 98 | --shift --shift_div=8 --shift_place=blockres --npb --gpus 0 1 --mu 3 --gamma 9 --gamma_finetune 1 99 | --use_group_contrastive --use_finetuning --finetune_start_epoch 350 --sup_thresh 50 --valbatchsize 16 --finetune_lr 0.002` 100 | 101 | ## Reference 102 | 103 | The implementation reused some portions from [TSM](https://github.com/mit-han-lab/temporal-shift-module)[1]. 104 | 105 | 106 | 1. Lin, Ji, Chuang Gan, and Song Han. "Tsm: Temporal shift module for efficient video understanding." Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019. 107 | -------------------------------------------------------------------------------- /archs/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet_dict 2 | -------------------------------------------------------------------------------- /archs/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import load_state_dict_from_url 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 17 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 18 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 19 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 20 | } 21 | 22 | resnet_dict={} 23 | 24 | def resnet_dic(function): 25 | resnet_dict[function.__name__] = function 26 | return function 27 | 28 | 29 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 30 | """3x3 convolution with padding""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=dilation, groups=groups, bias=False, dilation=dilation) 33 | 34 | 35 | def conv1x1(in_planes, out_planes, stride=1): 36 | """1x1 convolution""" 37 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 44 | base_width=64, dilation=1, norm_layer=None): 45 | super(BasicBlock, self).__init__() 46 | if norm_layer is None: 47 | norm_layer = nn.BatchNorm2d 48 | if groups != 1 or base_width != 64: 49 | raise ValueError( 50 | 'BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError( 53 | "Dilation > 1 not supported in BasicBlock") 54 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 55 | self.conv1 = conv3x3(inplanes, planes, stride) 56 | self.bn1 = norm_layer(planes) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3(planes, planes) 59 | self.bn2 = norm_layer(planes) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | identity = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | 73 | if self.downsample is not None: 74 | identity = self.downsample(x) 75 | 76 | out += identity 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class Bottleneck(nn.Module): 83 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 84 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 85 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 86 | # This variant is also known as ResNet V1.5 and improves accuracy according to 87 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 88 | 89 | expansion = 4 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 92 | base_width=64, dilation=1, norm_layer=None): 93 | super(Bottleneck, self).__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | width = int(planes * (base_width / 64.)) * groups 97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 98 | self.conv1 = conv1x1(inplanes, width) 99 | self.bn1 = norm_layer(width) 100 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 101 | self.bn2 = norm_layer(width) 102 | self.conv3 = conv1x1(width, planes * self.expansion) 103 | self.bn3 = norm_layer(planes * self.expansion) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | identity = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | identity = self.downsample(x) 124 | 125 | out += identity 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | 133 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 134 | norm_layer=None): 135 | super(ResNet, self).__init__() 136 | if norm_layer is None: 137 | norm_layer = nn.BatchNorm2d 138 | self._norm_layer = norm_layer 139 | 140 | self.inplanes = 64 141 | self.dilation = 1 142 | if replace_stride_with_dilation is None: 143 | # each element in the tuple indicates if we should replace 144 | # the 2x2 stride with a dilated convolution instead 145 | replace_stride_with_dilation = [False, False, False] 146 | if len(replace_stride_with_dilation) != 3: 147 | raise ValueError("replace_stride_with_dilation should be None " 148 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 149 | self.groups = groups 150 | self.base_width = width_per_group 151 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 152 | bias=False) 153 | self.bn1 = norm_layer(self.inplanes) 154 | self.relu = nn.ReLU(inplace=True) 155 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 156 | self.layer1 = self._make_layer(block, 64, layers[0]) 157 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 158 | dilate=replace_stride_with_dilation[0]) 159 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 160 | dilate=replace_stride_with_dilation[1]) 161 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 162 | dilate=replace_stride_with_dilation[2]) 163 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 164 | self.fc = nn.Linear(512 * block.expansion, num_classes) 165 | 166 | for m in self.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | nn.init.kaiming_normal_( 169 | m.weight, mode='fan_out', nonlinearity='relu') 170 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 171 | nn.init.constant_(m.weight, 1) 172 | nn.init.constant_(m.bias, 0) 173 | 174 | # Zero-initialize the last BN in each residual branch, 175 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 176 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 177 | if zero_init_residual: 178 | for m in self.modules(): 179 | if isinstance(m, Bottleneck): 180 | nn.init.constant_(m.bn3.weight, 0) 181 | elif isinstance(m, BasicBlock): 182 | nn.init.constant_(m.bn2.weight, 0) 183 | 184 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 185 | norm_layer = self._norm_layer 186 | downsample = None 187 | previous_dilation = self.dilation 188 | if dilate: 189 | self.dilation *= stride 190 | stride = 1 191 | if stride != 1 or self.inplanes != planes * block.expansion: 192 | downsample = nn.Sequential( 193 | conv1x1(self.inplanes, planes * block.expansion, stride), 194 | norm_layer(planes * block.expansion), 195 | ) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 199 | self.base_width, previous_dilation, norm_layer)) 200 | self.inplanes = planes * block.expansion 201 | for _ in range(1, blocks): 202 | layers.append(block(self.inplanes, planes, groups=self.groups, 203 | base_width=self.base_width, dilation=self.dilation, 204 | norm_layer=norm_layer)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def _forward_impl(self, x, unlabeled=False): 209 | # See note [TorchScript super()] 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | x = self.maxpool(x) 214 | #print("size of x is layer1 is {}".format(x.size())) 215 | x = self.layer1(x, unlabeled) 216 | #print("size of x is layer2 is {}".format(x.size())) 217 | 218 | x = self.layer2(x, unlabeled) 219 | #print("size of x is layer3 is {}".format(x.size())) 220 | 221 | x = self.layer3(x, unlabeled) 222 | #print("size of x is layer4 is {}".format(x.size())) 223 | 224 | x = self.layer4(x, unlabeled) 225 | 226 | x = self.avgpool(x) 227 | x = torch.flatten(x, 1) 228 | x = self.fc(x) 229 | 230 | return x 231 | 232 | def forward(self, x, unlabeled=False): 233 | return self._forward_impl(x, unlabeled) 234 | 235 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 236 | model = ResNet(block, layers, **kwargs) 237 | if pretrained: 238 | state_dict = load_state_dict_from_url(model_urls[arch], 239 | progress=progress) 240 | model.load_state_dict(state_dict) 241 | return model 242 | 243 | @resnet_dic 244 | def resnet18(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-18 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | progress (bool): If True, displays a progress bar of the download to stderr 250 | """ 251 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 252 | **kwargs) 253 | 254 | @resnet_dic 255 | def resnet34(pretrained=False, progress=True, **kwargs): 256 | r"""ResNet-34 model from 257 | `"Deep Residual Learning for Image Recognition" `_ 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | @resnet_dic 266 | def resnet50(pretrained=False, progress=True, **kwargs): 267 | r"""ResNet-50 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | progress (bool): If True, displays a progress bar of the download to stderr 272 | """ 273 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 274 | **kwargs) 275 | 276 | @resnet_dic 277 | def resnet101(pretrained=False, progress=True, **kwargs): 278 | r"""ResNet-101 model from 279 | `"Deep Residual Learning for Image Recognition" `_ 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 285 | **kwargs) 286 | 287 | @resnet_dic 288 | def resnet152(pretrained=False, progress=True, **kwargs): 289 | r"""ResNet-152 model from 290 | `"Deep Residual Learning for Image Recognition" `_ 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 296 | **kwargs) 297 | 298 | @resnet_dic 299 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 300 | r"""ResNeXt-50 32x4d model from 301 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | progress (bool): If True, displays a progress bar of the download to stderr 305 | """ 306 | kwargs['groups'] = 32 307 | kwargs['width_per_group'] = 4 308 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 309 | pretrained, progress, **kwargs) 310 | 311 | @resnet_dic 312 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 313 | r"""ResNeXt-101 32x8d model from 314 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | progress (bool): If True, displays a progress bar of the download to stderr 318 | """ 319 | kwargs['groups'] = 32 320 | kwargs['width_per_group'] = 8 321 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 322 | pretrained, progress, **kwargs) 323 | 324 | @resnet_dic 325 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 326 | r"""Wide ResNet-50-2 model from 327 | `"Wide Residual Networks" `_ 328 | The model is the same as ResNet except for the bottleneck number of channels 329 | which is twice larger in every block. The number of channels in outer 1x1 330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['width_per_group'] = 64 * 2 337 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 338 | pretrained, progress, **kwargs) 339 | 340 | @resnet_dic 341 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 342 | r"""Wide ResNet-101-2 model from 343 | `"Wide Residual Networks" `_ 344 | The model is the same as ResNet except for the bottleneck number of channels 345 | which is twice larger in every block. The number of channels in outer 1x1 346 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 347 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 348 | Args: 349 | pretrained (bool): If True, returns a model pre-trained on ImageNet 350 | progress (bool): If True, displays a progress bar of the download to stderr 351 | """ 352 | kwargs['width_per_group'] = 64 * 2 353 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 354 | pretrained, progress, **kwargs) 355 | -------------------------------------------------------------------------------- /archs/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url -------------------------------------------------------------------------------- /datasets/Mini-Something-V2/data/category_mini.txt: -------------------------------------------------------------------------------- 1 | Approaching something with your camera 2 | Attaching something to something 3 | Bending something until it breaks 4 | Digging something out of something 5 | Dropping something in front of something 6 | Dropping something into something 7 | Dropping something next to something 8 | Failing to put something into something because something does not fit 9 | Folding something 10 | Holding something behind something 11 | Holding something next to something 12 | Letting something roll up a slanted surface, so it rolls back down 13 | Lifting up one end of something without letting it drop down 14 | Lifting up one end of something, then letting it drop down 15 | Moving something and something away from each other 16 | Moving something and something closer to each other 17 | Moving something and something so they pass each other 18 | Moving something away from the camera 19 | Moving something towards the camera 20 | Moving away from something with your camera 21 | Opening something 22 | Picking something up 23 | Poking something so it slightly moves 24 | Poking something so lightly that it doesn't or almost doesn't move 25 | Poking something so that it falls over 26 | Poking a stack of something so the stack collapses 27 | Pretending or failing to wipe something off of something 28 | Pretending or trying and failing to twist something 29 | Pretending to be tearing something that is not tearable 30 | Pretending to close something without actually closing it 31 | Pretending to open something without actually opening it 32 | Pretending to pick something up 33 | Pretending to pour something out of something, but something is empty 34 | Pretending to put something behind something 35 | Pretending to put something into something 36 | Pretending to put something on a surface 37 | Pretending to put something onto something 38 | Pretending to scoop something up with something 39 | Pretending to spread air onto something 40 | Pretending to take something out of something 41 | Pretending to turn something upside down 42 | Pulling something from right to left 43 | Pulling something out of something 44 | Pulling two ends of something but nothing happens 45 | Pulling two ends of something so that it gets stretched 46 | Pulling two ends of something so that it separates into two pieces 47 | Pushing something from left to right 48 | Pushing something onto something 49 | Pushing something so it spins 50 | Pushing something so that it almost falls off but doesn't 51 | Pushing something so that it slightly moves 52 | Pushing something with something 53 | Putting something on a surface 54 | Putting something onto something else that cannot support it so it falls down 55 | Putting something that can't roll onto a slanted surface, so it stays where it is 56 | Showing something behind something 57 | Showing something next to something 58 | Showing something to the camera 59 | Showing a photo of something to the camera 60 | Spilling something next to something 61 | Spilling something onto something 62 | Spinning something that quickly stops spinning 63 | Spreading something onto something 64 | Sprinkling something onto something 65 | Stacking number of something 66 | Stuffing something into something 67 | Taking something from somewhere 68 | Tearing something just a little bit 69 | Throwing something 70 | Throwing something in the air and catching it 71 | Throwing something in the air and letting it fall 72 | Throwing something onto a surface 73 | Tilting something with something on it slightly so it doesn't fall down 74 | Tilting something with something on it until it falls off 75 | Touching (without moving) part of something 76 | Trying but failing to attach something to something because it doesn't stick 77 | Trying to bend something unbendable so nothing happens 78 | Trying to pour something into something, but missing so it spills next to it 79 | Turning the camera right while filming something 80 | Twisting something 81 | Uncovering something 82 | Unfolding something 83 | Something being deflected from something 84 | Something colliding with something and both are being deflected 85 | Something colliding with something and both come to a halt 86 | Something falling like a feather or paper 87 | Something falling like a rock -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 138 | 139 | 140 | 141 | 142 | Semi-Supervised Action Recognition with Temporal Contrastive Learning 143 | 144 | 145 | 146 | 147 | 148 |
149 |
Semi-Supervised Action Recognition with
Temporal Contrastive Learning
150 | 151 | 152 | 153 | 156 | 158 | 160 | 162 | 163 |
154 |
Ankit 155 | Singh1
157 |
Omprakash Chakraborty2
159 |
Ashutosh Varshney2
161 |
Rameswar Panda3
164 | 165 | 166 | 167 | 169 | 171 | 173 | 174 |
168 |
Rogerio Feris3
170 |
Kate Saenko3,4
172 |
Abir Das2
175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 |
1 IIT Madras
2 IIT Kharagpur
3 MIT-IBM Watson AI Lab
4 Boston University
184 | 185 | 186 | 188 | 189 |
187 |
CVPR 2021
190 | 191 | 194 |
192 |

193 |
195 | 196 |

Abstract

197 | Learning to recognize actions from only a handful of labeled videos is a challenging problem due to the scarcity of tediously collected activity labels. We approach this problem by learning a two-pathway temporal contrastive model using unlabeled videos at two different speeds leveraging the fact that changing video speed does not change an action. Specifically, we propose to maximize the similarity between encoded representations of the same video at two different speeds as well as minimize the similarity between different videos played at different speeds. This way we use the rich supervisory information in terms of ‘time’ that is present in otherwise unsupervised pool of videos. With this simple yet effective strategy of manipulating video playbackrates, we considerably outperform video extensions of sophisticated state-of-the-art semi-supervised image recognition methods across multiple diverse benchmark datasets and network architectures. Interestingly, our proposed approach benefits from out-of-domain unlabeled videos showing generalization and robustness. We also perform rigorous ablations and analysis to validate our approach
198 |
199 |
200 | 201 |

Comparative Study of TCL

202 | 203 | 206 |
204 |

205 |
207 | Comparison of top-1 accuracy for TCL(Ours) with Pseudo-Label and FixMatch baselines trained with different percentages of labeled training data 208 |
209 |
210 |
211 |
212 | 213 |

Results on Mini-Something-V2

214 | 215 | 218 |
216 |

217 |
219 |
220 |
221 | 222 | 223 | 224 |

Paper, code and other details

225 | 226 | 227 | 228 | 229 | 230 | 233 | 234 | 246 | 247 | 248 |
231 | 232 | 235 | Ankit Singh*, Omprakash Chakraborty*, Ashutosh Varshney, Rameswar Panda, Rogerio Feris, Kate Saenko, Abir Das
Semi-Supervised Action Recognition with Temporal Contrastive Learning
Computer Vision and Pattern Recognition (CVPR), 2021
236 | [PDF] 237 | [Supp] 238 | 239 | [Code] 240 | [Bibtex] 241 | [Video Presentation] 242 | [Poster] 243 | 244 |
245 |
249 | 250 |
251 |
252 | 253 |
254 | 255 |

256 | 259 | 260 | 261 | -------------------------------------------------------------------------------- /docs/resources/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/resources/Mini-Something_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVIR/TCL/661076c893120cdf980e10adb4440f773a9979fd/docs/resources/Mini-Something_results.png -------------------------------------------------------------------------------- /docs/resources/Semi-Supervised Action Recognition With Temporal Contrastive Learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVIR/TCL/661076c893120cdf980e10adb4440f773a9979fd/docs/resources/Semi-Supervised Action Recognition With Temporal Contrastive Learning.pdf -------------------------------------------------------------------------------- /docs/resources/TCL_bib.txt: -------------------------------------------------------------------------------- 1 | @InProceedings{Singh_2021_CVPR, 2 | author = {Singh, Ankit and Chakraborty, Omprakash and Varshney, Ashutosh and Panda, Rameswar and Feris, Rogerio and Saenko, Kate and Das, Abir}, 3 | title = {Semi-Supervised Action Recognition With Temporal Contrastive Learning}, 4 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | month = {June}, 6 | year = {2021}, 7 | pages = {10389-10399} 8 | } -------------------------------------------------------------------------------- /docs/resources/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVIR/TCL/661076c893120cdf980e10adb4440f773a9979fd/docs/resources/arch.png -------------------------------------------------------------------------------- /docs/resources/arch_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVIR/TCL/661076c893120cdf980e10adb4440f773a9979fd/docs/resources/arch_small.png -------------------------------------------------------------------------------- /docs/resources/comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVIR/TCL/661076c893120cdf980e10adb4440f773a9979fd/docs/resources/comp.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import random 5 | import datetime 6 | import numpy as np 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.nn.functional as F 11 | from collections import defaultdict 12 | from torch.nn.utils import clip_grad_norm_ 13 | 14 | from ops.dataset import TSNDataSet 15 | from ops.models import TSN 16 | from ops.transforms import * 17 | from opts import parser 18 | from ops import dataset_config 19 | from ops.utils import AverageMeter, accuracy 20 | from ops.temporal_shift import make_temporal_pool 21 | 22 | best_prec1 = 0 23 | 24 | 25 | def main(): 26 | print(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 27 | global args, best_prec1 28 | args = parser.parse_args() 29 | 30 | ##asset check #### 31 | if args.use_finetuning: 32 | assert args.finetune_start_epoch > args.sup_thresh 33 | 34 | num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,args.modality) 35 | full_arch_name = args.arch 36 | if args.temporal_pool: 37 | full_arch_name += '_tpool' 38 | args.store_name = '_'.join( 39 | ['TCL', datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), args.dataset, full_arch_name, 'p%.2f' % args.percentage,'th%.2f' % args.threshold,'gamma%0.2f' % args.gamma,'mu%0.2f'% args.mu,'seed%d' % args.seed,'seg%d' % args.num_segments, 'bs%d' % args.batch_size, 40 | 'e{}'.format(args.epochs)]) 41 | if args.dense_sample: 42 | args.store_name += '_dense' 43 | if args.non_local > 0: 44 | args.store_name += '_nl' 45 | if args.suffix is not None: 46 | args.store_name += '_{}'.format(args.suffix) 47 | print('storing name: ' + args.store_name) 48 | 49 | check_rootfolders() 50 | 51 | args.labeled_train_list, args.unlabeled_train_list=get_training_filenames(args.train_list) 52 | 53 | model = TSN(num_class, args.num_segments, args.modality, 54 | base_model=args.arch, 55 | consensus_type=args.consensus_type, 56 | dropout=args.dropout, 57 | img_feature_dim=args.img_feature_dim, 58 | partial_bn=not args.no_partialbn, 59 | pretrain=args.pretrain, 60 | second_segments = args.second_segments, 61 | is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place, 62 | fc_lr5=not (args.tune_from and args.dataset in args.tune_from), 63 | temporal_pool=args.temporal_pool, 64 | non_local=args.non_local) 65 | print("==============model desccription=============") 66 | print(model) 67 | crop_size = model.crop_size 68 | scale_size = model.scale_size 69 | input_mean = model.input_mean 70 | input_std = model.input_std 71 | policies = model.get_optim_policies() 72 | train_augmentation = model.get_augmentation(flip=args.flip) 73 | 74 | model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() 75 | 76 | optimizer = torch.optim.SGD(policies, 77 | args.lr, 78 | momentum=args.momentum, 79 | weight_decay=args.weight_decay) 80 | 81 | if args.resume: 82 | if args.temporal_pool: # early temporal pool so that we can load the state_dict 83 | make_temporal_pool(model.module.base_model, args.num_segments) 84 | if os.path.isfile(args.resume): 85 | print(("=> loading checkpoint '{}'".format(args.resume))) 86 | checkpoint = torch.load(args.resume) 87 | args.start_epoch = checkpoint['epoch'] 88 | best_prec1 = checkpoint['best_prec1'] 89 | model.load_state_dict(checkpoint['state_dict']) 90 | optimizer.load_state_dict(checkpoint['optimizer']) 91 | print(("=> loaded checkpoint '{}' (epoch {})" 92 | .format(args.evaluate, checkpoint['epoch']))) 93 | else: 94 | print(("=> no checkpoint found at '{}'".format(args.resume))) 95 | 96 | if args.tune_from: 97 | print(("=> fine-tuning from '{}'".format(args.tune_from))) 98 | sd = torch.load(args.tune_from) 99 | sd = sd['state_dict'] 100 | model_dict = model.state_dict() 101 | replace_dict = [] 102 | for k, v in sd.items(): 103 | if k not in model_dict and k.replace('.net', '') in model_dict: 104 | print('=> Load after remove .net: ', k) 105 | replace_dict.append((k, k.replace('.net', ''))) 106 | for k, v in model_dict.items(): 107 | if k not in sd and k.replace('.net', '') in sd: 108 | print('=> Load after adding .net: ', k) 109 | replace_dict.append((k.replace('.net', ''), k)) 110 | 111 | for k, k_new in replace_dict: 112 | sd[k_new] = sd.pop(k) 113 | keys1 = set(list(sd.keys())) 114 | keys2 = set(list(model_dict.keys())) 115 | set_diff = (keys1 - keys2) | (keys2 - keys1) 116 | print('#### Notice: keys that failed to load: {}'.format(set_diff)) 117 | if args.dataset not in args.tune_from: # new dataset 118 | print('=> New dataset, do not load fc weights') 119 | sd = {k: v for k, v in sd.items() if 'fc' not in k} 120 | model_dict.update(sd) 121 | model.load_state_dict(model_dict) 122 | 123 | if args.temporal_pool and not args.resume: 124 | make_temporal_pool(model.module.base_model, args.num_segments) 125 | 126 | cudnn.benchmark = True 127 | 128 | # Data loading code 129 | if args.modality != 'RGBDiff': 130 | normalize = GroupNormalize(input_mean, input_std) 131 | else: 132 | normalize = IdentityTransform() 133 | 134 | if args.modality == 'RGB': 135 | data_length = 1 136 | elif args.modality in ['Flow', 'RGBDiff']: 137 | data_length = 5 138 | 139 | labeled_trainloader = torch.utils.data.DataLoader( 140 | TSNDataSet(args.root_path, args.labeled_train_list, unlabeled=False, 141 | num_segments=args.num_segments, 142 | new_length=data_length, 143 | modality=args.modality, 144 | image_tmpl=prefix, 145 | second_segments = args.second_segments, 146 | transform=torchvision.transforms.Compose([ 147 | train_augmentation, 148 | Stack( 149 | roll=(args.arch in ['BNInception', 'InceptionV3'])), 150 | ToTorchFormatTensor( 151 | div=(args.arch not in ['BNInception', 'InceptionV3'])), 152 | normalize, 153 | ]), dense_sample=args.dense_sample), 154 | batch_size=args.batch_size, shuffle=True, 155 | num_workers=args.workers, pin_memory=True, 156 | drop_last=False) # prevent something not % n_GPU 157 | 158 | unlabeled_trainloader = torch.utils.data.DataLoader( 159 | TSNDataSet(args.root_path, args.unlabeled_train_list, unlabeled=True, 160 | num_segments=args.num_segments, 161 | new_length=data_length, 162 | modality=args.modality, 163 | image_tmpl=prefix, 164 | second_segments = args.second_segments, 165 | transform=torchvision.transforms.Compose([ 166 | train_augmentation, 167 | Stack( 168 | roll=(args.arch in ['BNInception', 'InceptionV3'])), 169 | ToTorchFormatTensor( 170 | div=(args.arch not in ['BNInception', 'InceptionV3'])), 171 | normalize, 172 | ]), dense_sample=args.dense_sample), 173 | batch_size=np.int(np.round(args.mu * args.batch_size)), shuffle=True, 174 | num_workers=args.workers, pin_memory=True, 175 | drop_last=False) # prevent something not % n_GPU 176 | 177 | val_loader = torch.utils.data.DataLoader( 178 | TSNDataSet(args.root_path, args.val_list, unlabeled=False, 179 | num_segments=args.num_segments, 180 | new_length=data_length, 181 | modality=args.modality, 182 | image_tmpl=prefix, 183 | random_shift=False, 184 | second_segments = args.second_segments, 185 | transform=torchvision.transforms.Compose([ 186 | GroupScale(int(scale_size)), 187 | GroupCenterCrop(crop_size), 188 | Stack( 189 | roll=(args.arch in ['BNInception', 'InceptionV3'])), 190 | ToTorchFormatTensor( 191 | div=(args.arch not in ['BNInception', 'InceptionV3'])), 192 | normalize, 193 | ]), dense_sample=args.dense_sample), 194 | batch_size=args.valbatchsize, shuffle=False, 195 | num_workers=args.workers, pin_memory=True) 196 | 197 | # define loss function (criterion) and optimizer 198 | if args.loss_type == 'nll': 199 | criterion = torch.nn.CrossEntropyLoss().cuda() 200 | else: 201 | raise ValueError("Unknown loss type") 202 | 203 | for group in policies: 204 | print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( 205 | group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) 206 | 207 | if args.evaluate: 208 | validate(val_loader, model, criterion, 0) 209 | return 210 | 211 | log_training = open(os.path.join( 212 | args.root_log, args.store_name, 'log.csv'), 'w') 213 | with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: 214 | f.write(str(args)) 215 | default_start = 0 216 | is_finetune_lr_set= False 217 | for epoch in range(args.start_epoch, args.epochs): 218 | if args.use_finetuning and epoch >= args.finetune_start_epoch: 219 | args.eval_freq = args.finetune_stage_eval_freq 220 | if args.use_finetuning and epoch >= args.finetune_start_epoch and args.finetune_lr > 0.0 and not is_finetune_lr_set: 221 | args.lr = args.finetune_lr 222 | default_start = args.finetune_start_epoch 223 | is_finetune_lr_set = True 224 | adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps,default_start, using_policy=True) 225 | 226 | # train for one epoch 227 | train(labeled_trainloader, unlabeled_trainloader, model, 228 | criterion, optimizer, epoch, log_training) 229 | 230 | # evaluate on validation set 231 | if ((epoch + 1) % args.eval_freq == 0 or epoch == (args.epochs - 1) or (epoch+1)== args.finetune_start_epoch) : 232 | prec1 = validate(val_loader, model, criterion, 233 | epoch, log_training) 234 | 235 | # remember best prec@1 and save checkpoint 236 | is_best = prec1 > best_prec1 237 | best_prec1 = max(prec1, best_prec1) 238 | 239 | output_best = 'Best Prec@1: %.3f\n' % (best_prec1) 240 | print(output_best) 241 | log_training.write(output_best + '\n') 242 | log_training.flush() 243 | if args.use_finetuning and (epoch+1) == args.finetune_start_epoch: 244 | one_stage_pl=True 245 | else: 246 | one_stage_pl = False 247 | save_checkpoint({ 248 | 'epoch': epoch + 1, 249 | 'arch': args.arch, 250 | 'state_dict': model.state_dict(), 251 | 'optimizer': optimizer.state_dict(), 252 | 'best_prec1': best_prec1, 253 | }, is_best,one_stage_pl) 254 | 255 | 256 | def train(labeled_trainloader, unlabeled_trainloader, model, criterion, optimizer, epoch, log): 257 | batch_time = AverageMeter() 258 | data_time = AverageMeter() 259 | total_losses = AverageMeter() 260 | supervised_losses = AverageMeter() 261 | contrastive_losses = AverageMeter() 262 | group_contrastive_losses = AverageMeter() 263 | pl_losses = AverageMeter() 264 | top1 = AverageMeter() 265 | top5 = AverageMeter() 266 | 267 | model = model.cuda() 268 | 269 | if args.no_partialbn: 270 | model.module.partialBN(False) 271 | else: 272 | model.module.partialBN(True) 273 | 274 | # switch to train mode 275 | model.train() 276 | if epoch >= args.sup_thresh or (args.use_finetuning and epoch >= args.finetune_start_epoch): 277 | data_loader = zip(labeled_trainloader, unlabeled_trainloader) 278 | else: 279 | data_loader = labeled_trainloader 280 | 281 | end = time.time() 282 | 283 | for i, data in enumerate(data_loader): 284 | # measure data loading time 285 | data_time.update(time.time() - end) 286 | #reseting losses 287 | contrastive_loss = torch.tensor(0.0).cuda() 288 | pl_loss = torch.tensor(0.0).cuda() 289 | loss = torch.tensor(0.0).cuda() 290 | group_contrastive_loss = torch.tensor(0.0).cuda() 291 | 292 | if epoch >= args.sup_thresh or (args.use_finetuning and epoch >= args.finetune_start_epoch): 293 | 294 | (labeled_data,unlabeled_data) =data 295 | images_fast, images_slow = unlabeled_data 296 | images_slow = images_slow.cuda() 297 | images_fast = images_fast.cuda() 298 | images_slow = torch.autograd.Variable(images_slow) 299 | images_fast = torch.autograd.Variable(images_fast) 300 | 301 | # contrastive_loss 302 | output_fast = model(images_fast) 303 | if not args.use_finetuning or epoch < args.finetune_start_epoch: 304 | output_slow = model(images_slow, unlabeled=True) 305 | output_fast_detach = output_fast.detach() 306 | if epoch >= args.sup_thresh and epoch < args.finetune_start_epoch: 307 | contrastive_loss = simclr_loss(torch.softmax(output_fast_detach,dim=1),torch.softmax(output_slow,dim=1)) 308 | if args.use_group_contrastive: 309 | grp_unlabeled_8seg = get_group(output_fast_detach) 310 | grp_unlabeled_4seg = get_group(output_slow) 311 | group_contrastive_loss = compute_group_contrastive_loss(grp_unlabeled_8seg,grp_unlabeled_4seg) 312 | elif args.use_finetuning and epoch >= args.finetune_start_epoch: 313 | pseudo_label = torch.softmax(output_fast_detach, dim=-1) 314 | max_probs, targets_pl = torch.max(pseudo_label, dim=-1) 315 | mask = max_probs.ge(args.threshold).float() 316 | targets_pl = torch.autograd.Variable(targets_pl) 317 | 318 | pl_loss = (F.cross_entropy(output_fast, targets_pl, 319 | reduction='none') * mask).mean() 320 | else: 321 | labeled_data = data 322 | input, target = labeled_data 323 | target = target.cuda() 324 | input = input.cuda() 325 | input = torch.autograd.Variable(input) 326 | target_var = torch.autograd.Variable(target) 327 | output = model(input) 328 | loss = criterion(output, target_var) 329 | 330 | total_loss = loss + args.gamma*contrastive_loss + group_contrastive_loss + args.gamma_finetune*pl_loss 331 | # measure accuracy and record loss 332 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 333 | if epoch >= args.sup_thresh: 334 | total_losses.update(total_loss.item(), input.size(0)+args.mu*input.size(0)) 335 | else: 336 | total_losses.update(total_loss.item(), input.size(0)) 337 | supervised_losses.update(loss.item(), input.size(0)) 338 | contrastive_losses.update(contrastive_loss.item(), input.size(0)+args.mu*input.size(0)) 339 | group_contrastive_losses.update(group_contrastive_loss.item(), input.size(0)+args.mu*input.size(0)) 340 | pl_losses.update(pl_loss.item(), input.size(0)+args.mu*input.size(0)) 341 | top1.update(prec1.item(), input.size(0)) 342 | top5.update(prec5.item(), input.size(0)) 343 | 344 | # compute gradient and do SGD step 345 | total_loss.backward() 346 | 347 | if args.clip_gradient is not None: 348 | total_norm = clip_grad_norm_( 349 | model.parameters(), args.clip_gradient) 350 | 351 | optimizer.step() 352 | optimizer.zero_grad() 353 | 354 | # measure elapsed time 355 | batch_time.update(time.time() - end) 356 | end = time.time() 357 | 358 | if i % args.print_freq == 0: 359 | output = ('Epoch: [{0}][{1}], lr: {lr:.5f}\t' 360 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 361 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 362 | 'TotalLoss {total_loss.val:.4f} ({total_loss.avg:.4f})\t' 363 | 'Supervised Loss {loss.val:.4f} ({loss.avg:.4f})\t' 364 | 'Contrastive_Loss {contrastive_loss.val:.4f} ({contrastive_loss.avg:.4f})\t' 365 | 'Group_contrastive_Loss {group_contrastive_loss.val:.4f} ({group_contrastive_loss.avg:.4f})\t' 366 | 'Pseudo_Loss {pl_loss.val:.4f} ({pl_loss.avg:.4f})\t' 367 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 368 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 369 | epoch, i, batch_time=batch_time, 370 | data_time=data_time, total_loss=total_losses,loss=supervised_losses, 371 | contrastive_loss=contrastive_losses,group_contrastive_loss=group_contrastive_losses,pl_loss=pl_losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO 372 | print(output) 373 | log.write(output + '\n') 374 | log.flush() 375 | 376 | def validate(val_loader, model, criterion, epoch, log=None): 377 | batch_time = AverageMeter() 378 | losses = AverageMeter() 379 | top1 = AverageMeter() 380 | top5 = AverageMeter() 381 | 382 | # switch to evaluate mode 383 | model.eval() 384 | 385 | end = time.time() 386 | with torch.no_grad(): 387 | for i, (input, target) in enumerate(val_loader): 388 | target = target.cuda() 389 | 390 | # compute output 391 | output = model(input) 392 | loss = criterion(output, target) 393 | 394 | # measure accuracy and record loss 395 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 396 | 397 | losses.update(loss.item(), input.size(0)) 398 | top1.update(prec1.item(), input.size(0)) 399 | top5.update(prec5.item(), input.size(0)) 400 | 401 | # measure elapsed time 402 | batch_time.update(time.time() - end) 403 | end = time.time() 404 | 405 | if i % args.print_freq == 0: 406 | output = ('Test: [{0}/{1}]\t' 407 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 408 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 409 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 410 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 411 | i, len(val_loader), batch_time=batch_time, loss=losses, 412 | top1=top1, top5=top5)) 413 | print(output) 414 | if log is not None: 415 | log.write(output + '\n') 416 | log.flush() 417 | 418 | output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' 419 | .format(top1=top1, top5=top5, loss=losses)) 420 | print(output) 421 | if log is not None: 422 | log.write(output + '\n') 423 | log.flush() 424 | 425 | return top1.avg 426 | 427 | def get_group(output): 428 | logits = torch.softmax(output, dim=-1) 429 | _ , target = torch.max(logits, dim=-1) 430 | groups ={} 431 | for x,y in zip(target, logits): 432 | group = groups.get(x.item(),[]) 433 | group.append(y) 434 | groups[x.item()]= group 435 | return groups 436 | 437 | def compute_group_contrastive_loss(grp_dict_un,grp_dict_lab): 438 | loss = [] 439 | l_fast =[] 440 | l_slow =[] 441 | for key in grp_dict_un.keys(): 442 | if key in grp_dict_lab: 443 | l_fast.append(torch.stack(grp_dict_un[key]).mean(dim=0)) 444 | l_slow.append(torch.stack(grp_dict_lab[key]).mean(dim=0)) 445 | if len(l_fast) > 0: 446 | l_fast = torch.stack(l_fast) 447 | l_slow = torch.stack(l_slow) 448 | loss = simclr_loss(l_fast,l_slow) 449 | loss = max(torch.tensor(0.000).cuda(),loss) 450 | else: 451 | loss= torch.tensor(0.0).cuda() 452 | return loss 453 | 454 | 455 | def save_checkpoint(state, is_best, one_stage_pl=False): 456 | filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name) 457 | torch.save(state, filename) 458 | if is_best: 459 | shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar')) 460 | if one_stage_pl: 461 | shutil.copyfile(filename, filename.replace('pth.tar', 'before_finetune.pth.tar')) 462 | 463 | 464 | def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps,default_start,using_policy): 465 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 466 | if lr_type == 'step': 467 | decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) 468 | lr = args.lr * decay 469 | decay = args.weight_decay 470 | elif lr_type == 'cos': 471 | import math 472 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * (epoch - default_start) / args.epochs)) 473 | decay = args.weight_decay 474 | else: 475 | raise NotImplementedError 476 | if using_policy: 477 | for param_group in optimizer.param_groups: 478 | param_group['lr'] = lr * param_group['lr_mult'] 479 | param_group['weight_decay'] = decay * param_group['decay_mult'] 480 | else: 481 | for param_group in optimizer.param_groups: 482 | param_group['lr'] = lr 483 | param_group['weight_decay'] = decay 484 | 485 | 486 | def check_rootfolders(): 487 | """Create log and model folder""" 488 | folders_util = [args.root_log, args.root_model, 489 | os.path.join(args.root_log, args.store_name), 490 | os.path.join(args.root_model, args.store_name)] 491 | for folder in folders_util: 492 | if not os.path.exists(folder): 493 | print('creating folder ' + folder) 494 | os.mkdir(folder) 495 | 496 | 497 | 498 | def split_file(file, unlabeled, labeled, percentage, isShuffle=True, seed=123, strategy='classwise'): 499 | """Splits a file in 2 given the `percentage` to go in the large file.""" 500 | if strategy == 'classwise': 501 | if os.path.exists(unlabeled) and os.path.exists(labeled): 502 | print("path exists with this seed and strategy") 503 | return 504 | random.seed(seed) 505 | #creating dictionary against each category 506 | def del_list(list_delete,indices_to_delete): 507 | for i in sorted(indices_to_delete, reverse=True): 508 | del(list_delete[i]) 509 | 510 | main_dict= defaultdict(list) 511 | with open(file,'r') as mainfile: 512 | lines = mainfile.readlines() 513 | for line in lines: 514 | video_info = line.strip().split() 515 | main_dict[video_info[2]].append((video_info[0],video_info[1])) 516 | with open(unlabeled,'w') as ul,\ 517 | open(labeled,'w') as l: 518 | for key,value in main_dict.items(): 519 | length_videos = len(value) 520 | ul_no_videos = int((length_videos* percentage)) 521 | indices = random.sample(range(length_videos),ul_no_videos) 522 | for index in indices: 523 | line_to_written = value[index][0] + " " + value[index][1] + " " +key+"\n" 524 | ul.write(line_to_written) 525 | del_list(value,indices) 526 | for label_index in range(len(value)): 527 | line_to_written = value[label_index][0] + " " + value[label_index][1] + " " +key+"\n" 528 | l.write(line_to_written) 529 | 530 | 531 | 532 | if strategy == 'overall': 533 | if os.path.exists(unlabeled) and os.path.exists(labeled): 534 | print("path exists with this seed and strategy") 535 | return 536 | random.seed(seed) 537 | with open(file, 'r') as fin, \ 538 | open(unlabeled, 'w') as foutBig, \ 539 | open(labeled, 'w') as foutSmall: 540 | # if didn't count you could only approximate the percentage 541 | lines = fin.readlines() 542 | random.shuffle(lines) 543 | nLines = sum(1 for line in lines) 544 | nTrain = int(nLines*percentage) 545 | i = 0 546 | for line in lines: 547 | line = line.rstrip('\n') + "\n" 548 | if i < nTrain: 549 | foutBig.write(line) 550 | i += 1 551 | else: 552 | foutSmall.write(line) 553 | 554 | def get_training_filenames(train_file_path): 555 | labeled_file_path = os.path.join("Run_"+str(int(np.round((1-args.percentage)*100))),args.dataset+'_'+str(args.seed)+args.strategy+"_labeled_training.txt") 556 | unlabeled_file_path = os.path.join("Run_"+str(int(np.round((1-args.percentage)*100))),args.dataset+'_'+str(args.seed)+args.strategy+"_unlabeled_training.txt") 557 | split_file(train_file_path, unlabeled_file_path, 558 | labeled_file_path,args.percentage, isShuffle=True,seed=args.seed, strategy=args.strategy) 559 | return labeled_file_path, unlabeled_file_path 560 | 561 | def simclr_loss(output_fast,output_slow,normalize=True): 562 | out = torch.cat((output_fast, output_slow), dim=0) 563 | sim_mat = torch.mm(out, torch.transpose(out,0,1)) 564 | if normalize: 565 | sim_mat_denom = torch.mm(torch.norm(out, dim=1).unsqueeze(1), torch.norm(out, dim=1).unsqueeze(1).t()) 566 | sim_mat = sim_mat / sim_mat_denom.clamp(min=1e-16) 567 | sim_mat = torch.exp(sim_mat / args.Temperature) 568 | if normalize: 569 | sim_mat_denom = torch.norm(output_fast, dim=1) * torch.norm(output_slow, dim=1) 570 | sim_match = torch.exp(torch.sum(output_fast * output_slow, dim=-1) / sim_mat_denom / args.Temperature) 571 | else: 572 | sim_match = torch.exp(torch.sum(output_fast * output_slow, dim=-1) / args.Temperature) 573 | sim_match = torch.cat((sim_match, sim_match), dim=0) 574 | norm_sum = torch.exp(torch.ones(out.size(0)) / args.Temperature ) 575 | norm_sum = norm_sum.cuda() 576 | loss = torch.mean(-torch.log(sim_match / (torch.sum(sim_mat, dim=-1) - norm_sum))) 577 | return loss 578 | 579 | if __name__ == '__main__': 580 | main() 581 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from ops.basic_ops import * -------------------------------------------------------------------------------- /ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Identity(torch.nn.Module): 5 | def forward(self, input): 6 | return input 7 | 8 | 9 | class SegmentConsensus(torch.nn.Module): 10 | 11 | def __init__(self, consensus_type, dim=1): 12 | super(SegmentConsensus, self).__init__() 13 | self.consensus_type = consensus_type 14 | self.dim = dim 15 | self.shape = None 16 | 17 | def forward(self, input_tensor): 18 | self.shape = input_tensor.size() 19 | if self.consensus_type == 'avg': 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) 21 | elif self.consensus_type == 'identity': 22 | output = input_tensor 23 | else: 24 | output = None 25 | 26 | return output 27 | 28 | 29 | class ConsensusModule(torch.nn.Module): 30 | 31 | def __init__(self, consensus_type, dim=1): 32 | super(ConsensusModule, self).__init__() 33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 34 | self.dim = dim 35 | 36 | def forward(self, input): 37 | return SegmentConsensus(self.consensus_type, self.dim)(input) 38 | -------------------------------------------------------------------------------- /ops/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | from numpy.random import randint 7 | import torch 8 | import random 9 | from ops.transforms import * 10 | from torchvision import transforms 11 | 12 | class VideoRecord(object): 13 | def __init__(self, row): 14 | self._data = row 15 | 16 | @property 17 | def path(self): 18 | return self._data[0] 19 | 20 | @property 21 | def num_frames(self): 22 | return int(self._data[1]) 23 | 24 | @property 25 | def label(self): 26 | return int(self._data[2]) 27 | 28 | 29 | class TSNDataSet(data.Dataset): 30 | def __init__(self, root_path, list_file, unlabeled= False,second_segments=2, 31 | num_segments=3, new_length=1, modality='RGB', 32 | image_tmpl='img_{:05d}.jpg', transform=None, 33 | random_shift=True, test_mode=False, 34 | remove_missing=False, dense_sample=False, twice_sample=False): 35 | 36 | self.root_path = root_path 37 | self.list_file = list_file 38 | self.num_segments = num_segments 39 | self.unlabeled = unlabeled 40 | self.new_length = new_length 41 | self.second_segments = second_segments 42 | self.modality = modality 43 | self.image_tmpl = image_tmpl 44 | self.transform = transform 45 | self.random_shift = random_shift 46 | self.test_mode = test_mode 47 | self.remove_missing = remove_missing 48 | self.dense_sample = dense_sample # using dense sample as I3D 49 | self.twice_sample = twice_sample # twice sample for more validation 50 | if self.dense_sample: 51 | print('=> Using dense sample for the dataset...') 52 | if self.twice_sample: 53 | print('=> Using twice sample for the dataset...') 54 | 55 | if self.modality == 'RGBDiff': 56 | self.new_length += 1 # Diff needs one more image to calculate diff 57 | 58 | self._parse_list() 59 | 60 | def _load_image(self, directory, idx): 61 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 62 | try: 63 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 64 | except Exception: 65 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 66 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 67 | 68 | def _parse_list(self): 69 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 70 | if not self.test_mode or self.remove_missing: 71 | tmp = [item for item in tmp if int(item[1]) >= 3] 72 | self.video_list = [VideoRecord(item) for item in tmp] 73 | 74 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 75 | for v in self.video_list: 76 | v._data[1] = int(v._data[1]) / 2 77 | print('video number:%d' % (len(self.video_list))) 78 | 79 | def _sample_indices(self, record,num_segments): 80 | """ 81 | 82 | :param record: VideoRecord 83 | :return: list 84 | """ 85 | if self.dense_sample: # i3d dense sample 86 | sample_pos = max(1, 1 + record.num_frames - 64) 87 | t_stride = 64 // num_segments 88 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 89 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(num_segments)] 90 | return np.array(offsets) + 1 91 | else: # normal sample 92 | average_duration = (record.num_frames - self.new_length + 1) // num_segments 93 | if average_duration > 0: 94 | offsets = np.multiply(list(range(num_segments)), average_duration) + randint(average_duration, 95 | size=num_segments) 96 | elif record.num_frames > num_segments: 97 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=num_segments)) 98 | else: 99 | offsets = np.zeros((num_segments,)) 100 | return offsets + 1 101 | 102 | def _get_val_indices(self, record,num_segments): 103 | if self.dense_sample: # i3d dense sample 104 | sample_pos = max(1, 1 + record.num_frames - 64) 105 | t_stride = 64 // num_segments 106 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 107 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(num_segments)] 108 | return np.array(offsets) + 1 109 | else: 110 | if record.num_frames > num_segments + self.new_length - 1: 111 | tick = (record.num_frames - self.new_length + 1) / float(num_segments) 112 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(num_segments)]) 113 | else: 114 | offsets = np.zeros((num_segments,)) 115 | return offsets + 1 116 | 117 | def _get_test_indices(self, record,num_segments): 118 | if self.dense_sample: 119 | sample_pos = max(1, 1 + record.num_frames - 64) 120 | t_stride = 64 // num_segments 121 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) 122 | offsets = [] 123 | for start_idx in start_list.tolist(): 124 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(num_segments)] 125 | return np.array(offsets) + 1 126 | elif self.twice_sample: 127 | tick = (record.num_frames - self.new_length + 1) / float(num_segments) 128 | 129 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(num_segments)] + 130 | [int(tick * x) for x in range(num_segments)]) 131 | 132 | return offsets + 1 133 | else: 134 | tick = (record.num_frames - self.new_length + 1) / float(num_segments) 135 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(num_segments)]) 136 | return offsets + 1 137 | 138 | def __getitem__(self, index): 139 | record = self.video_list[index] 140 | # check this is a legit video folder 141 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 142 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 143 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 144 | else: 145 | file_name = self.image_tmpl.format(1) 146 | full_path = os.path.join(self.root_path, record.path, file_name) 147 | 148 | while not os.path.exists(full_path): 149 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) 150 | index = np.random.randint(len(self.video_list)) 151 | record = self.video_list[index] 152 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 153 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 154 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 155 | else: 156 | file_name = self.image_tmpl.format(1) 157 | full_path = os.path.join(self.root_path, record.path, file_name) 158 | 159 | if not self.test_mode: 160 | if self.random_shift: 161 | if self.unlabeled: 162 | segment_indices_fast = self._sample_indices(record,self.num_segments) 163 | segment_indices_slow = self._sample_indices(record, self.second_segments) 164 | fast_data,_ = self.get(record, segment_indices_fast) 165 | slow_data,_= self.get(record,segment_indices_slow) 166 | return fast_data,slow_data 167 | 168 | 169 | if not self.unlabeled: 170 | segment_indices = self._sample_indices(record,self.num_segments) 171 | else: 172 | segment_indices = self._get_val_indices(record,self.num_segments) 173 | else: 174 | segment_indices = self._get_test_indices(record,self.num_segments) 175 | return self.get(record, segment_indices) 176 | 177 | def get(self, record, indices): 178 | 179 | images = list() 180 | for seg_ind in indices: 181 | p = int(seg_ind) 182 | for i in range(self.new_length): 183 | seg_imgs = self._load_image(record.path, p) 184 | images.extend(seg_imgs) 185 | if p < record.num_frames: 186 | p += 1 187 | 188 | process_data = self.transform(images) 189 | return process_data, record.label 190 | 191 | def __len__(self): 192 | return len(self.video_list) 193 | -------------------------------------------------------------------------------- /ops/dataset_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | def get_rootdataset(dataset): 6 | with open('root_dataset.yaml') as file: 7 | dataset_path = yaml.load(file, Loader=yaml.FullLoader) 8 | return dataset_path[dataset] 9 | 10 | 11 | def return_ucf101(modality): 12 | filename_categories = 'data/classInd.txt' 13 | if modality == 'RGB': 14 | root_data = ROOT_DATASET + 'Frames/' 15 | filename_imglist_train = 'data/ucf101_rgb_train_split_1.txt' 16 | filename_imglist_val = 'data/ucf101_rgb_val_split_1.txt' 17 | prefix = 'img_{:05d}.jpg' 18 | else: 19 | raise NotImplementedError('no such modality:' + modality) 20 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 21 | 22 | 23 | def return_hmdb51(modality): 24 | filename_categories = 51 25 | if modality == 'RGB': 26 | root_data = ROOT_DATASET + 'HMDB51/images' 27 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt' 28 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt' 29 | prefix = 'img_{:05d}.jpg' 30 | else: 31 | raise NotImplementedError('no such modality:' + modality) 32 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 33 | 34 | 35 | def return_something(modality): 36 | filename_categories = 'something/v1/category.txt' 37 | if modality == 'RGB': 38 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1' 39 | filename_imglist_train = 'something/v1/train_videofolder.txt' 40 | filename_imglist_val = 'something/v1/val_videofolder.txt' 41 | prefix = '{:05d}.jpg' 42 | else: 43 | print('no such modality:'+modality) 44 | raise NotImplementedError 45 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 46 | 47 | 48 | def return_somethingv2(modality): 49 | filename_categories = 'data/category_mini.txt' 50 | if modality == 'RGB': 51 | root_data = ROOT_DATASET + 'Frames' 52 | filename_imglist_train = 'data/train_videofolder_mini.txt' 53 | filename_imglist_val = 'data/val_videofolder_mini.txt' 54 | prefix = '{:06d}.jpg' 55 | else: 56 | raise NotImplementedError('no such modality:'+modality) 57 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 58 | 59 | def return_mini_moments(modality): 60 | filename_categories = 'data/categories.txt' 61 | if modality == 'RGB': 62 | root_data = ROOT_DATASET + 'Frames' 63 | filename_imglist_train = 'data/train_videofolder.txt' 64 | filename_imglist_val = 'data/val_videofolder.txt' 65 | prefix = '{:06d}.jpg' 66 | else: 67 | raise NotImplementedError('no such modality:'+modality) 68 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 69 | 70 | 71 | def return_jester(modality): 72 | filename_categories = 'data/classInd.txt' 73 | if modality == 'RGB': 74 | prefix = '{:05d}.jpg' 75 | root_data = ROOT_DATASET + 'Frames' 76 | filename_imglist_train = 'data/train_videofolder.txt' 77 | filename_imglist_val = 'data/val_videofolder.txt' 78 | else: 79 | raise NotImplementedError('no such modality:'+modality) 80 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 81 | 82 | def return_charades_ego(modality): 83 | filename_categories = 157#'Charades_v1_classes.txt' 84 | if modality == 'RGB': 85 | prefix = '{:06d}.jpg' 86 | root_data = ROOT_DATASET + 'Frames' 87 | filename_imglist_train_1p = 'data/train_only1st_segments.txt' 88 | filename_imglist_train_3p = 'data/train_only3rd_segments.txt' 89 | filename_imglist_val_1p = 'data/test_only1st_segments.txt' 90 | filename_imglist_val_3p = 'data/test_only3rd_segments.txt' 91 | else: 92 | raise NotImplementedError('no such modality:'+modality) 93 | return filename_categories, filename_imglist_train_1p, filename_imglist_train_3p, filename_imglist_val_1p, filename_imglist_val_3p, root_data, prefix 94 | 95 | 96 | def return_charades_full(modality): 97 | filename_categories = 157 98 | if modality == 'RGB': 99 | prefix = '{:06d}.jpg' 100 | root_data = ROOT_DATASET + 'Frames' 101 | filename_imglist_train = 'data/train_segments.txt'#'data/train_videofolder.txt' 102 | filename_imglist_val = 'data/test_segments.txt' #'data/val_videofolder.txt' 103 | else: 104 | raise NotImplementedError('no such modality:'+modality) 105 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 106 | 107 | 108 | def return_kinetics(modality): 109 | filename_categories = 400 110 | if modality == 'RGB': 111 | root_data = ROOT_DATASET + 'kinetics/images' 112 | filename_imglist_train = 'kinetics/labels/train_videofolder.txt' 113 | filename_imglist_val = 'kinetics/labels/val_videofolder.txt' 114 | prefix = 'img_{:05d}.jpg' 115 | else: 116 | raise NotImplementedError('no such modality:' + modality) 117 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 118 | 119 | 120 | def return_dataset(dataset, modality): 121 | global ROOT_DATASET 122 | ROOT_DATASET = get_rootdataset(dataset) 123 | dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2, 124 | 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, 'mini-moments': return_mini_moments, 125 | 'kinetics': return_kinetics, 'charades_full':return_charades_full } 126 | dict_charades = {'charades_ego': return_charades_ego} 127 | if dataset in dict_single: 128 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) 129 | elif dataset in dict_charades: 130 | file_categories, filename_imglist_train_1p, filename_imglist_train_3p, filename_imglist_val_1p, filename_imglist_val_3p, root_data, prefix = dict_charades[dataset](modality) 131 | file_imglist_train_1p = os.path.join(ROOT_DATASET, filename_imglist_train_1p) 132 | file_imglist_train_3p = os.path.join(ROOT_DATASET, filename_imglist_train_3p) 133 | file_imglist_val_1p = os.path.join(ROOT_DATASET, filename_imglist_val_1p) 134 | file_imglist_val_3p = os.path.join(ROOT_DATASET, filename_imglist_val_3p) 135 | return file_categories, file_imglist_train_1p, file_imglist_train_3p, file_imglist_val_1p, file_imglist_val_3p, root_data, prefix 136 | else: 137 | raise ValueError('Unknown dataset '+dataset) 138 | 139 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) 140 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) 141 | if isinstance(file_categories, str): 142 | file_categories = os.path.join(ROOT_DATASET, file_categories) 143 | with open(file_categories) as f: 144 | lines = f.readlines() 145 | categories = [item.rstrip() for item in lines] 146 | else: # number of categories 147 | categories = [None] * file_categories 148 | n_class = len(categories) 149 | print('{}: {} classes'.format(dataset, n_class)) 150 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix 151 | -------------------------------------------------------------------------------- /ops/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ops.basic_ops import ConsensusModule 4 | from ops.transforms import * 5 | from torch.nn.init import normal_, constant_ 6 | from archs.resnet import resnet_dict 7 | 8 | class TSN(nn.Module): 9 | def __init__(self, num_class, num_segments, modality, 10 | base_model='resnet18', new_length=None,second_segments=2, 11 | consensus_type='avg', before_softmax=True, 12 | dropout=0.8, img_feature_dim=256, 13 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet', 14 | is_shift=False, shift_div=8, shift_place='blockres', fc_lr5=False, 15 | temporal_pool=False, non_local=False): 16 | super(TSN, self).__init__() 17 | self.modality = modality 18 | self.num_segments = num_segments 19 | self.reshape = True 20 | self.before_softmax = before_softmax 21 | self.dropout = dropout 22 | self.crop_num = crop_num 23 | self.consensus_type = consensus_type 24 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame 25 | self.pretrain = pretrain 26 | self.second_segments = second_segments 27 | self.is_shift = is_shift 28 | self.shift_div = shift_div 29 | self.shift_place = shift_place 30 | self.base_model_name = base_model 31 | self.fc_lr5 = fc_lr5 32 | self.temporal_pool = temporal_pool 33 | self.non_local = non_local 34 | 35 | if not before_softmax and consensus_type != 'avg': 36 | raise ValueError("Only avg consensus can be used after Softmax") 37 | 38 | if new_length is None: 39 | self.new_length = 1 if modality == "RGB" else 5 40 | else: 41 | self.new_length = new_length 42 | if print_spec: 43 | print((""" 44 | Initializing TSN with base model: {}. 45 | TSN Configurations: 46 | input_modality: {} 47 | num_segments: {} 48 | new_length: {} 49 | consensus_module: {} 50 | dropout_ratio: {} 51 | img_feature_dim: {} 52 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) 53 | 54 | self._prepare_base_model(base_model) 55 | 56 | feature_dim = self._prepare_tsn(num_class) 57 | 58 | self.consensus = ConsensusModule(consensus_type) 59 | 60 | if not self.before_softmax: 61 | self.softmax = nn.Softmax() 62 | 63 | self._enable_pbn = partial_bn 64 | if partial_bn: 65 | self.partialBN(True) 66 | 67 | def _prepare_tsn(self, num_class): 68 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 69 | if self.dropout == 0: 70 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) 71 | self.new_fc = None 72 | else: 73 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 74 | self.new_fc = nn.Linear(feature_dim, num_class) 75 | 76 | std = 0.001 77 | if self.new_fc is None: 78 | normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) 79 | constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) 80 | else: 81 | if hasattr(self.new_fc, 'weight'): 82 | normal_(self.new_fc.weight, 0, std) 83 | constant_(self.new_fc.bias, 0) 84 | return feature_dim 85 | 86 | def _prepare_base_model(self, base_model): 87 | print('=> base model: {}'.format(base_model)) 88 | 89 | if 'resnet' in base_model: 90 | self.base_model = resnet_dict[base_model](True if self.pretrain == 'imagenet' else False) 91 | if self.is_shift: 92 | print('Adding temporal shift...') 93 | from ops.temporal_shift import make_temporal_shift 94 | make_temporal_shift(self.base_model, self.num_segments,second_segments= self.second_segments, 95 | n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool) 96 | 97 | if self.non_local: 98 | print('Adding non-local module...') 99 | from ops.non_local import make_non_local 100 | make_non_local(self.base_model, self.num_segments) 101 | 102 | self.base_model.last_layer_name = 'fc' 103 | self.input_size = 224 104 | self.input_mean = [0.485, 0.456, 0.406] 105 | self.input_std = [0.229, 0.224, 0.225] 106 | 107 | self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) 108 | else: 109 | raise ValueError('Unknown base model: {}'.format(base_model)) 110 | 111 | def train(self, mode=True): 112 | """ 113 | Override the default train() to freeze the BN parameters 114 | :return: 115 | """ 116 | super(TSN, self).train(mode) 117 | count = 0 118 | #this happens when training wiith pretrained weights 119 | if self._enable_pbn and mode: 120 | print("Freezing BatchNorm2D except the first one.") 121 | for m in self.base_model.modules(): 122 | if isinstance(m, nn.BatchNorm2d): 123 | count += 1 124 | if count >= (2 if self._enable_pbn else 1): 125 | m.eval() 126 | # shutdown update in frozen mode 127 | m.weight.requires_grad = False 128 | m.bias.requires_grad = False 129 | 130 | def partialBN(self, enable): 131 | self._enable_pbn = enable 132 | 133 | def get_optim_policies(self): 134 | first_conv_weight = [] 135 | first_conv_bias = [] 136 | normal_weight = [] 137 | normal_bias = [] 138 | lr5_weight = [] 139 | lr10_bias = [] 140 | bn = [] 141 | custom_ops = [] 142 | 143 | conv_cnt = 0 144 | bn_cnt = 0 145 | for m in self.modules(): 146 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 147 | ps = list(m.parameters()) 148 | conv_cnt += 1 149 | if conv_cnt == 1: 150 | first_conv_weight.append(ps[0]) 151 | if len(ps) == 2: 152 | first_conv_bias.append(ps[1]) 153 | else: 154 | normal_weight.append(ps[0]) 155 | if len(ps) == 2: 156 | normal_bias.append(ps[1]) 157 | elif isinstance(m, torch.nn.Linear): 158 | ps = list(m.parameters()) 159 | if self.fc_lr5: 160 | lr5_weight.append(ps[0]) 161 | else: 162 | normal_weight.append(ps[0]) 163 | if len(ps) == 2: 164 | if self.fc_lr5: 165 | lr10_bias.append(ps[1]) 166 | else: 167 | normal_bias.append(ps[1]) 168 | 169 | elif isinstance(m, torch.nn.BatchNorm2d): 170 | bn_cnt += 1 171 | # later BN's are frozen 172 | if not self._enable_pbn or bn_cnt == 1: 173 | bn.extend(list(m.parameters())) 174 | elif isinstance(m, torch.nn.BatchNorm3d): 175 | bn_cnt += 1 176 | # later BN's are frozen 177 | if not self._enable_pbn or bn_cnt == 1: 178 | bn.extend(list(m.parameters())) 179 | elif len(m._modules) == 0: 180 | if len(list(m.parameters())) > 0: 181 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) 182 | 183 | return [ 184 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, 185 | 'name': "first_conv_weight"}, 186 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, 187 | 'name': "first_conv_bias"}, 188 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 189 | 'name': "normal_weight"}, 190 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 191 | 'name': "normal_bias"}, 192 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 193 | 'name': "BN scale/shift"}, 194 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, 195 | 'name': "custom_ops"}, 196 | # for fc 197 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1, 198 | 'name': "lr5_weight"}, 199 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0, 200 | 'name': "lr10_bias"}, 201 | ] 202 | # return self.base_model.parameters() 203 | 204 | def forward(self, input, unlabeled=False, no_reshape=False): 205 | if not no_reshape: 206 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 207 | 208 | base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:]),unlabeled) 209 | else: 210 | base_out = self.base_model(input,unlabeled) 211 | 212 | if self.dropout > 0: 213 | base_out = self.new_fc(base_out) 214 | 215 | if not self.before_softmax: 216 | base_out = self.softmax(base_out) 217 | 218 | num_segments=self.num_segments 219 | 220 | if unlabeled: 221 | num_segments=self.second_segments 222 | if self.reshape: 223 | if self.is_shift and self.temporal_pool: 224 | base_out = base_out.view((-1, num_segments // 2) + base_out.size()[1:]) 225 | else: 226 | base_out = base_out.view((-1, num_segments) + base_out.size()[1:]) 227 | #print(base_out.size()) 228 | output = self.consensus(base_out) 229 | return output.squeeze(1) 230 | 231 | @property 232 | def crop_size(self): 233 | return self.input_size 234 | 235 | @property 236 | def scale_size(self): 237 | return self.input_size * 256 // 224 238 | 239 | def get_augmentation(self, flip=True): 240 | if self.modality == 'RGB': 241 | if flip: # sometimes 'if flip' doesn't work 242 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 243 | GroupRandomHorizontalFlip(is_flow=False)]) 244 | else: 245 | print('#' * 20, 'NO FLIP!!!') 246 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 247 | elif self.modality == 'Flow': 248 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 249 | GroupRandomHorizontalFlip(is_flow=True)]) 250 | elif self.modality == 'RGBDiff': 251 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 252 | GroupRandomHorizontalFlip(is_flow=False)]) 253 | -------------------------------------------------------------------------------- /ops/non_local.py: -------------------------------------------------------------------------------- 1 | # Non-local block using embedded gaussian 2 | # Code from 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class _NonLocalBlockND(nn.Module): 10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 11 | super(_NonLocalBlockND, self).__init__() 12 | 13 | assert dimension in [1, 2, 3] 14 | 15 | self.dimension = dimension 16 | self.sub_sample = sub_sample 17 | 18 | self.in_channels = in_channels 19 | self.inter_channels = inter_channels 20 | 21 | if self.inter_channels is None: 22 | self.inter_channels = in_channels // 2 23 | if self.inter_channels == 0: 24 | self.inter_channels = 1 25 | 26 | if dimension == 3: 27 | conv_nd = nn.Conv3d 28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 29 | bn = nn.BatchNorm3d 30 | elif dimension == 2: 31 | conv_nd = nn.Conv2d 32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 33 | bn = nn.BatchNorm2d 34 | else: 35 | conv_nd = nn.Conv1d 36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 37 | bn = nn.BatchNorm1d 38 | 39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | 42 | if bn_layer: 43 | self.W = nn.Sequential( 44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 45 | kernel_size=1, stride=1, padding=0), 46 | bn(self.in_channels) 47 | ) 48 | nn.init.constant_(self.W[1].weight, 0) 49 | nn.init.constant_(self.W[1].bias, 0) 50 | else: 51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 52 | kernel_size=1, stride=1, padding=0) 53 | nn.init.constant_(self.W.weight, 0) 54 | nn.init.constant_(self.W.bias, 0) 55 | 56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 59 | kernel_size=1, stride=1, padding=0) 60 | 61 | if sub_sample: 62 | self.g = nn.Sequential(self.g, max_pool_layer) 63 | self.phi = nn.Sequential(self.phi, max_pool_layer) 64 | 65 | def forward(self, x): 66 | ''' 67 | :param x: (b, c, t, h, w) 68 | :return: 69 | ''' 70 | 71 | batch_size = x.size(0) 72 | 73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 74 | g_x = g_x.permute(0, 2, 1) 75 | 76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 77 | theta_x = theta_x.permute(0, 2, 1) 78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 79 | f = torch.matmul(theta_x, phi_x) 80 | f_div_C = F.softmax(f, dim=-1) 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | return z 89 | 90 | 91 | class NONLocalBlock1D(_NonLocalBlockND): 92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 93 | super(NONLocalBlock1D, self).__init__(in_channels, 94 | inter_channels=inter_channels, 95 | dimension=1, sub_sample=sub_sample, 96 | bn_layer=bn_layer) 97 | 98 | 99 | class NONLocalBlock2D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock2D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=2, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock3D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock3D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=3, sub_sample=sub_sample, 112 | bn_layer=bn_layer) 113 | 114 | 115 | class NL3DWrapper(nn.Module): 116 | def __init__(self, block, n_segment): 117 | super(NL3DWrapper, self).__init__() 118 | self.block = block 119 | self.nl = NONLocalBlock3D(block.bn3.num_features) 120 | self.n_segment = n_segment 121 | 122 | def forward(self, x): 123 | x = self.block(x) 124 | 125 | nt, c, h, w = x.size() 126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 127 | x = self.nl(x) 128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) 129 | return x 130 | 131 | 132 | def make_non_local(net, n_segment): 133 | import torchvision 134 | import archs 135 | if isinstance(net, torchvision.models.ResNet): 136 | net.layer2 = nn.Sequential( 137 | NL3DWrapper(net.layer2[0], n_segment), 138 | net.layer2[1], 139 | NL3DWrapper(net.layer2[2], n_segment), 140 | net.layer2[3], 141 | ) 142 | net.layer3 = nn.Sequential( 143 | NL3DWrapper(net.layer3[0], n_segment), 144 | net.layer3[1], 145 | NL3DWrapper(net.layer3[2], n_segment), 146 | net.layer3[3], 147 | NL3DWrapper(net.layer3[4], n_segment), 148 | net.layer3[5], 149 | ) 150 | else: 151 | raise NotImplementedError 152 | 153 | 154 | if __name__ == '__main__': 155 | from torch.autograd import Variable 156 | import torch 157 | 158 | sub_sample = True 159 | bn_layer = True 160 | 161 | img = Variable(torch.zeros(2, 3, 20)) 162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 163 | out = net(img) 164 | print(out.size()) 165 | 166 | img = Variable(torch.zeros(2, 3, 20, 20)) 167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 168 | out = net(img) 169 | print(out.size()) 170 | 171 | img = Variable(torch.randn(2, 3, 10, 20, 20)) 172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 173 | out = net(img) 174 | print(out.size()) -------------------------------------------------------------------------------- /ops/temporal_shift.py: -------------------------------------------------------------------------------- 1 | from archs.resnet import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TemporalShift(nn.Module): 8 | def __init__(self, n_segment=3, n_div=8, inplace=False, second_segments=2): 9 | super(TemporalShift, self).__init__() 10 | self.n_segment = n_segment 11 | self.fold_div = n_div 12 | self.inplace = inplace 13 | self.second_segments = second_segments 14 | if inplace: 15 | print('=> Using in-place shift...') 16 | print('=> Using fold div: {}'.format(self.fold_div)) 17 | 18 | def forward(self, x, unlabeled): 19 | if unlabeled: 20 | #print("using unlabeled shift") 21 | x = self.shift(x,self.second_segments, 22 | fold_div=self.fold_div, inplace=self.inplace) 23 | else: 24 | #print(x.size()) 25 | x = self.shift(x, self.n_segment, 26 | fold_div=self.fold_div, inplace=self.inplace) 27 | return x 28 | 29 | @staticmethod 30 | def shift(x, n_segment, fold_div=3, inplace=False): 31 | #print("segment_size is {}".format(n_segment)) 32 | nt, c, h, w = x.size() 33 | n_batch = nt // n_segment 34 | x = x.view(n_batch, n_segment, c, h, w) 35 | 36 | fold = c // fold_div 37 | if inplace: 38 | # Due to some out of order error when performing parallel computing. 39 | # May need to write a CUDA kernel. 40 | #raise NotImplementedError 41 | out = InplaceShift.apply(x, fold) 42 | else: 43 | out = torch.zeros_like(x) 44 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 45 | out[:, 1:, fold: 2 * fold] = x[:, :- 46 | 1, fold: 2 * fold] # shift right 47 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 48 | 49 | return out.view(nt, c, h, w) 50 | 51 | 52 | class InplaceShift(torch.autograd.Function): 53 | # Special thanks to @raoyongming for the help to this function 54 | @staticmethod 55 | def forward(ctx, input, fold): 56 | # not support higher order gradient 57 | # input = input.detach_() 58 | ctx.fold_ = fold 59 | n, t, c, h, w = input.size() 60 | buffer = input.data.new(n, t, fold, h, w).zero_() 61 | buffer[:, :-1] = input.data[:, 1:, :fold] 62 | input.data[:, :, :fold] = buffer 63 | buffer.zero_() 64 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] 65 | input.data[:, :, fold: 2 * fold] = buffer 66 | return input 67 | 68 | @staticmethod 69 | def backward(ctx, grad_output): 70 | # grad_output = grad_output.detach_() 71 | fold = ctx.fold_ 72 | n, t, c, h, w = grad_output.size() 73 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() 74 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] 75 | grad_output.data[:, :, :fold] = buffer 76 | buffer.zero_() 77 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] 78 | grad_output.data[:, :, fold: 2 * fold] = buffer 79 | return grad_output, None 80 | 81 | 82 | 83 | class TemporalPool(nn.Module): 84 | def __init__(self, net, n_segment): 85 | super(TemporalPool, self).__init__() 86 | self.net = net 87 | self.n_segment = n_segment 88 | 89 | def forward(self, x): 90 | x = self.temporal_pool(x, n_segment=self.n_segment) 91 | return self.net(x) 92 | 93 | @staticmethod 94 | def temporal_pool(x, n_segment): 95 | raise NotImplementedError 96 | # nt, c, h, w = x.size() 97 | # n_batch = nt // n_segment 98 | # x = x.view(n_batch, n_segment, c, h, w).transpose( 99 | # 1, 2) # n, c, t, h, w 100 | # x = F.max_pool3d(x, kernel_size=(3, 1, 1), 101 | # stride=(2, 1, 1), padding=(1, 0, 0)) 102 | # x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) 103 | # return x 104 | 105 | 106 | class make_block_temporal(nn.Module): 107 | def __init__(self, stage, this_segment=3,n_div=8, second_segments=2): 108 | super(make_block_temporal, self).__init__() 109 | self.blocks = nn.ModuleList(list(stage.children())) 110 | self.second_segments = second_segments 111 | print('=> Processing stage with {} blocks'.format(len(self.blocks))) 112 | self.temporal_shift = TemporalShift(n_segment=this_segment, n_div=n_div, second_segments= self.second_segments) 113 | # for i, b in enumerate(self.blocks): 114 | # self.blocks[i]= nn.Sequential(b) 115 | def forward(self,x,unlabeled=False): 116 | for i, b in enumerate(self.blocks): 117 | x= self.temporal_shift(x,unlabeled) 118 | x = self.blocks[i](x) 119 | return x 120 | 121 | class make_blockres_temporal(nn.Module): 122 | def __init__(self, stage, this_segment=3,n_div=8, n_round=1, second_segments=2): 123 | super(make_blockres_temporal, self).__init__() 124 | self.blocks = nn.ModuleList(list(stage.children())) 125 | self.second_segments = second_segments 126 | self.n_round = n_round 127 | print('=> Processing stage with {} blocks'.format(len(self.blocks))) 128 | self.temporal_shift = TemporalShift(n_segment=this_segment, n_div=n_div, second_segments=self.second_segments) 129 | # for i, b in enumerate(self.blocks): 130 | # self.blocks[i]= nn.Sequential(b) 131 | # print(self.blocks[i]) 132 | 133 | def forward(self,x,unlabeled=False): 134 | #print("make_block_res_temporal_called") 135 | for i, b in enumerate(self.blocks): 136 | #print(x.size()) 137 | if i% self.n_round == 0: 138 | #print("size of x is {}".format(x.size())) 139 | x= self.temporal_shift(x,unlabeled) 140 | x = self.blocks[i](x) 141 | return x 142 | 143 | 144 | def make_temporal_shift(net, n_segment,second_segments=2, n_div=8, place='blockres', temporal_pool=False): 145 | if temporal_pool: 146 | n_segment_list = [n_segment, n_segment // 147 | 2, n_segment // 2, n_segment // 2] 148 | else: 149 | n_segment_list = [n_segment] * 4 150 | assert n_segment_list[-1] > 0 151 | print('=> n_segment per stage: {}'.format(n_segment_list)) 152 | 153 | import torchvision 154 | if isinstance(net, ResNet): 155 | if place == 'block': 156 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0], n_div, second_segments) 157 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1], n_div, second_segments) 158 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2], n_div, second_segments) 159 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3], n_div, second_segments) 160 | 161 | elif 'blockres' in place: 162 | n_round = 1 163 | if len(list(net.layer3.children())) >= 23: 164 | n_round = 2 165 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) 166 | 167 | net.layer1 = make_blockres_temporal(net.layer1, n_segment_list[0], n_div, n_round, second_segments) 168 | net.layer2 = make_blockres_temporal(net.layer2, n_segment_list[1], n_div, n_round, second_segments) 169 | net.layer3 = make_blockres_temporal(net.layer3, n_segment_list[2], n_div, n_round, second_segments) 170 | net.layer4 = make_blockres_temporal(net.layer4, n_segment_list[3], n_div, n_round, second_segments) 171 | else: 172 | raise NotImplementedError(place) 173 | 174 | 175 | def make_temporal_pool(net, n_segment): 176 | import torchvision 177 | if isinstance(net, torchvision.models.ResNet): 178 | print('=> Injecting nonlocal pooling') 179 | net.layer2 = TemporalPool(net.layer2, n_segment) 180 | else: 181 | raise NotImplementedError 182 | 183 | 184 | if __name__ == '__main__': 185 | # test inplace shift v.s. vanilla shift 186 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) 187 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) 188 | 189 | print('=> Testing CPU...') 190 | # test forward 191 | with torch.no_grad(): 192 | for i in range(10): 193 | x = torch.rand(2 * 8, 3, 224, 224) 194 | y1 = tsm1(x) 195 | y2 = tsm2(x) 196 | assert torch.norm(y1 - y2).item() < 1e-5 197 | 198 | # test backward 199 | with torch.enable_grad(): 200 | for i in range(10): 201 | x1 = torch.rand(2 * 8, 3, 224, 224) 202 | x1.requires_grad_() 203 | x2 = x1.clone() 204 | y1 = tsm1(x1) 205 | y2 = tsm2(x2) 206 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 207 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 208 | assert torch.norm(grad1 - grad2).item() < 1e-5 209 | 210 | print('=> Testing GPU...') 211 | tsm1.cuda() 212 | tsm2.cuda() 213 | # test forward 214 | with torch.no_grad(): 215 | for i in range(10): 216 | x = torch.rand(2 * 8, 3, 224, 224).cuda() 217 | y1 = tsm1(x) 218 | y2 = tsm2(x) 219 | assert torch.norm(y1 - y2).item() < 1e-5 220 | 221 | # test backward 222 | with torch.enable_grad(): 223 | for i in range(10): 224 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() 225 | x1.requires_grad_() 226 | x2 = x1.clone() 227 | y1 = tsm1(x1) 228 | y2 = tsm2(x2) 229 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 230 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 231 | assert torch.norm(grad1 - grad2).item() < 1e-5 232 | print('Test passed.') 233 | -------------------------------------------------------------------------------- /ops/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size, scale_size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | self.scale_size = (int(scale_size), int(scale_size)) 15 | else: 16 | self.size = size 17 | self.scale_size = scale_size 18 | def __call__(self, img_group): 19 | 20 | if img_group[0].size < self.size: 21 | scale = GroupScale(self.scale_size) 22 | res_img_group = scale(img_group) 23 | img_group = res_img_group 24 | 25 | w, h = img_group[0].size 26 | th, tw = self.size 27 | out_images = list() 28 | x1 = random.randint(0, w - tw) 29 | y1 = random.randint(0, h - th) 30 | #print(w, h,x1,y1) 31 | for img in img_group: 32 | assert(img.size[0] == w and img.size[1] == h) 33 | if w == tw and h == th: 34 | out_images.append(img) 35 | else: 36 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 37 | 38 | return out_images 39 | 40 | 41 | class GroupCenterCrop(object): 42 | def __init__(self, size): 43 | self.worker = torchvision.transforms.CenterCrop(size) 44 | 45 | def __call__(self, img_group): 46 | return [self.worker(img) for img in img_group] 47 | 48 | 49 | class GroupRandomHorizontalFlip(object): 50 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 51 | """ 52 | def __init__(self, is_flow=False): 53 | self.is_flow = is_flow 54 | 55 | def __call__(self, img_group, is_flow=False): 56 | v = random.random() 57 | if v < 0.5: 58 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 59 | if self.is_flow: 60 | for i in range(0, len(ret), 2): 61 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 62 | return ret 63 | else: 64 | return img_group 65 | 66 | 67 | class GroupNormalize(object): 68 | def __init__(self, mean, std): 69 | self.mean = mean 70 | self.std = std 71 | 72 | def __call__(self, tensor): 73 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 74 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 75 | 76 | # TODO: make efficient 77 | for t, m, s in zip(tensor, rep_mean, rep_std): 78 | t.sub_(m).div_(s) 79 | 80 | return tensor 81 | 82 | 83 | class GroupScale(object): 84 | """ Rescales the input PIL.Image to the given 'size'. 85 | 'size' will be the size of the smaller edge. 86 | For example, if height > width, then image will be 87 | rescaled to (size * height / width, size) 88 | size: size of the smaller edge 89 | interpolation: Default: PIL.Image.BILINEAR 90 | """ 91 | 92 | def __init__(self, size, interpolation=Image.BILINEAR): 93 | self.worker = torchvision.transforms.Resize(size, interpolation) 94 | 95 | def __call__(self, img_group): 96 | return [self.worker(img) for img in img_group] 97 | 98 | 99 | class GroupOverSample(object): 100 | def __init__(self, crop_size, scale_size=None, flip=True): 101 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 102 | 103 | if scale_size is not None: 104 | self.scale_worker = GroupScale(scale_size) 105 | else: 106 | self.scale_worker = None 107 | self.flip = flip 108 | 109 | def __call__(self, img_group): 110 | 111 | if self.scale_worker is not None: 112 | img_group = self.scale_worker(img_group) 113 | 114 | image_w, image_h = img_group[0].size 115 | crop_w, crop_h = self.crop_size 116 | 117 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 118 | oversample_group = list() 119 | for o_w, o_h in offsets: 120 | normal_group = list() 121 | flip_group = list() 122 | for i, img in enumerate(img_group): 123 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 124 | normal_group.append(crop) 125 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 126 | 127 | if img.mode == 'L' and i % 2 == 0: 128 | flip_group.append(ImageOps.invert(flip_crop)) 129 | else: 130 | flip_group.append(flip_crop) 131 | 132 | oversample_group.extend(normal_group) 133 | if self.flip: 134 | oversample_group.extend(flip_group) 135 | return oversample_group 136 | 137 | 138 | class GroupFullResSample(object): 139 | def __init__(self, crop_size, scale_size=None, flip=True): 140 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 141 | 142 | if scale_size is not None: 143 | self.scale_worker = GroupScale(scale_size) 144 | else: 145 | self.scale_worker = None 146 | self.flip = flip 147 | 148 | def __call__(self, img_group): 149 | 150 | if self.scale_worker is not None: 151 | img_group = self.scale_worker(img_group) 152 | 153 | image_w, image_h = img_group[0].size 154 | crop_w, crop_h = self.crop_size 155 | 156 | w_step = (image_w - crop_w) // 4 157 | h_step = (image_h - crop_h) // 4 158 | 159 | offsets = list() 160 | offsets.append((0 * w_step, 2 * h_step)) # left 161 | offsets.append((4 * w_step, 2 * h_step)) # right 162 | offsets.append((2 * w_step, 2 * h_step)) # center 163 | 164 | oversample_group = list() 165 | for o_w, o_h in offsets: 166 | normal_group = list() 167 | flip_group = list() 168 | for i, img in enumerate(img_group): 169 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 170 | normal_group.append(crop) 171 | if self.flip: 172 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 173 | 174 | if img.mode == 'L' and i % 2 == 0: 175 | flip_group.append(ImageOps.invert(flip_crop)) 176 | else: 177 | flip_group.append(flip_crop) 178 | 179 | oversample_group.extend(normal_group) 180 | oversample_group.extend(flip_group) 181 | return oversample_group 182 | 183 | 184 | class GroupMultiScaleCrop(object): 185 | 186 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 187 | self.scales = scales if scales is not None else [1, .875, .75, .66] 188 | self.max_distort = max_distort 189 | self.fix_crop = fix_crop 190 | self.more_fix_crop = more_fix_crop 191 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 192 | self.interpolation = Image.BILINEAR 193 | 194 | def __call__(self, img_group): 195 | 196 | im_size = img_group[0].size 197 | 198 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 199 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 200 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 201 | for img in crop_img_group] 202 | return ret_img_group 203 | 204 | def _sample_crop_size(self, im_size): 205 | image_w, image_h = im_size[0], im_size[1] 206 | 207 | # find a crop size 208 | base_size = min(image_w, image_h) 209 | crop_sizes = [int(base_size * x) for x in self.scales] 210 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 211 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 212 | 213 | pairs = [] 214 | for i, h in enumerate(crop_h): 215 | for j, w in enumerate(crop_w): 216 | if abs(i - j) <= self.max_distort: 217 | pairs.append((w, h)) 218 | 219 | crop_pair = random.choice(pairs) 220 | if not self.fix_crop: 221 | w_offset = random.randint(0, image_w - crop_pair[0]) 222 | h_offset = random.randint(0, image_h - crop_pair[1]) 223 | else: 224 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 225 | 226 | return crop_pair[0], crop_pair[1], w_offset, h_offset 227 | 228 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 229 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 230 | return random.choice(offsets) 231 | 232 | @staticmethod 233 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 234 | w_step = (image_w - crop_w) // 4 235 | h_step = (image_h - crop_h) // 4 236 | 237 | ret = list() 238 | ret.append((0, 0)) # upper left 239 | ret.append((4 * w_step, 0)) # upper right 240 | ret.append((0, 4 * h_step)) # lower left 241 | ret.append((4 * w_step, 4 * h_step)) # lower right 242 | ret.append((2 * w_step, 2 * h_step)) # center 243 | 244 | if more_fix_crop: 245 | ret.append((0, 2 * h_step)) # center left 246 | ret.append((4 * w_step, 2 * h_step)) # center right 247 | ret.append((2 * w_step, 4 * h_step)) # lower center 248 | ret.append((2 * w_step, 0 * h_step)) # upper center 249 | 250 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 251 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 252 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 253 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 254 | 255 | return ret 256 | 257 | 258 | class GroupRandomSizedCrop(object): 259 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 260 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 261 | This is popularly used to train the Inception networks 262 | size: size of the smaller edge 263 | interpolation: Default: PIL.Image.BILINEAR 264 | """ 265 | def __init__(self, size, interpolation=Image.BILINEAR): 266 | self.size = size 267 | self.interpolation = interpolation 268 | 269 | def __call__(self, img_group): 270 | for attempt in range(10): 271 | area = img_group[0].size[0] * img_group[0].size[1] 272 | target_area = random.uniform(0.08, 1.0) * area 273 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 274 | 275 | w = int(round(math.sqrt(target_area * aspect_ratio))) 276 | h = int(round(math.sqrt(target_area / aspect_ratio))) 277 | 278 | if random.random() < 0.5: 279 | w, h = h, w 280 | 281 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 282 | x1 = random.randint(0, img_group[0].size[0] - w) 283 | y1 = random.randint(0, img_group[0].size[1] - h) 284 | found = True 285 | break 286 | else: 287 | found = False 288 | x1 = 0 289 | y1 = 0 290 | 291 | if found: 292 | out_group = list() 293 | for img in img_group: 294 | img = img.crop((x1, y1, x1 + w, y1 + h)) 295 | assert(img.size == (w, h)) 296 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 297 | return out_group 298 | else: 299 | # Fallback 300 | scale = GroupScale(self.size, interpolation=self.interpolation) 301 | crop = GroupRandomCrop(self.size) 302 | return crop(scale(img_group)) 303 | 304 | 305 | class Stack(object): 306 | 307 | def __init__(self, roll=False): 308 | self.roll = roll 309 | 310 | def __call__(self, img_group): 311 | if img_group[0].mode == 'L': 312 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 313 | elif img_group[0].mode == 'RGB': 314 | if self.roll: 315 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 316 | else: 317 | return np.concatenate(img_group, axis=2) 318 | 319 | 320 | class ToTorchFormatTensor(object): 321 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 322 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 323 | def __init__(self, div=True): 324 | self.div = div 325 | 326 | def __call__(self, pic): 327 | if isinstance(pic, np.ndarray): 328 | # handle numpy array 329 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 330 | else: 331 | # handle PIL Image 332 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 333 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 334 | # put it from HWC to CHW format 335 | # yikes, this transpose takes 80% of the loading time/CPU 336 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 337 | return img.float().div(255) if self.div else img.float() 338 | 339 | 340 | class IdentityTransform(object): 341 | 342 | def __call__(self, data): 343 | return data 344 | 345 | 346 | if __name__ == "__main__": 347 | trans = torchvision.transforms.Compose([ 348 | GroupScale(256), 349 | GroupRandomCrop(224), 350 | Stack(), 351 | ToTorchFormatTensor(), 352 | GroupNormalize( 353 | mean=[.485, .456, .406], 354 | std=[.229, .224, .225] 355 | )] 356 | ) 357 | 358 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 359 | 360 | color_group = [im] * 3 361 | rst = trans(color_group) 362 | 363 | gray_group = [im.convert('L')] * 9 364 | gray_rst = trans(gray_group) 365 | 366 | trans2 = torchvision.transforms.Compose([ 367 | GroupRandomSizedCrop(256), 368 | Stack(), 369 | ToTorchFormatTensor(), 370 | GroupNormalize( 371 | mean=[.485, .456, .406], 372 | std=[.229, .224, .225]) 373 | ]) 374 | print(trans2(color_group)) -------------------------------------------------------------------------------- /ops/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import functional as F 3 | 4 | 5 | def softmax(scores): 6 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 7 | return es / es.sum(axis=-1)[..., None] 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | """Computes the precision@k for the specified values of k""" 31 | maxk = max(topk) 32 | batch_size = target.size(0) 33 | 34 | _, pred = output.topk(maxk, 1, True, True) 35 | pred = pred.t() 36 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 37 | 38 | res = [] 39 | for k in topk: 40 | correct_k = correct[:k].view(-1).float().sum(0) 41 | res.append(correct_k.mul_(100.0 / batch_size)) 42 | return res 43 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser( 3 | description="PyTorch implementation of Temporal Segment Networks") 4 | parser.add_argument('dataset', type=str) 5 | parser.add_argument('modality', type=str, choices=['RGB', 'Flow']) 6 | parser.add_argument('--train_list', type=str, default="") 7 | parser.add_argument('--val_list', type=str, default="") 8 | parser.add_argument('--labeled_train_list', type=str, default="") 9 | parser.add_argument('--unlabeled_train_list', type=str, default="") 10 | parser.add_argument('--root_path', type=str, default="") 11 | parser.add_argument('--store_name', type=str, default="") 12 | # ========================= Model Configs ========================== 13 | parser.add_argument('--arch', type=str, default="resnet18") 14 | parser.add_argument('--num_segments', type=int, default=3) 15 | parser.add_argument('--consensus_type', type=str, default='avg') 16 | parser.add_argument('--k', type=int, default=3) 17 | parser.add_argument('--num_clip', type=int, default=10, help='For number of clips for Video Acc') 18 | parser.add_argument('--num_crop', type=int, default=3, help='For number of crops for Video Acc') 19 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 20 | metavar='DO', help='dropout ratio (default: 0.5)') 21 | parser.add_argument('--loss_type', type=str, default="nll", 22 | choices=['nll']) 23 | parser.add_argument('--img_feature_dim', default=256, type=int, 24 | help="the feature dimension for each frame") 25 | parser.add_argument('--suffix', type=str, default=None) 26 | parser.add_argument('--pretrain', type=str, default="") 27 | parser.add_argument('--tune_from', type=str, default=None, 28 | help='fine-tune from checkpoint') 29 | parser.add_argument('--strategy', type=str, default='classwise', help='[classwise, overall] strategy for sampling') 30 | parser.add_argument('--resume_pretrain',type=str, default='pretrain', help='[finetune, pretrain] which part to resume training ONLY FOR UNS_PRETRAIN') 31 | parser.add_argument('--valbatchsize', default=16, type=int, help='mini-batch size for validation') 32 | parser.add_argument('--sup_thresh',default=50,type=int, help='threshold epchs for pseduo label calculation') 33 | parser.add_argument('--use_group_contrastive',action ='store_true', default=False) 34 | parser.add_argument('--use_finetuning', action ='store_true',default=False) 35 | parser.add_argument('--finetune_start_epoch', default=400, type =int, help='when to start the fine-tune using PL') 36 | parser.add_argument('--Temperature', default=0.5, type=float, help='temperature for sharpening') 37 | parser.add_argument('--finetune_lr', default=-1.0, type=float, help='set fine tune lr for last stage PL') 38 | parser.add_argument('--gamma_finetune',default=9.0, type=float, help= 'weight for pl_loss') 39 | parser.add_argument('--finetune_stage_eval_freq', default=1, type=int, help='frequency for evaluating at finetuning stage') 40 | parser.add_argument('--finetune_epochs', type=int, default=100) 41 | parser.add_argument('--start_finetune_epoch', type=int, default=0) 42 | # ========================= Learning Configs ========================== 43 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 44 | help='number of total epochs to run') 45 | parser.add_argument('-b', '--batch-size', default=8, type=int, 46 | metavar='N', help='mini-batch size (default: 256)') 47 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 48 | metavar='LR', help='initial learning rate') 49 | parser.add_argument('--lr_type', default='cos', type=str, 50 | metavar='LRtype', help='learning rate type') 51 | parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+", 52 | metavar='LRSteps', help='epochs to decay learning rate by 10') 53 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 54 | help='momentum') 55 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 56 | metavar='W', help='weight decay (default: 5e-4)') 57 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 58 | metavar='W', help='gradient norm clipping (default: disabled)') 59 | parser.add_argument('--no_partialbn', '--npb', 60 | default=False, action="store_true") 61 | parser.add_argument('--gamma', default=1.0, type=float, 62 | metavar='G', help='weight of contrastive loss') 63 | parser.add_argument('--percentage', default=0.95, type=float, 64 | help='should be between 0 and 1. decides percent of training\ 65 | data to be allocated to unlabeled data') 66 | parser.add_argument('--threshold', default=0.95, type=float, 67 | help='threshold for pseduo labels') 68 | parser.add_argument('--mu',default=8, type=float, help= 'coefficient for unlabeled data') 69 | parser.add_argument('--second_segments', default=2, type=int, help='no of segments for second branch') 70 | 71 | # ========================= Monitor Configs ========================== 72 | parser.add_argument('--print-freq', '-p', default=20, type=int, 73 | metavar='N', help='print frequency (default: 10)') 74 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 75 | metavar='N', help='evaluation frequency (default: 5)') 76 | parser.add_argument('--seed', default=123, type=int, 77 | help='seed for labeled and unlabeled data separation') 78 | # ========================= Runtime Configs ========================== 79 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 80 | help='number of data loading workers (default: 8)') 81 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 82 | help='path to latest checkpoint (default: none)') 83 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 84 | help='evaluate model on validation set') 85 | parser.add_argument('--snapshot_pref', type=str, default="") 86 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 87 | help='manual epoch number (useful on restarts)') 88 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 89 | parser.add_argument('--flow_prefix', default="", type=str) 90 | parser.add_argument('--root_log', type=str, default='checkpoint') 91 | parser.add_argument('--root_model', type=str, default='checkpoint') 92 | 93 | parser.add_argument('--shift', default=False, 94 | action="store_true", help='use shift for models') 95 | parser.add_argument('--flip', action="store_true", help='Mention this flag if RandomHorizontalFlip is required else do not mention') 96 | parser.add_argument('--shift_div', default=8, type=int, 97 | help='number of div for shift (default: 8)') 98 | parser.add_argument('--shift_place', default='blockres', 99 | type=str, help='place for shift (default: stageres)') 100 | 101 | parser.add_argument('--temporal_pool', default=False, 102 | action="store_true", help='add temporal pooling') 103 | parser.add_argument('--non_local', default=False, 104 | action="store_true", help='add non local block') 105 | 106 | parser.add_argument('--dense_sample', default=False, 107 | action="store_true", help='use dense sample for video dataset') 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pygments==2.6.1 2 | tqdm==4.46.0 3 | torch==1.4.0 4 | opencv_python==4.4.0.42 5 | numpy==1.19.2 6 | sortedcontainers==2.1.0 7 | torchvision==0.5.0 8 | Pillow==8.2.0 9 | PyYAML==5.4.1 10 | scikit_learn>=0.24.2 11 | -------------------------------------------------------------------------------- /root_dataset.yaml: -------------------------------------------------------------------------------- 1 | #root folder where each dataset is contained 2 | #dataset should contain the frames as well as train and 3 | #validation folder as defined in dataset_config 4 | 5 | somethingv2: 'datasets/Mini-Something-V2/' 6 | jester: '' 7 | kinetics: '' 8 | charades_ego: '' 9 | -------------------------------------------------------------------------------- /tools/extract_videos_kinetics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import json 6 | import numpy as np 7 | import skvideo.io 8 | import cv2 9 | import sys 10 | import concurrent.futures 11 | from shutil import copyfile 12 | import subprocess 13 | from tqdm import tqdm 14 | from random import shuffle 15 | 16 | parser = argparse.ArgumentParser(description='[Something-Something-V2] Video conversion') 17 | parser.add_argument('-i', '--input_root', help='location of input video and csv files.', type=str) 18 | parser.add_argument('-o', '--output_root', type=str, help='output image locations, ' 19 | 'generates `train` `val` `test` folders') 20 | parser.add_argument('-s', '--shorter_side', default=[256], type=int, nargs="+", 21 | help='shorter side of the generated image, if two values are provided, ' 22 | 'convert to the targeted size regardless the aspect ratio. [w,h]') 23 | parser.add_argument('-p', '--num_processes', type=int, help='number of processor', default=36) 24 | parser.add_argument('-n', '--num_classes', type=int, help='number of classes', 25 | choices=[400, 600], default=400) 26 | parser.add_argument('--do-test-set', action='store_true', help='convert test set') 27 | 28 | args = parser.parse_args() 29 | 30 | # input 31 | # is_400 = True 32 | # shorter_side = 331 33 | # image_format = 'jpg' 34 | 35 | train_file = "{}/data/kinetics-{}_train.csv".format(args.input_root, args.num_classes) 36 | val_file = "{}/data/kinetics-{}_val.csv".format(args.input_root, args.num_classes) 37 | test_file = "{}/data/kinetics-{}_test.csv".format(args.input_root, args.num_classes) 38 | train_video_folder = "{}/train".format(args.input_root) 39 | val_video_folder = "{}/val".format(args.input_root) 40 | test_video_folder = "{}/test".format(args.input_root) 41 | 42 | # output 43 | label_file = "{}/images/kinetics-{}_label.txt".format(args.output_root, args.num_classes) 44 | train_img_folder = "{}/images/train/".format(args.output_root) 45 | val_img_folder = "{}/images/val/".format(args.output_root) 46 | test_img_folder = "{}/images/test/".format(args.output_root) 47 | train_file_list = "{}/train_{}.txt".format(args.output_root, args.num_classes) 48 | val_file_list = "{}/val_{}.txt".format(args.output_root, args.num_classes) 49 | test_file_list = "{}/test_{}.txt".format(args.output_root, args.num_classes) 50 | 51 | train_fail_file_list = "{}/train_fail_{}.txt".format(args.output_root, args.num_classes) 52 | val_fail_file_list = "{}/val_fail_{}.txt".format(args.output_root, args.num_classes) 53 | test_fail_file_list = "{}/test_fail_{}.txt".format(args.output_root, args.num_classes) 54 | 55 | 56 | if not os.path.exists(os.path.join(args.output_root, 'images')): 57 | os.makedirs(os.path.join(args.output_root, 'images')) 58 | 59 | def load_video_list(file_path, build_label=False): 60 | labels = [] 61 | videos = [] 62 | with open(file_path) as f: 63 | for line in f.readlines(): 64 | line = line.strip() 65 | if line == "": 66 | continue 67 | if args.num_classes == 400: 68 | label, youtube_id, start_time, end_time, temp, _ = line.split(",") 69 | label = label.replace("\"", "") 70 | else: 71 | label, youtube_id, start_time, end_time, temp = line.split(",") 72 | if temp.strip() == 'split': 73 | continue 74 | label_name = label.strip() 75 | video_id = "{}_{:06d}_{:06d}".format(youtube_id, int(start_time), int(end_time)) 76 | videos.append([video_id, label_name]) 77 | labels.append(label_name) 78 | if not build_label: 79 | return videos 80 | labels = sorted(list(set(labels))) 81 | id_to_label = {} 82 | label_to_id = {} 83 | with open(label_file, 'w') as f: 84 | for i in range(len(labels)): 85 | label_to_id[labels[i]] = i 86 | id_to_label[i] = labels[i] 87 | print(labels[i], file=f) 88 | return videos, label_to_id, id_to_label 89 | 90 | 91 | def load_test_video_list(file_path): 92 | videos = [] 93 | with open(file_path) as f: 94 | for line in f.readlines(): 95 | line = line.strip() 96 | if line == "": 97 | continue 98 | youtube_id, start_time, end_time, temp = line.split(",") 99 | if temp.strip() == 'split': 100 | continue 101 | video_id = "{}_{:06d}_{:06d}".format(youtube_id, int(start_time), int(end_time)) 102 | videos.append([video_id, "x"]) 103 | return videos 104 | 105 | 106 | train_videos, label_to_id, id_to_label = load_video_list(train_file, build_label=True) 107 | val_videos = load_video_list(val_file) 108 | test_videos = load_test_video_list(test_file) 109 | 110 | 111 | def video_to_images(video, basedir, targetdir, shorter_side): 112 | try: 113 | cls_id = label_to_id[video[1]] 114 | filename = os.path.join(basedir, video[1], video[0] + ".mp4") 115 | output_foldername = os.path.join(targetdir, video[1], video[0]) 116 | except Exception as e: # for test videos 117 | cls_id = -1 118 | filename = os.path.join(basedir, video[0] + ".mp4") 119 | output_foldername = os.path.join(targetdir, video[0]) 120 | 121 | if not os.path.exists(filename): 122 | print("{} is not existed.".format(filename)) 123 | return video[0], video[1], -2 124 | else: 125 | try: 126 | video_meta = skvideo.io.ffprobe(filename) 127 | height = int(video_meta['video']['@height']) 128 | width = int(video_meta['video']['@width']) 129 | except: 130 | print("Can not get video info: {}".format(filename)) 131 | return video[0], video[1], 0 132 | 133 | if len(shorter_side) == 1: 134 | if width > height: 135 | scale = "scale=-1:{}".format(shorter_side[0]) 136 | else: 137 | scale = "scale={}:-1".format(shorter_side[0]) 138 | else: 139 | scale = "scale={}:{}".format(shorter_side[0], shorter_side[1]) 140 | if not os.path.exists(output_foldername): 141 | os.makedirs(output_foldername) 142 | 143 | command = ['ffmpeg', 144 | '-i', '"%s"' % filename, 145 | '-vf', scale, 146 | '-threads', '1', 147 | '-loglevel', 'panic', 148 | '-q:v', '0', 149 | '"{}/'.format(output_foldername) + '%05d.jpg"'] 150 | command = ' '.join(command) 151 | try: 152 | # print(command) 153 | subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT) 154 | except: 155 | print("fail to convert {}".format(filename)) 156 | return video[0], video[1], 0 157 | 158 | # get frame num 159 | i = 0 160 | while True: 161 | img_name = os.path.join(output_foldername, "{:05d}.jpg".format(i + 1)) 162 | if os.path.exists(img_name): 163 | i += 1 164 | else: 165 | break 166 | 167 | frame_num = i 168 | # print("Finish {}, id: {} frames: {}".format(filename, cls_id, frame_num)) 169 | return video[0], cls_id, frame_num 170 | 171 | 172 | def create_train_video(shorter_side): 173 | print("Resizing to shorter side: {}".format(shorter_side)) 174 | with open(train_file_list, 'w') as f, open(train_fail_file_list, 175 | 'w') as f_w, concurrent.futures.ProcessPoolExecutor( 176 | max_workers=64) as executor: 177 | futures = [executor.submit(video_to_images, video, train_video_folder, train_img_folder, 178 | shorter_side) 179 | for video in train_videos] 180 | total_videos = len(futures) 181 | curr_idx = 0 182 | print("label,youtube_id,time_start,time_end,split", file=f_w, flush=True) 183 | for future in concurrent.futures.as_completed(futures): 184 | video_id, label_id, frame_num = future.result() 185 | if frame_num == 0: 186 | youtube_id = video_id[:11] 187 | time_start = int(video_id[12:18]) 188 | time_end = int(video_id[19:25]) 189 | print("{},{},{},{},train".format(label_id, youtube_id, time_start, time_end), 190 | file=f_w, flush=True) 191 | elif frame_num == -2: 192 | youtube_id = video_id[:11] 193 | time_start = int(video_id[12:18]) 194 | time_end = int(video_id[19:25]) 195 | print("{},{},{},{},train,missed".format(label_id, youtube_id, time_start, time_end), 196 | file=f_w, flush=True) 197 | else: 198 | print("{};1;{};{}".format( 199 | os.path.join('images/train', id_to_label[label_id], video_id), frame_num, 200 | label_id), file=f, flush=True) 201 | if curr_idx % 1000 == 0: 202 | print("{}/{}".format(curr_idx, total_videos), flush=True) 203 | curr_idx += 1 204 | print("Completed") 205 | 206 | 207 | def create_val_video(shorter_side): 208 | print("Resizing to shorter side: {}".format(shorter_side)) 209 | with open(val_file_list, 'w') as f, open(val_fail_file_list, 210 | 'w') as f_w, concurrent.futures.ProcessPoolExecutor( 211 | max_workers=36) as executor: 212 | futures = [ 213 | executor.submit(video_to_images, video, val_video_folder, val_img_folder, shorter_side) 214 | for video in val_videos] 215 | total_videos = len(futures) 216 | curr_idx = 0 217 | print("label,youtube_id,time_start,time_end,split", file=f_w, flush=True) 218 | for future in concurrent.futures.as_completed(futures): 219 | video_id, label_id, frame_num = future.result() 220 | if frame_num == 0: 221 | youtube_id = video_id[:11] 222 | time_start = int(video_id[12:18]) 223 | time_end = int(video_id[19:25]) 224 | print("{},{},{},{},val".format(label_id, youtube_id, time_start, time_end), 225 | file=f_w, flush=True) 226 | # print("{},{},{},{},val".format(label_id, youtube_id, time_start, time_end), flush=True) 227 | elif frame_num == -2: 228 | youtube_id = video_id[:11] 229 | time_start = int(video_id[12:18]) 230 | time_end = int(video_id[19:25]) 231 | print("{},{},{},{},val,missed".format(label_id, youtube_id, time_start, time_end), 232 | file=f_w, flush=True) 233 | else: 234 | print("{};1;{};{}".format( 235 | os.path.join('images/val', id_to_label[label_id], video_id), frame_num, 236 | label_id), file=f, flush=True) 237 | if curr_idx % 1000 == 0: 238 | print("{}/{}".format(curr_idx, total_videos)) 239 | curr_idx += 1 240 | print("Completed") 241 | 242 | 243 | def create_test_video(shorter_side): 244 | with open(test_file_list, 'w') as f, open(test_fail_file_list, 245 | 'w') as f_w, concurrent.futures.ProcessPoolExecutor( 246 | max_workers=36) as executor: 247 | futures = [executor.submit(video_to_images, video, test_video_folder, test_img_folder, 248 | shorter_side) 249 | for video in test_videos] 250 | total_videos = len(futures) 251 | curr_idx = 0 252 | print("youtube_id,time_start,time_end,split", file=f_w, flush=True) 253 | for future in concurrent.futures.as_completed(futures): 254 | video_id, label_id, frame_num = future.result() 255 | if frame_num == 0: 256 | youtube_id = video_id[:11] 257 | time_start = int(video_id[12:18]) 258 | time_end = int(video_id[19:25]) 259 | print("{},{},{},test".format(youtube_id, time_start, time_end), file=f_w, 260 | flush=True) 261 | else: 262 | print("{} 1 {}".format(os.path.join('images/test', video_id), frame_num), file=f, 263 | flush=True) 264 | print("{}/{}".format(curr_idx, total_videos)) 265 | curr_idx += 1 266 | print("Completed") 267 | 268 | 269 | if __name__ == "__main__": 270 | create_train_video(args.shorter_side) 271 | create_val_video(args.shorter_side) 272 | if args.do_test_set: 273 | create_test_video(args.shorter_side) 274 | -------------------------------------------------------------------------------- /tools/extract_videos_moments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import concurrent.futures 5 | from shutil import copyfile 6 | import subprocess 7 | 8 | input_folder_root = "" 9 | if input_folder_root == "": 10 | raise ValueError("Please set input_folder_root") 11 | 12 | output_folder_root = "" 13 | if output_folder_root == "": 14 | raise ValueError("Please set output_folder_root") 15 | 16 | # input 17 | label_file = "{}/moments_categories.txt".format(input_folder_root) 18 | train_file = "{}/trainingSet.csv".format(input_folder_root) 19 | val_file = "{}/validationSet.csv".format(input_folder_root) 20 | video_folder = input_folder_root 21 | 22 | # output 23 | train_img_folder = "{}/train".format(output_folder_root) 24 | val_img_folder = "{}/val".format(output_folder_root) 25 | train_file_list = "{}/train.txt".format(output_folder_root) 26 | val_file_list = "{}/val.txt".format(output_folder_root) 27 | 28 | 29 | def load_categories(file_path): 30 | id_to_label = {} 31 | label_to_id = {} 32 | with open(file_path) as f: 33 | for label in f.readlines(): 34 | label = label.strip() 35 | if label == "": 36 | continue 37 | label = label.split(',') 38 | cls_id = int(label[-1]) 39 | id_to_label[cls_id] = label[0] 40 | label_to_id[label[0]] = cls_id 41 | return id_to_label, label_to_id 42 | 43 | 44 | id_to_label, label_to_id = load_categories(label_file) 45 | 46 | 47 | def load_video_list(file_path): 48 | videos = [] 49 | with open(file_path) as f: 50 | for line in f.readlines(): 51 | line = line.strip() 52 | if line == "": 53 | continue 54 | video_id, label_name, _, _= line.split(",") 55 | label_name = label_name.strip() 56 | videos.append([video_id, label_name]) 57 | return videos 58 | 59 | 60 | train_videos = load_video_list(train_file) 61 | val_videos = load_video_list(val_file) 62 | 63 | 64 | def video_to_images(video, basedir, targetdir): 65 | try: 66 | cls_id = label_to_id[video[1]] 67 | except: 68 | cls_id = -1 69 | assert cls_id >= 0 70 | filename = os.path.join(basedir, video[0]) 71 | video_basename = video[0].split('.')[0] 72 | output_foldername = os.path.join(targetdir, video_basename) 73 | if not os.path.exists(filename): 74 | print("{} is not existed.".format(filename)) 75 | return video[0], cls_id, 0 76 | else: 77 | if not os.path.exists(output_foldername): 78 | os.makedirs(output_foldername) 79 | 80 | command = ['ffmpeg', 81 | '-i', '"%s"' % filename, 82 | '-threads', '1', 83 | '-loglevel', 'panic', 84 | '-q:v', '0', 85 | '{}/'.format(output_foldername) + '"%05d.jpg"'] 86 | command = ' '.join(command) 87 | try: 88 | subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT) 89 | except: 90 | print("fail to convert {}".format(filename)) 91 | return video[0], cls_id, 0 92 | 93 | # get frame num 94 | i = 0 95 | while True: 96 | img_name = os.path.join(output_foldername + "/{:05d}.jpg".format(i + 1)) 97 | if os.path.exists(img_name): 98 | i += 1 99 | else: 100 | break 101 | 102 | frame_num = i 103 | print("Finish {}, id: {} frames: {}".format(filename, cls_id, frame_num)) 104 | return video_basename, cls_id, frame_num 105 | 106 | 107 | def create_train_video(): 108 | with open(train_file_list, 'w') as f, concurrent.futures.ProcessPoolExecutor(max_workers=36) as executor: 109 | futures = [executor.submit(video_to_images, video, os.path.join(video_folder, 'train'), train_img_folder) 110 | for video in train_videos] 111 | total_videos = len(futures) 112 | curr_idx = 0 113 | for future in concurrent.futures.as_completed(futures): 114 | video_id, label_id, frame_num = future.result() 115 | if frame_num == 0: 116 | print("Something wrong: {}".format(video_id)) 117 | else: 118 | print("{} 1 {} {}".format(os.path.join(train_img_folder, video_id), frame_num, label_id), file=f, flush=True) 119 | print("{}/{}".format(curr_idx, total_videos), flush=True) 120 | curr_idx += 1 121 | print("Completed") 122 | 123 | 124 | def create_val_video(): 125 | with open(val_file_list, 'w') as f, concurrent.futures.ProcessPoolExecutor(max_workers=36) as executor: 126 | futures = [executor.submit(video_to_images, video, os.path.join(video_folder, 'val'), val_img_folder) 127 | for video in val_videos] 128 | total_videos = len(futures) 129 | curr_idx = 0 130 | for future in concurrent.futures.as_completed(futures): 131 | video_id, label_id, frame_num = future.result() 132 | if frame_num == 0: 133 | print("Something wrong: {}".format(video_id)) 134 | else: 135 | print("{} 1 {} {}".format(os.path.join(val_img_folder, video_id), frame_num, label_id), file=f, flush=True) 136 | print("{}/{}".format(curr_idx, total_videos)) 137 | curr_idx += 1 138 | print("Completed") 139 | 140 | 141 | create_train_video() 142 | create_val_video() 143 | -------------------------------------------------------------------------------- /tools/extract_videos_st2st_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import json 5 | import skvideo.io 6 | import concurrent.futures 7 | import subprocess 8 | 9 | folder_root = "" 10 | 11 | if folder_root == "": 12 | raise ValueError("Please set folder_root") 13 | 14 | # input 15 | label_file = "{}/something-something-v2-labels.json".format(folder_root) 16 | train_file = "{}/something-something-v2-train.json".format(folder_root) 17 | val_file = "{}/something-something-v2-validation.json".format(folder_root) 18 | test_file = "{}/something-something-v2-test.json".format(folder_root) 19 | video_folder = "{}/20bn-something-something-v2".format(folder_root) 20 | 21 | # output 22 | train_img_folder = "{}/train".format(folder_root) 23 | val_img_folder = "{}/val".format(folder_root) 24 | test_img_folder = "{}/test".format(folder_root) 25 | train_file_list = "{}/train.txt".format(folder_root) 26 | val_file_list = "{}/val.txt".format(folder_root) 27 | test_file_list = "{}/test.txt".format(folder_root) 28 | 29 | def load_categories(file_path): 30 | id_to_label = {} 31 | label_to_id = {} 32 | with open(file_path) as f: 33 | labels = json.load(f) 34 | for label, cls_id in labels.items(): 35 | label = label 36 | id_to_label[int(cls_id)] = label 37 | label_to_id[label] = int(cls_id) 38 | return id_to_label, label_to_id 39 | 40 | 41 | id_to_label, label_to_id = load_categories(label_file) 42 | 43 | 44 | def load_video_list(file_path): 45 | videos = [] 46 | with open(file_path) as f: 47 | file_list = json.load(f) 48 | for temp in file_list: 49 | videos.append([temp['id'], temp['template'].replace( 50 | "[", "").replace("]", ""), temp['label'], temp['placeholders']]) 51 | return videos 52 | 53 | 54 | def load_test_video_list(file_path): 55 | videos = [] 56 | with open(file_path) as f: 57 | file_list = json.load(f) 58 | for temp in file_list: 59 | videos.append([temp['id']]) 60 | return videos 61 | 62 | 63 | train_videos = load_video_list(train_file) 64 | val_videos = load_video_list(val_file) 65 | test_videos = load_test_video_list(test_file) 66 | 67 | 68 | def resize_to_short_side(h, w, short_side=360): 69 | newh, neww = h, w 70 | if h < w: 71 | newh = short_side 72 | neww = (w / h) * newh 73 | else: 74 | neww = short_side 75 | newh = (h / w) * neww 76 | neww = int(neww + 0.5) 77 | newh = int(newh + 0.5) 78 | return newh, neww 79 | 80 | 81 | def video_to_images(video, basedir, targetdir, short_side=256): 82 | try: 83 | cls_id = label_to_id[video[1]] 84 | except: 85 | cls_id = -1 86 | filename = os.path.join(basedir, video[0] + ".webm") 87 | output_foldername = os.path.join(targetdir, video[0]) 88 | if not os.path.exists(filename): 89 | print("{} is not existed.".format(filename)) 90 | return video[0], cls_id, 0 91 | else: 92 | try: 93 | video_meta = skvideo.io.ffprobe(filename) 94 | height = int(video_meta['video']['@height']) 95 | width = int(video_meta['video']['@width']) 96 | except: 97 | print("Can not get video info: {}".format(filename)) 98 | return video[0], cls_id, 0 99 | 100 | if width > height: 101 | scale = "scale=-1:{}".format(short_side) 102 | else: 103 | scale = "scale={}:-1".format(short_side) 104 | if not os.path.exists(output_foldername): 105 | os.makedirs(output_foldername) 106 | 107 | command = ['ffmpeg', 108 | '-i', '"%s"' % filename, 109 | '-vf', scale, 110 | '-threads', '1', 111 | '-loglevel', 'panic', '-qmin', '1', '-qmax', '1', 112 | '-q:v', '0', 113 | '{}/'.format(output_foldername) + '"%05d.jpg"'] 114 | command = ' '.join(command) 115 | try: 116 | subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT) 117 | except: 118 | print("fail to convert {}".format(filename)) 119 | return video[0], cls_id, 0 120 | 121 | # get frame num 122 | i = 0 123 | while True: 124 | img_name = os.path.join(output_foldername + "/{:05d}.jpg".format(i + 1)) 125 | if os.path.exists(img_name): 126 | i += 1 127 | else: 128 | break 129 | 130 | frame_num = i 131 | print("Finish {}, id: {} frames: {}".format(filename, cls_id, frame_num)) 132 | return video[0], cls_id, frame_num 133 | 134 | 135 | def create_train_video(short_side): 136 | with open(train_file_list, 'w') as f, concurrent.futures.ProcessPoolExecutor(max_workers=36) as executor: 137 | futures = [executor.submit(video_to_images, video, video_folder, train_img_folder, int(short_side)) 138 | for video in train_videos] 139 | total_videos = len(futures) 140 | curr_idx = 0 141 | for future in concurrent.futures.as_completed(futures): 142 | video_id, label_id, frame_num = future.result() 143 | if frame_num == 0: 144 | print("Something wrong: {}".format(video_id)) 145 | else: 146 | print("{} 1 {} {}".format(os.path.join(train_img_folder, video_id), frame_num, label_id), file=f, flush=True) 147 | print("{}/{}".format(curr_idx, total_videos), flush=True) 148 | curr_idx += 1 149 | print("Completed") 150 | 151 | 152 | def create_val_video(short_side): 153 | with open(val_file_list, 'w') as f, concurrent.futures.ProcessPoolExecutor(max_workers=36) as executor: 154 | futures = [executor.submit(video_to_images, video, video_folder, val_img_folder, int(short_side)) 155 | for video in val_videos] 156 | total_videos = len(futures) 157 | curr_idx = 0 158 | for future in concurrent.futures.as_completed(futures): 159 | video_id, label_id, frame_num = future.result() 160 | if frame_num == 0: 161 | print("Something wrong: {}".format(video_id)) 162 | else: 163 | print("{} 1 {} {}".format(os.path.join(val_img_folder, video_id), frame_num, label_id), file=f, flush=True) 164 | print("{}/{}".format(curr_idx, total_videos)) 165 | curr_idx += 1 166 | print("Completed") 167 | 168 | 169 | def create_test_video(short_side): 170 | with open(test_file_list, 'w') as f, concurrent.futures.ProcessPoolExecutor(max_workers=36) as executor: 171 | futures = [executor.submit(video_to_images, video, video_folder, test_img_folder, int(short_side)) 172 | for video in test_videos] 173 | total_videos = len(futures) 174 | curr_idx = 0 175 | for future in concurrent.futures.as_completed(futures): 176 | video_id, label_id, frame_num = future.result() 177 | if frame_num == 0: 178 | print("Something wrong: {}".format(video_id)) 179 | else: 180 | print("{} 1 {}".format(os.path.join(test_img_folder, video_id), frame_num), file=f, flush=True) 181 | print("{}/{}".format(curr_idx, total_videos)) 182 | curr_idx += 1 183 | print("Completed") 184 | 185 | 186 | create_train_video(256) 187 | create_val_video(256) 188 | create_test_video(256) --------------------------------------------------------------------------------