├── .gitignore ├── .idea ├── dictionaries │ └── atcold.xml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── pytorch-MatchNet.iml └── vcs.xml ├── README.md ├── data ├── README.md ├── VideoFolder.py ├── add_frame_numbering.sh ├── dump_data_set.sh ├── objectify.sh ├── resize_and_sample.sh ├── resize_and_split.sh ├── sample_video.sh └── small_data_set │ ├── cup │ └── sfsdfs-nb.mp4 │ └── hand │ ├── hand5-nb.mp4 │ └── hand_1-nb.mp4 ├── image-pretraining ├── README.md ├── main.py └── model ├── main.py ├── model ├── ConvLSTMCell.py ├── DiscriminativeCell.py ├── GenerativeCell.py ├── Model01.py ├── Model02.py ├── PrednetModel.py ├── README.md ├── RG.py ├── models.svg └── utils ├── new_experiment.sh ├── notebook ├── README.md ├── data ├── display_loss.ipynb ├── figures_generator.ipynb ├── frequency_analysis.ipynb ├── get_all_embeddings.ipynb ├── get_data_stats.ipynb ├── model ├── network_bisection.ipynb ├── plot_conf.py ├── salient_regions.ipynb ├── stability_analysis.ipynb ├── utils └── verify_resnet.ipynb └── utils ├── README.md ├── check_exp_diff.sh ├── image_plot.py ├── show_error.plt ├── show_error_exp.plt ├── update_experiments.sh └── visualise.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Remove only PyCharm's workspace configuration 2 | workspace.xml 3 | 4 | # Remove [I]Python caching 5 | __pycache__ 6 | .ipynb_checkpoints/ 7 | 8 | # Remove Vim temp files 9 | *.sw* 10 | 11 | # Remove other unnecessary stuff 12 | *.dot* 13 | *.tar 14 | *.txt 15 | *.jpg 16 | *.png 17 | *.pdf 18 | *.mp4 19 | *.svg 20 | *.pth 21 | 22 | # Remove experiments 23 | last 24 | results 25 | *backup 26 | -------------------------------------------------------------------------------- /.idea/dictionaries/atcold.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | canziani 5 | conv 6 | convolutional 7 | criterions 8 | cuda 9 | fromarray 10 | intra 11 | logits 12 | lstm 13 | matplotlib 14 | optim 15 | prednet 16 | resizes 17 | sergey 18 | subsamples 19 | tanh 20 | zagoruyko 21 | 22 | 23 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pytorch-MatchNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *CortexNet* 2 | 3 | This repo contains the *PyTorch* implementation of *CortexNet*. 4 | Check the [project website](https://engineering.purdue.edu/elab/CortexNet/) for further information. 5 | 6 | ## Project structure 7 | 8 | The project consists of the following folders and files: 9 | 10 | - [`data/`](data): contains *Bash* scripts and a *Python* class definition inherent video data loading; 11 | - [`image-pretraining/`](image-pretraining/): hosts the code for pre-training TempoNet's discriminative branch; 12 | - [`model/`](model): stores several network architectures, including [*PredNet*](https://coxlab.github.io/prednet/), an additive feedback *Model01*, and a modulatory feedback *Model02* ([*CortexNet*](https://engineering.purdue.edu/elab/CortexNet/)); 13 | - [`notebook/`](notebook): collection of *Jupyter Notebook*s for data exploration and results visualisation; 14 | - [`utils/`](utils): scripts for 15 | - (current or former) training error plotting, 16 | - experiments `diff`, 17 | - multi-node synchronisation, 18 | - generative predictions visualisation, 19 | - network architecture graphing; 20 | - `results@`: link to the location where experimental results will be saved within 3-digit folders; 21 | - [`new_experiment.sh*`](new_experiment.sh): creates a new experiment folder, updates `last@`, prints a memo about last used settings; 22 | - `last@`: symbolic link pointing to a new results sub-directory created by `new_experiment.sh`; 23 | - [`main.py`](main.py): training script for *CortexNet* in *MatchNet* or *TempoNet* configuration; 24 | 25 | ## Dependencies 26 | 27 | + [*scikit-video*](https://github.com/scikit-video/scikit-video): accessing images / videos 28 | 29 | ```bash 30 | pip install sk-video 31 | ``` 32 | 33 | + [*tqdm*](https://github.com/tqdm/tqdm): progress bar 34 | 35 | ```bash 36 | conda config --add channels conda-forge 37 | conda update --all 38 | conda install tqdm 39 | ``` 40 | 41 | ## IDE 42 | 43 | This project has been realised with [*PyCharm*](https://www.jetbrains.com/pycharm/) by *JetBrains* and the [*Vim*](http://www.vim.org/) editor. 44 | [*Grip*](https://github.com/joeyespo/grip) has been also fundamental for crafting decent documtation locally. 45 | 46 | ## Initialise environment 47 | 48 | Once you've determined where you'd like to save your experimental results — let's call this directory `` — run the following commands from the project's root directory: 49 | 50 | ```bash 51 | ln -s results # replace 52 | mkdir results/000 && touch results/000/train.log # init. placeholder 53 | ln -s results/000 last # create pointer to the most recent result 54 | ``` 55 | 56 | ## Setup new experiment 57 | 58 | Ready to run your first experiment? 59 | Type the following: 60 | 61 | ```bash 62 | ./new_experiment.sh 63 | ``` 64 | 65 | ### GPU selection 66 | 67 | Let's say your machine has `N` GPUs. 68 | You can choose to use any of these, by specifying the index `n = 0, ..., N-1`. 69 | Therefore, type `CUDA_VISIBLE_DEVICES=n` just before `python ...` in the following sections. 70 | 71 | ## Train *MatchNet* 72 | 73 | + Download *e-VDS35* (*e.g.* `e-VDS35-May17.tar`) from [here](https://engineering.purdue.edu/elab/eVDS/). 74 | + Use [`data/resize_and_split.sh`](data/resize_and_split.sh) to prepare your (video) data for training. 75 | It resizes videos present in folders of folders (*i.e.* directory of classes) and may split them into training and validation set. 76 | May also skip short videos and trim longer ones. 77 | Check [`data/README.md`](data/README.md#matchnet-mode) for more details. 78 | + Run the [`main.py`](main.py) script to start training. 79 | Use `-h` to print the command line interface (CLI) arguments help. 80 | 81 | ```bash 82 | python -u main.py --mode MatchNet | tee last/train.log 83 | ``` 84 | 85 | ## Train *TempoNet* 86 | 87 | + Download *e-VDS35* (*e.g.* `e-VDS35-May17.tar`) from [here](https://engineering.purdue.edu/elab/eVDS/). 88 | + Pre-train the forward branch (see [`image-pretraining/`](image-pretraining)) on an image data set (*e.g.* `33-image-set.tar` from [here](https://engineering.purdue.edu/elab/eVDS/)); 89 | + Use [`data/resize_and_sample.sh`](data/resize_and_sample.sh) to prepare your (video) data for training. 90 | It resizes videos present in folders of folders (*i.e.* directory of classes) and samples them. 91 | Videos are then distributed across training and validation set. 92 | May also skip short videos and trim longer ones. 93 | Check [`data/README.md`](data/README.md#temponet-mode) for more details. 94 | + Run the [`main.py`](main.py) script to start training. 95 | Use `-h` to print the CLI arguments help. 96 | 97 | ```bash 98 | python -u main.py --mode TempoNet --pre-trained | tee last/train.log 99 | ``` 100 | 101 | ## GPU selection 102 | 103 | To run on a specific GPU, say `n`, type `CUDA_VISIBLE_DEVICES=n` just before `python ...`. 104 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Video data pre-processing 2 | 3 | This folder contains the following scripts: 4 | 5 | - [`add_frame_numbering.sh*`](add_frame_numbering.sh): draws a huge number on each frame on a specific video; 6 | - [`dump_data_set.sh*`](dump_data_set.sh): get images from videos for ["traditional training"](https://github.com/pytorch/examples/tree/master/imagenet); 7 | - [`objectify.sh*`](objectify.sh): convert sampled-data from video-indexed to object-indexed; 8 | - [`resize_and_split.sh*`](resize_and_split.sh): see [below](#matchnet-mode); 9 | - [`resize_and_sample.sh*`](resize_and_sample.sh): see [below](#temponet-mode); 10 | - [`sample_video.sh*`](sample_video.sh): sample the `` into `k` time-subsampled `-` videos; 11 | - [`VideoFolder.py`](VideoFolder.py): *PyTorch* `data.Dataset`'s sub-class for video data loading. 12 | 13 | ## Remove white spaces from file names 14 | White spaces and scripts are not good friends. 15 | Replace all spaces with `_` in your source data set with the following command, where you have to replace `` with the correct location. 16 | 17 | ```bash 18 | rename -n "s/ /_/g" /*/* # for a dry run 19 | rename "s/ /_/g" /*/* # to rename the files in !!! 20 | ``` 21 | 22 | In *e-VDS35* we have already done this for you. 23 | If you plan to use your own data, this step is fundamental! 24 | 25 | ## MatchNet mode 26 | ### Resize and *split* videos in train-val 27 | 28 | In order to speed up data loading, we shall resize the video shortest side to, say, `256`. 29 | To do so run the following script. 30 | 31 | ```bash 32 | ./resize_and_split.sh 33 | ``` 34 | 35 | By default, the script will 36 | 37 | - skip videos shorter than `144` frames (`4.8`s) 38 | - trim videos longer than `654` frames (`21.8`s) 39 | - use the last `2`s for the validation split 40 | - resize the shortest side to `256`px 41 | 42 | These options can be varied and turned off by changing the header of the script, which now looks something like this. 43 | 44 | ```bash 45 | ################################################################################ 46 | # SETTINGS ##################################################################### 47 | ################################################################################ 48 | # comment the next line if you don't want to skip short videos 49 | min_frames=144 50 | # comment the next line if you don't want to limit the max length 51 | max_frames=564 52 | # set split to 0 (seconds) if no splitting is required 53 | split=2 54 | ################################################################################ 55 | ``` 56 | 57 | ### File system 58 | 59 | Our input `data_set` looks somehow like this. 60 | 61 | ```bash 62 | data_set 63 | ├── barcode 64 | │   ├── 20160613_140057.mp4 65 | │   ├── 20160613_140115.mp4 66 | │   ├── 20160613_140138.mp4 67 | │   ├── 20160721_023437.mp4 68 | ├── bicycle 69 | │   ├── 0707_2_(2).mov 70 | │   ├── 0707_2_(4).mov 71 | ``` 72 | 73 | Running `./resize_and_sample.sh data_set/ processed-data` yields 74 | 75 | ```bash 76 | processed-data/ 77 | ├── train 78 | │   ├── barcode 79 | │   │   ├── 20160613_140057.mp4 80 | │   │   ├── 20160613_140115.mp4 81 | │   │   ├── 20160613_140138.mp4 82 | │   │   ├── 20160721_023437.mp4 83 | │   ├── bicycle 84 | │   │   ├── 0707_2_(2).mp4 85 | │   │   ├── 0707_2_(4).mp4 86 | ``` 87 | 88 | ## TempoNet mode 89 | ### Resize and *sample* videos, then splits in train-val 90 | 91 | In order to speed up data loading, we shall resize the video shortest side to, say, `256`. 92 | To do so run the following script. 93 | 94 | ```bash 95 | ./resize_and_sample.sh 96 | ``` 97 | 98 | By default, the script will 99 | 100 | - skip videos shorter than `144` frames (`4.8`s) 101 | - trim videos longer than `654` frames (`21.8`s) 102 | - perform `5` subsamples, use `4` for training, `1` for validation 103 | - resize the shortest side to `256`px 104 | 105 | These options can be varied and turned off by changing the header of the script, which now looks something like this. 106 | 107 | ```bash 108 | ################################################################################ 109 | # SETTINGS ##################################################################### 110 | ################################################################################ 111 | # comment the next line if you don't want to skip short videos 112 | min_frames=144 113 | # comment the next line if you don't want to limit the max length 114 | max_frames=564 115 | # set sampling interval: k - 1 train, 1 val 116 | k=5 117 | ################################################################################ 118 | ``` 119 | 120 | The output directory will contain **as many folders as the total number of videos**. 121 | Each folder will contain the individual splits. 122 | 123 | ### From video-index- to class-major data organisation 124 | 125 | If you would like to train against object classes instead of video indices (like explained in the paper), you also need to run 126 | 127 | ```bash 128 | ./objectify.sh 129 | ``` 130 | 131 | This script generates a new directory containing **as many folders as classes**, filled with symbolic links from the source directory. 132 | You can use it for both videos and dumped-videos (images) data sets. 133 | 134 | ### File system 135 | 136 | Running `./resize_and_sample.sh data_set/ sampled-data` yields 137 | 138 | ```bash 139 | sampled-data/ 140 | ├── train 141 | │   ├── barcode-20160613_140057 142 | │   │   ├── 1.mp4 143 | │   │   ├── 2.mp4 144 | │   │   ├── 3.mp4 145 | │   │   └── 4.mp4 146 | │   ├── barcode-20160613_140115 147 | │   │   ├── 1.mp4 148 | │   │   ├── 2.mp4 149 | ``` 150 | 151 | We can "objectify" the structure with `./objectify.sh sampled-data/ object-sampled-data` and get 152 | 153 | ```bash 154 | object-sampled-data/ 155 | ├── train 156 | │   ├── barcode 157 | │   │   ├── 20160613_140057-1.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/1.mp4 158 | │   │   ├── 20160613_140057-2.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/2.mp4 159 | │   │   ├── 20160613_140057-3.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/3.mp4 160 | │   │   ├── 20160613_140057-4.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/4.mp4 161 | │   │   ├── 20160613_140115-1.mp4 -> ../../../sampled-data/train/barcode-20160613_140115/1.mp4 162 | │   │   ├── 20160613_140115-2.mp4 -> ../../../sampled-data/train/barcode-20160613_140115/2.mp4 163 | │   │   ├── 20160613_140115-3.mp4 -> ../../../sampled-data/train/barcode-20160613_140115/3.mp4 164 | ``` 165 | 166 | To train the discriminative feed-forward branch we need to dump our `sampled-data` with `./dump_data_set.sh sampled-data/ dumped-sampled-data`. 167 | The file system will look like this 168 | 169 | ```bash 170 | dumped-sampled-data/ 171 | ├── train 172 | │   ├── barcode-20160613_140057 173 | │   │   ├── 1001.png 174 | │   │   ├── 1002.png 175 | │   │   ├── 1003.png 176 | │   │   ├── 1004.png 177 | │   │   ├── 1005.png 178 | │   │   ├── 1006.png 179 | │   │   ├── 1007.png 180 | ``` 181 | 182 | where the "training class" correspond to the **video name**. 183 | If we wish to train against **object classes**, then we can run `./objectify.sh dumped-sampled-data/ dumped-object-sampled-data` and get the following 184 | 185 | ```bash 186 | dumped-object-sampled-data/ 187 | ├── train 188 | │   ├── barcode 189 | │   │   ├── 20160613_140057-1001.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1001.png 190 | │   │   ├── 20160613_140057-1002.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1002.png 191 | │   │   ├── 20160613_140057-1003.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1003.png 192 | │   │   ├── 20160613_140057-1004.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1004.png 193 | │   │   ├── 20160613_140057-1005.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1005.png 194 | │   │   ├── 20160613_140057-1006.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1006.png 195 | │   │   ├── 20160613_140057-1007.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1007.png 196 | ``` 197 | -------------------------------------------------------------------------------- /data/VideoFolder.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import torch.utils.data as data 4 | 5 | from random import shuffle as list_shuffle # for shuffling list 6 | from math import ceil 7 | from os import listdir 8 | from os.path import isdir, join, isfile 9 | from itertools import islice 10 | from numpy.core.multiarray import concatenate, ndarray 11 | from skvideo.io import FFmpegReader, ffprobe 12 | from torch.utils.data.sampler import Sampler 13 | from torchvision import transforms as trn 14 | from tqdm import tqdm 15 | from time import sleep 16 | from bisect import bisect 17 | 18 | # Implement object from https://discuss.pytorch.org/t/loading-videos-from-folders-as-a-dataset-object/568 19 | 20 | VIDEO_EXTENSIONS = ['.mp4'] # pre-processing outputs MP4s only 21 | 22 | 23 | class BatchSampler(Sampler): 24 | def __init__(self, data_source, batch_size): 25 | """ 26 | Samples batches sequentially, always in the same order. 27 | 28 | :param data_source: data set to sample from 29 | :type data_source: Dataset 30 | :param batch_size: concurrent number of video streams 31 | :type batch_size: int 32 | """ 33 | self.batch_size = batch_size 34 | self.samples_per_row = ceil(len(data_source) / batch_size) 35 | self.num_samples = self.samples_per_row * batch_size 36 | 37 | def __iter__(self): 38 | return (self.samples_per_row * i + j for j in range(self.samples_per_row) for i in range(self.batch_size)) 39 | 40 | def __len__(self): 41 | return self.num_samples # fake nb of samples, transparent wrapping around 42 | 43 | 44 | class VideoCollate: 45 | def __init__(self, batch_size): 46 | self.batch_size = batch_size 47 | 48 | def __call__(self, batch: iter) -> torch.Tensor or list(torch.Tensor): 49 | """ 50 | Puts each data field into a tensor with outer dimension batch size 51 | 52 | :param batch: samples from a Dataset object 53 | :type batch: list 54 | :return: temporal batch of frames of size (t, batch_size, *frame.size()), 0 <= t < T, most likely t = T - 1 55 | :rtype: tuple 56 | """ 57 | if torch.is_tensor(batch[0]): 58 | return torch.cat(tuple(t.unsqueeze(0) for t in batch), 0).view(-1, self.batch_size, *batch[0].size()) 59 | elif isinstance(batch[0], int): 60 | return torch.LongTensor(batch).view(-1, self.batch_size) 61 | elif isinstance(batch[0], collections.Iterable): 62 | # if each batch element is not a tensor, then it should be a tuple 63 | # of tensors; in that case we collate each element in the tuple 64 | transposed = zip(*batch) 65 | return tuple(self.__call__(samples) for samples in transposed) 66 | 67 | raise TypeError(("batch must contain tensors, numbers, or lists; found {}" 68 | .format(type(batch[0])))) 69 | 70 | 71 | class VideoFolder(data.Dataset): 72 | def __init__(self, root, transform=None, target_transform=None, video_index=False, shuffle=None): 73 | """ 74 | Initialise a ``data.Dataset`` object for concurrent frame fetching from videos in a directory of folders of videos 75 | 76 | :param root: Data directory (train or validation folders path) 77 | :type root: str 78 | :param transform: image transform-ing object from ``torchvision.transforms`` 79 | :type transform: object 80 | :param target_transform: label transformation / mapping 81 | :type target_transform: object 82 | :param video_index: if ``True``, the label will be the video index instead of target class 83 | :type video_index: bool 84 | :param shuffle: ``None``, ``'init'`` or ``True`` 85 | :type shuffle: str 86 | """ 87 | classes, class_to_idx = self._find_classes(root) 88 | video_paths = self._find_videos(root, classes) 89 | videos, frames, frames_per_video, frames_per_class = self._make_data_set( 90 | root, video_paths, class_to_idx, shuffle, video_index 91 | ) 92 | 93 | self.root = root 94 | self.video_paths = video_paths 95 | self.videos = videos 96 | self.opened_videos = [[] for _ in videos] 97 | self.frames = frames 98 | self.frames_per_video = frames_per_video 99 | self.frames_per_class = frames_per_class 100 | self.classes = classes 101 | self.class_to_idx = class_to_idx 102 | self.transform = transform 103 | self.target_transform = target_transform 104 | self.alternative_target = video_index 105 | self.shuffle = shuffle 106 | 107 | def __getitem__(self, frame_idx): 108 | if frame_idx == 0: 109 | self.free() 110 | if self.shuffle is True: 111 | self._shuffle() 112 | 113 | frame_idx %= self.frames # wrap around indexing, if asking too much 114 | video_idx = bisect(self.videos, ((frame_idx,),)) # video to which frame_idx belongs 115 | (last, first), (path, target) = self.videos[video_idx] # get video metadata 116 | frame = self._get_frame(frame_idx - first, video_idx, frame_idx == last) # get frame from video 117 | if self.transform is not None: # image processing 118 | frame = self.transform(frame) 119 | if self.target_transform is not None: # target processing 120 | target = self.target_transform(target) 121 | 122 | if self.alternative_target: return frame, video_idx 123 | 124 | return frame, target 125 | 126 | def __len__(self): 127 | return self.frames 128 | 129 | def _get_frame(self, seek, video_idx, last): 130 | 131 | opened_video = None # handle to opened target video 132 | if self.opened_videos[video_idx]: # if handle(s) exists for target video 133 | current = self.opened_videos[video_idx] # get handles list 134 | opened_video = next((ov for ov in current if ov[0] == seek), None) # look for matching seek 135 | 136 | if opened_video is None: # no (matching) handle found 137 | video_path = join(self.root, self.videos[video_idx][1][0]) # build video path 138 | video_file = FFmpegReader(video_path) # get a video file pointer 139 | video_iter = video_file.nextFrame() # get an iterator 140 | opened_video = [seek, islice(video_iter, seek, None), video_file] # seek video and create o.v. item 141 | self.opened_videos[video_idx].append(opened_video) # add opened video object to o.v. list 142 | 143 | opened_video[0] = seek + 1 # update seek pointer 144 | frame = next(opened_video[1]) # cache output frame 145 | if last: 146 | opened_video[2]._close() # close video file (private method?!) 147 | self.opened_videos[video_idx].remove(opened_video) # remove o.v. item 148 | 149 | return frame 150 | 151 | def free(self): 152 | """ 153 | Frees all video files' pointers 154 | """ 155 | for video in self.opened_videos: # for every opened video 156 | for _ in range(len(video)): # for as many times as pointers 157 | opened_video = video.pop() # pop an item 158 | opened_video[2]._close() # close the file 159 | 160 | def _shuffle(self): 161 | """ 162 | Shuffles the video list 163 | by regenerating the sequence to sample sequentially 164 | """ 165 | def _is_video_file(filename_): 166 | return any(filename_.endswith(extension) for extension in VIDEO_EXTENSIONS) 167 | 168 | root = self.root 169 | video_paths = self.video_paths 170 | class_to_idx = self.class_to_idx 171 | list_shuffle(video_paths) # shuffle 172 | 173 | videos = list() 174 | frames_per_video = list() 175 | frames_counter = 0 176 | for filename in tqdm(video_paths, ncols=80): 177 | class_ = filename.split('/')[0] 178 | data_path = join(root, filename) 179 | if _is_video_file(data_path): 180 | video_meta = ffprobe(data_path) 181 | start_idx = frames_counter 182 | frames = int(video_meta['video'].get('@nb_frames')) 183 | frames_per_video.append(frames) 184 | frames_counter += frames 185 | item = ((frames_counter - 1, start_idx), (filename, class_to_idx[class_])) 186 | videos.append(item) 187 | 188 | sleep(0.5) # allows for progress bar completion 189 | # update the attributes with the altered sequence 190 | self.video_paths = video_paths 191 | self.videos = videos 192 | self.frames = frames_counter 193 | self.frames_per_video = frames_per_video 194 | 195 | @staticmethod 196 | def _find_classes(data_path): 197 | classes = [d for d in listdir(data_path) if isdir(join(data_path, d))] 198 | classes.sort() 199 | class_to_idx = {classes[i]: i for i in range(len(classes))} 200 | return classes, class_to_idx 201 | 202 | @staticmethod 203 | def _find_videos(root, classes): 204 | return [join(c, d) for c in classes for d in listdir(join(root, c))] 205 | 206 | @staticmethod 207 | def _make_data_set(root, video_paths, class_to_idx, init_shuffle, video_index): 208 | def _is_video_file(filename_): 209 | return any(filename_.endswith(extension) for extension in VIDEO_EXTENSIONS) 210 | 211 | if init_shuffle and not video_index: 212 | list_shuffle(video_paths) # shuffle 213 | 214 | videos = list() 215 | frames_per_video = list() 216 | frames_per_class = [0] * len(class_to_idx) 217 | frames_counter = 0 218 | for filename in tqdm(video_paths, ncols=80): 219 | class_ = filename.split('/')[0] 220 | data_path = join(root, filename) 221 | if _is_video_file(data_path): 222 | video_meta = ffprobe(data_path) 223 | start_idx = frames_counter 224 | frames = int(video_meta['video'].get('@nb_frames')) 225 | frames_per_video.append(frames) 226 | frames_per_class[class_to_idx[class_]] += frames 227 | frames_counter += frames 228 | item = ((frames_counter - 1, start_idx), (filename, class_to_idx[class_])) 229 | videos.append(item) 230 | 231 | sleep(0.5) # allows for progress bar completion 232 | return videos, frames_counter, frames_per_video, frames_per_class 233 | 234 | 235 | def _test_video_folder(): 236 | from textwrap import fill, indent 237 | 238 | batch_size = 5 239 | 240 | video_data_set = VideoFolder('small_data_set/') 241 | nb_of_classes = len(video_data_set.classes) 242 | print('There are', nb_of_classes, 'classes') 243 | print(indent(fill(' '.join(video_data_set.classes), 77), ' ')) 244 | print('There are {} frames'.format(len(video_data_set))) 245 | print('Videos in the data set:', *video_data_set.videos, sep='\n') 246 | 247 | import inflect 248 | ordinal = inflect.engine().ordinal 249 | 250 | def print_list(my_list): 251 | for a, b in enumerate(my_list): 252 | print(a, ':', end=' [') 253 | print(*b, sep=',\n ', end=']\n') 254 | 255 | # get first 3 batches 256 | n = ceil(len(video_data_set) / batch_size) 257 | print('Batch size:', batch_size) 258 | print('Frames per row:', n) 259 | for big_j in range(0, n, 90): 260 | batch = list() 261 | for j in range(big_j, big_j + 90): 262 | if j >= n: break # there are no more frames 263 | batch.append(tuple(video_data_set[i * n + j][0] for i in range(batch_size))) 264 | batch[-1] = concatenate(batch[-1], 0) 265 | batch = concatenate(batch, 1) 266 | _show_numpy(batch, 1e-1) 267 | print(ordinal(big_j // 90 + 1), '90 batches of shape', batch.shape) 268 | print_list(video_data_set.opened_videos) 269 | 270 | print('Freeing resources') 271 | video_data_set.free() 272 | print_list(video_data_set.opened_videos) 273 | 274 | # get frames 50 -> 52 275 | batch = list() 276 | for i in range(50, 53): 277 | batch.append(video_data_set[i][0]) 278 | _show_numpy(concatenate(batch, 1)) 279 | print_list(video_data_set.opened_videos) 280 | 281 | 282 | def _test_data_loader(): 283 | big_t = 10 284 | batch_size = 5 285 | t = trn.Compose((trn.ToPILImage(), trn.ToTensor())) # <-- add trn.CenterCrop(224) in between for training 286 | data_set = VideoFolder('small_data_set', t) 287 | my_loader = data.DataLoader(dataset=data_set, batch_size=batch_size * big_t, shuffle=False, 288 | sampler=BatchSampler(data_set, batch_size), num_workers=0, 289 | collate_fn=VideoCollate(batch_size)) 290 | print('Is my_loader an iterator [has __next__()]:', isinstance(my_loader, collections.Iterator)) 291 | print('Is my_loader an iterable [has __iter__()]:', isinstance(my_loader, collections.Iterable)) 292 | my_iter = iter(my_loader) 293 | my_batch = next(my_iter) 294 | print('my_batch is a', type(my_batch), 'of length', len(my_batch)) 295 | print('my_batch[0] is a', my_batch[0].type(), 'of size', tuple(my_batch[0].size()), ' # will 224, 224') 296 | _show_torch(_tile_up(my_batch), .2) 297 | for i in range(3): _show_torch(_tile_up(next(my_iter)), .2) 298 | 299 | 300 | def _show_numpy(tensor: ndarray, zoom: float = 1.) -> None: 301 | """ 302 | Display a ndarray image on screen 303 | 304 | :param tensor: image to visualise, of size (h, w, 1/3) 305 | :type tensor: ndarray 306 | :param zoom: zoom factor 307 | :type zoom: float 308 | """ 309 | from PIL import Image 310 | shape = tuple(map(lambda s: round(s * zoom), tensor.shape)) 311 | Image.fromarray(tensor).resize((shape[1], shape[0])).show() 312 | 313 | 314 | def _show_torch(tensor: torch.FloatTensor, zoom: float = 1.) -> None: 315 | numpy_tensor = tensor.clone().mul(255).int().numpy().astype('u1').transpose(1, 2, 0) 316 | _show_numpy(numpy_tensor, zoom) 317 | 318 | 319 | def _tile_up(temporal_batch): 320 | a = torch.cat(tuple(temporal_batch[0][:, i] for i in range(temporal_batch[0].size(1))), 2) 321 | a = torch.cat(tuple(a[j] for j in range(a.size(0))), 2) 322 | return a 323 | 324 | 325 | if __name__ == '__main__': 326 | _test_video_folder() 327 | _test_data_loader() 328 | 329 | 330 | __author__ = "Alfredo Canziani" 331 | __credits__ = ["Alfredo Canziani"] 332 | __maintainer__ = "Alfredo Canziani" 333 | __email__ = "alfredo.canziani@gmail.com" 334 | __status__ = "Production" # "Prototype", "Development", or "Production" 335 | __date__ = "Feb 17" 336 | -------------------------------------------------------------------------------- /data/add_frame_numbering.sh: -------------------------------------------------------------------------------- 1 | # Add frame number to video, to check correct loading 2 | # ./add_frame_numbering.sh 256min_data_set/barcode/20160613_140057.mp4 3 | # generate labelled small_data_set/barcode/20160613_140057-nb.mp4 4 | 5 | src_video=$1 6 | dst_dir="small_data_set" 7 | dst_video="${src_video%.*}-nb.mp4" 8 | dst_video="$dst_dir/${dst_video#*/}" 9 | dst_dir="${dst_video%/*}" 10 | 11 | mkdir -p $dst_dir 12 | ffmpeg \ 13 | -i $src_video \ 14 | -filter:v "drawtext=fontsize=200:fontfile=Arial.ttf: text=%{n}: x=(w-tw)/2: y=50:fontcolor=white: box=1: boxcolor=0x00000099" \ 15 | -loglevel quiet \ 16 | $dst_video 17 | printf "$src_video --> $dst_video\n" 18 | -------------------------------------------------------------------------------- /data/dump_data_set.sh: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Dump data set 3 | ################################################################################ 4 | # Alfredo Canziani, Apr 17 5 | ################################################################################ 6 | # Run as: 7 | # ./dump_data_set.sh src_path/ dst_path/ 8 | ################################################################################ 9 | 10 | # some colours 11 | r='\033[0;31m' # red 12 | g='\033[0;32m' # green 13 | b='\033[0;34m' # blue 14 | n='\033[0m' # none 15 | 16 | # title 17 | echo "Dumping video data set into separate frames" 18 | 19 | # assert existence of source directory 20 | src_dir=${1%/*} # remove trailing /, if present 21 | if [ ! -d $src_dir ] || [ -z $src_dir ]; then 22 | echo -e "${r}Source directory/link \"$src_dir\" is missing. Exiting.${n}" 23 | exit 1 24 | fi 25 | echo -e " - Source directory/link set to \"$b$src_dir$n\"" 26 | 27 | # assert existence of destination directory 28 | dst_dir="${2%/*}" # remove trailing /, if present 29 | if [ -d $dst_dir ]; then 30 | echo -e "${r}Destination directory \"$dst_dir\" already existent." \ 31 | "Exiting.${n}" 32 | exit 1 33 | fi 34 | echo -e " - Destination directory set to \"$b$dst_dir$n\"" 35 | 36 | # check if all is good 37 | printf "Does it look fine? (${g}y${g}${n}/${r}n${n}) " 38 | read ans 39 | if [ $ans == 'n' ]; then 40 | echo -e "${r}Exiting.${n}" 41 | exit 0 42 | fi 43 | 44 | # for every class 45 | for set_ in $(ls $src_dir); do 46 | 47 | printf "\nProcessing $set_ set\n" 48 | 49 | for class in $(ls $src_dir/$set_); do 50 | 51 | printf " > Processing class \"$class\":" 52 | 53 | # define src and dst 54 | src_class_dir="$src_dir/$set_/$class" 55 | dst_class_dir="$dst_dir/$set_/$class" 56 | mkdir -p $dst_class_dir 57 | 58 | # for each video in the class 59 | for video in $(ls $src_class_dir); do 60 | 61 | printf " \"$video\"" 62 | 63 | # define src and dst video paths 64 | src_video_path="$src_class_dir/$video" 65 | dst_video_path="$dst_class_dir/${video%.*}%03d.png" 66 | 67 | ffmpeg \ 68 | -loglevel error \ 69 | -i $src_video_path \ 70 | $dst_video_path 71 | 72 | done 73 | echo "" 74 | done 75 | done 76 | 77 | echo -e "${r}Done. Exiting.${n}" 78 | exit 0 79 | -------------------------------------------------------------------------------- /data/objectify.sh: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Link videos into object classes 3 | ################################################################################ 4 | # Alfredo Canziani, Apr 17 5 | ################################################################################ 6 | # Run as: 7 | # ./objectify.sh src_path/ dst_path/ 8 | ################################################################################ 9 | 10 | # Originally used these. Now they are arguments. 11 | src="sampled-data" # one folder per video 12 | dst="object-sampled-data" # one folder per object 13 | src="dumped-sampled-data" # one folder per video 14 | dst="dumped-object-sampled-data" # one folder per object 15 | 16 | src=${1%/*} # remove trailing /, if present 17 | dst=${2%/*} # remove trailing /, if present 18 | 19 | classes=$(ls processed-data/train/) 20 | sets="train val" 21 | 22 | for c in $classes; do 23 | echo "Processing class $c" 24 | for s in $sets; do 25 | dst_dir="$dst/$s/$c" 26 | mkdir $dst_dir 27 | for v in $(ls $src/$s/$c-*/*); do 28 | video_name=${v#$src/$s/$c-} # removes leading crap 29 | video_name=${video_name/'/'/-} # convert / into - 30 | ln -s ../../../$v $dst_dir/$video_name 31 | done 32 | done 33 | done 34 | 35 | -------------------------------------------------------------------------------- /data/resize_and_sample.sh: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Pre-process video data 3 | ################################################################################ 4 | # Alfredo Canziani, Apr 17 5 | ################################################################################ 6 | 7 | # Pre-process video data 8 | # 9 | # - resize video data minor side to specific size 10 | # - sample frames and split into train a val sets 11 | # - skip "too short" videos 12 | # - limit max length 13 | # 14 | # Run as: 15 | # ./resize_and_split.sh src_path/ dst_path/ 16 | # 17 | # It's better to perform the resizing and the sampling together since 18 | # re-encoding is necessary when a temporal sampling is performed. 19 | # Skipping and clipping max length are also easily achievable at this point in 20 | # time. 21 | 22 | # current object video data set 23 | # 95% interval: [144, 564] -> [4.8, 21.8] seconds 24 | # mean number of frames: 354 -> 11.8 seconds 25 | 26 | ################################################################################ 27 | # SETTINGS ##################################################################### 28 | ################################################################################ 29 | # comment the next line if you don't want to skip short videos 30 | min_frames=144 31 | # comment the next line if you don't want to limit the max length 32 | max_frames=564 33 | # set sampling interval: k - 1 train, 1 val 34 | k=5 35 | ################################################################################ 36 | 37 | # some colours 38 | r='\033[0;31m' # red 39 | g='\033[0;32m' # green 40 | b='\033[0;34m' # blue 41 | n='\033[0m' # none 42 | 43 | # check min_frames setting 44 | printf " - " 45 | if [ -n "$min_frames" ]; then 46 | echo -e "Skipping videos with < $b$min_frames$n frames" 47 | skip_count=0 48 | else 49 | echo "No skipping short vidos" 50 | min_frames=0 51 | fi 52 | 53 | # check max_frames setting 54 | printf " - " 55 | if [ -n "$max_frames" ]; then 56 | echo -e "Trimming videos with > $b$max_frames$n frames" 57 | trim_count=0 58 | else 59 | echo "No trimming long vidos" 60 | fi 61 | 62 | # check split setting 63 | printf " - " 64 | echo -e "Sampling every $b$k$n frames" 65 | kk=$(awk "BEGIN{print 1/$k}") 66 | 67 | # assert existence of source directory 68 | src_dir=${1%/*} # remove trailing /, if present 69 | if [ ! -d $src_dir ] || [ -z $src_dir ]; then 70 | echo -e "${r}Source directory/link \"$src_dir\" is missing. Exiting.${n}" 71 | exit 1 72 | fi 73 | echo -e " - Source directory/link set to \"$b$src_dir$n\"" 74 | 75 | # assert existence of destination directory 76 | dst_dir="${2%/*}" 77 | if [ -d $dst_dir ]; then 78 | echo -e "${r}Destination directory \"$dst_dir\" already existent." \ 79 | "Exiting.${n}" 80 | exit 1 81 | fi 82 | echo -e " - Destination directory set to \"$b$dst_dir$n\"" 83 | 84 | # check if all is good 85 | printf "Does it look fine? (${g}y${g}${n}/${r}n${n}) " 86 | read ans 87 | if [ $ans == 'n' ]; then 88 | echo -e "${r}Exiting.${n}" 89 | exit 0 90 | fi 91 | 92 | # for every class 93 | for class in $(ls $src_dir); do 94 | 95 | printf "\nProcessing class \"$class\"\n" 96 | 97 | # define src 98 | src_class_dir="$src_dir/$class" 99 | 100 | # for each video in the class 101 | for video in $(ls $src_class_dir); do 102 | 103 | printf " > Loading video \"$video\". " 104 | 105 | # define src video path 106 | src_video_path="$src_class_dir/$video" 107 | 108 | # count the frames 109 | frames=$(ffprobe \ 110 | -loglevel quiet \ 111 | -show_streams \ 112 | -select_streams v \ 113 | $src_video_path | awk \ 114 | '/nb_frames=/{sub(/nb_frames=/,""); print}') 115 | 116 | # skip if too short 117 | if ((frames < min_frames)); then 118 | printf "Frames: $b$frames$n < $b$min_frames$n min frames. " 119 | echo -e "${r}Skipping.$n" 120 | ((skip_count++)) 121 | continue 122 | fi 123 | 124 | # define and make dst and val dir 125 | dst_video_path="$dst_dir/train/$class-${video%.*}" 126 | val_video_path="$dst_dir/val/$class-${video%.*}" 127 | mkdir -p $dst_video_path 128 | mkdir -p $val_video_path 129 | 130 | # get src_video frame rate 131 | fps=$(ffprobe \ 132 | -loglevel error \ 133 | -show_streams \ 134 | -select_streams v \ 135 | $src_video_path | awk \ 136 | '/avg_frame_rate=/{sub(/avg_frame_rate=/,""); print}') 137 | 138 | # get my src_video_path 139 | # discard audio stram 140 | # be quiet (show errors only) 141 | # use complex filter (1 input file, multiple output files) 142 | # rescale the video stram min side to 256 143 | # select frames using frame_number % k and send it to streams 1, ..., k 144 | # as long as frame_number < max_frames 145 | # use the input average frame rate as output fps 146 | # send each stream to a separate output file dst_video_path/{1..k-1}.mp4 147 | # and val_video_path/k.mp4 148 | printf "Rescaling and sampling" 149 | ffmpeg \ 150 | -i $src_video_path \ 151 | -an \ 152 | -loglevel error \ 153 | -filter_complex \ 154 | "setpts=$kk*PTS, \ 155 | scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))[m]; \ 156 | [m]select=n=$k:e=(mod(n\,$k)+1)*lt(n\,$max_frames) \ 157 | $(for ((i=1; i<=$k; i++)); do 158 | echo -n "[a$i]" 159 | done)" \ 160 | $(for ((i=1; i<$k; i++)); do 161 | echo -n "-r $fps -map [a$i] $dst_video_path/$i.mp4 " 162 | done 163 | echo -n "-r $fps -map [a$k] $val_video_path/$k.mp4" 164 | ) 165 | 166 | # check the output stream resolution 167 | printf ' --> ' 168 | ffprobe \ 169 | -loglevel quiet \ 170 | -show_streams \ 171 | "$dst_video_path/1.mp4" | awk \ 172 | -F= \ 173 | '/^width/{printf $2"x"}; /^height/{printf $2"."}' 174 | 175 | if ((frames > max_frames)); then 176 | echo -n " Trimming $frames --> $max_frames." 177 | ((trim_count++)) 178 | fi 179 | 180 | # new line :) 181 | printf "\n" 182 | done 183 | done 184 | 185 | printf "\n---------------\n" 186 | printf "Skipped $b%d$n videos\n" "$skip_count" 187 | printf "Trimmed $b%d$n videos" "$trim_count" 188 | printf "\n---------------\n\n" 189 | echo -e "${r}Exiting.${n}" 190 | exit 0 191 | -------------------------------------------------------------------------------- /data/resize_and_split.sh: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Pre-process video data 3 | ################################################################################ 4 | # Alfredo Canziani, Feb 17 5 | ################################################################################ 6 | 7 | # Pre-process video data 8 | # 9 | # - resize video data minor side to specific size 10 | # - split into train a val sets 11 | # - skip "too short" videos 12 | # - limit max length 13 | # 14 | # Run as: 15 | # ./resize_and_split.sh src_path/ dst_path/ 16 | # 17 | # It's better to perform the resizing and the splitting together since 18 | # re-encoding is necessary when a temporal split is performed. 19 | # Skipping and clipping max length are also easily achievable at this point in 20 | # time. 21 | 22 | # current object video data set 23 | # 95% interval: [144, 564] -> [4.8, 21.8] seconds 24 | # mean number of frames: 354 -> 11.8 seconds 25 | 26 | ################################################################################ 27 | # SETTINGS ##################################################################### 28 | ################################################################################ 29 | # comment the next line if you don't want to skip short videos 30 | min_frames=144 31 | # comment the next line if you don't want to limit the max length 32 | max_frames=564 33 | # set split to 0 (seconds) if no splitting is required 34 | split=2 35 | ################################################################################ 36 | 37 | # some colours 38 | r='\033[0;31m' # red 39 | g='\033[0;32m' # green 40 | b='\033[0;34m' # blue 41 | n='\033[0m' # none 42 | 43 | # check min_frames setting 44 | printf " - " 45 | if [ -n "$min_frames" ]; then 46 | echo -e "Skipping videos with < $b$min_frames$n frames" 47 | skip_count=0 48 | else 49 | echo "No skipping short vidos" 50 | min_frames=0 51 | fi 52 | 53 | # check max_frames setting 54 | printf " - " 55 | if [ -n "$max_frames" ]; then 56 | echo -e "Trimming videos with > $b$max_frames$n frames" 57 | trim_count=0 58 | else 59 | echo "No trimming long vidos" 60 | fi 61 | 62 | # check split setting 63 | printf " - " 64 | if [ $split != 0 ]; then 65 | echo -e "Using last $b$split$n seconds for validation" 66 | dst="train/" 67 | else 68 | echo "No train-validation splitting will be performed" 69 | dst="" 70 | fi 71 | 72 | # assert existence of source directory 73 | src_dir=${1%/*} # remove trailing /, if present 74 | if [ ! -d $src_dir ] || [ -z $src_dir ]; then 75 | echo -e "${r}Source directory/link \"$src_dir\" is missing. Exiting.${n}" 76 | exit 1 77 | fi 78 | echo -e " - Source directory/link set to \"$b$src_dir$n\"" 79 | 80 | # assert existence of destination directory 81 | dst_dir="${2%/*}" 82 | if [ -d $dst_dir ]; then 83 | echo -e "${r}Destination directory \"$dst_dir\" already existent." \ 84 | "Exiting.${n}" 85 | exit 1 86 | fi 87 | echo -e " - Destination directory set to \"$b$dst_dir$n\"" 88 | 89 | # check if all is good 90 | printf "Does it look fine? (${g}y${g}${n}/${r}n${n}) " 91 | read ans 92 | if [ $ans == 'n' ]; then 93 | echo -e "${r}Exiting.${n}" 94 | exit 0 95 | fi 96 | 97 | # for every class 98 | for class in $(ls $src_dir); do 99 | 100 | printf "\nProcessing class \"$class\"\n" 101 | 102 | # define src and dst dir, make dst dir 103 | src_class_dir="$src_dir/$class" 104 | dst_class_dir="$dst_dir/$dst$class" 105 | mkdir -p $dst_class_dir 106 | 107 | # if split > 0, deal with validation dir too 108 | if [ $split != 0 ]; then 109 | val_class_dir="$dst_dir/val/$class" 110 | mkdir -p $val_class_dir 111 | fi 112 | 113 | # for each video in the class 114 | for video in $(ls $src_class_dir); do 115 | 116 | printf " > Loading video \"$video\". " 117 | 118 | # define src video path 119 | src_video_path="$src_class_dir/$video" 120 | 121 | # count the frames 122 | frames=$(ffprobe \ 123 | -loglevel quiet \ 124 | -show_streams \ 125 | -select_streams v \ 126 | $src_video_path | awk \ 127 | '/nb_frames=/{sub(/nb_frames=/,""); print}') 128 | 129 | # skip if too short 130 | if ((frames < min_frames)); then 131 | printf "Frames: $b$frames$n < $b$min_frames$n min frames. " 132 | echo -e "${r}Skipping.$n" 133 | ((skip_count++)) 134 | continue 135 | fi 136 | 137 | # get src_video duration 138 | tot_t=$(ffprobe \ 139 | -loglevel quiet \ 140 | -show_streams \ 141 | -select_streams v \ 142 | $src_video_path | awk \ 143 | '/duration=/{sub(/duration=/,""); print}') 144 | 145 | # get src_video frame rate 146 | fps=$(ffprobe \ 147 | -loglevel error \ 148 | -show_streams \ 149 | -select_streams v \ 150 | $src_video_path | awk \ 151 | '/avg_frame_rate=/{sub(/avg_frame_rate=/,""); print}') 152 | 153 | # if there is a max_frames and we are over it, redefine tot_t 154 | if [ -n "$max_frames" ] && ((frames > max_frames)); then 155 | printf "Frames: $b$frames$n > $b$max_frames$n max frames. " 156 | printf "Trimming %.2fs" "$tot_t" 157 | tot_t=$(awk \ 158 | "BEGIN{printf (\"%.4f\",$tot_t-($frames-$max_frames)/($fps))}") 159 | printf " --> %.2fs. " "$tot_t" 160 | ((trim_count++)) 161 | fi 162 | 163 | # compute duration in seconds 164 | end_t=$(awk "BEGIN{printf (\"%f\",$tot_t-$split)}") 165 | 166 | # compute duration in ffmpeg format 167 | ffmpeg_end_t=$(awk \ 168 | "BEGIN{printf (\"%02d:%02d:%02.4f\",$end_t/3600,($end_t%3600)/60,($end_t%60))}") 169 | 170 | # get my src_video_path 171 | # discard audio stram 172 | # use it until ffmpeg_end_t 173 | # rescale the video stram min side to 256 174 | # use the input average frame rate as output fps 175 | # be quiet (show errors only) 176 | # save at dst_video_path 177 | printf "Rescaling" 178 | dst_video_path="$dst_class_dir/${video%.*}.mp4" # replace extension 179 | ffmpeg \ 180 | -i $src_video_path \ 181 | -an \ 182 | -to $ffmpeg_end_t \ 183 | -filter:v "scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))" \ 184 | -r $fps \ 185 | -loglevel error \ 186 | $dst_video_path 187 | 188 | # check the output stream resolution 189 | printf ' --> ' 190 | ffprobe \ 191 | -loglevel quiet \ 192 | -show_streams \ 193 | $dst_video_path | awk \ 194 | -F= \ 195 | '/^width/{printf $2"x"}; /^height/{printf $2"."}' 196 | 197 | if [ $split != 0 ]; then 198 | # start at Ns from the end 199 | # of my src_video_path 200 | # discard audio stram 201 | # rescale the video stram min side to 256 202 | # be quiet (show errors only) 203 | # save at val_video_path 204 | printf " Splitting" 205 | val_video_path="$val_class_dir/${video%.*}.mp4" # replace extension 206 | ffmpeg \ 207 | -sseof -$split \ 208 | -i $src_video_path \ 209 | -an \ 210 | -filter:v "scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))" \ 211 | -r $fps \ 212 | -loglevel error \ 213 | $val_video_path 214 | 215 | # print temporal split 216 | duration=$(awk \ 217 | "BEGIN{printf (\"%02d:%02d:%02.4f\",$tot_t/3600,($tot_t%3600)/60,($tot_t%60))}") 218 | printf " 00:00:00 / $ffmpeg_end_t / $duration." 219 | fi 220 | 221 | # new line :) 222 | printf "\n" 223 | done 224 | done 225 | 226 | printf "\n---------------\n" 227 | printf "Skipped $b%d$n videos\n" "$skip_count" 228 | printf "Trimmed $b%d$n videos" "$trim_count" 229 | printf "\n---------------\n\n" 230 | echo -e "${r}Exiting.${n}" 231 | exit 0 232 | -------------------------------------------------------------------------------- /data/sample_video.sh: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Video sampler: y_i[n] = x[n_i + N * n], i < N 3 | ################################################################################ 4 | # Alfredo Canziani, Mar 17 5 | ################################################################################ 6 | # Run as 7 | # ./sample_video.sh src_video dst_prefix 8 | ################################################################################ 9 | 10 | src="small_data_set/cup/sfsdfs-nb.mp4" 11 | dst="sampled/sfsdfs-nb" 12 | src="data_set/barcode/20160613_140057.mp4" 13 | dst="sampled/20160613_140057" 14 | src="data_set/floor/VID_20160605_094332.mp4" 15 | dst="sampled/VID_20160605_094332" 16 | src="/home/atcold/Videos/20170416_184611.mp4" 17 | dst="bme-car/20170416_184611" 18 | src="/home/atcold/Videos/20170418_113638.mp4" 19 | dst="bme-chair/20170418_113638" 20 | src="/home/atcold/Videos/20160603_133515.mp4" 21 | dst="abhi-car/20160603_133515" 22 | src="/home/atcold/Videos/20170419_125021.mp4" 23 | dst="bme-chair/20170419_125021" 24 | 25 | src=$1 26 | dst=$2 27 | 28 | k=5 29 | kk=$(awk "BEGIN{print 1/$k}") 30 | ffmpeg \ 31 | -i $src \ 32 | -an \ 33 | -loglevel error \ 34 | -filter_complex \ 35 | "setpts=$kk*PTS, \ 36 | scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))[m]; \ 37 | [m]select=n=$k:e=(mod(n\, $k)+1)*lt(n\, 564) \ 38 | $(for ((i=1; i<=$k; i++)); do 39 | echo -n "[a$i]" 40 | done)" \ 41 | $(for ((i=1; i<=$k; i++)); do 42 | echo -n "-r 31230000/1042111 -map [a$i] $dst-$i.mp4 " 43 | done) 44 | -------------------------------------------------------------------------------- /data/small_data_set/cup/sfsdfs-nb.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-lab/pytorch-CortexNet/bc28dac4e6a1ad9abb11e2fbc48d310a85e9903a/data/small_data_set/cup/sfsdfs-nb.mp4 -------------------------------------------------------------------------------- /data/small_data_set/hand/hand5-nb.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-lab/pytorch-CortexNet/bc28dac4e6a1ad9abb11e2fbc48d310a85e9903a/data/small_data_set/hand/hand5-nb.mp4 -------------------------------------------------------------------------------- /data/small_data_set/hand/hand_1-nb.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-lab/pytorch-CortexNet/bc28dac4e6a1ad9abb11e2fbc48d310a85e9903a/data/small_data_set/hand/hand_1-nb.mp4 -------------------------------------------------------------------------------- /image-pretraining/README.md: -------------------------------------------------------------------------------- 1 | # Image pre-training 2 | 3 | Find the original code at [PyTorch ImageNet example](https://github.com/pytorch/examples/tree/master/imagenet). 4 | This adaptation trains the discriminative branch of CortexNet for TempoNet. 5 | 6 | ## Training 7 | 8 | To train the discriminative branch of CortexNet, run `main.py` with the path to an image data set: 9 | 10 | ```bash 11 | python main.py | tee train.log 12 | ``` 13 | 14 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. 15 | 16 | ## Usage 17 | 18 | ``` 19 | usage: main.py [-h] [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR] 20 | [--momentum M] [--weight-decay W] [--print-freq N] 21 | [--resume PATH] [-e] [--pretrained] [--size [S [S ...]]] 22 | DIR 23 | 24 | PyTorch ImageNet Training 25 | 26 | positional arguments: 27 | DIR path to dataset 28 | 29 | optional arguments: 30 | -h, --help show this help message and exit 31 | -j N, --workers N number of data loading workers (default: 4) 32 | --epochs N number of total epochs to run 33 | --start-epoch N manual epoch number (useful on restarts) 34 | -b N, --batch-size N mini-batch size (default: 256) 35 | --lr LR, --learning-rate LR 36 | initial learning rate 37 | --momentum M momentum 38 | --weight-decay W, --wd W 39 | weight decay (default: 1e-4) 40 | --print-freq N, -p N print frequency (default: 10) 41 | --resume PATH path to latest checkpoint (default: none) 42 | -e, --evaluate evaluate model on validation set 43 | --pretrained use pre-trained model 44 | --size [S [S ...]] number and size of hidden layers 45 | ``` 46 | -------------------------------------------------------------------------------- /image-pretraining/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | import torchvision.models as models 15 | 16 | 17 | model_names = sorted(name for name in models.__dict__ 18 | if name.islower() and not name.startswith("__") 19 | and callable(models.__dict__[name])) 20 | 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 23 | parser.add_argument('data', metavar='DIR', 24 | help='path to dataset') 25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 26 | help='number of data loading workers (default: 4)') 27 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 28 | help='number of total epochs to run') 29 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 30 | help='manual epoch number (useful on restarts)') 31 | parser.add_argument('-b', '--batch-size', default=256, type=int, 32 | metavar='N', help='mini-batch size (default: 256)') 33 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 34 | metavar='LR', help='initial learning rate') 35 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 36 | help='momentum') 37 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 38 | metavar='W', help='weight decay (default: 1e-4)') 39 | parser.add_argument('--print-freq', '-p', default=10, type=int, 40 | metavar='N', help='print frequency (default: 10)') 41 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 42 | help='path to latest checkpoint (default: none)') 43 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 44 | help='evaluate model on validation set') 45 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 46 | help='use pre-trained model') 47 | parser.add_argument('--size', type=int, default=(3, 32, 64, 128, 256, 256, 256), nargs='*', 48 | help='number and size of hidden layers', metavar='S') 49 | 50 | best_prec1 = 0 51 | 52 | 53 | def main(): 54 | global args, best_prec1 55 | args = parser.parse_args() 56 | args.size = tuple(args.size) 57 | 58 | # create model 59 | from model.Model02 import Model02 as Model 60 | 61 | class Capsule(nn.Module): 62 | 63 | def __init__(self): 64 | super().__init__() 65 | nb_of_classes = 33 # 970 (vid) or 35 (vid obj) or 33 (imgs) 66 | self.inner_model = Model(args.size + (nb_of_classes,), (256, 256)) 67 | 68 | def forward(self, x): 69 | (_, _), (_, video_index) = self.inner_model(x, None) 70 | return video_index 71 | 72 | model = Capsule() 73 | 74 | model = torch.nn.DataParallel(model).cuda() 75 | 76 | cudnn.benchmark = True 77 | 78 | # Data loading code 79 | traindir = os.path.join(args.data, 'train') 80 | valdir = os.path.join(args.data, 'val') 81 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 82 | # std=[0.229, 0.224, 0.225]) 83 | 84 | train_data = datasets.ImageFolder(traindir, transforms.Compose([ 85 | transforms.CenterCrop(256), 86 | transforms.ToTensor(), 87 | ])) 88 | train_loader = torch.utils.data.DataLoader( 89 | train_data, 90 | batch_size=args.batch_size, shuffle=True, 91 | num_workers=args.workers, pin_memory=True 92 | ) 93 | 94 | val_data = datasets.ImageFolder(valdir, transforms.Compose([transforms.CenterCrop(256), transforms.ToTensor(), ])) 95 | val_loader = torch.utils.data.DataLoader( 96 | val_data, 97 | batch_size=args.batch_size, shuffle=False, 98 | num_workers=args.workers, pin_memory=True 99 | ) 100 | 101 | # define loss function (criterion) and optimizer 102 | class_count = [0] * len(train_data.classes) 103 | for i in train_data.imgs: class_count[i[1]] += 1 104 | train_crit_weight = torch.Tensor(class_count) 105 | train_crit_weight.div_(train_crit_weight.mean()).pow_(-1) 106 | train_criterion = nn.CrossEntropyLoss(train_crit_weight).cuda() 107 | 108 | class_count = [0] * len(val_data.classes) 109 | for i in val_data.imgs: class_count[i[1]] += 1 110 | val_crit_weight = torch.Tensor(class_count) 111 | val_crit_weight.div_(val_crit_weight.mean()).pow_(-1) 112 | val_criterion = nn.CrossEntropyLoss(val_crit_weight).cuda() 113 | 114 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 115 | momentum=args.momentum, 116 | weight_decay=args.weight_decay) 117 | 118 | if args.evaluate: 119 | validate(val_loader, model, val_criterion) 120 | return 121 | 122 | for epoch in range(args.start_epoch, args.epochs): 123 | adjust_learning_rate(optimizer, epoch) 124 | 125 | # train for one epoch 126 | train(train_loader, model, train_criterion, optimizer, epoch) 127 | 128 | # evaluate on validation set 129 | prec1 = validate(val_loader, model, val_criterion) 130 | 131 | # remember best prec@1 and save checkpoint 132 | is_best = prec1 > best_prec1 133 | best_prec1 = max(prec1, best_prec1) 134 | save_checkpoint({ 135 | 'epoch': epoch + 1, 136 | 'state_dict': model.state_dict(), 137 | 'best_prec1': best_prec1, 138 | }, is_best) 139 | 140 | 141 | def train(train_loader, model, criterion, optimizer, epoch): 142 | batch_time = AverageMeter() 143 | data_time = AverageMeter() 144 | losses = AverageMeter() 145 | top1 = AverageMeter() 146 | top5 = AverageMeter() 147 | 148 | # switch to train mode 149 | model.train() 150 | 151 | end = time.time() 152 | for i, (input, target) in enumerate(train_loader): 153 | # measure data loading time 154 | data_time.update(time.time() - end) 155 | 156 | target = target.cuda(async=True) 157 | input_var = torch.autograd.Variable(input) 158 | target_var = torch.autograd.Variable(target) 159 | 160 | # compute output 161 | output = model(input_var) 162 | loss = criterion(output, target_var) 163 | 164 | # measure accuracy and record loss 165 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 166 | losses.update(loss.data[0], input.size(0)) 167 | top1.update(prec1[0], input.size(0)) 168 | top5.update(prec5[0], input.size(0)) 169 | 170 | # compute gradient and do SGD step 171 | optimizer.zero_grad() 172 | loss.backward() 173 | optimizer.step() 174 | 175 | # measure elapsed time 176 | batch_time.update(time.time() - end) 177 | end = time.time() 178 | 179 | if i % args.print_freq == 0: 180 | print('Epoch: [{0}][{1}/{2}]\t' 181 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 182 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 184 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 185 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 186 | epoch, i, len(train_loader), batch_time=batch_time, 187 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 188 | 189 | 190 | def validate(val_loader, model, criterion): 191 | batch_time = AverageMeter() 192 | losses = AverageMeter() 193 | top1 = AverageMeter() 194 | top5 = AverageMeter() 195 | 196 | # switch to evaluate mode 197 | model.eval() 198 | 199 | end = time.time() 200 | for i, (input, target) in enumerate(val_loader): 201 | target = target.cuda(async=True) 202 | input_var = torch.autograd.Variable(input, volatile=True) 203 | target_var = torch.autograd.Variable(target, volatile=True) 204 | 205 | # compute output 206 | output = model(input_var) 207 | loss = criterion(output, target_var) 208 | 209 | # measure accuracy and record loss 210 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 211 | losses.update(loss.data[0], input.size(0)) 212 | top1.update(prec1[0], input.size(0)) 213 | top5.update(prec5[0], input.size(0)) 214 | 215 | # measure elapsed time 216 | batch_time.update(time.time() - end) 217 | end = time.time() 218 | 219 | if i % args.print_freq == 0: 220 | print('Test: [{0}/{1}]\t' 221 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 222 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 223 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 224 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 225 | i, len(val_loader), batch_time=batch_time, loss=losses, 226 | top1=top1, top5=top5)) 227 | 228 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 229 | .format(top1=top1, top5=top5)) 230 | 231 | return top1.avg 232 | 233 | 234 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 235 | torch.save(state, filename) 236 | if is_best: 237 | shutil.copyfile(filename, 'model_best.pth.tar') 238 | 239 | 240 | class AverageMeter(object): 241 | """Computes and stores the average and current value""" 242 | def __init__(self): 243 | self.reset() 244 | 245 | def reset(self): 246 | self.val = 0 247 | self.avg = 0 248 | self.sum = 0 249 | self.count = 0 250 | 251 | def update(self, val, n=1): 252 | self.val = val 253 | self.sum += val * n 254 | self.count += n 255 | self.avg = self.sum / self.count 256 | 257 | 258 | def adjust_learning_rate(optimizer, epoch): 259 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 260 | lr = args.lr * (0.1 ** (epoch // 30)) 261 | for param_group in optimizer.param_groups: 262 | param_group['lr'] = lr 263 | 264 | 265 | def accuracy(output, target, topk=(1,)): 266 | """Computes the precision@k for the specified values of k""" 267 | maxk = max(topk) 268 | batch_size = target.size(0) 269 | 270 | _, pred = output.topk(maxk, 1, True, True) 271 | pred = pred.t() 272 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 273 | 274 | res = [] 275 | for k in topk: 276 | correct_k = correct[:k].view(-1).float().sum(0) 277 | res.append(correct_k.mul_(100.0 / batch_size)) 278 | return res 279 | 280 | 281 | if __name__ == '__main__': 282 | main() 283 | -------------------------------------------------------------------------------- /image-pretraining/model: -------------------------------------------------------------------------------- 1 | ../model/ -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as path 4 | import time 5 | from datetime import timedelta 6 | from sys import exit, argv 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.autograd import Variable as V 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms as trn 14 | 15 | from data.VideoFolder import VideoFolder, BatchSampler, VideoCollate 16 | from utils.image_plot import show_four, show_ten 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch MatchNet generative model training script', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | _ = parser.add_argument # define add_argument shortcut 21 | _('--data', type=str, default='./data/processed-data', help='location of the video data') 22 | _('--model', type=str, default='CortexNet', help='type of auto-encoder') 23 | _('--mode', type=str, required=True, help='training mode [MatchNet|TempoNet]') 24 | _('--size', type=int, default=(3, 32, 64, 128, 256), nargs='*', help='number and size of hidden layers', metavar='S') 25 | _('--spatial-size', type=int, default=(256, 256), nargs=2, help='frame cropping size', metavar=('H', 'W')) 26 | _('--lr', type=float, default=0.1, help='initial learning rate') 27 | _('--momentum', type=float, default=0.9, metavar='M', help='momentum') 28 | _('--weight-decay', type=float, default=1e-4, metavar='W', help='weight decay') 29 | _('--mu', type=float, default=1, help='matching MSE multiplier', dest='mu', metavar='μ') 30 | _('--tau', type=float, default=0.1, help='temporal CE multiplier', dest='tau', metavar='τ') 31 | _('--pi', default='τ', help='periodical CE multiplier', dest='pi', metavar='π') 32 | _('--epochs', type=int, default=10, help='upper epoch limit') 33 | _('--batch-size', type=int, default=20, metavar='B', help='batch size') 34 | _('--big-t', type=int, default=10, help='sequence length', metavar='T') 35 | _('--seed', type=int, default=0, help='random seed') 36 | _('--log-interval', type=int, default=10, metavar='N', help='report interval') 37 | _('--save', type=str, default='last/model.pth.tar', help='path to save the final model') 38 | _('--cuda', action='store_true', help='use CUDA') 39 | _('--view', type=int, default=tuple(), help='samples to view at the end of every log-interval batches', metavar='V') 40 | _('--show-x_hat', action='store_true', help='show x_hat') 41 | _('--lr-decay', type=float, default=None, nargs=2, metavar=('D', 'E'), 42 | help='decay of D (e.g. 3.16, 10) times, every E (e.g. 3) epochs') 43 | _('--pre-trained', type=str, default='', help='path to pre-trained model', metavar='P') 44 | args = parser.parse_args() 45 | args.size = tuple(args.size) # cast to tuple 46 | if args.lr_decay: args.lr_decay = tuple(args.lr_decay) 47 | if type(args.view) is int: args.view = (args.view,) # cast to tuple 48 | args.pi = args.tau if args.pi == 'τ' else float(args.pi) 49 | 50 | # Print current options 51 | print('CLI arguments:', ' '.join(argv[1:])) 52 | 53 | # Print current commit 54 | if path.isdir('.git'): # if we are in a repo 55 | with os.popen('git rev-parse HEAD') as pipe: # get the HEAD's hash 56 | print('Current commit hash:', pipe.read(), end='') 57 | 58 | # Set the random seed manually for reproducibility. 59 | torch.manual_seed(args.seed) 60 | if torch.cuda.is_available(): 61 | if not args.cuda: 62 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 63 | else: 64 | torch.cuda.manual_seed(args.seed) 65 | 66 | 67 | def main(): 68 | # Load data 69 | print('Define image pre-processing') 70 | # normalise? do we care? 71 | t = trn.Compose((trn.ToPILImage(), trn.CenterCrop(args.spatial_size), trn.ToTensor())) 72 | 73 | print('Define train data loader') 74 | train_data_name = 'train_data.tar' 75 | if os.access(train_data_name, os.R_OK): 76 | train_data = torch.load(train_data_name) 77 | else: 78 | train_path = path.join(args.data, 'train') 79 | if args.mode == 'MatchNet': 80 | train_data = VideoFolder(root=train_path, transform=t, video_index=True) 81 | elif args.mode == 'TempoNet': 82 | train_data = VideoFolder(root=train_path, transform=t, shuffle=True) 83 | torch.save(train_data, train_data_name) 84 | 85 | train_loader = DataLoader( 86 | dataset=train_data, 87 | batch_size=args.batch_size * args.big_t, # batch_size rows and T columns 88 | shuffle=False, 89 | sampler=BatchSampler(data_source=train_data, batch_size=args.batch_size), # given that BatchSampler knows it 90 | num_workers=1, 91 | collate_fn=VideoCollate(batch_size=args.batch_size), 92 | pin_memory=True 93 | ) 94 | 95 | print('Define validation data loader') 96 | val_data_name = 'val_data.tar' 97 | if os.access(val_data_name, os.R_OK): 98 | val_data = torch.load(val_data_name) 99 | else: 100 | val_path = path.join(args.data, 'val') 101 | if args.mode == 'MatchNet': 102 | val_data = VideoFolder(root=val_path, transform=t, video_index=True) 103 | elif args.mode == 'TempoNet': 104 | val_data = VideoFolder(root=val_path, transform=t, shuffle='init') 105 | torch.save(val_data, val_data_name) 106 | 107 | val_loader = DataLoader( 108 | dataset=val_data, 109 | batch_size=args.batch_size, # just one column of size batch_size 110 | shuffle=False, 111 | sampler=BatchSampler(data_source=val_data, batch_size=args.batch_size), 112 | num_workers=1, 113 | collate_fn=VideoCollate(batch_size=args.batch_size), 114 | pin_memory=True 115 | ) 116 | 117 | # Build the model 118 | if args.model == 'model_01': 119 | from model.Model01 import Model01 as Model 120 | elif args.model == 'model_02' or args.model == 'CortexNet': 121 | from model.Model02 import Model02 as Model 122 | elif args.model == 'model_02_rg': 123 | from model.Model02 import Model02RG as Model 124 | else: 125 | print('\n{:#^80}\n'.format(' Please select a valid model ')) 126 | exit() 127 | 128 | print('Define model') 129 | if args.mode == 'MatchNet': 130 | nb_train_videos = len(train_data.videos) 131 | model = Model(args.size + (nb_train_videos,), args.spatial_size) 132 | elif args.mode == 'TempoNet': 133 | nb_classes = len(train_data.classes) 134 | model = Model(args.size + (nb_classes,), args.spatial_size) 135 | 136 | if args.pre_trained: 137 | print('Load pre-trained weights') 138 | # args.pre_trained = 'image-pretraining/model02D-33IS/model_best.pth.tar' 139 | dict_33 = torch.load(args.pre_trained)['state_dict'] 140 | 141 | def load_state_dict(new_model, state_dict): 142 | own_state = new_model.state_dict() 143 | for name, param in state_dict.items(): 144 | name = name[19:] # remove 'module.inner_model.' part 145 | if name not in own_state: 146 | raise KeyError('unexpected key "{}" in state_dict' 147 | .format(name)) 148 | if name.startswith('stabiliser'): 149 | print('Skipping', name) 150 | continue 151 | if isinstance(param, nn.Parameter): 152 | # backwards compatibility for serialized parameters 153 | param = param.data 154 | own_state[name].copy_(param) 155 | 156 | missing = set(own_state.keys()) - set([k[19:] for k in state_dict.keys()]) 157 | if len(missing) > 0: 158 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 159 | 160 | load_state_dict(model, dict_33) 161 | 162 | print('Create a MSE and balanced NLL criterions') 163 | mse = nn.MSELoss() 164 | 165 | # independent CE computation 166 | nll_final = nn.CrossEntropyLoss(size_average=False) 167 | # balance classes based on frames per video; default balancing weight is 1.0f 168 | w = torch.Tensor(train_data.frames_per_video if args.mode == 'MatchNet' else train_data.frames_per_class) 169 | w.div_(w.mean()).pow_(-1) 170 | nll_train = nn.CrossEntropyLoss(w) 171 | w = torch.Tensor(val_data.frames_per_video if args.mode == 'MatchNet' else val_data.frames_per_class) 172 | w.div_(w.mean()).pow_(-1) 173 | nll_val = nn.CrossEntropyLoss(w) 174 | 175 | if args.cuda: 176 | model.cuda() 177 | mse.cuda() 178 | nll_final.cuda() 179 | nll_train.cuda() 180 | nll_val.cuda() 181 | 182 | print('Instantiate a SGD optimiser') 183 | optimiser = optim.SGD( 184 | params=model.parameters(), 185 | lr=args.lr, 186 | momentum=args.momentum, 187 | weight_decay=args.weight_decay 188 | ) 189 | 190 | # Loop over epochs 191 | for epoch in range(0, args.epochs): 192 | if args.lr_decay: adjust_learning_rate(optimiser, epoch) 193 | epoch_start_time = time.time() 194 | train(train_loader, model, (mse, nll_final, nll_train), optimiser, epoch) 195 | print(80 * '-', '| end of epoch {:3d} |'.format(epoch + 1), sep='\n', end=' ') 196 | val_loss = validate(val_loader, model, (mse, nll_final, nll_val)) 197 | elapsed_time = str(timedelta(seconds=int(time.time() - epoch_start_time))) # HH:MM:SS time format 198 | print('time: {} | mMSE {:.2e} | CE {:.2e} | rpl mMSE {:.2e} | per CE {:.2e} |'. 199 | format(elapsed_time, val_loss['mse'] * 1e3, val_loss['ce'], val_loss['rpl'] * 1e3, val_loss['per_ce'])) 200 | print(80 * '-') 201 | 202 | if args.save != '': 203 | torch.save(model, args.save) 204 | 205 | 206 | def adjust_learning_rate(opt, epoch): 207 | """Sets the learning rate to the initial LR decayed by D every E epochs""" 208 | d, e = args.lr_decay 209 | lr = args.lr * (d ** -(epoch // e)) 210 | for param_group in opt.param_groups: 211 | param_group['lr'] = lr 212 | 213 | 214 | def selective_zero(s, new, forward=True): 215 | if new.any(): # if at least one video changed 216 | b = new.nonzero().squeeze(1) # get the list of indices 217 | if forward: # no state forward, no grad backward 218 | if isinstance(s[0], list): # recurrent G 219 | for layer in range(len(s[0])): # for every layer having a state 220 | s[0][layer] = s[0][layer].index_fill(0, V(b), 0) # mask state, zero selected indices 221 | for layer in range(len(s[1])): # for every layer having a state 222 | s[1][layer] = s[1][layer].index_fill(0, V(b), 0) # mask state, zero selected indices 223 | else: # simple convolutional G 224 | for layer in range(len(s)): # for every layer having a state 225 | s[layer] = s[layer].index_fill(0, V(b), 0) # mask state, zero selected indices 226 | else: # just no grad backward 227 | if isinstance(s[0], list): # recurrent G 228 | for layer in range(len(s[0])): # for every layer having a state 229 | s[0][layer].register_hook(lambda g: g.index_fill(0, V(b), 0)) # zero selected gradients 230 | for layer in range(len(s[1])): # for every layer having a state 231 | s[1][layer].register_hook(lambda g: g.index_fill(0, V(b), 0)) # zero selected gradients 232 | else: # simple convolutional G 233 | for layer in range(len(s)): # for every layer having a state 234 | s[layer].register_hook(lambda g: g.index_fill(0, V(b), 0)) # zero selected gradients 235 | 236 | 237 | def selective_match(x_hat, x, new): 238 | if new.any(): # if at least one video changed 239 | b = new.nonzero().squeeze(1) # get the list of indices 240 | for bb in b: x_hat[bb].copy_(x[bb]) # force the output to be the expected output 241 | 242 | 243 | def selective_cross_entropy(logits, y, new, loss, count): 244 | if not new.any(): 245 | return V(logits.data.new(1).zero_()) # returns a variable, so we don't care about what happened here 246 | b = new.nonzero().squeeze(1) # get the list of indices 247 | count['ce_count'] += len(b) 248 | return loss(logits.index_select(0, V(b)), y.index_select(0, V(b))) # performs loss for selected indices only 249 | 250 | 251 | def train(train_loader, model, loss_fun, optimiser, epoch): 252 | print('Training epoch', epoch + 1) 253 | model.train() # set model in train mode 254 | total_loss = {'mse': 0, 'ce': 0, 'ce_count': 0, 'per_ce': 0, 'rpl': 0} 255 | mse, nll_final, nll_periodic = loss_fun 256 | 257 | def compute_loss(x_, next_x, y_, state_, periodic=False): 258 | nonlocal previous_mismatch # write access to variables of the enclosing function 259 | if args.mode == 'MatchNet': 260 | if not periodic and state_: selective_zero(state_, mismatch, forward=False) # no grad to the past 261 | (x_hat, state_), (_, idx) = model(V(x_), state_) 262 | selective_zero(state_, mismatch) # no state to the future, no grad from the future 263 | selective_match(x_hat.data, next_x, mismatch + previous_mismatch) # last frame or first frame 264 | previous_mismatch = mismatch # last frame <- first frame 265 | mse_loss_ = mse(x_hat, V(next_x)) 266 | total_loss['mse'] += mse_loss_.data[0] 267 | ce_loss_ = selective_cross_entropy(idx, V(y_), mismatch, nll_final, total_loss) 268 | total_loss['ce'] += ce_loss_.data[0] 269 | if periodic: 270 | ce_loss_ = (ce_loss_, nll_periodic(idx, V(y_))) 271 | total_loss['per_ce'] += ce_loss_[1].data[0] 272 | total_loss['rpl'] += mse(x_hat, V(x_, volatile=True)).data[0] 273 | return ce_loss_, mse_loss_, state_, x_hat.data 274 | 275 | data_time = 0 276 | batch_time = 0 277 | end_time = time.time() 278 | state = None # reset state at the beginning of a new epoch 279 | from_past = None # performs only T - 1 steps for the first temporal batch 280 | previous_mismatch = torch.ByteTensor(args.batch_size).fill_(1) # ignore first prediction 281 | if args.cuda: previous_mismatch = previous_mismatch.cuda() 282 | for batch_nb, (x, y) in enumerate(train_loader): 283 | data_time += time.time() - end_time 284 | if args.cuda: 285 | x = x.cuda(async=True) 286 | y = y.cuda(async=True) 287 | state = repackage_state(state) 288 | loss = 0 289 | # BTT loop 290 | if args.mode == 'MatchNet': 291 | if from_past: 292 | mismatch = y[0] != from_past[1] 293 | ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state, periodic=True) 294 | loss += mse_loss * args.mu + ce_loss[0] * args.tau + ce_loss[1] * args.pi 295 | for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward 296 | mismatch = y[t + 1] != y[t] 297 | ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state) 298 | loss += mse_loss * args.mu + ce_loss * args.tau 299 | elif args.mode == 'TempoNet': 300 | if from_past: 301 | mismatch = y[0] != from_past[1] 302 | ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state) 303 | loss += mse_loss * args.mu + ce_loss * args.tau 304 | for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward 305 | mismatch = y[t + 1] != y[t] 306 | last = t == min(args.big_t, x.size(0)) - 2 307 | ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state, periodic=last) 308 | if not last: 309 | loss += mse_loss * args.mu + ce_loss * args.tau 310 | else: 311 | loss += mse_loss * args.mu + ce_loss[0] * args.tau + ce_loss[1] * args.pi 312 | 313 | # compute gradient and do SGD step 314 | model.zero_grad() 315 | loss.backward() 316 | optimiser.step() 317 | 318 | # save last column for future 319 | from_past = x[-1], y[-1] 320 | 321 | # measure batch time 322 | batch_time += time.time() - end_time 323 | end_time = time.time() # for computing data_time 324 | 325 | if (batch_nb + 1) % args.log_interval == 0: 326 | if args.view: 327 | for f in args.view: 328 | show_four(x[t][f], x[t + 1][f], x_hat_data[f], f + 1) 329 | if args.show_x_hat: show_ten(x[t][f], x_hat_data[f]) 330 | total_loss['mse'] /= args.log_interval * args.big_t 331 | total_loss['rpl'] /= args.log_interval * args.big_t 332 | total_loss['per_ce'] /= args.log_interval 333 | if total_loss['ce_count']: total_loss['ce'] /= total_loss['ce_count'] 334 | avg_batch_time = batch_time * 1e3 / args.log_interval 335 | avg_data_time = data_time * 1e3 / args.log_interval 336 | lr = optimiser.param_groups[0]['lr'] # assumes len(param_groups) == 1 337 | print('| epoch {:3d} | {:4d}/{:4d} batches | lr {:.3f} |' 338 | ' ms/batch {:7.2f} | ms/data {:7.2f} | mMSE {:.2e} | CE {:.2e} | rpl mMSE {:.2e} | per CE {:.2e} |'. 339 | format(epoch + 1, batch_nb + 1, len(train_loader), lr, avg_batch_time, avg_data_time, 340 | total_loss['mse'] * 1e3, total_loss['ce'], total_loss['rpl'] * 1e3, total_loss['per_ce'])) 341 | for k in total_loss: total_loss[k] = 0 # zero the losses 342 | batch_time = 0 343 | data_time = 0 344 | 345 | 346 | def validate(val_loader, model, loss_fun): 347 | model.eval() # set model in evaluation mode 348 | total_loss = {'mse': 0, 'ce': 0, 'ce_count': 0, 'per_ce': 0, 'rpl': 0} 349 | mse, nll_final, nll_periodic = loss_fun 350 | batches = enumerate(val_loader) 351 | 352 | _, (x, y) = next(batches) 353 | if args.cuda: 354 | x = x.cuda(async=True) 355 | y = y.cuda(async=True) 356 | previous_mismatch = y[0].byte().fill_(1) # ignore first prediction 357 | state = None # reset state at the beginning of a new epoch 358 | for batch_nb, (next_x, next_y) in batches: 359 | if args.cuda: 360 | next_x = next_x.cuda(async=True) 361 | next_y = next_y.cuda(async=True) 362 | mismatch = next_y[0] != y[0] 363 | (x_hat, state), (_, idx) = model(V(x[0], volatile=True), state) # do not compute graph (volatile) 364 | selective_zero(state, mismatch) # no state to the future 365 | selective_match(x_hat.data, next_x[0], mismatch + previous_mismatch) # last frame or first frame 366 | previous_mismatch = mismatch # last frame <- first frame 367 | total_loss['mse'] += mse(x_hat, V(next_x[0])).data[0] 368 | ce_loss = selective_cross_entropy(idx, V(y[0]), mismatch, nll_final, total_loss) 369 | total_loss['ce'] += ce_loss.data[0] 370 | if batch_nb % args.big_t == 0: total_loss['per_ce'] += nll_periodic(idx, V(y[0])).data[0] 371 | total_loss['rpl'] += mse(x_hat, V(x[0])).data[0] 372 | x, y = next_x, next_y 373 | 374 | total_loss['mse'] /= len(val_loader) # average out 375 | total_loss['rpl'] /= len(val_loader) # average out 376 | total_loss['per_ce'] /= len(val_loader) / args.big_t # average out 377 | total_loss['ce'] /= total_loss['ce_count'] # average out 378 | return total_loss 379 | 380 | 381 | def repackage_state(h): 382 | """ 383 | Wraps hidden states in new Variables, to detach them from their history. 384 | """ 385 | if not h: 386 | return None 387 | elif type(h) == V: 388 | return V(h.data) 389 | else: 390 | return list(repackage_state(v) for v in h) 391 | 392 | 393 | if __name__ == '__main__': 394 | main() 395 | 396 | __author__ = "Alfredo Canziani" 397 | __credits__ = ["Alfredo Canziani"] 398 | __maintainer__ = "Alfredo Canziani" 399 | __email__ = "alfredo.canziani@gmail.com" 400 | __status__ = "Production" # "Prototype", "Development", or "Production" 401 | __date__ = "Feb 17" 402 | -------------------------------------------------------------------------------- /model/ConvLSTMCell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as f 4 | from torch.autograd import Variable 5 | 6 | 7 | # Define some constants 8 | KERNEL_SIZE = 3 9 | PADDING = KERNEL_SIZE // 2 10 | 11 | 12 | class ConvLSTMCell(nn.Module): 13 | """ 14 | Generate a convolutional LSTM cell 15 | """ 16 | 17 | def __init__(self, input_size, hidden_size): 18 | super().__init__() 19 | self.input_size = input_size 20 | self.hidden_size = hidden_size 21 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING) 22 | 23 | def forward(self, input_, prev_state): 24 | 25 | # get batch and spatial sizes 26 | batch_size = input_.data.size()[0] 27 | spatial_size = input_.data.size()[2:] 28 | 29 | # generate empty prev_state, if None is provided 30 | if prev_state is None: 31 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 32 | prev_state = ( 33 | Variable(torch.zeros(state_size)), 34 | Variable(torch.zeros(state_size)) 35 | ) 36 | 37 | prev_hidden, prev_cell = prev_state 38 | 39 | # data size is [batch, channel, height, width] 40 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 41 | gates = self.Gates(stacked_inputs) 42 | 43 | # chunk across channel dimension 44 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 45 | 46 | # apply sigmoid non linearity 47 | in_gate = f.sigmoid(in_gate) 48 | remember_gate = f.sigmoid(remember_gate) 49 | out_gate = f.sigmoid(out_gate) 50 | 51 | # apply tanh non linearity 52 | cell_gate = f.tanh(cell_gate) 53 | 54 | # compute current cell and hidden state 55 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 56 | hidden = out_gate * f.tanh(cell) 57 | 58 | return hidden, cell 59 | 60 | 61 | def _main(): 62 | """ 63 | Run some basic tests on the API 64 | """ 65 | 66 | # define batch_size, channels, height, width 67 | b, c, h, w = 1, 3, 4, 8 68 | d = 5 # hidden state size 69 | lr = 1e-1 # learning rate 70 | T = 6 # sequence length 71 | max_epoch = 20 # number of epochs 72 | 73 | # set manual seed 74 | torch.manual_seed(0) 75 | 76 | print('Instantiate model') 77 | model = ConvLSTMCell(c, d) 78 | print(repr(model)) 79 | 80 | print('Create input and target Variables') 81 | x = Variable(torch.rand(T, b, c, h, w)) 82 | y = Variable(torch.randn(T, b, d, h, w)) 83 | 84 | print('Create a MSE criterion') 85 | loss_fn = nn.MSELoss() 86 | 87 | print('Run for', max_epoch, 'iterations') 88 | for epoch in range(0, max_epoch): 89 | state = None 90 | loss = 0 91 | for t in range(0, T): 92 | state = model(x[t], state) 93 | loss += loss_fn(state[0], y[t]) 94 | 95 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0])) 96 | 97 | # zero grad parameters 98 | model.zero_grad() 99 | 100 | # compute new grad parameters through time! 101 | loss.backward() 102 | 103 | # learning_rate step against the gradient 104 | for p in model.parameters(): 105 | p.data.sub_(p.grad.data * lr) 106 | 107 | print('Input size:', list(x.data.size())) 108 | print('Target size:', list(y.data.size())) 109 | print('Last hidden state size:', list(state[0].size())) 110 | 111 | 112 | if __name__ == '__main__': 113 | _main() 114 | 115 | 116 | __author__ = "Alfredo Canziani" 117 | __credits__ = ["Alfredo Canziani"] 118 | __maintainer__ = "Alfredo Canziani" 119 | __email__ = "alfredo.canziani@gmail.com" 120 | __status__ = "Prototype" # "Prototype", "Development", or "Production" 121 | __date__ = "Jan 17" 122 | -------------------------------------------------------------------------------- /model/DiscriminativeCell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as f 4 | from torch.autograd import Variable 5 | 6 | 7 | # Define some constants 8 | KERNEL_SIZE = 3 9 | PADDING = KERNEL_SIZE // 2 10 | POOL = 2 11 | 12 | 13 | class DiscriminativeCell(nn.Module): 14 | """ 15 | Single discriminative layer 16 | """ 17 | 18 | def __init__(self, input_size, hidden_size, first=False): 19 | """ 20 | Create a discriminative cell (bottom_up, r_state) -> error 21 | 22 | :param input_size: {'input': bottom_up_size, 'state': r_state_size} 23 | :param hidden_size: int, shooting dimensionality 24 | :param first: True/False 25 | """ 26 | super().__init__() 27 | self.input_size = input_size 28 | self.hidden_size = hidden_size 29 | self.first = first 30 | if not first: 31 | self.from_bottom = nn.Conv2d(input_size['input'], hidden_size, KERNEL_SIZE, padding=PADDING) 32 | self.from_state = nn.Conv2d(input_size['state'], hidden_size, KERNEL_SIZE, padding=PADDING) 33 | 34 | def forward(self, bottom_up, state): 35 | input_projection = self.first and bottom_up or f.relu(f.max_pool2d(self.from_bottom(bottom_up), POOL, POOL)) 36 | state_projection = f.relu(self.from_state(state)) 37 | error = f.relu(torch.cat((input_projection - state_projection, state_projection - input_projection), 1)) 38 | return error 39 | 40 | 41 | def _test_layer1(): 42 | print('Define model for layer 1') 43 | discriminator = DiscriminativeCell(input_size={'input': 3, 'state': 3}, hidden_size=3, first=True) 44 | 45 | print('Define input and state') 46 | # at the first layer we have that system_state match the input_image dimensionality 47 | input_image = Variable(torch.rand(1, 3, 8, 12)) 48 | system_state = Variable(torch.randn(1, 3, 8, 12)) 49 | 50 | print('Input has size', list(input_image.data.size())) 51 | 52 | print('Forward input and state to the model') 53 | e = discriminator(input_image, system_state) 54 | 55 | # print output size 56 | print('Layer 1 error has size', list(e.data.size())) 57 | 58 | return e 59 | 60 | 61 | def _test_layer2(input_error): 62 | print('Define model for layer 2') 63 | discriminator = DiscriminativeCell(input_size={'input': 6, 'state': 32}, hidden_size=32, first=False) 64 | 65 | print('Define a new, smaller state') 66 | system_state = Variable(torch.randn(1, 32, 4, 6)) 67 | 68 | print('Forward layer 1 output and state to the model') 69 | e = discriminator(input_error, system_state) 70 | 71 | # print output size 72 | print('Layer 2 error has size', list(e.data.size())) 73 | 74 | 75 | def _test_layers(): 76 | error = _test_layer1() 77 | _test_layer2(input_error=error) 78 | 79 | 80 | if __name__ == '__main__': 81 | _test_layers() 82 | 83 | 84 | __author__ = "Alfredo Canziani" 85 | __credits__ = ["Alfredo Canziani"] 86 | __maintainer__ = "Alfredo Canziani" 87 | __email__ = "alfredo.canziani@gmail.com" 88 | __status__ = "Prototype" # "Prototype", "Development", or "Production" 89 | __date__ = "Feb 17" 90 | -------------------------------------------------------------------------------- /model/GenerativeCell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as f 4 | from torch.autograd import Variable 5 | 6 | from model.ConvLSTMCell import ConvLSTMCell 7 | 8 | 9 | class GenerativeCell(nn.Module): 10 | """ 11 | Single generative layer 12 | """ 13 | 14 | def __init__(self, input_size, hidden_size, error_init_size=None): 15 | """ 16 | Create a generative cell (error, top_down_state, r_state) -> r_state 17 | 18 | :param input_size: {'error': error_size, 'up_state': r_state_size}, r_state_size can be 0 19 | :param hidden_size: int, shooting dimensionality 20 | :param error_init_size: tuple, full size of initial (null) error 21 | """ 22 | super().__init__() 23 | self.input_size = input_size 24 | self.hidden_size = hidden_size 25 | self.error_init_size = error_init_size 26 | self.memory = ConvLSTMCell(input_size['error']+input_size['up_state'], hidden_size) 27 | 28 | def forward(self, error, top_down_state, state): 29 | if error is None: # we just started 30 | error = Variable(torch.zeros(self.error_init_size)) 31 | model_input = error 32 | if top_down_state is not None: 33 | model_input = torch.cat((error, f.upsample_nearest(top_down_state, scale_factor=2)), 1) 34 | return self.memory(model_input, state) 35 | 36 | 37 | def _test_layer2(): 38 | print('Define model for layer 2') 39 | generator = GenerativeCell(input_size={'error': 2*16, 'up_state': 0}, hidden_size=16) 40 | 41 | print('Define error and top down state') 42 | input_error = Variable(torch.randn(1, 2*16, 4, 6)) 43 | topdown_state = None 44 | 45 | print('Input error has size', list(input_error.data.size())) 46 | print('Top down state is None') 47 | 48 | print('Forward error and top down state to the model') 49 | state = None 50 | state = generator(input_error, topdown_state, state) 51 | 52 | # print output size 53 | print('Layer 2 state has size', list(state[0].data.size())) 54 | 55 | return state[0] # the element 1 is the cell state 56 | 57 | 58 | def _test_layer1(top_down_state): 59 | print('Define model for layer 1') 60 | generator = GenerativeCell(input_size={'error': 2*3, 'up_state': 16}, hidden_size=3) 61 | 62 | print('Define error and top down state') 63 | input_error = Variable(torch.randn(1, 2*3, 8, 12)) 64 | 65 | print('Input error has size', list(input_error.data.size())) 66 | print('Top down state has size', list(top_down_state.data.size())) 67 | 68 | print('Forward error and top down state to the model') 69 | state = None 70 | state = generator(input_error, top_down_state, state) 71 | 72 | # print output size 73 | print('Layer 1 state has size', list(state[0].data.size())) 74 | 75 | 76 | def _test_layers(): 77 | state = _test_layer2() 78 | _test_layer1(top_down_state=state) 79 | 80 | 81 | if __name__ == '__main__': 82 | _test_layers() 83 | 84 | 85 | __author__ = "Alfredo Canziani" 86 | __credits__ = ["Alfredo Canziani"] 87 | __maintainer__ = "Alfredo Canziani" 88 | __email__ = "alfredo.canziani@gmail.com" 89 | __status__ = "Prototype" # "Prototype", "Development", or "Production" 90 | __date__ = "Feb 17" 91 | -------------------------------------------------------------------------------- /model/Model01.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as f 4 | from torch.autograd import Variable as V 5 | from math import ceil 6 | 7 | 8 | # Define some constants 9 | KERNEL_SIZE = 3 10 | PADDING = KERNEL_SIZE // 2 11 | KERNEL_STRIDE = 2 12 | OUTPUT_ADJUST = KERNEL_SIZE - 2 * PADDING 13 | 14 | 15 | class Model01(nn.Module): 16 | """ 17 | Generate a constructor for model_01 type of network 18 | """ 19 | 20 | def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None: 21 | """ 22 | Initialise Model01 constructor 23 | 24 | :param network_size: (n, h1, h2, ..., emb_size, nb_videos) 25 | :type network_size: tuple 26 | :param input_spatial_size: (height, width) 27 | :type input_spatial_size: tuple 28 | """ 29 | super().__init__() 30 | self.hidden_layers = len(network_size) - 2 31 | 32 | print('\n{:-^80}'.format(' Building model ')) 33 | print('Hidden layers:', self.hidden_layers) 34 | print('Net sizing:', network_size) 35 | print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size)) 36 | 37 | # main auto-encoder blocks 38 | self.activation_size = [input_spatial_size] 39 | for layer in range(0, self.hidden_layers): 40 | # print some annotation when building model 41 | print('{:-<80}'.format('Layer ' + str(layer + 1) + ' ')) 42 | print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1])) 43 | self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer])) 44 | print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1])) 45 | 46 | # init D (discriminative) blocks 47 | setattr(self, 'D_' + str(layer + 1), nn.Conv2d( 48 | in_channels=network_size[layer], out_channels=network_size[layer + 1], 49 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING 50 | )) 51 | setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1])) 52 | 53 | # init G (generative) blocks 54 | setattr(self, 'G_' + str(layer + 1), nn.ConvTranspose2d( 55 | in_channels=network_size[layer + 1], out_channels=network_size[layer], 56 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING 57 | )) 58 | setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer])) 59 | 60 | # init auxiliary classifier 61 | print('{:-<80}'.format('Classifier ')) 62 | print(network_size[-2], '-->', network_size[-1]) 63 | self.average = nn.AvgPool2d(self.activation_size[-1]) 64 | self.stabiliser = nn.Linear(network_size[-2], network_size[-1]) 65 | print(80 * '-', end='\n\n') 66 | 67 | def forward(self, x, state): 68 | activation_sizes = [x.size()] # start from the input 69 | residuals = list() 70 | for layer in range(0, self.hidden_layers): # connect discriminative blocks 71 | x = getattr(self, 'D_' + str(layer + 1))(x) 72 | residuals.append(x) 73 | if layer < self.hidden_layers - 1 and state: x += state[layer] 74 | x = f.relu(x) 75 | x = getattr(self, 'BN_D_' + str(layer + 1))(x) 76 | activation_sizes.append(x.size()) # cache output size for later retrieval 77 | state = state or [None] * (self.hidden_layers - 1) 78 | for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks 79 | x = getattr(self, 'G_' + str(layer + 1))(x, activation_sizes[layer]) 80 | if layer: 81 | state[layer - 1] = x 82 | x += residuals[layer - 1] 83 | x = f.relu(x) 84 | x = getattr(self, 'BN_G_' + str(layer + 1))(x) 85 | x_mean = self.average(residuals[-1]) 86 | video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1)) 87 | 88 | return (x, state), (x_mean, video_index) 89 | 90 | 91 | def _test_model(): 92 | T = 2 93 | x = torch.rand(T + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5) 94 | K = 10 95 | y = torch.LongTensor(T, 1).random_(K) 96 | model_01 = Model01(network_size=(3, 6, 12, 18, K), input_spatial_size=x[0].size()[2:]) 97 | 98 | state = None 99 | (x_hat, state), (emb, idx) = model_01(V(x[0]), state) 100 | 101 | print('Input size:', tuple(x.size())) 102 | print('Output size:', tuple(x_hat.data.size())) 103 | print('Video index size:', tuple(idx.size())) 104 | for i, s in enumerate(state): 105 | print('State', i + 1, 'has size:', tuple(s.size())) 106 | print('Embedding has size:', emb.data.numel()) 107 | 108 | mse = nn.MSELoss() 109 | nll = nn.CrossEntropyLoss() 110 | x_next = V(x[1]) 111 | y_var = V(y[0]) 112 | loss_t1 = mse(x_hat, x_next) + nll(idx, y_var) 113 | 114 | from utils.visualise import show_graph 115 | show_graph(loss_t1) 116 | 117 | # run one more time 118 | (x_hat, _), (_, idx) = model_01(V(x[1]), state) 119 | 120 | x_next = V(x[2]) 121 | y_var = V(y[1]) 122 | loss_t2 = mse(x_hat, x_next) + nll(idx, y_var) 123 | loss_tot = loss_t2 + loss_t1 124 | 125 | show_graph(loss_tot) 126 | 127 | 128 | def _test_training(): 129 | K = 10 # number of training videos 130 | network_size = (3, 6, 12, 18, K) 131 | T = 6 # sequence length 132 | max_epoch = 10 # number of epochs 133 | lr = 1e-1 # learning rate 134 | 135 | # set manual seed 136 | torch.manual_seed(0) 137 | 138 | print('\n{:-^80}'.format(' Train a ' + str(network_size[:-1]) + ' layer network ')) 139 | print('Sequence length T:', T) 140 | print('Create the input image and target sequences') 141 | x = torch.rand(T + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5) 142 | y = torch.LongTensor(T, 1).random_(K) 143 | print('Input has size', tuple(x.size())) 144 | print('Target index has size', tuple(y.size())) 145 | 146 | print('Define model') 147 | model = Model01(network_size=network_size, input_spatial_size=x[0].size()[2:]) 148 | 149 | print('Create a MSE and NLL criterions') 150 | mse = nn.MSELoss() 151 | nll = nn.CrossEntropyLoss() 152 | 153 | print('Run for', max_epoch, 'iterations') 154 | for epoch in range(0, max_epoch): 155 | state = None 156 | loss = 0 157 | for t in range(0, T): 158 | (x_hat, state), (emb, idx) = model(V(x[t]), state) 159 | loss += mse(x_hat, V(x[t + 1])) + nll(idx, V(y[t])) 160 | 161 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch + 1), loss.data[0])) 162 | 163 | # zero grad parameters 164 | model.zero_grad() 165 | 166 | # compute new grad parameters through time! 167 | loss.backward() 168 | 169 | # learning_rate step against the gradient 170 | for p in model.parameters(): 171 | p.data.sub_(p.grad.data * lr) 172 | 173 | 174 | if __name__ == '__main__': 175 | _test_model() 176 | _test_training() 177 | 178 | 179 | __author__ = "Alfredo Canziani" 180 | __credits__ = ["Alfredo Canziani"] 181 | __maintainer__ = "Alfredo Canziani" 182 | __email__ = "alfredo.canziani@gmail.com" 183 | __status__ = "Production" # "Prototype", "Development", or "Production" 184 | __date__ = "Feb 17" 185 | -------------------------------------------------------------------------------- /model/Model02.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as f 4 | from torch.autograd import Variable as V 5 | from math import ceil 6 | 7 | 8 | # Define some constants 9 | from model.RG import RG 10 | 11 | KERNEL_SIZE = 3 12 | PADDING = KERNEL_SIZE // 2 13 | KERNEL_STRIDE = 2 14 | OUTPUT_ADJUST = KERNEL_SIZE - 2 * PADDING 15 | 16 | 17 | class Model02(nn.Module): 18 | """ 19 | Generate a constructor for model_02 type of network 20 | """ 21 | 22 | def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None: 23 | """ 24 | Initialise Model02 constructor 25 | 26 | :param network_size: (n, h1, h2, ..., emb_size, nb_videos) 27 | :type network_size: tuple 28 | :param input_spatial_size: (height, width) 29 | :type input_spatial_size: tuple 30 | """ 31 | super().__init__() 32 | self.hidden_layers = len(network_size) - 2 33 | 34 | print('\n{:-^80}'.format(' Building model Model02 ')) 35 | print('Hidden layers:', self.hidden_layers) 36 | print('Net sizing:', network_size) 37 | print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size)) 38 | 39 | # main auto-encoder blocks 40 | self.activation_size = [input_spatial_size] 41 | for layer in range(0, self.hidden_layers): 42 | # print some annotation when building model 43 | print('{:-<80}'.format('Layer ' + str(layer + 1) + ' ')) 44 | print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1])) 45 | self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer])) 46 | print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1])) 47 | 48 | # init D (discriminative) blocks 49 | multiplier = layer and 2 or 1 # D_n, n > 1, has intra-layer feedback 50 | setattr(self, 'D_' + str(layer + 1), nn.Conv2d( 51 | in_channels=network_size[layer] * multiplier, out_channels=network_size[layer + 1], 52 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING 53 | )) 54 | setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1])) 55 | 56 | # init G (generative) blocks 57 | setattr(self, 'G_' + str(layer + 1), nn.ConvTranspose2d( 58 | in_channels=network_size[layer + 1], out_channels=network_size[layer], 59 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING 60 | )) 61 | setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer])) 62 | 63 | # init auxiliary classifier 64 | print('{:-<80}'.format('Classifier ')) 65 | print(network_size[-2], '-->', network_size[-1]) 66 | self.average = nn.AvgPool2d(self.activation_size[-1]) 67 | self.stabiliser = nn.Linear(network_size[-2], network_size[-1]) 68 | print(80 * '-', end='\n\n') 69 | 70 | def forward(self, x, state): 71 | activation_sizes = [x.size()] # start from the input 72 | residuals = list() 73 | state = state or [None] * (self.hidden_layers - 1) 74 | for layer in range(0, self.hidden_layers): # connect discriminative blocks 75 | if layer: # concat the input with the state for D_n, n > 1 76 | s = state[layer - 1] or V(x.data.clone().zero_()) 77 | x = torch.cat((x, s), 1) 78 | x = getattr(self, 'D_' + str(layer + 1))(x) 79 | residuals.append(x) 80 | x = f.relu(x) 81 | x = getattr(self, 'BN_D_' + str(layer + 1))(x) 82 | activation_sizes.append(x.size()) # cache output size for later retrieval 83 | for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks 84 | x = getattr(self, 'G_' + str(layer + 1))(x, activation_sizes[layer]) 85 | if layer: 86 | state[layer - 1] = x 87 | x += residuals[layer - 1] 88 | x = f.relu(x) 89 | x = getattr(self, 'BN_G_' + str(layer + 1))(x) 90 | x_mean = self.average(residuals[-1]) 91 | video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1)) 92 | 93 | return (x, state), (x_mean, video_index) 94 | 95 | 96 | class Model02RG(nn.Module): 97 | """ 98 | Generate a constructor for model_02_rg type of network 99 | """ 100 | 101 | def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None: 102 | """ 103 | Initialise Model02RG constructor 104 | 105 | :param network_size: (n, h1, h2, ..., emb_size, nb_videos) 106 | :type network_size: tuple 107 | :param input_spatial_size: (height, width) 108 | :type input_spatial_size: tuple 109 | """ 110 | super().__init__() 111 | self.hidden_layers = len(network_size) - 2 112 | 113 | print('\n{:-^80}'.format(' Building model Model02RG ')) 114 | print('Hidden layers:', self.hidden_layers) 115 | print('Net sizing:', network_size) 116 | print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size)) 117 | 118 | # main auto-encoder blocks 119 | self.activation_size = [input_spatial_size] 120 | for layer in range(0, self.hidden_layers): 121 | # print some annotation when building model 122 | print('{:-<80}'.format('Layer ' + str(layer + 1) + ' ')) 123 | print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1])) 124 | self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer])) 125 | print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1])) 126 | 127 | # init D (discriminative) blocks 128 | multiplier = layer and 2 or 1 # D_n, n > 1, has intra-layer feedback 129 | setattr(self, 'D_' + str(layer + 1), nn.Conv2d( 130 | in_channels=network_size[layer] * multiplier, out_channels=network_size[layer + 1], 131 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING 132 | )) 133 | setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1])) 134 | 135 | # init G (generative) blocks 136 | setattr(self, 'G_' + str(layer + 1), RG( 137 | in_channels=network_size[layer + 1], out_channels=network_size[layer], 138 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING 139 | )) 140 | setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer])) 141 | 142 | # init auxiliary classifier 143 | print('{:-<80}'.format('Classifier ')) 144 | print(network_size[-2], '-->', network_size[-1]) 145 | self.average = nn.AvgPool2d(self.activation_size[-1]) 146 | self.stabiliser = nn.Linear(network_size[-2], network_size[-1]) 147 | print(80 * '-', end='\n\n') 148 | 149 | def forward(self, x, state): 150 | activation_sizes = [x.size()] # start from the input 151 | residuals = list() 152 | # state[0] --> network layer state; state[1] --> generative state 153 | state = state or [[None] * (self.hidden_layers - 1), [None] * self.hidden_layers] 154 | for layer in range(0, self.hidden_layers): # connect discriminative blocks 155 | if layer: # concat the input with the state for D_n, n > 1 156 | s = state[0][layer - 1] or V(x.data.clone().zero_()) 157 | x = torch.cat((x, s), 1) 158 | x = getattr(self, 'D_' + str(layer + 1))(x) 159 | residuals.append(x) 160 | x = f.relu(x) 161 | x = getattr(self, 'BN_D_' + str(layer + 1))(x) 162 | activation_sizes.append(x.size()) # cache output size for later retrieval 163 | for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks 164 | x = getattr(self, 'G_' + str(layer + 1))((x, activation_sizes[layer]), state[1][layer]) 165 | state[1][layer] = x # h[t - 1] <- h[t] 166 | if layer: 167 | state[0][layer - 1] = x 168 | x += residuals[layer - 1] 169 | x = f.relu(x) 170 | x = getattr(self, 'BN_G_' + str(layer + 1))(x) 171 | x_mean = self.average(residuals[-1]) 172 | video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1)) 173 | 174 | return (x, state), (x_mean, video_index) 175 | 176 | 177 | def _test_models(): 178 | _test_model(Model02) 179 | _test_model(Model02RG) 180 | 181 | 182 | def _test_model(Model): 183 | big_t = 2 184 | x = torch.rand(big_t + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5) 185 | big_k = 10 186 | y = torch.LongTensor(big_t, 1).random_(big_k) 187 | model = Model(network_size=(3, 6, 12, 18, big_k), input_spatial_size=x[0].size()[2:]) 188 | 189 | state = None 190 | (x_hat, state), (emb, idx) = model(V(x[0]), state) 191 | 192 | print('Input size:', tuple(x.size())) 193 | print('Output size:', tuple(x_hat.data.size())) 194 | print('Video index size:', tuple(idx.size())) 195 | for i, s in enumerate(state): 196 | if isinstance(s, list): 197 | for i, s in enumerate(state[0]): 198 | print('Net state', i + 1, 'has size:', tuple(s.size())) 199 | for i, s in enumerate(state[1]): 200 | print('G', i + 1, 'state has size:', tuple(s.size())) 201 | break 202 | else: 203 | print('State', i + 1, 'has size:', tuple(s.size())) 204 | print('Embedding has size:', emb.data.numel()) 205 | 206 | mse = nn.MSELoss() 207 | nll = nn.CrossEntropyLoss() 208 | x_next = V(x[1]) 209 | y_var = V(y[0]) 210 | loss_t1 = mse(x_hat, x_next) + nll(idx, y_var) 211 | 212 | from utils.visualise import show_graph 213 | show_graph(loss_t1) 214 | 215 | # run one more time 216 | (x_hat, _), (_, idx) = model(V(x[1]), state) 217 | 218 | x_next = V(x[2]) 219 | y_var = V(y[1]) 220 | loss_t2 = mse(x_hat, x_next) + nll(idx, y_var) 221 | loss_tot = loss_t2 + loss_t1 222 | 223 | show_graph(loss_tot) 224 | 225 | 226 | def _test_training_models(): 227 | _test_training(Model02) 228 | _test_training(Model02RG) 229 | 230 | 231 | def _test_training(Model): 232 | big_k = 10 # number of training videos 233 | network_size = (3, 6, 12, 18, big_k) 234 | big_t = 6 # sequence length 235 | max_epoch = 10 # number of epochs 236 | lr = 3.16e-2 # learning rate 237 | 238 | # set manual seed 239 | torch.manual_seed(0) 240 | 241 | print('\n{:-^80}'.format(' Train a ' + str(network_size[:-1]) + ' layer network ')) 242 | print('Sequence length T:', big_t) 243 | print('Create the input image and target sequences') 244 | x = torch.rand(big_t + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5) 245 | y = torch.LongTensor(big_t, 1).random_(big_k) 246 | print('Input has size', tuple(x.size())) 247 | print('Target index has size', tuple(y.size())) 248 | 249 | print('Define model') 250 | model = Model(network_size=network_size, input_spatial_size=x[0].size()[2:]) 251 | 252 | print('Create a MSE and NLL criterions') 253 | mse = nn.MSELoss() 254 | nll = nn.CrossEntropyLoss() 255 | 256 | print('Run for', max_epoch, 'iterations') 257 | for epoch in range(0, max_epoch): 258 | state = None 259 | loss = 0 260 | for t in range(0, big_t): 261 | (x_hat, state), (emb, idx) = model(V(x[t]), state) 262 | loss += mse(x_hat, V(x[t + 1])) + nll(idx, V(y[t])) 263 | 264 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch + 1), loss.data[0])) 265 | 266 | # zero grad parameters 267 | model.zero_grad() 268 | 269 | # compute new grad parameters through time! 270 | loss.backward() 271 | 272 | # learning_rate step against the gradient 273 | for p in model.parameters(): 274 | p.data.sub_(p.grad.data * lr) 275 | 276 | 277 | if __name__ == '__main__': 278 | _test_models() 279 | _test_training_models() 280 | 281 | 282 | __author__ = "Alfredo Canziani" 283 | __credits__ = ["Alfredo Canziani"] 284 | __maintainer__ = "Alfredo Canziani" 285 | __email__ = "alfredo.canziani@gmail.com" 286 | __status__ = "Production" # "Prototype", "Development", or "Production" 287 | __date__ = "Feb, Mar 17" 288 | -------------------------------------------------------------------------------- /model/PrednetModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | from model.DiscriminativeCell import DiscriminativeCell 6 | from model.GenerativeCell import GenerativeCell 7 | 8 | 9 | # Define some constants 10 | OUT_LAYER_SIZE = (3,) + tuple(2 ** p for p in range(4, 10)) 11 | ERR_LAYER_SIZE = tuple(size * 2 for size in OUT_LAYER_SIZE) 12 | IN_LAYER_SIZE = (3,) + ERR_LAYER_SIZE 13 | 14 | 15 | class PrednetModel(nn.Module): 16 | """ 17 | Build the Prednet model 18 | """ 19 | 20 | def __init__(self, error_size_list): 21 | super().__init__() 22 | self.number_of_layers = len(error_size_list) 23 | for layer in range(0, self.number_of_layers): 24 | setattr(self, 'discriminator_' + str(layer + 1), DiscriminativeCell( 25 | input_size={'input': IN_LAYER_SIZE[layer], 'state': OUT_LAYER_SIZE[layer]}, 26 | hidden_size=OUT_LAYER_SIZE[layer], 27 | first=(not layer) 28 | )) 29 | setattr(self, 'generator_' + str(layer + 1), GenerativeCell( 30 | input_size={'error': ERR_LAYER_SIZE[layer], 'up_state': 31 | OUT_LAYER_SIZE[layer + 1] if layer != self.number_of_layers - 1 else 0}, 32 | hidden_size=OUT_LAYER_SIZE[layer], 33 | error_init_size=error_size_list[layer] 34 | )) 35 | 36 | def forward(self, bottom_up_input, error, state): 37 | 38 | # generative branch 39 | up_state = None 40 | for layer in reversed(range(0, self.number_of_layers)): 41 | state[layer] = getattr(self, 'generator_' + str(layer + 1))( 42 | error[layer], up_state, state[layer] 43 | ) 44 | up_state = state[layer][0] 45 | 46 | # discriminative branch 47 | for layer in range(0, self.number_of_layers): 48 | error[layer] = getattr(self, 'discriminator_' + str(layer + 1))( 49 | layer and error[layer - 1] or bottom_up_input, 50 | state[layer][0] 51 | ) 52 | 53 | return error, state 54 | 55 | 56 | class _BuildOneLayerModel(nn.Module): 57 | """ 58 | Build a one layer Prednet model 59 | """ 60 | 61 | def __init__(self, error_size_list): 62 | super().__init__() 63 | self.discriminator = DiscriminativeCell( 64 | input_size={'input': IN_LAYER_SIZE[0], 'state': OUT_LAYER_SIZE[0]}, 65 | hidden_size=OUT_LAYER_SIZE[0], 66 | first=True 67 | ) 68 | self.generator = GenerativeCell( 69 | input_size={'error': ERR_LAYER_SIZE[0], 'up_state': 0}, 70 | hidden_size=OUT_LAYER_SIZE[0], 71 | error_init_size=error_size_list[0] 72 | ) 73 | 74 | def forward(self, bottom_up_input, prev_error, state): 75 | state = self.generator(prev_error, None, state) 76 | error = self.discriminator(bottom_up_input, state[0]) 77 | return error, state 78 | 79 | 80 | class _BuildTwoLayerModel(nn.Module): 81 | """ 82 | Build a two layer Prednet model 83 | """ 84 | 85 | def __init__(self, error_size_list): 86 | super().__init__() 87 | self.discriminator_1 = DiscriminativeCell( 88 | input_size={'input': IN_LAYER_SIZE[0], 'state': OUT_LAYER_SIZE[0]}, 89 | hidden_size=OUT_LAYER_SIZE[0], 90 | first=True 91 | ) 92 | self.discriminator_2 = DiscriminativeCell( 93 | input_size={'input': IN_LAYER_SIZE[1], 'state': OUT_LAYER_SIZE[1]}, 94 | hidden_size=OUT_LAYER_SIZE[1] 95 | ) 96 | self.generator_1 = GenerativeCell( 97 | input_size={'error': ERR_LAYER_SIZE[0], 'up_state': OUT_LAYER_SIZE[1]}, 98 | hidden_size=OUT_LAYER_SIZE[0], 99 | error_init_size=error_size_list[0] 100 | ) 101 | self.generator_2 = GenerativeCell( 102 | input_size={'error': ERR_LAYER_SIZE[1], 'up_state': 0}, 103 | hidden_size=OUT_LAYER_SIZE[1], 104 | error_init_size=error_size_list[1] 105 | ) 106 | 107 | def forward(self, bottom_up_input, error, state): 108 | state[1] = self.generator_2(error[1], None, state[1]) 109 | state[0] = self.generator_1(error[0], state[1][0], state[0]) 110 | error[0] = self.discriminator_1(bottom_up_input, state[0][0]) 111 | error[1] = self.discriminator_2(error[0], state[1][0]) 112 | return error, state 113 | 114 | 115 | def _test_one_layer_model(): 116 | print('\nCreate the input image') 117 | input_image = Variable(torch.rand(1, 3, 8, 12)) 118 | 119 | print('Input has size', list(input_image.data.size())) 120 | 121 | error_init_size = (1, 6, 8, 12) 122 | print('The error initialisation size is', error_init_size) 123 | 124 | print('Define a 1 layer Prednet') 125 | model = _BuildOneLayerModel((error_init_size,)) 126 | 127 | print('Forward input and state to the model') 128 | state = None 129 | error = None 130 | error, state = model(input_image, prev_error=error, state=state) 131 | 132 | print('The error has size', list(error.data.size())) 133 | print('The state has size', list(state[0].data.size())) 134 | 135 | 136 | def _test_two_layer_model(): 137 | print('\nCreate the input image') 138 | input_image = Variable(torch.rand(1, 3, 8, 12)) 139 | 140 | print('Input has size', list(input_image.data.size())) 141 | 142 | error_init_size_list = ((1, 6, 8, 12), (1, 32, 4, 6)) 143 | print('The error initialisation sizes are', *error_init_size_list) 144 | 145 | print('Define a 2 layer Prednet') 146 | model = _BuildTwoLayerModel(error_init_size_list) 147 | 148 | print('Forward input and state to the model') 149 | state = [None] * 2 150 | error = [None] * 2 151 | error, state = model(input_image, error=error, state=state) 152 | 153 | for layer in range(0, 2): 154 | print('Layer', layer + 1, 'error has size', list(error[layer].data.size())) 155 | print('Layer', layer + 1, 'state has size', list(state[layer][0].data.size())) 156 | 157 | 158 | def _test_L_layer_model(): 159 | 160 | max_number_of_layers = 5 161 | for L in range(0, max_number_of_layers): 162 | print('\n---------- Test', str(L + 1), 'layer network ----------') 163 | 164 | print('Create the input image') 165 | input_image = Variable(torch.rand(1, 3, 4 * 2 ** L, 6 * 2 ** L)) 166 | 167 | print('Input has size', list(input_image.data.size())) 168 | 169 | error_init_size_list = tuple( 170 | (1, ERR_LAYER_SIZE[l], 4 * 2 ** (L-l), 6 * 2 ** (L-l)) for l in range(0, L + 1) 171 | ) 172 | print('The error initialisation sizes are', *error_init_size_list) 173 | 174 | print('Define a', str(L + 1), 'layer Prednet') 175 | model = PrednetModel(error_init_size_list) 176 | 177 | print('Forward input and state to the model') 178 | state = [None] * (L + 1) 179 | error = [None] * (L + 1) 180 | error, state = model(input_image, error=error, state=state) 181 | 182 | for layer in range(0, L + 1): 183 | print('Layer', layer + 1, 'error has size', list(error[layer].data.size())) 184 | print('Layer', layer + 1, 'state has size', list(state[layer][0].data.size())) 185 | 186 | 187 | def _test_training(): 188 | number_of_layers = 3 189 | T = 6 # sequence length 190 | max_epoch = 10 # number of epochs 191 | lr = 1e-1 # learning rate 192 | 193 | # set manual seed 194 | torch.manual_seed(0) 195 | 196 | L = number_of_layers - 1 197 | print('\n---------- Train a', str(L + 1), 'layer network ----------') 198 | print('Create the input image and target sequences') 199 | input_sequence = Variable(torch.rand(T, 1, 3, 4 * 2 ** L, 6 * 2 ** L)) 200 | print('Input has size', list(input_sequence.data.size())) 201 | 202 | error_init_size_list = tuple( 203 | (1, ERR_LAYER_SIZE[l], 4 * 2 ** (L - l), 6 * 2 ** (L - l)) for l in range(0, L + 1) 204 | ) 205 | print('The error initialisation sizes are', *error_init_size_list) 206 | target_sequence = Variable(torch.zeros(T, *error_init_size_list[0])) 207 | 208 | print('Define a', str(L + 1), 'layer Prednet') 209 | model = PrednetModel(error_init_size_list) 210 | 211 | print('Create a MSE criterion') 212 | loss_fn = nn.MSELoss() 213 | 214 | print('Run for', max_epoch, 'iterations') 215 | for epoch in range(0, max_epoch): 216 | state = [None] * (L + 1) 217 | error = [None] * (L + 1) 218 | loss = 0 219 | for t in range(0, T): 220 | error, state = model(input_sequence[t], error, state) 221 | loss += loss_fn(error[0], target_sequence[t]) 222 | 223 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch + 1), loss.data[0])) 224 | 225 | # zero grad parameters 226 | model.zero_grad() 227 | 228 | # compute new grad parameters through time! 229 | loss.backward() 230 | 231 | # learning_rate step against the gradient 232 | for p in model.parameters(): 233 | p.data.sub_(p.grad.data * lr) 234 | 235 | 236 | def _main(): 237 | _test_one_layer_model() 238 | _test_two_layer_model() 239 | _test_L_layer_model() 240 | _test_training() 241 | 242 | 243 | if __name__ == '__main__': 244 | _main() 245 | 246 | 247 | __author__ = "Alfredo Canziani" 248 | __credits__ = ["Alfredo Canziani"] 249 | __maintainer__ = "Alfredo Canziani" 250 | __email__ = "alfredo.canziani@gmail.com" 251 | __status__ = "Prototype" # "Prototype", "Development", or "Production" 252 | __date__ = "Feb 17" 253 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | # Network architectures 2 | 3 | Three main architectures are currently available: 4 | 5 | - [`PrednetModel.py`](PrednetModel.py): implementation of [*PredNet*](https://coxlab.github.io/prednet/) in *PyTorch* through the following smaller blocks: 6 | - [`DiscriminativeCell.py`](DiscriminativeCell.py): computes the error between the input and state projections; 7 | - [`GenerativeCell.py`](GenerativeCell.py): computes the new state given error, top down state, and current state. Uses the following custom module: 8 | - [`ConvLSTMCell.py`](ConvLSTMCell.py): a pretty standard LSTM that uses convolutions instead of FC layers; 9 | - [`Model01.py`](Model01.py): symmetric, additional feed-forward/back; 10 | - [`Model02.py`](Model02.py): AKA *CortexNet*, additional feed-forward, modulated feed-back. May use: 11 | - [`RG.py`](RG.py): recurrent generative block ⇒ *Model02G*. 12 | 13 | ![Model diagrams](https://cdn.rawgit.com/e-lab/pytorch-CortexNet/master/model/models.svg) -------------------------------------------------------------------------------- /model/RG.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class RG(nn.Module): 5 | """Recurrent Generative Module""" 6 | 7 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 8 | """ Initialise RG Module (parameters as nn.ConvTranspose2d)""" 9 | super().__init__() 10 | self.from_input = nn.ConvTranspose2d( 11 | in_channels=in_channels, out_channels=out_channels, 12 | kernel_size=kernel_size, stride=stride, padding=padding 13 | ) 14 | self.from_state = nn.Conv2d( 15 | in_channels=out_channels, out_channels=out_channels, 16 | kernel_size=kernel_size, padding=padding, bias=False 17 | ) 18 | 19 | def forward(self, x, state): 20 | """ 21 | Calling signature 22 | 23 | :param x: (input, output_size) 24 | :type x: tuple 25 | :param state: previous output 26 | :type state: torch.Tensor 27 | :return: current state 28 | :rtype: torch.Tensor 29 | """ 30 | x = self.from_input(*x) # the very first x is a tuple (input, expected_output_size) 31 | if state: x += self.from_state(state) 32 | return x 33 | -------------------------------------------------------------------------------- /model/utils: -------------------------------------------------------------------------------- 1 | ../utils/ -------------------------------------------------------------------------------- /new_experiment.sh: -------------------------------------------------------------------------------- 1 | # Prepares environment for new experiment 2 | # Alfredo Canziani, Mar 17 3 | 4 | # It expects a directory / link named "results" in the cwd containing 5 | # numerically increasing folders with 3 digits 6 | 7 | shopt -s extglob # use extended pattern matching capabilities 8 | 9 | old_CLI=$(awk -F': ' '/CLI/{print $2}' last/train.log) 10 | max_exp=$(ls results/ | tail -1) 11 | max_exp=${max_exp##*(0)} # remove possible leading 0s 12 | dst_path=$(printf "results/%03d" "$((++max_exp))") 13 | 14 | echo -n " > Creating folder: " 15 | mkdir $dst_path 16 | ls -d --color=always $dst_path 17 | 18 | echo -n " > Linking:" 19 | ln -snf $dst_path last 20 | ls -l --color=always last | awk -F"$USER" '{print $3}' 21 | 22 | echo " > Previously you've used the following options" 23 | echo " CUDA_VISIBLE_DEVICES=n python -u main.py $old_CLI | tee last/train.log" 24 | -------------------------------------------------------------------------------- /notebook/README.md: -------------------------------------------------------------------------------- 1 | # Collection of Jupyter Notebooks 2 | 3 | *Jupyter Notebooks* are a very effective tool for interactive data exploration and visualisation. 4 | 5 | ## Correct display 6 | 7 | I use dark styles for both *GitHub* and *Jupyter Notebook*. 8 | To see the content appropriately install: 9 | 10 | - [*Jupyter Notebook* dark theme](https://userstyles.org/styles/98208/jupyter-notebook-dark-originally-from-ipython); 11 | - [*GitHub* dark theme](https://userstyles.org/styles/37035/github-dark) and comment out the `invert #fff to #181818` code block. 12 | 13 | ## Notebooks 14 | 15 | These notebooks are somehow of personal usage. 16 | I used them for data exploration, getting familiar with PyTorch model graphs, and generate the paper figures and website animations. 17 | They all run, on my system, with the data generated from the experiments. 18 | I'm releasing them here for reference only, and you are welcome to browse them, even though you might not be able to execute them. 19 | 20 | - [`display_loss.ipynb`](display_loss.ipynb): display train and validation losses for every experiment; 21 | - [`figures_generator.ipynb`](figures_generator.ipynb): generates paper figures and website animations; 22 | - [`frequency_analysis.ipynb`](frequency_analysis.ipynb): explor video temporal-frequency components; 23 | - [`get_all_embeddings.ipynb`](get_all_embeddings.ipynb): display ResNet18 embeddings and probability *vs.* time; 24 | - [`get_data_stats.ipynb`](get_data_stats.ipynb): show stats on the e-VDS35 data set; 25 | - [`network_bisection.ipynb`](network_bisection.ipynb): mucking around with PyTorch model graph; 26 | - [`salient_regions.ipynb`](salient_regions.ipynb): implements and compute salient regions from [bojarski2017explaining](https://arxiv.org/abs/1704.07911); 27 | - [`stability_analysis.ipynb`](stability_analysis.ipynb): mucking around with ResNet18 and videos; 28 | - [`verify_resnet.ipynb`](verify_resnet.ipynb): checks whether I'm capable to use PyTorch's ResNet18; 29 | - [`plot_conf.py`](plot_conf.py): `matplotlib` shared configuration. 30 | -------------------------------------------------------------------------------- /notebook/data: -------------------------------------------------------------------------------- 1 | ../data/ -------------------------------------------------------------------------------- /notebook/model: -------------------------------------------------------------------------------- 1 | ../model/ -------------------------------------------------------------------------------- /notebook/network_bisection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torchvision\n", 13 | "from torch.autograd import Variable as V\n", 14 | "from utils.visualise import make_dot" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "resnet_18 = torchvision.models.resnet18(pretrained=True)\n", 26 | "resnet_18.eval();" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "# by setting the volatile flag to True, intermediate caches are not saved\n", 38 | "# making the inspection of the graph pretty boring / useless\n", 39 | "torch.manual_seed(0)\n", 40 | "x = V(torch.randn(1, 3, 224, 224))#, volatile=True)\n", 41 | "h_x = resnet_18(x)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": { 48 | "collapsed": false 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "dot = make_dot(h_x) # generate network graph\n", 53 | "dot.render('net.dot'); # save DOT and PDF in the current directory\n", 54 | "# dot # uncomment for displaying the graph in the notebook" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 5, 60 | "metadata": { 61 | "collapsed": false 62 | }, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "h_x creator -> \n", 69 | "h_x creator prev fun type -> \n", 70 | "h_x creator prev fun length -> 3\n", 71 | "\n", 72 | "--- content of h_x creator prev fun ---\n", 73 | "0 --> (, 0)\n", 74 | "1 --> (Parameter containing:\n", 75 | "-1.8474e-02 -7.0461e-02 -5.1772e-02 ... -3.9030e-02 1.7351e-01 -4.0976e-02\n", 76 | "-8.1792e-02 -9.4370e-02 1.7355e-02 ... 2.0284e-01 -2.4782e-02 3.7172e-02\n", 77 | "-3.3164e-02 -5.6569e-02 -2.4165e-02 ... -3.4402e-02 -2.2659e-02 1.9705e-02\n", 78 | " ... ⋱ ... \n", 79 | "-1.0300e-02 3.2804e-03 -3.5863e-02 ... -2.7923e-02 -1.1458e-02 1.2759e-02\n", 80 | "-3.5879e-02 -3.5296e-02 -2.9602e-02 ... -3.2961e-02 -1.1022e-02 -5.1256e-02\n", 81 | " 2.1277e-03 -2.4839e-02 -8.2920e-02 ... 4.1731e-02 -5.0030e-02 6.6327e-02\n", 82 | "[torch.FloatTensor of size 1000x512]\n", 83 | ", 0)\n", 84 | "2 --> (Parameter containing:\n", 85 | "1.00000e-02 *\n", 86 | " -0.2634\n", 87 | " 0.3000\n", 88 | " 0.0656\n", 89 | " ⋮ \n", 90 | " -1.7868\n", 91 | " -0.0782\n", 92 | " -0.6345\n", 93 | "[torch.FloatTensor of size 1000]\n", 94 | ", 0)\n", 95 | "---------------------------------------\n", 96 | "\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "# explore network graph\n", 102 | "print('h_x creator ->',h_x.creator)\n", 103 | "print('h_x creator prev fun type ->', type(h_x.creator.previous_functions))\n", 104 | "print('h_x creator prev fun length ->', len(h_x.creator.previous_functions))\n", 105 | "print('\\n--- content of h_x creator prev fun ---')\n", 106 | "for a, b in enumerate(h_x.creator.previous_functions): print(a, '-->', b)\n", 107 | "print('---------------------------------------\\n')" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "The current node is a `torch.nn._functions.linear.Linear` object, fed by\n", 115 | "\n", 116 | "- 0 --> output of `torch.autograd._functions.tensor.View` object\n", 117 | "- 1 --> weight matrix of size `(1000, 512)`\n", 118 | "- 2 --> bias vector of size `(1000)`" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "metadata": { 125 | "collapsed": false 126 | }, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "ResNet (\n", 133 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 134 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", 135 | " (relu): ReLU (inplace)\n", 136 | " (maxpool): MaxPool2d (size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))\n", 137 | " (layer1): Sequential (\n", 138 | " (0): BasicBlock (\n", 139 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 140 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", 141 | " (relu): ReLU (inplace)\n", 142 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 143 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", 144 | " )\n", 145 | " (1): BasicBlock (\n", 146 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 147 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", 148 | " (relu): ReLU (inplace)\n", 149 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 150 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n", 151 | " )\n", 152 | " )\n", 153 | " (layer2): Sequential (\n", 154 | " (0): BasicBlock (\n", 155 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 156 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", 157 | " (relu): ReLU (inplace)\n", 158 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 159 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", 160 | " (downsample): Sequential (\n", 161 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 162 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", 163 | " )\n", 164 | " )\n", 165 | " (1): BasicBlock (\n", 166 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 167 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", 168 | " (relu): ReLU (inplace)\n", 169 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 170 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", 171 | " )\n", 172 | " )\n", 173 | " (layer3): Sequential (\n", 174 | " (0): BasicBlock (\n", 175 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 176 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", 177 | " (relu): ReLU (inplace)\n", 178 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 179 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", 180 | " (downsample): Sequential (\n", 181 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 182 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", 183 | " )\n", 184 | " )\n", 185 | " (1): BasicBlock (\n", 186 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 187 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", 188 | " (relu): ReLU (inplace)\n", 189 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 190 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", 191 | " )\n", 192 | " )\n", 193 | " (layer4): Sequential (\n", 194 | " (0): BasicBlock (\n", 195 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 196 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", 197 | " (relu): ReLU (inplace)\n", 198 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 199 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", 200 | " (downsample): Sequential (\n", 201 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 202 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", 203 | " )\n", 204 | " )\n", 205 | " (1): BasicBlock (\n", 206 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 207 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", 208 | " (relu): ReLU (inplace)\n", 209 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 210 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", 211 | " )\n", 212 | " )\n", 213 | " (avgpool): AvgPool2d (\n", 214 | " )\n", 215 | " (fc): Linear (512 -> 1000)\n", 216 | ")\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "print(resnet_18)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 7, 227 | "metadata": { 228 | "collapsed": false 229 | }, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "odict_keys(['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc'])" 235 | ] 236 | }, 237 | "execution_count": 7, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "resnet_18._modules.keys()" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 8, 249 | "metadata": { 250 | "collapsed": false 251 | }, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "m: \n", 258 | "i: \n", 259 | " len: 1 \n", 260 | " type: \n", 261 | " data size: torch.Size([1, 512, 7, 7]) \n", 262 | " data type: torch.FloatTensor \n", 263 | "o: \n", 264 | " data size: torch.Size([1, 512, 1, 1]) \n", 265 | " data type: torch.FloatTensor\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "avgpool_layer = resnet_18._modules.get('avgpool')\n", 271 | "h = avgpool_layer.register_forward_hook(\n", 272 | " lambda m, i, o: \\\n", 273 | " print(\n", 274 | " 'm:', type(m),\n", 275 | " '\\ni:', type(i),\n", 276 | " '\\n len:', len(i),\n", 277 | " '\\n type:', type(i[0]),\n", 278 | " '\\n data size:', i[0].data.size(),\n", 279 | " '\\n data type:', i[0].data.type(),\n", 280 | " '\\no:', type(o),\n", 281 | " '\\n data size:', o.data.size(),\n", 282 | " '\\n data type:', o.data.type(),\n", 283 | " )\n", 284 | ")\n", 285 | "h_x = resnet_18(x)\n", 286 | "h.remove()" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 9, 292 | "metadata": { 293 | "collapsed": false 294 | }, 295 | "outputs": [], 296 | "source": [ 297 | "my_embedding = torch.zeros(512)\n", 298 | "def fun(m, i, o): my_embedding.copy_(o.data)\n", 299 | "h = avgpool_layer.register_forward_hook(fun)\n", 300 | "h_x = resnet_18(x)\n", 301 | "h.remove()" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 10, 307 | "metadata": { 308 | "collapsed": false 309 | }, 310 | "outputs": [ 311 | { 312 | "data": { 313 | "text/plain": [ 314 | "\n", 315 | " 0.3879 0.0205 0.0268 2.9453 0.0234 0.0000 0.0000 0.7327 1.0997 0.0000\n", 316 | "[torch.FloatTensor of size 1x10]" 317 | ] 318 | }, 319 | "execution_count": 10, 320 | "metadata": {}, 321 | "output_type": "execute_result" 322 | } 323 | ], 324 | "source": [ 325 | "# print first values of the embedding\n", 326 | "my_embedding[:10].view(1, -1)" 327 | ] 328 | } 329 | ], 330 | "metadata": { 331 | "kernelspec": { 332 | "display_name": "Python 3", 333 | "language": "python", 334 | "name": "python3" 335 | }, 336 | "language_info": { 337 | "codemirror_mode": { 338 | "name": "ipython", 339 | "version": 3 340 | }, 341 | "file_extension": ".py", 342 | "mimetype": "text/x-python", 343 | "name": "python", 344 | "nbconvert_exporter": "python", 345 | "pygments_lexer": "ipython3", 346 | "version": "3.5.2" 347 | } 348 | }, 349 | "nbformat": 4, 350 | "nbformat_minor": 0 351 | } 352 | -------------------------------------------------------------------------------- /notebook/plot_conf.py: -------------------------------------------------------------------------------- 1 | # matplotlib and stuff 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plt_style(c='k'): 6 | """ 7 | Set plotting style for bright (``c = 'w'``) or dark (``c = 'k'``) backgrounds 8 | 9 | :param c: colour, can be set to ``'w'`` or ``'k'`` (which is the default) 10 | :type c: str 11 | """ 12 | import matplotlib as mpl 13 | from matplotlib import rc 14 | 15 | # Reset previous configuration 16 | mpl.rcParams.update(mpl.rcParamsDefault) 17 | # %matplotlib inline # not from script 18 | get_ipython().run_line_magic('matplotlib', 'inline') 19 | 20 | # configuration for bright background 21 | if c == 'w': 22 | plt.style.use('bmh') 23 | 24 | # configurations for dark background 25 | if c == 'k': 26 | # noinspection PyTypeChecker 27 | plt.style.use(['dark_background', 'bmh']) 28 | 29 | # remove background colour, set figure size 30 | rc('figure', figsize=(16, 8), max_open_warning=False) 31 | rc('axes', facecolor='none') 32 | 33 | plt_style() 34 | -------------------------------------------------------------------------------- /notebook/utils: -------------------------------------------------------------------------------- 1 | ../utils/ -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Utility scripts 2 | 3 | This folder and this `README.md` contain some very useful scripts. 4 | The folder content is the following: 5 | 6 | - [`check_exp_diff.sh*`](check_exp_diff.sh): shows difference between experiments' training CLI arguments; 7 | - [`update_experiments.sh*`](update_experiments.sh): synchronises experiments across multiple nodes; 8 | - [`show_error.plt*`](show_error.plt): show current error with *gnuplot* by `awk`-ing `last/train.log`; 9 | - [`show_error_exp.plt*`](show_error_exp.plt): show losses for a specific experiment; 10 | - [`image_plot.py`](image_plot.py): saves intermediate training visual outputs to PDF (may be used by `main.py`); 11 | 12 | ## Script usage and arguments 13 | 14 | ### Compare CLI arguments 15 | 16 | To compare the CLI arguments used in two different experiments run `./check_exp_diff.sh `. 17 | The script will output the `--word-diff` between the calling arguments, and `git show` the newer experiment's `HEAD`. 18 | Replace `<{base,new} exp>` with the corresponding three-digit identifier. 19 | 20 | ### Plotting current training losses 21 | 22 | To plot the (*MSE*, *CE*, *replica MSE*, and *periodic CE*) cost functions interactively for the current training network run `./show_error.plt -i`. 23 | To statically plot them, after training, run `./show_error.plt`. 24 | 25 | ### Displaying any experiment losses 26 | 27 | If the experiment data is not on the current machine run `./update_experiments.sh -i` to iteratively sync the data every `5` seconds. 28 | To view the losses iteratively run `./show_error_exp.plt -i`, where `` is for example `012`. 29 | To run statically restrain to use the `-i` flag. 30 | 31 | ### Synchronising experiments 32 | 33 | To sync experiments across multiple nodes run `./update_experiments.sh`. 34 | If you whish to run it quietly (verbose is default), run `./update_experiments.sh -q`. 35 | If you'd like to run it in a loop of `5` seconds, run `./update_experiments.sh -i`. 36 | 37 | ## Images manipulation scripts collection 38 | 39 | ### Get PNGs out of PDFs 40 | 41 | ```bash 42 | convert *.pdf *.png 43 | ``` 44 | 45 | ### Create GIFs 46 | 47 | Create animations from the content of `00{6,7}/PDFs` into `anim/`: 48 | 49 | ```bash 50 | for p in 006/PDFs/*; do 51 | g=anim/${p##*/} 52 | convert -delay 100 $p ${p/6/7} ${g/pdf/gif} 53 | done 54 | ``` 55 | 56 | `delay` is expressed in `10`ms. -------------------------------------------------------------------------------- /utils/check_exp_diff.sh: -------------------------------------------------------------------------------- 1 | # Check what has changed across experiments 2 | # Alfredo Canziani, Mar 17 3 | 4 | # ./check_exp_diff.sh base_exp new_exp 5 | 6 | base_exp="../results/$1/train.log" 7 | new_exp="../results/$2/train.log" 8 | 9 | git diff --no-index --word-diff --color=always $base_exp $new_exp | head -7 10 | 11 | echo "" 12 | hash=$(awk -F ": " '/hash/{print $2}' $new_exp) 13 | git show --stat $hash 14 | -------------------------------------------------------------------------------- /utils/image_plot.py: -------------------------------------------------------------------------------- 1 | import torch # if torch is not imported BEFORE pyplot you get a FUCKING segmentation fault 2 | from matplotlib import pyplot as plt 3 | from os.path import isdir, join 4 | from os import mkdir 5 | 6 | 7 | def _hist_show(a, k): 8 | a = _to_view(a) 9 | plt.subplot(2, 3, k) 10 | plt.hist(a.reshape(-1), 50) 11 | plt.grid('on') 12 | plt.gca().axes.get_yaxis().set_visible(False) 13 | 14 | 15 | def show_four(x, next_x, x_hat, fig): 16 | """ 17 | Saves/overwrites a PDF named fig.pdf with x, next_x, x_hat histogram and x_hat 18 | 19 | :param x: x[t] 20 | :type x: torch.FloatTensor 21 | :param next_x: x[t + 1] 22 | :type next_x: torch.FloatTensor 23 | :param x_hat: ~x[t + 1] 24 | :type x_hat: torch.FloatTensor 25 | :param fig: figure number 26 | :type fig: int 27 | :return: nothing 28 | :rtype: None 29 | """ 30 | f = plt.figure(fig) 31 | plt.clf() 32 | _sub(x, 1) 33 | _sub(next_x, 4) 34 | dif = next_x - x 35 | _sub(dif, 2) 36 | _hist_show(dif, 3) 37 | _sub(x_hat, 5) 38 | _hist_show(x_hat, 6) 39 | plt.subplots_adjust(left=0.01, bottom=0.06, right=.99, top=1, wspace=0, hspace=.12) 40 | f.savefig(str(fig) + '.pdf') 41 | 42 | 43 | # Setup output folder for figures collection 44 | def _show_ten_setup(pdf_path): 45 | if isdir(pdf_path): 46 | print('Folder "{}" already existent. Exiting.'.format(pdf_path)) 47 | exit() 48 | mkdir(pdf_path) 49 | 50 | 51 | def show_ten(x, x_hat, pdf_path='PDFs'): 52 | """ 53 | First two rows 10 ~x[t + 1], second two rows 10 x[t] 54 | 55 | :param x: x[t] 56 | :type x: torch.FloatTensor 57 | :param x_hat: ~x[t + 1] 58 | :type x_hat: torch.FloatTensor 59 | :param pdf_path: saving path 60 | :type pdf_path: str 61 | :return: nothing 62 | :rtype: None 63 | """ 64 | if show_ten.c == 0 and pdf_path: _show_ten_setup(pdf_path) 65 | if show_ten.c % 10 == 0: show_ten.f = plt.figure() 66 | plt.figure(show_ten.f.number) 67 | plt.subplot(4, 5, 1 + show_ten.c % 10) 68 | _img_show(x_hat, y0=-.16, s=8) 69 | plt.subplot(4, 5, 11 + show_ten.c % 10) 70 | _img_show(x, y0=-.16, s=8) 71 | show_ten.c += 1 72 | plt.subplots_adjust(left=0, bottom=0.02, right=1, top=1, wspace=0, hspace=.12) 73 | if show_ten.c % 10 == 0: show_ten.f.savefig(join(pdf_path, str(show_ten.c // 10) + '_10.pdf')) 74 | show_ten.c = 0 75 | 76 | 77 | def _img_show(a, y0=-.13, s=12): 78 | a = _to_view(a) 79 | plt.imshow(a) 80 | plt.title('<{:.2f}> [{:.2f}, {:.2f}]'.format(a.mean(), a.min(), a.max()), y=y0, fontsize=s) 81 | plt.axis('off') 82 | 83 | 84 | def _sub(a, k): 85 | plt.subplot(2, 3, k) 86 | _img_show(a) 87 | 88 | 89 | def _to_view(a): 90 | return a.cpu().numpy().transpose((1, 2, 0)) 91 | 92 | 93 | def _test_4(): 94 | img = _test_setup() 95 | show_four(img, img, img, 1) 96 | 97 | 98 | def _test_10(): 99 | img = _test_setup() 100 | for i in range(20): show_ten(img, -img, '') 101 | 102 | 103 | def _test_setup(): 104 | from skimage.data import astronaut 105 | from skimage.transform import resize 106 | from matplotlib.figure import Figure 107 | Figure.savefig = lambda self, _: plt.show() # patch Figure class to simply display the figure 108 | img = torch.from_numpy(resize(astronaut(), (256, 256)).astype('f4').transpose((2, 0, 1))) 109 | return img 110 | 111 | 112 | if __name__ == '__main__': 113 | _test_4() 114 | _test_10() 115 | 116 | __author__ = "Alfredo Canziani" 117 | __credits__ = ["Alfredo Canziani"] 118 | __maintainer__ = "Alfredo Canziani" 119 | __email__ = "alfredo.canziani@gmail.com" 120 | __status__ = "Development" # "Prototype", "Development", or "Production" 121 | __date__ = "Mar 17" 122 | -------------------------------------------------------------------------------- /utils/show_error.plt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/gnuplot -c 2 | 3 | # Plot MSE and CE train loss iteratively 4 | 5 | # Run as: 6 | # ./show_error3.plt -i # to run it iteratively 7 | # ./show_error3.plt # to run it statically 8 | 9 | # Alfredo Canziani, Mar 17 10 | 11 | # set white on black theme 12 | set terminal wxt background rgb "black" noraise 13 | set xlabel textcolor rgb "white" 14 | set ylabel textcolor rgb "white" 15 | set y2label textcolor rgb "white" 16 | set key textcolor rgb "white" 17 | set border lc rgb 'white' 18 | set grid lc rgb 'white' 19 | 20 | set grid 21 | set xlabel "mini batch index / 10" 22 | set ylabel "mMSE" 23 | set y2label "CE" 24 | set y2tics 25 | plot \ 26 | "< awk '/batches/{print $18,$21,$25,$29}' ../last/train.log" \ 27 | u 0:1 w lines lw 2 title "MSE", \ 28 | "" \ 29 | u 0:3 w lines lw 2 title "rpl MSE", \ 30 | "" \ 31 | u 0:2 w lines lw 2 title "CE" axis x1y2, \ 32 | "" \ 33 | u 0:4 w lines lw 2 title "per CE" axis x1y2 34 | 35 | if (ARG1 ne '-i') { 36 | pause -1 # just hang in there 37 | exit 38 | } 39 | 40 | pause 5 # wait 5 seconds 41 | reread # and start over 42 | -------------------------------------------------------------------------------- /utils/show_error_exp.plt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/gnuplot -c 2 | 3 | # Plot MSE and CE train loss iteratively 4 | 5 | # Run as: 6 | # ./show_error_exp.plt EXP -i # to run it iteratively 7 | # ./show_error_exp.plt EXP # to run it statically 8 | 9 | # Alfredo Canziani, Mar 17 10 | 11 | # You need to run ./utils/update_experiments.sh -i 12 | 13 | # set white on black theme 14 | set terminal wxt background rgb "black" noraise 15 | set xlabel textcolor rgb "white" 16 | set ylabel textcolor rgb "white" 17 | set y2label textcolor rgb "white" 18 | set title textcolor rgb "white" 19 | set key textcolor rgb "white" 20 | set border lc rgb 'white' 21 | set grid lc rgb 'white' 22 | 23 | set grid 24 | set title "Network " . ARG1 25 | set xlabel "mini batch index / 10" 26 | set ylabel "mMSE" 27 | set y2label "CE" 28 | set y2tics 29 | if (ARG1 + 0 < 20) { # "+ 0" conversion string to number 30 | plot \ 31 | "< awk '/batches/{print $18,$21,$25}' ../results/".ARG1."/train.log" \ 32 | u 0:1 w lines lw 2 title "MSE", \ 33 | "" \ 34 | u 0:2 w lines lw 2 title "CE" axis x1y2, \ 35 | "" \ 36 | u 0:3 w lines lw 2 title "rpl MSE" 37 | } else { 38 | plot \ 39 | "< awk '/batches/{print $18,$21,$25,$29}' ../results/".ARG1."/train.log" \ 40 | u 0:1 w lines lw 2 title "MSE", \ 41 | "" \ 42 | u 0:3 w lines lw 2 title "rpl MSE", \ 43 | "" \ 44 | u 0:2 w lines lw 2 title "CE" axis x1y2, \ 45 | "" \ 46 | u 0:4 w lines lw 2 title "per CE" axis x1y2 47 | } 48 | 49 | if (ARG2 ne '-i') { 50 | pause -1 # just hang in there 51 | exit 52 | } 53 | 54 | pause 5 # wait 5 seconds 55 | reread # and start over 56 | -------------------------------------------------------------------------------- /utils/update_experiments.sh: -------------------------------------------------------------------------------- 1 | # Pull and push latest experiment results 2 | # Alfredo Canziani, Mar 17 3 | 4 | # Run it as 5 | # ./update_experiments.sh 6 | # ./update_experiments.sh -q # quiet 7 | # ./update_experiments.sh -i # iteratively / 5 seconds loop 8 | 9 | if [ "$1" == "-i" ]; then 10 | function ctrl_c { 11 | echo -e "\nExiting." 12 | exit 0 13 | } 14 | trap ctrl_c INT 15 | echo -n "Syncing experiments every 5 seconds" 16 | while true; do 17 | ./update_experiments.sh -q 18 | # prints one "." every second for 5 seconds, then removes them 19 | for (( i = 0; i < 5; i++ )); do 20 | printf "." 21 | sleep 1 22 | done 23 | printf "%.s\b" {1..5} 24 | printf "%.s " {1..5} 25 | printf "%.s\b" {1..5} 26 | done 27 | fi 28 | 29 | if [ "$1" != "-q" ]; then verbose="--verbose"; fi 30 | 31 | local=$(hostname) 32 | case $local in 33 | "GPU0") remote=$GPU8;; 34 | "GPU8") remote=$GPU0;; 35 | *) echo "Something's wrong"; exit -1;; 36 | esac 37 | 38 | # Get experimental data 39 | if [ -n "$verbose" ]; then 40 | echo; printf "%.s#" {1..80}; echo 41 | echo -n "Getting experiments from $remote to $local" 42 | echo; printf "%.s#" {1..80}; echo; echo 43 | fi 44 | rsync \ 45 | --update \ 46 | --archive \ 47 | $verbose \ 48 | --human-readable \ 49 | $remote:MatchNet/results/ \ 50 | ../results 51 | 52 | # Send experimental data 53 | if [ -n "$verbose" ]; then 54 | echo; printf "%.s#" {1..80}; echo 55 | echo -n "Sending experiments to $remote from $local" 56 | echo; printf "%.s#" {1..80}; echo; echo 57 | fi 58 | rsync \ 59 | --update \ 60 | --archive \ 61 | $verbose \ 62 | --human-readable \ 63 | ../results/ \ 64 | $remote:MatchNet/results 65 | -------------------------------------------------------------------------------- /utils/visualise.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | from torch.autograd import Variable 3 | import sys, subprocess 4 | import uuid 5 | 6 | 7 | def make_dot(root): 8 | node_attr = dict(style='filled', 9 | shape='box', 10 | align='left', 11 | fontsize='12', 12 | ranksep='0.1', 13 | height='0.2') 14 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 15 | seen = set() 16 | 17 | def add_nodes(var): 18 | if var not in seen: 19 | if isinstance(var, Variable): 20 | value = '(' + ', '.join(['%d'% v for v in var.size()]) + ')' 21 | dot.node(str(id(var)), str(value), fillcolor='lightblue') 22 | else: 23 | dot.node(str(id(var)), str(type(var).__name__)) 24 | seen.add(var) 25 | if hasattr(var, 'previous_functions'): 26 | for u in var.previous_functions: 27 | dot.edge(str(id(u[0])), str(id(var))) 28 | add_nodes(u[0]) 29 | add_nodes(root.creator) 30 | return dot 31 | 32 | 33 | def show_graph(root): 34 | dot_file_name = '/tmp/' + str(uuid.uuid4()) 35 | make_dot(root).render(dot_file_name) 36 | pdf_file_name = dot_file_name + '.pdf' 37 | if sys.platform == 'darwin': 38 | subprocess.call(('open', pdf_file_name)) 39 | elif sys.platform == 'linux': 40 | subprocess.call(('xdg-open', pdf_file_name)) 41 | 42 | 43 | __author__ = "Sergey Zagoruyko and Alfredo Canziani" 44 | __credits__ = ["Sergey Zagoruyko", "Alfredo Canziani"] 45 | __maintainer__ = "Alfredo Canziani" 46 | __email__ = "alfredo.canziani@gmail.com" 47 | __status__ = "Production" # "Prototype", "Development", or "Production" 48 | __date__ = "Feb 17" 49 | --------------------------------------------------------------------------------