├── x_temporal ├── utils │ ├── __init__.py │ ├── calculate_map.py │ ├── optimizer_helper.py │ ├── log_helper.py │ ├── metrics.py │ ├── dataset_helper.py │ ├── multiprocessing.py │ ├── model_helper.py │ ├── utils.py │ ├── dist_helper.py │ └── lr_helper.py ├── cuda_shift │ ├── __init__.py │ ├── src │ │ ├── shift_cuda.h │ │ ├── cuda │ │ │ ├── shift_kernel_cuda.h │ │ │ └── shift_kernel_cuda.cu │ │ ├── shift_cuda.c │ │ └── shift_cuda.cpp │ └── rtc_wrap.py ├── models │ ├── __init__.py │ ├── resnet3D.py │ ├── slowfast.py │ ├── stresnet.py │ └── resnet.py ├── core │ ├── __init__.py │ ├── calculate_map.py │ ├── basic_ops.py │ ├── utils.py │ ├── tin.py │ ├── models_entry.py │ ├── non_local.py │ ├── tsm.py │ ├── dataset.py │ ├── models.py │ └── transforms.py ├── __init__.py ├── interface │ ├── __init__.py │ └── temporal_helper.py ├── test.py └── train.py ├── requirements.txt ├── easy_setup.sh ├── .gitignore ├── tools ├── vid2img_sthv2.sh ├── gen_label_video_source.py ├── vid2img_sthv2.py ├── gen_label_sthv2.py ├── vid2img_kinetics.py ├── gen_label_sthv1.py └── gen_label_kinetics.py ├── experiments ├── r3d │ ├── test.sh │ ├── train.sh │ └── default.yaml ├── tin │ ├── test.sh │ ├── train.sh │ └── default.yaml ├── tsm │ ├── test.sh │ ├── train.sh │ └── default.yaml ├── tsn │ ├── test.sh │ ├── train.sh │ └── default.yaml ├── r2plus1d │ ├── test.sh │ ├── train.sh │ └── default.yaml └── slowfast │ ├── test.sh │ ├── train.sh │ └── default.yaml ├── LICENSE ├── ModelZoo.md ├── Configuration.md ├── setup.py └── ReadMe.md /x_temporal/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /x_temporal/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord 2 | tqdm 3 | tensorboardX 4 | scikit-learn 5 | -------------------------------------------------------------------------------- /x_temporal/core/__init__.py: -------------------------------------------------------------------------------- 1 | from x_temporal.core.basic_ops import * 2 | -------------------------------------------------------------------------------- /easy_setup.sh: -------------------------------------------------------------------------------- 1 | python -m pip install --user -r requirements.txt 2 | python setup.py develop --user 3 | -------------------------------------------------------------------------------- /x_temporal/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | 3 | 4 | from .interface import TemporalHelper 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | *.bak 3 | *.pyc 4 | *.npy 5 | *.o 6 | *.so 7 | 8 | build/ 9 | _ext/ 10 | checkpoint/ 11 | __pycache__/ 12 | -------------------------------------------------------------------------------- /tools/vid2img_sthv2.sh: -------------------------------------------------------------------------------- 1 | LOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $1 -n1 --gres=gpu:0 --ntasks-per-node=1 --cpus-per-task=36 \ 2 | python -u vid2img_sthv2.py 3 | -------------------------------------------------------------------------------- /x_temporal/interface/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .sci import Metric, SpringCommonInterface 3 | except ImportError: 4 | pass 5 | 6 | from .temporal_helper import TemporalHelper 7 | 8 | -------------------------------------------------------------------------------- /experiments/r3d/test.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 8 | -------------------------------------------------------------------------------- /experiments/r3d/train.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 8 | -------------------------------------------------------------------------------- /experiments/tin/test.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 8 | -------------------------------------------------------------------------------- /experiments/tin/train.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 8 | -------------------------------------------------------------------------------- /experiments/tsm/test.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 8 | -------------------------------------------------------------------------------- /experiments/tsm/train.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 8 | -------------------------------------------------------------------------------- /experiments/tsn/test.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 8 | -------------------------------------------------------------------------------- /experiments/tsn/train.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 8 | -------------------------------------------------------------------------------- /experiments/r2plus1d/test.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 8 | -------------------------------------------------------------------------------- /experiments/r2plus1d/train.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 8 | -------------------------------------------------------------------------------- /experiments/slowfast/test.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 8 | -------------------------------------------------------------------------------- /experiments/slowfast/train.sh: -------------------------------------------------------------------------------- 1 | T=`date +%m%d%H%M` 2 | ROOT=../.. 3 | cfg=default.yaml 4 | 5 | export PYTHONPATH=$ROOT:$PYTHONPATH 6 | 7 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 8 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/src/shift_cuda.h: -------------------------------------------------------------------------------- 1 | at::Tensor shift_featuremap_cuda_forward(at::Tensor &data, 2 | at::Tensor &shift, at::Tensor &out); 3 | 4 | at::Tensor shift_featuremap_cuda_backward(at::Tensor &grad_output, 5 | at::Tensor &shift, at::Tensor &grad_input); 6 | -------------------------------------------------------------------------------- /x_temporal/utils/calculate_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import average_precision_score 3 | 4 | def calculate_mAP(y_pred, y_true): 5 | y_pred = y_pred.detach().cpu().numpy() 6 | y_true = y_true.detach().cpu().numpy() 7 | values = [] 8 | for i in range(len(y_pred)): 9 | values.append(average_precision_score(y_true[i], y_pred[i], average='macro')) 10 | return np.mean(values) * 100 11 | -------------------------------------------------------------------------------- /x_temporal/core/calculate_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import average_precision_score 3 | 4 | 5 | def calculate_mAP(y_pred, y_true): 6 | y_pred = y_pred.detach().cpu().numpy() 7 | y_true = y_true.detach().cpu().numpy() 8 | values = [] 9 | for i in range(len(y_pred)): 10 | values.append( 11 | average_precision_score( 12 | y_true[i], 13 | y_pred[i], 14 | average='macro')) 15 | return np.mean(values) * 100 16 | -------------------------------------------------------------------------------- /x_temporal/utils/optimizer_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | 5 | def build_cls_instance(module, cfg): 6 | """Build instance for given cls""" 7 | cls = getattr(module, cfg['type']) 8 | return cls(**cfg['kwargs']) 9 | 10 | 11 | def build_optimizer(cfg_optim, model): 12 | cfg_optim = copy.deepcopy(cfg_optim) 13 | trainable_params = [p for p in model.parameters() if p.requires_grad] 14 | cfg_optim['kwargs']['params'] = trainable_params 15 | optim_type = cfg_optim['type'] 16 | optimizer = build_cls_instance(torch.optim, cfg_optim) 17 | return optimizer 18 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/rtc_wrap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from . import _C as backend 5 | 6 | class ShiftFeatureFunc(Function): 7 | def __init__(self): 8 | super(ShiftFeatureFunc, self).__init__() 9 | 10 | def forward(self, data, shift): 11 | if not data.is_cuda or not shift.is_cuda: 12 | raise NotImplementedError 13 | 14 | if data.requires_grad: 15 | self.save_for_backward(shift) 16 | 17 | out = torch.zeros_like(data) 18 | backend.shift_featuremap_cuda_forward(data, shift, out) 19 | return out 20 | 21 | def backward(self, grad_output): 22 | if not grad_output.is_cuda: 23 | raise NotImplementedError 24 | shift = self.saved_tensors[0] 25 | data_grad_input = grad_output.new(*grad_output.size()).zero_() 26 | shift_grad_input = shift.new(*shift.size()).zero_() 27 | backend.shift_featuremap_cuda_backward(grad_output, shift, data_grad_input) 28 | return data_grad_input, shift_grad_input 29 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/src/cuda/shift_kernel_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef Shift_FeatureMap_CUDA 2 | #define Shift_FeatureMap_CUDA 3 | 4 | #ifdef __cplusplus 5 | extern "C" 6 | { 7 | #endif 8 | 9 | void ShiftDataCudaForward(cudaStream_t stream, 10 | float* data, 11 | int* shift, 12 | const int batch_size, 13 | const int channels, 14 | const int tsize, 15 | const int hwsize, 16 | const int groupsize, 17 | float* out); 18 | 19 | void ShiftDataCudaBackward(cudaStream_t stream, 20 | float* grad_output, 21 | int* shift, 22 | const int batch_size, 23 | const int channels, 24 | const int tsize, 25 | const int hwsize, 26 | const int groupsize, 27 | float* grad_input); 28 | #ifdef __cplusplus 29 | } 30 | #endif 31 | 32 | #endif 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yu Liu (Louis) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/slowfast/default.yaml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | config: 3 | gpus: 8 4 | seed: 2020 5 | dataset: 6 | workers: 3 7 | num_class: 600 8 | num_segments: 32 9 | batch_size: 32 10 | img_prefix: 'image_{:05d}.jpg' 11 | video_source: False 12 | dense_sample: True 13 | modality: RGB 14 | flow_prefix: '' 15 | root_dir: /path 16 | flip: True 17 | dense_sample_rate: 2 18 | input_mean: [0.485, 0.456, 0.406] 19 | input_std: [0.229, 0.224 ,0.225] 20 | crop_size: 112 21 | scale_size: 128 22 | train: 23 | meta_file: /path 24 | val: 25 | meta_file: /path 26 | test: 27 | meta_file: /path 28 | 29 | net: 30 | arch: sfresnet50 31 | model_type: 3D 32 | dropout: 0.5 33 | 34 | evaluate: 35 | spatial_crops: 1 36 | temporal_samples: 1 37 | 38 | trainer: 39 | print_freq: 20 40 | eval_freq: 5 41 | epochs: 120 42 | start_epoch: 0 43 | loss_type: nll 44 | clip_gradient: 20 45 | lr_scheduler: 46 | warmup_epochs: 10 47 | type: CosineAnnealingLR 48 | kwargs: 49 | T_max: 120 50 | optimizer: 51 | type: SGD 52 | kwargs: 53 | lr: 0.4 54 | momentum: 0.9 55 | weight_decay: 0.0005 56 | nesterov: True 57 | 58 | 59 | saver: 60 | save_dir: 'checkpoint/' 61 | #pretrain_model: '/path' 62 | #resume_model: '/path' 63 | -------------------------------------------------------------------------------- /experiments/r3d/default.yaml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | config: 3 | gpus: 8 4 | seed: 2020 5 | dataset: 6 | workers: 3 7 | num_class: 102 8 | num_segments: 16 9 | batch_size: 32 10 | img_prefix: 'image_{:05d}.jpg' 11 | video_source: False 12 | dense_sample: True 13 | modality: RGB 14 | flow_prefix: '' 15 | root_dir: /path 16 | flip: True 17 | dense_sample_rate: 1 18 | input_mean: [0.485, 0.456, 0.406] 19 | input_std: [0.229, 0.224 ,0.225] 20 | crop_size: 112 21 | scale_size: 128 22 | train: 23 | meta_file: /path 24 | val: 25 | meta_file: /path 26 | test: 27 | meta_file: /path 28 | 29 | net: 30 | arch: resnet3D18 31 | model_type: 3D 32 | dropout: 0.0 33 | 34 | evaluate: 35 | spatial_crops: 1 36 | temporal_samples: 1 37 | 38 | trainer: 39 | print_freq: 20 40 | eval_freq: 5 41 | epochs: 120 42 | start_epoch: 0 43 | loss_type: nll 44 | clip_gradient: 20 45 | lr_scheduler: 46 | warmup_epochs: 20 47 | warmup_type: exp 48 | type: MultiStepLR 49 | kwargs: 50 | milestones: [50, 80, 100] 51 | gamma: 0.15 52 | optimizer: 53 | type: SGD 54 | kwargs: 55 | lr: 0.2 56 | momentum: 0.9 57 | weight_decay: 0.01 58 | nesterov: True 59 | 60 | 61 | saver: 62 | save_dir: 'checkpoint/' 63 | #pretrain_model: '/path' 64 | -------------------------------------------------------------------------------- /experiments/r2plus1d/default.yaml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | config: 3 | gpus: 8 4 | seed: 2020 5 | dataset: 6 | workers: 3 7 | num_class: 102 8 | num_segments: 16 9 | batch_size: 32 10 | img_prefix: 'image_{:05d}.jpg' 11 | video_source: False 12 | dense_sample: True 13 | modality: RGB 14 | flow_prefix: '' 15 | root_dir: /path 16 | flip: True 17 | dense_sample_rate: 2 18 | input_mean: [0.485, 0.456, 0.406] 19 | input_std: [0.229, 0.224 ,0.225] 20 | crop_size: 112 21 | scale_size: 128 22 | train: 23 | meta_file: /path 24 | val: 25 | meta_file: /path 26 | test: 27 | meta_file: /path 28 | 29 | net: 30 | arch: stresnet18 31 | model_type: 3D 32 | dropout: 0.0 33 | max_pooling: True 34 | 35 | evaluate: 36 | spatial_crops: 1 37 | temporal_samples: 1 38 | 39 | trainer: 40 | print_freq: 20 41 | eval_freq: 5 42 | epochs: 120 43 | start_epoch: 0 44 | loss_type: nll 45 | clip_gradient: 20 46 | lr_scheduler: 47 | warmup_epochs: 10 48 | type: CosineAnnealingLR 49 | kwargs: 50 | T_max: 120 51 | optimizer: 52 | type: SGD 53 | kwargs: 54 | lr: 0.4 55 | momentum: 0.9 56 | weight_decay: 0.0005 57 | nesterov: True 58 | 59 | 60 | saver: 61 | save_dir: 'checkpoint/' 62 | #pretrain_model: '/path' 63 | #resume_model: '/path' 64 | -------------------------------------------------------------------------------- /tools/gen_label_video_source.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from decord import VideoReader 4 | from decord import cpu 5 | 6 | if __name__ == '__main__': 7 | root_dir = '' # video data root path 8 | dataset_name = 'hmdb51' 9 | with open(os.path.join('../datasets', dataset_name, 'category.txt')) as f: 10 | lines = f.readlines() 11 | categories = [] 12 | for line in lines: 13 | line = line.rstrip() 14 | categories.append(line) 15 | 16 | 17 | dict_categories = {} 18 | for i, category in enumerate(categories): 19 | dict_categories[category] = i 20 | 21 | filename_input = os.path.join('../datasets', dataset_name, 'vallist.txt') 22 | filename_output = 'test_videofolder.txt' 23 | with open(filename_input) as f: 24 | lines = f.readlines() 25 | videos = [] 26 | idx_categories = [] 27 | for line in lines: 28 | line = line.rstrip() 29 | videos.append(line) 30 | label = line.split('/')[0] 31 | idx_categories.append(dict_categories[label]) 32 | output = [] 33 | for i in range(len(videos)): 34 | curVideo = videos[i] 35 | curIDX = idx_categories[i] 36 | video_file = os.path.join(root_dir, curVideo) 37 | vr = VideoReader(os.path.join(root_dir, curVideo), ctx=cpu(0)) 38 | output.append('%s %d %d' % (curVideo, len(vr), curIDX)) 39 | print('%d/%d' % (i, len(vr))) 40 | with open(filename_output, 'w') as f: 41 | f.write('\n'.join(output)) 42 | -------------------------------------------------------------------------------- /experiments/tin/default.yaml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | config: 3 | gpus: 8 4 | seed: 2020 5 | dataset: 6 | workers: 3 7 | num_class: 174 8 | num_segments: 8 9 | batch_size: 8 10 | img_prefix: '{:05d}.jpg' 11 | video_source: False 12 | dense_sample: False 13 | modality: RGB 14 | flow_prefix: '' 15 | root_dir: /path 16 | flip: False 17 | input_mean: [0.485, 0.456, 0.406] 18 | input_std: [0.229, 0.224 ,0.225] 19 | crop_size: 224 20 | scale_size: 256 21 | train: 22 | meta_file: /path 23 | val: 24 | meta_file: /path 25 | test: 26 | meta_file: /path 27 | 28 | net: 29 | arch: resnet50 30 | model_type: 2D 31 | tin: True 32 | shift_div: 4 33 | consensus_type: avg 34 | dropout: 0.8 35 | img_feature_dim: 256 36 | pretrain: True # imagenet pretrain for 2D network 37 | 38 | 39 | trainer: 40 | print_freq: 20 41 | eval_freq: 1 42 | epochs: 35 43 | start_epoch: 0 44 | loss_type: nll 45 | no_partial_bn: True 46 | clip_gradient: 20 47 | lr_scheduler: 48 | warmup_epochs: 1 49 | warmup_type: linear 50 | type: CosineAnnealingLR 51 | kwargs: 52 | T_max: 30 53 | optimizer: 54 | type: SGD 55 | kwargs: 56 | lr: 0.02 57 | momentum: 0.9 58 | weight_decay: 0.0005 59 | nesterov: True 60 | 61 | 62 | saver: 63 | save_dir: 'checkpoint/' 64 | #pretrain_model: '/path' 65 | #resume_model: '/path' 66 | -------------------------------------------------------------------------------- /x_temporal/core/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Identity(torch.nn.Module): 5 | def forward(self, input): 6 | return input 7 | 8 | 9 | class SegmentConsensus(torch.autograd.Function): 10 | 11 | def __init__(self, consensus_type, dim=1): 12 | self.consensus_type = consensus_type 13 | self.dim = dim 14 | self.shape = None 15 | 16 | def forward(self, input_tensor): 17 | self.shape = input_tensor.size() 18 | if self.consensus_type == 'avg': 19 | output = input_tensor.mean(dim=self.dim, keepdim=True) 20 | elif self.consensus_type == 'identity': 21 | output = input_tensor 22 | else: 23 | output = None 24 | 25 | return output 26 | 27 | def backward(self, grad_output): 28 | if self.consensus_type == 'avg': 29 | grad_in = grad_output.expand( 30 | self.shape) / float(self.shape[self.dim]) 31 | elif self.consensus_type == 'identity': 32 | grad_in = grad_output 33 | else: 34 | grad_in = None 35 | 36 | return grad_in 37 | 38 | 39 | class ConsensusModule(torch.nn.Module): 40 | 41 | def __init__(self, consensus_type, dim=1): 42 | super(ConsensusModule, self).__init__() 43 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 44 | self.dim = dim 45 | 46 | def forward(self, input): 47 | return SegmentConsensus(self.consensus_type, self.dim)(input) 48 | -------------------------------------------------------------------------------- /experiments/tsm/default.yaml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | config: 3 | gpus: 8 4 | seed: 2020 5 | dataset: 6 | workers: 3 7 | num_class: 102 8 | num_segments: 8 9 | batch_size: 8 10 | img_prefix: 'image_{:05d}.jpg' 11 | video_source: False 12 | dense_sample: False 13 | modality: RGB 14 | flow_prefix: '' 15 | root_dir: /path 16 | flip: True 17 | input_mean: [0.485, 0.456, 0.406] 18 | input_std: [0.229, 0.224 ,0.225] 19 | crop_size: 224 20 | scale_size: 256 21 | train: 22 | meta_file: /path 23 | val: 24 | meta_file: /path 25 | test: 26 | meta_file: /path 27 | 28 | net: 29 | arch: resnet50 30 | model_type: 2D 31 | shift: True 32 | shift_div: 8 33 | tin: False 34 | consensus_type: avg 35 | dropout: 0.8 36 | img_feature_dim: 256 37 | non_local: False 38 | pretrain: True # imagenet pretrain for 2D network 39 | 40 | 41 | trainer: 42 | print_freq: 20 43 | eval_freq: 1 44 | epochs: 30 45 | start_epoch: 0 46 | loss_type: nll 47 | no_partial_bn: True 48 | clip_gradient: 20 49 | lr_scheduler: 50 | warmup_epochs: 5 51 | type: CosineAnnealingLR 52 | kwargs: 53 | T_max: 30 54 | optimizer: 55 | type: SGD 56 | kwargs: 57 | lr: 0.01 58 | momentum: 0.9 59 | weight_decay: 0.0005 60 | nesterov: True 61 | 62 | 63 | saver: 64 | save_dir: 'checkpoint/' 65 | #pretrain_path: '/path' 66 | #resume_path: '/path' 67 | -------------------------------------------------------------------------------- /experiments/tsn/default.yaml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | config: 3 | gpus: 8 4 | seed: 2020 5 | dataset: 6 | workers: 3 7 | num_class: 102 8 | num_segments: 8 9 | batch_size: 8 10 | img_prefix: 'image_{:05d}.jpg' 11 | video_source: False 12 | dense_sample: False 13 | modality: RGB 14 | flow_prefix: '' 15 | root_dir: /path 16 | flip: True 17 | input_mean: [0.485, 0.456, 0.406] 18 | input_std: [0.229, 0.224 ,0.225] 19 | crop_size: 224 20 | scale_size: 256 21 | train: 22 | meta_file: /path 23 | val: 24 | meta_file: /path 25 | test: 26 | meta_file: /path 27 | 28 | net: 29 | arch: resnet50 30 | model_type: 2D 31 | consensus_type: avg 32 | dropout: 0.8 33 | img_feature_dim: 256 34 | non_local: False 35 | pretrain: True # imagenet pretrain for 2D network 36 | 37 | evaluate: 38 | spatial_crops: 1 39 | temporal_samples: 1 40 | 41 | trainer: 42 | print_freq: 20 43 | eval_freq: 1 44 | epochs: 30 45 | start_epoch: 0 46 | loss_type: nll 47 | no_partial_bn: True 48 | clip_gradient: 20 49 | lr_scheduler: 50 | warmup_epochs: 5 51 | type: CosineAnnealingLR 52 | kwargs: 53 | T_max: 30 54 | optimizer: 55 | type: SGD 56 | kwargs: 57 | lr: 0.01 58 | momentum: 0.9 59 | weight_decay: 0.0005 60 | nesterov: True 61 | 62 | 63 | saver: 64 | save_dir: 'checkpoint/' 65 | #pretrain_model: '/path' 66 | #resume_model: '/path' 67 | -------------------------------------------------------------------------------- /x_temporal/utils/log_helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import torch 5 | from x_temporal.utils.dist_helper import is_master_proc 6 | 7 | 8 | logs = set() 9 | 10 | 11 | def init_log(name, level=logging.INFO): 12 | if (name, level) in logs: 13 | return 14 | 15 | logs.add((name, level)) 16 | logger = logging.getLogger(name) 17 | logger.setLevel(level) 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(level) 20 | 21 | logger.addFilter(lambda record: is_master_proc()) 22 | 23 | format_str = f'%(asctime)s-%(filename)s#%(lineno)d:%(message)s' 24 | formatter = logging.Formatter(format_str) 25 | ch.setFormatter(formatter) 26 | logger.addHandler(ch) 27 | 28 | def get_log_format(multi_class=False): 29 | if multi_class: 30 | return ('Epoch: [{2}/{3}]\tIter: [{0}/{1}]\t' 31 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 32 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 33 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 34 | 'mAP {mAP.val:.3f} ({mAP.avg:.3f})\t' 35 | 'LR {lr:.4f}') 36 | else: 37 | return ('Epoch: [{2}/{3}]\tIter: [{0}/{1}]\t' 38 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 39 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 40 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 41 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 42 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 43 | 'LR {lr:.4f}') 44 | -------------------------------------------------------------------------------- /x_temporal/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Metric(metaclass=abc.ABCMeta): 5 | @abc.abstractmethod 6 | def __str__(self): 7 | pass 8 | 9 | @abc.abstractmethod 10 | def __repr__(self): 11 | pass 12 | 13 | @abc.abstractmethod 14 | def __eq__(self, other): 15 | pass 16 | 17 | @abc.abstractmethod 18 | def __ne__(self, other): 19 | pass 20 | 21 | @abc.abstractmethod 22 | def __gt__(self, other): 23 | pass 24 | 25 | @abc.abstractmethod 26 | def __lt__(self, other): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def __ge__(self, other): 31 | pass 32 | 33 | @abc.abstractmethod 34 | def __le__(self, other): 35 | pass 36 | 37 | 38 | class Top1Metric(Metric): 39 | def __init__(self, top1, top5, loss=None): 40 | self.top1 = top1 41 | self.top5 = top5 42 | self.loss = loss 43 | 44 | def __str__(self): 45 | return "Prec@1: %.5f\tPrec@5: %.5f" % (self.top1, self.top5) 46 | 47 | def __repr__(self): 48 | return "Prec@1: %.5f\tPrec@5: %.5f" % (self.top1, self.top5) 49 | 50 | def __eq__(self, other): 51 | return self.top1 == other.top1 52 | 53 | def __ne__(self, other): 54 | return self.top1 != other.top1 55 | 56 | def __gt__(self, other): 57 | return self.top1 > other.top1 58 | 59 | def __lt__(self, other): 60 | return self.top1 < other.top1 61 | 62 | def __ge__(self, other): 63 | return self.top1 >= other.top1 64 | 65 | def __le__(self, other): 66 | return self.top1 <= other.top1 67 | -------------------------------------------------------------------------------- /tools/vid2img_sthv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | 4 | NUM_THREADS = 24 5 | # Downloaded webm videos 6 | VIDEO_ROOT = '' 7 | # Directory for extracted frames 8 | FRAME_ROOT = '' 9 | 10 | 11 | def split(l, n): 12 | """Yield successive n-sized chunks from l.""" 13 | for i in range(0, len(l), n): 14 | yield l[i:i + n] 15 | 16 | 17 | def extract(video, tmpl='%06d.jpg'): 18 | # os.system(f'ffmpeg -i {VIDEO_ROOT}/{video} -vf -threads 1 -vf scale=-1:256 -q:v 0 ' 19 | # f'{FRAME_ROOT}/{video[:-5]}/{tmpl}') 20 | cmd = 'ffmpeg -i \"{}/{}\" -threads 1 -vf scale=-1:256 -q:v 0 \"{}/{}/%06d.jpg\"'.format(VIDEO_ROOT, video, 21 | FRAME_ROOT, video[:-5]) 22 | print(cmd) 23 | os.system(cmd) 24 | 25 | 26 | def target(video_list): 27 | for video in video_list: 28 | if not os.path.exists(os.path.join(FRAME_ROOT, video[:-5])): 29 | os.makedirs(os.path.join(FRAME_ROOT, video[:-5])) 30 | extract(video) 31 | 32 | 33 | if __name__ == '__main__': 34 | if not os.path.exists(VIDEO_ROOT): 35 | raise ValueError('Please download videos and set VIDEO_ROOT variable.') 36 | if not os.path.exists(FRAME_ROOT): 37 | os.makedirs(FRAME_ROOT) 38 | 39 | video_list = os.listdir(VIDEO_ROOT) 40 | splits = list(split(video_list, NUM_THREADS)) 41 | 42 | threads = [] 43 | for i, split in enumerate(splits): 44 | thread = threading.Thread(target=target, args=(split,)) 45 | thread.start() 46 | threads.append(thread) 47 | 48 | for thread in threads: 49 | thread.join() 50 | -------------------------------------------------------------------------------- /x_temporal/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from easydict import EasyDict 4 | import torch 5 | 6 | from x_temporal.interface.temporal_helper import TemporalHelper 7 | from x_temporal.utils.multiprocessing import mrun 8 | 9 | 10 | parser = argparse.ArgumentParser(description='X-Temporal') 11 | parser.add_argument('--config', type=str, help='the path of config file') 12 | parser.add_argument("--shard_id", help="The shard id of current node, Starts from 0 to num_shards - 1", 13 | default=0, type=int) 14 | parser.add_argument("--num_shards", help="Number of shards using by the job", 15 | default=1, type=int) 16 | parser.add_argument("--init_method", help="Initialization method, includes TCP or shared file-system", 17 | default="tcp://localhost:9999", type=str) 18 | parser.add_argument('--dist_backend', default='nccl', type=str) 19 | 20 | def main(): 21 | args = parser.parse_args() 22 | 23 | with open(args.config) as f: 24 | config = yaml.load(f, Loader=yaml.FullLoader) 25 | 26 | config = EasyDict(config['config']) 27 | if config.gpus > 1: 28 | torch.multiprocessing.spawn( 29 | mrun, 30 | nprocs=config.gpus, 31 | args=(config.gpus, 32 | args.init_method, 33 | args.shard_id, 34 | args.num_shards, 35 | args.dist_backend, 36 | config, 37 | 'test', 38 | ), 39 | daemon=False) 40 | else: 41 | temporal_helper = TemporalHelper(config, inference_only=True) 42 | temporal_helper.evaluate() 43 | 44 | 45 | if __name__ == '__main__': 46 | torch.multiprocessing.set_start_method("forkserver") 47 | main() 48 | -------------------------------------------------------------------------------- /x_temporal/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from easydict import EasyDict 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from x_temporal.interface.temporal_helper import TemporalHelper 9 | from x_temporal.utils.multiprocessing import mrun 10 | 11 | 12 | parser = argparse.ArgumentParser(description='X-Temporal') 13 | parser.add_argument('--config', type=str, help='the path of config file') 14 | parser.add_argument("--shard_id", help="The shard id of current node, Starts from 0 to num_shards - 1", 15 | default=0, type=int) 16 | parser.add_argument("--num_shards", help="Number of shards using by the job", 17 | default=1, type=int) 18 | parser.add_argument("--init_method", help="Initialization method, includes TCP or shared file-system", 19 | default="tcp://localhost:9999", type=str) 20 | parser.add_argument('--dist_backend', default='nccl', type=str) 21 | 22 | def main(): 23 | args = parser.parse_args() 24 | 25 | with open(args.config) as f: 26 | config = yaml.load(f, Loader=yaml.FullLoader) 27 | 28 | 29 | config = EasyDict(config['config']) 30 | if config.gpus > 1: 31 | torch.multiprocessing.spawn( 32 | mrun, 33 | nprocs=config.gpus, 34 | args=(config.gpus, 35 | args.init_method, 36 | args.shard_id, 37 | args.num_shards, 38 | args.dist_backend, 39 | config, 40 | 'train', 41 | ), 42 | daemon=False) 43 | else: 44 | temporal_helper = TemporalHelper(config) 45 | temporal_helper.train() 46 | 47 | 48 | if __name__ == '__main__': 49 | torch.multiprocessing.set_start_method("forkserver") 50 | main() 51 | -------------------------------------------------------------------------------- /x_temporal/core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | 6 | def softmax(scores): 7 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 8 | return es / es.sum(axis=-1)[..., None] 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | """Computes the precision@k for the specified values of k""" 32 | maxk = max(topk) 33 | batch_size = target.size(0) 34 | 35 | _, pred = output.topk(maxk, 1, True, True) 36 | pred = pred.t() 37 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 38 | 39 | res = [] 40 | for k in topk: 41 | correct_k = correct[:k].view(-1).float().sum(0) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | return res 44 | 45 | 46 | def save_bias(net, checkpoint_dir, epoch, iters): 47 | weight_data = {} 48 | bias_data = {} 49 | idx = 0 50 | for layer_id in range(1, 5): 51 | layer = getattr(net, 'layer' + str(layer_id)) 52 | blocks = list(layer.children()) 53 | for i, b in enumerate(blocks): 54 | bias_data[idx] = blocks[i].conv1.buffer[0] 55 | weight_data[idx] = blocks[i].conv1.buffer[1] 56 | idx += 1 57 | w_save_path = os.path.join( 58 | checkpoint_dir, 'data', '%d_%d_weight.npz' % 59 | (epoch, iters)) 60 | b_save_path = os.path.join( 61 | checkpoint_dir, 'data', '%d_%d_bias.npz' % 62 | (epoch, iters)) 63 | np.savez(w_save_path, weight_data) 64 | np.savez(b_save_path, bias_data) 65 | -------------------------------------------------------------------------------- /x_temporal/utils/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from x_temporal.core.transforms import * 2 | from x_temporal.core.dataset import VideoDataSet 3 | from torch.utils.data.distributed import DistributedSampler 4 | from torch.utils.data.sampler import RandomSampler 5 | 6 | 7 | def get_val_crop_transform(config, spatial_crops): 8 | crop_size = config.crop_size 9 | scale_size = config.scale_size 10 | if spatial_crops == 1: 11 | crop_aug = GroupCenterCrop(crop_size) 12 | elif spatial_crops == 3: 13 | crop_aug = GroupFullResSample(crop_size, scale_size, flip=False) 14 | elif spatial_crops == 5: 15 | crop_aug = GroupOverSample(crop_size, scale_size, flip=False) 16 | else: 17 | crop_aug = MultiGroupRandomCrop(crop_size, spatial_crops) 18 | return crop_aug 19 | 20 | 21 | def get_dataset(config, data_type, test_mode, transform, data_length, temporal_samples=1): 22 | dataset = VideoDataSet(config.root_dir, config[data_type].meta_file, 23 | num_segments=config.num_segments, 24 | new_length=data_length, 25 | modality=config.modality, 26 | image_tmpl=config.img_prefix, 27 | test_mode=False, 28 | random_shift=not test_mode, 29 | transform=transform, 30 | dense_sample=config.dense_sample, 31 | dense_sample_rate=config.get('dense_sample_rate', 1), 32 | video_source=config.video_source, 33 | temporal_samples=temporal_samples, 34 | multi_class=config.get('multi_class', False), 35 | ) 36 | return dataset 37 | 38 | 39 | def shuffle_dataset(loader, cur_epoch): 40 | assert isinstance( 41 | loader.sampler, (RandomSampler, DistributedSampler) 42 | ), "Sampler type '{}' not supported".format(type(loader.sampler)) 43 | # RandomSampler handles shuffling automatically 44 | if isinstance(loader.sampler, DistributedSampler): 45 | # DistributedSampler shuffles data based on epoch 46 | loader.sampler.set_epoch(cur_epoch) 47 | -------------------------------------------------------------------------------- /x_temporal/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from x_temporal.interface.temporal_helper import TemporalHelper 4 | 5 | def mrun( 6 | local_rank, num_proc, init_method, shard_id, num_shards, backend, config, mode 7 | ): 8 | """ 9 | Runs a function from a child process. 10 | Args: 11 | local_rank (int): rank of the current process on the current machine. 12 | num_proc (int): number of processes per machine. 13 | init_method (string): method to initialize the distributed training. 14 | TCP initialization: equiring a network address reachable from all 15 | processes followed by the port. 16 | Shared file-system initialization: makes use of a file system that 17 | is shared and visible from all machines. The URL should start with 18 | file:// and contain a path to a non-existent file on a shared file 19 | system. 20 | shard_id (int): the rank of the current machine. 21 | num_shards (int): number of overall machines for the distributed 22 | training job. 23 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 24 | supports, each with different capabilities. Details can be found 25 | here: 26 | https://pytorch.org/docs/stable/distributed.html 27 | config (CfgNode): configs. 28 | """ 29 | # Initialize the process group. 30 | world_size = num_proc * num_shards 31 | rank = shard_id * num_proc + local_rank 32 | 33 | try: 34 | torch.distributed.init_process_group( 35 | backend=backend, 36 | init_method=init_method, 37 | world_size=world_size, 38 | rank=rank, 39 | ) 40 | except Exception as e: 41 | raise e 42 | 43 | torch.cuda.set_device(local_rank) 44 | if mode == 'train': 45 | temporal_helper = TemporalHelper(config) 46 | temporal_helper.train() 47 | else: 48 | temporal_helper = TemporalHelper(config, inference_only=True) 49 | temporal_helper.evaluate() 50 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/src/shift_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "shift_cuda.h" 3 | #include "cuda/shift_kernel_cuda.h" 4 | 5 | extern THCState *state; 6 | void shift_featuremap_cuda_forward(THCudaTensor *data, THCudaIntTensor *shift, THCudaTensor *out) 7 | { 8 | THArgCheck(THCudaTensor_isContiguous(state, data), 1, "data tensor has to be contiguous"); 9 | THArgCheck(THCudaTensor_isContiguous(state, shift), 1, "shift tensor has to be contiguous"); 10 | 11 | int batch_size = THCudaTensor_size(state, data, 0); 12 | int channels = THCudaTensor_size(state, data, 2); 13 | int tsize = THCudaTensor_size(state, data, 1); 14 | int hwsize = THCudaTensor_size(state, data, 3); 15 | int groupsize = THCudaTensor_size(state, shift, 1); 16 | 17 | ShiftDataCudaForward(THCState_getCurrentStream(state), 18 | THCudaTensor_data(state, data), 19 | THCudaIntTensor_data(state, shift), 20 | batch_size, 21 | channels, 22 | tsize, 23 | hwsize, 24 | groupsize, 25 | THCudaTensor_data(state, out)); 26 | } 27 | 28 | void shift_featuremap_cuda_backward(THCudaTensor *grad_output, THCudaIntTensor *shift, THCudaTensor *grad_input) 29 | { 30 | THArgCheck(THCudaTensor_isContiguous(state, grad_output), 1, "data tensor has to be contiguous"); 31 | THArgCheck(THCudaTensor_isContiguous(state, shift), 1, "shift tensor has to be contiguous"); 32 | 33 | int batch_size = THCudaTensor_size(state, grad_output, 0); 34 | int channels = THCudaTensor_size(state, grad_output, 2); 35 | int tsize = THCudaTensor_size(state, grad_output, 1); 36 | int hwsize = THCudaTensor_size(state, grad_output, 3); 37 | int groupsize = THCudaTensor_size(state, shift, 1); 38 | 39 | ShiftDataCudaBackward(THCState_getCurrentStream(state), 40 | THCudaTensor_data(state, grad_output), 41 | THCudaIntTensor_data(state, shift), 42 | batch_size, 43 | channels, 44 | tsize, 45 | hwsize, 46 | groupsize, 47 | THCudaTensor_data(state, grad_input)); 48 | } 49 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/src/shift_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cuda/shift_kernel_cuda.h" 4 | #include 5 | 6 | extern THCState *state; 7 | at::Tensor shift_featuremap_cuda_forward(const at::Tensor &data, const at::Tensor &shift, const at::Tensor &out) 8 | { 9 | THArgCheck(data.is_contiguous(), 1, "data tensor has to be contiguous"); 10 | THArgCheck(shift.is_contiguous(), 1, "shift tensor has to be contiguous"); 11 | 12 | int batch_size = data.size(0); 13 | int channels = data.size(2); 14 | int tsize = data.size(1); 15 | int hwsize = data.size(3); 16 | int groupsize = shift.size(1); 17 | 18 | ShiftDataCudaForward(THCState_getCurrentStream(state), 19 | data.data(), 20 | shift.data(), 21 | batch_size, 22 | channels, 23 | tsize, 24 | hwsize, 25 | groupsize, 26 | out.data()); 27 | return out; 28 | } 29 | 30 | at::Tensor shift_featuremap_cuda_backward(const at::Tensor &grad_output, const at::Tensor &shift, const at::Tensor &grad_input) 31 | { 32 | THArgCheck(grad_output.is_contiguous(), 1, "data tensor has to be contiguous"); 33 | THArgCheck(shift.is_contiguous(), 1, "shift tensor has to be contiguous"); 34 | 35 | int batch_size = grad_output.size(0); 36 | int channels = grad_output.size(2); 37 | int tsize = grad_output.size(1); 38 | int hwsize = grad_output.size(3); 39 | int groupsize = shift.size(1); 40 | 41 | ShiftDataCudaBackward(THCState_getCurrentStream(state), 42 | grad_output.data(), 43 | shift.data(), 44 | batch_size, 45 | channels, 46 | tsize, 47 | hwsize, 48 | groupsize, 49 | grad_input.data()); 50 | return grad_input; 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("shift_featuremap_cuda_forward", &shift_featuremap_cuda_forward, "shift_featuremap_cuda_forward"); 55 | m.def("shift_featuremap_cuda_backward", &shift_featuremap_cuda_backward, "shift_featuremap_cuda_backward"); 56 | } 57 | 58 | -------------------------------------------------------------------------------- /tools/gen_label_sthv2.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V2 8 | 9 | import os 10 | import json 11 | 12 | if __name__ == '__main__': 13 | dataset_name = 'something-something-v2' # 'jester-v1' 14 | with open('%s-labels.json' % dataset_name) as f: 15 | data = json.load(f) 16 | categories = [] 17 | for i, (cat, idx) in enumerate(data.items()): 18 | assert i == int(idx) # make sure the rank is right 19 | categories.append(cat) 20 | 21 | with open('category.txt', 'w') as f: 22 | f.write('\n'.join(categories)) 23 | 24 | dict_categories = {} 25 | for i, category in enumerate(categories): 26 | dict_categories[category] = i 27 | 28 | files_input = [ 29 | '%s-validation.json' % 30 | dataset_name, 31 | '%s-train.json' % 32 | dataset_name, 33 | '%s-test.json' % 34 | dataset_name] 35 | files_output = [ 36 | 'val_videofolder.txt', 37 | 'train_videofolder.txt', 38 | 'test_videofolder.txt'] 39 | for (filename_input, filename_output) in zip(files_input, files_output): 40 | with open(filename_input) as f: 41 | data = json.load(f) 42 | folders = [] 43 | idx_categories = [] 44 | for item in data: 45 | folders.append(item['id']) 46 | if 'test' not in filename_input: 47 | idx_categories.append( 48 | dict_categories[item['template'].replace('[', '').replace(']', '')]) 49 | else: 50 | idx_categories.append(0) 51 | output = [] 52 | for i in range(len(folders)): 53 | curFolder = folders[i] 54 | curIDX = idx_categories[i] 55 | # counting the number of frames in each video folders 56 | dir_files = os.listdir( 57 | os.path.join( 58 | '../something/v2/20bn-something-something-v2-frames', 59 | curFolder)) 60 | output.append('%s %d %d' % (curFolder, len(dir_files), curIDX)) 61 | print('%d/%d' % (i, len(folders))) 62 | with open(filename_output, 'w') as f: 63 | f.write('\n'.join(output)) 64 | -------------------------------------------------------------------------------- /ModelZoo.md: -------------------------------------------------------------------------------- 1 | #### TSN 2 | 3 | | Model | Dataset | Frames | Input | Top-1 | Top-1\* | mAP | mAP\* | Link | 4 | | :-------- | :----------- | :----- | -------- | :---- | ------- | :--- | ---- | ------------------------------------------------------------ | 5 | | resnet101 | MMit | 5 | 224\*224 | - | - | 58.9 | 60.7 | [model](https://drive.google.com/open?id=1fM53qYCceZEpdtnc06XmXjyMrb7Is7_a) | 6 | | resnet50 | Kinetics-600 | 8 | 224\*224 | 67.5 | 70.0 | - | - | [model](https://drive.google.com/open?id=1PWiCd15_VnBAwh3n-zzqzbGi6xVIhAeN) | 7 | 8 | 9 | 10 | 11 | #### TIN 12 | 13 | | Model | Dataset | | Input | Top-1 | Top-1\* | mAP | mAP\* | Link | 14 | | :-------- | :---------- | ---- | :------- | :---- | :---- | :--- | ---- | ------------------------------------------------------------ | 15 | | resnet50 | MMit | 8 | 224\*224 | - | - | 62.2 | 62.8 | [model](https://drive.google.com/open?id=1f1kXH0cv7rJyc590ksasQ4GCVD9B-Jx8) | 16 | | resnet50 | MMit | 16 | 224\*224 | - | - | 62.5 | 62.9 | [model](https://drive.google.com/open?id=1Tqsfqqol5udoVX0KhnexGZzzGZs0G9MZ) | 17 | | resnet101 | MMit | 8 | 224\*224 | - | - | 62.2 | 63.0 | [model](https://drive.google.com/open?id=140dJeXaUVvqnLyI8h4wEYxxgLpXRBScF) | 18 | | resnet50 | Something | 8 | 224\*224 | 46.0 | 47.1 | - | - | [model](https://drive.google.com/open?id=1xibYXjyvOsteoJNmSylXD9E7MKfIQYJQ) | 19 | | resnet101 | Kinetics-700 | 8 | 224\*224 | 61.9 | 64.3 | - | - | [model](https://drive.google.com/file/d/11HhSJSkrU6_NMyvm2gvAFSM0vy_yr-eg/view?usp=sharing) | 20 | 21 | 22 | 23 | 24 | #### SlowFast 25 | 26 | | Model | Dataset | Frames | Input | Top-1 | Top-1\* | mAP | mAP\* | Link | 27 | | :---------- | :----------- | :---------- | -------- | :---- | :---- | :--- | ---- | ------- | 28 | | SlowFast101 | Kinetics-700 | 64(8 \* 8) | 112\*112 | - | 65.2 | - | - | [model](https://drive.google.com/open?id=1IITbtSIAIfhHiZPtwB5GtSzq2evp20Ga) | 29 | | SlowFast101 | MMit | 64(8 \* 8) | 112\*112 | - | - | 59.9 | 61.5 | [model](https://drive.google.com/open?id=1dDilpoOGFpLql0a5M8gyEGkbX4JtpfRN) | 30 | | SlowFast50 | Kinetics-600 | 64(8 \* 8) | 112\*112 | 70.0 | 77.5 | - | - | [model](https://drive.google.com/open?id=1QPh3tKH9VzuaHr0oG3va3yDdKeomqLkm) | 31 | | SlowFast50 | Kinetics-600 | 64(8 \* 8) | 224\*224 | 72.3 | 79.8 | - | - | [model](https://drive.google.com/open?id=1WnuJxNHv1E81rtP-GNviVhIOVQffvl2s) | 32 | 33 | 34 | 35 | **\* :** Means using multi crops and multi clips (3 * 10) when testing 36 | 37 | **TSM** Models can refer to [Github](https://github.com/mit-han-lab/temporal-shift-module) 38 | 39 | **MMit :** Multi-Moments in Time 40 | -------------------------------------------------------------------------------- /tools/vid2img_kinetics.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | from __future__ import print_function, division 7 | import os 8 | import sys 9 | import subprocess 10 | from multiprocessing import Pool 11 | from tqdm import tqdm 12 | 13 | n_thread = 100 14 | 15 | 16 | def vid2jpg(file_name, class_path, dst_class_path): 17 | if '.avi' not in file_name: 18 | return 19 | name, ext = os.path.splitext(file_name) 20 | dst_directory_path = os.path.join(dst_class_path, name) 21 | 22 | video_file_path = os.path.join(class_path, file_name) 23 | try: 24 | if os.path.exists(dst_directory_path): 25 | if not os.path.exists(os.path.join( 26 | dst_directory_path, 'img_00001.jpg')): 27 | subprocess.call( 28 | 'rm -r \"{}\"'.format(dst_directory_path), 29 | shell=True) 30 | print('remove {}'.format(dst_directory_path)) 31 | os.mkdir(dst_directory_path) 32 | else: 33 | print('*** convert has been done: {}'.format(dst_directory_path)) 34 | return 35 | else: 36 | os.mkdir(dst_directory_path) 37 | except BaseException: 38 | print(dst_directory_path) 39 | return 40 | cmd = 'ffmpeg -i \"{}\" -threads 1 -vf scale=-1:331 -q:v 0 \"{}/img_%05d.jpg\"'.format( 41 | video_file_path, dst_directory_path) 42 | # print(cmd) 43 | subprocess.call(cmd, shell=True, 44 | stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 45 | 46 | 47 | def class_process(dir_path, dst_dir_path, class_name): 48 | print('*' * 20, class_name, '*' * 20) 49 | class_path = os.path.join(dir_path, class_name) 50 | if not os.path.isdir(class_path): 51 | print('*** is not a dir {}'.format(class_path)) 52 | return 53 | 54 | dst_class_path = os.path.join(dst_dir_path, class_name) 55 | if not os.path.exists(dst_class_path): 56 | os.mkdir(dst_class_path) 57 | 58 | vid_list = sorted(os.listdir(class_path)) 59 | p = Pool(n_thread) 60 | from functools import partial 61 | worker = partial( 62 | vid2jpg, 63 | class_path=class_path, 64 | dst_class_path=dst_class_path) 65 | for _ in tqdm(p.imap_unordered(worker, vid_list), total=len(vid_list)): 66 | pass 67 | # p.map(worker, vid_list) 68 | p.close() 69 | p.join() 70 | 71 | print('\n') 72 | 73 | 74 | if __name__ == "__main__": 75 | dir_path = sys.argv[1] 76 | dst_dir_path = sys.argv[2] 77 | 78 | class_list = sorted(os.listdir(dir_path)) 79 | for class_name in class_list: 80 | class_process(dir_path, dst_dir_path, class_name) 81 | 82 | class_name = 'test' 83 | class_process(dir_path, dst_dir_path, class_name) 84 | -------------------------------------------------------------------------------- /tools/gen_label_sthv1.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V1 8 | 9 | import os 10 | 11 | if __name__ == '__main__': 12 | dataset_name = 'something-something-v1' # 'jester-v1' 13 | with open('%s-labels.csv' % dataset_name) as f: 14 | lines = f.readlines() 15 | categories = [] 16 | for line in lines: 17 | line = line.rstrip() 18 | categories.append(line) 19 | categories = sorted(categories) 20 | with open('category.txt', 'w') as f: 21 | f.write('\n'.join(categories)) 22 | 23 | dict_categories = {} 24 | for i, category in enumerate(categories): 25 | dict_categories[category] = i 26 | 27 | ''' 28 | files_input = ['%s-validation.csv' % dataset_name, '%s-train.csv' % dataset_name] 29 | files_output = ['val_videofolder.txt', 'train_videofolder.txt'] 30 | for (filename_input, filename_output) in zip(files_input, files_output): 31 | with open(filename_input) as f: 32 | lines = f.readlines() 33 | folders = [] 34 | idx_categories = [] 35 | for line in lines: 36 | line = line.rstrip() 37 | items = line.split(';') 38 | folders.append(items[0]) 39 | idx_categories.append(dict_categories[items[1]]) 40 | output = [] 41 | for i in range(len(folders)): 42 | curFolder = folders[i] 43 | curIDX = idx_categories[i] 44 | # counting the number of frames in each video folders 45 | dir_files = os.listdir(os.path.join('../something/v1/20bn-something-something-v1', curFolder)) 46 | output.append('%s %d %d' % ('' + curFolder, len(dir_files), curIDX)) 47 | print('%d/%d' % (i, len(folders))) 48 | with open(filename_output, 'w') as f: 49 | f.write('\n'.join(output)) 50 | ''' 51 | filename_input = '%s-test.csv' % dataset_name 52 | filename_output = 'test_videofolder.txt' 53 | with open(filename_input) as f: 54 | lines = f.readlines() 55 | folders = [] 56 | idx_categories = [] 57 | for line in lines: 58 | line = line.rstrip() 59 | folders.append(line) 60 | idx_categories.append(0) 61 | output = [] 62 | for i in range(len(folders)): 63 | curFolder = folders[i] 64 | curIDX = idx_categories[i] 65 | dir_files = os.listdir( 66 | os.path.join( 67 | '../something/v1/20bn-something-something-v1', 68 | curFolder)) 69 | output.append('%s %d %d' % ('' + curFolder, len(dir_files), curIDX)) 70 | print('%d/%d' % (i, len(folders))) 71 | with open(filename_output, 'w') as f: 72 | f.write('\n'.join(output)) 73 | -------------------------------------------------------------------------------- /Configuration.md: -------------------------------------------------------------------------------- 1 | # Train 2 | ## example for 3D model 3 | ```yaml 4 | version: 1.0 # version 5 | config: 6 |   dataset: 7 |     workers: 3 # number of workers per process 8 |     num_class: 102 # Total number of dataset categories 9 |     num_segments: 16 # input frames 10 |     batch_size: 32 11 |     img_prefix: 'image _ {: 05d} .jpg' # If you read RGB frames as input, you need to set this parameter to define its naming mode 12 |     video_source: False # Whether to directly read the video 13 |     dense_sample: True # Whether the data sampling is dense sampling (or uniform sampling) 14 |     modality: RGB # RGB, FLOW 15 |     flow_prefix: '' 16 |     root_dir: / path # The root directory where the dataset files is located 17 |     flip: True # Use flip as augmentation 18 |     dense_sample_rate: 2 # dense sampling rate (sample every n frames) 19 |     input_mean: [0.485, 0.456, 0.406] # comes from imagenet params 20 |     input_std: [0.229, 0.224, 0.225] 21 |     crop_size: 112 22 |     scale_size: 128 # The size after resize the short side of the frame when augmentation 23 |     train: 24 |       meta_file: / path 25 |     val: 26 |       meta_file: / path 27 |     test: 28 |       meta_file: / path 29 | 30 |   net: 31 |     arch: stresnet18 # model name and depth 32 |     model_type: 3D # 2D or 3D 33 |     dropout: 0.0 34 |     max_pooling: True # Use maxpooling layer after Conv1 to reduce the spatial size to 1/2 (only work at R(2+1)D models) 35 | 36 |   trainer: 37 |     print_freq: 20 # output log every n iter 38 |     eval_freq: 5 # eval every n epochs and output log 39 |     epochs: 120 # total training epochs 40 |     start_epoch: 0 41 |     loss_type: nll 42 |     no_partial_bn: False # FreezeBN (currently only for 2D models) 43 |     clip_gradient: 20 # Gradient crop 44 |     lr_scheduler: # Configuration can refer to pytorch 45 |       warmup_epochs: 10 46 |       type: CosineAnnealingLR 47 |       kwargs: 48 |         T_max: 120 49 |     optimizer: # Configuration can refer to pytorch 50 |       type: SGD 51 |       kwargs: 52 |         lr: 0.4 53 |         momentum: 0.9 54 |         weight_decay: 0.0005 55 |         nesterov: True 56 | 57 | 58 |   saver: 59 |     save_dir: 'checkpoint /' # checkpoint save path 60 |     pretrain_model: '/ path' # Read pretrain model path 61 |     resume_model: '/ path' # resume model path 62 | 63 | ``` 64 | 65 | ## example for 2D model (since most of them are the same, only the differences are listed here) 66 | ```yaml 67 |   net: 68 |     arch: resnet50 69 |     model_type: 2D 70 |     shift: True # TSM model switch 71 |     shift_div: 8 72 |     tin: False # TIN model switch 73 |     consensus_type: avg # The consensus function used when calculating each frame needs to be summarized, generally use avg 74 |     dropout: 0.8 75 |     non_local: False 76 |     pretrain: True # imagenet pretrain for 2D network 77 | 78 | ``` 79 | 80 | # Val & Test 81 | When testing video models, we often use intensive sampling and then average logits as the final result. 82 | ```yaml 83 |   evaluate: 84 |     spatial_crops: 3 # The number of crops in the spatial dimension 85 |     temporal_samples: 10 # The number of crops in the temporal dimension 86 | 87 | # Finally, the number of samples used for testing in each video is 3 * 10 = 30 88 | ``` 89 | -------------------------------------------------------------------------------- /tools/gen_label_kinetics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | dataset_path = 'kinetics_600/train_frames' 5 | label_path = '' 6 | 7 | if __name__ == '__main__': 8 | with open('kinetics_label_map.txt') as f: 9 | categories = f.readlines() 10 | categories = [ 11 | c.strip().replace( 12 | ' ', 13 | '_').replace( 14 | '"', 15 | '').replace( 16 | '(', 17 | '').replace( 18 | ')', 19 | '').replace( 20 | "'", 21 | '') for c in categories] 22 | assert len(set(categories)) == 600 23 | dict_categories = {} 24 | for i, category in enumerate(categories): 25 | dict_categories[category] = i 26 | 27 | print(dict_categories) 28 | 29 | files_input = ['kinetics-600_val.csv', 'kinetics-600_train.csv'] 30 | files_output = ['val_videofolder.txt', 'train_videofolder.txt'] 31 | for (filename_input, filename_output) in zip(files_input, files_output): 32 | count_cat = {k: 0 for k in dict_categories.keys()} 33 | with open(os.path.join(label_path, filename_input)) as f: 34 | lines = f.readlines()[1:] 35 | folders = [] 36 | idx_categories = [] 37 | categories_list = [] 38 | for line in lines: 39 | line = line.rstrip() 40 | items = line.split(',') 41 | st = int(items[2]) 42 | et = int(items[3]) 43 | folders.append(items[1] + '_' + "%06d" % st + '_' + "%06d" % et) 44 | this_catergory = items[0].replace( 45 | ' ', 46 | '_').replace( 47 | '"', 48 | '').replace( 49 | '(', 50 | '').replace( 51 | ')', 52 | '').replace( 53 | "'", 54 | '') 55 | categories_list.append(this_catergory) 56 | idx_categories.append(dict_categories[this_catergory]) 57 | count_cat[this_catergory] += 1 58 | print(max(count_cat.values())) 59 | 60 | assert len(idx_categories) == len(folders) 61 | missing_folders = [] 62 | output = [] 63 | for i in range(len(folders)): 64 | curFolder = folders[i] 65 | curIDX = idx_categories[i] 66 | # counting the number of frames in each video folders 67 | img_dir = os.path.join(dataset_path, categories_list[i], curFolder) 68 | if not os.path.exists(img_dir): 69 | missing_folders.append(img_dir) 70 | # print(missing_folders) 71 | else: 72 | dir_files = os.listdir(img_dir) 73 | output.append( 74 | '%s %d %d' % 75 | (os.path.join( 76 | categories_list[i], 77 | curFolder), 78 | len(dir_files), 79 | curIDX)) 80 | print( 81 | '%d/%d, missing %d' % 82 | (i, len(folders), len(missing_folders))) 83 | with open(os.path.join(label_path, filename_output), 'w') as f: 84 | f.write('\n'.join(output)) 85 | with open(os.path.join(label_path, 'missing_' + filename_output), 'w') as f: 86 | f.write('\n'.join(missing_folders)) 87 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import io 5 | import re 6 | import glob 7 | import os 8 | import subprocess 9 | 10 | import numpy as np 11 | import torch 12 | from setuptools import Extension 13 | from setuptools import find_packages 14 | from setuptools import setup 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | def _find_cuda_home(): 19 | # guess rule 3 of torch.utils.cpp_extension 20 | nvcc = subprocess.check_output(['which', 'nvcc']).decode().rstrip('\r\n') 21 | cuda_home = os.path.dirname(os.path.dirname(nvcc)) 22 | print(f'find cuda home:{cuda_home}') 23 | return cuda_home 24 | 25 | 26 | # remember to overwrite PyTorch auto-detected cuda home which 27 | # may not be our expected 28 | torch.utils.cpp_extension.CUDA_HOME = _find_cuda_home() 29 | CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME 30 | 31 | CORE_DIR = 'x_temporal' 32 | EXT_DIR = 'cuda_shift' 33 | 34 | def recursive_glob(base_dir, pattern): 35 | files = [] 36 | for root, subdirs, subfiles in os.walk(base_dir): 37 | files += glob.glob(os.path.join(root, pattern)) 38 | return files 39 | 40 | def get_extensions(): 41 | # this_dir = os.path.dirname(os.path.abspath(__file__)) 42 | sources_dir = os.path.join(CORE_DIR, EXT_DIR, "src") 43 | 44 | source_cpu = recursive_glob(sources_dir, "*.cpp") 45 | source_cuda = recursive_glob(sources_dir, "*.cu") 46 | 47 | sources = source_cpu 48 | extension = CppExtension 49 | 50 | extra_compile_args = {"cxx": []} 51 | define_macros = [] 52 | assert torch.cuda.is_available() and CUDA_HOME is not None 53 | 54 | if torch.cuda.is_available() and CUDA_HOME is not None: 55 | extension = CUDAExtension 56 | sources += source_cuda 57 | define_macros += [("WITH_CUDA", None)] 58 | extra_compile_args["nvcc"] = [ 59 | "-DCUDA_HAS_FP16=1", 60 | "-D__CUDA_NO_HALF_OPERATORS__", 61 | "-D__CUDA_NO_HALF_CONVERSIONS__", 62 | "-D__CUDA_NO_HALF2_OPERATORS__", 63 | ] 64 | 65 | # sources = [os.path.join(extensions_dir, s) for s in sources] 66 | print(f'sources:{sources}') 67 | 68 | include_dirs = [sources_dir] 69 | 70 | ext_modules = [ 71 | extension( 72 | f"{CORE_DIR}.{EXT_DIR}._C", 73 | sources, 74 | include_dirs=include_dirs, 75 | define_macros=define_macros, 76 | extra_compile_args=extra_compile_args, 77 | ) 78 | ] 79 | 80 | return ext_modules 81 | 82 | with io.open("x_temporal/__init__.py", "rt", encoding="utf8") as f: 83 | version = re.search(r'__version__ = "(\D*)(.*?)"', f.read(), re.M).group(2) 84 | 85 | setup( 86 | name="x_temporal", 87 | version=version, 88 | author="X-Lab Temporal Team", 89 | url="http://gitlab.bj.sensetime.com/spring-ce/element/x-temporal", 90 | description="Video Understanding Framework in Distributed PyTorch", 91 | author_email="spring-support@senstime.com", 92 | package_data={ 93 | }, 94 | packages=find_packages(exclude=( 95 | "scripts", 96 | "experiments" 97 | )), 98 | ext_modules=get_extensions(), 99 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 100 | ) 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /x_temporal/utils/model_helper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import torch 5 | 6 | logger = logging.getLogger('global') 7 | 8 | 9 | def load_state_dict(model, other_state_dict, strict=False): 10 | """ 11 | 1. load resume model or pretained detection model 12 | 2. load pretrained clssification model 13 | """ 14 | def remove_prefix(state_dict, prefix): 15 | """Old style model is stored with all names of parameters share common prefix 'module.'""" 16 | def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x 17 | return {f(key): value for key, value in state_dict.items()} 18 | 19 | def add_prefix(state_dict, prefix): 20 | def f(x): return 'module.' + x 21 | return {f(key): value for key, value in state_dict.items()} 22 | 23 | logger.info( 24 | "try to load the whole resume model or pretrained model...") 25 | 26 | model_state_dict = model.state_dict() 27 | model_keys = model_state_dict.keys() 28 | other_keys = other_state_dict.keys() 29 | 30 | model_key_sample = list(model_keys)[0] 31 | other_key_sample = list(other_keys)[0] 32 | if model_key_sample.startswith('module') and not other_key_sample.startswith('module'): 33 | other_state_dict = add_prefix(other_state_dict, 'module.') 34 | elif not model_key_sample.startswith('module') and other_key_sample.startswith('module'): 35 | other_state_dict = remove_prefix(other_state_dict, 'module.') 36 | 37 | other_keys = other_state_dict.keys() 38 | shared_keys, unexpected_keys, missing_keys \ 39 | = check_keys(model_keys, other_keys) 40 | 41 | incompatible_keys = set([]) 42 | for key in other_keys: 43 | if key in model_keys: 44 | if model_state_dict[key].shape != other_state_dict[key].shape: 45 | incompatible_keys.add(key) 46 | 47 | for key in incompatible_keys: 48 | other_state_dict.pop(key) 49 | unexpected_keys = unexpected_keys | incompatible_keys 50 | model.load_state_dict(other_state_dict, strict=strict) 51 | 52 | num_share_keys = len(shared_keys) 53 | display_info("model", shared_keys, unexpected_keys, missing_keys) 54 | if num_share_keys == 0: 55 | logger.info( 56 | 'failed to load the whole detection model directly,' 57 | 'try to load each part seperately...') 58 | for mname, module in model.named_children(): 59 | module.load_state_dict(other_state_dict, strict=strict) 60 | module_keys = module.state_dict().keys() 61 | other_keys = other_state_dict.keys() 62 | 63 | # check and display info module by module 64 | shared_keys, unexpected_keys, missing_keys, \ 65 | = check_keys(module_keys, other_keys) 66 | display_info(mname, shared_keys, unexpected_keys, missing_keys) 67 | num_share_keys += len(shared_keys) 68 | else: 69 | display_info("model", shared_keys, unexpected_keys, missing_keys) 70 | return num_share_keys 71 | 72 | 73 | def check_keys(own_keys, other_keys): 74 | own_keys = set(own_keys) 75 | other_keys = set(other_keys) 76 | shared_keys = own_keys & other_keys 77 | unexpected_keys = other_keys - own_keys 78 | missing_keys = own_keys - other_keys 79 | return shared_keys, unexpected_keys, missing_keys 80 | 81 | 82 | def display_info(mname, shared_keys, unexpected_keys, missing_keys): 83 | info = "load {}:{} shared keys, {} unexpected keys, {} missing keys.".format( 84 | mname, len(shared_keys), len(unexpected_keys), len(missing_keys)) 85 | 86 | if len(missing_keys) > 0: 87 | info += "\nmissing keys are as follows:\n {}".format( 88 | "\n ".join(missing_keys)) 89 | logger.info(info) 90 | -------------------------------------------------------------------------------- /x_temporal/utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import copy 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | def format_cfg(cfg): 11 | """Format experiment config for friendly display""" 12 | 13 | def list2str(cfg): 14 | for key, value in cfg.items(): 15 | if isinstance(value, dict): 16 | cfg[key] = list2str(value) 17 | elif isinstance(value, list): 18 | if len(value) == 0 or isinstance(value[0], (int, float)): 19 | cfg[key] = str(value) 20 | else: 21 | for i, item in enumerate(value): 22 | if isinstance(item, dict): 23 | value[i] = list2str(item) 24 | cfg[key] = value 25 | return cfg 26 | 27 | cfg = list2str(copy.deepcopy(cfg)) 28 | json_str = json.dumps(cfg, indent=2, ensure_ascii=False).split(r"\n") 29 | json_str = [re.sub(r"(\"|,$|\{|\}|\[$)", "", line) 30 | for line in json_str if line.strip() not in "{}[]"] 31 | cfg_str = r"\n".join([line.rstrip() for line in json_str if line.strip()]) 32 | return cfg_str 33 | 34 | 35 | def accuracy(output, target, topk=(1,)): 36 | """Computes the precision@k for the specified values of k""" 37 | maxk = max(topk) 38 | batch_size = target.size(0) 39 | 40 | _, pred = output.topk(maxk, 1, True, True) 41 | pred = pred.t() 42 | correct = pred.eq(target.contiguous().view(1, -1).expand_as(pred)) 43 | 44 | res = [] 45 | for k in topk: 46 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 47 | res.append(correct_k.mul_(100.0 / batch_size)) 48 | return res 49 | 50 | 51 | class AverageMeter(object): 52 | """Computes and stores the average and current value""" 53 | 54 | def __init__(self, length=0): 55 | self.length = length 56 | self.reset() 57 | 58 | def reset(self): 59 | if self.length > 0: 60 | self.history = [] 61 | else: 62 | self.count = 0 63 | self.sum = 0.0 64 | self.val = 0.0 65 | self.avg = 0.0 66 | 67 | def update(self, val, num=1): 68 | if self.length > 0: 69 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements 70 | #assert num == 1 71 | self.history.append(val) 72 | if len(self.history) > self.length: 73 | del self.history[0] 74 | 75 | self.val = self.history[-1] 76 | self.avg = np.mean(self.history) 77 | else: 78 | self.val = val 79 | self.sum += val * num 80 | self.count += num 81 | self.avg = self.sum / self.count 82 | 83 | 84 | def load_checkpoint(ckpt_path): 85 | """Load state_dict from checkpoint""" 86 | 87 | def remove_prefix(state_dict, prefix): 88 | """Old style model is stored with all names of parameters share common prefix 'module.'""" 89 | def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x 90 | return {f(key): value for key, value in state_dict.items()} 91 | 92 | assert os.path.exists(ckpt_path), f'No such file: {ckpt_path}' 93 | device = torch.cuda.current_device() 94 | ckpt_dict = torch.load( 95 | ckpt_path, 96 | map_location=lambda storage, 97 | loc: storage.cuda(device)) 98 | 99 | # handle different storage format between pretrain vs resume 100 | if 'model' in ckpt_dict: 101 | state_dict = ckpt_dict['model'] 102 | elif 'state_dict' in ckpt_dict: 103 | state_dict = ckpt_dict['state_dict'] 104 | else: 105 | state_dict = ckpt_dict 106 | 107 | state_dict = remove_prefix(state_dict, 'module.') 108 | ckpt_dict['model'] = state_dict 109 | 110 | return ckpt_dict 111 | -------------------------------------------------------------------------------- /x_temporal/cuda_shift/src/cuda/shift_kernel_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "shift_kernel_cuda.h" 2 | #include 3 | #include 4 | 5 | #define CUDA_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 7 | i < (n); \ 8 | i += blockDim.x * gridDim.x) 9 | 10 | const int CUDA_NUM_THREADS = 1024; 11 | inline int GET_BLOCKS(const int N) 12 | { 13 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 14 | } 15 | 16 | __global__ void ShiftDataCudaForwardKernel(int n, float* data, int* shift, const int batch_size, const int channels, const int tsize, const int hwsize, const int groupsize, int groupchannel, float* out){ 17 | CUDA_KERNEL_LOOP(index, n) 18 | { 19 | const int hw_index = index % hwsize; 20 | const int j = (index / hwsize) % channels; 21 | 22 | const int n_index = (index / hwsize / channels) % batch_size; 23 | int group_id = j / groupchannel; 24 | int t_shift = shift[n_index * groupsize + group_id]; 25 | int offset = n_index * tsize * hwsize * channels + hwsize* j + hw_index; 26 | for(int i=0; i < tsize; i++) 27 | { 28 | int now_t = i + t_shift; 29 | int data_id = i * hwsize * channels + offset; 30 | if (now_t < 0 || now_t >= tsize) { 31 | continue; 32 | } 33 | int out_id = now_t * hwsize * channels +offset; 34 | out[out_id] = data[data_id]; 35 | } 36 | } 37 | } 38 | 39 | 40 | __global__ void ShiftDataCudaBackwardKernel(int n, float* data, int* shift, const int batch_size, const int channels, const int tsize, const int hwsize, const int groupsize, int groupchannel, float* out){ 41 | CUDA_KERNEL_LOOP(index, n) 42 | { 43 | const int hw_index = index % hwsize; 44 | const int j = (index / hwsize) % channels; 45 | const int n_index = (index / hwsize / channels) % batch_size; 46 | int group_id = j / groupchannel; 47 | int t_shift = shift[n_index * groupsize + group_id]; 48 | int offset = n_index * tsize * hwsize * channels + hwsize* j + hw_index; 49 | for(int i=0; i < tsize; i++) 50 | { 51 | int now_t = i - t_shift; 52 | int data_id = i * hwsize * channels + offset; 53 | if (now_t < 0 || now_t >= tsize) { 54 | continue; 55 | } 56 | int out_id = now_t * hwsize * channels +offset; 57 | out[out_id] = data[data_id]; 58 | 59 | } 60 | } 61 | } 62 | 63 | void ShiftDataCudaForward(cudaStream_t stream, 64 | float* data, 65 | int* shift, 66 | const int batch_size, 67 | const int channels, 68 | const int tsize, 69 | const int hwsize, 70 | const int groupsize, 71 | float* out){ 72 | const int num_kernels = batch_size * hwsize * channels; 73 | int groupchannel = channels / groupsize; 74 | ShiftDataCudaForwardKernel<<>>(num_kernels, data, shift, batch_size, channels, tsize, hwsize, groupsize, groupchannel, out); 75 | } 76 | 77 | void ShiftDataCudaBackward(cudaStream_t stream, 78 | float* data, 79 | int* shift, 80 | const int batch_size, 81 | const int channels, 82 | const int tsize, 83 | const int hwsize, 84 | const int groupsize, 85 | float* out){ 86 | const int num_kernels = batch_size * hwsize * channels; 87 | int groupchannel = channels / groupsize; 88 | ShiftDataCudaBackwardKernel<<>>(num_kernels, data, shift, batch_size, channels, tsize, hwsize, groupsize, groupchannel, out); 89 | } 90 | 91 | -------------------------------------------------------------------------------- /x_temporal/utils/dist_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import shutil 4 | import torch 5 | import math 6 | import numpy as np 7 | from collections import defaultdict 8 | import torch.distributed as dist 9 | 10 | 11 | 12 | def load_state(path, model, optimizer=None): 13 | 14 | rank = get_rank() 15 | 16 | def map_func(storage, location): 17 | return storage.cuda() 18 | 19 | if os.path.isfile(path): 20 | if rank == 0: 21 | print("=> loading checkpoint '{}'".format(path)) 22 | 23 | checkpoint = torch.load(path, map_location=map_func) 24 | model.load_state_dict(checkpoint['state_dict'], strict=False) 25 | 26 | if rank == 0: 27 | ckpt_keys = set(checkpoint['state_dict'].keys()) 28 | own_keys = set(model.state_dict().keys()) 29 | missing_keys = own_keys - ckpt_keys 30 | for k in missing_keys: 31 | print( 32 | 'caution: missing keys from checkpoint {}: {}'.format( 33 | path, k)) 34 | 35 | if optimizer is not None: 36 | best_prec1 = checkpoint['best_prec1'] 37 | last_iter = checkpoint['step'] 38 | optimizer.load_state_dict(checkpoint['optimizer']) 39 | if rank == 0: 40 | print( 41 | "=> also loaded optimizer from checkpoint '{}' (iter {})".format( 42 | path, last_iter)) 43 | return best_prec1, last_iter 44 | else: 45 | if rank == 0: 46 | print("=> no checkpoint found at '{}'".format(path)) 47 | 48 | 49 | def is_master_proc(num_gpus=8): 50 | """ 51 | Determines if the current process is the master process. 52 | """ 53 | if torch.distributed.is_initialized(): 54 | return dist.get_rank() % num_gpus == 0 55 | else: 56 | return True 57 | 58 | 59 | def all_gather(tensors): 60 | """ 61 | All gathers the provided tensors from all processes across machines. 62 | Args: 63 | tensors (list): tensors to perform all gather across all processes in 64 | all machines. 65 | """ 66 | 67 | gather_list = [] 68 | output_tensor = [] 69 | world_size = dist.get_world_size() 70 | for tensor in tensors: 71 | tensor_placeholder = [ 72 | torch.ones_like(tensor) for _ in range(world_size) 73 | ] 74 | dist.all_gather(tensor_placeholder, tensor, async_op=False) 75 | gather_list.append(tensor_placeholder) 76 | for gathered_tensor in gather_list: 77 | output_tensor.append(torch.cat(gathered_tensor, dim=0)) 78 | return output_tensor 79 | 80 | 81 | def all_reduce(tensor, average=True): 82 | """ 83 | All reduce the provided tensors from all processes across machines. 84 | Args: 85 | tensor (tensor): tensor to perform all reduce across all processes in 86 | all machines. 87 | average (bool): scales the reduced tensor by the number of overall 88 | processes across all machines. 89 | """ 90 | 91 | dist.all_reduce(tensor, async_op=False) 92 | if average: 93 | world_size = dist.get_world_size() 94 | tensor.mul_(1.0 / world_size) 95 | return tensor 96 | 97 | 98 | def get_world_size(): 99 | """ 100 | Get the size of the world. 101 | """ 102 | if not dist.is_available(): 103 | return 1 104 | if not dist.is_initialized(): 105 | return 1 106 | return dist.get_world_size() 107 | 108 | 109 | def get_rank(): 110 | """ 111 | Get the rank of the current process. 112 | """ 113 | if not dist.is_available(): 114 | return 0 115 | if not dist.is_initialized(): 116 | return 0 117 | return dist.get_rank() 118 | 119 | 120 | def synchronize(): 121 | """ 122 | Helper function to synchronize (barrier) among all processes when 123 | using distributed training 124 | """ 125 | if not dist.is_available(): 126 | return 127 | if not dist.is_initialized(): 128 | return 129 | world_size = dist.get_world_size() 130 | if world_size == 1: 131 | return 132 | dist.barrier() 133 | 134 | 135 | -------------------------------------------------------------------------------- /x_temporal/utils/lr_helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import copy 3 | 4 | import torch 5 | 6 | logger = logging.getLogger('global') 7 | 8 | 9 | class ExponentialWarmUpLR(object): 10 | """Scheduler that update learning rate exponentially 11 | """ 12 | 13 | def __init__(self, warmup_iter, init_lr, target_lr): 14 | self.lr_scale = target_lr / init_lr 15 | self.gamma = self.lr_scale**(1.0 / max(1, warmup_iter)) 16 | self.warmup_iter = warmup_iter 17 | 18 | def get_lr(self, last_epoch, base_lrs, optimizer): 19 | return [base_lr * self.gamma**last_epoch / 20 | self.lr_scale for base_lr in base_lrs] 21 | 22 | 23 | class LinearWarmUpLR(object): 24 | """Scheduler that update learning rate linearly 25 | """ 26 | 27 | def __init__(self, warmup_iter, init_lr, target_lr): 28 | self.lr_gap = target_lr - init_lr 29 | self.gamma = self.lr_gap / max(1, warmup_iter) 30 | self.warmup_iter = warmup_iter 31 | 32 | def get_lr(self, last_epoch, base_lrs, optimizer): 33 | return [base_lr + self.gamma * last_epoch - 34 | self.lr_gap for base_lr in base_lrs] 35 | 36 | 37 | _warmup_lr = { 38 | 'linear': LinearWarmUpLR, 39 | 'exp': ExponentialWarmUpLR 40 | } 41 | 42 | 43 | def build_warmup_scheduler(cfg_scheduler, optimizer, data_size, lr_scale): 44 | 45 | target_lr = [group.get('initial_lr', group['lr']) 46 | for group in optimizer.param_groups][0] 47 | warmup_epochs = cfg_scheduler.get('warmup_epochs', 0) 48 | # no linear scaling if no warmup 49 | if warmup_epochs > 0: 50 | init_lr = target_lr / float(lr_scale) 51 | else: 52 | init_lr = target_lr 53 | warmup_iter = int(warmup_epochs * data_size) 54 | 55 | warmup_type = cfg_scheduler.get('warmup_type', 'exp') 56 | assert warmup_type in _warmup_lr, f'warmup scheduler {warmup_type} not supported' 57 | 58 | return _warmup_lr[warmup_type](warmup_iter, init_lr, target_lr) 59 | 60 | 61 | def prepare_scheduler(cfg_scheduler, optimizer, data_size): 62 | """Convert epoch to iteration""" 63 | 64 | cfg = copy.deepcopy(cfg_scheduler) 65 | 66 | cfg['kwargs']['optimizer'] = optimizer 67 | if cfg['type'] == 'MultiStepLR': 68 | cfg['kwargs']['milestones'] = [ 69 | int(e * data_size) for e in cfg['kwargs']['milestones']] 70 | elif cfg['type'] == 'StepLR': 71 | cfg['kwargs']['step_size'] = cfg['kwargs']['step_size'] * data_size 72 | elif cfg['type'] == 'ReduceLROnPlateau': 73 | cfg['kwargs']['patience'] = cfg['kwargs']['patience'] * data_size 74 | elif cfg['type'] == 'CosineAnnealingLR': 75 | cfg['kwargs']['T_max'] = cfg['kwargs']['T_max'] * data_size 76 | else: 77 | raise NotImplementedError(f'{cfg} is not supported') 78 | scheduler = getattr(torch.optim.lr_scheduler, cfg['type']) 79 | return scheduler, cfg 80 | 81 | 82 | def build_scheduler(cfg_scheduler, optimizer, data_size, lr_scale): 83 | """ Build composed warmup scheduler and standard scheduler. 84 | There will be no linar scaling process if no warmup 85 | """ 86 | standard_scheduler_class, cfg_scheduler = prepare_scheduler( 87 | cfg_scheduler, optimizer, data_size) 88 | 89 | warmup_scheduler = build_warmup_scheduler( 90 | cfg_scheduler, optimizer, data_size, lr_scale) 91 | 92 | class ChainIterLR(standard_scheduler_class): 93 | """Unified scheduler that chains warmup scheduler and standard scheduler 94 | """ 95 | 96 | def __init__(self, *args, **kwargs): 97 | super(ChainIterLR, self).__init__(*args, **kwargs) 98 | 99 | def get_lr(self): 100 | if self.last_iter <= warmup_scheduler.warmup_iter: 101 | return warmup_scheduler.get_lr( 102 | self.last_iter, self.base_lrs, self.optimizer) 103 | else: 104 | return super(ChainIterLR, self).get_lr() 105 | 106 | @property 107 | def last_iter(self): 108 | return self.last_epoch 109 | 110 | return ChainIterLR(**cfg_scheduler['kwargs']) 111 | 112 | 113 | if __name__ == '__main__': 114 | import torchvision 115 | import os 116 | import sys 117 | model = torchvision.models.resnet18() 118 | params = [p for p in model.parameters() if p.requires_grad] 119 | optimizer = torch.optim.SGD(params, lr=0.1) 120 | cfg_scheduler = { 121 | 'warmup_epochs': 1, 122 | 'type': 'MultiStepLR', 123 | 'kwargs': { 124 | 'milestones': [10, 20], 125 | 'gamma': 0.1 126 | } 127 | } 128 | scheduler = build_scheduler( 129 | cfg_scheduler, 130 | optimizer, 131 | data_size=5, 132 | lr_scale=10) 133 | if len(sys.argv) > 1 and os.path.exists(sys.argv[1]): 134 | state = torch.load(sys.argv[1]) 135 | optimizer.load_state_dict(state['optimizer']) 136 | scheduler.load_state_dict(state['scheduler']) 137 | start_iter = scheduler.last_iter 138 | for i in range(start_iter, 120): 139 | if i % 30 == 0: 140 | state = { 141 | 'optimizer': optimizer.state_dict(), 142 | 'scheduler': scheduler.state_dict()} 143 | torch.save(state, f'iter{i}.pkl') 144 | scheduler.step() 145 | print(f'iter:{i}, lr:{scheduler.get_lr()}') 146 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # X-Temporal 2 | 3 | **Easily implement SOTA video understanding methods with PyTorch on multiple machines and GPUs** 4 | 5 | X-Temporal is an open source video understanding codebase from Sensetime X-Lab group that provides state-of-the-art video classification models, including papers "[Temporal Segment Networks](https://arxiv.org/abs/1608.00859)", "[Temporal Interlacing Network](https://arxiv.org/abs/2001.06499)", "[Temporal Shift Module](https://arxiv.org/abs/1811.08383)", "[ResNet 3D](https://arxiv.org/pdf/1711.11248)", "[SlowFast Networks for Video Recognition](https://arxiv.org/abs/1812.03982)", and "[Non-local Neural Networks](https://arxiv.org/abs/1711.07971)". 6 | 7 | *This repo includes all models and codes used in our 1st place solution in ICCV19-Multi Moments in Time Challenge [Challenge Website](http://moments.csail.mit.edu/results2019.html)* 8 | 9 | ## Introduction 10 | * Support popular video understanding frameworks 11 | * SlowFast 12 | * R(2+1)D 13 | * R3D 14 | * TSN 15 | * TIN 16 | * TSM 17 | * Support various datasets (Kinetics, Something2Something, Multi-Moments in Time...) 18 | * Take raw video as input 19 | * Take video RGB frames as input 20 | * Take video Flow frames as input 21 | * Support Multi-label dataset 22 | * High-performance and modular design can help rapid implementation and evaluation of novel video research ideas. 23 | 24 | 25 | 26 | ## Updates 27 | v0.1.0 (08/04/2020) 28 | > X-Temporal is online! 29 | 30 | ## Get started 31 | ### Prerequisites 32 | 33 | The code is built with following libraries: 34 | 35 | - [PyTorch](https://pytorch.org/) 1.0 or higher 36 | - [TensorboardX](https://github.com/lanpa/tensorboardX) 37 | - [tqdm](https://github.com/tqdm/tqdm.git) 38 | - [sklearn](https://github.com/scikit-learn/scikit-learn) 39 | - [scikit-learn](https://scikit-learn.org/stable/) 40 | - [decord](https://github.com/dmlc/decord) 41 | 42 | For extracting frames from video data, you may need [ffmpeg](https://www.ffmpeg.org/). 43 | 44 | ### Installation 45 | 1. clone repo 46 | ```bash 47 | git clone https://github.com/Sense-X/X-Temporal.git X-Temporal 48 | cd X-Temporal 49 | ``` 50 | 2. run the install script 51 | ```bash 52 | ./easy_setup.sh 53 | ``` 54 | 55 | 56 | ### Prepare dataset 57 | Each row in the meta file of the data set represents a video, which is divided into 3 columns, which are the picture folder, frame number, and category id after the frame extraction. For example as shown below: 58 | ``` 59 | abseiling/Tdd9inAW1VY_000361_000371 300 0 60 | zumba/x0KPHFRbzDo_000087_000097 300 599 61 | ``` 62 | 63 | You can also directly read the original video file. Decor library is used in X-Temporal code for real-time video frame extraction. 64 | ``` 65 | abseiling/Tdd9inAW1VY_000361_000371.mkv 300 0 66 | zumba/x0KPHFRbzDo_000087_000097.mkv 300 599 67 | ``` 68 | In the **tools** folder, scripts for extracting frames and generating data set meta files are provided. 69 | 70 | ### About multi-label classification 71 | The format of the multi-category data set is as follows, which are the video path, the number of frames, and the categories included. 72 | ``` 73 | trimming/getty-cutting-meat-cleaver-video-id163936215_13.mp4 90 144,246 74 | exercising/meta-935267_68.mp4 92 69 75 | cooking/yt-SSLy25MQb9g_307.mp4 91 264,311,7,188,246 76 | ``` 77 | 78 | YAML config: 79 | ``` 80 | trainer: 81 | loss_type: bce 82 | dataset: 83 | multi_class: True 84 | ``` 85 | 86 | ### Training 87 | 1. Create a folder for the experiment. 88 | ```bash 89 | cd /path/to/X-Temporal 90 | mkdir -p experiments/test 91 | ``` 92 | 93 | 2. New or copy config from existing experiment config. 94 | ```bash 95 | cp experiments/r2plus1d/default.config experiments/test 96 | cp experiments/r2plus1d/run.sh experiments/test 97 | ``` 98 | 99 | 3. Set up training scripts, where ROOT and cfg fiile may need to be changed according to specific settings 100 | ```bash 101 | T=`date +%m%d%H%M` 102 | ROOT=../.. 103 | cfg=default.yaml 104 | 105 | export PYTHONPATH=$ROOT:$PYTHONPATH 106 | 107 | python $ROOT/x_temporal/train.py --config $cfg | tee log.train.$T 108 | ``` 109 | 110 | 4. Start training. 111 | ```bash 112 | ./train.sh 113 | ``` 114 | 115 | ### Testing 116 | 1. Set the resume_model path in config. 117 | ```yaml 118 | saver: # Required. 119 | resume_model: checkpoints/ckpt_e13.pth # checkpoint to test 120 | ``` 121 | 2. Set the parameters in the evaluate in config, such as the need to use multiple crops on the spatial and temporal during the test to modify the specific parameters. (it is recommended to reduce the batchsize by the same proportion) 122 | ```yaml 123 | evaluate: 124 | spatial_crops: 3 125 | temporal_samples: 10 126 | ``` 127 | 3. Modify run.sh or create new test.sh, the main modification is to change train.py to test.py. The sample is as follows: 128 | ```bash 129 | T=`date +%m%d%H%M` 130 | ROOT=../.. 131 | cfg=default.yaml 132 | 133 | export PYTHONPATH=$ROOT:$PYTHONPATH 134 | 135 | python $ROOT/x_temporal/test.py --config $cfg | tee log.test.$T 136 | ``` 137 | 4. Start Testing 138 | ```bash 139 | ./test.sh 140 | ``` 141 | 142 | ## [ModelZoo](ModelZoo.md) 143 | 144 | 145 | ## LICENSE 146 | X-Temporal is released under the [MIT license](LICENSE). 147 | 148 | ## [Configuration details](Configuration.md) 149 | 150 | ## Reference 151 | Kindly cite our publications if this repo and algorithms help in your research. 152 | ``` 153 | @article{zhang2020top, 154 | title={Top-1 Solution of Multi-Moments in Time Challenge 2019}, 155 | author={Zhang, Manyuan and Shao, Hao and Song, Guanglu and Liu, Yu and Yan, Junjie}, 156 | journal={arXiv preprint arXiv:2003.05837}, 157 | year={2020} 158 | } 159 | 160 | @article{shao2020temporal, 161 | title={Temporal Interlacing Network}, 162 | author={Hao Shao and Shengju Qian and Yu Liu}, 163 | year={2020}, 164 | journal={AAAI}, 165 | } 166 | ``` 167 | 168 | ## Contributors 169 | X-Temporal is maintained by Hao Shao and ManYuan Zhang and [Yu Liu](http://liuyu.us/). 170 | -------------------------------------------------------------------------------- /x_temporal/core/tin.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import random 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from x_temporal.cuda_shift.rtc_wrap import ShiftFeatureFunc 12 | 13 | logger = logging.getLogger('global') 14 | 15 | def solve_sigmoid(x): 16 | return -math.log(1.0 / x - 1) 17 | 18 | 19 | def linear_sampler(data, bias): 20 | ''' 21 | data: N * T * C * H * W 22 | bias: N * T * Groups 23 | weight: N * T 24 | ''' 25 | N, T, C, H, W = data.shape 26 | bias_0 = torch.floor(bias).int() 27 | bias_1 = bias_0 + 1 28 | 29 | # N * T * C * H * W 30 | sf1 = ShiftFeatureFunc() 31 | sf2 = ShiftFeatureFunc() 32 | 33 | data = data.view(N, T, C, H * W).contiguous() 34 | data_0 = sf1(data, bias_0) 35 | data_1 = sf2(data, bias_1) 36 | 37 | w_0 = 1 - (bias - bias_0.float()) 38 | w_1 = 1 - w_0 39 | 40 | groupsize = bias.shape[1] 41 | w_0 = w_0[:, :, None].repeat(1, 1, C // groupsize) 42 | w_0 = w_0.view(w_0.size(0), -1) 43 | w_1 = w_1[:, :, None].repeat(1, 1, C // groupsize) 44 | w_1 = w_1.view(w_1.size(0), -1) 45 | 46 | w_0 = w_0[:, None, :, None] 47 | w_1 = w_1[:, None, :, None] 48 | 49 | out = w_0 * data_0 + w_1 * data_1 50 | out = out.view(N, T, C, H, W) 51 | 52 | return out 53 | 54 | 55 | class WeightConvNet(nn.Module): 56 | def __init__(self, in_channels, groups, n_segment): 57 | super(WeightConvNet, self).__init__() 58 | self.lastlayer = nn.Conv1d(in_channels, groups, 3, padding=1) 59 | self.groups = groups 60 | 61 | def forward(self, x): 62 | N, C, T = x.shape 63 | x = self.lastlayer(x) 64 | x = x.view(N, self.groups, T) 65 | x = x.permute(0, 2, 1) 66 | return x 67 | 68 | 69 | class BiasConvFc2Net(nn.Module): 70 | def __init__(self, in_channels, groups, 71 | n_segment, kernel_size=3, padding=1): 72 | super(BiasConvFc2Net, self).__init__() 73 | self.conv = nn.Conv1d(in_channels, 1, kernel_size, padding=padding) 74 | self.fc = nn.Linear(n_segment, n_segment) 75 | self.relu = nn.ReLU() 76 | self.lastlayer = nn.Linear(n_segment, groups) 77 | 78 | def forward(self, x): 79 | N, C, T = x.shape 80 | x = self.conv(x) 81 | x = x.view(N, T) 82 | x = self.relu(self.fc(x)) 83 | x = self.lastlayer(x) 84 | x = x.view(N, 1, -1) 85 | return x 86 | 87 | 88 | class BiasNet(nn.Module): 89 | def __init__(self, in_channels, groups, n_segment): 90 | super(BiasNet, self).__init__() 91 | self.sigmoid = nn.Sigmoid() 92 | self.net = BiasConvFc2Net(in_channels, groups, n_segment, 3, 1) 93 | self.net.lastlayer.bias.data[...] = 0.5108 94 | 95 | def forward(self, x): 96 | x = self.net(x) 97 | x = 4 * (self.sigmoid(x) - 0.5) 98 | return x 99 | 100 | 101 | class WeightNet(nn.Module): 102 | def __init__(self, in_channels, groups, n_segment): 103 | super(WeightNet, self).__init__() 104 | self.sigmoid = nn.Sigmoid() 105 | self.groups = groups * 2 106 | 107 | self.net = WeightConvNet(in_channels, groups, n_segment) 108 | 109 | self.net.lastlayer.bias.data[...] = 0 110 | 111 | def forward(self, x): 112 | x = self.net(x) 113 | x = 2 * self.sigmoid(x) 114 | return x 115 | 116 | 117 | class TemporalDeform(nn.Module): 118 | def __init__(self, in_channels, n_segment=3, shift_div=1): 119 | super(TemporalDeform, self).__init__() 120 | self.n_segment = n_segment 121 | self.shift_div = shift_div 122 | self.in_channels = in_channels 123 | 124 | self.biasnet = BiasNet(in_channels // shift_div, 2, n_segment) 125 | self.weightnet = WeightNet(in_channels // shift_div, 2, n_segment) 126 | 127 | @staticmethod 128 | def _init_weights(weights, std): 129 | fan_out = weights.size(0) 130 | fan_in = weights.size(1) * weights.size(2) 131 | w = np.random.normal(0.0, std, (fan_out, fan_in)) 132 | return torch.from_numpy(w.reshape(weights.size())) 133 | 134 | def forward(self, x): 135 | nt, c, h, w = x.size() 136 | n_batch = nt // self.n_segment 137 | fold = c // self.shift_div 138 | 139 | out = torch.zeros_like(x) 140 | x_def = x[:, :fold, :] 141 | 142 | x_def = x_def.view(n_batch, self.n_segment, fold, h, w) 143 | 144 | x_pooled = torch.mean(x_def, 3) 145 | x_pooled_1d = torch.mean(x_pooled, 3) 146 | x_pooled_1d = x_pooled_1d.permute(0, 2, 1).contiguous() 147 | # N * T * C 148 | 149 | x_bias = self.biasnet(x_pooled_1d).view(n_batch, -1) 150 | x_weight = self.weightnet(x_pooled_1d) 151 | 152 | x_bias = torch.cat([x_bias, -x_bias], 1) 153 | x_sa = linear_sampler(x_def, x_bias) 154 | 155 | x_weight = x_weight[:, :, :, None] 156 | x_weight = x_weight.repeat(1, 1, 2, fold // 2 // 2) 157 | x_weight = x_weight.view(x_weight.size(0), x_weight.size(1), -1) 158 | 159 | x_weight = x_weight[:, :, :, None, None] 160 | x_sa = x_sa * x_weight 161 | x_sa = x_sa.contiguous().view(nt, fold, h, w) 162 | 163 | out[:, :fold, :] = x_sa 164 | out[:, fold:, :] = x[:, fold:, :] 165 | return out 166 | 167 | 168 | class CombinNet(nn.Module): 169 | def __init__(self, net1, net2): 170 | super(CombinNet, self).__init__() 171 | self.net1 = net1 172 | self.net2 = net2 173 | 174 | def forward(self, x): 175 | x = self.net1(x) 176 | x = self.net2(x) 177 | return x 178 | 179 | 180 | def make_temporal_interlace(net, n_segment, place='blockres', shift_div=1): 181 | n_segment_list = [n_segment] * 4 182 | assert n_segment_list[-1] > 0 183 | logger.info('=> n_segment per stage: {}'.format(n_segment_list)) 184 | 185 | import torchvision 186 | n_round = 1 187 | if len(list(net.layer3.children())) >= 23: 188 | logger.info('=> Using n_round {} to insert temporal shift'.format(n_round)) 189 | 190 | def make_block_interlace(stage, this_segment, shift_div): 191 | blocks = list(stage.children()) 192 | logger.info('=> Processing stage with {} blocks residual'.format(len(blocks))) 193 | for i, b in enumerate(blocks): 194 | if i % n_round == 0: 195 | tds = TemporalDeform( 196 | b.conv1.in_channels, 197 | n_segment=this_segment, 198 | shift_div=shift_div) 199 | blocks[i].conv1 = CombinNet(tds, blocks[i].conv1) 200 | return nn.Sequential(*blocks) 201 | 202 | net.layer1 = make_block_interlace(net.layer1, n_segment_list[0], shift_div) 203 | net.layer2 = make_block_interlace(net.layer2, n_segment_list[1], shift_div) 204 | net.layer3 = make_block_interlace(net.layer3, n_segment_list[2], shift_div) 205 | net.layer4 = make_block_interlace(net.layer4, n_segment_list[3], shift_div) 206 | -------------------------------------------------------------------------------- /x_temporal/core/models_entry.py: -------------------------------------------------------------------------------- 1 | from x_temporal.core.models import TSN 2 | from x_temporal.core.transforms import * 3 | from x_temporal.models.stresnet import * 4 | from x_temporal.models.slowfast import * 5 | from x_temporal.models.resnet3D import * 6 | 7 | import torchvision 8 | 9 | 10 | def get_model(config): 11 | num_class = config.dataset.num_class 12 | dropout = config.net.dropout 13 | arch = config.net.arch 14 | 15 | if config.net.model_type == '2D': 16 | model = TSN(num_class, config.dataset.num_segments, config.dataset.modality, 17 | base_model=config.net.arch, 18 | consensus_type=config.net.consensus_type, 19 | dropout=config.net.dropout, 20 | img_feature_dim=config.net.img_feature_dim, 21 | partial_bn=not config.trainer.no_partial_bn, 22 | is_shift=config.net.get('shift', False), shift_div=config.net.get('shift_div', 8), 23 | non_local=config.net.get('non_local', False), 24 | tin=config.net.get('tin', False), 25 | pretrain=config.net.get('pretrain', True), 26 | ) 27 | 28 | elif config.net.model_type == '3D': 29 | if arch.startswith('stresnet'): 30 | model = globals()[arch](sample_size=config.dataset.crop_size, sample_duration=config.dataset.num_segments, 31 | num_classes=num_class, max_pooling=config.net.max_pooling, dropout=dropout) 32 | elif arch.startswith('sfresnet') or arch.startswith('resnet3D'): 33 | model = globals()[arch](sample_size=config.dataset.crop_size, sample_duration=config.dataset.num_segments, 34 | num_classes=num_class, dropout=dropout) 35 | else: 36 | raise ValueError("Not Found Arch: %s" % arch) 37 | 38 | cur_device = torch.cuda.current_device() 39 | model = model.cuda(device=cur_device) 40 | if config.gpus > 1: 41 | model = torch.nn.parallel.DistributedDataParallel( 42 | module=model, device_ids=[cur_device], output_device=cur_device 43 | ) 44 | 45 | return model 46 | 47 | 48 | def get_augmentation(config): 49 | if config.dataset.modality == 'RGB': 50 | if config.dataset.flip: 51 | return torchvision.transforms.Compose([GroupMultiScaleCrop(config.dataset.crop_size, [1, .875, .75, .66]), 52 | GroupRandomHorizontalFlip(is_flow=False)]) 53 | else: 54 | return torchvision.transforms.Compose( 55 | [GroupMultiScaleCrop(config.dataset.crop_size, [1, .875, .75, .66])]) 56 | elif config.dataset.modality == 'Flow': 57 | return torchvision.transforms.Compose([GroupMultiScaleCrop(config.dataset.crop_size, [1, .875, .75]), 58 | GroupRandomHorizontalFlip(is_flow=True)]) 59 | elif config.dataset.modality == 'RGBDiff': 60 | return torchvision.transforms.Compose([GroupMultiScaleCrop(config.dataset.crop_size, [1, .875, .75]), 61 | GroupRandomHorizontalFlip(is_flow=False)]) 62 | 63 | 64 | def get_optim_policies(model, args): 65 | first_conv_weight = [] 66 | first_conv_bias = [] 67 | normal_weight = [] 68 | normal_bias = [] 69 | lr5_weight = [] 70 | lr10_bias = [] 71 | cs_weight = [] 72 | cs_bias = [] 73 | bn = [] 74 | custom_ops = [] 75 | 76 | conv_cnt = 0 77 | bn_cnt = 0 78 | linear_cnt = 0 79 | fc_lr5 = not (args.tune_from and args.dataset in args.tune_from), 80 | for m in model.modules(): 81 | if isinstance(m, torch.nn.Conv2d) or isinstance( 82 | m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): 83 | ps = list(m.parameters()) 84 | conv_cnt += 1 85 | if conv_cnt == 1: 86 | first_conv_weight.append(ps[0]) 87 | if len(ps) == 2: 88 | first_conv_bias.append(ps[1]) 89 | else: 90 | normal_weight.append(ps[0]) 91 | if len(ps) == 2: 92 | normal_bias.append(ps[1]) 93 | elif isinstance(m, torch.nn.Linear) and (not args.is_dtn): 94 | ps = list(m.parameters()) 95 | if fc_lr5: 96 | lr5_weight.append(ps[0]) 97 | else: 98 | normal_weight.append(ps[0]) 99 | if len(ps) == 2: 100 | if fc_lr5: 101 | lr10_bias.append(ps[1]) 102 | else: 103 | normal_bias.append(ps[1]) 104 | 105 | elif isinstance(m, torch.nn.Linear) and args.is_dtn: 106 | linear_cnt += 1 107 | ps = list(m.parameters()) 108 | 109 | if linear_cnt < 33: 110 | cs_weight.append(ps[0]) 111 | cs_bias.append(ps[1]) 112 | else: 113 | if fc_lr5: 114 | lr5_weight.append(ps[0]) 115 | else: 116 | normal_weight.append(ps[0]) 117 | if len(ps) == 2: 118 | if fc_lr5: 119 | lr10_bias.append(ps[1]) 120 | else: 121 | normal_bias.append(ps[1]) 122 | 123 | elif isinstance(m, torch.nn.BatchNorm2d): 124 | bn_cnt += 1 125 | # later BN's are frozen 126 | if args.no_partialbn or bn_cnt == 1: 127 | bn.extend(list(m.parameters())) 128 | elif isinstance(m, torch.nn.BatchNorm3d): 129 | bn_cnt += 1 130 | # later BN's are frozen 131 | if args.no_partialbn or bn_cnt == 1: 132 | bn.extend(list(m.parameters())) 133 | elif len(m._modules) == 0: 134 | if len(list(m.parameters())) > 0: 135 | raise ValueError( 136 | "New atomic module type: {}. Need to give it a learning policy".format( 137 | type(m))) 138 | 139 | return [ 140 | {'params': first_conv_weight, 'lr_mult': 5 if args.modality == 'Flow' else 1, 'decay_mult': 1, 141 | 'name': "first_conv_weight"}, 142 | {'params': first_conv_bias, 'lr_mult': 10 if args.modality == 'Flow' else 2, 'decay_mult': 0, 143 | 'name': "first_conv_bias"}, 144 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, 145 | 'name': "normal_weight"}, 146 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, 147 | 'name': "normal_bias"}, 148 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, 149 | 'name': "BN scale/shift"}, 150 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, 151 | 'name': "custom_ops"}, 152 | # for fc 153 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1, 154 | 'name': "lr5_weight"}, 155 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0, 156 | 'name': "lr10_bias"}, 157 | {'params': cs_weight, 'lr_mult': 1, 'decay_mult': 1, 158 | 'name': "cs_weight"}, 159 | {'params': cs_bias, 'lr_mult': 2, 'decay_mult': 0, 160 | 'name': "cs_bias"}, 161 | ] 162 | -------------------------------------------------------------------------------- /x_temporal/core/non_local.py: -------------------------------------------------------------------------------- 1 | # Non-local block using embedded gaussian 2 | # Code from 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class _NonLocalBlockND(nn.Module): 10 | def __init__(self, in_channels, inter_channels=None, 11 | dimension=3, sub_sample=True, bn_layer=True): 12 | super(_NonLocalBlockND, self).__init__() 13 | 14 | assert dimension in [1, 2, 3] 15 | 16 | self.dimension = dimension 17 | self.sub_sample = sub_sample 18 | 19 | self.in_channels = in_channels 20 | self.inter_channels = inter_channels 21 | 22 | if self.inter_channels is None: 23 | self.inter_channels = in_channels // 2 24 | if self.inter_channels == 0: 25 | self.inter_channels = 1 26 | 27 | if dimension == 3: 28 | conv_nd = nn.Conv3d 29 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 30 | bn = nn.BatchNorm3d 31 | elif dimension == 2: 32 | conv_nd = nn.Conv2d 33 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 34 | bn = nn.BatchNorm2d 35 | else: 36 | conv_nd = nn.Conv1d 37 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 38 | bn = nn.BatchNorm1d 39 | 40 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 41 | kernel_size=1, stride=1, padding=0) 42 | 43 | if bn_layer: 44 | self.W = nn.Sequential( 45 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 46 | kernel_size=1, stride=1, padding=0), 47 | bn(self.in_channels) 48 | ) 49 | nn.init.constant_(self.W[1].weight, 0) 50 | nn.init.constant_(self.W[1].bias, 0) 51 | else: 52 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 53 | kernel_size=1, stride=1, padding=0) 54 | nn.init.constant_(self.W.weight, 0) 55 | nn.init.constant_(self.W.bias, 0) 56 | 57 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 58 | kernel_size=1, stride=1, padding=0) 59 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 60 | kernel_size=1, stride=1, padding=0) 61 | 62 | if sub_sample: 63 | self.g = nn.Sequential(self.g, max_pool_layer) 64 | self.phi = nn.Sequential(self.phi, max_pool_layer) 65 | 66 | def forward(self, x): 67 | ''' 68 | :param x: (b, c, t, h, w) 69 | :return: 70 | ''' 71 | 72 | batch_size = x.size(0) 73 | 74 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 75 | g_x = g_x.permute(0, 2, 1) 76 | 77 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 78 | theta_x = theta_x.permute(0, 2, 1) 79 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 80 | f = torch.matmul(theta_x, phi_x) 81 | f_div_C = F.softmax(f, dim=-1) 82 | 83 | y = torch.matmul(f_div_C, g_x) 84 | y = y.permute(0, 2, 1).contiguous() 85 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 86 | W_y = self.W(y) 87 | z = W_y + x 88 | 89 | return z 90 | 91 | 92 | class NONLocalBlock1D(_NonLocalBlockND): 93 | def __init__(self, in_channels, inter_channels=None, 94 | sub_sample=True, bn_layer=True): 95 | super(NONLocalBlock1D, self).__init__(in_channels, 96 | inter_channels=inter_channels, 97 | dimension=1, sub_sample=sub_sample, 98 | bn_layer=bn_layer) 99 | 100 | 101 | class NONLocalBlock2D(_NonLocalBlockND): 102 | def __init__(self, in_channels, inter_channels=None, 103 | sub_sample=True, bn_layer=True): 104 | super(NONLocalBlock2D, self).__init__(in_channels, 105 | inter_channels=inter_channels, 106 | dimension=2, sub_sample=sub_sample, 107 | bn_layer=bn_layer) 108 | 109 | 110 | class NONLocalBlock3D(_NonLocalBlockND): 111 | def __init__(self, in_channels, inter_channels=None, 112 | sub_sample=True, bn_layer=True): 113 | super(NONLocalBlock3D, self).__init__(in_channels, 114 | inter_channels=inter_channels, 115 | dimension=3, sub_sample=sub_sample, 116 | bn_layer=bn_layer) 117 | 118 | 119 | class NL3DWrapper(nn.Module): 120 | def __init__(self, block, n_segment): 121 | super(NL3DWrapper, self).__init__() 122 | self.block = block 123 | self.nl = NONLocalBlock3D(block.bn3.num_features) 124 | self.n_segment = n_segment 125 | 126 | def forward(self, x): 127 | x = self.block(x) 128 | 129 | nt, c, h, w = x.size() 130 | x = x.view( 131 | nt // 132 | self.n_segment, 133 | self.n_segment, 134 | c, 135 | h, 136 | w).transpose( 137 | 1, 138 | 2) # n, c, t, h, w 139 | x = self.nl(x) 140 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) 141 | return x 142 | 143 | 144 | def make_non_local(net, n_segment): 145 | import torchvision 146 | import archs 147 | # isinstance(net, torchvision.models.ResNet) or isinstance(net, 148 | # archs.small_resnet.ResNet): 149 | if True: 150 | net.layer2 = nn.Sequential( 151 | NL3DWrapper(net.layer2[0], n_segment), 152 | net.layer2[1], 153 | NL3DWrapper(net.layer2[2], n_segment), 154 | net.layer2[3], 155 | ) 156 | net.layer3 = nn.Sequential( 157 | NL3DWrapper(net.layer3[0], n_segment), 158 | net.layer3[1], 159 | NL3DWrapper(net.layer3[2], n_segment), 160 | net.layer3[3], 161 | NL3DWrapper(net.layer3[4], n_segment), 162 | net.layer3[5], 163 | ) 164 | else: 165 | raise NotImplementedError 166 | 167 | 168 | if __name__ == '__main__': 169 | from torch.autograd import Variable 170 | import torch 171 | 172 | sub_sample = True 173 | bn_layer = True 174 | 175 | img = Variable(torch.zeros(2, 3, 20)) 176 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 177 | out = net(img) 178 | print(out.size()) 179 | 180 | img = Variable(torch.zeros(2, 3, 20, 20)) 181 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 182 | out = net(img) 183 | print(out.size()) 184 | 185 | img = Variable(torch.randn(2, 3, 10, 20, 20)) 186 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 187 | out = net(img) 188 | print(out.size()) 189 | -------------------------------------------------------------------------------- /x_temporal/core/tsm.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class TemporalShift(nn.Module): 12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): 13 | super(TemporalShift, self).__init__() 14 | self.net = net 15 | self.n_segment = n_segment 16 | self.fold_div = n_div 17 | self.inplace = inplace 18 | 19 | def forward(self, x): 20 | x = self.shift( 21 | x, 22 | self.n_segment, 23 | fold_div=self.fold_div, 24 | inplace=self.inplace) 25 | return self.net(x) 26 | 27 | @staticmethod 28 | def shift(x, n_segment, fold_div=3, inplace=False): 29 | nt, c, h, w = x.size() 30 | n_batch = nt // n_segment 31 | x = x.view(n_batch, n_segment, c, h, w) 32 | 33 | fold = c // fold_div 34 | if inplace: 35 | out = InplaceShift.apply(x, fold) 36 | else: 37 | out = torch.zeros_like(x) 38 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 39 | out[:, 1:, fold: 2 * fold] = x[:, :- 40 | 1, fold: 2 * fold] # shift right 41 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 42 | 43 | return out.view(nt, c, h, w) 44 | 45 | 46 | class InplaceShift(torch.autograd.Function): 47 | # Special thanks to @raoyongming for the help to this function 48 | @staticmethod 49 | def forward(ctx, input, fold): 50 | # not support higher order gradient 51 | # input = input.detach_() 52 | ctx.fold_ = fold 53 | n, t, c, h, w = input.size() 54 | buffer = input.data.new(n, t, fold, h, w).zero_() 55 | buffer[:, :-1] = input.data[:, 1:, :fold] 56 | input.data[:, :, :fold] = buffer 57 | buffer.zero_() 58 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] 59 | input.data[:, :, fold: 2 * fold] = buffer 60 | return input 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | # grad_output = grad_output.detach_() 65 | fold = ctx.fold_ 66 | n, t, c, h, w = grad_output.size() 67 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() 68 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] 69 | grad_output.data[:, :, :fold] = buffer 70 | buffer.zero_() 71 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] 72 | grad_output.data[:, :, fold: 2 * fold] = buffer 73 | return grad_output, None 74 | 75 | 76 | class TemporalPool(nn.Module): 77 | def __init__(self, net, n_segment): 78 | super(TemporalPool, self).__init__() 79 | self.net = net 80 | self.n_segment = n_segment 81 | 82 | def forward(self, x): 83 | x = self.temporal_pool(x, n_segment=self.n_segment) 84 | return self.net(x) 85 | 86 | @staticmethod 87 | def temporal_pool(x, n_segment): 88 | nt, c, h, w = x.size() 89 | n_batch = nt // n_segment 90 | x = x.view( 91 | n_batch, 92 | n_segment, 93 | c, 94 | h, 95 | w).transpose( 96 | 1, 97 | 2) # n, c, t, h, w 98 | x = F.max_pool3d( 99 | x, kernel_size=( 100 | 3, 1, 1), stride=( 101 | 2, 1, 1), padding=( 102 | 1, 0, 0)) 103 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) 104 | return x 105 | 106 | 107 | def make_temporal_shift(net, n_segment, n_div=8, 108 | place='blockres', temporal_pool=False): 109 | if temporal_pool: 110 | n_segment_list = [ 111 | n_segment, 112 | n_segment // 2, 113 | n_segment // 2, 114 | n_segment // 2] 115 | else: 116 | n_segment_list = [n_segment] * 4 117 | assert n_segment_list[-1] > 0 118 | 119 | import torchvision 120 | if True: # isinstance(net, torchvision.models.ResNet): 121 | if place == 'block': 122 | def make_block_temporal(stage, this_segment): 123 | blocks = list(stage.children()) 124 | for i, b in enumerate(blocks): 125 | blocks[i] = TemporalShift( 126 | b, n_segment=this_segment, n_div=n_div) 127 | return nn.Sequential(*(blocks)) 128 | 129 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 130 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 131 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 132 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 133 | 134 | elif 'blockres' in place: 135 | n_round = 1 136 | if len(list(net.layer3.children())) >= 23: 137 | print( 138 | '=> Using n_round {} to insert temporal shift'.format(n_round)) 139 | 140 | def make_block_temporal(stage, this_segment): 141 | blocks = list(stage.children()) 142 | for i, b in enumerate(blocks): 143 | if i % n_round == 0: 144 | blocks[i].conv1 = TemporalShift( 145 | b.conv1, n_segment=this_segment, n_div=n_div) 146 | return nn.Sequential(*blocks) 147 | 148 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 149 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 150 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 151 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 152 | else: 153 | raise NotImplementedError(place) 154 | 155 | 156 | def make_temporal_pool(net, n_segment): 157 | import torchvision 158 | if isinstance(net, torchvision.models.ResNet): 159 | print('=> Injecting nonlocal pooling') 160 | net.layer2 = TemporalPool(net.layer2, n_segment) 161 | else: 162 | raise NotImplementedError 163 | 164 | 165 | if __name__ == '__main__': 166 | # test inplace shift v.s. vanilla shift 167 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) 168 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) 169 | 170 | print('=> Testing CPU...') 171 | # test forward 172 | with torch.no_grad(): 173 | for i in range(10): 174 | x = torch.rand(2 * 8, 3, 224, 224) 175 | y1 = tsm1(x) 176 | y2 = tsm2(x) 177 | assert torch.norm(y1 - y2).item() < 1e-5 178 | 179 | # test backward 180 | with torch.enable_grad(): 181 | for i in range(10): 182 | x1 = torch.rand(2 * 8, 3, 224, 224) 183 | x1.requires_grad_() 184 | x2 = x1.clone() 185 | y1 = tsm1(x1) 186 | y2 = tsm2(x2) 187 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 188 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 189 | assert torch.norm(grad1 - grad2).item() < 1e-5 190 | 191 | print('=> Testing GPU...') 192 | tsm1.cuda() 193 | tsm2.cuda() 194 | # test forward 195 | with torch.no_grad(): 196 | for i in range(10): 197 | x = torch.rand(2 * 8, 3, 224, 224).cuda() 198 | y1 = tsm1(x) 199 | y2 = tsm2(x) 200 | assert torch.norm(y1 - y2).item() < 1e-5 201 | 202 | # test backward 203 | with torch.enable_grad(): 204 | for i in range(10): 205 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() 206 | x1.requires_grad_() 207 | x2 = x1.clone() 208 | y1 = tsm1(x1) 209 | y2 = tsm2(x2) 210 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 211 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 212 | assert torch.norm(grad1 - grad2).item() < 1e-5 213 | print('Test passed.') 214 | -------------------------------------------------------------------------------- /x_temporal/models/resnet3D.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from functools import partial 8 | 9 | 10 | 11 | __all__ = [ 12 | 'ResNet3D', 'resnet3D10', 'resnet3D18', 'resnet3D34', 'resnet3D50', 'resnet3D101', 13 | 'resnet3D152', 'resnet3D200' 14 | ] 15 | 16 | 17 | def conv3x3x3(in_planes, out_planes, stride=1): 18 | # 3x3x3 convolution with padding 19 | return nn.Conv3d( 20 | in_planes, 21 | out_planes, 22 | kernel_size=3, 23 | stride=stride, 24 | padding=1, 25 | bias=False) 26 | 27 | 28 | def downsample_basic_block(x, planes, stride): 29 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 30 | zero_pads = torch.Tensor( 31 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 32 | out.size(4)).zero_() 33 | if isinstance(out.data, torch.cuda.FloatTensor): 34 | zero_pads = zero_pads.cuda() 35 | 36 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 37 | 38 | return out 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = conv3x3x3(inplanes, planes, stride) 47 | self.bn1 = BN(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3x3(planes, planes) 50 | self.bn2 = BN(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | residual = self.downsample(x) 66 | 67 | out += residual 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 79 | self.bn1 = BN(planes) 80 | self.conv2 = nn.Conv3d( 81 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 82 | self.bn2 = BN(planes) 83 | self.conv3 = nn.Conv3d( 84 | planes, 85 | planes * 4, 86 | kernel_size=1, 87 | stride=1, 88 | padding=0, 89 | dilation=1, 90 | groups=1, 91 | bias=False) 92 | self.bn3 = BN(planes * 4) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | residual = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class ResNet3D(nn.Module): 120 | 121 | def __init__(self, 122 | block, 123 | layers, 124 | sample_size, 125 | sample_duration, 126 | dropout=0.5, 127 | shortcut_type='B', 128 | num_classes=700, 129 | ): 130 | 131 | global BN 132 | 133 | BN = nn.BatchNorm3d 134 | 135 | 136 | self.inplanes = 64 137 | self.n_segment = sample_duration 138 | super(ResNet3D, self).__init__() 139 | self.conv1 = nn.Conv3d( 140 | 3, 141 | 64, 142 | kernel_size=7, 143 | stride=(1, 2, 2), 144 | padding=(3, 3, 3), 145 | bias=False) 146 | self.bn1 = BN(64) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 149 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 150 | self.layer2 = self._make_layer( 151 | block, 128, layers[1], shortcut_type, stride=2) 152 | self.layer3 = self._make_layer( 153 | block, 256, layers[2], shortcut_type, stride=2) 154 | self.layer4 = self._make_layer( 155 | block, 512, layers[3], shortcut_type, stride=2) 156 | last_duration = int(math.ceil(sample_duration / 16)) 157 | last_size = int(math.ceil(sample_size / 32)) 158 | self.avgpool = nn.AvgPool3d( 159 | (last_duration, last_size, last_size), stride=1) 160 | self.dropout = nn.Dropout(dropout) 161 | self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv3d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 166 | elif isinstance(m, nn.BatchNorm3d): 167 | m.weight.data.fill_(1) 168 | m.bias.data.zero_() 169 | 170 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 171 | downsample = None 172 | if stride != 1 or self.inplanes != planes * block.expansion: 173 | if shortcut_type == 'A': 174 | downsample = partial( 175 | downsample_basic_block, 176 | planes=planes * block.expansion, 177 | stride=stride) 178 | else: 179 | downsample = nn.Sequential( 180 | nn.Conv3d( 181 | self.inplanes, 182 | planes * block.expansion, 183 | kernel_size=1, 184 | stride=stride, 185 | bias=False), BN(planes * block.expansion)) 186 | 187 | layers = [] 188 | layers.append(block(self.inplanes, planes, stride, downsample)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | x = self.layer4(x) 205 | 206 | x = self.avgpool(x) 207 | 208 | x = x.view(x.size(0), -1) 209 | x = self.dropout(x) 210 | x = self.fc(x) 211 | 212 | return x 213 | 214 | 215 | def resnet3D10(**kwargs): 216 | """Constructs a ResNet-18 model. 217 | """ 218 | model = ResNet3D(BasicBlock, [1, 1, 1, 1], **kwargs) 219 | return model 220 | 221 | 222 | def resnet3D18(**kwargs): 223 | """Constructs a ResNet-18 model. 224 | """ 225 | model = ResNet3D(BasicBlock, [2, 2, 2, 2], **kwargs) 226 | return model 227 | 228 | 229 | def resnet3D34(**kwargs): 230 | """Constructs a ResNet-34 model. 231 | """ 232 | model = ResNet3D(BasicBlock, [3, 4, 6, 3], **kwargs) 233 | return model 234 | 235 | 236 | def resnet3D50(**kwargs): 237 | """Constructs a ResNet-50 model. 238 | """ 239 | model = ResNet3D(Bottleneck, [3, 4, 6, 3], **kwargs) 240 | return model 241 | 242 | 243 | def resnet3D101(**kwargs): 244 | """Constructs a ResNet-101 model. 245 | """ 246 | model = ResNet3D(Bottleneck, [3, 4, 23, 3], **kwargs) 247 | return model 248 | 249 | 250 | def resnet3D152(**kwargs): 251 | """Constructs a ResNet-101 model. 252 | """ 253 | model = ResNet3D(Bottleneck, [3, 8, 36, 3], **kwargs) 254 | return model 255 | 256 | 257 | def resnet3D200(**kwargs): 258 | """Constructs a ResNet-101 model. 259 | """ 260 | model = ResNet3D(Bottleneck, [3, 24, 36, 3], **kwargs) 261 | return model 262 | -------------------------------------------------------------------------------- /x_temporal/models/slowfast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | __all__ = ['sfresnet50', 'sfresnet101','sfresnet152', 'sfresnet200'] 7 | 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None, head_conv=1): 14 | super(Bottleneck, self).__init__() 15 | if head_conv == 1: 16 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 17 | self.bn1 = BN(planes) 18 | elif head_conv == 3: 19 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), bias=False, padding=(1, 0, 0)) 20 | self.bn1 = BN(planes) 21 | else: 22 | raise ValueError("Unsupported head_conv!") 23 | self.conv2 = nn.Conv3d( 24 | planes, planes, kernel_size=(1, 3, 3), stride=(1,stride,stride), padding=(0, 1, 1), bias=False) 25 | self.bn2 = BN(planes) 26 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 27 | self.bn3 = BN(planes * 4) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv3(out) 44 | out = self.bn3(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class SlowFast(nn.Module): 55 | def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=700, dropout=0.5, 56 | sample_duration=0, sample_size=112): 57 | 58 | super(SlowFast, self).__init__() 59 | 60 | 61 | global BN 62 | BN = nn.BatchNorm3d 63 | 64 | self.n_segment = sample_duration 65 | 66 | self.fast_inplanes = 8 67 | self.fast_conv1 = nn.Conv3d(3, 8, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False) 68 | self.fast_bn1 = BN(8) 69 | self.fast_relu = nn.ReLU(inplace=True) 70 | self.fast_maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 71 | self.fast_res2 = self._make_layer_fast(block, 8, layers[0], head_conv=3) 72 | self.fast_res3 = self._make_layer_fast( 73 | block, 16, layers[1], stride=2, head_conv=3) 74 | self.fast_res4 = self._make_layer_fast( 75 | block, 32, layers[2], stride=2, head_conv=3) 76 | self.fast_res5 = self._make_layer_fast( 77 | block, 64, layers[3], stride=2, head_conv=3) 78 | 79 | self.lateral_p1 = nn.Conv3d(8, 8*2, kernel_size=(5, 1, 1), stride=(4, 1 ,1), bias=False, padding=(2, 0, 0)) 80 | self.lateral_res2 = nn.Conv3d(32,32*2, kernel_size=(5, 1, 1), stride=(4, 1 ,1), bias=False, padding=(2, 0, 0)) 81 | self.lateral_res3 = nn.Conv3d(64,64*2, kernel_size=(5, 1, 1), stride=(4, 1 ,1), bias=False, padding=(2, 0, 0)) 82 | self.lateral_res4 = nn.Conv3d(128,128*2, kernel_size=(5, 1, 1), stride=(4, 1 ,1), bias=False, padding=(2, 0, 0)) 83 | 84 | self.slow_inplanes = 64+64//8*2 85 | self.slow_conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False) 86 | self.slow_bn1 = BN(64) 87 | self.slow_relu = nn.ReLU(inplace=True) 88 | self.slow_maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 89 | self.slow_res2 = self._make_layer_slow(block, 64, layers[0], head_conv=1) 90 | self.slow_res3 = self._make_layer_slow( 91 | block, 128, layers[1], stride=2, head_conv=1) 92 | self.slow_res4 = self._make_layer_slow( 93 | block, 256, layers[2], stride=2, head_conv=3) 94 | self.slow_res5 = self._make_layer_slow( 95 | block, 512, layers[3], stride=2, head_conv=3) 96 | self.dp = nn.Dropout(dropout) 97 | self.fc = nn.Linear(self.fast_inplanes+2048, num_classes, bias=False) 98 | 99 | def forward(self, input): 100 | fast, lateral = self.FastPath(input[:, :, ::1, :, :]) 101 | slow = self.SlowPath(input[:, :, ::4, :, :], lateral) 102 | x = torch.cat([slow, fast], dim=1) 103 | x = self.dp(x) 104 | x = self.fc(x) 105 | return x 106 | 107 | 108 | 109 | def SlowPath(self, input, lateral): 110 | x = self.slow_conv1(input) 111 | x = self.slow_bn1(x) 112 | x = self.slow_relu(x) 113 | x = self.slow_maxpool(x) 114 | x = torch.cat([x, lateral[0]],dim=1) 115 | x = self.slow_res2(x) 116 | x = torch.cat([x, lateral[1]],dim=1) 117 | x = self.slow_res3(x) 118 | x = torch.cat([x, lateral[2]],dim=1) 119 | x = self.slow_res4(x) 120 | x = torch.cat([x, lateral[3]],dim=1) 121 | x = self.slow_res5(x) 122 | x = nn.AdaptiveAvgPool3d(1)(x) 123 | x = x.view(-1, x.size(1)) 124 | return x 125 | 126 | def FastPath(self, input): 127 | lateral = [] 128 | x = self.fast_conv1(input) 129 | x = self.fast_bn1(x) 130 | x = self.fast_relu(x) 131 | pool1 = self.fast_maxpool(x) 132 | lateral_p = self.lateral_p1(pool1) 133 | lateral.append(lateral_p) 134 | 135 | res2 = self.fast_res2(pool1) 136 | lateral_res2 = self.lateral_res2(res2) 137 | lateral.append(lateral_res2) 138 | 139 | res3 = self.fast_res3(res2) 140 | lateral_res3 = self.lateral_res3(res3) 141 | lateral.append(lateral_res3) 142 | 143 | res4 = self.fast_res4(res3) 144 | lateral_res4 = self.lateral_res4(res4) 145 | lateral.append(lateral_res4) 146 | 147 | res5 = self.fast_res5(res4) 148 | x = nn.AdaptiveAvgPool3d(1)(res5) 149 | x = x.view(-1, x.size(1)) 150 | 151 | return x, lateral 152 | 153 | def _make_layer_fast(self, block, planes, blocks, stride=1, head_conv=1): 154 | downsample = None 155 | if stride != 1 or self.fast_inplanes != planes * block.expansion: 156 | downsample = nn.Sequential( 157 | nn.Conv3d( 158 | self.fast_inplanes, 159 | planes * block.expansion, 160 | kernel_size=1, 161 | stride=(1,stride,stride), 162 | bias=False), BN(planes * block.expansion)) 163 | 164 | layers = [] 165 | layers.append(block(self.fast_inplanes, planes, stride, downsample, head_conv=head_conv)) 166 | self.fast_inplanes = planes * block.expansion 167 | for i in range(1, blocks): 168 | layers.append(block(self.fast_inplanes, planes, head_conv=head_conv)) 169 | return nn.Sequential(*layers) 170 | 171 | def _make_layer_slow(self, block, planes, blocks, stride=1, head_conv=1): 172 | downsample = None 173 | if stride != 1 or self.slow_inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | nn.Conv3d( 176 | self.slow_inplanes, 177 | planes * block.expansion, 178 | kernel_size=1, 179 | stride=(1,stride,stride), 180 | bias=False), BN(planes * block.expansion)) 181 | 182 | layers = [] 183 | layers.append(block(self.slow_inplanes, planes, stride, downsample, head_conv=head_conv)) 184 | self.slow_inplanes = planes * block.expansion 185 | for i in range(1, blocks): 186 | layers.append(block(self.slow_inplanes, planes, head_conv=head_conv)) 187 | 188 | self.slow_inplanes = planes * block.expansion + planes * block.expansion//8*2 189 | return nn.Sequential(*layers) 190 | 191 | 192 | 193 | 194 | def sfresnet50(**kwargs): 195 | """Constructs a ResNet-50 model. 196 | """ 197 | model = SlowFast(Bottleneck, [3, 4, 6, 3], **kwargs) 198 | return model 199 | 200 | 201 | def sfresnet101(**kwargs): 202 | """Constructs a ResNet-101 model. 203 | """ 204 | model = SlowFast(Bottleneck, [3, 4, 23, 3], **kwargs) 205 | return model 206 | 207 | 208 | def sfresnet152(**kwargs): 209 | """Constructs a ResNet-101 model. 210 | """ 211 | model = SlowFast(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | return model 213 | 214 | 215 | def sfresnet200(**kwargs): 216 | """Constructs a ResNet-101 model. 217 | """ 218 | model = SlowFast(Bottleneck, [3, 24, 36, 3], **kwargs) 219 | return model 220 | -------------------------------------------------------------------------------- /x_temporal/models/stresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | from torch.nn.modules.utils import _triple 8 | 9 | 10 | __all__ = [ 11 | 'ResNet', 'stresnet10', 'stresnet18', 'stresnet34', 'stresnet50', 'stresnet101', 12 | 'stresnet152', 'stresnet200' 13 | ] 14 | 15 | 16 | class SpatioTemporalConv(nn.Module): 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size=3, 19 | stride=1, padding=1, bias=False): 20 | super(SpatioTemporalConv, self).__init__() 21 | 22 | kernel_size = _triple(kernel_size) 23 | stride = _triple(stride) 24 | padding = _triple(padding) 25 | 26 | spatial_kernel_size = [1, kernel_size[1], kernel_size[2]] 27 | spatial_stride = [1, stride[1], stride[2]] 28 | spatial_padding = [0, padding[1], padding[2]] 29 | 30 | temporal_kernel_size = [kernel_size[0], 1, 1] 31 | temporal_stride = [stride[0], 1, 1] 32 | temporal_padding = [padding[0], 0, 0] 33 | 34 | intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels) / 35 | (kernel_size[1] * kernel_size[2] * in_channels + kernel_size[0] * out_channels))) 36 | 37 | self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size, 38 | stride=spatial_stride, padding=spatial_padding, bias=bias) 39 | self.bn = BN(intermed_channels) 40 | self.relu = nn.ReLU() 41 | 42 | self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, 43 | stride=temporal_stride, padding=temporal_padding, bias=bias) 44 | 45 | def forward(self, x): 46 | x = self.relu(self.bn(self.spatial_conv(x))) 47 | x = self.temporal_conv(x) 48 | return x 49 | 50 | 51 | def conv3x3x3(in_planes, out_planes, stride=1): 52 | # 3x3x3 convolution with padding 53 | return nn.Conv3d( 54 | in_planes, 55 | out_planes, 56 | kernel_size=3, 57 | stride=stride, 58 | padding=1, 59 | bias=False) 60 | 61 | 62 | def downsample_basic_block(x, planes, stride): 63 | out = F.avg_pool3d(x, kernel_size=stride, stride=stride) 64 | zero_pads = torch.Tensor( 65 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 66 | out.size(4)).zero_() 67 | if isinstance(out.data, torch.cuda.FloatTensor): 68 | zero_pads = zero_pads.cuda() 69 | 70 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 71 | 72 | return out 73 | 74 | 75 | class BasicBlock(nn.Module): 76 | expansion = 1 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None): 79 | super(BasicBlock, self).__init__() 80 | self.conv1 = SpatioTemporalConv(inplanes, planes, 3, stride=stride) 81 | self.bn1 = BN(planes) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv2 = SpatioTemporalConv(planes, planes) 84 | self.bn2 = BN(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | 110 | def __init__(self, inplanes, planes, stride=1, downsample=None): 111 | super(Bottleneck, self).__init__() 112 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 113 | self.bn1 = BN(planes) 114 | self.conv2 = SpatioTemporalConv( 115 | planes, 116 | planes, 117 | kernel_size=3, 118 | stride=stride, 119 | padding=1, 120 | bias=False) 121 | self.bn2 = BN(planes) 122 | self.conv3 = nn.Conv3d( 123 | planes, 124 | planes * 4, 125 | kernel_size=1, 126 | stride=1, 127 | padding=0, 128 | dilation=1, 129 | groups=1, 130 | bias=False) 131 | self.bn3 = BN(planes * 4) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.downsample = downsample 134 | self.stride = stride 135 | 136 | def forward(self, x): 137 | residual = x 138 | 139 | out = self.conv1(x) 140 | out = self.bn1(out) 141 | out = self.relu(out) 142 | 143 | out = self.conv2(out) 144 | out = self.bn2(out) 145 | out = self.relu(out) 146 | out = self.conv3(out) 147 | out = self.bn3(out) 148 | 149 | if self.downsample is not None: 150 | residual = self.downsample(x) 151 | 152 | out += residual 153 | out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | 160 | def __init__(self, 161 | block, 162 | layers, 163 | sample_size, 164 | sample_duration, 165 | dropout=0.5, 166 | shortcut_type='B', 167 | num_classes=700, 168 | max_pooling=False): 169 | 170 | global BN 171 | BN = nn.BatchNorm3d 172 | 173 | self.inplanes = 64 174 | self.n_segment = sample_duration 175 | self.max_pooling = max_pooling 176 | super(ResNet, self).__init__() 177 | self.conv1 = SpatioTemporalConv( 178 | 3, 64, kernel_size=( 179 | 3, 7, 7), stride=( 180 | 1, 2, 2), padding=( 181 | 1, 3, 3), bias=False) 182 | self.bn1 = BN(64) 183 | self.relu = nn.ReLU(inplace=True) 184 | self.maxpool = nn.MaxPool3d( 185 | kernel_size=( 186 | 1, 3, 3), stride=( 187 | 1, 2, 2), padding=( 188 | 0, 1, 1)) 189 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 190 | self.layer2 = self._make_layer( 191 | block, 128, layers[1], shortcut_type, stride=2) 192 | self.layer3 = self._make_layer( 193 | block, 256, layers[2], shortcut_type, stride=2) 194 | self.layer4 = self._make_layer( 195 | block, 512, layers[3], shortcut_type, stride=2) 196 | last_duration = int(math.ceil(sample_duration / 8)) 197 | last_size = int(math.ceil(sample_size / 16)) 198 | self.avgpool = nn.AvgPool3d( 199 | (last_duration, last_size, last_size), stride=1) 200 | self.avgpool = torch.nn.AdaptiveAvgPool3d(1) 201 | self.dropout = nn.Dropout(dropout) 202 | self.fc = nn.Linear(512 * block.expansion, num_classes) 203 | 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv3d): 206 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 207 | elif isinstance(m, nn.BatchNorm3d): 208 | m.weight.data.fill_(1) 209 | m.bias.data.zero_() 210 | 211 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 212 | downsample = None 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | if shortcut_type == 'A': 215 | downsample = partial( 216 | downsample_basic_block, 217 | planes=planes * block.expansion, 218 | stride=stride) 219 | else: 220 | downsample = nn.Sequential( 221 | nn.Conv3d( 222 | self.inplanes, 223 | planes * block.expansion, 224 | kernel_size=1, 225 | stride=stride, 226 | bias=False), BN(planes * block.expansion)) 227 | 228 | layers = [] 229 | layers.append(block(self.inplanes, planes, stride, downsample)) 230 | self.inplanes = planes * block.expansion 231 | for i in range(1, blocks): 232 | layers.append(block(self.inplanes, planes)) 233 | 234 | return nn.Sequential(*layers) 235 | 236 | 237 | def forward(self, x): 238 | x = self.conv1(x) 239 | x = self.bn1(x) 240 | x = self.relu(x) 241 | if self.max_pooling: 242 | x = self.maxpool(x) 243 | 244 | x = self.layer1(x) 245 | x = self.layer2(x) 246 | x = self.layer3(x) 247 | x = self.layer4(x) 248 | x = self.avgpool(x) 249 | 250 | feature = x.view(x.size(0), -1) 251 | feature = self.dropout(feature) 252 | x = self.fc(feature) 253 | 254 | return x 255 | 256 | 257 | def stresnet10(**kwargs): 258 | """Constructs a ResNet-18 model. 259 | """ 260 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 261 | return model 262 | 263 | 264 | def stresnet18(**kwargs): 265 | """Constructs a ResNet-18 model. 266 | """ 267 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 268 | return model 269 | 270 | 271 | def stresnet34(**kwargs): 272 | """Constructs a ResNet-34 model. 273 | """ 274 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 275 | return model 276 | 277 | 278 | def stresnet50(**kwargs): 279 | """Constructs a ResNet-50 model. 280 | """ 281 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 282 | return model 283 | 284 | 285 | def stresnet101(**kwargs): 286 | """Constructs a ResNet-101 model. 287 | """ 288 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 289 | return model 290 | 291 | 292 | def stresnet152(**kwargs): 293 | """Constructs a ResNet-101 model. 294 | """ 295 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 296 | return model 297 | 298 | 299 | def stresnet200(**kwargs): 300 | """Constructs a ResNet-101 model. 301 | """ 302 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 303 | return model 304 | -------------------------------------------------------------------------------- /x_temporal/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | BN = None 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet50c', 'resnet50d', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | "3x3 convolution with padding" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = BN(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = BN(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = BN(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = BN(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = BN(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000, deep_stem=False, 102 | avg_down=False, 103 | ): 104 | 105 | global BN 106 | BN = nn.BatchNorm2d 107 | 108 | 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | 112 | self.deep_stem = deep_stem 113 | self.avg_down = avg_down 114 | 115 | if self.deep_stem: 116 | self.conv1 = nn.Sequential( 117 | nn.Conv2d( 118 | 3, 119 | 32, 120 | kernel_size=3, 121 | stride=2, 122 | padding=1, 123 | bias=False), 124 | BN(32), 125 | nn.ReLU(inplace=True), 126 | nn.Conv2d( 127 | 32, 128 | 32, 129 | kernel_size=3, 130 | stride=1, 131 | padding=1, 132 | bias=False), 133 | BN(32), 134 | nn.ReLU(inplace=True), 135 | nn.Conv2d( 136 | 32, 137 | 64, 138 | kernel_size=3, 139 | stride=1, 140 | padding=1, 141 | bias=False), 142 | ) 143 | else: 144 | self.conv1 = nn.Conv2d( 145 | 3, 146 | 64, 147 | kernel_size=7, 148 | stride=2, 149 | padding=3, 150 | bias=False) 151 | self.bn1 = BN(64) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 | self.layer1 = self._make_layer(block, 64, layers[0]) 155 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 158 | self.avgpool = nn.AdaptiveAvgPool2d(1) 159 | self.fc = nn.Linear(512 * block.expansion, num_classes) 160 | 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 164 | m.weight.data.normal_(0, math.sqrt(2. / n)) 165 | elif isinstance(m, nn.BatchNorm2d): 166 | m.weight.data.fill_(1) 167 | m.bias.data.zero_() 168 | 169 | 170 | def _make_layer(self, block, planes, blocks, stride=1, avg_down=False): 171 | downsample = None 172 | if stride != 1 or self.inplanes != planes * block.expansion: 173 | if self.avg_down: 174 | downsample = nn.Sequential( 175 | nn.AvgPool2d( 176 | stride, 177 | stride=stride, 178 | ceil_mode=True, 179 | count_include_pad=False), 180 | nn.Conv2d(self.inplanes, planes * block.expansion, 181 | kernel_size=1, stride=1, bias=False), 182 | BN(planes * block.expansion), 183 | ) 184 | else: 185 | downsample = nn.Sequential( 186 | nn.Conv2d(self.inplanes, planes * block.expansion, 187 | kernel_size=1, stride=stride, bias=False), 188 | BN(planes * block.expansion), 189 | ) 190 | 191 | layers = [] 192 | layers.append(block(self.inplanes, planes, stride, downsample)) 193 | self.inplanes = planes * block.expansion 194 | self.dropout = nn.Dropout(0.5) 195 | for i in range(1, blocks): 196 | layers.append(block(self.inplanes, planes)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def forward(self, x): 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = x.view(x.size(0), -1) 213 | x = self.dropout(x) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | 219 | def resnet18(pretrained=False, **kwargs): 220 | """Constructs a ResNet-18 model. 221 | 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 226 | if pretrained: 227 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 228 | return model 229 | 230 | 231 | def resnet34(pretrained=False, **kwargs): 232 | """Constructs a ResNet-34 model. 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | """ 237 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 238 | if pretrained: 239 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 240 | return model 241 | 242 | 243 | def resnet50(pretrained=False, **kwargs): 244 | """Constructs a ResNet-50 model. 245 | 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | """ 249 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 250 | if pretrained: 251 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 252 | return model 253 | 254 | 255 | def resnet50c(pretrained=False, **kwargs): 256 | """Constructs a ResNet-50 model. 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | """ 261 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs, deep_stem=True) 262 | if pretrained: 263 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 264 | return model 265 | 266 | 267 | def resnet50d(pretrained=False, **kwargs): 268 | """Constructs a ResNet-50 model. 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | """ 273 | model = ResNet(Bottleneck, [3, 4, 6, 3], ** 274 | kwargs, deep_stem=True, avg_down=True) 275 | if pretrained: 276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 277 | return model 278 | 279 | 280 | def resnet101(pretrained=False, **kwargs): 281 | """Constructs a ResNet-101 model. 282 | 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | """ 286 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 287 | if pretrained: 288 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 289 | return model 290 | 291 | 292 | def resnet152(pretrained=False, **kwargs): 293 | """Constructs a ResNet-152 model. 294 | 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | """ 298 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 299 | if pretrained: 300 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 301 | return model 302 | -------------------------------------------------------------------------------- /x_temporal/core/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import numpy as np 9 | from numpy.random import randint 10 | from decord import VideoReader 11 | from decord import cpu 12 | 13 | logger = logging.getLogger('global') 14 | 15 | 16 | 17 | class VideoRecord(object): 18 | def __init__(self, row): 19 | self._data = row 20 | 21 | @property 22 | def path(self): 23 | return self._data[0] 24 | 25 | @property 26 | def num_frames(self): 27 | return int(self._data[1]) 28 | 29 | @property 30 | def label(self): 31 | return int(self._data[2]) 32 | 33 | @property 34 | def mlabel(self): 35 | labels = torch.tensor([int(x) 36 | for x in self._data[2].split(',')]).long() 37 | onehot = torch.FloatTensor(313) 38 | onehot.zero_() 39 | onehot[labels] = 1 40 | return onehot 41 | 42 | 43 | class VideoDataSet(data.Dataset): 44 | def __init__(self, root_path, list_file, 45 | num_segments=3, new_length=1, modality='RGB', 46 | image_tmpl='img_{:05d}.jpg', transform=None, 47 | random_shift=True, test_mode=False, 48 | remove_missing=False, dense_sample=False, multi_class=False, 49 | temporal_samples=1, reverse_samples=False, dense_sample_rate=2, 50 | video_source=False): 51 | 52 | self.root_path = root_path 53 | self.list_file = list_file 54 | self.num_segments = num_segments 55 | self.new_length = new_length 56 | self.modality = modality 57 | self.image_tmpl = image_tmpl 58 | self.transform = transform 59 | self.random_shift = random_shift 60 | self.test_mode = test_mode 61 | self.remove_missing = remove_missing 62 | self.dense_sample = dense_sample # using dense sample as I3D 63 | self.multi_class = multi_class 64 | self.dense_sample_rate = dense_sample_rate 65 | self.video_source = video_source 66 | 67 | # new args for test 68 | self.temporal_samples = temporal_samples 69 | self.reverse_samples = reverse_samples 70 | 71 | if self.dense_sample: 72 | logger.info('=> Using dense sample for the dataset...') 73 | 74 | if self.modality == 'RGBDiff': 75 | self.new_length += 1 # Diff needs one more image to calculate diff 76 | 77 | self._parse_list() 78 | 79 | 80 | def _load_image(self, directory, idx): 81 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 82 | try: 83 | filename = os.path.join( 84 | self.root_path, directory, self.image_tmpl.format(idx)) 85 | img = Image.open(filename).convert('RGB') 86 | return [img] 87 | except Exception as e: 88 | logger.info(e) 89 | logger.info( 90 | 'error loading image: %s' % 91 | os.path.join( 92 | self.root_path, 93 | directory, 94 | self.image_tmpl.format(idx))) 95 | return [Image.open(os.path.join( 96 | self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 97 | elif self.modality == 'Flow': 98 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf 99 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert( 100 | 'L') 101 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert( 102 | 'L') 103 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow 104 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. 105 | format(int(directory), 'x', idx))).convert('L') 106 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. 107 | format(int(directory), 'y', idx))).convert('L') 108 | else: 109 | try: 110 | filename = os.path.join( 111 | self.root_path, directory, self.image_tmpl.format(idx)) 112 | flow = Image.open(filename).convert('RGB') 113 | except Exception: 114 | logger.info('error loading flow file: %s' % 115 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 116 | flow = Image.open( 117 | os.path.join( 118 | self.root_path, 119 | directory, 120 | self.image_tmpl.format(1))).convert('RGB') 121 | flow_x, flow_y, _ = flow.split() 122 | x_img = flow_x.convert('L') 123 | y_img = flow_y.convert('L') 124 | 125 | return [x_img, y_img] 126 | 127 | def _parse_list(self): 128 | # check the frame number is large >3: 129 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 130 | if not self.test_mode or self.remove_missing: 131 | tmp = [item for item in tmp if int(item[1]) >= 3] 132 | self.video_list = [VideoRecord(item) for item in tmp] 133 | 134 | if self.image_tmpl == '{:06d}-{}_{:ample_ind/5d}.jpg': 135 | for v in self.video_list: 136 | v._data[1] = int(v._data[1]) / 2 137 | logger.info('video number:%d' % (len(self.video_list))) 138 | 139 | def _sample_indices(self, record): 140 | """ 141 | 142 | :param record: VideoRecord 143 | :return: list 144 | """ 145 | if self.dense_sample: # i3d dense sample 146 | sample_range = self.num_segments * self.dense_sample_rate 147 | sample_pos = max(1, 1 + record.num_frames - sample_range) 148 | t_stride = self.dense_sample_rate 149 | start_idx = 0 if sample_pos == 1 else np.random.randint( 150 | 0, sample_pos - 1) 151 | offsets = [ 152 | (idx * t_stride + start_idx) % 153 | record.num_frames for idx in range( 154 | self.num_segments)] 155 | return np.array(offsets) + 1 156 | else: # normal sample 157 | average_duration = ( 158 | record.num_frames - self.new_length + 1) // self.num_segments 159 | if average_duration > 0: 160 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 161 | size=self.num_segments) 162 | elif record.num_frames > self.num_segments: 163 | offsets = np.sort( 164 | randint( 165 | record.num_frames - 166 | self.new_length + 167 | 1, 168 | size=self.num_segments)) 169 | else: 170 | offsets = np.zeros((self.num_segments,)) 171 | return offsets + 1 172 | 173 | def _get_val_indices(self, record): 174 | if self.dense_sample: # i3d dense sample 175 | sample_range = self.num_segments * self.dense_sample_rate 176 | sample_pos = max(1, 1 + record.num_frames - sample_range) 177 | t_stride = self.dense_sample_rate 178 | if self.temporal_samples == 1: 179 | start_idx = 0 if sample_pos == 1 else sample_pos // 2 180 | offsets = [ 181 | (idx * t_stride + start_idx) % 182 | record.num_frames for idx in range( 183 | self.num_segments)] 184 | else: 185 | start_list = np.linspace(0, sample_pos - 1, num=self.temporal_samples, dtype=int) 186 | offsets = [] 187 | for start_idx in start_list.tolist(): 188 | offsets += [(idx * t_stride + start_idx) % 189 | record.num_frames for idx in range(self.num_segments)] 190 | return np.array(offsets) + 1 191 | else: 192 | t_offsets = [] 193 | tick = (record.num_frames - self.new_length + 1) / \ 194 | float(self.num_segments) 195 | offsets = np.array([1 + int(tick / 2.0 + tick * x) 196 | for x in range(self.num_segments)]) 197 | t_offsets.append(offsets) 198 | 199 | average_duration = ( 200 | record.num_frames - self.new_length + 1) // self.num_segments 201 | for i in range(self.temporal_samples - 1): 202 | offsets = np.multiply(list(range(self.num_segments)), 203 | average_duration) + randint(average_duration, 204 | size=self.num_segments) 205 | t_offsets.append(offsets + 1) 206 | 207 | t_offsets = np.stack(t_offsets).reshape(-1) 208 | return t_offsets 209 | 210 | def _get_test_indices(self, record): 211 | if self.dense_sample: 212 | sample_range = self.num_segments * self.dense_sample_rate 213 | sample_pos = max(1, 1 + record.num_frames - sample_range) 214 | t_stride = self.dense_sample_rate 215 | start_list = np.linspace(0, sample_pos - 1, num=self.temporal_samples, dtype=int) 216 | offsets = [] 217 | for start_idx in start_list.tolist(): 218 | offsets += [(idx * t_stride + start_idx) % 219 | record.num_frames for idx in range(self.num_segments)] 220 | return np.array(offsets) + 1 221 | else: 222 | t_offsets = [] 223 | tick = (record.num_frames - self.new_length + 1) / \ 224 | float(self.num_segments) 225 | offsets = np.array([1 + int(tick / 2.0 + tick * x) 226 | for x in range(self.num_segments)]) 227 | t_offsets.append(offsets) 228 | 229 | average_duration = ( 230 | record.num_frames - self.new_length + 1) // self.num_segments 231 | for i in range(self.temporal_samples - 1): 232 | offsets = np.multiply(list(range(self.num_segments)), 233 | average_duration) + randint(average_duration, 234 | size=self.num_segments) 235 | t_offsets.append(offsets + 1) 236 | 237 | t_offsets = np.stack(t_offsets).reshape(-1) 238 | return t_offsets 239 | 240 | def __getitem__(self, index): 241 | record = self.video_list[index] 242 | 243 | # check this is a legit video folder 244 | if self.video_source: 245 | full_path = os.path.join(self.root_path, record.path) 246 | while not os.path.exists(full_path): 247 | logger.info( 248 | '################## Not Found: %s' % 249 | os.path.join( 250 | self.root_path, 251 | record.path)) 252 | index = np.random.randint(len(self.video_list)) 253 | record = self.video_list[index] 254 | full_path = os.path.join(self.root_path, record.path) 255 | else: 256 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 257 | file_name = self.image_tmpl.format('x', 1) 258 | full_path = os.path.join( 259 | self.root_path, record.path, file_name) 260 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 261 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 262 | full_path = os.path.join( 263 | self.root_path, '{:06d}'.format(int(record.path)), file_name) 264 | else: 265 | file_name = self.image_tmpl.format(1) 266 | full_path = os.path.join( 267 | self.root_path, record.path, file_name) 268 | while not os.path.exists(full_path): 269 | logger.info( 270 | '################## Not Found: %s' % 271 | os.path.join( 272 | self.root_path, 273 | record.path, 274 | file_name)) 275 | index = np.random.randint(len(self.video_list)) 276 | record = self.video_list[index] 277 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 278 | file_name = self.image_tmpl.format('x', 1) 279 | full_path = os.path.join( 280 | self.root_path, record.path, file_name) 281 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 282 | file_name = self.image_tmpl.format( 283 | int(record.path), 'x', 1) 284 | full_path = os.path.join( 285 | self.root_path, '{:06d}'.format(int(record.path)), file_name) 286 | else: 287 | file_name = self.image_tmpl.format(1) 288 | full_path = os.path.join( 289 | self.root_path, record.path, file_name) 290 | 291 | if not self.test_mode: 292 | segment_indices = self._sample_indices( 293 | record) if self.random_shift else self._get_val_indices(record) 294 | else: 295 | segment_indices = self._get_test_indices(record) 296 | return self.get(record, segment_indices, record.path) 297 | 298 | def get(self, record, indices, path): 299 | images = list() 300 | if not self.video_source: 301 | for seg_ind in indices: 302 | p = int(seg_ind) 303 | seg_imgs = self._load_image(path, p) 304 | images.extend(seg_imgs) 305 | else: 306 | vr = VideoReader( 307 | os.path.join( 308 | self.root_path, 309 | record.path), 310 | ctx=cpu(0)) 311 | for seg_ind in indices: 312 | try: 313 | images.append(Image.fromarray(vr[seg_ind-1].asnumpy())) 314 | except Exception as e: 315 | images.append(Image.fromarray(vr[0].asnumpy())) 316 | 317 | process_data = self.transform(images) 318 | if self.multi_class: 319 | return process_data, record.mlabel 320 | else: 321 | return process_data, record.label 322 | 323 | def __len__(self): 324 | return len(self.video_list) 325 | -------------------------------------------------------------------------------- /x_temporal/core/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from x_temporal.core.basic_ops import ConsensusModule 4 | from x_temporal.core.transforms import * 5 | from torch.nn.init import normal_, constant_ 6 | from x_temporal.models.resnet import * 7 | from x_temporal.models.slowfast import * 8 | 9 | import logging 10 | logger = logging.getLogger('global') 11 | 12 | 13 | 14 | class TSN(nn.Module): 15 | def __init__(self, num_class, num_segments, modality, 16 | base_model='resnet101', new_length=None, 17 | consensus_type='avg', before_softmax=True, 18 | dropout=0.8, img_feature_dim=256, 19 | crop_num=1, partial_bn=True, print_spec=True, pretrain=True, 20 | is_shift=False, shift_div=8, shift_place='blockres', 21 | temporal_pool=False, non_local=False, tin=False): 22 | super(TSN, self).__init__() 23 | self.modality = modality 24 | self.num_segments = num_segments 25 | self.reshape = True 26 | self.before_softmax = before_softmax 27 | self.dropout = dropout 28 | self.crop_num = crop_num 29 | self.consensus_type = consensus_type 30 | # the dimension of the CNN feature to represent each frame 31 | self.img_feature_dim = img_feature_dim 32 | self.pretrain = pretrain 33 | 34 | self.is_shift = is_shift 35 | self.shift_div = shift_div 36 | self.shift_place = shift_place 37 | 38 | 39 | self.tin = tin 40 | 41 | self.base_model_name = base_model 42 | self.temporal_pool = temporal_pool 43 | self.non_local = non_local 44 | 45 | if not before_softmax and consensus_type != 'avg': 46 | raise ValueError("Only avg consensus can be used after Softmax") 47 | 48 | if new_length is None: 49 | self.new_length = 1 if modality == "RGB" else 5 50 | else: 51 | self.new_length = new_length 52 | if print_spec: 53 | logger.info((""" 54 | Initializing with base model: {}. 55 | Model Configurations: 56 | input_modality: {} 57 | num_segments: {} 58 | new_length: {} 59 | consensus_module: {} 60 | dropout_ratio: {} 61 | img_feature_dim: {} 62 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) 63 | 64 | self._prepare_base_model(base_model) 65 | 66 | feature_dim = self._prepare_tsn(num_class) 67 | 68 | if self.modality == 'Flow': 69 | logger.info("Converting the ImageNet model to a flow init model") 70 | self.base_model = self._construct_flow_model(self.base_model) 71 | logger.info("Done. Flow model ready...") 72 | elif self.modality == 'RGBDiff': 73 | logger.info("Converting the ImageNet model to RGB+Diff init model") 74 | self.base_model = self._construct_diff_model(self.base_model) 75 | logger.info("Done. RGBDiff model ready.") 76 | 77 | self.consensus = ConsensusModule(consensus_type) 78 | 79 | if not self.before_softmax: 80 | self.softmax = nn.Softmax() 81 | 82 | self._enable_pbn = partial_bn 83 | if partial_bn: 84 | self.partialBN(True) 85 | 86 | def _prepare_tsn(self, num_class): 87 | feature_dim = getattr(self.base_model, 88 | self.base_model.last_layer_name).in_features 89 | if self.dropout == 0: 90 | setattr( 91 | self.base_model, 92 | self.base_model.last_layer_name, 93 | nn.Linear( 94 | feature_dim, 95 | num_class)) 96 | self.new_fc = None 97 | else: 98 | setattr( 99 | self.base_model, 100 | self.base_model.last_layer_name, 101 | nn.Dropout( 102 | p=self.dropout)) 103 | if self.consensus_type in ['TRN', 'TRNmultiscale']: 104 | self.new_fc = nn.Linear(feature_dim, self.img_feature_dim) 105 | else: 106 | self.new_fc = nn.Linear(feature_dim, num_class) 107 | 108 | std = 0.001 109 | if self.new_fc is None: 110 | normal_( 111 | getattr( 112 | self.base_model, 113 | self.base_model.last_layer_name).weight, 114 | 0, 115 | std) 116 | constant_( 117 | getattr( 118 | self.base_model, 119 | self.base_model.last_layer_name).bias, 120 | 0) 121 | else: 122 | if hasattr(self.new_fc, 'weight'): 123 | normal_(self.new_fc.weight, 0, std) 124 | constant_(self.new_fc.bias, 0) 125 | return feature_dim 126 | 127 | def _prepare_base_model(self, base_model, config={}): 128 | logger.info('=> base model: {}'.format(base_model)) 129 | 130 | if base_model.startswith('resnet'): 131 | self.base_model = globals()[base_model](pretrained=self.pretrain) 132 | 133 | if self.is_shift: 134 | logger.info('Adding temporal shift...') 135 | from x_temporal.core.tsm import make_temporal_shift 136 | make_temporal_shift(self.base_model, self.num_segments, 137 | n_div=self.shift_div, place=self.shift_place) 138 | 139 | if self.tin: 140 | logger.info('Adding temporal interlace conv...') 141 | from x_temporal.core.tin import make_temporal_interlace 142 | make_temporal_interlace( 143 | self.base_model, 144 | self.num_segments, 145 | shift_div=self.shift_div) 146 | 147 | if self.non_local: 148 | logger.info('Adding non-local module...') 149 | from x_temporal.core.non_local import make_non_local 150 | make_non_local(self.base_model, self.num_segments) 151 | 152 | self.base_model.last_layer_name = 'fc' 153 | self.input_size = 224 154 | self.input_mean = [0.485, 0.456, 0.406] 155 | self.input_std = [0.229, 0.224, 0.225] 156 | 157 | self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) 158 | 159 | if self.modality == 'Flow': 160 | self.input_mean = [0.5] 161 | self.input_std = [np.mean(self.input_std)] 162 | elif self.modality == 'RGBDiff': 163 | self.input_mean = [0.485, 0.456, 0.406] + \ 164 | [0] * 3 * self.new_length 165 | self.input_std = self.input_std + \ 166 | [np.mean(self.input_std) * 2] * 3 * self.new_length 167 | 168 | def train(self, mode=True): 169 | """ 170 | Override the default train() to freeze the BN parameters 171 | :return: 172 | """ 173 | super(TSN, self).train(mode) 174 | count = 0 175 | if self._enable_pbn and mode: 176 | logger.info("Freezing BatchNorm2D except the first one.") 177 | for m in self.base_model.modules(): 178 | if isinstance(m, nn.BatchNorm2d): 179 | count += 1 180 | if count >= (2 if self._enable_pbn else 1): 181 | m.eval() 182 | 183 | 184 | def partialBN(self, enable): 185 | self._enable_pbn = enable 186 | 187 | def forward(self, input, no_reshape=False): 188 | if not no_reshape: 189 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 190 | 191 | if self.modality == 'RGBDiff': 192 | sample_len = 3 * self.new_length 193 | input = self._get_diff(input) 194 | 195 | base_out = self.base_model(input.view( 196 | (-1, sample_len) + input.size()[-2:])) 197 | else: 198 | base_out = self.base_model(input) 199 | 200 | if self.dropout > 0: 201 | base_out = self.new_fc(base_out) 202 | 203 | if not self.before_softmax: 204 | base_out = self.softmax(base_out) 205 | 206 | if self.reshape: 207 | if self.is_shift and self.temporal_pool: 208 | base_out = base_out.view( 209 | (-1, self.num_segments // 2) + base_out.size()[1:]) 210 | else: 211 | base_out = base_out.view( 212 | (-1, self.num_segments) + base_out.size()[1:]) 213 | output = self.consensus(base_out) 214 | return output.squeeze(1) 215 | 216 | def _get_diff(self, input, keep_rgb=False): 217 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 218 | input_view = input.view( 219 | (-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) 220 | if keep_rgb: 221 | new_data = input_view.clone() 222 | else: 223 | new_data = input_view[:, :, 1:, :, :, :].clone() 224 | 225 | for x in reversed(list(range(1, self.new_length + 1))): 226 | if keep_rgb: 227 | new_data[:, :, x, :, :, :] = input_view[:, :, x, 228 | :, :, :] - input_view[:, :, x - 1, :, :, :] 229 | else: 230 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, 231 | x, :, :, :] - input_view[:, :, x - 1, :, :, :] 232 | 233 | return new_data 234 | 235 | def _construct_flow_model(self, base_model): 236 | # modify the convolution layers 237 | # Torch models are usually defined in a hierarchical way. 238 | # nn.modules.children() return all sub modules in a DFS manner 239 | modules = list(self.base_model.modules()) 240 | first_conv_idx = list( 241 | filter( 242 | lambda x: isinstance( 243 | modules[x], nn.Conv2d), list( 244 | range( 245 | len(modules)))))[0] 246 | conv_layer = modules[first_conv_idx] 247 | container = modules[first_conv_idx - 1] 248 | 249 | # modify parameters, assume the first blob contains the convolution 250 | # kernels 251 | params = [x.clone() for x in conv_layer.parameters()] 252 | kernel_size = params[0].size() 253 | new_kernel_size = kernel_size[:1] + \ 254 | (2 * self.new_length, ) + kernel_size[2:] 255 | new_kernels = params[0].data.mean( 256 | dim=1, keepdim=True).expand(new_kernel_size).contiguous() 257 | 258 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 259 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 260 | bias=True if len(params) == 2 else False) 261 | new_conv.weight.data = new_kernels 262 | if len(params) == 2: 263 | new_conv.bias.data = params[1].data # add bias if neccessary 264 | # remove .weight suffix to get the layer name 265 | layer_name = list(container.state_dict().keys())[0][:-7] 266 | 267 | # replace the first convlution layer 268 | setattr(container, layer_name, new_conv) 269 | 270 | if self.base_model_name == 'BNInception': 271 | import torch.utils.model_zoo as model_zoo 272 | sd = model_zoo.load_url( 273 | 'https://www.dropbox.com/s/35ftw2t4mxxgjae/BNInceptionFlow-ef652051.pth.tar?dl=1') 274 | base_model.load_state_dict(sd) 275 | logger.info('=> Loading pretrained Flow weight done...') 276 | else: 277 | logger.info('#' * 30, 'Warning! No Flow pretrained model is found') 278 | return base_model 279 | 280 | def _construct_diff_model(self, base_model, keep_rgb=False): 281 | # modify the convolution layers 282 | # Torch models are usually defined in a hierarchical way. 283 | # nn.modules.children() return all sub modules in a DFS manner 284 | modules = list(self.base_model.modules()) 285 | first_conv_idx = filter( 286 | lambda x: isinstance( 287 | modules[x], nn.Conv2d), list( 288 | range( 289 | len(modules))))[0] 290 | conv_layer = modules[first_conv_idx] 291 | container = modules[first_conv_idx - 1] 292 | 293 | # modify parameters, assume the first blob contains the convolution 294 | # kernels 295 | params = [x.clone() for x in conv_layer.parameters()] 296 | kernel_size = params[0].size() 297 | if not keep_rgb: 298 | new_kernel_size = kernel_size[:1] + \ 299 | (3 * self.new_length,) + kernel_size[2:] 300 | new_kernels = params[0].data.mean( 301 | dim=1, keepdim=True).expand(new_kernel_size).contiguous() 302 | else: 303 | new_kernel_size = kernel_size[:1] + \ 304 | (3 * self.new_length,) + kernel_size[2:] 305 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), 306 | 1) 307 | new_kernel_size = kernel_size[:1] + \ 308 | (3 + 3 * self.new_length,) + kernel_size[2:] 309 | 310 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, 311 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 312 | bias=True if len(params) == 2 else False) 313 | new_conv.weight.data = new_kernels 314 | if len(params) == 2: 315 | new_conv.bias.data = params[1].data # add bias if neccessary 316 | # remove .weight suffix to get the layer name 317 | layer_name = list(container.state_dict().keys())[0][:-7] 318 | 319 | # replace the first convolution layer 320 | setattr(container, layer_name, new_conv) 321 | return base_model 322 | 323 | @property 324 | def crop_size(self): 325 | return self.input_size 326 | 327 | @property 328 | def scale_size(self): 329 | return self.input_size * 256 // 224 330 | 331 | def get_augmentation(self, flip=True): 332 | if self.modality == 'RGB': 333 | if flip: 334 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 335 | GroupRandomHorizontalFlip(is_flow=False)]) 336 | else: 337 | logger.info('#' * 20, 'NO FLIP!!!') 338 | return torchvision.transforms.Compose( 339 | [GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 340 | elif self.modality == 'Flow': 341 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 342 | GroupRandomHorizontalFlip(is_flow=True)]) 343 | elif self.modality == 'RGBDiff': 344 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 345 | GroupRandomHorizontalFlip(is_flow=False)]) 346 | -------------------------------------------------------------------------------- /x_temporal/core/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class MultiGroupRandomCrop(object): 38 | def __init__(self, size, groups=1): 39 | if isinstance(size, numbers.Number): 40 | self.size = (int(size), int(size)) 41 | else: 42 | self.size = size 43 | self.groups = groups 44 | 45 | def __call__(self, img_group): 46 | 47 | w, h = img_group[0].size 48 | th, tw = self.size 49 | 50 | out_images = list() 51 | 52 | for i in range(self.groups): 53 | x1 = random.randint(0, w - tw) 54 | y1 = random.randint(0, h - th) 55 | 56 | for img in img_group: 57 | assert(img.size[0] == w and img.size[1] == h) 58 | if w == tw and h == th: 59 | out_images.append(img) 60 | else: 61 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 62 | 63 | return out_images 64 | 65 | 66 | class GroupCenterCrop(object): 67 | def __init__(self, size): 68 | self.worker = torchvision.transforms.CenterCrop(size) 69 | 70 | def __call__(self, img_group): 71 | return [self.worker(img) for img in img_group] 72 | 73 | 74 | class GroupRandomHorizontalFlip(object): 75 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 76 | """ 77 | 78 | def __init__(self, is_flow=False): 79 | self.is_flow = is_flow 80 | 81 | def __call__(self, img_group, is_flow=False): 82 | v = random.random() 83 | if v < 0.5: 84 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 85 | if self.is_flow: 86 | for i in range(0, len(ret), 2): 87 | # invert flow pixel values when flipping 88 | ret[i] = ImageOps.invert(ret[i]) 89 | return ret 90 | else: 91 | return img_group 92 | 93 | 94 | class GroupNormalize(object): 95 | def __init__(self, mean, std): 96 | self.mean = mean 97 | self.std = std 98 | 99 | def __call__(self, tensor): 100 | rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) 101 | rep_std = self.std * (tensor.size()[0] // len(self.std)) 102 | 103 | # TODO: make efficient 104 | for t, m, s in zip(tensor, rep_mean, rep_std): 105 | t.sub_(m).div_(s) 106 | 107 | return tensor 108 | 109 | 110 | class GroupScale(object): 111 | """ Rescales the input PIL.Image to the given 'size'. 112 | 'size' will be the size of the smaller edge. 113 | For example, if height > width, then image will be 114 | rescaled to (size * height / width, size) 115 | size: size of the smaller edge 116 | interpolation: Default: PIL.Image.BILINEAR 117 | """ 118 | 119 | def __init__(self, size, interpolation=Image.BILINEAR): 120 | self.worker = torchvision.transforms.Resize(size, interpolation) 121 | 122 | def __call__(self, img_group): 123 | return [self.worker(img) for img in img_group] 124 | 125 | 126 | class GroupOverSample(object): 127 | def __init__(self, crop_size, scale_size=None, flip=True): 128 | self.crop_size = crop_size if not isinstance( 129 | crop_size, int) else (crop_size, crop_size) 130 | 131 | if scale_size is not None: 132 | self.scale_worker = GroupScale(scale_size) 133 | else: 134 | self.scale_worker = None 135 | self.flip = flip 136 | 137 | def __call__(self, img_group): 138 | 139 | if self.scale_worker is not None: 140 | img_group = self.scale_worker(img_group) 141 | 142 | image_w, image_h = img_group[0].size 143 | crop_w, crop_h = self.crop_size 144 | 145 | offsets = GroupMultiScaleCrop.fill_fix_offset( 146 | False, image_w, image_h, crop_w, crop_h) 147 | oversample_group = list() 148 | for o_w, o_h in offsets: 149 | normal_group = list() 150 | flip_group = list() 151 | for i, img in enumerate(img_group): 152 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 153 | normal_group.append(crop) 154 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 155 | 156 | if img.mode == 'L' and i % 2 == 0: 157 | flip_group.append(ImageOps.invert(flip_crop)) 158 | else: 159 | flip_group.append(flip_crop) 160 | 161 | oversample_group.extend(normal_group) 162 | if self.flip: 163 | oversample_group.extend(flip_group) 164 | return oversample_group 165 | 166 | 167 | class GroupFullResSample(object): 168 | def __init__(self, crop_size, scale_size=None, flip=True): 169 | self.crop_size = crop_size if not isinstance( 170 | crop_size, int) else (crop_size, crop_size) 171 | 172 | if scale_size is not None: 173 | self.scale_worker = GroupScale(scale_size) 174 | else: 175 | self.scale_worker = None 176 | self.flip = flip 177 | 178 | def __call__(self, img_group): 179 | 180 | if self.scale_worker is not None: 181 | img_group = self.scale_worker(img_group) 182 | 183 | image_w, image_h = img_group[0].size 184 | crop_w, crop_h = self.crop_size 185 | 186 | w_step = (image_w - crop_w) // 4 187 | h_step = (image_h - crop_h) // 4 188 | 189 | offsets = list() 190 | offsets.append((0 * w_step, 2 * h_step)) # left 191 | offsets.append((4 * w_step, 2 * h_step)) # right 192 | offsets.append((2 * w_step, 2 * h_step)) # center 193 | 194 | oversample_group = list() 195 | for o_w, o_h in offsets: 196 | normal_group = list() 197 | flip_group = list() 198 | for i, img in enumerate(img_group): 199 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 200 | normal_group.append(crop) 201 | if self.flip: 202 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 203 | 204 | if img.mode == 'L' and i % 2 == 0: 205 | flip_group.append(ImageOps.invert(flip_crop)) 206 | else: 207 | flip_group.append(flip_crop) 208 | 209 | oversample_group.extend(normal_group) 210 | oversample_group.extend(flip_group) 211 | return oversample_group 212 | 213 | 214 | class GroupMultiScaleCrop(object): 215 | 216 | def __init__(self, input_size, scales=None, max_distort=1, 217 | fix_crop=True, more_fix_crop=True): 218 | self.scales = scales if scales is not None else [1, .875, .75, .66] 219 | self.max_distort = max_distort 220 | self.fix_crop = fix_crop 221 | self.more_fix_crop = more_fix_crop 222 | self.input_size = input_size if not isinstance(input_size, int) else [ 223 | input_size, input_size] 224 | self.interpolation = Image.BILINEAR 225 | 226 | def __call__(self, img_group): 227 | 228 | im_size = img_group[0].size 229 | 230 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 231 | crop_img_group = [ 232 | img.crop( 233 | (offset_w, 234 | offset_h, 235 | offset_w + 236 | crop_w, 237 | offset_h + 238 | crop_h)) for img in img_group] 239 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 240 | for img in crop_img_group] 241 | return ret_img_group 242 | 243 | def _sample_crop_size(self, im_size): 244 | image_w, image_h = im_size[0], im_size[1] 245 | 246 | # find a crop size 247 | base_size = min(image_w, image_h) 248 | crop_sizes = [int(base_size * x) for x in self.scales] 249 | crop_h = [ 250 | self.input_size[1] if abs( 251 | x - self.input_size[1]) < 3 else x for x in crop_sizes] 252 | crop_w = [ 253 | self.input_size[0] if abs( 254 | x - self.input_size[0]) < 3 else x for x in crop_sizes] 255 | 256 | pairs = [] 257 | for i, h in enumerate(crop_h): 258 | for j, w in enumerate(crop_w): 259 | if abs(i - j) <= self.max_distort: 260 | pairs.append((w, h)) 261 | 262 | crop_pair = random.choice(pairs) 263 | if not self.fix_crop: 264 | w_offset = random.randint(0, image_w - crop_pair[0]) 265 | h_offset = random.randint(0, image_h - crop_pair[1]) 266 | else: 267 | w_offset, h_offset = self._sample_fix_offset( 268 | image_w, image_h, crop_pair[0], crop_pair[1]) 269 | 270 | return crop_pair[0], crop_pair[1], w_offset, h_offset 271 | 272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 273 | offsets = self.fill_fix_offset( 274 | self.more_fix_crop, image_w, image_h, crop_w, crop_h) 275 | return random.choice(offsets) 276 | 277 | @staticmethod 278 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 279 | w_step = (image_w - crop_w) // 4 280 | h_step = (image_h - crop_h) // 4 281 | 282 | ret = list() 283 | ret.append((0, 0)) # upper left 284 | ret.append((4 * w_step, 0)) # upper right 285 | ret.append((0, 4 * h_step)) # lower left 286 | ret.append((4 * w_step, 4 * h_step)) # lower right 287 | ret.append((2 * w_step, 2 * h_step)) # center 288 | 289 | if more_fix_crop: 290 | ret.append((0, 2 * h_step)) # center left 291 | ret.append((4 * w_step, 2 * h_step)) # center right 292 | ret.append((2 * w_step, 4 * h_step)) # lower center 293 | ret.append((2 * w_step, 0 * h_step)) # upper center 294 | 295 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 296 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 297 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 298 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 299 | 300 | return ret 301 | 302 | 303 | class GroupRandomSizedCrop(object): 304 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 305 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 306 | This is popularly used to train the Inception networks 307 | size: size of the smaller edge 308 | interpolation: Default: PIL.Image.BILINEAR 309 | """ 310 | 311 | def __init__(self, size, interpolation=Image.BILINEAR): 312 | self.size = size 313 | self.interpolation = interpolation 314 | 315 | def __call__(self, img_group): 316 | for attempt in range(10): 317 | area = img_group[0].size[0] * img_group[0].size[1] 318 | target_area = random.uniform(0.08, 1.0) * area 319 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 320 | 321 | w = int(round(math.sqrt(target_area * aspect_ratio))) 322 | h = int(round(math.sqrt(target_area / aspect_ratio))) 323 | 324 | if random.random() < 0.5: 325 | w, h = h, w 326 | 327 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 328 | x1 = random.randint(0, img_group[0].size[0] - w) 329 | y1 = random.randint(0, img_group[0].size[1] - h) 330 | found = True 331 | break 332 | else: 333 | found = False 334 | x1 = 0 335 | y1 = 0 336 | 337 | if found: 338 | out_group = list() 339 | for img in img_group: 340 | img = img.crop((x1, y1, x1 + w, y1 + h)) 341 | assert(img.size == (w, h)) 342 | out_group.append( 343 | img.resize( 344 | (self.size, self.size), self.interpolation)) 345 | return out_group 346 | else: 347 | # Fallback 348 | scale = GroupScale(self.size, interpolation=self.interpolation) 349 | crop = GroupRandomCrop(self.size) 350 | return crop(scale(img_group)) 351 | 352 | 353 | class ConvertDataFormat(object): 354 | def __init__(self, model_type): 355 | self.model_type = model_type 356 | 357 | def __call__(self, images): 358 | if self.model_type == '2D': 359 | return images 360 | tc, h, w = images.size() 361 | t = tc // 3 362 | images = images.view(t, 3, h, w) 363 | images = images.permute(1, 0, 2, 3) 364 | return images 365 | 366 | 367 | class Stack(object): 368 | 369 | def __init__(self, roll=False): 370 | self.roll = roll 371 | 372 | def __call__(self, img_group): 373 | if img_group[0].mode == 'L': 374 | return np.concatenate([np.expand_dims(x, 2) 375 | for x in img_group], axis=2) 376 | elif img_group[0].mode == 'RGB': 377 | if self.roll: 378 | return np.concatenate([np.array(x)[:, :, ::-1] 379 | for x in img_group], axis=2) 380 | else: 381 | #print(np.concatenate(img_group, axis=2).shape) 382 | # print(img_group[0].shape) 383 | return np.concatenate(img_group, axis=2) 384 | 385 | 386 | class ToTorchFormatTensor(object): 387 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 388 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 389 | 390 | def __init__(self, div=True): 391 | self.div = div 392 | 393 | def __call__(self, pic): 394 | if isinstance(pic, np.ndarray): 395 | # handle numpy array 396 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 397 | else: 398 | # handle PIL Image 399 | img = torch.ByteTensor( 400 | torch.ByteStorage.from_buffer( 401 | pic.tobytes())) 402 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 403 | # put it from HWC to CHW format 404 | # yikes, this transpose takes 80% of the loading time/CPU 405 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 406 | return img.float().div(255) if self.div else img.float() 407 | 408 | 409 | class IdentityTransform(object): 410 | 411 | def __call__(self, data): 412 | return data 413 | 414 | 415 | if __name__ == "__main__": 416 | trans = torchvision.transforms.Compose([ 417 | GroupScale(256), 418 | GroupRandomCrop(224), 419 | Stack(), 420 | ToTorchFormatTensor(), 421 | GroupNormalize( 422 | mean=[.485, .456, .406], 423 | std=[.229, .224, .225] 424 | )] 425 | ) 426 | 427 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 428 | 429 | color_group = [im] * 3 430 | rst = trans(color_group) 431 | 432 | gray_group = [im.convert('L')] * 9 433 | gray_rst = trans(gray_group) 434 | 435 | trans2 = torchvision.transforms.Compose([ 436 | GroupRandomSizedCrop(256), 437 | Stack(), 438 | ToTorchFormatTensor(), 439 | GroupNormalize( 440 | mean=[.485, .456, .406], 441 | std=[.229, .224, .225]) 442 | ]) 443 | print(trans2(color_group)) 444 | -------------------------------------------------------------------------------- /x_temporal/interface/temporal_helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import copy 3 | import math 4 | import os 5 | import json 6 | import shutil 7 | import time 8 | 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn.utils import clip_grad_norm_ 13 | from torch.utils.data.distributed import DistributedSampler 14 | from tensorboardX import SummaryWriter 15 | 16 | from x_temporal.utils.log_helper import init_log, get_log_format 17 | from x_temporal.utils.lr_helper import build_scheduler 18 | from x_temporal.utils.optimizer_helper import build_optimizer 19 | from x_temporal.utils.metrics import Top1Metric 20 | from x_temporal.utils.utils import format_cfg, accuracy, AverageMeter, load_checkpoint 21 | from x_temporal.utils.dist_helper import get_rank, get_world_size, all_gather, all_reduce 22 | from x_temporal.utils.model_helper import load_state_dict 23 | from x_temporal.utils.dataset_helper import get_val_crop_transform, get_dataset, shuffle_dataset 24 | from x_temporal.core.models_entry import get_model, get_augmentation 25 | from x_temporal.core.transforms import * 26 | from x_temporal.core.dataset import VideoDataSet 27 | 28 | 29 | class TemporalHelper(object): 30 | def __init__(self, config, work_dir='./', ckpt_dict=None, inference_only=False): 31 | """ 32 | Args: 33 | config: configuration for training and testing, sometimes 34 | """ 35 | self.work_dir = work_dir 36 | self.inference_only = inference_only 37 | self.config = copy.deepcopy(config) 38 | 39 | self._setup_env() 40 | self._init_metrics() 41 | self._build() 42 | self._resume(ckpt_dict) 43 | self._ready() 44 | self._last_time = time.time() 45 | self.logger.info('Running with config:\n{}'.format(format_cfg(self.config))) 46 | 47 | def _resume(self, ckpt=None): 48 | """load state from given checkpoint or from pretrain_model/resume_model 49 | """ 50 | if ckpt is None: 51 | ckpt = self.load_pretrain_or_resume() 52 | if ckpt is None: 53 | self.logger.info('Train from scratch...') 54 | return 55 | 56 | # load model weights 57 | load_state_dict(self.model, ckpt['model'], strict=False) 58 | if not self.inference_only: 59 | if 'optimizer' in ckpt: 60 | self.start_iter = ckpt['epoch'] * self.epoch_iters 61 | self.cur_epoch = self.start_epoch = ckpt['epoch'] 62 | self.optimizer.load_state_dict(ckpt['optimizer']) 63 | self.lr_scheduler.load_state_dict(ckpt['lr_scheduler']) 64 | self.logger.info(f'resume from epoch:{self.start_epoch}') 65 | 66 | def _build(self): 67 | self.data_loaders = self._build_dataloaders() 68 | self.model = self.build_model() 69 | self.criterion = self.build_criterion() 70 | 71 | if not self.inference_only: 72 | self.optimizer = build_optimizer(self.config.trainer.optimizer, self.model) 73 | self.lr_scheduler = build_scheduler( 74 | self.config.trainer['lr_scheduler'], 75 | self.optimizer, self.epoch_iters, self.world_size * self.config.dataset.batch_size) 76 | 77 | def build_criterion(self): 78 | if self.config.trainer.loss_type == 'nll': 79 | criterion = torch.nn.CrossEntropyLoss() 80 | elif self.config.trainer.loss_type == 'bce': 81 | criterion = torch.nn.BCEWithLogitsLoss() 82 | else: 83 | raise ValueError("Unknown loss type") 84 | return criterion 85 | 86 | 87 | def _build_dataloaders(self): 88 | dataloader_dict = {} 89 | data_types = ['val', 'test'] 90 | if not self.inference_only: 91 | data_types.append('train') 92 | for data_type in data_types: 93 | if data_type in self.config.dataset: 94 | dataloader_dict[data_type] = self._build_dataloader(data_type) 95 | return dataloader_dict 96 | 97 | def _build_dataloader(self, data_type): 98 | dargs = self.config.dataset 99 | if dargs.modality == 'RGB': 100 | data_length = 1 101 | elif dargs.modality in ['Flow', 'RGBDiff']: 102 | data_length = 5 103 | 104 | if dargs.modality != 'RGBDiff': 105 | normalize = GroupNormalize(dargs.input_mean, dargs.input_std) 106 | else: 107 | normalize = IdentityTransform() 108 | 109 | 110 | if data_type == 'train': 111 | train_augmentation = get_augmentation(self.config) 112 | transform = torchvision.transforms.Compose([ 113 | train_augmentation, 114 | Stack(roll=False), 115 | ToTorchFormatTensor(div=True), 116 | normalize, 117 | ConvertDataFormat(self.config.net.model_type), 118 | ]) 119 | dataset = get_dataset(dargs, data_type, False, transform, data_length) 120 | self.train_data_size = len(dataset) 121 | self.epoch_iters = math.ceil(float(self.train_data_size) /dargs.batch_size / self.world_size) 122 | self.max_iter = self.epoch_iters * self.config.trainer.epochs 123 | sampler = DistributedSampler(dataset) if self.config.gpus > 1 else None 124 | 125 | train_loader = torch.utils.data.DataLoader( 126 | dataset, 127 | batch_size=dargs.batch_size, shuffle=(False if sampler else True), 128 | num_workers=dargs.workers, pin_memory=True, 129 | drop_last=True, sampler=sampler) 130 | return train_loader 131 | 132 | else: 133 | if self.inference_only: 134 | spatial_crops = self.config.get('evaluate', {}).get('spatial_crops', 1) 135 | temporal_samples = self.config.get('evaluate', {}).get('temporal_samples', 1) 136 | else: 137 | spatial_crops = 1 138 | temporal_samples = 1 139 | 140 | crop_aug = get_val_crop_transform(self.config.dataset, spatial_crops) 141 | transform = torchvision.transforms.Compose([ 142 | GroupScale(int(dargs.scale_size)), 143 | crop_aug, 144 | Stack(roll=False), 145 | ToTorchFormatTensor(div=True), 146 | normalize, 147 | ConvertDataFormat(self.config.net.model_type), 148 | ]) 149 | 150 | dataset = get_dataset(dargs, data_type, True, transform, data_length, temporal_samples) 151 | sampler = DistributedSampler(dataset) if self.config.gpus > 1 else None 152 | val_loader = torch.utils.data.DataLoader( 153 | dataset, 154 | batch_size=dargs.batch_size, shuffle=(False if sampler else True), 155 | drop_last=False, num_workers=dargs.workers, 156 | pin_memory=True, sampler=sampler) 157 | return val_loader 158 | 159 | 160 | def _setup_env(self): 161 | # set random seed 162 | np.random.seed(self.config.get('seed', 2020)) 163 | torch.manual_seed(self.config.get('seed', 2020)) 164 | 165 | init_log('global', logging.INFO) 166 | self.logger = logging.getLogger('global') 167 | self.rank = get_rank() 168 | self.world_size = get_world_size() 169 | self.start_iter = self.cur_iter = 0 170 | self.cur_epoch = 0 171 | self.best_prec1 = 0.0 172 | self.multi_class = self.config.dataset.get('multi_class', False) 173 | if self.multi_class: 174 | from x_temporal.utils.calculate_map import calculate_mAP 175 | self.calculate_mAP = calculate_mAP 176 | if self.rank == 0 and not self.inference_only: 177 | self.tb_logger = SummaryWriter(os.path.join(self.work_dir, 'events')) 178 | if not os.path.exists(self.work_dir): 179 | os.makedirs(self.work_dir) 180 | 181 | if not os.path.exists(os.path.join(self.work_dir, self.config.saver.save_dir)): 182 | os.makedirs(os.path.join(self.work_dir, self.config.saver.save_dir)) 183 | 184 | def _ready(self): 185 | self.model = self.model.cuda() 186 | 187 | def build_model(self): 188 | model = get_model(self.config).cuda() 189 | return model 190 | 191 | def get_dump_dict(self): 192 | return { 193 | 'epoch': self.cur_epoch, 194 | 'optimizer': self.optimizer.state_dict(), 195 | 'model': self.model.state_dict(), 196 | 'lr_scheduler': self.lr_scheduler.state_dict(), 197 | 'best_prec1': self.best_prec1 198 | } 199 | 200 | def get_batch(self, batch_type='train'): 201 | assert batch_type in self.data_loaders 202 | if not hasattr(self, 'data_iterators'): 203 | self.data_iterators = {} 204 | if batch_type not in self.data_iterators: 205 | iterator = self.data_iterators[batch_type] = iter(self.data_loaders[batch_type]) 206 | else: 207 | iterator = self.data_iterators[batch_type] 208 | 209 | try: 210 | batch = next(iterator) 211 | except StopIteration as e: # noqa 212 | shuffle_dataset(self.data_loaders[batch_type], self.cur_epoch) 213 | iterator = self.data_iterators[batch_type] = iter(self.data_loaders[batch_type]) 214 | batch = next(iterator) 215 | 216 | batch[0] = batch[0].cuda(non_blocking=True) 217 | batch[1] = batch[1].cuda(non_blocking=True) 218 | 219 | return batch 220 | 221 | def get_total_iter(self): 222 | return self.max_iter 223 | 224 | @staticmethod 225 | def load_weights(model, ckpt): 226 | assert 'model' in ckpt or 'state_dict' in ckpt 227 | model.load_state_dict(ckpt.get('model', ckpt.get('state_dict', {})), False) 228 | 229 | 230 | def forward(self, batch): 231 | data_time = time.time() - self._last_time 232 | output = self.model(batch[0]) 233 | loss = self.criterion(output, batch[1]) 234 | if self.multi_class: 235 | mAP = self.calculate_mAP(output, batch[1]) 236 | self._preverse_for_show = [loss.detach(), data_time, mAP] 237 | else: 238 | prec1, prec5 = accuracy(output, batch[1], topk=(1, min(5, self.config.dataset.num_class))) 239 | self._preverse_for_show = [loss.detach(), data_time, prec1.detach(), prec5.detach()] 240 | return loss 241 | 242 | def backward(self, loss): 243 | self.model.zero_grad() 244 | self.optimizer.zero_grad() 245 | loss.backward() 246 | return loss 247 | 248 | def update(self): 249 | self.optimizer.step() 250 | self.lr_scheduler.step() 251 | batch_time = time.time() - self._last_time 252 | if self.multi_class: 253 | loss, data_time, mAP = self._preverse_for_show 254 | self.reduce_update_metrics(loss, data_time, batch_time, mAP=mAP) 255 | else: 256 | loss, data_time, top1, top5 = self._preverse_for_show 257 | self.reduce_update_metrics(loss, data_time, batch_time, prec1=top1, prec5=top5) 258 | self._last_time = time.time() 259 | 260 | def _init_metrics(self): 261 | self.metrics = {} 262 | self.metrics['losses'] = AverageMeter(self.config.trainer.print_freq) 263 | self.metrics['batch_time'] = AverageMeter(self.config.trainer.print_freq) 264 | self.metrics['data_time'] = AverageMeter(self.config.trainer.print_freq) 265 | if self.multi_class: 266 | self.metrics['mAP'] = AverageMeter(self.config.trainer.print_freq) 267 | else: 268 | self.metrics['top1'] = AverageMeter(self.config.trainer.print_freq) 269 | self.metrics['top5'] = AverageMeter(self.config.trainer.print_freq) 270 | 271 | def reduce_update_metrics(self, loss, data_time, batch_time, prec1=None, prec5=None, mAP=None): 272 | reduced_loss = loss.clone() 273 | if self.config.gpus > 1: 274 | all_reduce(reduced_loss) 275 | 276 | self.metrics['losses'].update(reduced_loss.item()) 277 | self.metrics['batch_time'].update(batch_time) 278 | self.metrics['data_time'].update(data_time) 279 | 280 | if self.multi_class: 281 | reduced_mAP = torch.Tensor([mAP]).cuda() 282 | if self.config.gpus > 1: 283 | all_reduce(reduced_mAP) 284 | self.metrics['mAP'].update(reduced_mAP.item()) 285 | else: 286 | reduced_prec1 = prec1.clone() 287 | reduced_prec5 = prec5.clone() 288 | if self.config.gpus > 1: 289 | all_reduce(reduced_prec1) 290 | all_reduce(reduced_prec5) 291 | self.metrics['top1'].update(reduced_prec1.item()) 292 | self.metrics['top5'].update(reduced_prec5.item()) 293 | 294 | def reset_metrics(self): 295 | for key in self.metrics: 296 | self.metrics[key].reset() 297 | 298 | def train(self): 299 | self.model.cuda().train() 300 | for iter_idx in range(self.start_iter, self.max_iter): 301 | self.cur_epoch = int(float(iter_idx + 1) / self.epoch_iters) 302 | self.cur_iter = iter_idx 303 | inputs = self.get_batch('train') 304 | loss = self.forward(inputs) 305 | 306 | self.backward(loss) 307 | if self.config.trainer.clip_gradient > 0: 308 | clip_grad_norm_(self.model.parameters(), self.config.trainer.clip_gradient) 309 | self.update() 310 | 311 | if iter_idx % self.config.trainer.print_freq == 0 and self.rank == 0: 312 | self.tb_logger.add_scalar('loss_train', self.metrics['losses'].avg, iter_idx) 313 | self.tb_logger.add_scalar('lr', self.lr_scheduler.get_lr()[0], iter_idx) 314 | log_formatter = get_log_format(self.multi_class) 315 | if self.multi_class: 316 | self.tb_logger.add_scalar('mAP_train', self.metrics['mAP'].avg, iter_idx) 317 | self.logger.info(log_formatter.format( 318 | iter_idx, self.max_iter, self.cur_epoch + 1, self.config.trainer.epochs, 319 | batch_time=self.metrics['batch_time'], data_time=self.metrics['data_time'], loss=self.metrics['losses'], 320 | mAP=self.metrics['mAP'], lr=self.lr_scheduler.get_lr()[0])) 321 | else: 322 | self.tb_logger.add_scalar('acc1_train', self.metrics['top1'].avg, iter_idx) 323 | self.tb_logger.add_scalar('acc5_train', self.metrics['top5'].avg, iter_idx) 324 | self.logger.info(log_formatter.format( 325 | iter_idx, self.max_iter, self.cur_epoch + 1, self.config.trainer.epochs, 326 | batch_time=self.metrics['batch_time'], data_time=self.metrics['data_time'], loss=self.metrics['losses'], 327 | top1=self.metrics['top1'], top5=self.metrics['top5'], lr=self.lr_scheduler.get_lr()[0])) 328 | 329 | if (iter_idx == self.max_iter - 1) or (iter_idx % self.epoch_iters == 0 and iter_idx > 0 and \ 330 | self.cur_epoch % self.config.trainer.eval_freq == 0): 331 | metric = self.evaluate() 332 | 333 | if self.rank == 0 and self.tb_logger is not None: 334 | self.tb_logger.add_scalar('loss_val', metric.loss, iter_idx) 335 | if self.multi_class: 336 | self.tb_logger.add_scalar('mAP_val', metric.top1, iter_idx) 337 | else: 338 | self.tb_logger.add_scalar('acc1_val', metric.top1, iter_idx) 339 | self.tb_logger.add_scalar('acc5_val', metric.top5, iter_idx) 340 | 341 | if self.rank == 0: 342 | # remember best prec@1 and save checkpoint 343 | is_best = metric.top1 > self.best_prec1 344 | self.best_prec1 = max(metric.top1, self.best_prec1) 345 | self.save_checkpoint({ 346 | 'epoch': self.cur_epoch, 347 | 'optimizer': self.optimizer.state_dict(), 348 | 'model': self.model.state_dict(), 349 | 'lr_scheduler': self.lr_scheduler.state_dict(), 350 | 'best_prec1': self.best_prec1 351 | }, is_best) 352 | 353 | if self.multi_class: 354 | self.logger.info(' * Best mAP {:.3f}'.format(self.best_prec1)) 355 | else: 356 | self.logger.info(' * Best Prec@1 {:.3f}'.format(self.best_prec1)) 357 | 358 | 359 | end = time.time() 360 | 361 | if self.rank == 0: self.tb_logger.close() 362 | 363 | def save_checkpoint(self, state, is_best): 364 | torch.save(state, os.path.join(self.work_dir, self.config.saver.save_dir, 'ckpt.pth.tar')) 365 | if is_best: 366 | shutil.copyfile(os.path.join(self.work_dir, self.config.saver.save_dir, 'ckpt.pth.tar'), 367 | os.path.join(self.work_dir, self.config.saver.save_dir, 'ckpt_best.pth.tar')) 368 | 369 | @torch.no_grad() 370 | def evaluate(self): 371 | batch_time = AverageMeter(0) 372 | losses = AverageMeter(0) 373 | if self.multi_class: 374 | mAPs = AverageMeter(0) 375 | else: 376 | top1 = AverageMeter(0) 377 | top5 = AverageMeter(0) 378 | 379 | if self.inference_only: 380 | spatial_crops = self.config.get('evaluate', {}).get('spatial_crops', 1) 381 | temporal_samples = self.config.get('evaluate', {}).get('temporal_samples', 1) 382 | else: 383 | spatial_crops = 1 384 | temporal_samples = 1 385 | dup_samples = spatial_crops * temporal_samples 386 | 387 | self.model.cuda().eval() 388 | test_loader = self.data_loaders['test'] 389 | test_len = len(test_loader) 390 | end = time.time() 391 | for iter_idx in range(test_len): 392 | inputs = self.get_batch('test') 393 | isizes = inputs[0].shape 394 | 395 | if self.config.net.model_type == '2D': 396 | inputs[0] = inputs[0].view( 397 | isizes[0] * dup_samples, -1, isizes[2], isizes[3]) 398 | else: 399 | inputs[0] = inputs[0].view( 400 | isizes[0], isizes[1], dup_samples, -1, isizes[3], isizes[4] 401 | ) 402 | inputs[0] = inputs[0].permute(0, 2, 1, 3, 4, 5).contiguous() 403 | inputs[0] = inputs[0].view(isizes[0] * dup_samples, isizes[1], -1, isizes[3], isizes[4]) 404 | 405 | output = self.model(inputs[0]) 406 | osizes = output.shape 407 | 408 | output = output.view((osizes[0] // dup_samples, -1, osizes[1])) 409 | output = torch.mean(output, 1) 410 | 411 | 412 | loss = self.criterion(output, inputs[1]) 413 | num = inputs[0].size(0) 414 | losses.update(loss.item(), num) 415 | if self.multi_class: 416 | mAP = self.calculate_mAP(output, inputs[1]) 417 | mAPs.update(mAP, num) 418 | else: 419 | prec1, prec5 = accuracy(output, inputs[1], topk=(1, min(5, self.config.dataset.num_class))) 420 | top1.update(prec1.item(), num) 421 | top5.update(prec5.item(), num) 422 | 423 | # measure elapsed time 424 | batch_time.update(time.time() - end) 425 | end = time.time() 426 | 427 | if iter_idx % self.config.trainer.print_freq == 0: 428 | self.logger.info('Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 429 | iter_idx, test_len, batch_time=batch_time)) 430 | 431 | total_num = torch.Tensor([losses.count]).cuda() 432 | loss_sum = torch.Tensor([losses.avg*losses.count]).cuda() 433 | 434 | if self.config.gpus > 1: 435 | all_reduce(total_num, False) 436 | all_reduce(loss_sum, False) 437 | final_loss = loss_sum.item()/total_num.item() 438 | 439 | if self.multi_class: 440 | mAP_sum = torch.Tensor([mAPs.avg*mAPs.count]).cuda() 441 | if self.config.gpus > 1: 442 | all_reduce(mAP_sum) 443 | final_mAP = mAP_sum.item()/total_num.item() 444 | self.logger.info(' * mAP {:.3f}\tLoss {:.3f}\ttotal_num={}'.format(final_mAP, final_loss, 445 | total_num.item())) 446 | metric = Top1Metric(final_mAP, 0, final_loss) 447 | else: 448 | top1_sum = torch.Tensor([top1.avg*top1.count]).cuda() 449 | top5_sum = torch.Tensor([top5.avg*top5.count]).cuda() 450 | if self.config.gpus > 1: 451 | all_reduce(top1_sum, False) 452 | all_reduce(top5_sum, False) 453 | final_top1 = top1_sum.item()/total_num.item() 454 | final_top5 = top5_sum.item()/total_num.item() 455 | self.logger.info(' * Prec@1 {:.3f}\tPrec@5 {:.3f}\tLoss {:.3f}\ttotal_num={}'.format(final_top1, 456 | final_top5, final_loss, total_num.item())) 457 | metric = Top1Metric(final_top1, final_top5, final_loss) 458 | 459 | self.model.cuda().train() 460 | return metric 461 | 462 | def load_pretrain_or_resume(self): 463 | if 'resume_model' in self.config.saver: 464 | self.logger.info('Load checkpoint from {}'.format(self.config.saver['resume_model'])) 465 | return load_checkpoint(self.config.saver['resume_model']) 466 | elif 'pretrain_model' in self.config.saver: 467 | state = load_checkpoint(self.config.saver['pretrain_model']) 468 | self.logger.info('Load checkpoint from {}'.format(self.config.saver['pretrain_model'])) 469 | return {'model': state['model']} 470 | else: 471 | self.logger.info('Load nothing! No weights provided {}') 472 | return None 473 | 474 | --------------------------------------------------------------------------------