├── .idea ├── NAS_spatiotemporal.iml ├── deployment.xml ├── encodings.xml ├── misc.xml ├── modules.xml ├── other.xml ├── remote-mappings.xml └── workspace.xml ├── NAS_utils ├── __init__.py └── ops.py ├── README.md ├── Spatiotemporal Fusion in 3D CNNs A Probabilistic View.pdf ├── Supplementary Materials.pdf.pdf ├── action.zip ├── args.py ├── dataset ├── IO.py ├── __init__.py ├── augment.py └── config.py ├── main.py ├── models ├── __init__.py ├── densenet_3d.py ├── densenet_3d_forstat.py ├── mobilenet_v2_3d.py └── resnet_3d.py ├── philly_distributed_utils ├── __init__.py ├── distributed.py └── env.py ├── tools ├── __init__.py ├── ckpt_checker.py ├── generate_label_sthsthv1.py ├── generate_label_ucf101.py ├── statistics.py ├── to_hdf5.py └── visualize.py └── utils.py /.idea/NAS_spatiotemporal.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 14 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 29 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /NAS_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/NAS_utils/__init__.py -------------------------------------------------------------------------------- /NAS_utils/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch as t 4 | import torch.nn.functional as F 5 | 6 | from torch.nn import Module, Conv3d, Linear 7 | 8 | from args import parser 9 | args = parser.parse_args() 10 | 11 | _NASAS = args.enable_nasas 12 | #TRAINING_SIZE = 86017 13 | 14 | class Conv3d_with_CD(Conv3d): 15 | def __init__(self, in_channels, out_channels, kernel_size, 16 | stride=1, padding=0, dilation=1, groups=1, bias=True, 17 | weight_reg=10.0, drop_reg=1.0, p_init=1e-1, deterministic=False, debug=False, split=1, split_pattern=None, training_size=0, deact_nasas=False): 18 | super(Conv3d_with_CD, self).__init__(in_channels, out_channels, kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=bias) 20 | 21 | self._weight_reg = weight_reg / training_size 22 | self._drop_reg = drop_reg / training_size 23 | self._det_mode = deterministic 24 | self._debug_mode = debug 25 | self._deterministic = deterministic 26 | self._deact_nasas = deact_nasas 27 | 28 | self.split = split 29 | self.split_pattern = split_pattern 30 | if self.split_pattern: 31 | assert len(self.split_pattern) == self.split 32 | assert sum(split_pattern) == self.in_channels 33 | 34 | #self._noise_shape = (self.in_channels, 1, 1, 1) 35 | self._noise_shape = (1, 1, 1, 1) #if not self._deterministic else (self.in_channels, 1, 1, 1) 36 | self._eps = 1e-8 37 | self._temp = 1. / 5. 38 | self._p_init = p_init 39 | 40 | self.p_logit = t.nn.Parameter(t.Tensor([np.log(self._p_init) - np.log(1. - self._p_init)]*self.split)) if _NASAS and not self._deact_nasas else None 41 | 42 | if self._deterministic: 43 | print('Using determinist drop.') 44 | self.unif_noise_var = t.zeros(size=[1]+list(self._noise_shape)).uniform_(0,1) 45 | self.unif_noise_variable = t.nn.Parameter(self.unif_noise_var, requires_grad=False) 46 | if self._debug_mode: 47 | if self.in_channels == 64: 48 | self.p_logit.register_hook(print) 49 | 50 | def _concrete_dropout(self, input): 51 | if self.split_pattern: 52 | _p = self.p_logit[0].expand(self.split_pattern[0]) 53 | if self.split > 1: 54 | _p = t.cat( 55 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)), 56 | dim=0 57 | ) 58 | else: 59 | assert self.split == 1 60 | #_p = self.p_logit[0].expand(self.in_channels) 61 | _p = self.p_logit[0] #if not self._deterministic else self.p_logit[0].expand(self.in_channels) 62 | _p = _p.sigmoid().view([1]+list(self._noise_shape)) 63 | 64 | if self._deterministic: 65 | drop_tensor = t.floor(self.unif_noise_variable.cuda() + _p) 66 | random_tensor = 1. - drop_tensor 67 | else: 68 | unif_noise_1 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda() 69 | unif_noise_2 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda() 70 | 71 | drop_prob = ( 72 | t.log(_p + self._eps) 73 | - t.log(1. - _p + self._eps) 74 | + t.log(-t.log(unif_noise_1 + self._eps) + self._eps) 75 | - t.log(-t.log(unif_noise_2 + self._eps) + self._eps) 76 | ) 77 | 78 | drop_prob = t.sigmoid(drop_prob/self._temp) 79 | random_tensor = 1. - drop_prob 80 | 81 | return input * random_tensor 82 | 83 | 84 | def forward(self, input): 85 | input = self._concrete_dropout(input) if _NASAS and not self._deact_nasas else input 86 | return F.conv3d(input, self.weight, self.bias, self.stride, 87 | self.padding, self.dilation, self.groups) 88 | 89 | @property 90 | def KLreg(self): 91 | if self._deact_nasas: 92 | return 0.0 93 | 94 | if self.split_pattern: 95 | _p = self.p_logit[0].expand(self.split_pattern[0]) 96 | if self.split > 1: 97 | _p = t.cat( 98 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)), 99 | dim=0 100 | ) 101 | else: 102 | assert self.split == 1 103 | #_p = self.p_logit[0].expand(self.in_channels) 104 | _p = self.p_logit[0] 105 | _p = _p.sigmoid() 106 | # deprecated by split version 107 | weight_regularizer = self._weight_reg * t.sum(self.weight**2) * (1. - _p) 108 | #weight_regularizer = self._weight_reg * t.sum((self.weight**2) * (1. - _p.view([1, self.in_channels, 1, 1, 1]))) 109 | dropout_regularizer = _p * t.log(_p) 110 | dropout_regularizer += (1. - _p) * t.log(1. - _p) 111 | # deprecated by split version 112 | #dropout_regularizer *= self._drop_reg * self.in_channels 113 | dropout_regularizer *= self._drop_reg 114 | return weight_regularizer + t.sum(dropout_regularizer) 115 | 116 | @property 117 | def p(self): 118 | if self._deact_nasas: 119 | return None 120 | return self.p_logit.sigmoid() 121 | 122 | 123 | class Linear_with_CD(Linear): 124 | def __init__(self, in_features, out_features, bias=True, 125 | weight_reg=10.0, drop_reg=1.0, p_init=1e-1, deterministic=False, debug=False, split=1, split_pattern=None, training_size=0, deact_nasas = False): 126 | super(Linear_with_CD, self).__init__(in_features, out_features, bias) 127 | 128 | self._weight_reg = weight_reg / training_size 129 | self._drop_reg = drop_reg / training_size 130 | self._det_mode = deterministic 131 | self._debug_mode = debug 132 | self._deterministic = deterministic 133 | self._deact_nasas = deact_nasas 134 | 135 | self.split = split 136 | self.split_pattern = split_pattern 137 | if self.split_pattern: 138 | assert len(self.split_pattern) == self.split 139 | assert sum(split_pattern) == self.in_features 140 | 141 | self._noise_shape = (self.in_features,) 142 | self._eps = 1e-8 143 | self._temp = 1. / 5. 144 | self._p_init = p_init 145 | 146 | self.p_logit = t.nn.Parameter(t.Tensor([np.log(self._p_init) - np.log(1. - self._p_init)] * self.split)) if _NASAS and not self._deact_nasas else None 147 | 148 | if self._deterministic: 149 | print('Using determinist drop.') 150 | self.unif_noise_var = t.zeros(size=[1] + list(self._noise_shape)).uniform_(0,1) 151 | self.unif_noise_variable = t.nn.Parameter(self.unif_noise_var, requires_grad=False) 152 | if self._debug_mode: 153 | self.p_logit.register_hook(print) 154 | 155 | def _concrete_dropout(self, input): 156 | if self.split_pattern: 157 | _p = self.p_logit[0].expand(self.split_pattern[0]) 158 | if self.split > 1: 159 | _p = t.cat( 160 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)), 161 | dim=0 162 | ) 163 | else: 164 | assert self.split == 1 165 | _p = self.p_logit[0].expand(self.in_features) 166 | _p = _p.sigmoid().view([1]+list(self._noise_shape)) 167 | 168 | if self._deterministic: 169 | drop_tensor = t.floor(self.unif_noise_variable.cuda() + _p) 170 | random_tensor = 1. - drop_tensor 171 | else: 172 | unif_noise_1 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda() 173 | unif_noise_2 = t.rand(size=[input.shape[0]]+list(self._noise_shape)).cuda() 174 | 175 | drop_prob = ( 176 | t.log(_p + self._eps) 177 | - t.log(1. - _p + self._eps) 178 | + t.log(-t.log(unif_noise_1 + self._eps) + self._eps) 179 | - t.log(-t.log(unif_noise_2 + self._eps) + self._eps) 180 | ) 181 | 182 | drop_prob = t.sigmoid(drop_prob/self._temp) 183 | random_tensor = 1. - drop_prob 184 | 185 | return input * random_tensor 186 | 187 | 188 | def forward(self, input): 189 | input = self._concrete_dropout(input) if _NASAS and not self._deact_nasas else input 190 | return F.linear(input, self.weight, self.bias) 191 | 192 | @property 193 | def KLreg(self): 194 | if self._deact_nasas: 195 | return 0.0 196 | 197 | if self.split_pattern: 198 | _p = self.p_logit[0].expand(self.split_pattern[0]) 199 | if self.split > 1: 200 | _p = t.cat( 201 | (_p, self.p_logit[1:].view(-1,1).expand(self.split-1, self.split_pattern[1]).reshape(-1)), 202 | dim=0 203 | ) 204 | else: 205 | assert self.split == 1 206 | _p = self.p_logit[0].expand(self.in_features) 207 | _p = _p.sigmoid() 208 | # deprecated by split version 209 | #weight_regularizer = self._weight_reg * t.sum(self.weight**2) * (1. - _p) 210 | weight_regularizer = self._weight_reg * t.sum((self.weight**2) * (1. - _p.view([1, self.in_features]))) 211 | dropout_regularizer = _p * t.log(_p) 212 | dropout_regularizer += (1. - _p) * t.log(1. - _p) 213 | # deprecated by split version 214 | #dropout_regularizer *= self._drop_reg * self.in_channels 215 | dropout_regularizer *= self._drop_reg 216 | return weight_regularizer + t.sum(dropout_regularizer) 217 | 218 | @property 219 | def p(self): 220 | if self._deact_nasas: 221 | return None 222 | return self.p_logit.sigmoid() 223 | 224 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatiotemporal Fusion in 3D CNNs: A Probabilistic View 2 | 3 | Experimental codes for the CVPR 2020 Oral Paper "Spatiotemporal Fusion in 3D CNNs: A Probabilistic View". 4 | 5 | The official code (Re-organized) is still under reviewed and to be appeared in the Microsoft official Repo. 6 | 7 | 8 | # Reference 9 | 10 | [1] Yizhou Zhou, Xiaoyan Sun, Chong Luo, Zheng-Jun Zha and Wengjun Zeng. Spatiotemporal fusion in 3D CNNs: A probabilistic view. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 9829-9838). 11 | -------------------------------------------------------------------------------- /Spatiotemporal Fusion in 3D CNNs A Probabilistic View.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/Spatiotemporal Fusion in 3D CNNs A Probabilistic View.pdf -------------------------------------------------------------------------------- /Supplementary Materials.pdf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/Supplementary Materials.pdf.pdf -------------------------------------------------------------------------------- /action.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/action.zip -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | from philly_distributed_utils.env import get_master_ip 2 | from philly_distributed_utils.distributed import ompi_size 3 | 4 | import argparse 5 | parser = argparse.ArgumentParser(description="PyTorch implementation of NAS_spatiotemporal") 6 | parser.add_argument('--dataset', type=str, default="something") 7 | parser.add_argument('--modality', type=str, default='RGB', choices=['RGB', 'Flow']) 8 | parser.add_argument('--train_list', type=str, default="") 9 | parser.add_argument('--val_list', type=str, default="") 10 | parser.add_argument('--root_path', type=str, default="/mnt/data/") 11 | parser.add_argument('--store_name', type=str, default="") 12 | # ========================= Model Configs ========================== 13 | parser.add_argument('--arch', type=str, default="Dense3D121") 14 | parser.add_argument('--num_segments', type=int, default=1) 15 | parser.add_argument('--consensus_type', type=str, default='avg') 16 | parser.add_argument('--k', type=int, default=3) 17 | 18 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 19 | metavar='DO', help='dropout ratio (default: 0.5)') 20 | parser.add_argument('--loss_type', type=str, default="nll", 21 | choices=['nll']) 22 | parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame") 23 | parser.add_argument('--suffix', type=str, default=None) 24 | parser.add_argument('--pretrain', type=str, default='imagenet') 25 | parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint') 26 | 27 | parser.add_argument('--enable_nasas', default=False, action="store_true", 28 | help='enable NASAS for architecture search') 29 | parser.add_argument('--temporal_nasas_only', default=False, action="store_true", 30 | help='only enable NASAS on temporal axis for architecture search') 31 | 32 | parser.add_argument('--cross_warmup', default=False, action="store_true", 33 | help='cross warmup for NASAS') 34 | 35 | parser.add_argument('--weight_reg', type=float, default=10.0, 36 | help='weight regularization used for nasas') 37 | parser.add_argument('--p_init', type=float, default=0.1, 38 | help='initial p used for nasas') 39 | parser.add_argument('--selection_mode', default=False, action="store_true", 40 | help='use selection mode in nasas') 41 | parser.add_argument('--test_mode', default=False, action="store_true", 42 | help='use test mode in nasas') 43 | parser.add_argument('--finetune_mode', default=False, action="store_true", 44 | help='use finetune mode in nasas') 45 | parser.add_argument('--training_size', default=86017, type=int, 46 | help='number of training samples') 47 | 48 | 49 | parser.add_argument('--net_version', default='pure_fused', type=str, 50 | help='densenet 3d version') 51 | 52 | # ========================= Learning Configs ========================== 53 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 54 | help='number of total epochs to run') 55 | parser.add_argument('-b', '--batch-size', default=32, type=int, 56 | metavar='N', help='mini-batch size (default: 256)') 57 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 58 | metavar='LR', help='initial learning rate') 59 | parser.add_argument('--lr_type', default='step', type=str, 60 | metavar='LRtype', help='learning rate type') 61 | parser.add_argument('--lr_steps', default=[30, 60, 80], type=float, nargs="+", 62 | metavar='LRSteps', help='epochs to decay learning rate by 10') 63 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 64 | help='momentum') 65 | parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float, 66 | metavar='W', help='weight decay (default: 5e-4)') 67 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 68 | metavar='W', help='gradient norm clipping (default: disabled)') 69 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 70 | 71 | 72 | # ========================= Monitor Configs ========================== 73 | parser.add_argument('--print-freq', '-p', default=20, type=int, 74 | metavar='N', help='print frequency (default: 10)') 75 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 76 | metavar='N', help='evaluation frequency (default: 5)') 77 | parser.add_argument('--test_split', type=int, default=0, 78 | help='The index of test file') 79 | 80 | 81 | # ========================= Runtime Configs ========================== 82 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 83 | help='number of data loading workers (default: 8)') 84 | parser.add_argument('--resume', 85 | default='/mnt/log/NAS_spatiotemporal/checkpoint/warmup/NAS_sptp_something_RGB_Dense3D121_avg_segment1_e50_droprate0.5_num_dense_sample32_dense/ckpt.best.pth.tar', 86 | type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 87 | parser.add_argument('--break_resume', default=False, action="store_true", 88 | help='if do break restore') 89 | parser.add_argument('--warmup', default=False, action="store_true", 90 | help='if do warmup initialization') 91 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 92 | help='evaluate model on validation set') 93 | parser.add_argument('--snapshot_pref', type=str, default="") 94 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 95 | help='manual epoch number (useful on restarts)') 96 | parser.add_argument('--gpus', nargs='+', type=int, default=None) 97 | parser.add_argument('--flow_prefix', default="", type=str) 98 | parser.add_argument('--root_log',type=str, default='/mnt/log/NAS_spatiotemporal/log') 99 | parser.add_argument('--root_model', type=str, default='/mnt/log/NAS_spatiotemporal/checkpoint') 100 | 101 | parser.add_argument('--shift', default=False, action="store_true", help='use shift for models') 102 | parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)') 103 | parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)') 104 | 105 | parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling') 106 | parser.add_argument('--non_local', default=False, action="store_true", help='add non local block') 107 | 108 | parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset') 109 | parser.add_argument('--dense_sample_stride', default=1, type=int, help='dense sample stride for dense sample') 110 | parser.add_argument('--num_dense_sample', default=32, type=int, help='dense sample number for dense sample') 111 | parser.add_argument('--random_dense_sample_stride', default=False, action="store_true", help='use random dense sample stride for video dataset') 112 | 113 | parser.add_argument('--syncbn', default=False, action="store_true", help='Synchronized batch normalization') 114 | parser.add_argument('--use_zip', default=False, action="store_true", help='Use ZIP file for data I/O') 115 | parser.add_argument('--freeze_bn', default=False, action="store_true", help='Freeze batch normalization') 116 | # ========================= Distributed Configs ========================== 117 | parser.add_argument('--local_rank', type=int) 118 | parser.add_argument('--node_rank', type=int, default=-1) 119 | parser.add_argument('--dist-url', 120 | default='', #'tcp://' + get_master_ip() + ':23456', 121 | type=str, 122 | help='url used to set up distributed training') 123 | parser.add_argument('--world-size', default=0,#ompi_size(), 124 | type=int, help='number of distributed processes') 125 | parser.add_argument('--dist-backend', default='nccl', type=str, 126 | help='distributed backend') 127 | parser.add_argument('--philly-mpi-multi-node', default=False,action="store_true", 128 | help='nccl multiple node distributed') 129 | parser.add_argument('--philly-nccl-multi-node', default=False,action="store_true", 130 | help='nccl multiple node distributed') 131 | -------------------------------------------------------------------------------- /dataset/IO.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from zipfile import ZipFile 4 | from PIL import Image 5 | import os 6 | import numpy as np 7 | from numpy.random import randint 8 | 9 | 10 | class VideoRecord(object): 11 | def __init__(self, row): 12 | self._data = row 13 | 14 | @property 15 | def path(self): 16 | return self._data[0] 17 | 18 | @property 19 | def num_frames(self): 20 | return int(self._data[1]) 21 | 22 | @property 23 | def label(self): 24 | return int(self._data[2]) 25 | 26 | 27 | class TSNDataSet(data.Dataset): 28 | def __init__(self, root_path, list_file, 29 | num_segments=3, new_length=1, modality='RGB', 30 | image_tmpl='img_{:05d}.jpg', transform=None, 31 | random_shift=True, test_mode=False, 32 | remove_missing=False, dense_sample=False, num_dense_sample=32, dense_sample_stride=1, random_dense_sample_stride=False, is_zip=False): 33 | 34 | self.root_path = root_path 35 | self.list_file = list_file 36 | self.num_segments = num_segments 37 | self.new_length = new_length 38 | self.modality = modality 39 | self.image_tmpl = image_tmpl 40 | self.transform = transform 41 | self.random_shift = random_shift 42 | self.test_mode = test_mode 43 | self.remove_missing = remove_missing 44 | self.dense_sample = dense_sample # using dense sample as I3D 45 | self.num_dense_sample = num_dense_sample 46 | self.dense_sample_stride = dense_sample_stride 47 | self.random_dense_sample_stride = random_dense_sample_stride 48 | self.is_zip = is_zip 49 | if self.dense_sample: 50 | print('=> Using dense sample for the dataset...') 51 | 52 | if self.modality == 'RGBDiff': 53 | self.new_length += 1 # Diff needs one more image to calculate diff 54 | 55 | self._parse_list() 56 | 57 | def _load_image(self, directory, idx, zip_f=None): 58 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 59 | try: 60 | if self.is_zip: 61 | return [Image.open(zip_f.open(self.image_tmpl.format(idx))).convert('RGB')] 62 | else: 63 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 64 | except Exception: 65 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 66 | if self.is_zip: 67 | return [Image.open(zip_f.open(self.image_tmpl.format(1))).convert('RGB')] 68 | else: 69 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 70 | elif self.modality == 'Flow': 71 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf 72 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert( 73 | 'L') 74 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert( 75 | 'L') 76 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow 77 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. 78 | format(int(directory), 'x', idx))).convert('L') 79 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. 80 | format(int(directory), 'y', idx))).convert('L') 81 | else: 82 | try: 83 | # idx_skip = 1 + (idx-1)*5 84 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert( 85 | 'RGB') 86 | except Exception: 87 | print('error loading flow file:', 88 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 89 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB') 90 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel 91 | flow_x, flow_y, _ = flow.split() 92 | x_img = flow_x.convert('L') 93 | y_img = flow_y.convert('L') 94 | 95 | return [x_img, y_img] 96 | 97 | def _parse_list(self): 98 | # check the frame number is large >3: 99 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 100 | tmp = [[' '.join(x[:-2]), x[-2], x[-1]] for x in tmp] 101 | if not self.test_mode or self.remove_missing: 102 | if self.test_mode and 'kinetics' in self.root_path: 103 | tmp = [item for item in tmp if int(item[1]) >= 32] 104 | print('####################### Heavy remove #######################') 105 | else: 106 | tmp = [item for item in tmp if int(item[1]) >= 3] 107 | self.video_list = [VideoRecord(item) for item in tmp] 108 | 109 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 110 | for v in self.video_list: 111 | v._data[1] = int(v._data[1]) / 2 112 | print('video number:%d' % (len(self.video_list))) 113 | 114 | def _sample_indices(self, record): 115 | """ 116 | 117 | :param record: VideoRecord 118 | :return: list 119 | """ 120 | if self.dense_sample: # i3d dense sample 121 | sample_pos = max(1, 1 + record.num_frames - self.num_dense_sample * self.dense_sample_stride) 122 | t_stride = self.num_dense_sample * self.dense_sample_stride // self.num_segments 123 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 124 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 125 | return np.array(offsets) + 1 126 | else: # normal sample 127 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 128 | if average_duration > 0: 129 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, 130 | size=self.num_segments) 131 | elif record.num_frames > self.num_segments: 132 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 133 | else: 134 | offsets = np.zeros((self.num_segments,)) 135 | return offsets + 1 136 | 137 | def _get_val_indices(self, record): 138 | if self.dense_sample: # i3d dense sample 139 | sample_pos = max(1, 1 + record.num_frames - self.num_dense_sample * self.dense_sample_stride) 140 | t_stride = self.num_dense_sample * self.dense_sample_stride // self.num_segments 141 | #start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 142 | start_idx = 0 if sample_pos == 1 else sample_pos//2 143 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 144 | return np.array(offsets) + 1 145 | else: 146 | if record.num_frames > self.num_segments + self.new_length - 1: 147 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 148 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 149 | else: 150 | offsets = np.zeros((self.num_segments,)) 151 | return offsets + 1 152 | 153 | def _get_test_indices(self, record): 154 | if self.dense_sample: 155 | sample_pos = max(1, 1 + record.num_frames - self.num_dense_sample * self.dense_sample_stride) 156 | t_stride = self.num_dense_sample * self.dense_sample_stride // self.num_segments 157 | start_list = np.linspace(0, sample_pos - 1, num=2, dtype=int) 158 | offsets = [] 159 | for start_idx in start_list.tolist(): 160 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] 161 | return np.array(offsets) + 1 162 | else: 163 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 164 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 165 | return offsets + 1 166 | 167 | def __getitem__(self, index): 168 | record = self.video_list[index] 169 | # check this is a legit video folder 170 | 171 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 172 | file_name = self.image_tmpl.format('x', 1) 173 | full_path = os.path.join(self.root_path, record.path, file_name) 174 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 175 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 176 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 177 | else: 178 | file_name = self.image_tmpl.format(1) 179 | full_path = os.path.join(self.root_path, record.path, file_name) 180 | 181 | if not self.is_zip: 182 | while not os.path.exists(full_path): 183 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) 184 | index = np.random.randint(len(self.video_list)) 185 | record = self.video_list[index] 186 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': 187 | file_name = self.image_tmpl.format('x', 1) 188 | full_path = os.path.join(self.root_path, record.path, file_name) 189 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': 190 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) 191 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) 192 | else: 193 | file_name = self.image_tmpl.format(1) 194 | full_path = os.path.join(self.root_path, record.path, file_name) 195 | 196 | if not self.test_mode: 197 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 198 | else: 199 | segment_indices = self._get_test_indices(record) 200 | return self.get(record, segment_indices) 201 | 202 | def get(self, record, indices): 203 | 204 | images = list() 205 | zip_f = ZipFile(os.path.join(self.root_path, record.path, 'RGB_frames.zip'), mode='r') if self.is_zip else None 206 | if self.dense_sample: 207 | assert self.num_segments == 1, "dense sample needs segment number to be 1." 208 | for seg_ind in indices: 209 | p = int(seg_ind) 210 | for i in range(self.num_dense_sample): 211 | seg_imgs = self._load_image(record.path, p, zip_f) 212 | images.extend(seg_imgs) 213 | if p < record.num_frames - self.dense_sample_stride: 214 | if self.random_dense_sample_stride and self.random_shift: 215 | p += randint(1, self.dense_sample_stride+1) 216 | else: 217 | p += self.dense_sample_stride 218 | else: 219 | for seg_ind in indices: 220 | p = int(seg_ind) 221 | for i in range(self.new_length): 222 | seg_imgs = self._load_image(record.path, p) 223 | images.extend(seg_imgs) 224 | if p < record.num_frames: 225 | p += 1 226 | 227 | process_data = self.transform(images) 228 | if zip_f: 229 | zip_f.close() 230 | return process_data, record.label 231 | 232 | def __len__(self): 233 | return len(self.video_list) 234 | 235 | if __name__ == '__main__': 236 | TSNDataSet('', '/data/home/v-yizzh/workspace/code/NAS_spatiotemporal/dataset/val_videofolder.txt') 237 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/augment.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size, repeat=0): 12 | 13 | self.repeat = repeat 14 | if isinstance(size, numbers.Number): 15 | self.size = (int(size), int(size)) 16 | else: 17 | self.size = size 18 | 19 | def __call__(self, img_group): 20 | 21 | h, w, _ = img_group[0].shape 22 | th, tw = self.size 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | cropped_img_group = img_group[:, y1 : y1+th, x1 : x1+tw, :] 27 | 28 | for i in range(self.repeat): 29 | x1 = random.randint(0, w - tw) 30 | y1 = random.randint(0, h - th) 31 | cropped_img_group = np.concatenate((cropped_img_group, img_group[:, y1 : y1+th, x1 : x1+tw, :]), axis=0) 32 | ''' 33 | out_images = list() 34 | 35 | x1 = random.randint(0, w - tw) 36 | y1 = random.randint(0, h - th) 37 | 38 | for img in img_group: 39 | assert(img.size[0] == w and img.size[1] == h) 40 | if w == tw and h == th: 41 | out_images.append(img) 42 | else: 43 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 44 | 45 | 46 | return out_images 47 | ''' 48 | return cropped_img_group 49 | 50 | 51 | class GroupCenterCrop(object): 52 | ''' 53 | def __init__(self, size): 54 | self.worker = torchvision.transforms.CenterCrop(size) 55 | 56 | def __call__(self, img_group): 57 | return [self.worker(img) for img in img_group] 58 | ''' 59 | 60 | def __init__(self, size): 61 | if isinstance(size, numbers.Number): 62 | self.size = (int(size), int(size)) 63 | else: 64 | self.size = size 65 | 66 | def __call__(self, img_group): 67 | 68 | h, w, _ = img_group[0].shape 69 | th, tw = self.size 70 | 71 | assert th <= h and tw <= w, "target size must be smaller than original size." 72 | 73 | x1 = (w - tw) // 2 74 | y1 = (h - th) // 2 75 | 76 | return img_group[:, y1: y1 + th, x1: x1 + tw, :] 77 | 78 | 79 | class GroupRandomHorizontalFlip(object): 80 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 81 | """ 82 | def __init__(self, is_flow=False): 83 | self.is_flow = is_flow 84 | ''' 85 | def __call__(self, img_group, is_flow=False): 86 | v = random.random() 87 | if v < 0.5: 88 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 89 | if self.is_flow: 90 | for i in range(0, len(ret), 2): 91 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 92 | return ret 93 | else: 94 | return img_group 95 | ''' 96 | def __call__(self, img_group, is_flow=False): 97 | assert is_flow is False, "Currently only RGB flip is supported." 98 | v = random.random() 99 | if v < 0.5: 100 | return img_group[:, :, ::-1, :].copy() 101 | else: 102 | return img_group 103 | 104 | 105 | class GroupNormalize(object): 106 | def __init__(self, mean, std): 107 | self.mean = torch.from_numpy(np.array(mean, dtype=np.float32)).view(-1,1,1,1) 108 | self.std = torch.from_numpy(np.array(std, dtype=np.float32)).view(-1,1,1,1) 109 | 110 | def __call__(self, tensor): 111 | ''' 112 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 113 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 114 | 115 | # TODO: make efficient 116 | for t, m, s in zip(tensor, rep_mean, rep_std): 117 | t.sub_(m).div_(s) 118 | ''' 119 | tensor.sub_(self.mean).div_(self.std) 120 | return tensor 121 | 122 | 123 | class GroupScale(object): 124 | """ Rescales the input PIL.Image to the given 'size'. 125 | 'size' will be the size of the smaller edge. 126 | For example, if height > width, then image will be 127 | rescaled to (size * height / width, size) 128 | size: size of the smaller edge 129 | interpolation: Default: PIL.Image.BILINEAR 130 | """ 131 | def __init__(self, size, interpolation=Image.BILINEAR): 132 | if isinstance(size, int): 133 | size = [size] 134 | else: 135 | assert isinstance(size, list), "Size is list or int." 136 | self.worker = [torchvision.transforms.Resize(this_size, interpolation) for this_size in size] 137 | 138 | def __call__(self, img_group): 139 | this_worker = self.worker[np.random.randint(len(self.worker))] 140 | return [this_worker(img) for img in img_group] 141 | 142 | 143 | class GroupOverSample(object): 144 | def __init__(self, crop_size, scale_size=None, flip=True): 145 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 146 | 147 | if scale_size is not None: 148 | self.scale_worker = GroupScale(scale_size) 149 | else: 150 | self.scale_worker = None 151 | self.flip = flip 152 | 153 | def __call__(self, img_group): 154 | 155 | if self.scale_worker is not None: 156 | img_group = self.scale_worker(img_group) 157 | 158 | image_w, image_h = img_group[0].size 159 | crop_w, crop_h = self.crop_size 160 | 161 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 162 | oversample_group = list() 163 | for o_w, o_h in offsets: 164 | normal_group = list() 165 | flip_group = list() 166 | for i, img in enumerate(img_group): 167 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 168 | normal_group.append(crop) 169 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 170 | 171 | if img.mode == 'L' and i % 2 == 0: 172 | flip_group.append(ImageOps.invert(flip_crop)) 173 | else: 174 | flip_group.append(flip_crop) 175 | 176 | oversample_group.extend(normal_group) 177 | if self.flip: 178 | oversample_group.extend(flip_group) 179 | return oversample_group 180 | 181 | 182 | class GroupFullResSample(object): 183 | def __init__(self, crop_size, scale_size=None, flip=True): 184 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 185 | 186 | if scale_size is not None: 187 | self.scale_worker = GroupScale(scale_size) 188 | else: 189 | self.scale_worker = None 190 | self.flip = flip 191 | 192 | def __call__(self, img_group): 193 | 194 | if self.scale_worker is not None: 195 | img_group = self.scale_worker(img_group) 196 | 197 | image_w, image_h = img_group[0].size 198 | crop_w, crop_h = self.crop_size 199 | 200 | w_step = (image_w - crop_w) // 4 201 | h_step = (image_h - crop_h) // 4 202 | 203 | offsets = list() 204 | offsets.append((0 * w_step, 2 * h_step)) # left 205 | offsets.append((4 * w_step, 2 * h_step)) # right 206 | offsets.append((2 * w_step, 2 * h_step)) # center 207 | 208 | oversample_group = list() 209 | for o_w, o_h in offsets: 210 | normal_group = list() 211 | flip_group = list() 212 | for i, img in enumerate(img_group): 213 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 214 | normal_group.append(crop) 215 | if self.flip: 216 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 217 | 218 | if img.mode == 'L' and i % 2 == 0: 219 | flip_group.append(ImageOps.invert(flip_crop)) 220 | else: 221 | flip_group.append(flip_crop) 222 | 223 | oversample_group.extend(normal_group) 224 | oversample_group.extend(flip_group) 225 | return oversample_group 226 | 227 | 228 | class GroupMultiScaleCrop(object): 229 | 230 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 231 | self.scales = scales if scales is not None else [1, .875, .75, .66] 232 | self.max_distort = max_distort 233 | self.fix_crop = fix_crop 234 | self.more_fix_crop = more_fix_crop 235 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 236 | self.interpolation = Image.BILINEAR 237 | 238 | def __call__(self, img_group): 239 | 240 | im_size = img_group[0].size 241 | 242 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 243 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 244 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 245 | for img in crop_img_group] 246 | return ret_img_group 247 | 248 | def _sample_crop_size(self, im_size): 249 | image_w, image_h = im_size[0], im_size[1] 250 | 251 | # find a crop size 252 | base_size = min(image_w, image_h) 253 | crop_sizes = [int(base_size * x) for x in self.scales] 254 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 255 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 256 | 257 | pairs = [] 258 | for i, h in enumerate(crop_h): 259 | for j, w in enumerate(crop_w): 260 | if abs(i - j) <= self.max_distort: 261 | pairs.append((w, h)) 262 | 263 | crop_pair = random.choice(pairs) 264 | if not self.fix_crop: 265 | w_offset = random.randint(0, image_w - crop_pair[0]) 266 | h_offset = random.randint(0, image_h - crop_pair[1]) 267 | else: 268 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 269 | 270 | return crop_pair[0], crop_pair[1], w_offset, h_offset 271 | 272 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 273 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 274 | return random.choice(offsets) 275 | 276 | @staticmethod 277 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 278 | w_step = (image_w - crop_w) // 4 279 | h_step = (image_h - crop_h) // 4 280 | 281 | ret = list() 282 | ret.append((0, 0)) # upper left 283 | ret.append((4 * w_step, 0)) # upper right 284 | ret.append((0, 4 * h_step)) # lower left 285 | ret.append((4 * w_step, 4 * h_step)) # lower right 286 | ret.append((2 * w_step, 2 * h_step)) # center 287 | 288 | if more_fix_crop: 289 | ret.append((0, 2 * h_step)) # center left 290 | ret.append((4 * w_step, 2 * h_step)) # center right 291 | ret.append((2 * w_step, 4 * h_step)) # lower center 292 | ret.append((2 * w_step, 0 * h_step)) # upper center 293 | 294 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 295 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 296 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 297 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 298 | 299 | return ret 300 | 301 | 302 | class GroupRandomSizedCrop(object): 303 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 304 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 305 | This is popularly used to train the Inception networks 306 | size: size of the smaller edge 307 | interpolation: Default: PIL.Image.BILINEAR 308 | """ 309 | def __init__(self, size, interpolation=Image.BILINEAR): 310 | self.size = size 311 | self.interpolation = interpolation 312 | 313 | def __call__(self, img_group): 314 | for attempt in range(10): 315 | area = img_group[0].size[0] * img_group[0].size[1] 316 | target_area = random.uniform(0.08, 1.0) * area 317 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 318 | 319 | w = int(round(math.sqrt(target_area * aspect_ratio))) 320 | h = int(round(math.sqrt(target_area / aspect_ratio))) 321 | 322 | if random.random() < 0.5: 323 | w, h = h, w 324 | 325 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 326 | x1 = random.randint(0, img_group[0].size[0] - w) 327 | y1 = random.randint(0, img_group[0].size[1] - h) 328 | found = True 329 | break 330 | else: 331 | found = False 332 | x1 = 0 333 | y1 = 0 334 | 335 | if found: 336 | out_group = list() 337 | for img in img_group: 338 | img = img.crop((x1, y1, x1 + w, y1 + h)) 339 | assert(img.size == (w, h)) 340 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 341 | return out_group 342 | else: 343 | # Fallback 344 | scale = GroupScale(self.size, interpolation=self.interpolation) 345 | crop = GroupRandomCrop(self.size) 346 | return crop(scale(img_group)) 347 | 348 | 349 | class Stack(object): 350 | 351 | def __init__(self, roll=False): 352 | self.roll = roll 353 | 354 | def __call__(self, img_group): 355 | if img_group[0].mode == 'L': 356 | return np.concatenate([np.expand_dims(np.expand_dims(x, 2), 0) for x in img_group], axis=0) 357 | elif img_group[0].mode == 'RGB': 358 | if self.roll: 359 | return np.concatenate([np.expand_dims(np.array(x)[:, :, ::-1], axis=0) for x in img_group], axis=0) 360 | else: 361 | return np.concatenate([np.expand_dims(np.array(x), axis=0) for x in img_group], axis=0) 362 | 363 | 364 | class ToTorchFormatTensor(object): 365 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 366 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 367 | def __init__(self, div=True): 368 | self.div = div 369 | 370 | def __call__(self, pic): 371 | assert isinstance(pic, np.ndarray), "Require numpy array input." 372 | if isinstance(pic, np.ndarray): 373 | # handle numpy array 374 | img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous() 375 | else: 376 | # handle PIL Image 377 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 378 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 379 | # put it from HWC to CHW format 380 | # yikes, this transpose takes 80% of the loading time/CPU 381 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 382 | return img.float().div(255) if self.div else img.float() 383 | 384 | 385 | class IdentityTransform(object): 386 | 387 | def __call__(self, data): 388 | return data 389 | 390 | 391 | def get_train_augmentation(modality='RGB', flip=True, div=True, roll=False): 392 | assert modality == 'RGB', "Currently only RGB augmentation is supported." 393 | if modality == 'RGB': 394 | if flip: 395 | #return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), 396 | # GroupRandomHorizontalFlip(is_flow=False)]) 397 | return torchvision.transforms.Compose([GroupScale([256, 288, 320]), 398 | Stack(roll=roll), 399 | GroupRandomCrop(224), 400 | GroupRandomHorizontalFlip(is_flow=False), 401 | ToTorchFormatTensor(div=div)] 402 | ) 403 | else: 404 | print('#' * 20, 'NO FLIP!!!') 405 | #return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) 406 | return torchvision.transforms.Compose([GroupScale(256), 407 | Stack(roll=roll), 408 | GroupRandomCrop(224), 409 | ToTorchFormatTensor(div=div)] 410 | ) 411 | elif modality == 'Flow': 412 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 413 | GroupRandomHorizontalFlip(is_flow=True)]) 414 | elif modality == 'RGBDiff': 415 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 416 | GroupRandomHorizontalFlip(is_flow=False)]) 417 | 418 | 419 | def get_val_augmentation(modality='RGB', div=True, roll=False): 420 | assert modality == 'RGB', "Currently only RGB augmentation is supported." 421 | if modality == 'RGB': 422 | return torchvision.transforms.Compose([GroupScale(256), 423 | Stack(roll=roll), 424 | #GroupRandomCrop(224), 425 | GroupCenterCrop(256), 426 | ToTorchFormatTensor(div=div)] 427 | ) 428 | elif modality == 'Flow': 429 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 430 | GroupRandomHorizontalFlip(is_flow=True)]) 431 | elif modality == 'RGBDiff': 432 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 433 | GroupRandomHorizontalFlip(is_flow=False)]) 434 | 435 | 436 | def get_test_augmentation(modality='RGB', div=True, roll=False): 437 | assert modality == 'RGB', "Currently only RGB augmentation is supported." 438 | if modality == 'RGB': 439 | return torchvision.transforms.Compose([GroupScale(256), 440 | Stack(roll=roll), 441 | #GroupRandomCrop(256, 2), 442 | GroupCenterCrop(256), 443 | ToTorchFormatTensor(div=div)] 444 | ) 445 | elif modality == 'Flow': 446 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 447 | GroupRandomHorizontalFlip(is_flow=True)]) 448 | elif modality == 'RGBDiff': 449 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 450 | GroupRandomHorizontalFlip(is_flow=False)]) 451 | 452 | 453 | def get_selection_augmentation(modality='RGB', div=True, roll=False): 454 | assert modality == 'RGB', "Currently only RGB augmentation is supported." 455 | if modality == 'RGB': 456 | return torchvision.transforms.Compose([GroupScale(256), 457 | Stack(roll=roll), 458 | GroupCenterCrop(256), 459 | ToTorchFormatTensor(div=div)] 460 | ) 461 | elif modality == 'Flow': 462 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 463 | GroupRandomHorizontalFlip(is_flow=True)]) 464 | elif modality == 'RGBDiff': 465 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), 466 | GroupRandomHorizontalFlip(is_flow=False)]) 467 | 468 | if __name__ == "__main__": 469 | trans = torchvision.transforms.Compose([GroupScale(256), 470 | Stack(), 471 | GroupRandomCrop(224), 472 | GroupRandomHorizontalFlip(is_flow=False), 473 | ToTorchFormatTensor(), 474 | GroupNormalize(mean=[.485, .456, .406], std=[.229, .224, .225])] 475 | ) 476 | 477 | im = Image.open('/mnt/data/somethingsomethingv1_raw/20bn-something-something-v1/2/00002.jpg') 478 | 479 | color_group = [im] * 6 480 | rst = trans(color_group) 481 | 482 | gray_group = [im.convert('L')] * 9 483 | gray_rst = trans(gray_group) 484 | 485 | trans2 = torchvision.transforms.Compose([ 486 | GroupRandomSizedCrop(256), 487 | Stack(), 488 | ToTorchFormatTensor(), 489 | GroupNormalize( 490 | mean=[.485, .456, .406], 491 | std=[.229, .224, .225]) 492 | ]) 493 | print(trans2(color_group)) -------------------------------------------------------------------------------- /dataset/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DATASET = '/mnt/data/' 4 | 5 | 6 | def return_ucf101(modality): 7 | filename_categories = 'UCF101/labels/classInd.txt' 8 | if modality == 'RGB': 9 | root_data = ROOT_DATASET + 'UCF101/jpg' 10 | filename_imglist_train = 'UCF101/file_list/ucf101_rgb_train_split_1.txt' 11 | filename_imglist_val = 'UCF101/file_list/ucf101_rgb_val_split_1.txt' 12 | prefix = 'img_{:05d}.jpg' 13 | elif modality == 'Flow': 14 | root_data = ROOT_DATASET + 'UCF101/jpg' 15 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt' 16 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt' 17 | prefix = 'flow_{}_{:05d}.jpg' 18 | else: 19 | raise NotImplementedError('no such modality:' + modality) 20 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 21 | 22 | 23 | def return_ucf101_zip(modality, root_path=''): 24 | assert modality == 'RGB', "Currently RGB only." 25 | filename_categories = 'ucf101_zip/classInd.txt' 26 | if modality == 'RGB': 27 | root_data = os.path.join(root_path, 'ucf101_zip/') 28 | filename_imglist_train = 'ucf101_zip/train_videofolder.txt' 29 | filename_imglist_val = 'ucf101_zip/val_videofolder.txt' 30 | prefix = 'image_{:05d}.jpg' 31 | elif modality == 'Flow': 32 | root_data = ROOT_DATASET + 'UCF101/jpg' 33 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt' 34 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt' 35 | prefix = 'flow_{}_{:05d}.jpg' 36 | else: 37 | raise NotImplementedError('no such modality:' + modality) 38 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 39 | 40 | 41 | def return_hmdb51(modality): 42 | filename_categories = 51 43 | if modality == 'RGB': 44 | root_data = ROOT_DATASET + 'HMDB51/images' 45 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt' 46 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt' 47 | prefix = 'img_{:05d}.jpg' 48 | elif modality == 'Flow': 49 | root_data = ROOT_DATASET + 'HMDB51/images' 50 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt' 51 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt' 52 | prefix = 'flow_{}_{:05d}.jpg' 53 | else: 54 | raise NotImplementedError('no such modality:' + modality) 55 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 56 | 57 | 58 | def return_something(modality, root_path=''): 59 | assert modality == 'RGB', "Currently RGB only." 60 | filename_categories = '20bn-something-something-v1/category.txt' 61 | if modality == 'RGB': 62 | root_data = os.path.join(root_path, '20bn-something-something-v1/') 63 | filename_imglist_train = '20bn-something-something-v1/train_videofolder.txt' 64 | filename_imglist_val = '20bn-something-something-v1/val_videofolder.txt' 65 | prefix = '{:05d}.jpg' 66 | elif modality == 'Flow': 67 | root_data = os.path.join(root_path, 'something/v1/20bn-something-something-v1-flow/') 68 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt' 69 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt' 70 | prefix = '{:06d}-{}_{:05d}.jpg' 71 | else: 72 | print('no such modality:'+modality) 73 | raise NotImplementedError 74 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 75 | 76 | 77 | def return_something_zip(modality, root_path=''): 78 | assert modality == 'RGB', "Currently RGB only." 79 | filename_categories = '20bn-something-something-v1_zip/category.txt' 80 | if modality == 'RGB': 81 | root_data = os.path.join(root_path, '20bn-something-something-v1_zip/') 82 | filename_imglist_train = '20bn-something-something-v1_zip/train_videofolder.txt' 83 | filename_imglist_val = '20bn-something-something-v1_zip/val_videofolder.txt' 84 | prefix = '{:05d}.jpg' 85 | elif modality == 'Flow': 86 | root_data = os.path.join(root_path, 'something/v1/20bn-something-something-v1-flow/') 87 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt' 88 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt' 89 | prefix = '{:06d}-{}_{:05d}.jpg' 90 | else: 91 | print('no such modality:'+modality) 92 | raise NotImplementedError 93 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 94 | 95 | 96 | def return_somethingv2(modality, root_path=''): 97 | assert modality == 'RGB', "Currently RGB only." 98 | filename_categories = '20bn-something-something-v2/category.txt' 99 | if modality == 'RGB': 100 | root_data = os.path.join(root_path, '20bn-something-something-v2/') 101 | filename_imglist_train = '20bn-something-something-v2/train_videofolder.txt' 102 | filename_imglist_val = '20bn-something-something-v2/val_videofolder.txt' 103 | prefix = '{:06d}.jpg' 104 | elif modality == 'Flow': 105 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow' 106 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt' 107 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt' 108 | prefix = '{:06d}.jpg' 109 | else: 110 | raise NotImplementedError('no such modality:'+modality) 111 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 112 | 113 | 114 | def return_somethingv2_zip(modality, root_path=''): 115 | assert modality == 'RGB', "Currently RGB only." 116 | filename_categories = '20bn-something-something-v2_zip/category.txt' 117 | if modality == 'RGB': 118 | root_data = os.path.join(root_path, '20bn-something-something-v2_zip/') 119 | filename_imglist_train = '20bn-something-something-v2_zip/train_videofolder.txt' 120 | filename_imglist_val = '20bn-something-something-v2_zip/val_videofolder.txt' 121 | prefix = '{:06d}.jpg' 122 | elif modality == 'Flow': 123 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow' 124 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt' 125 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt' 126 | prefix = '{:06d}.jpg' 127 | else: 128 | raise NotImplementedError('no such modality:'+modality) 129 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 130 | 131 | 132 | def return_jester(modality): 133 | filename_categories = 'jester/category.txt' 134 | if modality == 'RGB': 135 | prefix = '{:05d}.jpg' 136 | root_data = ROOT_DATASET + 'jester/20bn-jester-v1' 137 | filename_imglist_train = 'jester/train_videofolder.txt' 138 | filename_imglist_val = 'jester/val_videofolder.txt' 139 | else: 140 | raise NotImplementedError('no such modality:'+modality) 141 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 142 | 143 | 144 | def return_kinetics(modality, root_path=''): 145 | filename_categories = 400 146 | if modality == 'RGB': 147 | root_data = os.path.join(root_path, 'kinetics400_frame/') 148 | filename_imglist_train = 'kinetics400_frame/train_videofolder.txt' 149 | filename_imglist_val = 'kinetics400_frame/val_videofolder.txt' 150 | prefix = 'img_{:05d}.jpg' 151 | else: 152 | raise NotImplementedError('no such modality:' + modality) 153 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 154 | 155 | 156 | def return_kinetics_zip(modality, root_path=''): 157 | filename_categories = 400 158 | if modality == 'RGB': 159 | root_data = os.path.join(root_path, 'kinetics400_frame_zip/') 160 | filename_imglist_train = 'kinetics400_frame_zip/train_videofolder.txt' 161 | filename_imglist_val = 'kinetics400_frame_zip/val_videofolder.txt' 162 | prefix = 'img_{:05d}.jpg' 163 | else: 164 | raise NotImplementedError('no such modality:' + modality) 165 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix 166 | 167 | 168 | def return_dataset(dataset, modality, root_path=''): 169 | dict_single = {'jester': return_jester, 'something': return_something, 'something_zip': return_something_zip, 170 | 'somethingv2': return_somethingv2, 'somethingv2_zip': return_somethingv2_zip, 171 | 'ucf101': return_ucf101, 'ucf101_zip': return_ucf101_zip, 'hmdb51': return_hmdb51, 172 | 'kinetics': return_kinetics, 'kinetics_zip': return_kinetics_zip} 173 | if dataset in dict_single: 174 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality, root_path) 175 | else: 176 | raise ValueError('Unknown dataset '+dataset) 177 | 178 | file_imglist_train = os.path.join(root_path, file_imglist_train) 179 | file_imglist_val = os.path.join(root_path, file_imglist_val) 180 | if isinstance(file_categories, str): 181 | file_categories = os.path.join(root_path, file_categories) 182 | with open(file_categories) as f: 183 | lines = f.readlines() 184 | categories = [item.rstrip() for item in lines] 185 | else: # number of categories 186 | categories = [None] * file_categories 187 | n_class = len(categories) 188 | print('{}: {} classes'.format(dataset, n_class)) 189 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/models/__init__.py -------------------------------------------------------------------------------- /models/densenet_3d.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | import types 7 | 8 | import os 9 | 10 | from collections import OrderedDict 11 | from functools import partial 12 | 13 | from args import parser 14 | from NAS_utils.ops import Conv3d_with_CD, Linear_with_CD 15 | 16 | args = parser.parse_args() 17 | 18 | #Conv2d = Conv2d 19 | Conv3d = partial(Conv3d_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init) 20 | Linear = partial(Linear_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init) 21 | nnConv2d = nn.Conv2d 22 | BatchNorm3d = partial(nn.BatchNorm3d, track_running_stats=not args.freeze_bn) 23 | _TEMPORAL_NASAS_ONLY = args.temporal_nasas_only 24 | _TEMPORAL_NODOWNSAMPLE = 'v1d3' in args.net_version or ('pure' in args.net_version and ('something' in args.dataset or 'ucf' in args.dataset)) 25 | 26 | __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] 27 | 28 | 29 | model_urls = { 30 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 31 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 32 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 33 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 34 | } 35 | 36 | 37 | input_sizes = {} 38 | means = {} 39 | stds = {} 40 | 41 | 42 | for model_name in __all__: 43 | input_sizes[model_name] = [3, 224, 224] 44 | means[model_name] = [0.485, 0.456, 0.406] 45 | stds[model_name] = [0.229, 0.224, 0.225] 46 | 47 | 48 | pretrained_settings = {} 49 | 50 | 51 | for model_name in __all__: 52 | pretrained_settings[model_name] = { 53 | 'imagenet': { 54 | 'url': model_urls[model_name], 55 | 'input_space': 'RGB', 56 | 'input_size': input_sizes[model_name], 57 | 'crop_size': input_sizes[model_name][-1] * 256 // 224, 58 | 'input_range': [0, 1], 59 | 'mean': means[model_name], 60 | 'std': stds[model_name] 61 | #'num_classes': 174 62 | } 63 | } 64 | 65 | 66 | def load_pretrained(model, num_classes, settings): 67 | #assert num_classes == settings['num_classes'], \ 68 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 69 | try: 70 | state_dict = torch.load('/log/checkpoint/Densenet121_2D_ImagenetPretrained/densenet121-a639ec97.pth') 71 | except: 72 | state_dict = model_zoo.load_url(settings['url']) 73 | state_dict = update_state_dict(state_dict) 74 | mk, uk = model.load_state_dict(state_dict, strict=False) 75 | model.input_space = settings['input_space'] 76 | model.input_size = settings['input_size'] 77 | model.input_range = settings['input_range'] 78 | model.mean = settings['mean'] 79 | model.std = settings['std'] 80 | return model 81 | 82 | 83 | def update_state_dict(state_dict): 84 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 85 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 86 | # They are also in the checkpoints in model_urls. This pattern is used 87 | # to find such keys. 88 | pattern = re.compile( 89 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 90 | for key in list(state_dict.keys()): 91 | res = pattern.match(key) 92 | if res: 93 | new_key = res.group(1) + res.group(2) 94 | state_dict[new_key] = state_dict[key] 95 | del state_dict[key] 96 | 97 | # Inflate to 3d densenet 98 | pattern = re.compile( 99 | r'^(.*)((?:conv|norm)(?:[012]?)\.(?:weight|bias|running_mean|running_var))$') 100 | for key in list(state_dict.keys()): 101 | res = pattern.match(key) 102 | if res: 103 | v = state_dict[key] 104 | if 'conv' in key: 105 | v = torch.unsqueeze(v, dim=2) 106 | if 'conv0' in key: 107 | v = v.repeat([1, 1, 5, 1, 1]) 108 | v /= 5.0 109 | state_dict[key] = v 110 | elif 'conv1' in key: 111 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2) 112 | state_dict[new_key_btnk] = v 113 | if 'v1' in args.net_version or 'pure_temporal' in args.net_version: 114 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 115 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0 116 | del state_dict[key] 117 | elif 'conv2' in key: 118 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2) 119 | state_dict[new_key_sptl] = v 120 | del state_dict[key] 121 | else: 122 | if 'transition' in key: 123 | new_key_btnk = res.group(1) + 'original.' + res.group(2) 124 | state_dict[new_key_btnk] = v 125 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3', 'pure_temporal']: 126 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 127 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0 128 | del state_dict[key] 129 | else: 130 | state_dict[key] = v 131 | else: 132 | if 'norm1' in key: 133 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2) 134 | state_dict[new_key_btnk] = v 135 | if args.net_version in ['v1d2', 'v1nt', 'v1d3', 'pure_temporal']: 136 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 137 | state_dict[new_key_tmpr] = v 138 | del state_dict[key] 139 | elif 'norm2' in key: 140 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2) 141 | state_dict[new_key_sptl] = v 142 | del state_dict[key] 143 | else: 144 | if 'transition' in key: 145 | new_key_btnk = res.group(1) + 'original.' + res.group(2) 146 | state_dict[new_key_btnk] = v 147 | if args.net_version in ['v1d2', 'vt', 'v1d3', 'pure_temporal']: 148 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 149 | state_dict[new_key_tmpr] = v 150 | del state_dict[key] 151 | else: 152 | state_dict[key] = v 153 | 154 | if 'classifier' in key: 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | class _DenseLayer(nn.Sequential): 160 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, split=1, split_pattern=None): 161 | super(_DenseLayer, self).__init__() 162 | if 'pure_temporal' not in args.net_version: 163 | self.bottleneck = nn.Sequential(OrderedDict([ 164 | ('norm1', BatchNorm3d(num_input_features)), 165 | ('relu1', nn.ReLU(inplace=True)), 166 | ('conv1', Conv3d(num_input_features, bn_size * 167 | growth_rate, kernel_size=1, stride=1, bias=False, split=split, 168 | split_pattern=split_pattern)) 169 | ])) 170 | self.spatial = nn.Sequential(OrderedDict([ 171 | ('norm2', BatchNorm3d(bn_size * growth_rate)), 172 | ('relu2', nn.ReLU(inplace=True)), 173 | ('conv2', Conv3d(bn_size * growth_rate, growth_rate, 174 | kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False, split=1, deact_nasas=_TEMPORAL_NASAS_ONLY)) 175 | ])) 176 | 177 | if 'v1' in args.net_version: 178 | self.temporal = nn.Sequential(OrderedDict([ 179 | ('norm1', BatchNorm3d(num_input_features)), 180 | ('relu1', nn.ReLU(inplace=True)), 181 | ('conv1', Conv3d(num_input_features, bn_size * 182 | growth_rate, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, 183 | split=split, 184 | split_pattern=split_pattern)) 185 | ])) 186 | elif 'v2' in args.net_version: 187 | self.temporal = nn.Sequential(OrderedDict([ 188 | ('norm1', BatchNorm3d(bn_size * growth_rate)), 189 | ('relu1', nn.ReLU(inplace=True)), 190 | ('conv1', Conv3d(bn_size * growth_rate, bn_size * growth_rate, 191 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=1)) 192 | ])) 193 | elif 'v3' in args.net_version: 194 | self.temporal = nn.Sequential(OrderedDict([ 195 | ('norm1', BatchNorm3d(growth_rate)), 196 | ('relu1', nn.ReLU(inplace=True)), 197 | ('conv1', Conv3d(growth_rate, growth_rate, 198 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=1)) 199 | ])) 200 | elif 'pure_temporal' in args.net_version: 201 | self.temporal = nn.Sequential(OrderedDict([ 202 | ('norm1', BatchNorm3d(num_input_features)), 203 | ('relu1', nn.ReLU(inplace=True)), 204 | ('conv1', Conv3d(num_input_features, bn_size * 205 | growth_rate, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, 206 | split=split, 207 | split_pattern=split_pattern)) 208 | ])) 209 | else: 210 | pass 211 | self.drop_rate = drop_rate 212 | 213 | def forward(self, x): 214 | if 'v1' in args.net_version: 215 | new_features = self.temporal.forward(x) + self.bottleneck.forward(x) 216 | new_features = self.spatial.forward(new_features) 217 | elif 'v2' in args.net_version: 218 | new_features = self.bottleneck.forward(x) 219 | new_features = self.temporal.forward(new_features) + new_features 220 | new_features = self.spatial.forward(new_features) 221 | elif 'v3' in args.net_version: 222 | new_features = self.bottleneck.forward(x) 223 | new_features = self.spatial.forward(new_features) 224 | new_features = self.temporal.forward(new_features) + new_features 225 | elif 'pure_temporal' in args.net_version: 226 | new_features = self.temporal.forward(x) 227 | new_features = self.spatial.forward(new_features) 228 | else: 229 | new_features = self.bottleneck.forward(x) 230 | new_features = self.spatial.forward(new_features) 231 | #if self.drop_rate > 0: 232 | # new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 233 | return torch.cat([x, new_features], 1) 234 | 235 | 236 | class _DenseBlock(nn.Sequential): 237 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 238 | super(_DenseBlock, self).__init__() 239 | self._split_pattern = [num_input_features] 240 | for i in range(num_layers): 241 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, split=1, split_pattern=None)#split=i+1, split_pattern=self._split_pattern) 242 | self.add_module('denselayer%d' % (i + 1), layer) 243 | # DO NOT use += in-place operator here! 244 | self._split_pattern = self._split_pattern + [growth_rate] 245 | 246 | 247 | class _Transition(nn.Sequential): 248 | def __init__(self, num_input_features, num_output_features, split=1, split_pattern=None, temporal_pool_size=1): 249 | super(_Transition, self).__init__() 250 | ''' 251 | self.add_module('norm', nn.BatchNorm3d(num_input_features)) 252 | self.add_module('relu', nn.ReLU(inplace=True)) 253 | self.add_module('conv', Conv3d(num_input_features, num_output_features, 254 | kernel_size=1, stride=1, bias=False, split=split, split_pattern=split_pattern)) 255 | self.add_module('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2))) 256 | ''' 257 | if 'pure_temporal' not in args.net_version: 258 | self.original = nn.Sequential(OrderedDict([ 259 | ('norm', BatchNorm3d(num_input_features)), 260 | ('relu', nn.ReLU(inplace=True)), 261 | ('conv', Conv3d(num_input_features, num_output_features, 262 | kernel_size=1, stride=1, bias=False, split=split, split_pattern=split_pattern)) 263 | ])) 264 | self.transition_pool = nn.Sequential(OrderedDict([ 265 | ('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2))) 266 | ])) 267 | 268 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']: 269 | self.temporal = nn.Sequential(OrderedDict([ 270 | ('norm', BatchNorm3d(num_input_features)), 271 | ('relu', nn.ReLU(inplace=True)), 272 | ('conv', Conv3d(num_input_features, num_output_features, 273 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=split, 274 | split_pattern=split_pattern)) 275 | ])) 276 | elif args.net_version in ['v2', 'v3', 'v4']: 277 | self.temporal = nn.Sequential(OrderedDict([ 278 | ('norm', BatchNorm3d(num_output_features)), 279 | ('relu', nn.ReLU(inplace=True)), 280 | ('conv', Conv3d(num_output_features, num_output_features, 281 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=1)) 282 | ])) 283 | elif 'pure_temporal' in args.net_version: 284 | self.temporal = nn.Sequential(OrderedDict([ 285 | ('norm', BatchNorm3d(num_input_features)), 286 | ('relu', nn.ReLU(inplace=True)), 287 | ('conv', Conv3d(num_input_features, num_output_features, 288 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False, split=split, 289 | split_pattern=split_pattern)) 290 | ])) 291 | else: 292 | pass 293 | 294 | def forward(self, input): 295 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']: 296 | new_features = self.original(input) + self.temporal(input) 297 | elif args.net_version in ['v2', 'v3', 'v4']: 298 | new_features = self.original(input) 299 | new_features = self.temporal(new_features) + new_features 300 | elif 'pure_temporal' in args.net_version: 301 | new_features = self.temporal(input) 302 | else: 303 | new_features = self.original(input) 304 | new_features = self.transition_pool(new_features) 305 | return new_features 306 | 307 | 308 | class DenseNet(nn.Module): 309 | r"""Densenet-BC model class, based on 310 | `"Densely Connected Convolutional Networks" `_ 311 | 312 | Args: 313 | growth_rate (int) - how many filters to add each layer (`k` in paper) 314 | block_config (list of 4 ints) - how many layers in each pooling block 315 | num_init_features (int) - the number of filters to learn in the first convolution layer 316 | bn_size (int) - multiplicative factor for number of bottle neck layers 317 | (i.e. bn_size * k features in the bottleneck layer) 318 | drop_rate (float) - dropout rate after each dense layer 319 | num_classes (int) - number of classification classes 320 | """ 321 | 322 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 323 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 324 | 325 | super(DenseNet, self).__init__() 326 | self.drop_rate = drop_rate 327 | 328 | # First convolution 329 | self.features = nn.Sequential(OrderedDict([ 330 | ('conv0', nn.Conv3d(3, num_init_features, kernel_size=(5, 7, 7), stride=(1, 2, 2) if _TEMPORAL_NODOWNSAMPLE else 2, padding=(2, 3, 3), bias=False)), 331 | ('norm0', BatchNorm3d(num_init_features)), 332 | ('relu0', nn.ReLU(inplace=True)), 333 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2) if _TEMPORAL_NODOWNSAMPLE else 2, padding=1)), 334 | ])) 335 | 336 | # Each denseblock 337 | num_features = num_init_features 338 | downsample_pos = [-1] if _TEMPORAL_NODOWNSAMPLE else [0] 339 | for i, num_layers in enumerate(block_config): 340 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 341 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=self.drop_rate) 342 | self.features.add_module('denseblock%d' % (i + 1), block) 343 | num_features = num_features + num_layers * growth_rate 344 | if i != len(block_config) - 1: 345 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, 346 | split=1,#split=num_layers+1, 347 | split_pattern=None, #split_pattern=[num_features - num_layers * growth_rate]+[growth_rate]*num_layers, 348 | temporal_pool_size=2 if i in downsample_pos else 1) 349 | self.features.add_module('transition%d' % (i + 1), trans) 350 | num_features = num_features // 2 351 | 352 | # Final batch norm 353 | self.features.add_module('norm5', BatchNorm3d(num_features)) 354 | 355 | # Linear layer 356 | self.classifier = Linear(num_features, num_classes) 357 | 358 | # Official init from torch repo. 359 | for m in self.modules(): 360 | if isinstance(m, nn.Conv3d): 361 | nn.init.kaiming_normal_(m.weight) 362 | elif isinstance(m, nn.BatchNorm3d): 363 | nn.init.constant_(m.weight, 0) 364 | nn.init.constant_(m.bias, 0) 365 | elif isinstance(m, nn.Linear): 366 | nn.init.constant_(m.bias, 0) 367 | 368 | def forward(self, x): 369 | features = self.features(x) 370 | out = F.relu(features, inplace=True) 371 | out = F.adaptive_avg_pool3d(out, (1, 1, 1)).view(features.size(0), -1) 372 | out = F.dropout(out, p=self.drop_rate, training=self.training) 373 | out = self.classifier(out) 374 | return out 375 | 376 | 377 | def _load_state_dict(model, model_url): 378 | # '.'s are no longer allowed in module names, but previous _DenseLayer 379 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 380 | # They are also in the checkpoints in model_urls. This pattern is used 381 | # to find such keys. 382 | pattern = re.compile( 383 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 384 | state_dict = model_zoo.load_url(model_url) 385 | for key in list(state_dict.keys()): 386 | res = pattern.match(key) 387 | if res: 388 | new_key = res.group(1) + res.group(2) 389 | state_dict[new_key] = state_dict[key] 390 | del state_dict[key] 391 | model.load_state_dict(state_dict) 392 | 393 | 394 | def modify_densenets(model): 395 | # Modify attributs 396 | model.last_linear = model.classifier 397 | del model.classifier 398 | 399 | def logits(self, features): 400 | x = F.relu(features, inplace=True) 401 | x = F.avg_pool2d(x, kernel_size=7, stride=1) 402 | x = x.view(x.size(0), -1) 403 | x = self.last_linear(x) 404 | return x 405 | 406 | def forward(self, input): 407 | x = self.features(input) 408 | x = self.logits(x) 409 | return x 410 | 411 | # Modify methods 412 | model.logits = types.MethodType(logits, model) 413 | model.forward = types.MethodType(forward, model) 414 | return model 415 | 416 | 417 | def _densenet121(num_classes, **kwargs): 418 | r"""Densenet-121 model from 419 | `"Densely Connected Convolutional Networks" `_ 420 | 421 | Args: 422 | pretrained (bool): If True, returns a model pre-trained on ImageNet 423 | """ 424 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes, 425 | **kwargs) 426 | return model 427 | 428 | def _densenet169(pretrained=False, **kwargs): 429 | r"""Densenet-121 model from 430 | `"Densely Connected Convolutional Networks" `_ 431 | 432 | Args: 433 | pretrained (bool): If True, returns a model pre-trained on ImageNet 434 | """ 435 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 436 | **kwargs) 437 | if pretrained: 438 | _load_state_dict(model, model_urls['densenet169']) 439 | return model 440 | 441 | def densenet121(num_classes=1000, pretrained='imagenet', drop_rate=0.0): 442 | r"""Densenet-121 model from 443 | `"Densely Connected Convolutional Networks" ` 444 | """ 445 | model = _densenet121(num_classes=num_classes, drop_rate=drop_rate) 446 | if pretrained is not None: 447 | settings = pretrained_settings['densenet121'][pretrained] 448 | model = load_pretrained(model, num_classes, settings) 449 | return model 450 | 451 | 452 | def densenet169(num_classes=1000, pretrained='imagenet'): 453 | r"""Densenet-121 model from 454 | `"Densely Connected Convolutional Networks" ` 455 | """ 456 | model = _densenet169(pretrained=False) 457 | if pretrained is not None: 458 | settings = pretrained_settings['densenet169'][pretrained] 459 | model = load_pretrained(model, num_classes, settings) 460 | #model = modify_densenets(model) 461 | return model 462 | 463 | 464 | def densenet201(pretrained=False, **kwargs): 465 | r"""Densenet-201 model from 466 | `"Densely Connected Convolutional Networks" `_ 467 | 468 | Args: 469 | pretrained (bool): If True, returns a model pre-trained on ImageNet 470 | """ 471 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 472 | **kwargs) 473 | if pretrained: 474 | _load_state_dict(model, model_urls['densenet201']) 475 | return model 476 | 477 | 478 | def densenet161(pretrained=False, **kwargs): 479 | r"""Densenet-161 model from 480 | `"Densely Connected Convolutional Networks" `_ 481 | 482 | Args: 483 | pretrained (bool): If True, returns a model pre-trained on ImageNet 484 | """ 485 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 486 | **kwargs) 487 | if pretrained: 488 | _load_state_dict(model, model_urls['densenet161']) 489 | return model -------------------------------------------------------------------------------- /models/densenet_3d_forstat.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | import types 7 | 8 | import os 9 | 10 | from collections import OrderedDict 11 | from functools import partial 12 | 13 | from args import parser 14 | from NAS_utils.ops import Conv3d_with_CD, Linear_with_CD 15 | 16 | args = parser.parse_args() 17 | 18 | #Conv2d = Conv2d 19 | Conv3d = nn.Conv3d#partial(Conv3d_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init) 20 | Linear = nn.Linear #partial(Linear_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode or args.finetune_mode else False, training_size=args.training_size, p_init=args.p_init) 21 | nnConv2d = nn.Conv2d 22 | BatchNorm3d = partial(nn.BatchNorm3d, track_running_stats=not args.freeze_bn) 23 | _TEMPORAL_NASAS_ONLY = args.temporal_nasas_only 24 | 25 | __all__ = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] 26 | 27 | 28 | model_urls = { 29 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 30 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 31 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 32 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 33 | } 34 | 35 | 36 | input_sizes = {} 37 | means = {} 38 | stds = {} 39 | 40 | 41 | for model_name in __all__: 42 | input_sizes[model_name] = [3, 224, 224] 43 | means[model_name] = [0.485, 0.456, 0.406] 44 | stds[model_name] = [0.229, 0.224, 0.225] 45 | 46 | 47 | pretrained_settings = {} 48 | 49 | 50 | for model_name in __all__: 51 | pretrained_settings[model_name] = { 52 | 'imagenet': { 53 | 'url': model_urls[model_name], 54 | 'input_space': 'RGB', 55 | 'input_size': input_sizes[model_name], 56 | 'crop_size': input_sizes[model_name][-1] * 256 // 224, 57 | 'input_range': [0, 1], 58 | 'mean': means[model_name], 59 | 'std': stds[model_name] 60 | #'num_classes': 174 61 | } 62 | } 63 | 64 | 65 | def load_pretrained(model, num_classes, settings): 66 | #assert num_classes == settings['num_classes'], \ 67 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 68 | try: 69 | state_dict = torch.load('/log/checkpoint/Densenet121_2D_ImagenetPretrained/densenet121-a639ec97.pth') 70 | except: 71 | state_dict = model_zoo.load_url(settings['url']) 72 | state_dict = update_state_dict(state_dict) 73 | mk, uk = model.load_state_dict(state_dict, strict=False) 74 | model.input_space = settings['input_space'] 75 | model.input_size = settings['input_size'] 76 | model.input_range = settings['input_range'] 77 | model.mean = settings['mean'] 78 | model.std = settings['std'] 79 | return model 80 | 81 | 82 | def update_state_dict(state_dict): 83 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 84 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 85 | # They are also in the checkpoints in model_urls. This pattern is used 86 | # to find such keys. 87 | pattern = re.compile( 88 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 89 | for key in list(state_dict.keys()): 90 | res = pattern.match(key) 91 | if res: 92 | new_key = res.group(1) + res.group(2) 93 | state_dict[new_key] = state_dict[key] 94 | del state_dict[key] 95 | 96 | # Inflate to 3d densenet 97 | pattern = re.compile( 98 | r'^(.*)((?:conv|norm)(?:[012]?)\.(?:weight|bias|running_mean|running_var))$') 99 | for key in list(state_dict.keys()): 100 | res = pattern.match(key) 101 | if res: 102 | v = state_dict[key] 103 | if 'conv' in key: 104 | v = torch.unsqueeze(v, dim=2) 105 | if 'conv0' in key: 106 | v = v.repeat([1, 1, 5, 1, 1]) 107 | v /= 5.0 108 | state_dict[key] = v 109 | elif 'conv1' in key: 110 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2) 111 | state_dict[new_key_btnk] = v 112 | if 'v1' in args.net_version: 113 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 114 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0 115 | del state_dict[key] 116 | elif 'conv2' in key: 117 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2) 118 | state_dict[new_key_sptl] = v 119 | del state_dict[key] 120 | else: 121 | if 'transition' in key: 122 | new_key_btnk = res.group(1) + 'original.' + res.group(2) 123 | state_dict[new_key_btnk] = v 124 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']: 125 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 126 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0 127 | del state_dict[key] 128 | else: 129 | state_dict[key] = v 130 | else: 131 | if 'norm1' in key: 132 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2) 133 | state_dict[new_key_btnk] = v 134 | if args.net_version in ['v1d2', 'v1nt', 'v1d3']: 135 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 136 | state_dict[new_key_tmpr] = v 137 | del state_dict[key] 138 | elif 'norm2' in key: 139 | new_key_sptl = res.group(1) + 'spatial.' + res.group(2) 140 | state_dict[new_key_sptl] = v 141 | del state_dict[key] 142 | else: 143 | if 'transition' in key: 144 | new_key_btnk = res.group(1) + 'original.' + res.group(2) 145 | state_dict[new_key_btnk] = v 146 | if args.net_version in ['v1d2', 'vt', 'v1d3']: 147 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 148 | state_dict[new_key_tmpr] = v 149 | del state_dict[key] 150 | else: 151 | state_dict[key] = v 152 | 153 | if 'classifier' in key: 154 | del state_dict[key] 155 | return state_dict 156 | 157 | 158 | class _DenseLayer(nn.Sequential): 159 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, split=1, split_pattern=None): 160 | super(_DenseLayer, self).__init__() 161 | self.bottleneck = nn.Sequential(OrderedDict([ 162 | ('norm1', BatchNorm3d(num_input_features)), 163 | ('relu1', nn.ReLU(inplace=True)), 164 | ('conv1', Conv3d(num_input_features, bn_size * 165 | growth_rate, kernel_size=1, stride=1, bias=False)) 166 | ])) 167 | self.spatial = nn.Sequential(OrderedDict([ 168 | ('norm2', BatchNorm3d(bn_size * growth_rate)), 169 | ('relu2', nn.ReLU(inplace=True)), 170 | ('conv2', Conv3d(bn_size * growth_rate, growth_rate, 171 | kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False)) 172 | ])) 173 | 174 | if 'v1' in args.net_version: 175 | self.temporal = nn.Sequential(OrderedDict([ 176 | ('norm1', BatchNorm3d(num_input_features)), 177 | ('relu1', nn.ReLU(inplace=True)), 178 | ('conv1', Conv3d(num_input_features, bn_size * 179 | growth_rate, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 180 | ])) 181 | elif 'v2' in args.net_version: 182 | self.temporal = nn.Sequential(OrderedDict([ 183 | ('norm1', BatchNorm3d(bn_size * growth_rate)), 184 | ('relu1', nn.ReLU(inplace=True)), 185 | ('conv1', Conv3d(bn_size * growth_rate, bn_size * growth_rate, 186 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 187 | ])) 188 | elif 'v3' in args.net_version: 189 | self.temporal = nn.Sequential(OrderedDict([ 190 | ('norm1', BatchNorm3d(growth_rate)), 191 | ('relu1', nn.ReLU(inplace=True)), 192 | ('conv1', Conv3d(growth_rate, growth_rate, 193 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 194 | ])) 195 | else: 196 | pass 197 | self.drop_rate = drop_rate 198 | 199 | def forward(self, x): 200 | if 'v1' in args.net_version: 201 | new_features = self.temporal.forward(x) + self.bottleneck.forward(x) 202 | new_features = self.spatial.forward(new_features) 203 | elif 'v2' in args.net_version: 204 | new_features = self.bottleneck.forward(x) 205 | new_features = self.temporal.forward(new_features) + new_features 206 | new_features = self.spatial.forward(new_features) 207 | elif 'v3' in args.net_version: 208 | new_features = self.bottleneck.forward(x) 209 | new_features = self.spatial.forward(new_features) 210 | new_features = self.temporal.forward(new_features) + new_features 211 | else: 212 | new_features = self.bottleneck.forward(x) 213 | new_features = self.spatial.forward(new_features) 214 | #if self.drop_rate > 0: 215 | # new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 216 | return torch.cat([x, new_features], 1) 217 | 218 | 219 | class _DenseBlock(nn.Sequential): 220 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 221 | super(_DenseBlock, self).__init__() 222 | self._split_pattern = [num_input_features] 223 | for i in range(num_layers): 224 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, split=1, split_pattern=None)#split=i+1, split_pattern=self._split_pattern) 225 | self.add_module('denselayer%d' % (i + 1), layer) 226 | # DO NOT use += in-place operator here! 227 | self._split_pattern = self._split_pattern + [growth_rate] 228 | 229 | 230 | class _Transition(nn.Sequential): 231 | def __init__(self, num_input_features, num_output_features, split=1, split_pattern=None, temporal_pool_size=1): 232 | super(_Transition, self).__init__() 233 | ''' 234 | self.add_module('norm', nn.BatchNorm3d(num_input_features)) 235 | self.add_module('relu', nn.ReLU(inplace=True)) 236 | self.add_module('conv', Conv3d(num_input_features, num_output_features, 237 | kernel_size=1, stride=1, bias=False, split=split, split_pattern=split_pattern)) 238 | self.add_module('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2))) 239 | ''' 240 | self.original = nn.Sequential(OrderedDict([ 241 | ('norm', BatchNorm3d(num_input_features)), 242 | ('relu', nn.ReLU(inplace=True)), 243 | ('conv', Conv3d(num_input_features, num_output_features, 244 | kernel_size=1, stride=1, bias=False)) 245 | ])) 246 | self.transition_pool = nn.Sequential(OrderedDict([ 247 | ('pool', nn.AvgPool3d(kernel_size=(temporal_pool_size, 2, 2), stride=(temporal_pool_size, 2, 2))) 248 | ])) 249 | 250 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']: 251 | self.temporal = nn.Sequential(OrderedDict([ 252 | ('norm', BatchNorm3d(num_input_features)), 253 | ('relu', nn.ReLU(inplace=True)), 254 | ('conv', Conv3d(num_input_features, num_output_features, 255 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 256 | ])) 257 | elif args.net_version in ['v2', 'v3', 'v4']: 258 | self.temporal = nn.Sequential(OrderedDict([ 259 | ('norm', BatchNorm3d(num_output_features)), 260 | ('relu', nn.ReLU(inplace=True)), 261 | ('conv', Conv3d(num_output_features, num_output_features, 262 | kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 263 | ])) 264 | else: 265 | pass 266 | 267 | def forward(self, input): 268 | if args.net_version in ['v1', 'v1d2', 'vt', 'v1d3']: 269 | new_features = self.original(input) + self.temporal(input) 270 | elif args.net_version in ['v2', 'v3', 'v4']: 271 | new_features = self.original(input) 272 | new_features = self.temporal(new_features) + new_features 273 | else: 274 | new_features = self.original(input) 275 | new_features = self.transition_pool(new_features) 276 | return new_features 277 | 278 | 279 | class DenseNet(nn.Module): 280 | r"""Densenet-BC model class, based on 281 | `"Densely Connected Convolutional Networks" `_ 282 | 283 | Args: 284 | growth_rate (int) - how many filters to add each layer (`k` in paper) 285 | block_config (list of 4 ints) - how many layers in each pooling block 286 | num_init_features (int) - the number of filters to learn in the first convolution layer 287 | bn_size (int) - multiplicative factor for number of bottle neck layers 288 | (i.e. bn_size * k features in the bottleneck layer) 289 | drop_rate (float) - dropout rate after each dense layer 290 | num_classes (int) - number of classification classes 291 | """ 292 | 293 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 294 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 295 | 296 | super(DenseNet, self).__init__() 297 | self.drop_rate = drop_rate 298 | 299 | # First convolution 300 | self.features = nn.Sequential(OrderedDict([ 301 | ('conv0', nn.Conv3d(3, num_init_features, kernel_size=(5, 7, 7), stride=(1, 2, 2) if args.net_version=='v1d3' else 2, padding=(2, 3, 3), bias=False)), 302 | ('norm0', BatchNorm3d(num_init_features)), 303 | ('relu0', nn.ReLU(inplace=True)), 304 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2) if args.net_version=='v1d3' or args.random_dense_sample_stride else 2, padding=1)), 305 | ])) 306 | 307 | # Each denseblock 308 | num_features = num_init_features 309 | downsample_pos = [-1] if args.net_version=='v1d3' else [0] 310 | for i, num_layers in enumerate(block_config): 311 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 312 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=self.drop_rate) 313 | self.features.add_module('denseblock%d' % (i + 1), block) 314 | num_features = num_features + num_layers * growth_rate 315 | if i != len(block_config) - 1: 316 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, 317 | split=1,#split=num_layers+1, 318 | split_pattern=None, #split_pattern=[num_features - num_layers * growth_rate]+[growth_rate]*num_layers, 319 | temporal_pool_size=2 if i in downsample_pos else 1) 320 | self.features.add_module('transition%d' % (i + 1), trans) 321 | num_features = num_features // 2 322 | 323 | # Final batch norm 324 | self.features.add_module('norm5', BatchNorm3d(num_features)) 325 | 326 | # Linear layer 327 | self.classifier = Linear(num_features, num_classes) 328 | 329 | # Official init from torch repo. 330 | for m in self.modules(): 331 | if isinstance(m, nn.Conv3d): 332 | nn.init.kaiming_normal_(m.weight) 333 | elif isinstance(m, nn.BatchNorm3d): 334 | nn.init.constant_(m.weight, 0) 335 | nn.init.constant_(m.bias, 0) 336 | elif isinstance(m, nn.Linear): 337 | nn.init.constant_(m.bias, 0) 338 | 339 | def forward(self, x): 340 | features = self.features(x) 341 | out = F.relu(features, inplace=True) 342 | out = F.adaptive_avg_pool3d(out, (1, 1, 1)).view(features.size(0), -1) 343 | out = F.dropout(out, p=self.drop_rate, training=self.training) 344 | out = self.classifier(out) 345 | return out 346 | 347 | 348 | def _load_state_dict(model, model_url): 349 | # '.'s are no longer allowed in module names, but previous _DenseLayer 350 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 351 | # They are also in the checkpoints in model_urls. This pattern is used 352 | # to find such keys. 353 | pattern = re.compile( 354 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 355 | state_dict = model_zoo.load_url(model_url) 356 | for key in list(state_dict.keys()): 357 | res = pattern.match(key) 358 | if res: 359 | new_key = res.group(1) + res.group(2) 360 | state_dict[new_key] = state_dict[key] 361 | del state_dict[key] 362 | model.load_state_dict(state_dict) 363 | 364 | 365 | def modify_densenets(model): 366 | # Modify attributs 367 | model.last_linear = model.classifier 368 | del model.classifier 369 | 370 | def logits(self, features): 371 | x = F.relu(features, inplace=True) 372 | x = F.avg_pool2d(x, kernel_size=7, stride=1) 373 | x = x.view(x.size(0), -1) 374 | x = self.last_linear(x) 375 | return x 376 | 377 | def forward(self, input): 378 | x = self.features(input) 379 | x = self.logits(x) 380 | return x 381 | 382 | # Modify methods 383 | model.logits = types.MethodType(logits, model) 384 | model.forward = types.MethodType(forward, model) 385 | return model 386 | 387 | 388 | def _densenet121(num_classes, **kwargs): 389 | r"""Densenet-121 model from 390 | `"Densely Connected Convolutional Networks" `_ 391 | 392 | Args: 393 | pretrained (bool): If True, returns a model pre-trained on ImageNet 394 | """ 395 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=num_classes, 396 | **kwargs) 397 | return model 398 | 399 | def _densenet169(pretrained=False, **kwargs): 400 | r"""Densenet-121 model from 401 | `"Densely Connected Convolutional Networks" `_ 402 | 403 | Args: 404 | pretrained (bool): If True, returns a model pre-trained on ImageNet 405 | """ 406 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 407 | **kwargs) 408 | if pretrained: 409 | _load_state_dict(model, model_urls['densenet169']) 410 | return model 411 | 412 | def densenet121(num_classes=1000, pretrained='imagenet', drop_rate=0.0): 413 | r"""Densenet-121 model from 414 | `"Densely Connected Convolutional Networks" ` 415 | """ 416 | model = _densenet121(num_classes=num_classes, drop_rate=drop_rate) 417 | if pretrained is not None: 418 | settings = pretrained_settings['densenet121'][pretrained] 419 | model = load_pretrained(model, num_classes, settings) 420 | return model 421 | 422 | 423 | def densenet169(num_classes=1000, pretrained='imagenet'): 424 | r"""Densenet-121 model from 425 | `"Densely Connected Convolutional Networks" ` 426 | """ 427 | model = _densenet169(pretrained=False) 428 | if pretrained is not None: 429 | settings = pretrained_settings['densenet169'][pretrained] 430 | model = load_pretrained(model, num_classes, settings) 431 | #model = modify_densenets(model) 432 | return model 433 | 434 | 435 | def densenet201(pretrained=False, **kwargs): 436 | r"""Densenet-201 model from 437 | `"Densely Connected Convolutional Networks" `_ 438 | 439 | Args: 440 | pretrained (bool): If True, returns a model pre-trained on ImageNet 441 | """ 442 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 443 | **kwargs) 444 | if pretrained: 445 | _load_state_dict(model, model_urls['densenet201']) 446 | return model 447 | 448 | 449 | def densenet161(pretrained=False, **kwargs): 450 | r"""Densenet-161 model from 451 | `"Densely Connected Convolutional Networks" `_ 452 | 453 | Args: 454 | pretrained (bool): If True, returns a model pre-trained on ImageNet 455 | """ 456 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 457 | **kwargs) 458 | if pretrained: 459 | _load_state_dict(model, model_urls['densenet161']) 460 | return model -------------------------------------------------------------------------------- /models/mobilenet_v2_3d.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.functional import F 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from args import parser 6 | args = parser.parse_args() 7 | 8 | import torch 9 | import re 10 | 11 | 12 | __all__ = ['mobilenet_v2'] 13 | 14 | model_urls = { 15 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 16 | } 17 | 18 | input_sizes = {} 19 | means = {} 20 | stds = {} 21 | 22 | for model_name in __all__: 23 | input_sizes[model_name] = [3, 224, 224] 24 | means[model_name] = [0.485, 0.456, 0.406] 25 | stds[model_name] = [0.229, 0.224, 0.225] 26 | 27 | pretrained_settings = {} 28 | 29 | 30 | for model_name in __all__: 31 | pretrained_settings[model_name] = { 32 | 'imagenet': { 33 | 'url': model_urls[model_name], 34 | 'input_space': 'RGB', 35 | 'input_size': input_sizes[model_name], 36 | 'crop_size': input_sizes[model_name][-1] * 256 // 224, 37 | 'input_range': [0, 1], 38 | 'mean': means[model_name], 39 | 'std': stds[model_name] 40 | #'num_classes': 174 41 | } 42 | } 43 | 44 | 45 | def update_state_dict(state_dict): 46 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 47 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 48 | # They are also in the checkpoints in model_urls. This pattern is used 49 | # to find such keys. 50 | """ 51 | pattern = re.compile( 52 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 53 | for key in list(state_dict.keys()): 54 | res = pattern.match(key) 55 | if res: 56 | new_key = res.group(1) + res.group(2) 57 | state_dict[new_key] = state_dict[key] 58 | del state_dict[key] 59 | """ 60 | # Inflate to 3d densenet 61 | pattern = re.compile( 62 | r'^(.*)((?:conv|bn)(?:[0123]?)\.(?:weight|bias|running_mean|running_var))$') 63 | for key in list(state_dict.keys()): 64 | if True: 65 | v = state_dict[key] 66 | if 'features.0.' in key: 67 | if 'features.0.0.weight' in key: 68 | v = torch.unsqueeze(v, dim=2) 69 | v = v.repeat([1, 1, 5, 1, 1]) 70 | v /= 5.0 71 | state_dict[key] = v 72 | else: 73 | pass 74 | elif 'features.1.' in key: 75 | if 'conv.0' in key: 76 | if 'conv.0.0' in key: 77 | v = torch.unsqueeze(v, dim=2) 78 | new_key_btnk = key.replace('conv.0', 'depth_wise') 79 | state_dict[new_key_btnk] = v 80 | del state_dict[key] 81 | elif 'conv.1' in key: 82 | v = torch.unsqueeze(v, dim=2) 83 | new_key_btnk = key.replace('conv.1', 'point_wise') 84 | state_dict[new_key_btnk] = v 85 | del state_dict[key] 86 | else: 87 | assert 'conv.2' in key 88 | new_key_btnk = key.replace('conv.2', 'bn') 89 | state_dict[new_key_btnk] = v 90 | del state_dict[key] 91 | elif 'features.18.' in key: 92 | if 'features.18.0.weight' in key: 93 | v = torch.unsqueeze(v, dim=2) 94 | state_dict[key] = v 95 | else: 96 | pass 97 | elif 'classifier' in key: 98 | pass 99 | else: 100 | if 'conv.0.' in key: 101 | if 'conv.0.0.' in key: 102 | v = torch.unsqueeze(v, dim=2) 103 | new_key_btnk = key.replace('conv.0', 'bottleneck') 104 | state_dict[new_key_btnk] = v 105 | new_key_btnk = key.replace('conv.0', 'temporal') 106 | state_dict[new_key_btnk] = v.repeat([1, 1, 3, 1, 1]) / 3.0 107 | else: 108 | new_key_btnk = key.replace('conv.0', 'bottleneck') 109 | state_dict[new_key_btnk] = v 110 | new_key_btnk = key.replace('conv.0', 'temporal') 111 | state_dict[new_key_btnk] = v 112 | del state_dict[key] 113 | elif 'conv.1.' in key: 114 | if 'conv.1.0.' in key: 115 | v = torch.unsqueeze(v, dim=2) 116 | new_key_btnk = key.replace('conv.1', 'depth_wise') 117 | state_dict[new_key_btnk] = v 118 | del state_dict[key] 119 | elif 'conv.2.' in key: 120 | v = torch.unsqueeze(v, dim=2) 121 | new_key_btnk = key.replace('conv.2', 'point_wise') 122 | state_dict[new_key_btnk] = v 123 | del state_dict[key] 124 | else: 125 | assert 'conv.3.' in key 126 | new_key_btnk = key.replace('conv.3', 'bn') 127 | state_dict[new_key_btnk] = v 128 | del state_dict[key] 129 | 130 | return state_dict 131 | 132 | 133 | def load_pretrained(model, num_classes, settings): 134 | #assert num_classes == settings['num_classes'], \ 135 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 136 | state_dict = model_zoo.load_url(settings['url']) 137 | state_dict = update_state_dict(state_dict) 138 | mk, uk = model.load_state_dict(state_dict, strict=False) 139 | print('mk: {}'.format(mk)) 140 | print('uk: {}'.format(uk)) 141 | model.input_space = settings['input_space'] 142 | model.input_size = settings['input_size'] 143 | model.input_range = settings['input_range'] 144 | model.mean = settings['mean'] 145 | model.std = settings['std'] 146 | return model 147 | 148 | 149 | def _make_divisible(v, divisor, min_value=None): 150 | """ 151 | This function is taken from the original tf repo. 152 | It ensures that all layers have a channel number that is divisible by 8 153 | It can be seen here: 154 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 155 | :param v: 156 | :param divisor: 157 | :param min_value: 158 | :return: 159 | """ 160 | if min_value is None: 161 | min_value = divisor 162 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 163 | # Make sure that round down does not go down by more than 10%. 164 | if new_v < 0.9 * v: 165 | new_v += divisor 166 | return new_v 167 | 168 | 169 | class ConvBNReLU(nn.Sequential): 170 | def __init__(self, in_planes, out_planes, kernel_size=(1, 3, 3), stride=(1, 1, 1), groups=1): 171 | 172 | padding = tuple([(k - 1) // 2 for k in kernel_size]) 173 | super(ConvBNReLU, self).__init__( 174 | nn.Conv3d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 175 | nn.BatchNorm3d(out_planes), 176 | nn.ReLU6(inplace=True) 177 | ) 178 | 179 | 180 | class InvertedResidual(nn.Module): 181 | def __init__(self, inp, oup, stride, expand_ratio, modality=None): 182 | super(InvertedResidual, self).__init__() 183 | self.stride = stride 184 | self.modality = modality 185 | self.expand_ratio = expand_ratio 186 | assert stride in [1, 2] 187 | 188 | hidden_dim = int(round(inp * expand_ratio)) 189 | self.use_res_connect = self.stride == 1 and inp == oup 190 | 191 | if expand_ratio != 1: 192 | if args.net_version == 'pure_spatial': 193 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1)) 194 | elif args.net_version == 'pure_temporal': 195 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1)) 196 | elif args.net_version == 'pure_fused': 197 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1)) 198 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1)) 199 | elif args.net_version == 'pure_adaptive': 200 | assert self.modality is not None 201 | if self.modality == 'fused': 202 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1)) 203 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1)) 204 | elif self.modality == 'spatial': 205 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1)) 206 | else: 207 | assert self.modality == 'temporal' 208 | self.temporal = ConvBNReLU(inp, hidden_dim, kernel_size=(3, 1, 1)) 209 | else: 210 | self.bottleneck = ConvBNReLU(inp, hidden_dim, kernel_size=(1, 1, 1)) 211 | # pw 212 | self.depth_wise = ConvBNReLU(hidden_dim, hidden_dim, stride=(1, stride, stride), groups=hidden_dim) 213 | self.point_wise = nn.Conv3d(hidden_dim, oup, 1, 1, 0, bias=False) 214 | self.bn = nn.BatchNorm3d(oup) 215 | 216 | def forward(self, x): 217 | if self.expand_ratio != 1: 218 | if args.net_version == 'pure_spatial': 219 | new_features = self.bottleneck(x) 220 | elif args.net_version == 'pure_temporal': 221 | new_features = self.temporal(x) 222 | elif args.net_version == 'pure_fused': 223 | new_features = self.bottleneck(x) + self.temporal(x) 224 | elif args.net_version == 'pure_adaptive': 225 | assert self.modality is not None 226 | if self.modality == 'fused': 227 | new_features = self.bottleneck(x) + self.temporal(x) 228 | elif self.modality == 'spatial': 229 | new_features = self.bottleneck(x) 230 | else: 231 | assert self.modality == 'temporal' 232 | new_features = self.temporal(x) 233 | else: 234 | new_features = self.bottleneck(x) 235 | else: 236 | new_features = x 237 | new_features = self.bn(self.point_wise(self.depth_wise(new_features))) 238 | if self.use_res_connect: 239 | return x + new_features 240 | else: 241 | return new_features 242 | 243 | 244 | class MobileNetV2(nn.Module): 245 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, drop_rate=0.0): 246 | """ 247 | MobileNet V2 main class 248 | 249 | Args: 250 | num_classes (int): Number of classes 251 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 252 | inverted_residual_setting: Network structure 253 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 254 | Set to 1 to turn off rounding 255 | """ 256 | super(MobileNetV2, self).__init__() 257 | block = InvertedResidual 258 | input_channel = 32 259 | last_channel = 1280 260 | 261 | if inverted_residual_setting is None: 262 | if 'something' in args.dataset: 263 | inverted_residual_setting = [ 264 | # t, c, n, s, m 265 | [1, 16, 1, 1, 'fused'], 266 | [6, 24, 2, 2, 'fused'], 267 | [6, 32, 3, 2, 'fused'], 268 | [6, 64, 4, 2, 'temporal'], 269 | [6, 96, 3, 1, 'temporal'], 270 | [6, 160, 3, 2, 'fused'], 271 | [6, 320, 1, 1, 'fused'], 272 | ] 273 | else: 274 | inverted_residual_setting = [ 275 | # t, c, n, s, m 276 | [1, 16, 1, 1, 'fused'], 277 | [6, 24, 2, 2, 'fused'], 278 | [6, 32, 3, 2, 'spatial'], 279 | [6, 64, 4, 2, 'fused'], 280 | [6, 96, 3, 1, 'spatial'], 281 | [6, 160, 3, 2, 'temporal'], 282 | [6, 320, 1, 1, 'fused'], 283 | ] 284 | 285 | # only check the first element, assuming user knows t,c,n,s are required 286 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 5: 287 | raise ValueError("inverted_residual_setting should be non-empty " 288 | "or a 5-element list, got {}".format(inverted_residual_setting)) 289 | 290 | # building first layer 291 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 292 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 293 | features = [ConvBNReLU(3, input_channel, stride=2, kernel_size=(5, 3, 3))] 294 | # building inverted residual blocks 295 | for t, c, n, s, m in inverted_residual_setting: 296 | output_channel = _make_divisible(c * width_mult, round_nearest) 297 | for i in range(n): 298 | stride = s if i == 0 else 1 299 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, modality=m)) 300 | input_channel = output_channel 301 | # building last several layers 302 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=(1, 1, 1))) 303 | # make it nn.Sequential 304 | self.features = nn.Sequential(*features) 305 | 306 | # building classifier 307 | self.new_classifier = nn.Sequential( 308 | nn.Dropout(drop_rate), 309 | nn.Linear(self.last_channel, num_classes), 310 | ) 311 | 312 | # weight initialization 313 | for m in self.modules(): 314 | if isinstance(m, nn.Conv2d): 315 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 316 | if m.bias is not None: 317 | nn.init.zeros_(m.bias) 318 | elif isinstance(m, nn.BatchNorm2d): 319 | nn.init.ones_(m.weight) 320 | nn.init.zeros_(m.bias) 321 | elif isinstance(m, nn.Linear): 322 | nn.init.normal_(m.weight, 0, 0.01) 323 | nn.init.zeros_(m.bias) 324 | 325 | def forward(self, x): 326 | x = self.features(x) 327 | x = F.adaptive_avg_pool3d(x, (1, 1, 1)).view(x.size(0), -1) 328 | x = self.new_classifier(x) 329 | return x 330 | 331 | 332 | def mobilenet_v2(pretrained='imagenet', progress=True, **kwargs): 333 | """ 334 | Constructs a MobileNetV2 architecture from 335 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 336 | 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | progress (bool): If True, displays a progress bar of the download to stderr 340 | """ 341 | model = MobileNetV2(**kwargs) 342 | if pretrained: 343 | settings = pretrained_settings['mobilenet_v2'][pretrained] 344 | model = load_pretrained(model, kwargs['num_classes'], settings) 345 | return model 346 | 347 | 348 | if __name__ == '__main__': 349 | mobilenet_v2(pretrained='imagenet', num_classes=1000) -------------------------------------------------------------------------------- /models/resnet_3d.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | #from functools import partial 10 | 11 | from args import parser 12 | #from NAS_utils.ops import Conv3d_with_CD, Linear_with_CD 13 | 14 | args = parser.parse_args() 15 | 16 | #Conv2d = Conv2d 17 | Conv3d = nn.Conv3d #partial(Conv3d_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode else False, training_size=args.training_size) 18 | Linear = nn.Linear #partial(Linear_with_CD, weight_reg=args.weight_reg, deterministic=True if args.selection_mode else False, training_size=args.training_size) 19 | nnConv2d = nn.Conv2d 20 | 21 | _TEMPORAL_NASAS_ONLY = args.temporal_nasas_only 22 | 23 | 24 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 25 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 26 | 'wide_resnet50_2', 'wide_resnet101_2'] 27 | 28 | 29 | model_urls = { 30 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 31 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 32 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 33 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 34 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 35 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 36 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 37 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 38 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 39 | } 40 | 41 | input_sizes = {} 42 | means = {} 43 | stds = {} 44 | 45 | for model_name in __all__: 46 | input_sizes[model_name] = [3, 224, 224] 47 | means[model_name] = [0.485, 0.456, 0.406] 48 | stds[model_name] = [0.229, 0.224, 0.225] 49 | 50 | pretrained_settings = {} 51 | 52 | for model_name in __all__: 53 | pretrained_settings[model_name] = { 54 | 'imagenet': { 55 | 'url': model_urls[model_name], 56 | 'input_space': 'RGB', 57 | 'input_size': input_sizes[model_name], 58 | 'crop_size': input_sizes[model_name][-1] * 256 // 224, 59 | 'input_range': [0, 1], 60 | 'mean': means[model_name], 61 | 'std': stds[model_name] 62 | #'num_classes': 174 63 | } 64 | } 65 | 66 | 67 | def update_state_dict(state_dict): 68 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 69 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 70 | # They are also in the checkpoints in model_urls. This pattern is used 71 | # to find such keys. 72 | """ 73 | pattern = re.compile( 74 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 75 | for key in list(state_dict.keys()): 76 | res = pattern.match(key) 77 | if res: 78 | new_key = res.group(1) + res.group(2) 79 | state_dict[new_key] = state_dict[key] 80 | del state_dict[key] 81 | """ 82 | # Inflate to 3d densenet 83 | pattern = re.compile( 84 | r'^(.*)((?:conv|bn)(?:[0123]?)\.(?:weight|bias|running_mean|running_var))$') 85 | for key in list(state_dict.keys()): 86 | res = pattern.match(key) 87 | if res: 88 | v = state_dict[key] 89 | if 'conv' in key: 90 | v = torch.unsqueeze(v, dim=2) 91 | if 'layer' not in key: 92 | v = v.repeat([1, 1, 5, 1, 1]) 93 | v /= 5.0 94 | state_dict[key] = v 95 | elif 'conv1' in key: 96 | new_key_btnk = res.group(1) + 'bottleneck.' + res.group(2) 97 | state_dict[new_key_btnk] = v 98 | new_key_tmpr = res.group(1) + 'temporal.' + res.group(2) 99 | state_dict[new_key_tmpr] = v.repeat([1, 1, 3, 1, 1]) / 3.0 100 | del state_dict[key] 101 | elif 'conv2' in key: 102 | state_dict[key] = v 103 | elif 'conv3' in key: 104 | state_dict[key] = v 105 | else: 106 | if 'bn1' in key: 107 | pass 108 | elif 'bn2' in key: 109 | pass 110 | else: 111 | pass 112 | if 'downsample' in key: 113 | v = state_dict[key] 114 | if 'downsample.0' in key: 115 | v = torch.unsqueeze(v, dim=2) 116 | state_dict[key] = v 117 | if 'fc' in key: 118 | del state_dict[key] 119 | return state_dict 120 | 121 | 122 | def load_pretrained(model, num_classes, settings): 123 | #assert num_classes == settings['num_classes'], \ 124 | # "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) 125 | state_dict = model_zoo.load_url(settings['url']) 126 | state_dict = update_state_dict(state_dict) 127 | mk, uk = model.load_state_dict(state_dict, strict=False) 128 | model.input_space = settings['input_space'] 129 | model.input_size = settings['input_size'] 130 | model.input_range = settings['input_range'] 131 | model.mean = settings['mean'] 132 | model.std = settings['std'] 133 | return model 134 | 135 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 136 | """3x3 convolution with padding""" 137 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 138 | padding=dilation, groups=groups, bias=False, dilation=dilation) 139 | 140 | 141 | def conv1x1(in_planes, out_planes, stride=1): 142 | """1x1 convolution""" 143 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 144 | 145 | 146 | class BasicBlock(nn.Module): 147 | expansion = 1 148 | 149 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 150 | base_width=64, dilation=1, norm_layer=None): 151 | super(BasicBlock, self).__init__() 152 | if norm_layer is None: 153 | norm_layer = nn.BatchNorm2d 154 | if groups != 1 or base_width != 64: 155 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 156 | if dilation > 1: 157 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 158 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 159 | self.conv1 = conv3x3(inplanes, planes, stride) 160 | self.bn1 = norm_layer(planes) 161 | self.relu = nn.ReLU(inplace=True) 162 | self.conv2 = conv3x3(planes, planes) 163 | self.bn2 = norm_layer(planes) 164 | self.downsample = downsample 165 | self.stride = stride 166 | 167 | def forward(self, x): 168 | identity = x 169 | 170 | out = self.conv1(x) 171 | out = self.bn1(out) 172 | out = self.relu(out) 173 | 174 | out = self.conv2(out) 175 | out = self.bn2(out) 176 | 177 | if self.downsample is not None: 178 | identity = self.downsample(x) 179 | 180 | out += identity 181 | out = self.relu(out) 182 | 183 | return out 184 | 185 | 186 | class Bottleneck(nn.Module): 187 | expansion = 4 188 | 189 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 190 | base_width=64, dilation=1, norm_layer=None, temporal_stride=1, enable_fuse=False, modality='temporal'): 191 | super(Bottleneck, self).__init__() 192 | self.enable_fuse=enable_fuse 193 | self.modality = modality 194 | 195 | if norm_layer is None: 196 | norm_layer = nn.BatchNorm3d 197 | width = int(planes * (base_width / 64.)) * groups 198 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 199 | if args.net_version == 'pure_fused': 200 | self.bottleneck = nn.Sequential(OrderedDict([ 201 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False)) 202 | ])) 203 | self.temporal = nn.Sequential(OrderedDict([ 204 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 205 | ])) 206 | elif args.net_version == 'pure_spatial': 207 | self.bottleneck = nn.Sequential(OrderedDict([ 208 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False)) 209 | ])) 210 | elif args.net_version == 'pure_temporal': 211 | self.temporal = nn.Sequential(OrderedDict([ 212 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 213 | ])) 214 | else: 215 | assert args.net_version == 'pure_adaptive', 'Unknown network version: {}'.format(args.net_version) 216 | if self.enable_fuse: 217 | self.bottleneck = nn.Sequential(OrderedDict([ 218 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False)) 219 | ])) 220 | self.temporal = nn.Sequential(OrderedDict([ 221 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 222 | ])) 223 | else: 224 | if self.modality == 'temporal': 225 | self.temporal = nn.Sequential(OrderedDict([ 226 | ('conv1', Conv3d(inplanes, width, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=False)) 227 | ])) 228 | else: 229 | self.bottleneck = nn.Sequential(OrderedDict([ 230 | ('conv1', Conv3d(inplanes, width, kernel_size=1, stride=1, bias=False)) 231 | ])) 232 | 233 | 234 | self.bn1 = norm_layer(width) 235 | 236 | if temporal_stride != 1: 237 | self.temporal_pool = nn.AvgPool3d(kernel_size=(temporal_stride, 1, 1), stride=(temporal_stride, 1, 1)) 238 | 239 | self.conv2 = Conv3d(width, width, groups=groups, dilation=dilation, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, 1, 1), bias=False) 240 | self.bn2 = norm_layer(width) 241 | self.conv3 = Conv3d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False) 242 | self.bn3 = norm_layer(planes * self.expansion) 243 | self.relu = nn.ReLU(inplace=True) 244 | self.downsample = downsample 245 | self.stride = stride 246 | self.temporal_stride = temporal_stride 247 | 248 | def forward(self, x): 249 | identity = x 250 | 251 | if self.temporal_stride != 1: 252 | out = self.temporal_pool(x) 253 | else: 254 | out = x 255 | if args.net_version == 'pure_fused': 256 | out = self.temporal(out) + self.bottleneck(out) 257 | elif args.net_version == 'pure_spatial': 258 | out = self.bottleneck(out) 259 | elif args.net_version == 'pure_temporal': 260 | out = self.temporal(out) 261 | else: 262 | assert args.net_version == 'pure_adaptive', 'Unknown network version: {}'.format(args.net_version) 263 | if self.enable_fuse: 264 | out = self.temporal(out) + self.bottleneck(out) 265 | else: 266 | if self.modality == 'temporal': 267 | out = self.temporal(out) 268 | else: 269 | out = self.bottleneck(out) 270 | 271 | out = self.bn1(out) 272 | out = self.relu(out) 273 | 274 | out = self.conv2(out) 275 | out = self.bn2(out) 276 | out = self.relu(out) 277 | 278 | out = self.conv3(out) 279 | out = self.bn3(out) 280 | 281 | if self.downsample is not None: 282 | identity = self.downsample(x) 283 | 284 | out += identity 285 | out = self.relu(out) 286 | 287 | return out 288 | 289 | 290 | class ResNet(nn.Module): 291 | 292 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 293 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 294 | norm_layer=None, drop_rate=0.0): 295 | super(ResNet, self).__init__() 296 | self.drop_rate = drop_rate 297 | 298 | if norm_layer is None: 299 | norm_layer = nn.BatchNorm3d 300 | self._norm_layer = norm_layer 301 | 302 | self.inplanes = 64 303 | self.dilation = 1 304 | if replace_stride_with_dilation is None: 305 | # each element in the tuple indicates if we should replace 306 | # the 2x2 stride with a dilated convolution instead 307 | replace_stride_with_dilation = [False, False, False] 308 | if len(replace_stride_with_dilation) != 3: 309 | raise ValueError("replace_stride_with_dilation should be None " 310 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 311 | self.groups = groups 312 | self.base_width = width_per_group 313 | self.conv1 = nn.Conv3d(3, self.inplanes, kernel_size=(5, 7, 7), stride=(2, 2, 2) if 'kinetics' in args.dataset or 'ucf' in args.dataset else (1, 2, 2), padding=(2, 3, 3), 314 | bias=False) 315 | self.bn1 = norm_layer(self.inplanes) 316 | self.relu = nn.ReLU(inplace=True) 317 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)) 318 | self.layer1 = self._make_layer(block, 64, layers[0], temporal_stride=1, enable_fuse=True) 319 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 320 | dilate=replace_stride_with_dilation[0], temporal_stride=1, enable_fuse=True) 321 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 322 | dilate=replace_stride_with_dilation[1], temporal_stride=1, enable_fuse=False) 323 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 324 | dilate=replace_stride_with_dilation[2], temporal_stride=1, enable_fuse=True) 325 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 326 | self.new_fc = Linear(512 * block.expansion, num_classes) 327 | 328 | for m in self.modules(): 329 | if isinstance(m, nn.Conv3d): 330 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 331 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)): 332 | nn.init.constant_(m.weight, 1) 333 | nn.init.constant_(m.bias, 0) 334 | elif isinstance(m, nn.Linear): 335 | nn.init.constant_(m.bias, 0) 336 | 337 | # Zero-initialize the last BN in each residual branch, 338 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 339 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 340 | if zero_init_residual: 341 | for m in self.modules(): 342 | if isinstance(m, Bottleneck): 343 | nn.init.constant_(m.bn3.weight, 0) 344 | elif isinstance(m, BasicBlock): 345 | nn.init.constant_(m.bn2.weight, 0) 346 | 347 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, temporal_stride=1, enable_fuse=False): 348 | norm_layer = self._norm_layer 349 | downsample = None 350 | previous_dilation = self.dilation 351 | if dilate: 352 | self.dilation *= stride 353 | stride = 1 354 | if stride != 1: 355 | downsample = nn.Sequential(OrderedDict([ 356 | ('avepool', nn.AvgPool3d(kernel_size=(temporal_stride, 3, 3), stride=(temporal_stride, stride, stride), padding=(0, 1, 1))), 357 | ('0', Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=(1, 1, 1), bias=False)), 358 | ('1', norm_layer(planes * block.expansion)) 359 | ])) 360 | elif self.inplanes != planes * block.expansion: 361 | assert temporal_stride==1, 'temporal stride != 1' 362 | downsample = nn.Sequential(OrderedDict([ 363 | ('0', Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=(1, 1, 1), bias=False)), 364 | ('1', norm_layer(planes * block.expansion)) 365 | ])) 366 | layers = [] 367 | if blocks == 23: 368 | if 'kinetics' in args.dataset: 369 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 370 | self.base_width, previous_dilation, norm_layer, temporal_stride=temporal_stride, 371 | enable_fuse=False, modality='spatial')) 372 | self.inplanes = planes * block.expansion 373 | for _ in range(1, 5): 374 | layers.append(block(self.inplanes, planes, groups=self.groups, 375 | base_width=self.base_width, dilation=self.dilation, 376 | norm_layer=norm_layer, enable_fuse=False, modality='spatial')) 377 | for _ in range(5, 15): 378 | layers.append(block(self.inplanes, planes, groups=self.groups, 379 | base_width=self.base_width, dilation=self.dilation, 380 | norm_layer=norm_layer, enable_fuse=True)) 381 | for _ in range(15, 19): 382 | layers.append(block(self.inplanes, planes, groups=self.groups, 383 | base_width=self.base_width, dilation=self.dilation, 384 | norm_layer=norm_layer, enable_fuse=False, modality='spatial')) 385 | for _ in range(19, blocks): 386 | layers.append(block(self.inplanes, planes, groups=self.groups, 387 | base_width=self.base_width, dilation=self.dilation, 388 | norm_layer=norm_layer, enable_fuse=False, modality='temporal')) 389 | else: 390 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 391 | self.base_width, previous_dilation, norm_layer, temporal_stride=temporal_stride, 392 | enable_fuse=False)) 393 | self.inplanes = planes * block.expansion 394 | for _ in range(1, 10): 395 | layers.append(block(self.inplanes, planes, groups=self.groups, 396 | base_width=self.base_width, dilation=self.dilation, 397 | norm_layer=norm_layer, enable_fuse=True)) 398 | for _ in range(10, blocks): 399 | layers.append(block(self.inplanes, planes, groups=self.groups, 400 | base_width=self.base_width, dilation=self.dilation, 401 | norm_layer=norm_layer, enable_fuse=False)) 402 | else: 403 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 404 | self.base_width, previous_dilation, norm_layer, temporal_stride=temporal_stride, 405 | enable_fuse=enable_fuse)) 406 | self.inplanes = planes * block.expansion 407 | for _ in range(1, blocks): 408 | layers.append(block(self.inplanes, planes, groups=self.groups, 409 | base_width=self.base_width, dilation=self.dilation, 410 | norm_layer=norm_layer, enable_fuse=enable_fuse)) 411 | 412 | return nn.Sequential(*layers) 413 | 414 | def forward(self, x): 415 | x = self.conv1(x) 416 | x = self.bn1(x) 417 | x = self.relu(x) 418 | x = self.maxpool(x) 419 | 420 | x = self.layer1(x) 421 | x = self.layer2(x) 422 | x = self.layer3(x) 423 | x = self.layer4(x) 424 | 425 | x = F.adaptive_avg_pool3d(x, (1, 1, 1)).view(x.size(0), -1) 426 | x = F.dropout(x, p=self.drop_rate, training=self.training) 427 | x = self.new_fc(x) 428 | 429 | return x 430 | 431 | 432 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 433 | model = ResNet(block, layers, **kwargs) 434 | if pretrained is not None: 435 | settings = pretrained_settings[arch][pretrained] 436 | model = load_pretrained(model, kwargs['num_classes'], settings) 437 | return model 438 | 439 | 440 | def resnet18(pretrained=False, progress=True, **kwargs): 441 | r"""ResNet-18 model from 442 | `"Deep Residual Learning for Image Recognition" `_ 443 | 444 | Args: 445 | pretrained (bool): If True, returns a model pre-trained on ImageNet 446 | progress (bool): If True, displays a progress bar of the download to stderr 447 | """ 448 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 449 | **kwargs) 450 | 451 | 452 | 453 | def resnet34(pretrained=False, progress=True, **kwargs): 454 | r"""ResNet-34 model from 455 | `"Deep Residual Learning for Image Recognition" `_ 456 | 457 | Args: 458 | pretrained (bool): If True, returns a model pre-trained on ImageNet 459 | progress (bool): If True, displays a progress bar of the download to stderr 460 | """ 461 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 462 | **kwargs) 463 | 464 | 465 | 466 | def resnet50(pretrained='imagenet', progress=True, **kwargs): 467 | r"""ResNet-50 model from 468 | `"Deep Residual Learning for Image Recognition" `_ 469 | 470 | Args: 471 | pretrained (str): pre-trained model 472 | progress (bool): If True, displays a progress bar of the download to stderr 473 | """ 474 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 475 | **kwargs) 476 | 477 | 478 | 479 | def resnet101(pretrained=False, progress=True, **kwargs): 480 | r"""ResNet-101 model from 481 | `"Deep Residual Learning for Image Recognition" `_ 482 | 483 | Args: 484 | pretrained (bool): If True, returns a model pre-trained on ImageNet 485 | progress (bool): If True, displays a progress bar of the download to stderr 486 | """ 487 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 488 | **kwargs) 489 | 490 | 491 | 492 | def resnet152(pretrained=False, progress=True, **kwargs): 493 | r"""ResNet-152 model from 494 | `"Deep Residual Learning for Image Recognition" `_ 495 | 496 | Args: 497 | pretrained (bool): If True, returns a model pre-trained on ImageNet 498 | progress (bool): If True, displays a progress bar of the download to stderr 499 | """ 500 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 501 | **kwargs) 502 | 503 | 504 | 505 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 506 | r"""ResNeXt-50 32x4d model from 507 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 508 | 509 | Args: 510 | pretrained (bool): If True, returns a model pre-trained on ImageNet 511 | progress (bool): If True, displays a progress bar of the download to stderr 512 | """ 513 | kwargs['groups'] = 32 514 | kwargs['width_per_group'] = 4 515 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 516 | pretrained, progress, **kwargs) 517 | 518 | 519 | 520 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 521 | r"""ResNeXt-101 32x8d model from 522 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 523 | 524 | Args: 525 | pretrained (bool): If True, returns a model pre-trained on ImageNet 526 | progress (bool): If True, displays a progress bar of the download to stderr 527 | """ 528 | kwargs['groups'] = 32 529 | kwargs['width_per_group'] = 8 530 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 531 | pretrained, progress, **kwargs) 532 | 533 | 534 | 535 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 536 | r"""Wide ResNet-50-2 model from 537 | `"Wide Residual Networks" `_ 538 | 539 | The model is the same as ResNet except for the bottleneck number of channels 540 | which is twice larger in every block. The number of channels in outer 1x1 541 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 542 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 543 | 544 | Args: 545 | pretrained (bool): If True, returns a model pre-trained on ImageNet 546 | progress (bool): If True, displays a progress bar of the download to stderr 547 | """ 548 | kwargs['width_per_group'] = 64 * 2 549 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 550 | pretrained, progress, **kwargs) 551 | 552 | 553 | 554 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 555 | r"""Wide ResNet-101-2 model from 556 | `"Wide Residual Networks" `_ 557 | 558 | The model is the same as ResNet except for the bottleneck number of channels 559 | which is twice larger in every block. The number of channels in outer 1x1 560 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 561 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 562 | 563 | Args: 564 | pretrained (bool): If True, returns a model pre-trained on ImageNet 565 | progress (bool): If True, displays a progress bar of the download to stderr 566 | """ 567 | kwargs['width_per_group'] = 64 * 2 568 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 569 | pretrained, progress, **kwargs) 570 | 571 | 572 | if __name__ == '__main__': 573 | resnet50(pretrained='imagenet', num_classes=1000) -------------------------------------------------------------------------------- /philly_distributed_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/philly_distributed_utils/__init__.py -------------------------------------------------------------------------------- /philly_distributed_utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import numpy as np 4 | import subprocess 5 | from contextlib import contextmanager 6 | import logging 7 | 8 | def ompi_rank(): 9 | """Find OMPI world rank without calling mpi functions 10 | :rtype: int 11 | """ 12 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 13 | 14 | 15 | def ompi_size(): 16 | """Find OMPI world size without calling mpi functions 17 | :rtype: int 18 | """ 19 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 20 | 21 | 22 | def ompi_local_rank(): 23 | """Find OMPI local rank without calling mpi functions 24 | :rtype: int 25 | """ 26 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 27 | 28 | 29 | def ompi_local_size(): 30 | """Find OMPI local size without calling mpi functions 31 | :rtype: int 32 | """ 33 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE') or 1) 34 | 35 | 36 | @contextmanager 37 | def run_and_terminate_process(*args, **kwargs): 38 | """Run a process and terminate it at the end 39 | """ 40 | p = None 41 | try: 42 | p = subprocess.Popen(*args, **kwargs) 43 | yield p 44 | finally: 45 | if not p: 46 | return 47 | try: 48 | p.terminate() # send sigterm 49 | except OSError: 50 | pass 51 | try: 52 | p.kill() # send sigkill 53 | except OSError: 54 | pass 55 | 56 | 57 | def get_gpus_nocache(): 58 | """List of NVIDIA GPUs 59 | """ 60 | cmds = 'nvidia-smi --query-gpu=name --format=csv,noheader'.split(' ') 61 | with run_and_terminate_process(cmds, 62 | stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 63 | bufsize=1) as process: 64 | return [line.strip() for line in iter(process.stdout.readline, "")] 65 | 66 | 67 | def get_gpus(): 68 | """List of NVIDIA GPUs 69 | """ 70 | return get_gpus_nocache() 71 | 72 | 73 | def gpu_indices(divisible=True): 74 | """Get the GPU device indices for this process/rank 75 | :param divisible: if GPU count of all ranks must be the same 76 | :rtype: list[int] 77 | """ 78 | local_size = ompi_local_size() 79 | local_rank = ompi_local_rank() 80 | assert 0 <= local_rank < local_size, "Invalid local_rank: {} local_size: {}".format(local_rank, local_size) 81 | gpu_count = len(get_gpus()) 82 | assert gpu_count >= local_size > 0, "GPU count: {} must be >= LOCAL_SIZE: {} > 0".format(gpu_count, local_size) 83 | if divisible: 84 | ngpu = gpu_count / local_size 85 | gpus = np.arange(local_rank * ngpu, (local_rank + 1) * ngpu) 86 | if gpu_count % local_size != 0: 87 | logging.warning("gpu_count: {} not divisible by local_size: {}; some GPUs may be unused".format( 88 | gpu_count, local_size 89 | )) 90 | else: 91 | gpus = np.array_split(range(gpu_count), local_size)[local_rank] 92 | return gpus 93 | -------------------------------------------------------------------------------- /philly_distributed_utils/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import re 4 | 5 | def _vc_home(): 6 | """Find philly's VC home in scratch space 7 | :rtype: str 8 | """ 9 | home = os.environ.get('PHILLY_VC_NFS_DIRECTORY', os.environ.get('PHILLY_VC_DIRECTORY')) 10 | if not home: 11 | home = op.expanduser('~') 12 | home = '/'.join(home.split('/')[:5]) 13 | return home 14 | 15 | 16 | _VC_HOME = _vc_home() 17 | 18 | 19 | def vc_name(): 20 | """Find philly's VC name 21 | :rtype: str 22 | """ 23 | name = os.environ.get('PHILLY_VC') 24 | if name: 25 | return name 26 | name = op.basename(_VC_HOME) 27 | if name: 28 | return name 29 | return op.basename(op.dirname(_VC_HOME)) 30 | 31 | 32 | _VC_NAME = vc_name() 33 | 34 | 35 | def _vc_hdfs_base(): 36 | base = os.environ.get("PHILLY_DATA_DIRECTORY") or os.environ.get("PHILLY_HDFS_PREFIX") 37 | if base: 38 | return base 39 | for base in ["/hdfs", "/home"]: 40 | if op.isdir(base): 41 | return base 42 | return _VC_HOME 43 | 44 | 45 | def vc_hdfs_root(): 46 | """Find the HDFS root of the VC 47 | :rtype: str 48 | """ 49 | path = os.environ.get('PHILLY_VC_HDFS_DIRECTORY') 50 | if path: 51 | return path 52 | path = op.join(os.environ.get('PHILLY_HDFS_PREFIX', _vc_hdfs_base()), _VC_NAME) 53 | return path 54 | 55 | 56 | _VC_HDFS_ROOT = vc_hdfs_root() 57 | 58 | 59 | def expand_vc_user(path): 60 | """Expand ~ to VC's home 61 | :param path: the path to expand VC user 62 | :type path: str 63 | :return:/var/storage/shared/$VC_NAME 64 | :rtype: str 65 | """ 66 | if path.startswith('~'): 67 | path = op.abspath(op.join(_VC_HOME, '.' + path[1:])) 68 | 69 | return path 70 | 71 | def abspath(path, roots=None): 72 | """Expand ~ to VC's home and resolve relative paths to absolute paths 73 | :param path: the path to resolve 74 | :type path: str 75 | :param roots: CWD roots to resolve relative paths to them 76 | :type roots: list 77 | """ 78 | path = expand_vc_user(path) 79 | if op.isabs(path): 80 | return path 81 | if not roots: 82 | roots = ["~"] 83 | roots = [expand_vc_user(root) for root in roots] 84 | for root in roots: 85 | resolved = op.abspath(op.join(root, path)) 86 | if op.isfile(resolved) or op.isdir(resolved): 87 | return resolved 88 | # return assuming the first root (even though it does not exist) 89 | return op.abspath(op.join(roots[0], path)) 90 | 91 | 92 | def job_id(path=None): 93 | """Get the philly job ID (from a path) 94 | :param path:Path to seach for app id 95 | :rtype: str 96 | """ 97 | if path is None: 98 | return os.environ.get('PHILLY_JOB_ID') or job_id(op.expanduser('~')) 99 | m = re.search('/(?Papplication_[\d_]+)[/\w]*$', path) 100 | if m: 101 | return m.group('app_id') 102 | return '' 103 | 104 | 105 | def get_model_path(path=None): 106 | """Find the default location to output/models 107 | """ 108 | return abspath(op.join('sys', 'jobs', job_id(path), 'models'), roots=[vc_hdfs_root()]) 109 | 110 | 111 | def get_master_machine(): 112 | mpi_host_file = op.expanduser('~/mpi-hosts') 113 | with open(mpi_host_file, 'r') as f: 114 | master_name = f.readline().strip() 115 | return master_name 116 | 117 | 118 | def get_master_ip(master_name=None): 119 | if master_name is None: 120 | master_name = get_master_machine() 121 | etc_host_file = '/etc/hosts' 122 | with open(etc_host_file, 'r') as f: 123 | name_ip_pairs = f.readlines() 124 | name2ip = {} 125 | for name_ip_pair in name_ip_pairs: 126 | pair_list = name_ip_pair.split(' ') 127 | key = pair_list[1].strip() 128 | value = pair_list[0] 129 | name2ip[key] = value 130 | return name2ip[master_name] 131 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scenarios/Probabilistic3DCNN/444385aeabc001282064877faba7a15a787a4f94/tools/__init__.py -------------------------------------------------------------------------------- /tools/ckpt_checker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description="PyTorch implementation of NAS_spatiotemporal") 3 | parser.add_argument('--ckptpth', type=str, default='/mnt/log/NAS_spatiotemporal/checkpoint/NAS_sptp_nasas__ls250.0_something_RGB_Dense3D121_avg_segment1_e90_droprate0.0_num_dense_sample32_dense_sample_stride1_dense_Netv1d3Bz2by16Lr0.005SbnTpbtShare40-60-80/NAS_sptp_nasas_selection_ls250.0_something_RGB_Dense3D121_avg_segment1_e50_droprate0.0_num_dense_sample32_dense_Netv1d3Spbtsharels250temptSelect/ckpt.best.1.pth.tar') 4 | 5 | import os 6 | import numpy as np 7 | import csv 8 | 9 | def alpha_checker(state_dict, path=''): 10 | with open(os.path.join(path, 'p_log.csv'), mode='w') as csv_file: 11 | fields = ['index', 'S', 'T'] 12 | csv_writer = csv.DictWriter(csv_file, fieldnames=fields) 13 | csv_writer.writeheader() 14 | 15 | records = {} 16 | for name, value in state_dict.items(): 17 | if 'p_logit' in name and 'classifier' not in name: 18 | print('{} {}'.format(name, value.sigmoid().item())) 19 | name = name.replace('module.features.', '') 20 | name = name.replace('.conv1.p_logit', '') 21 | name = name.replace('.conv.p_logit', '') 22 | name = name.replace('bottleneck', 'S') 23 | name = name.replace('temporal', 'T') 24 | name = name.replace('denseblock', 'B') 25 | name = name.replace('denselayer', 'L') 26 | name = name.replace('original', 'S') 27 | if '.S' in name: 28 | if name.replace('.S', '') in records.keys(): 29 | records[name.replace('.S', '')]['S'] = value.sigmoid().item() 30 | else: 31 | records[name.replace('.S', '')] = {} 32 | records[name.replace('.S', '')]['S'] = value.sigmoid().item() 33 | #csv_writer.writerow({'index': name, 'S': value.sigmoid().item(), 'T': state_dict[name.replace('.S', '.T')].sigmoid().item()}) 34 | elif '.T' in name: 35 | if name.replace('.T', '') in records.keys(): 36 | records[name.replace('.T', '')]['T'] = value.sigmoid().item() 37 | else: 38 | records[name.replace('.T', '')] = {} 39 | records[name.replace('.T', '')]['T'] = value.sigmoid().item() 40 | else: 41 | pass 42 | 43 | for name, value in records.items(): 44 | csv_writer.writerow({'index': name, 'S': value['S'], 'T': value['T']}) 45 | 46 | 47 | def sptp_checker(state_dict): 48 | t_count = 0 49 | s_count = 0 50 | for name, value in state_dict.items(): 51 | if 'p_logit' in name and 'classifier' not in name: 52 | #print('{}: {}'.format(name, value.sigmoid())) 53 | stensor = 1- np.floor(state_dict[name.replace('p_logit', 'unif_noise_variable')].cpu().item() + value.sigmoid().cpu().item()) 54 | if int(stensor) == 0: 55 | if 'temporal' in name: 56 | t_count += 1 57 | elif 'bottleneck' or 'original' in name: 58 | s_count += 1 59 | print('{}: {}'.format(name, stensor)) 60 | #if 'norm' in name: 61 | # print('{}: {}'.format(name, value)) 62 | print('{}: {}'.format('t_count', t_count)) 63 | print('{}: {}'.format('s_count', s_count)) 64 | 65 | 66 | def main(): 67 | args = parser.parse_args() 68 | import torch 69 | checkpoint = torch.load(os.path.join(args.ckptpth, 'ckpt.best.1.pth.tar'), map_location='cpu') 70 | state_dict = checkpoint['state_dict'] 71 | #alpha_checker(state_dict, path=args.ckptpth) 72 | sptp_checker(state_dict) 73 | 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /tools/generate_label_sthsthv1.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V1 8 | 9 | import os 10 | 11 | root_path = '/mnt/data/somethingsomethingv1/' 12 | 13 | if __name__ == '__main__': 14 | os.chdir(root_path) 15 | dataset_name = 'something-something-v1' # 'jester-v1' 16 | with open('%s-labels.csv' % dataset_name) as f: 17 | lines = f.readlines() 18 | categories = [] 19 | for line in lines: 20 | line = line.rstrip() 21 | categories.append(line) 22 | categories = sorted(categories) 23 | with open('category.txt', 'w') as f: 24 | f.write('\n'.join(categories)) 25 | 26 | dict_categories = {} 27 | for i, category in enumerate(categories): 28 | dict_categories[category] = i 29 | 30 | files_input = ['%s-validation.csv' % dataset_name, '%s-train.csv' % dataset_name] 31 | files_output = ['val_videofolder_azure.txt', 'train_videofolder_azure.txt'] 32 | for (filename_input, filename_output) in zip(files_input, files_output): 33 | with open(filename_input) as f: 34 | lines = f.readlines() 35 | folders = [] 36 | idx_categories = [] 37 | for line in lines: 38 | line = line.rstrip() 39 | items = line.split(';') 40 | folders.append(items[0]) 41 | idx_categories.append(dict_categories[items[1]]) 42 | output = [] 43 | for i in range(len(folders)): 44 | curFolder = folders[i] 45 | curIDX = idx_categories[i] 46 | # counting the number of frames in each video folders 47 | dir_files = os.listdir(os.path.join('./img', curFolder)) 48 | output.append('%s %d %d' % (os.path.join('', curFolder), len(dir_files), curIDX)) 49 | print('%d/%d' % (i, len(folders))) 50 | with open(filename_output, 'w') as f: 51 | f.write('\n'.join(output)) -------------------------------------------------------------------------------- /tools/generate_label_ucf101.py: -------------------------------------------------------------------------------- 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" 2 | # arXiv:1811.08383 3 | # Ji Lin*, Chuang Gan, Song Han 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu 5 | # ------------------------------------------------------ 6 | # Code adapted from https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py 7 | # processing the raw data of the video Something-Something-V1 8 | 9 | import os 10 | 11 | root_path = '/home/sda/data-writable/ucf101/' 12 | 13 | if __name__ == '__main__': 14 | os.chdir(root_path) 15 | dataset_name = 'ucf101' 16 | with open('%s-labels.txt' % dataset_name) as f: 17 | lines = f.readlines() 18 | dict_categories = {} 19 | for line in lines: 20 | line = line.rstrip() 21 | line = line.split( ) 22 | dict_categories[line[1]] = int(line[0])-1 23 | 24 | 25 | files_input = ['testlist01.txt', 'trainlist01.txt'] 26 | files_output = ['val_videofolder.txt', 'train_videofolder.txt'] 27 | for (filename_input, filename_output) in zip(files_input, files_output): 28 | with open(filename_input) as f: 29 | lines = f.readlines() 30 | folders = [] 31 | idx_categories = [] 32 | for line in lines: 33 | line = line.rstrip() 34 | items = line.split('.') 35 | folders.append(items[0]) 36 | idx_categories.append(dict_categories[items[0].split('/')[0]]) 37 | output = [] 38 | for i in range(len(folders)): 39 | curFolder = folders[i] 40 | curIDX = idx_categories[i] 41 | # counting the number of frames in each video folders 42 | dir_files = os.listdir(os.path.join('./UCF-101_image', curFolder, 'i')) 43 | output.append('%s %d %d' % (os.path.join('', curFolder), len(dir_files), curIDX)) 44 | print('%d/%d' % (i, len(folders))) 45 | with open(filename_output, 'w') as f: 46 | f.write('\n'.join(output)) -------------------------------------------------------------------------------- /tools/statistics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | from thop import profile 5 | from models.densenet_3d_forstat import densenet121 6 | from args import parser 7 | 8 | import torch 9 | 10 | args = parser.parse_args() 11 | model = densenet121(num_classes=174, pretrained=None, drop_rate=0.5) 12 | input = torch.randn(1, 3, 128, 256, 256) 13 | flops, params = profile(model, inputs=(input, )) 14 | 15 | print(flops) 16 | print(params) -------------------------------------------------------------------------------- /tools/to_hdf5.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import h5py 5 | from zipfile import ZipFile 6 | import numpy as np 7 | 8 | from PIL import Image 9 | 10 | 11 | class VideoRecord(object): 12 | def __init__(self, row): 13 | self._data = row 14 | 15 | @property 16 | def path(self): 17 | return self._data[0] 18 | 19 | @property 20 | def num_frames(self): 21 | return int(self._data[1]) 22 | 23 | @property 24 | def label(self): 25 | return int(self._data[2]) 26 | 27 | 28 | class Converter(object): 29 | def __init__(self, root_path, list_file, target_path, image_tmpl): 30 | self._list_file = list_file 31 | self._root_path = root_path 32 | self._target_path = target_path 33 | self._image_tmpl = image_tmpl 34 | 35 | self._parse_list() 36 | 37 | def _parse_list(self): 38 | tmp = [x.strip().split(' ') for x in open(self._list_file)] 39 | tmp = [[' '.join(x[:-2]), x[-2], x[-1]] for x in tmp] 40 | self._video_list = [VideoRecord(item) for item in tmp] 41 | 42 | if self._image_tmpl == '{:06d}-{}_{:05d}.jpg': 43 | for v in self._video_list: 44 | v._data[1] = int(v._data[1]) / 2 45 | print('video number:%d' % (len(self._video_list))) 46 | 47 | def _full_path(self, directory, idx): 48 | return os.path.join(self._root_path, directory, self._image_tmpl.format(idx)) 49 | 50 | def _load_image(self, directory, idx): 51 | try: 52 | return Image.open(os.path.join(self._root_path, directory, self._image_tmpl.format(idx))).convert('RGB') 53 | except Exception: 54 | print('error loading image:', os.path.join(self._root_path, directory, self._image_tmpl.format(idx))) 55 | return Image.open(os.path.join(self._root_path, directory, self._image_tmpl.format(1))).convert('RGB') 56 | 57 | def convert(self): 58 | raise NotImplementedError() 59 | 60 | 61 | class HDF5Converter(Converter): 62 | def __init__(self, root_path, list_file, target_path, image_tmpl): 63 | super(HDF5Converter, self).__init__(root_path, list_file, target_path, image_tmpl) 64 | 65 | def convert(self): 66 | for record in self._video_list: 67 | if not os.path.exists(os.path.join(self._target_path, record.path)): 68 | os.makedirs(os.path.join(self._target_path, record.path)) 69 | assert not os.path.exists(os.path.join(self._target_path, record.path, 'RGB_frames')), "{} already exist".format(os.path.join(self._target_path, record.path, 'RGB_frames')) 70 | with h5py.File(os.path.join(self._target_path, record.path, 'RGB_frames'), 'w') as hdf: 71 | for idx in range(record.num_frames): 72 | img = np.asarray((self._load_image(record.path, idx+1)), dtype="uint8") 73 | hdf.create_dataset(record.path+"/"+self._image_tmpl.format(idx+1), data=img, dtype="uint8") 74 | print("{} Done".format(record.path)) 75 | 76 | 77 | class ZIPConverter(Converter): 78 | def __init__(self, root_path, list_file, target_path, image_tmpl): 79 | super(ZIPConverter, self).__init__(root_path, list_file, target_path, image_tmpl) 80 | 81 | def convert(self): 82 | _video_num = len(self._video_list) 83 | for i, record in enumerate(self._video_list): 84 | if not os.path.exists(os.path.join(self._target_path, record.path)): 85 | os.makedirs(os.path.join(self._target_path, record.path)) 86 | assert not os.path.exists(os.path.join(self._target_path, record.path, 'RGB_frames.zip')), "{} already exist".format(os.path.join(self._target_path, record.path, 'RGB_frames.zip')) 87 | with ZipFile(os.path.join(self._target_path, record.path, 'RGB_frames.zip'), 'w') as zipf: 88 | for idx in range(record.num_frames): 89 | #img = np.asarray((self._load_image(record.path, idx+1)), dtype="uint8") 90 | zipf.write(self._full_path(record.path, idx+1), arcname=self._image_tmpl.format(idx+1)) 91 | print("{} of {} ({}) Done".format(str(i), str(_video_num), record.path)) 92 | 93 | 94 | def main(list_file): 95 | root_path = "/home/sda/data-writable/something-something/" 96 | target_path = "/home/sdb/writable/20bn-something-something-v1_zip/" 97 | #list_file = "/home/sda/data-writable/kinetics400_frame/val_videofolder.txt" 98 | list_file = os.path.join(root_path, list_file) 99 | 100 | cvt = ZIPConverter(os.path.join(root_path, '20bn-something-something-v1'), list_file, target_path, '{:05d}.jpg') 101 | cvt.convert() 102 | 103 | 104 | if __name__ == '__main__': 105 | main(sys.argv[1]) 106 | 107 | -------------------------------------------------------------------------------- /tools/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import re 6 | import os 7 | 8 | def visualization(args): 9 | log_path = os.path.join(args.root_log, args.store_name) 10 | checkpoint_path = os.path.join(args.root_model, args.store_name) 11 | current_checkpoint_path = os.path.join(checkpoint_path, "ckpt.pth.tar") 12 | best_checkpoint_path = os.path.join(checkpoint_path, "ckpt.best.pth.tar") 13 | 14 | p_record_path = os.path.join(log_path, "p_record.txt") 15 | if os.path.isfile(p_record_path): 16 | # 不是第一次 17 | try: 18 | with open(p_record_path, 'r') as f: 19 | p_record = int(f.read()) 20 | assert 0 <= p_record <= 100000, "p_record out of range" 21 | print("history record p: ", p_record) 22 | except (IOError, ValueError) as E: 23 | print(E) 24 | print("the p_record.txt file is suddenly missing") 25 | p_record = 0 26 | else: 27 | # 第一次 28 | p_record = 0 29 | print("init record p: ", p_record) 30 | 31 | 32 | # get data 33 | try: 34 | files = os.listdir(log_path) 35 | log_pattern = re.compile('(log)(-*\d+)(\.csv)') 36 | log_number = re.compile('-*\d+') 37 | log_numbers = [] 38 | for f in files: 39 | if log_pattern.match(f): 40 | this_log_number = int(log_number.findall(f)[0]) 41 | log_numbers.append(this_log_number) 42 | min_log_number = min(log_numbers) 43 | min_log = 'log' + str(min_log_number) + '.csv' 44 | data = pd.read_csv(os.path.join(log_path, min_log), delimiter='\n', engine='python', header=None, 45 | error_bad_lines=False) 46 | data = [str(x) for x in data.values] 47 | test = [re.findall(r'-?\d+\.?\d*e?-?\d*?', x) for x in data if "Testing" in x] 48 | train = [re.findall(r'-?\d+\.?\d*e?-?\d*?', x) for x in data if "Worker" in x] 49 | y_val = [float(x[1]) for x in test] 50 | y_train_batch = [float(x[-5]) for x in train] 51 | y_train_avg_epoch = [] 52 | 53 | loss_val = [float(x[-1]) for x in test] 54 | loss_train_batch = [float(x[9]) for x in train] 55 | CE_loss_train_batch = [float(x[11]) for x in train] 56 | KL_loss_train_batch = [float(x[13]) for x in train] 57 | loss_train_avg_epoch = [] 58 | CE_loss_train_avg_epoch = [] 59 | KL_loss_train_avg_epoch = [] 60 | 61 | for i, tmp in enumerate(train): 62 | if float(tmp[-5]) == float(tmp[-4]) and i > 0: 63 | y_train_avg_epoch.append(float(train[i - 1][-4])) 64 | loss_train_avg_epoch.append(float(train[i - 1][10])) 65 | CE_loss_train_avg_epoch.append(float(train[i - 1][12])) 66 | KL_loss_train_avg_epoch.append(float(train[i - 1][14])) 67 | 68 | num_epochs = len(y_train_avg_epoch) 69 | num_batchs = len(y_train_batch) 70 | num_vals = len(y_val) 71 | x_train_avg_epoch = np.array(range(num_epochs)) 72 | x_train_batch = np.array(range(num_batchs)) * (num_epochs / num_batchs) 73 | x_val = np.array(range(num_vals)) * (num_epochs / num_vals) 74 | 75 | best_val_y = max(y_val) 76 | best_val_x = x_val[y_val.index(best_val_y)] 77 | 78 | if True: 79 | current_checkpoint = torch.load(current_checkpoint_path, map_location=lambda storage, loc: storage) 80 | best_checkpoint = torch.load(best_checkpoint_path, map_location=lambda storage, loc: storage) 81 | current_state_dict = current_checkpoint['state_dict'] 82 | best_state_dict = best_checkpoint['state_dict'] 83 | assert current_state_dict.keys() == best_state_dict.keys() 84 | current_p_logit = [(x, torch.sigmoid(torch.tensor(float(current_state_dict[x])))) for x in 85 | current_state_dict.keys() if 86 | "p_logit" in x] 87 | best_p_logit = [(x, torch.sigmoid(torch.tensor(float(best_state_dict[x])))) for x in 88 | best_state_dict.keys() if 89 | "p_logit" in x] 90 | X_p = list(range(len(current_p_logit))) 91 | X_p_ticks = [loc for loc, value in current_p_logit] 92 | Y_p_current = np.array([value.item() for loc, value in current_p_logit]) 93 | Y_p_best = np.array([value.item() for loc, value in best_p_logit]) 94 | except Exception as e: 95 | print("visual exception: ", e) 96 | return 97 | 98 | else: 99 | print("log and checkpoint data load success, generating the result picture") 100 | # result 101 | plt.figure(figsize=(20, 10)) 102 | 103 | plt.subplot(121) 104 | plt.title("prec1@{}".format("something")) 105 | plt.plot(x_train_batch, y_train_batch, label="train batchs") 106 | plt.plot(x_train_avg_epoch, y_train_avg_epoch, marker='*', label="train epochs average") 107 | plt.plot(x_val, y_val, marker='o', label="test per {} epochs".format(round(num_epochs / num_vals))) 108 | plt.annotate('best: {}'.format(best_val_y), 109 | xy=(best_val_x, best_val_y), 110 | xycoords='data', 111 | xytext=(50, 50), 112 | textcoords='offset points', 113 | fontsize=16, 114 | arrowprops=dict(arrowstyle='->', connectionstyle="arc3, rad=.2")) 115 | 116 | plt.xlabel("epochs") 117 | plt.ylabel("top1%") 118 | plt.legend(loc="best") 119 | 120 | plt.subplot(122) 121 | plt.title("loss@{}".format("something")) 122 | plt.plot(x_train_batch, loss_train_batch, label="train batchs' loss") 123 | plt.plot(x_train_batch, CE_loss_train_batch, label="train batchs' CrossEntropy loss") 124 | plt.plot(x_train_batch, KL_loss_train_batch, label="train batchs' KL loss") 125 | plt.plot(x_train_avg_epoch, loss_train_avg_epoch, marker="*", label="train epoch avg's loss") 126 | plt.plot(x_train_avg_epoch, CE_loss_train_avg_epoch, marker="*", label="train epoch avg's CE loss") 127 | plt.plot(x_train_avg_epoch, KL_loss_train_avg_epoch, marker="*", label="train epoch avg's KL loss") 128 | plt.plot(x_val, loss_val, marker="*", label="test per {} epochs' loss".format(round(num_epochs / num_vals))) 129 | plt.xlabel("epochs") 130 | plt.ylabel("loss") 131 | 132 | plt.legend(loc='best') 133 | 134 | plt.savefig(os.path.join(log_path, "result.png"), bbox_inches='tight') 135 | 136 | if True: 137 | 138 | # current p 139 | p_record += 1 140 | with open(p_record_path, "w") as f: 141 | f.write(str(p_record)) 142 | 143 | plt.figure(figsize=(20, 10)) 144 | plt.bar(X_p, Y_p_current, facecolor='#ff9800') 145 | plt.plot(X_p, Y_p_current, marker='^', markersize=15, linewidth=3) 146 | plt.title("p to drop connection") 147 | # 显示数据 148 | for x, y in zip(X_p, Y_p_current): 149 | plt.text(x, y + 0.03, 150 | '%.2f' % y, ha='center', va='bottom', 151 | fontdict={'color': '#0091ea', 152 | 'size': 16}) 153 | plt.ylim(0., 1.) 154 | plt.xticks(X_p, X_p_ticks, size="small", rotation=85) 155 | plt.yticks([]) 156 | plt.savefig(os.path.join(log_path, "current{}_p.png".format(p_record)), bbox_inches='tight') 157 | 158 | # best p 159 | plt.figure(figsize=(20, 10)) 160 | plt.bar(X_p, Y_p_best, facecolor='#0000FF') 161 | plt.plot(X_p, Y_p_best, marker='^', markersize=15, linewidth=3) 162 | plt.title("p to drop connection") 163 | # 显示数据 164 | for x, y in zip(X_p, Y_p_best): 165 | plt.text(x, y + 0.03, 166 | '%.2f' % y, ha='center', va='bottom', 167 | fontdict={'color': '#0091ea', 168 | 'size': 16}) 169 | plt.ylim(0., 1.) 170 | plt.xticks(X_p, X_p_ticks, size="small", rotation=85) 171 | plt.yticks([]) 172 | plt.savefig(os.path.join(log_path, "best_p.png"), bbox_inches='tight') 173 | 174 | plt.close('all') 175 | 176 | if __name__ == '__main__': 177 | import argparse 178 | parser = argparse.ArgumentParser(description="PyTorch implementation of NAS_spatiotemporal") 179 | parser.add_argument('--root_log', type=str, default='/mnt/log/NAS_spatiotemporal/log') 180 | parser.add_argument('--root_model', type=str, default='/mnt/log/NAS_spatiotemporal/checkpoint') 181 | parser.add_argument('--store_name', type=str, default="") 182 | 183 | args = parser.parse_args() 184 | visualization(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def softmax(scores): 5 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 6 | return es / es.sum(axis=-1)[..., None] 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def accuracy(output, target, topk=(1,)): 29 | """Computes the precision@k for the specified values of k""" 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | 37 | res = [] 38 | for k in topk: 39 | correct_k = correct[:k].view(-1).float().sum(0) 40 | res.append(correct_k.mul_(100.0 / batch_size)) 41 | return res --------------------------------------------------------------------------------