├── .gitignore ├── README.md ├── lib ├── data │ ├── data_container.py │ ├── data_set.py │ ├── spatial_transformation.py │ └── temporal_sampling.py ├── models │ └── resnet.py ├── netwrapper │ ├── fast_pathway.py │ ├── slow_pathway.py │ └── two_stream_net.py └── utils │ ├── UCF101-TrainTest-Split.ipynb │ ├── config.py │ ├── config_file_handling.py │ └── miscellaneous.py └── tools ├── _init_lib_path.py ├── epoch_loop.py ├── test.py ├── train.py ├── trainer.py ├── tv_abc.py └── validator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .nox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | .pytest_cache/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | db.sqlite3 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # IPython 75 | profile_default/ 76 | ipython_config.py 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | .dmypy.json 109 | dmypy.json 110 | .idea 111 | 112 | # deep learning framework directories 113 | experiment/ 114 | dataset/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SlowFast Networks for Video Recognition in PyTorch 2 | This is a PyTorch implementation of the "SlowFast Networks for Video Recognition" paper by Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, Kaiming He published in ICCV 2019. The official code has not been released yet. This implementation is motivated by the code found [here](https://github.com/r1ch88/SlowFastNetworks). 3 | 4 | ## Contents 5 | 6 | 1. [Introduction](#introduction) 7 | 2. [Installation](#installation) 8 | 3. [Pre-trained Base Networks](#pre-trained-base-networks) 9 | 4. [Datasets](#datasets) 10 | 5. [Preparation](#preparation) 11 | 6. [Training and Testing Procedures](#training-and-testing-procedures) 12 | 7. [Experimental Results](#experimental-results) 13 | 14 | ## Introduction 15 | Action recognition is one of the core tasks in video understanding and it has similar importance to image classification in the static vision domain. There are two common approaches in deep learning that started far apart at the beginning and recently have shown converging to somewhere in between. The first approach is using 3D convolutional layers that process the input spatiotemporal tensor while the second approach is human-brain-inspired and benefits from a Siamese network architecture. Recently, [Christoph](https://arxiv.org/abs/1812.03982) has proposed to extend the two-stream networks with the idea of having expert networks on each pathway: (1) the slow pathway has high parametric capacity and process the RGB information in slower speed of processing while (2) the fast pathway benefits from the fast pathway with a wider temporal receptive field but lower parametric capacity. This has shown promising improvement over the SOTA in action recognition and detection tasks on Kinetics and AVA datasets respectively. 16 | This repository is motivate by [this](https://github.com/r1ch88/SlowFastNetworks) SlowFast implementation and is extended to improve the usability, readability and modularity of the code. 17 | 18 | ## Installation 19 | 20 | 1. Clone the slowfast-networks-pytorch repository 21 | 22 | ```shell 23 | # Clone the repository 24 | git clone https://github.com/mbiparva/slowfast-networks-pytorch.git 25 | ``` 26 | 27 | 2. Go into the tools directory 28 | 29 | ```shell 30 | cd tools 31 | ``` 32 | 33 | 3. Run the training or testing script 34 | ```shell 35 | # to train 36 | python train.py 37 | # to test 38 | python test.py 39 | ``` 40 | 41 | ## Pre-trained Base Networks 42 | The pre-trained networks have not officially released yet. The original work trains the proposed network on a GPU cluster of 128 GPUs (i.e. 8 * 8). We provide this implementation as a proof-of-concept model that improves over the previous implementations of this work. 43 | 44 | ## Datasets 45 | Simply download the UCF101 from the original publisher at [here](https://www.crcv.ucf.edu/data/UCF101.php). You need to extract the zip files in the dataset directory such that it respect the following directory hierarchy so then the provided dataloader can easily find directories of different categories. Please run the jupyter notebook located in the lib/utils accordingly to split the original dataset based on the split number into the training and validation sets. 46 | 47 | ### Directory Hierarchy 48 | Please make sure the downloaded dataset folders and files sit according to the following structure: 49 | 50 | ``` 51 | dataset 52 | | | UCF101 53 | | | | training 54 | │ │ │ | ApplyEyeMakeup 55 | │ │ │ | ApplyLipstick 56 | │ │ │ | ... 57 | | | | validation 58 | │ │ │ | ApplyEyeMakeup 59 | │ │ │ | ApplyLipstick 60 | │ │ │ | ... 61 | ``` 62 | ## Preparation 63 | This implementation is tested on the following packages: 64 | * Python 3.7 65 | * PyTorch 1.2 66 | * CUDA 10.1 67 | * EasyDict 68 | 69 | ## Training and Testing Procedures 70 | You can train or test the network by using the "train.py" or "test.pt" as follows. 71 | 72 | ### Training Script 73 | You can use the tools/train.py to start training the network. If you use --help you will see the list of optional sys arguments that could be passed such as "--use-gpu" and "--gpu-id". You can also have a custom cfg file loaded to customize the reference one if you would not like to change the reference one. Additionally, you can set them one by one once you call "--set". 74 | 75 | ### Test Script 76 | You can use the tools/test.py to start testing the network by loading a custom network snapshot. You have to pass "--pre-trained-id" and "--pre-trained-epoch" to specify the network id and the epoch the snapshot was taken at. 77 | 78 | ### Configuration File 79 | All of the configuration hyperparameters are set in the lib/utils/config.py. If you want to change them permanently, simply edit the file with the settings you would like to. Otherwise, use the approaches mentioned above to temporary change them. 80 | 81 | ## Experimental Results 82 | Currently, the implementation is trained and tested on UCF101 with the following results. We are going to update gradually as we search for hyperparameters that improve the prediction results. The results are the 0-1 classification accuracy rate. 83 | 84 | | Net | Training | Test | 85 | | ----------------------- |:-------------:| -----:| 86 | | Aug 2019, Split #1 | 92% | 42% | 87 | -------------------------------------------------------------------------------- /lib/data/data_container.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | 3 | import os 4 | import datetime 5 | 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | import torch.utils.data 9 | from data.data_set import UCF101 10 | import data.spatial_transformation as Transformation 11 | 12 | numpy_type_map = { 13 | 'float64': torch.DoubleTensor, 14 | 'float32': torch.FloatTensor, 15 | 'float16': torch.HalfTensor, 16 | 'int64': torch.LongTensor, 17 | 'int32': torch.IntTensor, 18 | 'int16': torch.ShortTensor, 19 | 'int8': torch.CharTensor, 20 | 'uint8': torch.ByteTensor, 21 | } 22 | 23 | 24 | def ds_worker_init_fn(worker_id): 25 | np.random.seed(datetime.datetime.now().microsecond + worker_id) 26 | 27 | 28 | class DataContainer: 29 | def __init__(self, mode): 30 | self.dataset, self.dataloader = None, None 31 | self.mode = mode 32 | self.mode_cfg = cfg.get(self.mode.upper()) 33 | 34 | self.create() 35 | 36 | def create(self): 37 | self.create_dataset() 38 | self.create_dataloader() 39 | 40 | def create_transform(self): 41 | h, w = cfg.SPATIAL_INPUT_SIZE 42 | assert h == w 43 | transformations_final = [ 44 | Transformation.ToTensor(), 45 | Transformation.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet-driven 46 | ] 47 | if self.mode == 'train': 48 | # transformations = [ 49 | # Transformation.CenterCrop(240), 50 | # Transformation.Resize(h), 51 | # ] 52 | transformations = [ 53 | # Transformation.RandomCornerCrop(240, crop_scale=(240, 224, 192, 168), border=0.25), 54 | Transformation.RandomCornerCrop(240, crop_scale=(0.66, 1.0), border=0.25), 55 | Transformation.Resize(h), # This is necessary for async-resolution streaming 56 | Transformation.RandomHorizontalFlip() 57 | ] 58 | elif self.mode == 'valid': 59 | transformations = [ 60 | Transformation.CenterCrop(240), 61 | Transformation.Resize(h), 62 | ] 63 | else: 64 | raise NotImplementedError 65 | 66 | return Transformation.Compose( 67 | [Transformation.ToPILImage('RGB')] + transformations + transformations_final # since cv2 loads into np 68 | ) 69 | 70 | def create_dataset(self): 71 | spatial_transform = self.create_transform() 72 | 73 | if cfg.DATASET_NAME == 'UCF101': 74 | self.dataset = UCF101(self.mode, spatial_transform) 75 | 76 | def create_dataloader(self): 77 | self.dataloader = DataLoader(self.dataset, 78 | batch_size=self.mode_cfg.BATCH_SIZE, 79 | shuffle=self.mode_cfg.SHUFFLE, 80 | num_workers=4, 81 | pin_memory=True, 82 | drop_last=True, 83 | worker_init_fn=ds_worker_init_fn, 84 | ) 85 | -------------------------------------------------------------------------------- /lib/data/data_set.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | 3 | import os 4 | import numpy as np 5 | import cv2 6 | import torch 7 | from torch.utils.data import Dataset 8 | from data.temporal_sampling import TemporalSampler 9 | 10 | 11 | class UCF101(Dataset): 12 | def __init__(self, mode, spatial_trans): 13 | self.mode = mode 14 | self.dataset_path = os.path.join(cfg.DATASET_ROOT, 'training' if self.mode == 'train' else 'validation') 15 | self.spatial_trans = spatial_trans 16 | 17 | self.temporal_sampler = TemporalSampler(cfg.FRAME_SAMPLING_METHOD) 18 | 19 | self.file_names, file_labels, = [], [] 20 | for l in sorted(os.listdir(self.dataset_path)): 21 | for f in os.listdir(os.path.join(self.dataset_path, l)): 22 | self.file_names.append(os.path.join(self.dataset_path, l, f)) 23 | file_labels.append(l) 24 | 25 | self.l2i = {l: i for i, l in enumerate(sorted(set(file_labels)))} 26 | self.file_labels = np.asarray(list(map(lambda lab: self.l2i[lab], file_labels)), dtype=np.int) 27 | 28 | def __getitem__(self, index): 29 | f_name, f_label = self.file_names[index], self.file_labels[index] 30 | 31 | f_frames, f_height, f_width = self.video_info_retrieval(f_name) 32 | 33 | f_frame_list = self.temporal_sampler.frame_sampler(f_frames) 34 | 35 | frame_bank = self.load_frames(f_name, f_frame_list) 36 | 37 | frames_transformed = [] 38 | 39 | self.spatial_trans.randomize_parameters() 40 | for i in frame_bank: 41 | if cfg.FRAME_RANDOMIZATION: 42 | self.spatial_trans.randomize_parameters() 43 | frames_transformed.append(self.spatial_trans(i)) 44 | 45 | frames_packed = self.pack_frames(frames_transformed) 46 | 47 | return frames_packed, {'file_path': f_name, 'file_name': f_name.split('/')[-1], 'nframes': f_frames, 'label': f_label} 48 | 49 | def __len__(self): 50 | return len(self.file_names) 51 | 52 | @staticmethod 53 | def count_frames_accurately(vid_cap): 54 | frames = 0 55 | retaining, frame_data = vid_cap.read() 56 | 57 | while retaining: 58 | frames += 1 59 | retaining, frame_data = vid_cap.read() 60 | 61 | assert frames > 0, 'video file is corrupted, could not count frames' 62 | return frames 63 | 64 | def video_info_retrieval(self, file_name): 65 | vid_cap = cv2.VideoCapture(file_name) 66 | f_c = self.count_frames_accurately(vid_cap) # WE MUST USE THE MANUAL VERSION, IT IS ERROR PRONE. 67 | f_h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 68 | f_w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 69 | vid_cap.release() 70 | 71 | return f_c, f_h, f_w 72 | 73 | @staticmethod 74 | def load_frames(fname, frame_list): 75 | vid_cap = cv2.VideoCapture(fname) 76 | 77 | retaining, frame_data = vid_cap.read() 78 | assert retaining, 'the video clip is initially empty, very odd, maybe it is corrupted' 79 | frame_count = 0 80 | frame_bank = [] 81 | frame_list_iter = iter(frame_list) 82 | f = next(frame_list_iter) 83 | break_out = False 84 | 85 | while retaining: 86 | frame_data = frame_data[:, :, ::-1] # OpenCV loads in BGR 87 | while f == frame_count: 88 | frame_bank.append(frame_data) 89 | try: 90 | f = next(frame_list_iter) 91 | except StopIteration: 92 | break_out = True 93 | break 94 | 95 | if break_out: 96 | break 97 | 98 | retaining, frame_data = vid_cap.read() 99 | frame_count += 1 100 | 101 | vid_cap.release() 102 | 103 | return frame_bank 104 | 105 | @staticmethod 106 | def pack_frames(frames): 107 | frames_out = torch.stack(frames).transpose(1, 0) 108 | 109 | return frames_out 110 | -------------------------------------------------------------------------------- /lib/data/spatial_transformation.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numbers 3 | import random 4 | 5 | import torch 6 | import torchvision.transforms 7 | from torchvision.transforms import functional 8 | import numpy as np 9 | from PIL import Image 10 | try: 11 | import accimage 12 | except ImportError: 13 | accimage = None 14 | 15 | _pil_interpolation_to_str = { 16 | Image.NEAREST: 'PIL.Image.NEAREST', 17 | Image.BILINEAR: 'PIL.Image.BILINEAR', 18 | Image.BICUBIC: 'PIL.Image.BICUBIC', 19 | Image.LANCZOS: 'PIL.Image.LANCZOS', 20 | } 21 | 22 | 23 | def _is_pil_image(img): 24 | if accimage is not None: 25 | return isinstance(img, (Image.Image, accimage.Image)) 26 | else: 27 | return isinstance(img, Image.Image) 28 | 29 | 30 | def _is_tensor_image(img): 31 | return torch.is_tensor(img) and img.ndimension() == 3 32 | 33 | 34 | def _is_numpy_image(img): 35 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 36 | 37 | 38 | class Compose(torchvision.transforms.Compose): 39 | """Compose class that has randomization 40 | """ 41 | 42 | def __init__(self, transforms): 43 | super().__init__(transforms) 44 | 45 | def randomize_parameters(self): 46 | for t in self.transforms: 47 | t.randomize_parameters() 48 | 49 | 50 | def random_corner_crop(img, crop_position, size, b_w=0, b_h=0): 51 | w, h = img.size 52 | c_h, c_w = size 53 | 54 | if b_w and b_h: 55 | b_w, b_h = int((w - c_w) * b_w), int((h - c_h) * b_h) 56 | 57 | if c_h > w or c_w > h: 58 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w))) 59 | 60 | if crop_position == 'center': 61 | return functional.center_crop(img, (c_h, c_w)) 62 | elif crop_position == 'tl': 63 | return img.crop((b_w, b_h, b_w + c_w, b_h + c_h)) 64 | elif crop_position == 'tr': 65 | return img.crop((w - c_w - b_w, b_h, w - b_w, b_h + c_h)) 66 | elif crop_position == 'bl': 67 | return img.crop((b_w, h - c_h - b_h, b_w + c_w, h - b_h)) 68 | elif crop_position == 'br': 69 | return img.crop((w - c_w - b_w, h - c_h - b_h, w - b_w, h - b_h)) 70 | else: 71 | raise NotImplementedError 72 | 73 | 74 | # ------------------------------------------------- 75 | # transform re-implementation 76 | # ------------------------------------------------- 77 | class ToTensor(torchvision.transforms.ToTensor): 78 | 79 | def randomize_parameters(self): 80 | pass 81 | 82 | 83 | class Normalize(torchvision.transforms.Normalize): 84 | 85 | def randomize_parameters(self): 86 | pass 87 | 88 | 89 | class Resize(torchvision.transforms.Resize): 90 | 91 | def randomize_parameters(self): 92 | pass 93 | 94 | 95 | class CenterCrop(torchvision.transforms.CenterCrop): 96 | 97 | def randomize_parameters(self): 98 | pass 99 | 100 | 101 | class ToPILImage(torchvision.transforms.ToPILImage): 102 | 103 | def randomize_parameters(self): 104 | pass 105 | 106 | 107 | class RandomCornerCrop(object): 108 | """Randomly crops the given PIL Image at the four corners or center. 109 | Args: 110 | size (sequence or int): Desired output size of the crop. If size is an 111 | int instead of sequence like (h, w), a square crop (size, size) is 112 | made. 113 | crop_position (sequence): Desired crop position of the crop. If it is None, 114 | the position is randomly selected. Default choices are 115 | ('center', 'tl', 'tr', 'bl', 'br'). 116 | crop_scale (sequence): Desired list of scales in the range (0, image_size) 117 | to randomly crop from corners. 118 | """ 119 | def __init__(self, size, crop_position=None, crop_scale=1.0, border=0): 120 | self.size, self.crop_size, self.border = size, size, border 121 | self.border_w, self.border_h = 0, 0 122 | self.randomize_corner, self.randomize_scale = True, True 123 | self.default_positions = ('center', 'tl', 'tr', 'bl', 'br') 124 | if crop_position is not None: 125 | self.randomize_corner, self.crop_position = False, crop_position 126 | if isinstance(crop_scale, tuple): 127 | self.crop_scale = crop_scale 128 | elif isinstance(crop_scale, float): 129 | self.randomize_scale = False 130 | self.crop_size = self.size * crop_scale 131 | self.crop_size = (int(self.crop_size), int(self.crop_size)) 132 | else: 133 | raise Exception('NotDefined') 134 | 135 | def __call__(self, img): 136 | return random_corner_crop(img, self.crop_position, self.crop_size, self.border_w, self.border_h) 137 | 138 | def randomize_parameters(self): 139 | if self.randomize_corner: 140 | self.crop_position = random.choice(self.default_positions) 141 | if self.randomize_scale: 142 | if len(self.crop_scale) == 2 and self.crop_scale[0] < 2: 143 | self.crop_size = self.size * np.random.uniform(low=self.crop_scale[0], high=self.crop_scale[1]) 144 | else: 145 | self.crop_size = np.random.choice(self.crop_scale) 146 | self.crop_size = (int(self.crop_size), int(self.crop_size)) 147 | if self.border > 0: 148 | self.border_w, self.border_h = self.border*np.random.rand(), self.border*np.random.rand() 149 | 150 | 151 | class RandomHorizontalFlip(object): 152 | """Horizontally flip the given PIL Image randomly with a given probability. 153 | Args: 154 | p (float): probability of the image being flipped. Default value is 0.5 155 | """ 156 | 157 | def __init__(self, p=0.5): 158 | self.p, self.random_value = p, None 159 | 160 | def __call__(self, img): 161 | """ 162 | Args: 163 | img (PIL Image): Image to be flipped. 164 | Returns: 165 | PIL Image: Randomly flipped image. 166 | """ 167 | if self.random_value < self.p: 168 | return functional.hflip(img) 169 | return img 170 | 171 | def __repr__(self): 172 | return self.__class__.__name__ + '(p={})'.format(self.p) 173 | 174 | def randomize_parameters(self): 175 | self.random_value = random.random() 176 | 177 | 178 | class RandomResizedCrop(object): 179 | """Crop the given PIL Image to random size. 180 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 181 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 182 | is finally resized to given size. 183 | This is popularly used to train the Inception networks. 184 | Args: 185 | size: expected output size of each edge 186 | scale: range of size of the origin size cropped 187 | interpolation: Default: PIL.Image.BILINEAR 188 | """ 189 | 190 | def __init__(self, size, scale=(0.75, 1.0), interpolation=Image.BILINEAR): 191 | self.size = (size, size) 192 | self.interpolation = interpolation 193 | self.scale = scale 194 | self.r_scale, self.r_tl = None, None 195 | 196 | def __call__(self, img): 197 | """ 198 | Args: 199 | img (PIL Image): Image to be cropped and resized. 200 | Returns: 201 | PIL Image: Randomly cropped and resized image. 202 | """ 203 | w, h = img.size 204 | r_tl_w, r_tl_h = self.r_tl 205 | 206 | min_length = min(w, h) 207 | crop_size = int(min_length * self.r_scale) 208 | 209 | i = r_tl_h * (h - crop_size) 210 | j = r_tl_w * (w - crop_size) 211 | 212 | return functional.resized_crop(img, i, j, crop_size, crop_size, self.size, self.interpolation) 213 | 214 | def __repr__(self): 215 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 216 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 217 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 218 | format_string += ', interpolation={0})'.format(interpolate_str) 219 | return format_string 220 | 221 | def randomize_parameters(self): 222 | self.r_scale = random.uniform(*self.scale) 223 | self.r_tl = [random.random(), random.random()] 224 | -------------------------------------------------------------------------------- /lib/data/temporal_sampling.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | 3 | import math 4 | import numpy as np 5 | import os 6 | 7 | 8 | class TemporalSampler: 9 | def __init__(self, frame_sampling_method): 10 | self.frame_sampling_method = frame_sampling_method 11 | self.nf = cfg.NFRAMES_PER_VIDEO if self.frame_sampling_method != 'f25' else 25 12 | 13 | def frame_sampler(self, in_nframes): 14 | if self.frame_sampling_method == 'uniform': 15 | num_frames = max(1, self.nf-1) 16 | sample_rate = max(in_nframes // num_frames, 1) 17 | frame_samples = np.arange(0, in_nframes, sample_rate) 18 | 19 | elif self.frame_sampling_method == 'temporal_stride': 20 | frame_samples = np.arange(0, in_nframes, cfg.TEMPORAL_STRIDE[0]) 21 | if len(frame_samples) < self.nf: 22 | frame_samples = np.linspace(0, in_nframes, self.nf, endpoint=False).astype(np.int) 23 | 24 | elif self.frame_sampling_method == 'random': 25 | frame_samples = np.random.permutation(in_nframes) 26 | 27 | elif self.frame_sampling_method == 'temporal_stride_random': 28 | temporal_stride = np.random.randint(cfg.TEMPORAL_STRIDE[0], cfg.TEMPORAL_STRIDE[1]) 29 | frame_samples = np.arange(0, in_nframes, temporal_stride) 30 | 31 | elif self.frame_sampling_method == 'f25': 32 | frame_samples = np.linspace(0, in_nframes, 25, endpoint=False).round().tolist() 33 | 34 | else: 35 | raise NotImplementedError 36 | 37 | # check the under or over frame sample list length. 38 | if len(frame_samples) < self.nf: 39 | add_frames, difference = 0, self.nf - len(frame_samples) 40 | while difference > 0: 41 | next_len = len(frame_samples) 42 | add_samples = np.linspace(cfg.TEMPORAL_INPUT_SIZE / 4, in_nframes - cfg.TEMPORAL_INPUT_SIZE / 4, 43 | next_len, endpoint=False) 44 | add_samples = add_samples.round().tolist() 45 | add_samples = add_samples[np.random.randint(1, len(add_samples), 1)[0]] 46 | if add_frames > 20: 47 | frame_samples = np.linspace(cfg.TEMPORAL_INPUT_SIZE / 4, in_nframes - cfg.TEMPORAL_INPUT_SIZE / 4, 48 | self.nf, endpoint=False) 49 | frame_samples = frame_samples.round().tolist() 50 | else: 51 | frame_samples = np.append(frame_samples, add_samples) 52 | difference = self.nf - len(frame_samples) 53 | add_frames += 1 54 | frame_samples = np.sort(frame_samples) 55 | elif len(frame_samples) > self.nf: 56 | start = np.random.randint(0, len(frame_samples)-self.nf) 57 | frame_samples = frame_samples[start:start+self.nf] 58 | else: 59 | pass 60 | 61 | if self.frame_sampling_method == 'random': 62 | frame_samples = np.sort(frame_samples) 63 | 64 | assert (frame_samples == np.sort(frame_samples)).all() # ensure the frame numbers are sorted 65 | 66 | return frame_samples 67 | -------------------------------------------------------------------------------- /lib/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as nn_init 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | import numpy as np 7 | from torch.nn import Parameter 8 | 9 | __all__ = ['ResNet3D', 'resnet50'] 10 | 11 | 12 | model_urls = { 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | } 15 | 16 | 17 | # ============================= 18 | # ********* 3D ResNet ********* 19 | # ============================= 20 | def conv1x3x3(in_planes, out_planes, stride=1): 21 | """1x3x3 convolution with padding""" 22 | return nn.Conv3d(in_planes, out_planes, 23 | kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, 1, 1), bias=False) 24 | 25 | 26 | class BasicBlock3D(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock3D, self).__init__() 31 | self.conv1 = conv1x3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm3d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv1x3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm3d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x, residual=None): 40 | if residual is None: 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | def init_temporal(self, strategy): 59 | raise NotImplementedError 60 | 61 | 62 | class Bottleneck3D(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, bias=False, head_conv=1): 66 | super().__init__() 67 | 68 | if head_conv == 1: 69 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = nn.BatchNorm3d(planes) 71 | elif head_conv == 3: 72 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), bias=False, padding=(1, 0, 0)) 73 | self.bn1 = nn.BatchNorm3d(planes) 74 | else: 75 | raise ValueError("Unsupported head_conv!") 76 | 77 | self.conv2 = nn.Conv3d(planes, planes, 78 | kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, 1, 1), bias=bias) 79 | self.bn2 = nn.BatchNorm3d(planes) 80 | 81 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 82 | self.bn3 = nn.BatchNorm3d(planes * 4) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | residual = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | residual = self.downsample(x) 103 | 104 | out += residual 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet3D(nn.Module): 111 | def __init__(self, block, layers, **kwargs): 112 | super().__init__() 113 | in_channels, num_classes = kwargs['in_channels'], kwargs['num_classes'] 114 | self.alpha = kwargs['alpha'] 115 | self.slow = kwargs['slow'] # slow->1 else fast->0 116 | self.t2s_mul = kwargs['t2s_mul'] 117 | self.inplanes = (64 + 64//self.alpha*self.t2s_mul) if self.slow else 64//self.alpha 118 | self.conv1 = nn.Conv3d(in_channels, 64//(1 if self.slow else self.alpha), 119 | kernel_size=(1 if self.slow else 5, 7, 7), 120 | stride=(1, 2, 2), padding=(0 if self.slow else 2, 3, 3), bias=False) 121 | self.bn1 = nn.BatchNorm3d(64//(1 if self.slow else self.alpha)) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 124 | self.layer1 = self._make_layer(block, 64//(1 if self.slow else self.alpha), layers[0], 125 | head_conv=1 if self.slow else 3) 126 | self.layer2 = self._make_layer(block, 128//(1 if self.slow else self.alpha), layers[1], stride=2, 127 | head_conv=1 if self.slow else 3) 128 | self.layer3 = self._make_layer(block, 256//(1 if self.slow else self.alpha), layers[2], stride=2, 129 | head_conv=3) 130 | self.layer4 = self._make_layer(block, 512//(1 if self.slow else self.alpha), layers[3], stride=2, 131 | head_conv=3) 132 | 133 | def init_params(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv3d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | # nn_init.normal_(m.weight) 139 | # nn_init.xavier_normal_(m.weight) 140 | nn_init.kaiming_normal_(m.weight) 141 | if m.bias: 142 | nn_init.constant_(m.bias, 0) 143 | elif isinstance(m, nn.BatchNorm3d): 144 | # m.weight.data.fill_(1) 145 | # m.bias.data.zero_() 146 | nn_init.constant_(m.weight, 1) 147 | nn_init.constant_(m.bias, 0) 148 | 149 | def forward(self, x): 150 | raise NotImplementedError('use each pathway network\' forward function') 151 | 152 | def _make_layer(self, block, planes, blocks, stride=1, head_conv=1): 153 | downsample = None 154 | if stride != 1 or self.inplanes != planes * block.expansion: 155 | downsample = nn.Sequential( 156 | nn.Conv3d( 157 | self.inplanes, 158 | planes * block.expansion, 159 | kernel_size=1, 160 | stride=(1, stride, stride), 161 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 162 | 163 | layers = list() 164 | layers.append(block(self.inplanes, planes, stride, downsample, head_conv=head_conv)) 165 | self.inplanes = planes * block.expansion 166 | for i in range(1, blocks): 167 | layers.append(block(self.inplanes, planes, head_conv=head_conv)) 168 | 169 | self.inplanes += self.slow * block.expansion * planes // self.alpha * self.t2s_mul 170 | 171 | return nn.Sequential(*layers) 172 | 173 | 174 | def resnet50(pretrained=False, **kwargs): 175 | """Constructs a ResNet-50 model. 176 | 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet3D(Bottleneck3D, [3, 4, 6, 3], **kwargs) 181 | if pretrained: 182 | pt = model_zoo.load_url(model_urls['resnet50']) 183 | pt.pop('conv1.weight') 184 | pt.pop('fc.weight') 185 | pt.pop('fc.bias') 186 | model_dict = model.state_dict() 187 | model_dict.update(pt) 188 | model.load_state_dict(model_dict) 189 | 190 | return model 191 | -------------------------------------------------------------------------------- /lib/netwrapper/fast_pathway.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | 3 | import models.resnet 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FastNet(models.resnet.ResNet3D): 9 | def __init__(self, block, layers, **kwargs): 10 | super().__init__(block, layers, **kwargs) 11 | self.l_maxpool = nn.Conv3d(64//self.alpha, 64//self.alpha*self.t2s_mul, 12 | kernel_size=(5, 1, 1), stride=(8, 1, 1), bias=False, padding=(2, 0, 0)) 13 | self.l_layer1 = nn.Conv3d(4*64//self.alpha, 4*64//self.alpha*self.t2s_mul, 14 | kernel_size=(5, 1, 1), stride=(8, 1, 1), bias=False, padding=(2, 0, 0)) 15 | self.l_layer2 = nn.Conv3d(8*64//self.alpha, 8*64//self.alpha*self.t2s_mul, 16 | kernel_size=(5, 1, 1), stride=(8, 1, 1), bias=False, padding=(2, 0, 0)) 17 | self.l_layer3 = nn.Conv3d(16*64//self.alpha, 16*64//self.alpha*self.t2s_mul, 18 | kernel_size=(5, 1, 1), stride=(8, 1, 1), bias=False, padding=(2, 0, 0)) 19 | self.init_params() 20 | 21 | def forward(self, x): 22 | laterals = [] 23 | 24 | x = self.conv1(x) 25 | x = self.bn1(x) 26 | x = self.relu(x) 27 | x = self.maxpool(x) 28 | laterals.append(self.l_maxpool(x)) 29 | 30 | x = self.layer1(x) 31 | laterals.append(self.l_layer1(x)) 32 | 33 | x = self.layer2(x) 34 | laterals.append(self.l_layer2(x)) 35 | 36 | x = self.layer3(x) 37 | laterals.append(self.l_layer3(x)) 38 | 39 | x = self.layer4(x) 40 | 41 | x = F.adaptive_avg_pool3d(x, 1) 42 | x = x.view(-1, x.size(1)) 43 | 44 | return x, laterals 45 | 46 | 47 | def resnet50_f(**kwargs): 48 | """Constructs a ResNet-50 model. 49 | """ 50 | model = FastNet(models.resnet.Bottleneck3D, [3, 4, 6, 3], **kwargs) 51 | 52 | return model 53 | -------------------------------------------------------------------------------- /lib/netwrapper/slow_pathway.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import models.resnet 7 | 8 | 9 | class SlowNet(models.resnet.ResNet3D): 10 | def __init__(self, block, layers, **kwargs): 11 | super().__init__(block, layers, **kwargs) 12 | self.init_params() 13 | 14 | def forward(self, x): 15 | x, laterals = x 16 | 17 | x = self.conv1(x) 18 | x = self.bn1(x) 19 | x = self.relu(x) 20 | x = self.maxpool(x) 21 | 22 | x = torch.cat([x, laterals[0]], dim=1) 23 | x = self.layer1(x) 24 | 25 | x = torch.cat([x, laterals[1]], dim=1) 26 | x = self.layer2(x) 27 | 28 | x = torch.cat([x, laterals[2]], dim=1) 29 | x = self.layer3(x) 30 | 31 | x = torch.cat([x, laterals[3]], dim=1) 32 | x = self.layer4(x) 33 | 34 | x = F.adaptive_avg_pool3d(x, 1) 35 | x = x.view(-1, x.size(1)) 36 | 37 | return x 38 | 39 | 40 | def resnet50_s(**kwargs): 41 | """Constructs a ResNet-50 model. 42 | """ 43 | model = SlowNet(models.resnet.Bottleneck3D, [3, 4, 6, 3], **kwargs) 44 | 45 | return model 46 | -------------------------------------------------------------------------------- /lib/netwrapper/two_stream_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.config import cfg 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as nn_init 7 | import torch.optim as optim 8 | 9 | from netwrapper.slow_pathway import resnet50_s 10 | from netwrapper.fast_pathway import resnet50_f 11 | 12 | 13 | # noinspection PyProtectedMember 14 | class StepLRestart(optim.lr_scheduler._LRScheduler): 15 | """The same as StepLR, but this one has restart. 16 | """ 17 | def __init__(self, optimizer, step_size, restart_size, gamma=0.1, last_epoch=-1): 18 | self.step_size = step_size 19 | self.restart_size = restart_size 20 | assert self.restart_size > self.step_size 21 | self.gamma = gamma 22 | super(StepLRestart, self).__init__(optimizer, last_epoch) 23 | 24 | def get_lr(self): 25 | return [base_lr * self.gamma ** ((self.last_epoch % self.restart_size) // self.step_size) 26 | for base_lr in self.base_lrs] 27 | 28 | 29 | class TwoStreamNet(nn.Module): 30 | 31 | def __init__(self, device): 32 | super().__init__() 33 | self.slow_net, self.fast_net = None, None 34 | self.criterion, self.optimizer, self.scheduler = None, None, None 35 | 36 | self.create_load(device) 37 | 38 | self.dropout = nn.Dropout(cfg.SLOWFAST.DP).to(device) 39 | self.fc = nn.Linear(4*512 + 4*512//cfg.SLOWFAST.ALPHA, cfg.NUM_CLASSES, 40 | bias=False).to(device) 41 | # nn_init.normal_(self.fc.weight) 42 | # nn_init.xavier_normal_(self.fc.weight) 43 | nn_init.kaiming_normal_(self.fc.weight) 44 | 45 | self.setup_optimizer() 46 | 47 | def create_load(self, device): 48 | if cfg.PRETRAINED_MODE == 'Custom': 49 | self.create_net() 50 | self.load(cfg.PT_PATH) 51 | else: 52 | self.create_net() 53 | 54 | self.slow_net = self.slow_net.to(device) 55 | self.fast_net = self.fast_net.to(device) 56 | 57 | def create_net(self): 58 | self.slow_net = resnet50_s(**{ 59 | 'in_channels': cfg.CHANNEL_INPUT_SIZE, 60 | 'num_classes': cfg.NUM_CLASSES, 61 | 'alpha': cfg.SLOWFAST.ALPHA, 62 | 'slow': 1, 63 | 't2s_mul': cfg.SLOWFAST.T2S_MUL, 64 | }) 65 | self.fast_net = resnet50_f(**{ 66 | 'in_channels': cfg.CHANNEL_INPUT_SIZE, 67 | 'num_classes': cfg.NUM_CLASSES, 68 | 'alpha': cfg.SLOWFAST.ALPHA, 69 | 'slow': 0, 70 | 't2s_mul': cfg.SLOWFAST.T2S_MUL, 71 | }) 72 | 73 | def setup_optimizer(self): 74 | self.criterion = nn.CrossEntropyLoss() 75 | 76 | self.optimizer = optim.SGD(params=self.parameters(), 77 | lr=cfg.TRAIN.LR, 78 | weight_decay=cfg.TRAIN.WEIGHT_DECAY, 79 | momentum=cfg.TRAIN.MOMENTUM, 80 | nesterov=cfg.TRAIN.NESTEROV) 81 | 82 | if cfg.TRAIN.SCHEDULER_MODE: 83 | if cfg.TRAIN.SCHEDULER_TYPE == 'step': 84 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=cfg.TRAIN.SCHEDULER_STEP_MILESTONE, 85 | gamma=0.1) 86 | elif cfg.TRAIN.SCHEDULER_TYPE == 'step_restart': 87 | self.scheduler = StepLRestart(self.optimizer, step_size=4, restart_size=8, gamma=0.1) 88 | elif cfg.TRAIN.SCHEDULER_TYPE == 'multi': 89 | self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, 90 | milestones=cfg.TRAIN.SCHEDULER_MULTI_MILESTONE, 91 | gamma=0.1) 92 | elif cfg.TRAIN.SCHEDULER_TYPE == 'lambda': 93 | def lr_lambda(e): return 1 if e < 5 else .5 if e < 10 else .1 if e < 15 else .01 94 | self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda) 95 | elif cfg.TRAIN.SCHEDULER_TYPE == 'plateau': 96 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.1, patience=5, 97 | cooldown=0, 98 | verbose=True) 99 | else: 100 | raise NotImplementedError 101 | 102 | def schedule_step(self, metric=None): 103 | if cfg.TRAIN.SCHEDULER_MODE: 104 | if cfg.TRAIN.SCHEDULER_TYPE in ['step', 'step_restart', 'multi', 'lambda']: 105 | self.scheduler.step() 106 | if cfg.TRAIN.SCHEDULER_TYPE == 'plateau': 107 | self.scheduler.step(metric['loss'].avg) 108 | 109 | def save(self, file_path, e): 110 | torch.save(self.state_dict(), os.path.join(file_path, '{:03d}.pth'.format(e))) 111 | 112 | def load(self, file_path): 113 | self.load_state_dict(torch.load(file_path)) 114 | 115 | def forward(self, x): 116 | x_slow, x_fast = x 117 | x_fast, laterals = self.fast_net(x_fast) 118 | x_slow = self.slow_net((x_slow, laterals)) 119 | 120 | x = torch.cat([x_slow, x_fast], dim=1) 121 | 122 | x = self.dropout(x) 123 | x = self.fc(x) 124 | 125 | return x 126 | 127 | def loss_update(self, p, a, step=True): 128 | loss = self.criterion(p, a) 129 | 130 | if step: 131 | loss.backward() 132 | self.optimizer.step() 133 | self.optimizer.zero_grad() 134 | 135 | return loss.item() 136 | -------------------------------------------------------------------------------- /lib/utils/config.py: -------------------------------------------------------------------------------- 1 | """Config file setting hyperparameters 2 | 3 | This file specifies default config options. You should not 4 | change values in this file. Instead, you should write a config file (in yaml) 5 | and use cfg_from_file(yaml_file) to load it and override the default options. 6 | """ 7 | 8 | from easydict import EasyDict as edict 9 | import os 10 | import datetime 11 | import socket 12 | 13 | __C = edict() 14 | cfg = __C # from config.py import cfg 15 | 16 | 17 | # ================ 18 | # GENERAL 19 | # ================ 20 | 21 | # Set modes 22 | __C.TRAINING = True 23 | __C.VALIDATING = True 24 | 25 | # Root directory of project 26 | __C.ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) 27 | 28 | # Data directory 29 | __C.DATASET_DIR = os.path.abspath(os.path.join(__C.ROOT_DIR, 'dataset')) 30 | 31 | # Model directory 32 | __C.MODELS_DIR = os.path.abspath(os.path.join(__C.ROOT_DIR, 'lib', 'models')) 33 | 34 | # Experiment directory 35 | __C.EXPERIMENT_DIR = os.path.abspath(os.path.join(__C.ROOT_DIR, 'experiment')) 36 | 37 | # Set meters to use for experimental evaluation 38 | __C.METERS = ['loss', 'label_accuracy'] 39 | 40 | # Use GPU 41 | __C.USE_GPU = True 42 | 43 | # Default GPU device id 44 | __C.GPU_ID = 0 45 | 46 | # Number of epochs 47 | __C.NUM_EPOCH = 40 48 | 49 | # Dataset name 50 | __C.DATASET_NAME = ('UCF101', )[0] 51 | 52 | if __C.DATASET_NAME == 'UCF101': 53 | __C.SPLIT_NO = 1 54 | 55 | # Number of categories 56 | __C.NUM_CLASSES = 101 57 | 58 | __C.DATASET_ROOT = os.path.join(__C.DATASET_DIR, __C.DATASET_NAME) 59 | 60 | # Normalize database samples according to some mean and std values 61 | __C.DATASET_NORM = True 62 | 63 | # Input data size 64 | __C.SPATIAL_INPUT_SIZE = (112, 112) 65 | __C.CHANNEL_INPUT_SIZE = 3 66 | 67 | # Set parameters for snapshot and verbose routines 68 | __C.MODEL_ID = datetime.datetime.now().strftime('%Y%m%d_%H%M%S_%f') 69 | __C.SNAPSHOT = True 70 | __C.SNAPSHOT_INTERVAL = 5 71 | __C.VERBOSE = True 72 | __C.VERBOSE_INTERVAL = 10 73 | __C.VALID_INTERVAL = 1 74 | 75 | # Network Architecture 76 | __C.NET_ARCH = ('resnet', )[0] 77 | 78 | # Pre-trained network 79 | __C.PRETRAINED_MODE = (None, 'Custom')[0] 80 | 81 | # Path to the pre-segmentation network 82 | __C.PT_PATH = os.path.join(__C.EXPERIMENT_DIR, 'snapshot', '20181010_124618_219443', '079.pt') 83 | 84 | # ============================= 85 | # Spatiotemporal ResNet options 86 | # ============================= 87 | __C.RST = edict() 88 | 89 | __C.FRAME_SAMPLING_METHOD = ('uniform', 'temporal_stride', 'random', 'temporal_stride_random')[1] 90 | __C.NFRAMES_PER_VIDEO = 64 # T x tau 91 | __C.TEMPORAL_STRIDE = (1, 25) 92 | __C.FRAME_RANDOMIZATION = False 93 | 94 | # ============================= 95 | # SlowFast ResNet options 96 | # ============================= 97 | __C.SLOWFAST = edict() 98 | 99 | # T = NFRAMES_PER_VIDEO // TAU 100 | __C.SLOWFAST.TAU = 16 101 | __C.SLOWFAST.ALPHA = 8 102 | __C.SLOWFAST.T2S_MUL = 2 103 | __C.SLOWFAST.DP = 0.5 104 | 105 | # ================ 106 | # Training options 107 | # ================ 108 | if __C.TRAINING: 109 | __C.TRAIN = edict() 110 | 111 | # Images to use per minibatch 112 | __C.TRAIN.BATCH_SIZE = 32 113 | 114 | # Shuffle the dataset 115 | __C.TRAIN.SHUFFLE = True 116 | 117 | # Learning parameters are set below 118 | __C.TRAIN.LR = 1e-3 119 | __C.TRAIN.WEIGHT_DECAY = 1e-5 120 | __C.TRAIN.MOMENTUM = 0.90 121 | __C.TRAIN.NESTEROV = False 122 | __C.TRAIN.SCHEDULER_MODE = False 123 | __C.TRAIN.SCHEDULER_TYPE = ('step', 'step_restart', 'multi', 'lambda', 'plateau')[0] 124 | __C.TRAIN.SCHEDULER_STEP_MILESTONE = 10 125 | __C.TRAIN.SCHEDULER_MULTI_MILESTONE = [10] 126 | 127 | # ================ 128 | # Validation options 129 | # ================ 130 | if __C.VALIDATING: 131 | __C.VALID = edict() 132 | 133 | # Images to use per minibatch 134 | __C.VALID.BATCH_SIZE = __C.TRAIN.BATCH_SIZE 135 | 136 | # Shuffle the dataset 137 | __C.VALID.SHUFFLE = False 138 | -------------------------------------------------------------------------------- /lib/utils/config_file_handling.py: -------------------------------------------------------------------------------- 1 | """Config file handling module 2 | 3 | This file specifies file handling routing to manipulate configurations. 4 | """ 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | import os 8 | import yaml 9 | from utils.config import cfg 10 | 11 | 12 | def _merge_a_into_b(a, b): 13 | """Merge config dictionary a into config dictionary b, clobbering the 14 | options in b whenever they are also specified in a. 15 | """ 16 | if type(a) is not edict: 17 | return 18 | 19 | for k, v in a.items(): 20 | # a must specify keys that are in b 21 | if k not in b: 22 | raise KeyError('{} is not a valid config key'.format(k)) 23 | 24 | # the types must match, too 25 | old_type = type(b[k]) 26 | if old_type is not type(v): 27 | if isinstance(b[k], np.ndarray): 28 | v = np.array(v, dtype=b[k].dtype) 29 | else: 30 | raise ValueError(('Type mismatch ({} vs. {}) ' 31 | 'for config key: {}').format(type(b[k]), 32 | type(v), k)) 33 | 34 | # recursively merge dicts 35 | if type(v) is edict: 36 | try: 37 | _merge_a_into_b(a[k], b[k]) 38 | except Exception: 39 | print('Error under config key: {}'.format(k)) 40 | raise 41 | else: 42 | b[k] = v 43 | 44 | 45 | def cfg_from_file(filename): 46 | """Load a config file and merge it into the default options.""" 47 | import yaml 48 | with open(filename, 'r') as f: 49 | yaml_cfg = edict(yaml.load(f)) 50 | 51 | _merge_a_into_b(yaml_cfg, cfg) 52 | 53 | 54 | # noinspection PyBroadException 55 | def cfg_from_list(cfg_list): 56 | """Set config keys via list (e.g., from command line).""" 57 | from ast import literal_eval 58 | assert len(cfg_list) % 2 == 0 59 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 60 | key_list = k.upper().split('.') 61 | d = cfg 62 | for subkey in key_list[:-1]: 63 | assert subkey in d 64 | d = d[subkey] 65 | subkey = key_list[-1] 66 | assert subkey in d 67 | try: 68 | value = literal_eval(v) 69 | except Exception: 70 | # handle the case when v is a string literal 71 | value = v 72 | assert isinstance(value, type(d[subkey])), 'type {} does not match original type {}'.format( 73 | type(value), type(d[subkey]) 74 | ) 75 | d[subkey] = value 76 | 77 | 78 | def cfg_to_file(cfg_in, path_in, name_in): 79 | with open(os.path.join(path_in, '{}.yml'.format(name_in)), 'w') as output_file: 80 | yaml.dump(cfg_in, output_file, default_flow_style=False) 81 | -------------------------------------------------------------------------------- /lib/utils/miscellaneous.py: -------------------------------------------------------------------------------- 1 | # This is meant to contain miscellaneous functions and routines 2 | # such as showing, saving, loading images and results. 3 | import sys 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | 8 | 9 | class AverageMeter: 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.val, self.avg, self.sum, self.count = (0,)*4 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def print_size(path, image_size=(342, 256)): 29 | check_all = False 30 | outliers = [] 31 | for m, n in enumerate(os.listdir(path)): 32 | dir_name = os.path.join(path, n) 33 | if os.path.isfile(dir_name): 34 | continue 35 | if check_all: 36 | for k, i in enumerate(os.listdir(dir_name)): 37 | image = Image.open(os.path.join(dir_name, i)) 38 | assert image_size == image.size, '{2}:{0}|{1}'.format(image_size, image.size, n) 39 | print('{2}/{3}:{0}/{1}'.format(k, len(os.listdir(dir_name)), m, len(os.listdir(path)))) 40 | else: 41 | dir_content = os.listdir(dir_name) 42 | image_name = np.random.choice(dir_content) 43 | image_path = os.path.join(dir_name, image_name) 44 | image = Image.open(image_path) 45 | if not image_size == image.size: 46 | print('---one outlier detected---') 47 | outliers.append('{2}:{0}|{1}'.format(image_size, image.size, n)) 48 | print('{0}/{1}:{2}'.format(m, len(os.listdir(path)), n)) 49 | for o in outliers: 50 | print(o) 51 | 52 | 53 | if __name__ == '__main__': 54 | if sys.argv[1] == 'print_size': 55 | print_size(sys.argv[2], (int(sys.argv[3]), int(sys.argv[4]))) 56 | -------------------------------------------------------------------------------- /tools/_init_lib_path.py: -------------------------------------------------------------------------------- 1 | 2 | """ Simply add lib path to sys for STNet """ 3 | 4 | import os.path 5 | import sys 6 | 7 | 8 | def add_path(path): 9 | if path not in sys.path: 10 | sys.path.insert(0, path) 11 | 12 | 13 | this_dir = os.path.dirname(__file__) 14 | lib_path = os.path.join(this_dir, '..', 'lib') 15 | add_path(os.path.normpath(lib_path)) 16 | -------------------------------------------------------------------------------- /tools/epoch_loop.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | from utils.config_file_handling import cfg_to_file 3 | 4 | import os 5 | import time 6 | from trainer import Trainer 7 | from validator import Validator 8 | 9 | import torch 10 | from tensorboardX import SummaryWriter 11 | from netwrapper.two_stream_net import TwoStreamNet 12 | 13 | started_time = time.time() 14 | 15 | 16 | class EpochLoop: 17 | def __init__(self): 18 | self.trainer, self.validator = None, None 19 | self.device, self.net = None, None 20 | self.logger_writer = None 21 | 22 | self.setup_gpu() 23 | 24 | def setup_gpu(self): 25 | cuda_device_id = cfg.GPU_ID 26 | if cfg.USE_GPU and torch.cuda.is_available(): 27 | self.device = torch.device('cuda:{}'.format(cuda_device_id)) 28 | else: 29 | self.device = torch.device('cpu') 30 | 31 | def setup_logger(self): 32 | logger_dir = os.path.join(cfg.EXPERIMENT_DIR, 33 | 'logger_{}_{}'.format(cfg.DATASET_NAME, cfg.NET_ARCH), 34 | cfg.MODEL_ID) 35 | if not os.path.exists(logger_dir): 36 | os.makedirs(logger_dir) 37 | self.logger_writer = SummaryWriter(logger_dir) 38 | cfg_to_file(cfg, logger_dir, '{}_cfg'.format(cfg.MODEL_ID)) 39 | 40 | def logger_update(self, e, mode): 41 | if e == 0: 42 | self.setup_logger() 43 | if mode == 'train' and self.trainer: 44 | for k, m_avg in self.trainer.get_avg(): 45 | self.logger_writer.add_scalar('{}/{}'.format('train', k), m_avg, e) 46 | if mode == 'valid' and self.validator: 47 | for k, m_avg in self.validator.get_avg(): 48 | self.logger_writer.add_scalar('{}/{}'.format('valid', k), m_avg, e) 49 | 50 | def check_if_save_snapshot(self, e): 51 | if cfg.SNAPSHOT and (e + 1) % cfg.SNAPSHOT_INTERVAL == 0: 52 | file_path = os.path.join(cfg.EXPERIMENT_DIR, 53 | 'snapshot_{}_{}'.format(cfg.DATASET_NAME, cfg.NET_ARCH), 54 | cfg.MODEL_ID) 55 | if not os.path.exists(file_path): 56 | os.makedirs(file_path) 57 | 58 | self.net.save(file_path, e) 59 | 60 | def check_if_validating(self, e): 61 | if cfg.VALIDATING and (e + 1) % cfg.VALID_INTERVAL == 0: 62 | self.validator_epoch_loop(e) 63 | 64 | def main(self): 65 | self.create_sets() 66 | self.setup_net() 67 | self.run() 68 | 69 | def create_sets(self): 70 | self.trainer = Trainer('train', cfg.METERS, self.device) if cfg.TRAINING else None 71 | self.validator = Validator('valid', cfg.METERS, self.device) if cfg.VALIDATING else None 72 | 73 | def setup_net(self): 74 | self.net = TwoStreamNet(self.device) 75 | 76 | def run(self): 77 | if cfg.TRAINING: 78 | self.trainer_epoch_loop() 79 | elif cfg.VALIDATING: 80 | self.validator_epoch_loop(0) 81 | elif cfg.TESTING: 82 | raise NotImplementedError('TESTING mode is not implemented yet') 83 | else: 84 | raise NotImplementedError('One of {TRAINING, VALIDATING, TESTING} must be set to True') 85 | 86 | def trainer_epoch_loop(self): 87 | for e in range(cfg.NUM_EPOCH): 88 | self.trainer.set_net_mode(self.net) 89 | 90 | self.trainer.reset_meters() 91 | 92 | self.trainer.batch_loop(self.net, e, started_time) 93 | 94 | self.check_if_save_snapshot(e) 95 | 96 | self.check_if_validating(e) 97 | 98 | self.logger_update(e, mode='train') 99 | 100 | self.net.schedule_step(metric=self.validator.meters) 101 | 102 | def validator_epoch_loop(self, e): 103 | self.validator.set_net_mode(self.net) 104 | 105 | self.validator.reset_meters() 106 | 107 | self.validator.batch_loop(self.net, e, started_time) 108 | 109 | self.logger_update(e, mode='valid') 110 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import _init_lib_path 2 | 3 | import os 4 | from datetime import datetime 5 | import datetime as dt 6 | import time 7 | 8 | from utils.config import cfg 9 | from epoch_loop import EpochLoop 10 | 11 | import argparse 12 | from utils.config_file_handling import cfg_from_file, cfg_from_list 13 | from pprint import PrettyPrinter 14 | 15 | pp = PrettyPrinter(indent=4) 16 | cfg.TRAINING = False 17 | cfg.VALIDATING = True 18 | cfg.PRETRAINED_MODE = 'Custom' 19 | 20 | 21 | def parse_args(): 22 | """ 23 | Parse input arguments 24 | """ 25 | parser = argparse.ArgumentParser(description='Testing the network') 26 | 27 | parser.add_argument('-d', '--dataset-dir', dest='dataset_dir', 28 | help='dataset directory', type=str, required=False) 29 | parser.add_argument('-e', '--experiment-dir', dest='experiment_dir', 30 | help='a directory used to write experiment results', type=str, required=False) 31 | parser.add_argument('-i', '--pre-trained-id', dest='pt_id', 32 | help='the pre-trained network id that you want to load for testing', type=str, required=True) 33 | parser.add_argument('-p', '--pre-trained-epoch', dest='pt_epoch', 34 | help='the epoch at which a snapshot for the id is taken', type=str, required=True) 35 | parser.add_argument('-u', '--use-gpu', dest='use_gpu', 36 | help='whether to use gpu for the net inference', type=int, required=False) 37 | parser.add_argument('-g', '--gpu-id', dest='gpu_id', 38 | help='gpu id to use', type=int, required=False) 39 | parser.add_argument('-c', '--cfg', dest='cfg_file', 40 | help='optional config file to override the defaults', default=None, type=str) 41 | parser.add_argument('-s', '--set', dest='set_cfg', 42 | help='set config arg parameters', default=None, nargs=argparse.REMAINDER) 43 | return parser.parse_args() 44 | 45 | 46 | def set_positional_cfg(args_in): 47 | args_list = [] 48 | for n, a in args_in.__dict__.items(): 49 | if a is not None and n not in ['cfg_file', 'set_cfg']: 50 | args_list += [n, a] 51 | return args_list 52 | 53 | 54 | def main(): 55 | epoch_loop = EpochLoop() 56 | 57 | try: 58 | epoch_loop.main() 59 | except KeyboardInterrupt: 60 | print('*** The experiment is terminated by a keyboard interruption') 61 | 62 | 63 | if __name__ == '__main__': 64 | args = parse_args() 65 | 66 | print('Called with args:') 67 | print(args) 68 | 69 | if args.cfg_file is not None: 70 | cfg_from_file(args.cfg_file) 71 | if args.set_cfg is not None: 72 | cfg_from_list(args.set_cfg) 73 | 74 | cfg_from_list(set_positional_cfg(args)) # input arguments override cfg files and defaults 75 | 76 | cfg.PT_PATH = os.path.join(cfg.EXPERIMENT_DIR, 77 | 'snapshot_{}_{}'.format(cfg.DATASET_NAME, cfg.NET_ARCH), 78 | args.pt_id, 79 | '{:03}.pt'.format(args.pt_epoch)) 80 | 81 | print('configuration file cfg is loaded for testing ...') 82 | pp.pprint(cfg) 83 | 84 | started_time = time.time() 85 | print('*** started @', datetime.now()) 86 | main() 87 | length = time.time() - started_time 88 | print('*** ended @', datetime.now()) 89 | print('took', dt.timedelta(seconds=int(length))) 90 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import _init_lib_path 2 | 3 | from datetime import datetime 4 | import datetime as dt 5 | import time 6 | import os 7 | 8 | from utils.config import cfg 9 | from epoch_loop import EpochLoop 10 | 11 | import argparse 12 | from utils.config_file_handling import cfg_from_file, cfg_from_list 13 | from pprint import PrettyPrinter 14 | 15 | pp = PrettyPrinter(indent=4) 16 | cfg.TRAINING = True 17 | 18 | 19 | def parse_args(): 20 | """ 21 | Parse input arguments 22 | """ 23 | parser = argparse.ArgumentParser(description='Training the network') 24 | 25 | parser.add_argument('-d', '--dataset-dir', dest='dataset_dir', 26 | help='dataset directory', type=str, required=False) 27 | parser.add_argument('-e', '--experiment-dir', dest='experiment_dir', 28 | help='a directory used to write experiment results', type=str, required=False) 29 | parser.add_argument('-i', '--pre-trained-id', dest='pt_id', 30 | help='the pre-trained network id that you want to load for testing', type=str, required=False) 31 | parser.add_argument('-p', '--pre-trained-epoch', dest='pt_epoch', 32 | help='the epoch at which a snapshot for the id is taken', type=str, required=False) 33 | parser.add_argument('-u', '--use-gpu', dest='use_gpu', 34 | help='whether to use gpu for the net inference', type=int, required=False) 35 | parser.add_argument('-g', '--gpu-id', dest='gpu_id', 36 | help='gpu id to use', type=int, required=False) 37 | parser.add_argument('-c', '--cfg', dest='cfg_file', 38 | help='optional config file to override the defaults', default=None, type=str) 39 | parser.add_argument('-s', '--set', dest='set_cfg', 40 | help='set config arg parameters', default=None, nargs=argparse.REMAINDER) 41 | return parser.parse_args() 42 | 43 | 44 | def set_positional_cfg(args_in): 45 | args_list = [] 46 | for n, a in args_in.__dict__.items(): 47 | if a is not None and n not in ['cfg_file', 'set_cfg', 'pt_id', 'pt_epoch']: 48 | args_list += [n, a] 49 | return args_list 50 | 51 | 52 | def main(): 53 | epoch_loop = EpochLoop() 54 | 55 | try: 56 | epoch_loop.main() 57 | except KeyboardInterrupt: 58 | print('*** The experiment is terminated by a keyboard interruption') 59 | 60 | 61 | if __name__ == '__main__': 62 | args = parse_args() 63 | 64 | print('Called with args:') 65 | print(args) 66 | 67 | if args.cfg_file is not None: 68 | cfg_from_file(args.cfg_file) 69 | if args.set_cfg is not None: 70 | cfg_from_list(args.set_cfg) 71 | 72 | cfg_from_list(set_positional_cfg(args)) # input arguments override cfg files and defaults 73 | 74 | if args.pt_id: 75 | assert args.pt_epoch, 'you must set the epoch for the pre-trained model' 76 | cfg.PT_PATH = os.path.join(cfg.EXPERIMENT_DIR, 77 | 'snapshot_{}_{}'.format(cfg.DATASET_NAME, cfg.NET_ARCH), 78 | args.pt_id, 79 | '{:03}.pt'.format(int(args.pt_epoch))) 80 | cfg.PRETRAINED_MODE = 'Custom' 81 | 82 | print('configuration file cfg is loaded for training ...') 83 | pp.pprint(cfg) 84 | 85 | started_time = time.time() 86 | print('*** started @', datetime.now()) 87 | main() 88 | length = time.time() - started_time 89 | print('*** ended @', datetime.now()) 90 | print('took', dt.timedelta(seconds=int(length))) 91 | -------------------------------------------------------------------------------- /tools/trainer.py: -------------------------------------------------------------------------------- 1 | from tv_abc import TVBase 2 | 3 | 4 | class Trainer(TVBase): 5 | def __init__(self, mode, meters, device): 6 | super().__init__(mode, meters, device) 7 | 8 | def set_net_mode(self, net): 9 | net.train() 10 | 11 | def batch_main(self, net, x_slow, x_fast, annotation): 12 | p = net.forward((x_slow, x_fast)) 13 | 14 | a = self.generate_gt(annotation) 15 | 16 | loss = net.loss_update(p, a, step=True) 17 | 18 | acc = self.evaluate(p, a) 19 | 20 | return {'loss': loss, 21 | 'label_accuracy': acc} 22 | -------------------------------------------------------------------------------- /tools/tv_abc.py: -------------------------------------------------------------------------------- 1 | from utils.config import cfg 2 | import time 3 | from abc import ABC, abstractmethod 4 | import datetime 5 | import torch 6 | 7 | from data.data_container import DataContainer 8 | from utils.miscellaneous import AverageMeter 9 | 10 | 11 | class TVBase(ABC): 12 | def __init__(self, mode, meters, device): 13 | self.data_container = None 14 | assert mode in ('train', 'valid') 15 | self.mode, self.device = mode, device 16 | self.meters = {m: AverageMeter() for m in meters} 17 | 18 | self.create_dataset() 19 | 20 | def reset_meters(self): 21 | for m in self.meters.values(): 22 | m.reset() 23 | 24 | def update_meters(self, **kwargs): 25 | for k, m in self.meters.items(): 26 | try: 27 | m.update(kwargs[k]) 28 | except KeyError: 29 | raise KeyError('Key {} is not defined in the dictionary'.format(k)) 30 | 31 | def get_avg(self): 32 | for k, m in self.meters.items(): 33 | yield k, m.avg 34 | 35 | def create_dataset(self): 36 | self.data_container = DataContainer(self.mode) 37 | 38 | def result_print(self, i, epoch, batch_time, started_time): 39 | print( 40 | '{4} {0} [{1}][{2}/{3}]\t' 41 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 42 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 43 | 'Label Acc. {label.val:.4f} ({label.avg:.4f})'.format( 44 | self.mode.upper(), epoch, i, len(self.data_container.dataloader) - 1, 45 | str(datetime.timedelta(seconds=int(time.time() - started_time))), batch_time=batch_time, 46 | loss=self.meters['loss'], 47 | label=self.meters['label_accuracy'])) 48 | 49 | @abstractmethod 50 | def set_net_mode(self, net): 51 | pass 52 | 53 | def generate_gt(self, annotation): 54 | return annotation['label'].to(self.device) 55 | 56 | # noinspection PyUnresolvedReferences 57 | @staticmethod 58 | def evaluate(p, a): 59 | return (p.argmax(dim=1) == a).sum().item() / len(a) 60 | 61 | @abstractmethod 62 | def batch_main(self, net, x_slow, x_fast, annotation): 63 | pass 64 | 65 | def batch_loop(self, net, epoch, started_time): 66 | batch_time = AverageMeter() 67 | end = time.time() 68 | for i, (image, annotation) in enumerate(self.data_container.dataloader): 69 | 70 | x_slow = image[:, :, ::cfg.SLOWFAST.TAU, :, :].to(self.device) 71 | x_fast = image[:, :, ::cfg.SLOWFAST.TAU//cfg.SLOWFAST.ALPHA, :, :].to(self.device) 72 | 73 | results = self.batch_main(net, x_slow, x_fast, annotation) 74 | self.update_meters(**results) 75 | 76 | batch_time.update(time.time() - end) 77 | end = time.time() 78 | 79 | self.result_print(i, epoch, batch_time, started_time) 80 | -------------------------------------------------------------------------------- /tools/validator.py: -------------------------------------------------------------------------------- 1 | from tv_abc import TVBase 2 | import torch 3 | 4 | 5 | class Validator(TVBase): 6 | def __init__(self, mode, meters, device): 7 | super().__init__(mode, meters, device) 8 | 9 | def set_net_mode(self, net): 10 | net.eval() 11 | 12 | def batch_main(self, net, x_slow, x_fast, annotation): 13 | with torch.no_grad(): 14 | p = net.forward((x_slow, x_fast)) 15 | 16 | a = self.generate_gt(annotation) 17 | 18 | loss = net.loss_update(p, a, step=False) 19 | 20 | acc = self.evaluate(p, a) 21 | 22 | return {'loss': loss, 23 | 'label_accuracy': acc} 24 | --------------------------------------------------------------------------------