├── .gitignore
├── .idea
├── dictionaries
│ └── atcold.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── pytorch-MatchNet.iml
└── vcs.xml
├── README.md
├── data
├── README.md
├── VideoFolder.py
├── add_frame_numbering.sh
├── dump_data_set.sh
├── objectify.sh
├── resize_and_sample.sh
├── resize_and_split.sh
├── sample_video.sh
└── small_data_set
│ ├── cup
│ └── sfsdfs-nb.mp4
│ └── hand
│ ├── hand5-nb.mp4
│ └── hand_1-nb.mp4
├── image-pretraining
├── README.md
├── main.py
└── model
├── main.py
├── model
├── ConvLSTMCell.py
├── DiscriminativeCell.py
├── GenerativeCell.py
├── Model01.py
├── Model02.py
├── PrednetModel.py
├── README.md
├── RG.py
├── models.svg
└── utils
├── new_experiment.sh
├── notebook
├── README.md
├── data
├── display_loss.ipynb
├── figures_generator.ipynb
├── frequency_analysis.ipynb
├── get_all_embeddings.ipynb
├── get_data_stats.ipynb
├── model
├── network_bisection.ipynb
├── plot_conf.py
├── salient_regions.ipynb
├── stability_analysis.ipynb
├── utils
└── verify_resnet.ipynb
└── utils
├── README.md
├── check_exp_diff.sh
├── image_plot.py
├── show_error.plt
├── show_error_exp.plt
├── update_experiments.sh
└── visualise.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Remove only PyCharm's workspace configuration
2 | workspace.xml
3 |
4 | # Remove [I]Python caching
5 | __pycache__
6 | .ipynb_checkpoints/
7 |
8 | # Remove Vim temp files
9 | *.sw*
10 |
11 | # Remove other unnecessary stuff
12 | *.dot*
13 | *.tar
14 | *.txt
15 | *.jpg
16 | *.png
17 | *.pdf
18 | *.mp4
19 | *.svg
20 | *.pth
21 |
22 | # Remove experiments
23 | last
24 | results
25 | *backup
26 |
--------------------------------------------------------------------------------
/.idea/dictionaries/atcold.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | canziani
5 | conv
6 | convolutional
7 | criterions
8 | cuda
9 | fromarray
10 | intra
11 | logits
12 | lstm
13 | matplotlib
14 | optim
15 | prednet
16 | resizes
17 | sergey
18 | subsamples
19 | tanh
20 | zagoruyko
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pytorch-MatchNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # *CortexNet*
2 |
3 | This repo contains the *PyTorch* implementation of *CortexNet*.
4 | Check the [project website](https://engineering.purdue.edu/elab/CortexNet/) for further information.
5 |
6 | ## Project structure
7 |
8 | The project consists of the following folders and files:
9 |
10 | - [`data/`](data): contains *Bash* scripts and a *Python* class definition inherent video data loading;
11 | - [`image-pretraining/`](image-pretraining/): hosts the code for pre-training TempoNet's discriminative branch;
12 | - [`model/`](model): stores several network architectures, including [*PredNet*](https://coxlab.github.io/prednet/), an additive feedback *Model01*, and a modulatory feedback *Model02* ([*CortexNet*](https://engineering.purdue.edu/elab/CortexNet/));
13 | - [`notebook/`](notebook): collection of *Jupyter Notebook*s for data exploration and results visualisation;
14 | - [`utils/`](utils): scripts for
15 | - (current or former) training error plotting,
16 | - experiments `diff`,
17 | - multi-node synchronisation,
18 | - generative predictions visualisation,
19 | - network architecture graphing;
20 | - `results@`: link to the location where experimental results will be saved within 3-digit folders;
21 | - [`new_experiment.sh*`](new_experiment.sh): creates a new experiment folder, updates `last@`, prints a memo about last used settings;
22 | - `last@`: symbolic link pointing to a new results sub-directory created by `new_experiment.sh`;
23 | - [`main.py`](main.py): training script for *CortexNet* in *MatchNet* or *TempoNet* configuration;
24 |
25 | ## Dependencies
26 |
27 | + [*scikit-video*](https://github.com/scikit-video/scikit-video): accessing images / videos
28 |
29 | ```bash
30 | pip install sk-video
31 | ```
32 |
33 | + [*tqdm*](https://github.com/tqdm/tqdm): progress bar
34 |
35 | ```bash
36 | conda config --add channels conda-forge
37 | conda update --all
38 | conda install tqdm
39 | ```
40 |
41 | ## IDE
42 |
43 | This project has been realised with [*PyCharm*](https://www.jetbrains.com/pycharm/) by *JetBrains* and the [*Vim*](http://www.vim.org/) editor.
44 | [*Grip*](https://github.com/joeyespo/grip) has been also fundamental for crafting decent documtation locally.
45 |
46 | ## Initialise environment
47 |
48 | Once you've determined where you'd like to save your experimental results — let's call this directory `` — run the following commands from the project's root directory:
49 |
50 | ```bash
51 | ln -s results # replace
52 | mkdir results/000 && touch results/000/train.log # init. placeholder
53 | ln -s results/000 last # create pointer to the most recent result
54 | ```
55 |
56 | ## Setup new experiment
57 |
58 | Ready to run your first experiment?
59 | Type the following:
60 |
61 | ```bash
62 | ./new_experiment.sh
63 | ```
64 |
65 | ### GPU selection
66 |
67 | Let's say your machine has `N` GPUs.
68 | You can choose to use any of these, by specifying the index `n = 0, ..., N-1`.
69 | Therefore, type `CUDA_VISIBLE_DEVICES=n` just before `python ...` in the following sections.
70 |
71 | ## Train *MatchNet*
72 |
73 | + Download *e-VDS35* (*e.g.* `e-VDS35-May17.tar`) from [here](https://engineering.purdue.edu/elab/eVDS/).
74 | + Use [`data/resize_and_split.sh`](data/resize_and_split.sh) to prepare your (video) data for training.
75 | It resizes videos present in folders of folders (*i.e.* directory of classes) and may split them into training and validation set.
76 | May also skip short videos and trim longer ones.
77 | Check [`data/README.md`](data/README.md#matchnet-mode) for more details.
78 | + Run the [`main.py`](main.py) script to start training.
79 | Use `-h` to print the command line interface (CLI) arguments help.
80 |
81 | ```bash
82 | python -u main.py --mode MatchNet | tee last/train.log
83 | ```
84 |
85 | ## Train *TempoNet*
86 |
87 | + Download *e-VDS35* (*e.g.* `e-VDS35-May17.tar`) from [here](https://engineering.purdue.edu/elab/eVDS/).
88 | + Pre-train the forward branch (see [`image-pretraining/`](image-pretraining)) on an image data set (*e.g.* `33-image-set.tar` from [here](https://engineering.purdue.edu/elab/eVDS/));
89 | + Use [`data/resize_and_sample.sh`](data/resize_and_sample.sh) to prepare your (video) data for training.
90 | It resizes videos present in folders of folders (*i.e.* directory of classes) and samples them.
91 | Videos are then distributed across training and validation set.
92 | May also skip short videos and trim longer ones.
93 | Check [`data/README.md`](data/README.md#temponet-mode) for more details.
94 | + Run the [`main.py`](main.py) script to start training.
95 | Use `-h` to print the CLI arguments help.
96 |
97 | ```bash
98 | python -u main.py --mode TempoNet --pre-trained | tee last/train.log
99 | ```
100 |
101 | ## GPU selection
102 |
103 | To run on a specific GPU, say `n`, type `CUDA_VISIBLE_DEVICES=n` just before `python ...`.
104 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Video data pre-processing
2 |
3 | This folder contains the following scripts:
4 |
5 | - [`add_frame_numbering.sh*`](add_frame_numbering.sh): draws a huge number on each frame on a specific video;
6 | - [`dump_data_set.sh*`](dump_data_set.sh): get images from videos for ["traditional training"](https://github.com/pytorch/examples/tree/master/imagenet);
7 | - [`objectify.sh*`](objectify.sh): convert sampled-data from video-indexed to object-indexed;
8 | - [`resize_and_split.sh*`](resize_and_split.sh): see [below](#matchnet-mode);
9 | - [`resize_and_sample.sh*`](resize_and_sample.sh): see [below](#temponet-mode);
10 | - [`sample_video.sh*`](sample_video.sh): sample the `` into `k` time-subsampled `-` videos;
11 | - [`VideoFolder.py`](VideoFolder.py): *PyTorch* `data.Dataset`'s sub-class for video data loading.
12 |
13 | ## Remove white spaces from file names
14 | White spaces and scripts are not good friends.
15 | Replace all spaces with `_` in your source data set with the following command, where you have to replace `` with the correct location.
16 |
17 | ```bash
18 | rename -n "s/ /_/g" /*/* # for a dry run
19 | rename "s/ /_/g" /*/* # to rename the files in !!!
20 | ```
21 |
22 | In *e-VDS35* we have already done this for you.
23 | If you plan to use your own data, this step is fundamental!
24 |
25 | ## MatchNet mode
26 | ### Resize and *split* videos in train-val
27 |
28 | In order to speed up data loading, we shall resize the video shortest side to, say, `256`.
29 | To do so run the following script.
30 |
31 | ```bash
32 | ./resize_and_split.sh
33 | ```
34 |
35 | By default, the script will
36 |
37 | - skip videos shorter than `144` frames (`4.8`s)
38 | - trim videos longer than `654` frames (`21.8`s)
39 | - use the last `2`s for the validation split
40 | - resize the shortest side to `256`px
41 |
42 | These options can be varied and turned off by changing the header of the script, which now looks something like this.
43 |
44 | ```bash
45 | ################################################################################
46 | # SETTINGS #####################################################################
47 | ################################################################################
48 | # comment the next line if you don't want to skip short videos
49 | min_frames=144
50 | # comment the next line if you don't want to limit the max length
51 | max_frames=564
52 | # set split to 0 (seconds) if no splitting is required
53 | split=2
54 | ################################################################################
55 | ```
56 |
57 | ### File system
58 |
59 | Our input `data_set` looks somehow like this.
60 |
61 | ```bash
62 | data_set
63 | ├── barcode
64 | │ ├── 20160613_140057.mp4
65 | │ ├── 20160613_140115.mp4
66 | │ ├── 20160613_140138.mp4
67 | │ ├── 20160721_023437.mp4
68 | ├── bicycle
69 | │ ├── 0707_2_(2).mov
70 | │ ├── 0707_2_(4).mov
71 | ```
72 |
73 | Running `./resize_and_sample.sh data_set/ processed-data` yields
74 |
75 | ```bash
76 | processed-data/
77 | ├── train
78 | │ ├── barcode
79 | │ │ ├── 20160613_140057.mp4
80 | │ │ ├── 20160613_140115.mp4
81 | │ │ ├── 20160613_140138.mp4
82 | │ │ ├── 20160721_023437.mp4
83 | │ ├── bicycle
84 | │ │ ├── 0707_2_(2).mp4
85 | │ │ ├── 0707_2_(4).mp4
86 | ```
87 |
88 | ## TempoNet mode
89 | ### Resize and *sample* videos, then splits in train-val
90 |
91 | In order to speed up data loading, we shall resize the video shortest side to, say, `256`.
92 | To do so run the following script.
93 |
94 | ```bash
95 | ./resize_and_sample.sh
96 | ```
97 |
98 | By default, the script will
99 |
100 | - skip videos shorter than `144` frames (`4.8`s)
101 | - trim videos longer than `654` frames (`21.8`s)
102 | - perform `5` subsamples, use `4` for training, `1` for validation
103 | - resize the shortest side to `256`px
104 |
105 | These options can be varied and turned off by changing the header of the script, which now looks something like this.
106 |
107 | ```bash
108 | ################################################################################
109 | # SETTINGS #####################################################################
110 | ################################################################################
111 | # comment the next line if you don't want to skip short videos
112 | min_frames=144
113 | # comment the next line if you don't want to limit the max length
114 | max_frames=564
115 | # set sampling interval: k - 1 train, 1 val
116 | k=5
117 | ################################################################################
118 | ```
119 |
120 | The output directory will contain **as many folders as the total number of videos**.
121 | Each folder will contain the individual splits.
122 |
123 | ### From video-index- to class-major data organisation
124 |
125 | If you would like to train against object classes instead of video indices (like explained in the paper), you also need to run
126 |
127 | ```bash
128 | ./objectify.sh
129 | ```
130 |
131 | This script generates a new directory containing **as many folders as classes**, filled with symbolic links from the source directory.
132 | You can use it for both videos and dumped-videos (images) data sets.
133 |
134 | ### File system
135 |
136 | Running `./resize_and_sample.sh data_set/ sampled-data` yields
137 |
138 | ```bash
139 | sampled-data/
140 | ├── train
141 | │ ├── barcode-20160613_140057
142 | │ │ ├── 1.mp4
143 | │ │ ├── 2.mp4
144 | │ │ ├── 3.mp4
145 | │ │ └── 4.mp4
146 | │ ├── barcode-20160613_140115
147 | │ │ ├── 1.mp4
148 | │ │ ├── 2.mp4
149 | ```
150 |
151 | We can "objectify" the structure with `./objectify.sh sampled-data/ object-sampled-data` and get
152 |
153 | ```bash
154 | object-sampled-data/
155 | ├── train
156 | │ ├── barcode
157 | │ │ ├── 20160613_140057-1.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/1.mp4
158 | │ │ ├── 20160613_140057-2.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/2.mp4
159 | │ │ ├── 20160613_140057-3.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/3.mp4
160 | │ │ ├── 20160613_140057-4.mp4 -> ../../../sampled-data/train/barcode-20160613_140057/4.mp4
161 | │ │ ├── 20160613_140115-1.mp4 -> ../../../sampled-data/train/barcode-20160613_140115/1.mp4
162 | │ │ ├── 20160613_140115-2.mp4 -> ../../../sampled-data/train/barcode-20160613_140115/2.mp4
163 | │ │ ├── 20160613_140115-3.mp4 -> ../../../sampled-data/train/barcode-20160613_140115/3.mp4
164 | ```
165 |
166 | To train the discriminative feed-forward branch we need to dump our `sampled-data` with `./dump_data_set.sh sampled-data/ dumped-sampled-data`.
167 | The file system will look like this
168 |
169 | ```bash
170 | dumped-sampled-data/
171 | ├── train
172 | │ ├── barcode-20160613_140057
173 | │ │ ├── 1001.png
174 | │ │ ├── 1002.png
175 | │ │ ├── 1003.png
176 | │ │ ├── 1004.png
177 | │ │ ├── 1005.png
178 | │ │ ├── 1006.png
179 | │ │ ├── 1007.png
180 | ```
181 |
182 | where the "training class" correspond to the **video name**.
183 | If we wish to train against **object classes**, then we can run `./objectify.sh dumped-sampled-data/ dumped-object-sampled-data` and get the following
184 |
185 | ```bash
186 | dumped-object-sampled-data/
187 | ├── train
188 | │ ├── barcode
189 | │ │ ├── 20160613_140057-1001.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1001.png
190 | │ │ ├── 20160613_140057-1002.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1002.png
191 | │ │ ├── 20160613_140057-1003.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1003.png
192 | │ │ ├── 20160613_140057-1004.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1004.png
193 | │ │ ├── 20160613_140057-1005.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1005.png
194 | │ │ ├── 20160613_140057-1006.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1006.png
195 | │ │ ├── 20160613_140057-1007.png -> ../../../dumped-sampled-data/train/barcode-20160613_140057/1007.png
196 | ```
197 |
--------------------------------------------------------------------------------
/data/VideoFolder.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import torch
3 | import torch.utils.data as data
4 |
5 | from random import shuffle as list_shuffle # for shuffling list
6 | from math import ceil
7 | from os import listdir
8 | from os.path import isdir, join, isfile
9 | from itertools import islice
10 | from numpy.core.multiarray import concatenate, ndarray
11 | from skvideo.io import FFmpegReader, ffprobe
12 | from torch.utils.data.sampler import Sampler
13 | from torchvision import transforms as trn
14 | from tqdm import tqdm
15 | from time import sleep
16 | from bisect import bisect
17 |
18 | # Implement object from https://discuss.pytorch.org/t/loading-videos-from-folders-as-a-dataset-object/568
19 |
20 | VIDEO_EXTENSIONS = ['.mp4'] # pre-processing outputs MP4s only
21 |
22 |
23 | class BatchSampler(Sampler):
24 | def __init__(self, data_source, batch_size):
25 | """
26 | Samples batches sequentially, always in the same order.
27 |
28 | :param data_source: data set to sample from
29 | :type data_source: Dataset
30 | :param batch_size: concurrent number of video streams
31 | :type batch_size: int
32 | """
33 | self.batch_size = batch_size
34 | self.samples_per_row = ceil(len(data_source) / batch_size)
35 | self.num_samples = self.samples_per_row * batch_size
36 |
37 | def __iter__(self):
38 | return (self.samples_per_row * i + j for j in range(self.samples_per_row) for i in range(self.batch_size))
39 |
40 | def __len__(self):
41 | return self.num_samples # fake nb of samples, transparent wrapping around
42 |
43 |
44 | class VideoCollate:
45 | def __init__(self, batch_size):
46 | self.batch_size = batch_size
47 |
48 | def __call__(self, batch: iter) -> torch.Tensor or list(torch.Tensor):
49 | """
50 | Puts each data field into a tensor with outer dimension batch size
51 |
52 | :param batch: samples from a Dataset object
53 | :type batch: list
54 | :return: temporal batch of frames of size (t, batch_size, *frame.size()), 0 <= t < T, most likely t = T - 1
55 | :rtype: tuple
56 | """
57 | if torch.is_tensor(batch[0]):
58 | return torch.cat(tuple(t.unsqueeze(0) for t in batch), 0).view(-1, self.batch_size, *batch[0].size())
59 | elif isinstance(batch[0], int):
60 | return torch.LongTensor(batch).view(-1, self.batch_size)
61 | elif isinstance(batch[0], collections.Iterable):
62 | # if each batch element is not a tensor, then it should be a tuple
63 | # of tensors; in that case we collate each element in the tuple
64 | transposed = zip(*batch)
65 | return tuple(self.__call__(samples) for samples in transposed)
66 |
67 | raise TypeError(("batch must contain tensors, numbers, or lists; found {}"
68 | .format(type(batch[0]))))
69 |
70 |
71 | class VideoFolder(data.Dataset):
72 | def __init__(self, root, transform=None, target_transform=None, video_index=False, shuffle=None):
73 | """
74 | Initialise a ``data.Dataset`` object for concurrent frame fetching from videos in a directory of folders of videos
75 |
76 | :param root: Data directory (train or validation folders path)
77 | :type root: str
78 | :param transform: image transform-ing object from ``torchvision.transforms``
79 | :type transform: object
80 | :param target_transform: label transformation / mapping
81 | :type target_transform: object
82 | :param video_index: if ``True``, the label will be the video index instead of target class
83 | :type video_index: bool
84 | :param shuffle: ``None``, ``'init'`` or ``True``
85 | :type shuffle: str
86 | """
87 | classes, class_to_idx = self._find_classes(root)
88 | video_paths = self._find_videos(root, classes)
89 | videos, frames, frames_per_video, frames_per_class = self._make_data_set(
90 | root, video_paths, class_to_idx, shuffle, video_index
91 | )
92 |
93 | self.root = root
94 | self.video_paths = video_paths
95 | self.videos = videos
96 | self.opened_videos = [[] for _ in videos]
97 | self.frames = frames
98 | self.frames_per_video = frames_per_video
99 | self.frames_per_class = frames_per_class
100 | self.classes = classes
101 | self.class_to_idx = class_to_idx
102 | self.transform = transform
103 | self.target_transform = target_transform
104 | self.alternative_target = video_index
105 | self.shuffle = shuffle
106 |
107 | def __getitem__(self, frame_idx):
108 | if frame_idx == 0:
109 | self.free()
110 | if self.shuffle is True:
111 | self._shuffle()
112 |
113 | frame_idx %= self.frames # wrap around indexing, if asking too much
114 | video_idx = bisect(self.videos, ((frame_idx,),)) # video to which frame_idx belongs
115 | (last, first), (path, target) = self.videos[video_idx] # get video metadata
116 | frame = self._get_frame(frame_idx - first, video_idx, frame_idx == last) # get frame from video
117 | if self.transform is not None: # image processing
118 | frame = self.transform(frame)
119 | if self.target_transform is not None: # target processing
120 | target = self.target_transform(target)
121 |
122 | if self.alternative_target: return frame, video_idx
123 |
124 | return frame, target
125 |
126 | def __len__(self):
127 | return self.frames
128 |
129 | def _get_frame(self, seek, video_idx, last):
130 |
131 | opened_video = None # handle to opened target video
132 | if self.opened_videos[video_idx]: # if handle(s) exists for target video
133 | current = self.opened_videos[video_idx] # get handles list
134 | opened_video = next((ov for ov in current if ov[0] == seek), None) # look for matching seek
135 |
136 | if opened_video is None: # no (matching) handle found
137 | video_path = join(self.root, self.videos[video_idx][1][0]) # build video path
138 | video_file = FFmpegReader(video_path) # get a video file pointer
139 | video_iter = video_file.nextFrame() # get an iterator
140 | opened_video = [seek, islice(video_iter, seek, None), video_file] # seek video and create o.v. item
141 | self.opened_videos[video_idx].append(opened_video) # add opened video object to o.v. list
142 |
143 | opened_video[0] = seek + 1 # update seek pointer
144 | frame = next(opened_video[1]) # cache output frame
145 | if last:
146 | opened_video[2]._close() # close video file (private method?!)
147 | self.opened_videos[video_idx].remove(opened_video) # remove o.v. item
148 |
149 | return frame
150 |
151 | def free(self):
152 | """
153 | Frees all video files' pointers
154 | """
155 | for video in self.opened_videos: # for every opened video
156 | for _ in range(len(video)): # for as many times as pointers
157 | opened_video = video.pop() # pop an item
158 | opened_video[2]._close() # close the file
159 |
160 | def _shuffle(self):
161 | """
162 | Shuffles the video list
163 | by regenerating the sequence to sample sequentially
164 | """
165 | def _is_video_file(filename_):
166 | return any(filename_.endswith(extension) for extension in VIDEO_EXTENSIONS)
167 |
168 | root = self.root
169 | video_paths = self.video_paths
170 | class_to_idx = self.class_to_idx
171 | list_shuffle(video_paths) # shuffle
172 |
173 | videos = list()
174 | frames_per_video = list()
175 | frames_counter = 0
176 | for filename in tqdm(video_paths, ncols=80):
177 | class_ = filename.split('/')[0]
178 | data_path = join(root, filename)
179 | if _is_video_file(data_path):
180 | video_meta = ffprobe(data_path)
181 | start_idx = frames_counter
182 | frames = int(video_meta['video'].get('@nb_frames'))
183 | frames_per_video.append(frames)
184 | frames_counter += frames
185 | item = ((frames_counter - 1, start_idx), (filename, class_to_idx[class_]))
186 | videos.append(item)
187 |
188 | sleep(0.5) # allows for progress bar completion
189 | # update the attributes with the altered sequence
190 | self.video_paths = video_paths
191 | self.videos = videos
192 | self.frames = frames_counter
193 | self.frames_per_video = frames_per_video
194 |
195 | @staticmethod
196 | def _find_classes(data_path):
197 | classes = [d for d in listdir(data_path) if isdir(join(data_path, d))]
198 | classes.sort()
199 | class_to_idx = {classes[i]: i for i in range(len(classes))}
200 | return classes, class_to_idx
201 |
202 | @staticmethod
203 | def _find_videos(root, classes):
204 | return [join(c, d) for c in classes for d in listdir(join(root, c))]
205 |
206 | @staticmethod
207 | def _make_data_set(root, video_paths, class_to_idx, init_shuffle, video_index):
208 | def _is_video_file(filename_):
209 | return any(filename_.endswith(extension) for extension in VIDEO_EXTENSIONS)
210 |
211 | if init_shuffle and not video_index:
212 | list_shuffle(video_paths) # shuffle
213 |
214 | videos = list()
215 | frames_per_video = list()
216 | frames_per_class = [0] * len(class_to_idx)
217 | frames_counter = 0
218 | for filename in tqdm(video_paths, ncols=80):
219 | class_ = filename.split('/')[0]
220 | data_path = join(root, filename)
221 | if _is_video_file(data_path):
222 | video_meta = ffprobe(data_path)
223 | start_idx = frames_counter
224 | frames = int(video_meta['video'].get('@nb_frames'))
225 | frames_per_video.append(frames)
226 | frames_per_class[class_to_idx[class_]] += frames
227 | frames_counter += frames
228 | item = ((frames_counter - 1, start_idx), (filename, class_to_idx[class_]))
229 | videos.append(item)
230 |
231 | sleep(0.5) # allows for progress bar completion
232 | return videos, frames_counter, frames_per_video, frames_per_class
233 |
234 |
235 | def _test_video_folder():
236 | from textwrap import fill, indent
237 |
238 | batch_size = 5
239 |
240 | video_data_set = VideoFolder('small_data_set/')
241 | nb_of_classes = len(video_data_set.classes)
242 | print('There are', nb_of_classes, 'classes')
243 | print(indent(fill(' '.join(video_data_set.classes), 77), ' '))
244 | print('There are {} frames'.format(len(video_data_set)))
245 | print('Videos in the data set:', *video_data_set.videos, sep='\n')
246 |
247 | import inflect
248 | ordinal = inflect.engine().ordinal
249 |
250 | def print_list(my_list):
251 | for a, b in enumerate(my_list):
252 | print(a, ':', end=' [')
253 | print(*b, sep=',\n ', end=']\n')
254 |
255 | # get first 3 batches
256 | n = ceil(len(video_data_set) / batch_size)
257 | print('Batch size:', batch_size)
258 | print('Frames per row:', n)
259 | for big_j in range(0, n, 90):
260 | batch = list()
261 | for j in range(big_j, big_j + 90):
262 | if j >= n: break # there are no more frames
263 | batch.append(tuple(video_data_set[i * n + j][0] for i in range(batch_size)))
264 | batch[-1] = concatenate(batch[-1], 0)
265 | batch = concatenate(batch, 1)
266 | _show_numpy(batch, 1e-1)
267 | print(ordinal(big_j // 90 + 1), '90 batches of shape', batch.shape)
268 | print_list(video_data_set.opened_videos)
269 |
270 | print('Freeing resources')
271 | video_data_set.free()
272 | print_list(video_data_set.opened_videos)
273 |
274 | # get frames 50 -> 52
275 | batch = list()
276 | for i in range(50, 53):
277 | batch.append(video_data_set[i][0])
278 | _show_numpy(concatenate(batch, 1))
279 | print_list(video_data_set.opened_videos)
280 |
281 |
282 | def _test_data_loader():
283 | big_t = 10
284 | batch_size = 5
285 | t = trn.Compose((trn.ToPILImage(), trn.ToTensor())) # <-- add trn.CenterCrop(224) in between for training
286 | data_set = VideoFolder('small_data_set', t)
287 | my_loader = data.DataLoader(dataset=data_set, batch_size=batch_size * big_t, shuffle=False,
288 | sampler=BatchSampler(data_set, batch_size), num_workers=0,
289 | collate_fn=VideoCollate(batch_size))
290 | print('Is my_loader an iterator [has __next__()]:', isinstance(my_loader, collections.Iterator))
291 | print('Is my_loader an iterable [has __iter__()]:', isinstance(my_loader, collections.Iterable))
292 | my_iter = iter(my_loader)
293 | my_batch = next(my_iter)
294 | print('my_batch is a', type(my_batch), 'of length', len(my_batch))
295 | print('my_batch[0] is a', my_batch[0].type(), 'of size', tuple(my_batch[0].size()), ' # will 224, 224')
296 | _show_torch(_tile_up(my_batch), .2)
297 | for i in range(3): _show_torch(_tile_up(next(my_iter)), .2)
298 |
299 |
300 | def _show_numpy(tensor: ndarray, zoom: float = 1.) -> None:
301 | """
302 | Display a ndarray image on screen
303 |
304 | :param tensor: image to visualise, of size (h, w, 1/3)
305 | :type tensor: ndarray
306 | :param zoom: zoom factor
307 | :type zoom: float
308 | """
309 | from PIL import Image
310 | shape = tuple(map(lambda s: round(s * zoom), tensor.shape))
311 | Image.fromarray(tensor).resize((shape[1], shape[0])).show()
312 |
313 |
314 | def _show_torch(tensor: torch.FloatTensor, zoom: float = 1.) -> None:
315 | numpy_tensor = tensor.clone().mul(255).int().numpy().astype('u1').transpose(1, 2, 0)
316 | _show_numpy(numpy_tensor, zoom)
317 |
318 |
319 | def _tile_up(temporal_batch):
320 | a = torch.cat(tuple(temporal_batch[0][:, i] for i in range(temporal_batch[0].size(1))), 2)
321 | a = torch.cat(tuple(a[j] for j in range(a.size(0))), 2)
322 | return a
323 |
324 |
325 | if __name__ == '__main__':
326 | _test_video_folder()
327 | _test_data_loader()
328 |
329 |
330 | __author__ = "Alfredo Canziani"
331 | __credits__ = ["Alfredo Canziani"]
332 | __maintainer__ = "Alfredo Canziani"
333 | __email__ = "alfredo.canziani@gmail.com"
334 | __status__ = "Production" # "Prototype", "Development", or "Production"
335 | __date__ = "Feb 17"
336 |
--------------------------------------------------------------------------------
/data/add_frame_numbering.sh:
--------------------------------------------------------------------------------
1 | # Add frame number to video, to check correct loading
2 | # ./add_frame_numbering.sh 256min_data_set/barcode/20160613_140057.mp4
3 | # generate labelled small_data_set/barcode/20160613_140057-nb.mp4
4 |
5 | src_video=$1
6 | dst_dir="small_data_set"
7 | dst_video="${src_video%.*}-nb.mp4"
8 | dst_video="$dst_dir/${dst_video#*/}"
9 | dst_dir="${dst_video%/*}"
10 |
11 | mkdir -p $dst_dir
12 | ffmpeg \
13 | -i $src_video \
14 | -filter:v "drawtext=fontsize=200:fontfile=Arial.ttf: text=%{n}: x=(w-tw)/2: y=50:fontcolor=white: box=1: boxcolor=0x00000099" \
15 | -loglevel quiet \
16 | $dst_video
17 | printf "$src_video --> $dst_video\n"
18 |
--------------------------------------------------------------------------------
/data/dump_data_set.sh:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Dump data set
3 | ################################################################################
4 | # Alfredo Canziani, Apr 17
5 | ################################################################################
6 | # Run as:
7 | # ./dump_data_set.sh src_path/ dst_path/
8 | ################################################################################
9 |
10 | # some colours
11 | r='\033[0;31m' # red
12 | g='\033[0;32m' # green
13 | b='\033[0;34m' # blue
14 | n='\033[0m' # none
15 |
16 | # title
17 | echo "Dumping video data set into separate frames"
18 |
19 | # assert existence of source directory
20 | src_dir=${1%/*} # remove trailing /, if present
21 | if [ ! -d $src_dir ] || [ -z $src_dir ]; then
22 | echo -e "${r}Source directory/link \"$src_dir\" is missing. Exiting.${n}"
23 | exit 1
24 | fi
25 | echo -e " - Source directory/link set to \"$b$src_dir$n\""
26 |
27 | # assert existence of destination directory
28 | dst_dir="${2%/*}" # remove trailing /, if present
29 | if [ -d $dst_dir ]; then
30 | echo -e "${r}Destination directory \"$dst_dir\" already existent." \
31 | "Exiting.${n}"
32 | exit 1
33 | fi
34 | echo -e " - Destination directory set to \"$b$dst_dir$n\""
35 |
36 | # check if all is good
37 | printf "Does it look fine? (${g}y${g}${n}/${r}n${n}) "
38 | read ans
39 | if [ $ans == 'n' ]; then
40 | echo -e "${r}Exiting.${n}"
41 | exit 0
42 | fi
43 |
44 | # for every class
45 | for set_ in $(ls $src_dir); do
46 |
47 | printf "\nProcessing $set_ set\n"
48 |
49 | for class in $(ls $src_dir/$set_); do
50 |
51 | printf " > Processing class \"$class\":"
52 |
53 | # define src and dst
54 | src_class_dir="$src_dir/$set_/$class"
55 | dst_class_dir="$dst_dir/$set_/$class"
56 | mkdir -p $dst_class_dir
57 |
58 | # for each video in the class
59 | for video in $(ls $src_class_dir); do
60 |
61 | printf " \"$video\""
62 |
63 | # define src and dst video paths
64 | src_video_path="$src_class_dir/$video"
65 | dst_video_path="$dst_class_dir/${video%.*}%03d.png"
66 |
67 | ffmpeg \
68 | -loglevel error \
69 | -i $src_video_path \
70 | $dst_video_path
71 |
72 | done
73 | echo ""
74 | done
75 | done
76 |
77 | echo -e "${r}Done. Exiting.${n}"
78 | exit 0
79 |
--------------------------------------------------------------------------------
/data/objectify.sh:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Link videos into object classes
3 | ################################################################################
4 | # Alfredo Canziani, Apr 17
5 | ################################################################################
6 | # Run as:
7 | # ./objectify.sh src_path/ dst_path/
8 | ################################################################################
9 |
10 | # Originally used these. Now they are arguments.
11 | src="sampled-data" # one folder per video
12 | dst="object-sampled-data" # one folder per object
13 | src="dumped-sampled-data" # one folder per video
14 | dst="dumped-object-sampled-data" # one folder per object
15 |
16 | src=${1%/*} # remove trailing /, if present
17 | dst=${2%/*} # remove trailing /, if present
18 |
19 | classes=$(ls processed-data/train/)
20 | sets="train val"
21 |
22 | for c in $classes; do
23 | echo "Processing class $c"
24 | for s in $sets; do
25 | dst_dir="$dst/$s/$c"
26 | mkdir $dst_dir
27 | for v in $(ls $src/$s/$c-*/*); do
28 | video_name=${v#$src/$s/$c-} # removes leading crap
29 | video_name=${video_name/'/'/-} # convert / into -
30 | ln -s ../../../$v $dst_dir/$video_name
31 | done
32 | done
33 | done
34 |
35 |
--------------------------------------------------------------------------------
/data/resize_and_sample.sh:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Pre-process video data
3 | ################################################################################
4 | # Alfredo Canziani, Apr 17
5 | ################################################################################
6 |
7 | # Pre-process video data
8 | #
9 | # - resize video data minor side to specific size
10 | # - sample frames and split into train a val sets
11 | # - skip "too short" videos
12 | # - limit max length
13 | #
14 | # Run as:
15 | # ./resize_and_split.sh src_path/ dst_path/
16 | #
17 | # It's better to perform the resizing and the sampling together since
18 | # re-encoding is necessary when a temporal sampling is performed.
19 | # Skipping and clipping max length are also easily achievable at this point in
20 | # time.
21 |
22 | # current object video data set
23 | # 95% interval: [144, 564] -> [4.8, 21.8] seconds
24 | # mean number of frames: 354 -> 11.8 seconds
25 |
26 | ################################################################################
27 | # SETTINGS #####################################################################
28 | ################################################################################
29 | # comment the next line if you don't want to skip short videos
30 | min_frames=144
31 | # comment the next line if you don't want to limit the max length
32 | max_frames=564
33 | # set sampling interval: k - 1 train, 1 val
34 | k=5
35 | ################################################################################
36 |
37 | # some colours
38 | r='\033[0;31m' # red
39 | g='\033[0;32m' # green
40 | b='\033[0;34m' # blue
41 | n='\033[0m' # none
42 |
43 | # check min_frames setting
44 | printf " - "
45 | if [ -n "$min_frames" ]; then
46 | echo -e "Skipping videos with < $b$min_frames$n frames"
47 | skip_count=0
48 | else
49 | echo "No skipping short vidos"
50 | min_frames=0
51 | fi
52 |
53 | # check max_frames setting
54 | printf " - "
55 | if [ -n "$max_frames" ]; then
56 | echo -e "Trimming videos with > $b$max_frames$n frames"
57 | trim_count=0
58 | else
59 | echo "No trimming long vidos"
60 | fi
61 |
62 | # check split setting
63 | printf " - "
64 | echo -e "Sampling every $b$k$n frames"
65 | kk=$(awk "BEGIN{print 1/$k}")
66 |
67 | # assert existence of source directory
68 | src_dir=${1%/*} # remove trailing /, if present
69 | if [ ! -d $src_dir ] || [ -z $src_dir ]; then
70 | echo -e "${r}Source directory/link \"$src_dir\" is missing. Exiting.${n}"
71 | exit 1
72 | fi
73 | echo -e " - Source directory/link set to \"$b$src_dir$n\""
74 |
75 | # assert existence of destination directory
76 | dst_dir="${2%/*}"
77 | if [ -d $dst_dir ]; then
78 | echo -e "${r}Destination directory \"$dst_dir\" already existent." \
79 | "Exiting.${n}"
80 | exit 1
81 | fi
82 | echo -e " - Destination directory set to \"$b$dst_dir$n\""
83 |
84 | # check if all is good
85 | printf "Does it look fine? (${g}y${g}${n}/${r}n${n}) "
86 | read ans
87 | if [ $ans == 'n' ]; then
88 | echo -e "${r}Exiting.${n}"
89 | exit 0
90 | fi
91 |
92 | # for every class
93 | for class in $(ls $src_dir); do
94 |
95 | printf "\nProcessing class \"$class\"\n"
96 |
97 | # define src
98 | src_class_dir="$src_dir/$class"
99 |
100 | # for each video in the class
101 | for video in $(ls $src_class_dir); do
102 |
103 | printf " > Loading video \"$video\". "
104 |
105 | # define src video path
106 | src_video_path="$src_class_dir/$video"
107 |
108 | # count the frames
109 | frames=$(ffprobe \
110 | -loglevel quiet \
111 | -show_streams \
112 | -select_streams v \
113 | $src_video_path | awk \
114 | '/nb_frames=/{sub(/nb_frames=/,""); print}')
115 |
116 | # skip if too short
117 | if ((frames < min_frames)); then
118 | printf "Frames: $b$frames$n < $b$min_frames$n min frames. "
119 | echo -e "${r}Skipping.$n"
120 | ((skip_count++))
121 | continue
122 | fi
123 |
124 | # define and make dst and val dir
125 | dst_video_path="$dst_dir/train/$class-${video%.*}"
126 | val_video_path="$dst_dir/val/$class-${video%.*}"
127 | mkdir -p $dst_video_path
128 | mkdir -p $val_video_path
129 |
130 | # get src_video frame rate
131 | fps=$(ffprobe \
132 | -loglevel error \
133 | -show_streams \
134 | -select_streams v \
135 | $src_video_path | awk \
136 | '/avg_frame_rate=/{sub(/avg_frame_rate=/,""); print}')
137 |
138 | # get my src_video_path
139 | # discard audio stram
140 | # be quiet (show errors only)
141 | # use complex filter (1 input file, multiple output files)
142 | # rescale the video stram min side to 256
143 | # select frames using frame_number % k and send it to streams 1, ..., k
144 | # as long as frame_number < max_frames
145 | # use the input average frame rate as output fps
146 | # send each stream to a separate output file dst_video_path/{1..k-1}.mp4
147 | # and val_video_path/k.mp4
148 | printf "Rescaling and sampling"
149 | ffmpeg \
150 | -i $src_video_path \
151 | -an \
152 | -loglevel error \
153 | -filter_complex \
154 | "setpts=$kk*PTS, \
155 | scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))[m]; \
156 | [m]select=n=$k:e=(mod(n\,$k)+1)*lt(n\,$max_frames) \
157 | $(for ((i=1; i<=$k; i++)); do
158 | echo -n "[a$i]"
159 | done)" \
160 | $(for ((i=1; i<$k; i++)); do
161 | echo -n "-r $fps -map [a$i] $dst_video_path/$i.mp4 "
162 | done
163 | echo -n "-r $fps -map [a$k] $val_video_path/$k.mp4"
164 | )
165 |
166 | # check the output stream resolution
167 | printf ' --> '
168 | ffprobe \
169 | -loglevel quiet \
170 | -show_streams \
171 | "$dst_video_path/1.mp4" | awk \
172 | -F= \
173 | '/^width/{printf $2"x"}; /^height/{printf $2"."}'
174 |
175 | if ((frames > max_frames)); then
176 | echo -n " Trimming $frames --> $max_frames."
177 | ((trim_count++))
178 | fi
179 |
180 | # new line :)
181 | printf "\n"
182 | done
183 | done
184 |
185 | printf "\n---------------\n"
186 | printf "Skipped $b%d$n videos\n" "$skip_count"
187 | printf "Trimmed $b%d$n videos" "$trim_count"
188 | printf "\n---------------\n\n"
189 | echo -e "${r}Exiting.${n}"
190 | exit 0
191 |
--------------------------------------------------------------------------------
/data/resize_and_split.sh:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Pre-process video data
3 | ################################################################################
4 | # Alfredo Canziani, Feb 17
5 | ################################################################################
6 |
7 | # Pre-process video data
8 | #
9 | # - resize video data minor side to specific size
10 | # - split into train a val sets
11 | # - skip "too short" videos
12 | # - limit max length
13 | #
14 | # Run as:
15 | # ./resize_and_split.sh src_path/ dst_path/
16 | #
17 | # It's better to perform the resizing and the splitting together since
18 | # re-encoding is necessary when a temporal split is performed.
19 | # Skipping and clipping max length are also easily achievable at this point in
20 | # time.
21 |
22 | # current object video data set
23 | # 95% interval: [144, 564] -> [4.8, 21.8] seconds
24 | # mean number of frames: 354 -> 11.8 seconds
25 |
26 | ################################################################################
27 | # SETTINGS #####################################################################
28 | ################################################################################
29 | # comment the next line if you don't want to skip short videos
30 | min_frames=144
31 | # comment the next line if you don't want to limit the max length
32 | max_frames=564
33 | # set split to 0 (seconds) if no splitting is required
34 | split=2
35 | ################################################################################
36 |
37 | # some colours
38 | r='\033[0;31m' # red
39 | g='\033[0;32m' # green
40 | b='\033[0;34m' # blue
41 | n='\033[0m' # none
42 |
43 | # check min_frames setting
44 | printf " - "
45 | if [ -n "$min_frames" ]; then
46 | echo -e "Skipping videos with < $b$min_frames$n frames"
47 | skip_count=0
48 | else
49 | echo "No skipping short vidos"
50 | min_frames=0
51 | fi
52 |
53 | # check max_frames setting
54 | printf " - "
55 | if [ -n "$max_frames" ]; then
56 | echo -e "Trimming videos with > $b$max_frames$n frames"
57 | trim_count=0
58 | else
59 | echo "No trimming long vidos"
60 | fi
61 |
62 | # check split setting
63 | printf " - "
64 | if [ $split != 0 ]; then
65 | echo -e "Using last $b$split$n seconds for validation"
66 | dst="train/"
67 | else
68 | echo "No train-validation splitting will be performed"
69 | dst=""
70 | fi
71 |
72 | # assert existence of source directory
73 | src_dir=${1%/*} # remove trailing /, if present
74 | if [ ! -d $src_dir ] || [ -z $src_dir ]; then
75 | echo -e "${r}Source directory/link \"$src_dir\" is missing. Exiting.${n}"
76 | exit 1
77 | fi
78 | echo -e " - Source directory/link set to \"$b$src_dir$n\""
79 |
80 | # assert existence of destination directory
81 | dst_dir="${2%/*}"
82 | if [ -d $dst_dir ]; then
83 | echo -e "${r}Destination directory \"$dst_dir\" already existent." \
84 | "Exiting.${n}"
85 | exit 1
86 | fi
87 | echo -e " - Destination directory set to \"$b$dst_dir$n\""
88 |
89 | # check if all is good
90 | printf "Does it look fine? (${g}y${g}${n}/${r}n${n}) "
91 | read ans
92 | if [ $ans == 'n' ]; then
93 | echo -e "${r}Exiting.${n}"
94 | exit 0
95 | fi
96 |
97 | # for every class
98 | for class in $(ls $src_dir); do
99 |
100 | printf "\nProcessing class \"$class\"\n"
101 |
102 | # define src and dst dir, make dst dir
103 | src_class_dir="$src_dir/$class"
104 | dst_class_dir="$dst_dir/$dst$class"
105 | mkdir -p $dst_class_dir
106 |
107 | # if split > 0, deal with validation dir too
108 | if [ $split != 0 ]; then
109 | val_class_dir="$dst_dir/val/$class"
110 | mkdir -p $val_class_dir
111 | fi
112 |
113 | # for each video in the class
114 | for video in $(ls $src_class_dir); do
115 |
116 | printf " > Loading video \"$video\". "
117 |
118 | # define src video path
119 | src_video_path="$src_class_dir/$video"
120 |
121 | # count the frames
122 | frames=$(ffprobe \
123 | -loglevel quiet \
124 | -show_streams \
125 | -select_streams v \
126 | $src_video_path | awk \
127 | '/nb_frames=/{sub(/nb_frames=/,""); print}')
128 |
129 | # skip if too short
130 | if ((frames < min_frames)); then
131 | printf "Frames: $b$frames$n < $b$min_frames$n min frames. "
132 | echo -e "${r}Skipping.$n"
133 | ((skip_count++))
134 | continue
135 | fi
136 |
137 | # get src_video duration
138 | tot_t=$(ffprobe \
139 | -loglevel quiet \
140 | -show_streams \
141 | -select_streams v \
142 | $src_video_path | awk \
143 | '/duration=/{sub(/duration=/,""); print}')
144 |
145 | # get src_video frame rate
146 | fps=$(ffprobe \
147 | -loglevel error \
148 | -show_streams \
149 | -select_streams v \
150 | $src_video_path | awk \
151 | '/avg_frame_rate=/{sub(/avg_frame_rate=/,""); print}')
152 |
153 | # if there is a max_frames and we are over it, redefine tot_t
154 | if [ -n "$max_frames" ] && ((frames > max_frames)); then
155 | printf "Frames: $b$frames$n > $b$max_frames$n max frames. "
156 | printf "Trimming %.2fs" "$tot_t"
157 | tot_t=$(awk \
158 | "BEGIN{printf (\"%.4f\",$tot_t-($frames-$max_frames)/($fps))}")
159 | printf " --> %.2fs. " "$tot_t"
160 | ((trim_count++))
161 | fi
162 |
163 | # compute duration in seconds
164 | end_t=$(awk "BEGIN{printf (\"%f\",$tot_t-$split)}")
165 |
166 | # compute duration in ffmpeg format
167 | ffmpeg_end_t=$(awk \
168 | "BEGIN{printf (\"%02d:%02d:%02.4f\",$end_t/3600,($end_t%3600)/60,($end_t%60))}")
169 |
170 | # get my src_video_path
171 | # discard audio stram
172 | # use it until ffmpeg_end_t
173 | # rescale the video stram min side to 256
174 | # use the input average frame rate as output fps
175 | # be quiet (show errors only)
176 | # save at dst_video_path
177 | printf "Rescaling"
178 | dst_video_path="$dst_class_dir/${video%.*}.mp4" # replace extension
179 | ffmpeg \
180 | -i $src_video_path \
181 | -an \
182 | -to $ffmpeg_end_t \
183 | -filter:v "scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))" \
184 | -r $fps \
185 | -loglevel error \
186 | $dst_video_path
187 |
188 | # check the output stream resolution
189 | printf ' --> '
190 | ffprobe \
191 | -loglevel quiet \
192 | -show_streams \
193 | $dst_video_path | awk \
194 | -F= \
195 | '/^width/{printf $2"x"}; /^height/{printf $2"."}'
196 |
197 | if [ $split != 0 ]; then
198 | # start at Ns from the end
199 | # of my src_video_path
200 | # discard audio stram
201 | # rescale the video stram min side to 256
202 | # be quiet (show errors only)
203 | # save at val_video_path
204 | printf " Splitting"
205 | val_video_path="$val_class_dir/${video%.*}.mp4" # replace extension
206 | ffmpeg \
207 | -sseof -$split \
208 | -i $src_video_path \
209 | -an \
210 | -filter:v "scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))" \
211 | -r $fps \
212 | -loglevel error \
213 | $val_video_path
214 |
215 | # print temporal split
216 | duration=$(awk \
217 | "BEGIN{printf (\"%02d:%02d:%02.4f\",$tot_t/3600,($tot_t%3600)/60,($tot_t%60))}")
218 | printf " 00:00:00 / $ffmpeg_end_t / $duration."
219 | fi
220 |
221 | # new line :)
222 | printf "\n"
223 | done
224 | done
225 |
226 | printf "\n---------------\n"
227 | printf "Skipped $b%d$n videos\n" "$skip_count"
228 | printf "Trimmed $b%d$n videos" "$trim_count"
229 | printf "\n---------------\n\n"
230 | echo -e "${r}Exiting.${n}"
231 | exit 0
232 |
--------------------------------------------------------------------------------
/data/sample_video.sh:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Video sampler: y_i[n] = x[n_i + N * n], i < N
3 | ################################################################################
4 | # Alfredo Canziani, Mar 17
5 | ################################################################################
6 | # Run as
7 | # ./sample_video.sh src_video dst_prefix
8 | ################################################################################
9 |
10 | src="small_data_set/cup/sfsdfs-nb.mp4"
11 | dst="sampled/sfsdfs-nb"
12 | src="data_set/barcode/20160613_140057.mp4"
13 | dst="sampled/20160613_140057"
14 | src="data_set/floor/VID_20160605_094332.mp4"
15 | dst="sampled/VID_20160605_094332"
16 | src="/home/atcold/Videos/20170416_184611.mp4"
17 | dst="bme-car/20170416_184611"
18 | src="/home/atcold/Videos/20170418_113638.mp4"
19 | dst="bme-chair/20170418_113638"
20 | src="/home/atcold/Videos/20160603_133515.mp4"
21 | dst="abhi-car/20160603_133515"
22 | src="/home/atcold/Videos/20170419_125021.mp4"
23 | dst="bme-chair/20170419_125021"
24 |
25 | src=$1
26 | dst=$2
27 |
28 | k=5
29 | kk=$(awk "BEGIN{print 1/$k}")
30 | ffmpeg \
31 | -i $src \
32 | -an \
33 | -loglevel error \
34 | -filter_complex \
35 | "setpts=$kk*PTS, \
36 | scale=w=2*trunc(128*max(1\, iw/ih)):h=2*trunc(128*max(1\, ih/iw))[m]; \
37 | [m]select=n=$k:e=(mod(n\, $k)+1)*lt(n\, 564) \
38 | $(for ((i=1; i<=$k; i++)); do
39 | echo -n "[a$i]"
40 | done)" \
41 | $(for ((i=1; i<=$k; i++)); do
42 | echo -n "-r 31230000/1042111 -map [a$i] $dst-$i.mp4 "
43 | done)
44 |
--------------------------------------------------------------------------------
/data/small_data_set/cup/sfsdfs-nb.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-lab/pytorch-CortexNet/bc28dac4e6a1ad9abb11e2fbc48d310a85e9903a/data/small_data_set/cup/sfsdfs-nb.mp4
--------------------------------------------------------------------------------
/data/small_data_set/hand/hand5-nb.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-lab/pytorch-CortexNet/bc28dac4e6a1ad9abb11e2fbc48d310a85e9903a/data/small_data_set/hand/hand5-nb.mp4
--------------------------------------------------------------------------------
/data/small_data_set/hand/hand_1-nb.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-lab/pytorch-CortexNet/bc28dac4e6a1ad9abb11e2fbc48d310a85e9903a/data/small_data_set/hand/hand_1-nb.mp4
--------------------------------------------------------------------------------
/image-pretraining/README.md:
--------------------------------------------------------------------------------
1 | # Image pre-training
2 |
3 | Find the original code at [PyTorch ImageNet example](https://github.com/pytorch/examples/tree/master/imagenet).
4 | This adaptation trains the discriminative branch of CortexNet for TempoNet.
5 |
6 | ## Training
7 |
8 | To train the discriminative branch of CortexNet, run `main.py` with the path to an image data set:
9 |
10 | ```bash
11 | python main.py | tee train.log
12 | ```
13 |
14 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs.
15 |
16 | ## Usage
17 |
18 | ```
19 | usage: main.py [-h] [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR]
20 | [--momentum M] [--weight-decay W] [--print-freq N]
21 | [--resume PATH] [-e] [--pretrained] [--size [S [S ...]]]
22 | DIR
23 |
24 | PyTorch ImageNet Training
25 |
26 | positional arguments:
27 | DIR path to dataset
28 |
29 | optional arguments:
30 | -h, --help show this help message and exit
31 | -j N, --workers N number of data loading workers (default: 4)
32 | --epochs N number of total epochs to run
33 | --start-epoch N manual epoch number (useful on restarts)
34 | -b N, --batch-size N mini-batch size (default: 256)
35 | --lr LR, --learning-rate LR
36 | initial learning rate
37 | --momentum M momentum
38 | --weight-decay W, --wd W
39 | weight decay (default: 1e-4)
40 | --print-freq N, -p N print frequency (default: 10)
41 | --resume PATH path to latest checkpoint (default: none)
42 | -e, --evaluate evaluate model on validation set
43 | --pretrained use pre-trained model
44 | --size [S [S ...]] number and size of hidden layers
45 | ```
46 |
--------------------------------------------------------------------------------
/image-pretraining/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim
11 | import torch.utils.data
12 | import torchvision.transforms as transforms
13 | import torchvision.datasets as datasets
14 | import torchvision.models as models
15 |
16 |
17 | model_names = sorted(name for name in models.__dict__
18 | if name.islower() and not name.startswith("__")
19 | and callable(models.__dict__[name]))
20 |
21 |
22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
23 | parser.add_argument('data', metavar='DIR',
24 | help='path to dataset')
25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
26 | help='number of data loading workers (default: 4)')
27 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
28 | help='number of total epochs to run')
29 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
30 | help='manual epoch number (useful on restarts)')
31 | parser.add_argument('-b', '--batch-size', default=256, type=int,
32 | metavar='N', help='mini-batch size (default: 256)')
33 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
34 | metavar='LR', help='initial learning rate')
35 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
36 | help='momentum')
37 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
38 | metavar='W', help='weight decay (default: 1e-4)')
39 | parser.add_argument('--print-freq', '-p', default=10, type=int,
40 | metavar='N', help='print frequency (default: 10)')
41 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
42 | help='path to latest checkpoint (default: none)')
43 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
44 | help='evaluate model on validation set')
45 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
46 | help='use pre-trained model')
47 | parser.add_argument('--size', type=int, default=(3, 32, 64, 128, 256, 256, 256), nargs='*',
48 | help='number and size of hidden layers', metavar='S')
49 |
50 | best_prec1 = 0
51 |
52 |
53 | def main():
54 | global args, best_prec1
55 | args = parser.parse_args()
56 | args.size = tuple(args.size)
57 |
58 | # create model
59 | from model.Model02 import Model02 as Model
60 |
61 | class Capsule(nn.Module):
62 |
63 | def __init__(self):
64 | super().__init__()
65 | nb_of_classes = 33 # 970 (vid) or 35 (vid obj) or 33 (imgs)
66 | self.inner_model = Model(args.size + (nb_of_classes,), (256, 256))
67 |
68 | def forward(self, x):
69 | (_, _), (_, video_index) = self.inner_model(x, None)
70 | return video_index
71 |
72 | model = Capsule()
73 |
74 | model = torch.nn.DataParallel(model).cuda()
75 |
76 | cudnn.benchmark = True
77 |
78 | # Data loading code
79 | traindir = os.path.join(args.data, 'train')
80 | valdir = os.path.join(args.data, 'val')
81 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
82 | # std=[0.229, 0.224, 0.225])
83 |
84 | train_data = datasets.ImageFolder(traindir, transforms.Compose([
85 | transforms.CenterCrop(256),
86 | transforms.ToTensor(),
87 | ]))
88 | train_loader = torch.utils.data.DataLoader(
89 | train_data,
90 | batch_size=args.batch_size, shuffle=True,
91 | num_workers=args.workers, pin_memory=True
92 | )
93 |
94 | val_data = datasets.ImageFolder(valdir, transforms.Compose([transforms.CenterCrop(256), transforms.ToTensor(), ]))
95 | val_loader = torch.utils.data.DataLoader(
96 | val_data,
97 | batch_size=args.batch_size, shuffle=False,
98 | num_workers=args.workers, pin_memory=True
99 | )
100 |
101 | # define loss function (criterion) and optimizer
102 | class_count = [0] * len(train_data.classes)
103 | for i in train_data.imgs: class_count[i[1]] += 1
104 | train_crit_weight = torch.Tensor(class_count)
105 | train_crit_weight.div_(train_crit_weight.mean()).pow_(-1)
106 | train_criterion = nn.CrossEntropyLoss(train_crit_weight).cuda()
107 |
108 | class_count = [0] * len(val_data.classes)
109 | for i in val_data.imgs: class_count[i[1]] += 1
110 | val_crit_weight = torch.Tensor(class_count)
111 | val_crit_weight.div_(val_crit_weight.mean()).pow_(-1)
112 | val_criterion = nn.CrossEntropyLoss(val_crit_weight).cuda()
113 |
114 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
115 | momentum=args.momentum,
116 | weight_decay=args.weight_decay)
117 |
118 | if args.evaluate:
119 | validate(val_loader, model, val_criterion)
120 | return
121 |
122 | for epoch in range(args.start_epoch, args.epochs):
123 | adjust_learning_rate(optimizer, epoch)
124 |
125 | # train for one epoch
126 | train(train_loader, model, train_criterion, optimizer, epoch)
127 |
128 | # evaluate on validation set
129 | prec1 = validate(val_loader, model, val_criterion)
130 |
131 | # remember best prec@1 and save checkpoint
132 | is_best = prec1 > best_prec1
133 | best_prec1 = max(prec1, best_prec1)
134 | save_checkpoint({
135 | 'epoch': epoch + 1,
136 | 'state_dict': model.state_dict(),
137 | 'best_prec1': best_prec1,
138 | }, is_best)
139 |
140 |
141 | def train(train_loader, model, criterion, optimizer, epoch):
142 | batch_time = AverageMeter()
143 | data_time = AverageMeter()
144 | losses = AverageMeter()
145 | top1 = AverageMeter()
146 | top5 = AverageMeter()
147 |
148 | # switch to train mode
149 | model.train()
150 |
151 | end = time.time()
152 | for i, (input, target) in enumerate(train_loader):
153 | # measure data loading time
154 | data_time.update(time.time() - end)
155 |
156 | target = target.cuda(async=True)
157 | input_var = torch.autograd.Variable(input)
158 | target_var = torch.autograd.Variable(target)
159 |
160 | # compute output
161 | output = model(input_var)
162 | loss = criterion(output, target_var)
163 |
164 | # measure accuracy and record loss
165 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
166 | losses.update(loss.data[0], input.size(0))
167 | top1.update(prec1[0], input.size(0))
168 | top5.update(prec5[0], input.size(0))
169 |
170 | # compute gradient and do SGD step
171 | optimizer.zero_grad()
172 | loss.backward()
173 | optimizer.step()
174 |
175 | # measure elapsed time
176 | batch_time.update(time.time() - end)
177 | end = time.time()
178 |
179 | if i % args.print_freq == 0:
180 | print('Epoch: [{0}][{1}/{2}]\t'
181 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
182 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
184 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
185 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
186 | epoch, i, len(train_loader), batch_time=batch_time,
187 | data_time=data_time, loss=losses, top1=top1, top5=top5))
188 |
189 |
190 | def validate(val_loader, model, criterion):
191 | batch_time = AverageMeter()
192 | losses = AverageMeter()
193 | top1 = AverageMeter()
194 | top5 = AverageMeter()
195 |
196 | # switch to evaluate mode
197 | model.eval()
198 |
199 | end = time.time()
200 | for i, (input, target) in enumerate(val_loader):
201 | target = target.cuda(async=True)
202 | input_var = torch.autograd.Variable(input, volatile=True)
203 | target_var = torch.autograd.Variable(target, volatile=True)
204 |
205 | # compute output
206 | output = model(input_var)
207 | loss = criterion(output, target_var)
208 |
209 | # measure accuracy and record loss
210 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
211 | losses.update(loss.data[0], input.size(0))
212 | top1.update(prec1[0], input.size(0))
213 | top5.update(prec5[0], input.size(0))
214 |
215 | # measure elapsed time
216 | batch_time.update(time.time() - end)
217 | end = time.time()
218 |
219 | if i % args.print_freq == 0:
220 | print('Test: [{0}/{1}]\t'
221 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
222 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
223 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
224 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
225 | i, len(val_loader), batch_time=batch_time, loss=losses,
226 | top1=top1, top5=top5))
227 |
228 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
229 | .format(top1=top1, top5=top5))
230 |
231 | return top1.avg
232 |
233 |
234 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
235 | torch.save(state, filename)
236 | if is_best:
237 | shutil.copyfile(filename, 'model_best.pth.tar')
238 |
239 |
240 | class AverageMeter(object):
241 | """Computes and stores the average and current value"""
242 | def __init__(self):
243 | self.reset()
244 |
245 | def reset(self):
246 | self.val = 0
247 | self.avg = 0
248 | self.sum = 0
249 | self.count = 0
250 |
251 | def update(self, val, n=1):
252 | self.val = val
253 | self.sum += val * n
254 | self.count += n
255 | self.avg = self.sum / self.count
256 |
257 |
258 | def adjust_learning_rate(optimizer, epoch):
259 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
260 | lr = args.lr * (0.1 ** (epoch // 30))
261 | for param_group in optimizer.param_groups:
262 | param_group['lr'] = lr
263 |
264 |
265 | def accuracy(output, target, topk=(1,)):
266 | """Computes the precision@k for the specified values of k"""
267 | maxk = max(topk)
268 | batch_size = target.size(0)
269 |
270 | _, pred = output.topk(maxk, 1, True, True)
271 | pred = pred.t()
272 | correct = pred.eq(target.view(1, -1).expand_as(pred))
273 |
274 | res = []
275 | for k in topk:
276 | correct_k = correct[:k].view(-1).float().sum(0)
277 | res.append(correct_k.mul_(100.0 / batch_size))
278 | return res
279 |
280 |
281 | if __name__ == '__main__':
282 | main()
283 |
--------------------------------------------------------------------------------
/image-pretraining/model:
--------------------------------------------------------------------------------
1 | ../model/
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as path
4 | import time
5 | from datetime import timedelta
6 | from sys import exit, argv
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | from torch.autograd import Variable as V
12 | from torch.utils.data import DataLoader
13 | from torchvision import transforms as trn
14 |
15 | from data.VideoFolder import VideoFolder, BatchSampler, VideoCollate
16 | from utils.image_plot import show_four, show_ten
17 |
18 | parser = argparse.ArgumentParser(description='PyTorch MatchNet generative model training script',
19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20 | _ = parser.add_argument # define add_argument shortcut
21 | _('--data', type=str, default='./data/processed-data', help='location of the video data')
22 | _('--model', type=str, default='CortexNet', help='type of auto-encoder')
23 | _('--mode', type=str, required=True, help='training mode [MatchNet|TempoNet]')
24 | _('--size', type=int, default=(3, 32, 64, 128, 256), nargs='*', help='number and size of hidden layers', metavar='S')
25 | _('--spatial-size', type=int, default=(256, 256), nargs=2, help='frame cropping size', metavar=('H', 'W'))
26 | _('--lr', type=float, default=0.1, help='initial learning rate')
27 | _('--momentum', type=float, default=0.9, metavar='M', help='momentum')
28 | _('--weight-decay', type=float, default=1e-4, metavar='W', help='weight decay')
29 | _('--mu', type=float, default=1, help='matching MSE multiplier', dest='mu', metavar='μ')
30 | _('--tau', type=float, default=0.1, help='temporal CE multiplier', dest='tau', metavar='τ')
31 | _('--pi', default='τ', help='periodical CE multiplier', dest='pi', metavar='π')
32 | _('--epochs', type=int, default=10, help='upper epoch limit')
33 | _('--batch-size', type=int, default=20, metavar='B', help='batch size')
34 | _('--big-t', type=int, default=10, help='sequence length', metavar='T')
35 | _('--seed', type=int, default=0, help='random seed')
36 | _('--log-interval', type=int, default=10, metavar='N', help='report interval')
37 | _('--save', type=str, default='last/model.pth.tar', help='path to save the final model')
38 | _('--cuda', action='store_true', help='use CUDA')
39 | _('--view', type=int, default=tuple(), help='samples to view at the end of every log-interval batches', metavar='V')
40 | _('--show-x_hat', action='store_true', help='show x_hat')
41 | _('--lr-decay', type=float, default=None, nargs=2, metavar=('D', 'E'),
42 | help='decay of D (e.g. 3.16, 10) times, every E (e.g. 3) epochs')
43 | _('--pre-trained', type=str, default='', help='path to pre-trained model', metavar='P')
44 | args = parser.parse_args()
45 | args.size = tuple(args.size) # cast to tuple
46 | if args.lr_decay: args.lr_decay = tuple(args.lr_decay)
47 | if type(args.view) is int: args.view = (args.view,) # cast to tuple
48 | args.pi = args.tau if args.pi == 'τ' else float(args.pi)
49 |
50 | # Print current options
51 | print('CLI arguments:', ' '.join(argv[1:]))
52 |
53 | # Print current commit
54 | if path.isdir('.git'): # if we are in a repo
55 | with os.popen('git rev-parse HEAD') as pipe: # get the HEAD's hash
56 | print('Current commit hash:', pipe.read(), end='')
57 |
58 | # Set the random seed manually for reproducibility.
59 | torch.manual_seed(args.seed)
60 | if torch.cuda.is_available():
61 | if not args.cuda:
62 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
63 | else:
64 | torch.cuda.manual_seed(args.seed)
65 |
66 |
67 | def main():
68 | # Load data
69 | print('Define image pre-processing')
70 | # normalise? do we care?
71 | t = trn.Compose((trn.ToPILImage(), trn.CenterCrop(args.spatial_size), trn.ToTensor()))
72 |
73 | print('Define train data loader')
74 | train_data_name = 'train_data.tar'
75 | if os.access(train_data_name, os.R_OK):
76 | train_data = torch.load(train_data_name)
77 | else:
78 | train_path = path.join(args.data, 'train')
79 | if args.mode == 'MatchNet':
80 | train_data = VideoFolder(root=train_path, transform=t, video_index=True)
81 | elif args.mode == 'TempoNet':
82 | train_data = VideoFolder(root=train_path, transform=t, shuffle=True)
83 | torch.save(train_data, train_data_name)
84 |
85 | train_loader = DataLoader(
86 | dataset=train_data,
87 | batch_size=args.batch_size * args.big_t, # batch_size rows and T columns
88 | shuffle=False,
89 | sampler=BatchSampler(data_source=train_data, batch_size=args.batch_size), # given that BatchSampler knows it
90 | num_workers=1,
91 | collate_fn=VideoCollate(batch_size=args.batch_size),
92 | pin_memory=True
93 | )
94 |
95 | print('Define validation data loader')
96 | val_data_name = 'val_data.tar'
97 | if os.access(val_data_name, os.R_OK):
98 | val_data = torch.load(val_data_name)
99 | else:
100 | val_path = path.join(args.data, 'val')
101 | if args.mode == 'MatchNet':
102 | val_data = VideoFolder(root=val_path, transform=t, video_index=True)
103 | elif args.mode == 'TempoNet':
104 | val_data = VideoFolder(root=val_path, transform=t, shuffle='init')
105 | torch.save(val_data, val_data_name)
106 |
107 | val_loader = DataLoader(
108 | dataset=val_data,
109 | batch_size=args.batch_size, # just one column of size batch_size
110 | shuffle=False,
111 | sampler=BatchSampler(data_source=val_data, batch_size=args.batch_size),
112 | num_workers=1,
113 | collate_fn=VideoCollate(batch_size=args.batch_size),
114 | pin_memory=True
115 | )
116 |
117 | # Build the model
118 | if args.model == 'model_01':
119 | from model.Model01 import Model01 as Model
120 | elif args.model == 'model_02' or args.model == 'CortexNet':
121 | from model.Model02 import Model02 as Model
122 | elif args.model == 'model_02_rg':
123 | from model.Model02 import Model02RG as Model
124 | else:
125 | print('\n{:#^80}\n'.format(' Please select a valid model '))
126 | exit()
127 |
128 | print('Define model')
129 | if args.mode == 'MatchNet':
130 | nb_train_videos = len(train_data.videos)
131 | model = Model(args.size + (nb_train_videos,), args.spatial_size)
132 | elif args.mode == 'TempoNet':
133 | nb_classes = len(train_data.classes)
134 | model = Model(args.size + (nb_classes,), args.spatial_size)
135 |
136 | if args.pre_trained:
137 | print('Load pre-trained weights')
138 | # args.pre_trained = 'image-pretraining/model02D-33IS/model_best.pth.tar'
139 | dict_33 = torch.load(args.pre_trained)['state_dict']
140 |
141 | def load_state_dict(new_model, state_dict):
142 | own_state = new_model.state_dict()
143 | for name, param in state_dict.items():
144 | name = name[19:] # remove 'module.inner_model.' part
145 | if name not in own_state:
146 | raise KeyError('unexpected key "{}" in state_dict'
147 | .format(name))
148 | if name.startswith('stabiliser'):
149 | print('Skipping', name)
150 | continue
151 | if isinstance(param, nn.Parameter):
152 | # backwards compatibility for serialized parameters
153 | param = param.data
154 | own_state[name].copy_(param)
155 |
156 | missing = set(own_state.keys()) - set([k[19:] for k in state_dict.keys()])
157 | if len(missing) > 0:
158 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
159 |
160 | load_state_dict(model, dict_33)
161 |
162 | print('Create a MSE and balanced NLL criterions')
163 | mse = nn.MSELoss()
164 |
165 | # independent CE computation
166 | nll_final = nn.CrossEntropyLoss(size_average=False)
167 | # balance classes based on frames per video; default balancing weight is 1.0f
168 | w = torch.Tensor(train_data.frames_per_video if args.mode == 'MatchNet' else train_data.frames_per_class)
169 | w.div_(w.mean()).pow_(-1)
170 | nll_train = nn.CrossEntropyLoss(w)
171 | w = torch.Tensor(val_data.frames_per_video if args.mode == 'MatchNet' else val_data.frames_per_class)
172 | w.div_(w.mean()).pow_(-1)
173 | nll_val = nn.CrossEntropyLoss(w)
174 |
175 | if args.cuda:
176 | model.cuda()
177 | mse.cuda()
178 | nll_final.cuda()
179 | nll_train.cuda()
180 | nll_val.cuda()
181 |
182 | print('Instantiate a SGD optimiser')
183 | optimiser = optim.SGD(
184 | params=model.parameters(),
185 | lr=args.lr,
186 | momentum=args.momentum,
187 | weight_decay=args.weight_decay
188 | )
189 |
190 | # Loop over epochs
191 | for epoch in range(0, args.epochs):
192 | if args.lr_decay: adjust_learning_rate(optimiser, epoch)
193 | epoch_start_time = time.time()
194 | train(train_loader, model, (mse, nll_final, nll_train), optimiser, epoch)
195 | print(80 * '-', '| end of epoch {:3d} |'.format(epoch + 1), sep='\n', end=' ')
196 | val_loss = validate(val_loader, model, (mse, nll_final, nll_val))
197 | elapsed_time = str(timedelta(seconds=int(time.time() - epoch_start_time))) # HH:MM:SS time format
198 | print('time: {} | mMSE {:.2e} | CE {:.2e} | rpl mMSE {:.2e} | per CE {:.2e} |'.
199 | format(elapsed_time, val_loss['mse'] * 1e3, val_loss['ce'], val_loss['rpl'] * 1e3, val_loss['per_ce']))
200 | print(80 * '-')
201 |
202 | if args.save != '':
203 | torch.save(model, args.save)
204 |
205 |
206 | def adjust_learning_rate(opt, epoch):
207 | """Sets the learning rate to the initial LR decayed by D every E epochs"""
208 | d, e = args.lr_decay
209 | lr = args.lr * (d ** -(epoch // e))
210 | for param_group in opt.param_groups:
211 | param_group['lr'] = lr
212 |
213 |
214 | def selective_zero(s, new, forward=True):
215 | if new.any(): # if at least one video changed
216 | b = new.nonzero().squeeze(1) # get the list of indices
217 | if forward: # no state forward, no grad backward
218 | if isinstance(s[0], list): # recurrent G
219 | for layer in range(len(s[0])): # for every layer having a state
220 | s[0][layer] = s[0][layer].index_fill(0, V(b), 0) # mask state, zero selected indices
221 | for layer in range(len(s[1])): # for every layer having a state
222 | s[1][layer] = s[1][layer].index_fill(0, V(b), 0) # mask state, zero selected indices
223 | else: # simple convolutional G
224 | for layer in range(len(s)): # for every layer having a state
225 | s[layer] = s[layer].index_fill(0, V(b), 0) # mask state, zero selected indices
226 | else: # just no grad backward
227 | if isinstance(s[0], list): # recurrent G
228 | for layer in range(len(s[0])): # for every layer having a state
229 | s[0][layer].register_hook(lambda g: g.index_fill(0, V(b), 0)) # zero selected gradients
230 | for layer in range(len(s[1])): # for every layer having a state
231 | s[1][layer].register_hook(lambda g: g.index_fill(0, V(b), 0)) # zero selected gradients
232 | else: # simple convolutional G
233 | for layer in range(len(s)): # for every layer having a state
234 | s[layer].register_hook(lambda g: g.index_fill(0, V(b), 0)) # zero selected gradients
235 |
236 |
237 | def selective_match(x_hat, x, new):
238 | if new.any(): # if at least one video changed
239 | b = new.nonzero().squeeze(1) # get the list of indices
240 | for bb in b: x_hat[bb].copy_(x[bb]) # force the output to be the expected output
241 |
242 |
243 | def selective_cross_entropy(logits, y, new, loss, count):
244 | if not new.any():
245 | return V(logits.data.new(1).zero_()) # returns a variable, so we don't care about what happened here
246 | b = new.nonzero().squeeze(1) # get the list of indices
247 | count['ce_count'] += len(b)
248 | return loss(logits.index_select(0, V(b)), y.index_select(0, V(b))) # performs loss for selected indices only
249 |
250 |
251 | def train(train_loader, model, loss_fun, optimiser, epoch):
252 | print('Training epoch', epoch + 1)
253 | model.train() # set model in train mode
254 | total_loss = {'mse': 0, 'ce': 0, 'ce_count': 0, 'per_ce': 0, 'rpl': 0}
255 | mse, nll_final, nll_periodic = loss_fun
256 |
257 | def compute_loss(x_, next_x, y_, state_, periodic=False):
258 | nonlocal previous_mismatch # write access to variables of the enclosing function
259 | if args.mode == 'MatchNet':
260 | if not periodic and state_: selective_zero(state_, mismatch, forward=False) # no grad to the past
261 | (x_hat, state_), (_, idx) = model(V(x_), state_)
262 | selective_zero(state_, mismatch) # no state to the future, no grad from the future
263 | selective_match(x_hat.data, next_x, mismatch + previous_mismatch) # last frame or first frame
264 | previous_mismatch = mismatch # last frame <- first frame
265 | mse_loss_ = mse(x_hat, V(next_x))
266 | total_loss['mse'] += mse_loss_.data[0]
267 | ce_loss_ = selective_cross_entropy(idx, V(y_), mismatch, nll_final, total_loss)
268 | total_loss['ce'] += ce_loss_.data[0]
269 | if periodic:
270 | ce_loss_ = (ce_loss_, nll_periodic(idx, V(y_)))
271 | total_loss['per_ce'] += ce_loss_[1].data[0]
272 | total_loss['rpl'] += mse(x_hat, V(x_, volatile=True)).data[0]
273 | return ce_loss_, mse_loss_, state_, x_hat.data
274 |
275 | data_time = 0
276 | batch_time = 0
277 | end_time = time.time()
278 | state = None # reset state at the beginning of a new epoch
279 | from_past = None # performs only T - 1 steps for the first temporal batch
280 | previous_mismatch = torch.ByteTensor(args.batch_size).fill_(1) # ignore first prediction
281 | if args.cuda: previous_mismatch = previous_mismatch.cuda()
282 | for batch_nb, (x, y) in enumerate(train_loader):
283 | data_time += time.time() - end_time
284 | if args.cuda:
285 | x = x.cuda(async=True)
286 | y = y.cuda(async=True)
287 | state = repackage_state(state)
288 | loss = 0
289 | # BTT loop
290 | if args.mode == 'MatchNet':
291 | if from_past:
292 | mismatch = y[0] != from_past[1]
293 | ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state, periodic=True)
294 | loss += mse_loss * args.mu + ce_loss[0] * args.tau + ce_loss[1] * args.pi
295 | for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward
296 | mismatch = y[t + 1] != y[t]
297 | ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state)
298 | loss += mse_loss * args.mu + ce_loss * args.tau
299 | elif args.mode == 'TempoNet':
300 | if from_past:
301 | mismatch = y[0] != from_past[1]
302 | ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state)
303 | loss += mse_loss * args.mu + ce_loss * args.tau
304 | for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward
305 | mismatch = y[t + 1] != y[t]
306 | last = t == min(args.big_t, x.size(0)) - 2
307 | ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state, periodic=last)
308 | if not last:
309 | loss += mse_loss * args.mu + ce_loss * args.tau
310 | else:
311 | loss += mse_loss * args.mu + ce_loss[0] * args.tau + ce_loss[1] * args.pi
312 |
313 | # compute gradient and do SGD step
314 | model.zero_grad()
315 | loss.backward()
316 | optimiser.step()
317 |
318 | # save last column for future
319 | from_past = x[-1], y[-1]
320 |
321 | # measure batch time
322 | batch_time += time.time() - end_time
323 | end_time = time.time() # for computing data_time
324 |
325 | if (batch_nb + 1) % args.log_interval == 0:
326 | if args.view:
327 | for f in args.view:
328 | show_four(x[t][f], x[t + 1][f], x_hat_data[f], f + 1)
329 | if args.show_x_hat: show_ten(x[t][f], x_hat_data[f])
330 | total_loss['mse'] /= args.log_interval * args.big_t
331 | total_loss['rpl'] /= args.log_interval * args.big_t
332 | total_loss['per_ce'] /= args.log_interval
333 | if total_loss['ce_count']: total_loss['ce'] /= total_loss['ce_count']
334 | avg_batch_time = batch_time * 1e3 / args.log_interval
335 | avg_data_time = data_time * 1e3 / args.log_interval
336 | lr = optimiser.param_groups[0]['lr'] # assumes len(param_groups) == 1
337 | print('| epoch {:3d} | {:4d}/{:4d} batches | lr {:.3f} |'
338 | ' ms/batch {:7.2f} | ms/data {:7.2f} | mMSE {:.2e} | CE {:.2e} | rpl mMSE {:.2e} | per CE {:.2e} |'.
339 | format(epoch + 1, batch_nb + 1, len(train_loader), lr, avg_batch_time, avg_data_time,
340 | total_loss['mse'] * 1e3, total_loss['ce'], total_loss['rpl'] * 1e3, total_loss['per_ce']))
341 | for k in total_loss: total_loss[k] = 0 # zero the losses
342 | batch_time = 0
343 | data_time = 0
344 |
345 |
346 | def validate(val_loader, model, loss_fun):
347 | model.eval() # set model in evaluation mode
348 | total_loss = {'mse': 0, 'ce': 0, 'ce_count': 0, 'per_ce': 0, 'rpl': 0}
349 | mse, nll_final, nll_periodic = loss_fun
350 | batches = enumerate(val_loader)
351 |
352 | _, (x, y) = next(batches)
353 | if args.cuda:
354 | x = x.cuda(async=True)
355 | y = y.cuda(async=True)
356 | previous_mismatch = y[0].byte().fill_(1) # ignore first prediction
357 | state = None # reset state at the beginning of a new epoch
358 | for batch_nb, (next_x, next_y) in batches:
359 | if args.cuda:
360 | next_x = next_x.cuda(async=True)
361 | next_y = next_y.cuda(async=True)
362 | mismatch = next_y[0] != y[0]
363 | (x_hat, state), (_, idx) = model(V(x[0], volatile=True), state) # do not compute graph (volatile)
364 | selective_zero(state, mismatch) # no state to the future
365 | selective_match(x_hat.data, next_x[0], mismatch + previous_mismatch) # last frame or first frame
366 | previous_mismatch = mismatch # last frame <- first frame
367 | total_loss['mse'] += mse(x_hat, V(next_x[0])).data[0]
368 | ce_loss = selective_cross_entropy(idx, V(y[0]), mismatch, nll_final, total_loss)
369 | total_loss['ce'] += ce_loss.data[0]
370 | if batch_nb % args.big_t == 0: total_loss['per_ce'] += nll_periodic(idx, V(y[0])).data[0]
371 | total_loss['rpl'] += mse(x_hat, V(x[0])).data[0]
372 | x, y = next_x, next_y
373 |
374 | total_loss['mse'] /= len(val_loader) # average out
375 | total_loss['rpl'] /= len(val_loader) # average out
376 | total_loss['per_ce'] /= len(val_loader) / args.big_t # average out
377 | total_loss['ce'] /= total_loss['ce_count'] # average out
378 | return total_loss
379 |
380 |
381 | def repackage_state(h):
382 | """
383 | Wraps hidden states in new Variables, to detach them from their history.
384 | """
385 | if not h:
386 | return None
387 | elif type(h) == V:
388 | return V(h.data)
389 | else:
390 | return list(repackage_state(v) for v in h)
391 |
392 |
393 | if __name__ == '__main__':
394 | main()
395 |
396 | __author__ = "Alfredo Canziani"
397 | __credits__ = ["Alfredo Canziani"]
398 | __maintainer__ = "Alfredo Canziani"
399 | __email__ = "alfredo.canziani@gmail.com"
400 | __status__ = "Production" # "Prototype", "Development", or "Production"
401 | __date__ = "Feb 17"
402 |
--------------------------------------------------------------------------------
/model/ConvLSTMCell.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as f
4 | from torch.autograd import Variable
5 |
6 |
7 | # Define some constants
8 | KERNEL_SIZE = 3
9 | PADDING = KERNEL_SIZE // 2
10 |
11 |
12 | class ConvLSTMCell(nn.Module):
13 | """
14 | Generate a convolutional LSTM cell
15 | """
16 |
17 | def __init__(self, input_size, hidden_size):
18 | super().__init__()
19 | self.input_size = input_size
20 | self.hidden_size = hidden_size
21 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING)
22 |
23 | def forward(self, input_, prev_state):
24 |
25 | # get batch and spatial sizes
26 | batch_size = input_.data.size()[0]
27 | spatial_size = input_.data.size()[2:]
28 |
29 | # generate empty prev_state, if None is provided
30 | if prev_state is None:
31 | state_size = [batch_size, self.hidden_size] + list(spatial_size)
32 | prev_state = (
33 | Variable(torch.zeros(state_size)),
34 | Variable(torch.zeros(state_size))
35 | )
36 |
37 | prev_hidden, prev_cell = prev_state
38 |
39 | # data size is [batch, channel, height, width]
40 | stacked_inputs = torch.cat((input_, prev_hidden), 1)
41 | gates = self.Gates(stacked_inputs)
42 |
43 | # chunk across channel dimension
44 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
45 |
46 | # apply sigmoid non linearity
47 | in_gate = f.sigmoid(in_gate)
48 | remember_gate = f.sigmoid(remember_gate)
49 | out_gate = f.sigmoid(out_gate)
50 |
51 | # apply tanh non linearity
52 | cell_gate = f.tanh(cell_gate)
53 |
54 | # compute current cell and hidden state
55 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
56 | hidden = out_gate * f.tanh(cell)
57 |
58 | return hidden, cell
59 |
60 |
61 | def _main():
62 | """
63 | Run some basic tests on the API
64 | """
65 |
66 | # define batch_size, channels, height, width
67 | b, c, h, w = 1, 3, 4, 8
68 | d = 5 # hidden state size
69 | lr = 1e-1 # learning rate
70 | T = 6 # sequence length
71 | max_epoch = 20 # number of epochs
72 |
73 | # set manual seed
74 | torch.manual_seed(0)
75 |
76 | print('Instantiate model')
77 | model = ConvLSTMCell(c, d)
78 | print(repr(model))
79 |
80 | print('Create input and target Variables')
81 | x = Variable(torch.rand(T, b, c, h, w))
82 | y = Variable(torch.randn(T, b, d, h, w))
83 |
84 | print('Create a MSE criterion')
85 | loss_fn = nn.MSELoss()
86 |
87 | print('Run for', max_epoch, 'iterations')
88 | for epoch in range(0, max_epoch):
89 | state = None
90 | loss = 0
91 | for t in range(0, T):
92 | state = model(x[t], state)
93 | loss += loss_fn(state[0], y[t])
94 |
95 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0]))
96 |
97 | # zero grad parameters
98 | model.zero_grad()
99 |
100 | # compute new grad parameters through time!
101 | loss.backward()
102 |
103 | # learning_rate step against the gradient
104 | for p in model.parameters():
105 | p.data.sub_(p.grad.data * lr)
106 |
107 | print('Input size:', list(x.data.size()))
108 | print('Target size:', list(y.data.size()))
109 | print('Last hidden state size:', list(state[0].size()))
110 |
111 |
112 | if __name__ == '__main__':
113 | _main()
114 |
115 |
116 | __author__ = "Alfredo Canziani"
117 | __credits__ = ["Alfredo Canziani"]
118 | __maintainer__ = "Alfredo Canziani"
119 | __email__ = "alfredo.canziani@gmail.com"
120 | __status__ = "Prototype" # "Prototype", "Development", or "Production"
121 | __date__ = "Jan 17"
122 |
--------------------------------------------------------------------------------
/model/DiscriminativeCell.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as f
4 | from torch.autograd import Variable
5 |
6 |
7 | # Define some constants
8 | KERNEL_SIZE = 3
9 | PADDING = KERNEL_SIZE // 2
10 | POOL = 2
11 |
12 |
13 | class DiscriminativeCell(nn.Module):
14 | """
15 | Single discriminative layer
16 | """
17 |
18 | def __init__(self, input_size, hidden_size, first=False):
19 | """
20 | Create a discriminative cell (bottom_up, r_state) -> error
21 |
22 | :param input_size: {'input': bottom_up_size, 'state': r_state_size}
23 | :param hidden_size: int, shooting dimensionality
24 | :param first: True/False
25 | """
26 | super().__init__()
27 | self.input_size = input_size
28 | self.hidden_size = hidden_size
29 | self.first = first
30 | if not first:
31 | self.from_bottom = nn.Conv2d(input_size['input'], hidden_size, KERNEL_SIZE, padding=PADDING)
32 | self.from_state = nn.Conv2d(input_size['state'], hidden_size, KERNEL_SIZE, padding=PADDING)
33 |
34 | def forward(self, bottom_up, state):
35 | input_projection = self.first and bottom_up or f.relu(f.max_pool2d(self.from_bottom(bottom_up), POOL, POOL))
36 | state_projection = f.relu(self.from_state(state))
37 | error = f.relu(torch.cat((input_projection - state_projection, state_projection - input_projection), 1))
38 | return error
39 |
40 |
41 | def _test_layer1():
42 | print('Define model for layer 1')
43 | discriminator = DiscriminativeCell(input_size={'input': 3, 'state': 3}, hidden_size=3, first=True)
44 |
45 | print('Define input and state')
46 | # at the first layer we have that system_state match the input_image dimensionality
47 | input_image = Variable(torch.rand(1, 3, 8, 12))
48 | system_state = Variable(torch.randn(1, 3, 8, 12))
49 |
50 | print('Input has size', list(input_image.data.size()))
51 |
52 | print('Forward input and state to the model')
53 | e = discriminator(input_image, system_state)
54 |
55 | # print output size
56 | print('Layer 1 error has size', list(e.data.size()))
57 |
58 | return e
59 |
60 |
61 | def _test_layer2(input_error):
62 | print('Define model for layer 2')
63 | discriminator = DiscriminativeCell(input_size={'input': 6, 'state': 32}, hidden_size=32, first=False)
64 |
65 | print('Define a new, smaller state')
66 | system_state = Variable(torch.randn(1, 32, 4, 6))
67 |
68 | print('Forward layer 1 output and state to the model')
69 | e = discriminator(input_error, system_state)
70 |
71 | # print output size
72 | print('Layer 2 error has size', list(e.data.size()))
73 |
74 |
75 | def _test_layers():
76 | error = _test_layer1()
77 | _test_layer2(input_error=error)
78 |
79 |
80 | if __name__ == '__main__':
81 | _test_layers()
82 |
83 |
84 | __author__ = "Alfredo Canziani"
85 | __credits__ = ["Alfredo Canziani"]
86 | __maintainer__ = "Alfredo Canziani"
87 | __email__ = "alfredo.canziani@gmail.com"
88 | __status__ = "Prototype" # "Prototype", "Development", or "Production"
89 | __date__ = "Feb 17"
90 |
--------------------------------------------------------------------------------
/model/GenerativeCell.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as f
4 | from torch.autograd import Variable
5 |
6 | from model.ConvLSTMCell import ConvLSTMCell
7 |
8 |
9 | class GenerativeCell(nn.Module):
10 | """
11 | Single generative layer
12 | """
13 |
14 | def __init__(self, input_size, hidden_size, error_init_size=None):
15 | """
16 | Create a generative cell (error, top_down_state, r_state) -> r_state
17 |
18 | :param input_size: {'error': error_size, 'up_state': r_state_size}, r_state_size can be 0
19 | :param hidden_size: int, shooting dimensionality
20 | :param error_init_size: tuple, full size of initial (null) error
21 | """
22 | super().__init__()
23 | self.input_size = input_size
24 | self.hidden_size = hidden_size
25 | self.error_init_size = error_init_size
26 | self.memory = ConvLSTMCell(input_size['error']+input_size['up_state'], hidden_size)
27 |
28 | def forward(self, error, top_down_state, state):
29 | if error is None: # we just started
30 | error = Variable(torch.zeros(self.error_init_size))
31 | model_input = error
32 | if top_down_state is not None:
33 | model_input = torch.cat((error, f.upsample_nearest(top_down_state, scale_factor=2)), 1)
34 | return self.memory(model_input, state)
35 |
36 |
37 | def _test_layer2():
38 | print('Define model for layer 2')
39 | generator = GenerativeCell(input_size={'error': 2*16, 'up_state': 0}, hidden_size=16)
40 |
41 | print('Define error and top down state')
42 | input_error = Variable(torch.randn(1, 2*16, 4, 6))
43 | topdown_state = None
44 |
45 | print('Input error has size', list(input_error.data.size()))
46 | print('Top down state is None')
47 |
48 | print('Forward error and top down state to the model')
49 | state = None
50 | state = generator(input_error, topdown_state, state)
51 |
52 | # print output size
53 | print('Layer 2 state has size', list(state[0].data.size()))
54 |
55 | return state[0] # the element 1 is the cell state
56 |
57 |
58 | def _test_layer1(top_down_state):
59 | print('Define model for layer 1')
60 | generator = GenerativeCell(input_size={'error': 2*3, 'up_state': 16}, hidden_size=3)
61 |
62 | print('Define error and top down state')
63 | input_error = Variable(torch.randn(1, 2*3, 8, 12))
64 |
65 | print('Input error has size', list(input_error.data.size()))
66 | print('Top down state has size', list(top_down_state.data.size()))
67 |
68 | print('Forward error and top down state to the model')
69 | state = None
70 | state = generator(input_error, top_down_state, state)
71 |
72 | # print output size
73 | print('Layer 1 state has size', list(state[0].data.size()))
74 |
75 |
76 | def _test_layers():
77 | state = _test_layer2()
78 | _test_layer1(top_down_state=state)
79 |
80 |
81 | if __name__ == '__main__':
82 | _test_layers()
83 |
84 |
85 | __author__ = "Alfredo Canziani"
86 | __credits__ = ["Alfredo Canziani"]
87 | __maintainer__ = "Alfredo Canziani"
88 | __email__ = "alfredo.canziani@gmail.com"
89 | __status__ = "Prototype" # "Prototype", "Development", or "Production"
90 | __date__ = "Feb 17"
91 |
--------------------------------------------------------------------------------
/model/Model01.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as f
4 | from torch.autograd import Variable as V
5 | from math import ceil
6 |
7 |
8 | # Define some constants
9 | KERNEL_SIZE = 3
10 | PADDING = KERNEL_SIZE // 2
11 | KERNEL_STRIDE = 2
12 | OUTPUT_ADJUST = KERNEL_SIZE - 2 * PADDING
13 |
14 |
15 | class Model01(nn.Module):
16 | """
17 | Generate a constructor for model_01 type of network
18 | """
19 |
20 | def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
21 | """
22 | Initialise Model01 constructor
23 |
24 | :param network_size: (n, h1, h2, ..., emb_size, nb_videos)
25 | :type network_size: tuple
26 | :param input_spatial_size: (height, width)
27 | :type input_spatial_size: tuple
28 | """
29 | super().__init__()
30 | self.hidden_layers = len(network_size) - 2
31 |
32 | print('\n{:-^80}'.format(' Building model '))
33 | print('Hidden layers:', self.hidden_layers)
34 | print('Net sizing:', network_size)
35 | print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size))
36 |
37 | # main auto-encoder blocks
38 | self.activation_size = [input_spatial_size]
39 | for layer in range(0, self.hidden_layers):
40 | # print some annotation when building model
41 | print('{:-<80}'.format('Layer ' + str(layer + 1) + ' '))
42 | print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1]))
43 | self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer]))
44 | print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1]))
45 |
46 | # init D (discriminative) blocks
47 | setattr(self, 'D_' + str(layer + 1), nn.Conv2d(
48 | in_channels=network_size[layer], out_channels=network_size[layer + 1],
49 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
50 | ))
51 | setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1]))
52 |
53 | # init G (generative) blocks
54 | setattr(self, 'G_' + str(layer + 1), nn.ConvTranspose2d(
55 | in_channels=network_size[layer + 1], out_channels=network_size[layer],
56 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
57 | ))
58 | setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer]))
59 |
60 | # init auxiliary classifier
61 | print('{:-<80}'.format('Classifier '))
62 | print(network_size[-2], '-->', network_size[-1])
63 | self.average = nn.AvgPool2d(self.activation_size[-1])
64 | self.stabiliser = nn.Linear(network_size[-2], network_size[-1])
65 | print(80 * '-', end='\n\n')
66 |
67 | def forward(self, x, state):
68 | activation_sizes = [x.size()] # start from the input
69 | residuals = list()
70 | for layer in range(0, self.hidden_layers): # connect discriminative blocks
71 | x = getattr(self, 'D_' + str(layer + 1))(x)
72 | residuals.append(x)
73 | if layer < self.hidden_layers - 1 and state: x += state[layer]
74 | x = f.relu(x)
75 | x = getattr(self, 'BN_D_' + str(layer + 1))(x)
76 | activation_sizes.append(x.size()) # cache output size for later retrieval
77 | state = state or [None] * (self.hidden_layers - 1)
78 | for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks
79 | x = getattr(self, 'G_' + str(layer + 1))(x, activation_sizes[layer])
80 | if layer:
81 | state[layer - 1] = x
82 | x += residuals[layer - 1]
83 | x = f.relu(x)
84 | x = getattr(self, 'BN_G_' + str(layer + 1))(x)
85 | x_mean = self.average(residuals[-1])
86 | video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1))
87 |
88 | return (x, state), (x_mean, video_index)
89 |
90 |
91 | def _test_model():
92 | T = 2
93 | x = torch.rand(T + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5)
94 | K = 10
95 | y = torch.LongTensor(T, 1).random_(K)
96 | model_01 = Model01(network_size=(3, 6, 12, 18, K), input_spatial_size=x[0].size()[2:])
97 |
98 | state = None
99 | (x_hat, state), (emb, idx) = model_01(V(x[0]), state)
100 |
101 | print('Input size:', tuple(x.size()))
102 | print('Output size:', tuple(x_hat.data.size()))
103 | print('Video index size:', tuple(idx.size()))
104 | for i, s in enumerate(state):
105 | print('State', i + 1, 'has size:', tuple(s.size()))
106 | print('Embedding has size:', emb.data.numel())
107 |
108 | mse = nn.MSELoss()
109 | nll = nn.CrossEntropyLoss()
110 | x_next = V(x[1])
111 | y_var = V(y[0])
112 | loss_t1 = mse(x_hat, x_next) + nll(idx, y_var)
113 |
114 | from utils.visualise import show_graph
115 | show_graph(loss_t1)
116 |
117 | # run one more time
118 | (x_hat, _), (_, idx) = model_01(V(x[1]), state)
119 |
120 | x_next = V(x[2])
121 | y_var = V(y[1])
122 | loss_t2 = mse(x_hat, x_next) + nll(idx, y_var)
123 | loss_tot = loss_t2 + loss_t1
124 |
125 | show_graph(loss_tot)
126 |
127 |
128 | def _test_training():
129 | K = 10 # number of training videos
130 | network_size = (3, 6, 12, 18, K)
131 | T = 6 # sequence length
132 | max_epoch = 10 # number of epochs
133 | lr = 1e-1 # learning rate
134 |
135 | # set manual seed
136 | torch.manual_seed(0)
137 |
138 | print('\n{:-^80}'.format(' Train a ' + str(network_size[:-1]) + ' layer network '))
139 | print('Sequence length T:', T)
140 | print('Create the input image and target sequences')
141 | x = torch.rand(T + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5)
142 | y = torch.LongTensor(T, 1).random_(K)
143 | print('Input has size', tuple(x.size()))
144 | print('Target index has size', tuple(y.size()))
145 |
146 | print('Define model')
147 | model = Model01(network_size=network_size, input_spatial_size=x[0].size()[2:])
148 |
149 | print('Create a MSE and NLL criterions')
150 | mse = nn.MSELoss()
151 | nll = nn.CrossEntropyLoss()
152 |
153 | print('Run for', max_epoch, 'iterations')
154 | for epoch in range(0, max_epoch):
155 | state = None
156 | loss = 0
157 | for t in range(0, T):
158 | (x_hat, state), (emb, idx) = model(V(x[t]), state)
159 | loss += mse(x_hat, V(x[t + 1])) + nll(idx, V(y[t]))
160 |
161 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch + 1), loss.data[0]))
162 |
163 | # zero grad parameters
164 | model.zero_grad()
165 |
166 | # compute new grad parameters through time!
167 | loss.backward()
168 |
169 | # learning_rate step against the gradient
170 | for p in model.parameters():
171 | p.data.sub_(p.grad.data * lr)
172 |
173 |
174 | if __name__ == '__main__':
175 | _test_model()
176 | _test_training()
177 |
178 |
179 | __author__ = "Alfredo Canziani"
180 | __credits__ = ["Alfredo Canziani"]
181 | __maintainer__ = "Alfredo Canziani"
182 | __email__ = "alfredo.canziani@gmail.com"
183 | __status__ = "Production" # "Prototype", "Development", or "Production"
184 | __date__ = "Feb 17"
185 |
--------------------------------------------------------------------------------
/model/Model02.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as f
4 | from torch.autograd import Variable as V
5 | from math import ceil
6 |
7 |
8 | # Define some constants
9 | from model.RG import RG
10 |
11 | KERNEL_SIZE = 3
12 | PADDING = KERNEL_SIZE // 2
13 | KERNEL_STRIDE = 2
14 | OUTPUT_ADJUST = KERNEL_SIZE - 2 * PADDING
15 |
16 |
17 | class Model02(nn.Module):
18 | """
19 | Generate a constructor for model_02 type of network
20 | """
21 |
22 | def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
23 | """
24 | Initialise Model02 constructor
25 |
26 | :param network_size: (n, h1, h2, ..., emb_size, nb_videos)
27 | :type network_size: tuple
28 | :param input_spatial_size: (height, width)
29 | :type input_spatial_size: tuple
30 | """
31 | super().__init__()
32 | self.hidden_layers = len(network_size) - 2
33 |
34 | print('\n{:-^80}'.format(' Building model Model02 '))
35 | print('Hidden layers:', self.hidden_layers)
36 | print('Net sizing:', network_size)
37 | print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size))
38 |
39 | # main auto-encoder blocks
40 | self.activation_size = [input_spatial_size]
41 | for layer in range(0, self.hidden_layers):
42 | # print some annotation when building model
43 | print('{:-<80}'.format('Layer ' + str(layer + 1) + ' '))
44 | print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1]))
45 | self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer]))
46 | print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1]))
47 |
48 | # init D (discriminative) blocks
49 | multiplier = layer and 2 or 1 # D_n, n > 1, has intra-layer feedback
50 | setattr(self, 'D_' + str(layer + 1), nn.Conv2d(
51 | in_channels=network_size[layer] * multiplier, out_channels=network_size[layer + 1],
52 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
53 | ))
54 | setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1]))
55 |
56 | # init G (generative) blocks
57 | setattr(self, 'G_' + str(layer + 1), nn.ConvTranspose2d(
58 | in_channels=network_size[layer + 1], out_channels=network_size[layer],
59 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
60 | ))
61 | setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer]))
62 |
63 | # init auxiliary classifier
64 | print('{:-<80}'.format('Classifier '))
65 | print(network_size[-2], '-->', network_size[-1])
66 | self.average = nn.AvgPool2d(self.activation_size[-1])
67 | self.stabiliser = nn.Linear(network_size[-2], network_size[-1])
68 | print(80 * '-', end='\n\n')
69 |
70 | def forward(self, x, state):
71 | activation_sizes = [x.size()] # start from the input
72 | residuals = list()
73 | state = state or [None] * (self.hidden_layers - 1)
74 | for layer in range(0, self.hidden_layers): # connect discriminative blocks
75 | if layer: # concat the input with the state for D_n, n > 1
76 | s = state[layer - 1] or V(x.data.clone().zero_())
77 | x = torch.cat((x, s), 1)
78 | x = getattr(self, 'D_' + str(layer + 1))(x)
79 | residuals.append(x)
80 | x = f.relu(x)
81 | x = getattr(self, 'BN_D_' + str(layer + 1))(x)
82 | activation_sizes.append(x.size()) # cache output size for later retrieval
83 | for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks
84 | x = getattr(self, 'G_' + str(layer + 1))(x, activation_sizes[layer])
85 | if layer:
86 | state[layer - 1] = x
87 | x += residuals[layer - 1]
88 | x = f.relu(x)
89 | x = getattr(self, 'BN_G_' + str(layer + 1))(x)
90 | x_mean = self.average(residuals[-1])
91 | video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1))
92 |
93 | return (x, state), (x_mean, video_index)
94 |
95 |
96 | class Model02RG(nn.Module):
97 | """
98 | Generate a constructor for model_02_rg type of network
99 | """
100 |
101 | def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
102 | """
103 | Initialise Model02RG constructor
104 |
105 | :param network_size: (n, h1, h2, ..., emb_size, nb_videos)
106 | :type network_size: tuple
107 | :param input_spatial_size: (height, width)
108 | :type input_spatial_size: tuple
109 | """
110 | super().__init__()
111 | self.hidden_layers = len(network_size) - 2
112 |
113 | print('\n{:-^80}'.format(' Building model Model02RG '))
114 | print('Hidden layers:', self.hidden_layers)
115 | print('Net sizing:', network_size)
116 | print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size))
117 |
118 | # main auto-encoder blocks
119 | self.activation_size = [input_spatial_size]
120 | for layer in range(0, self.hidden_layers):
121 | # print some annotation when building model
122 | print('{:-<80}'.format('Layer ' + str(layer + 1) + ' '))
123 | print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1]))
124 | self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer]))
125 | print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1]))
126 |
127 | # init D (discriminative) blocks
128 | multiplier = layer and 2 or 1 # D_n, n > 1, has intra-layer feedback
129 | setattr(self, 'D_' + str(layer + 1), nn.Conv2d(
130 | in_channels=network_size[layer] * multiplier, out_channels=network_size[layer + 1],
131 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
132 | ))
133 | setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1]))
134 |
135 | # init G (generative) blocks
136 | setattr(self, 'G_' + str(layer + 1), RG(
137 | in_channels=network_size[layer + 1], out_channels=network_size[layer],
138 | kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
139 | ))
140 | setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer]))
141 |
142 | # init auxiliary classifier
143 | print('{:-<80}'.format('Classifier '))
144 | print(network_size[-2], '-->', network_size[-1])
145 | self.average = nn.AvgPool2d(self.activation_size[-1])
146 | self.stabiliser = nn.Linear(network_size[-2], network_size[-1])
147 | print(80 * '-', end='\n\n')
148 |
149 | def forward(self, x, state):
150 | activation_sizes = [x.size()] # start from the input
151 | residuals = list()
152 | # state[0] --> network layer state; state[1] --> generative state
153 | state = state or [[None] * (self.hidden_layers - 1), [None] * self.hidden_layers]
154 | for layer in range(0, self.hidden_layers): # connect discriminative blocks
155 | if layer: # concat the input with the state for D_n, n > 1
156 | s = state[0][layer - 1] or V(x.data.clone().zero_())
157 | x = torch.cat((x, s), 1)
158 | x = getattr(self, 'D_' + str(layer + 1))(x)
159 | residuals.append(x)
160 | x = f.relu(x)
161 | x = getattr(self, 'BN_D_' + str(layer + 1))(x)
162 | activation_sizes.append(x.size()) # cache output size for later retrieval
163 | for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks
164 | x = getattr(self, 'G_' + str(layer + 1))((x, activation_sizes[layer]), state[1][layer])
165 | state[1][layer] = x # h[t - 1] <- h[t]
166 | if layer:
167 | state[0][layer - 1] = x
168 | x += residuals[layer - 1]
169 | x = f.relu(x)
170 | x = getattr(self, 'BN_G_' + str(layer + 1))(x)
171 | x_mean = self.average(residuals[-1])
172 | video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1))
173 |
174 | return (x, state), (x_mean, video_index)
175 |
176 |
177 | def _test_models():
178 | _test_model(Model02)
179 | _test_model(Model02RG)
180 |
181 |
182 | def _test_model(Model):
183 | big_t = 2
184 | x = torch.rand(big_t + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5)
185 | big_k = 10
186 | y = torch.LongTensor(big_t, 1).random_(big_k)
187 | model = Model(network_size=(3, 6, 12, 18, big_k), input_spatial_size=x[0].size()[2:])
188 |
189 | state = None
190 | (x_hat, state), (emb, idx) = model(V(x[0]), state)
191 |
192 | print('Input size:', tuple(x.size()))
193 | print('Output size:', tuple(x_hat.data.size()))
194 | print('Video index size:', tuple(idx.size()))
195 | for i, s in enumerate(state):
196 | if isinstance(s, list):
197 | for i, s in enumerate(state[0]):
198 | print('Net state', i + 1, 'has size:', tuple(s.size()))
199 | for i, s in enumerate(state[1]):
200 | print('G', i + 1, 'state has size:', tuple(s.size()))
201 | break
202 | else:
203 | print('State', i + 1, 'has size:', tuple(s.size()))
204 | print('Embedding has size:', emb.data.numel())
205 |
206 | mse = nn.MSELoss()
207 | nll = nn.CrossEntropyLoss()
208 | x_next = V(x[1])
209 | y_var = V(y[0])
210 | loss_t1 = mse(x_hat, x_next) + nll(idx, y_var)
211 |
212 | from utils.visualise import show_graph
213 | show_graph(loss_t1)
214 |
215 | # run one more time
216 | (x_hat, _), (_, idx) = model(V(x[1]), state)
217 |
218 | x_next = V(x[2])
219 | y_var = V(y[1])
220 | loss_t2 = mse(x_hat, x_next) + nll(idx, y_var)
221 | loss_tot = loss_t2 + loss_t1
222 |
223 | show_graph(loss_tot)
224 |
225 |
226 | def _test_training_models():
227 | _test_training(Model02)
228 | _test_training(Model02RG)
229 |
230 |
231 | def _test_training(Model):
232 | big_k = 10 # number of training videos
233 | network_size = (3, 6, 12, 18, big_k)
234 | big_t = 6 # sequence length
235 | max_epoch = 10 # number of epochs
236 | lr = 3.16e-2 # learning rate
237 |
238 | # set manual seed
239 | torch.manual_seed(0)
240 |
241 | print('\n{:-^80}'.format(' Train a ' + str(network_size[:-1]) + ' layer network '))
242 | print('Sequence length T:', big_t)
243 | print('Create the input image and target sequences')
244 | x = torch.rand(big_t + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5)
245 | y = torch.LongTensor(big_t, 1).random_(big_k)
246 | print('Input has size', tuple(x.size()))
247 | print('Target index has size', tuple(y.size()))
248 |
249 | print('Define model')
250 | model = Model(network_size=network_size, input_spatial_size=x[0].size()[2:])
251 |
252 | print('Create a MSE and NLL criterions')
253 | mse = nn.MSELoss()
254 | nll = nn.CrossEntropyLoss()
255 |
256 | print('Run for', max_epoch, 'iterations')
257 | for epoch in range(0, max_epoch):
258 | state = None
259 | loss = 0
260 | for t in range(0, big_t):
261 | (x_hat, state), (emb, idx) = model(V(x[t]), state)
262 | loss += mse(x_hat, V(x[t + 1])) + nll(idx, V(y[t]))
263 |
264 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch + 1), loss.data[0]))
265 |
266 | # zero grad parameters
267 | model.zero_grad()
268 |
269 | # compute new grad parameters through time!
270 | loss.backward()
271 |
272 | # learning_rate step against the gradient
273 | for p in model.parameters():
274 | p.data.sub_(p.grad.data * lr)
275 |
276 |
277 | if __name__ == '__main__':
278 | _test_models()
279 | _test_training_models()
280 |
281 |
282 | __author__ = "Alfredo Canziani"
283 | __credits__ = ["Alfredo Canziani"]
284 | __maintainer__ = "Alfredo Canziani"
285 | __email__ = "alfredo.canziani@gmail.com"
286 | __status__ = "Production" # "Prototype", "Development", or "Production"
287 | __date__ = "Feb, Mar 17"
288 |
--------------------------------------------------------------------------------
/model/PrednetModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 |
5 | from model.DiscriminativeCell import DiscriminativeCell
6 | from model.GenerativeCell import GenerativeCell
7 |
8 |
9 | # Define some constants
10 | OUT_LAYER_SIZE = (3,) + tuple(2 ** p for p in range(4, 10))
11 | ERR_LAYER_SIZE = tuple(size * 2 for size in OUT_LAYER_SIZE)
12 | IN_LAYER_SIZE = (3,) + ERR_LAYER_SIZE
13 |
14 |
15 | class PrednetModel(nn.Module):
16 | """
17 | Build the Prednet model
18 | """
19 |
20 | def __init__(self, error_size_list):
21 | super().__init__()
22 | self.number_of_layers = len(error_size_list)
23 | for layer in range(0, self.number_of_layers):
24 | setattr(self, 'discriminator_' + str(layer + 1), DiscriminativeCell(
25 | input_size={'input': IN_LAYER_SIZE[layer], 'state': OUT_LAYER_SIZE[layer]},
26 | hidden_size=OUT_LAYER_SIZE[layer],
27 | first=(not layer)
28 | ))
29 | setattr(self, 'generator_' + str(layer + 1), GenerativeCell(
30 | input_size={'error': ERR_LAYER_SIZE[layer], 'up_state':
31 | OUT_LAYER_SIZE[layer + 1] if layer != self.number_of_layers - 1 else 0},
32 | hidden_size=OUT_LAYER_SIZE[layer],
33 | error_init_size=error_size_list[layer]
34 | ))
35 |
36 | def forward(self, bottom_up_input, error, state):
37 |
38 | # generative branch
39 | up_state = None
40 | for layer in reversed(range(0, self.number_of_layers)):
41 | state[layer] = getattr(self, 'generator_' + str(layer + 1))(
42 | error[layer], up_state, state[layer]
43 | )
44 | up_state = state[layer][0]
45 |
46 | # discriminative branch
47 | for layer in range(0, self.number_of_layers):
48 | error[layer] = getattr(self, 'discriminator_' + str(layer + 1))(
49 | layer and error[layer - 1] or bottom_up_input,
50 | state[layer][0]
51 | )
52 |
53 | return error, state
54 |
55 |
56 | class _BuildOneLayerModel(nn.Module):
57 | """
58 | Build a one layer Prednet model
59 | """
60 |
61 | def __init__(self, error_size_list):
62 | super().__init__()
63 | self.discriminator = DiscriminativeCell(
64 | input_size={'input': IN_LAYER_SIZE[0], 'state': OUT_LAYER_SIZE[0]},
65 | hidden_size=OUT_LAYER_SIZE[0],
66 | first=True
67 | )
68 | self.generator = GenerativeCell(
69 | input_size={'error': ERR_LAYER_SIZE[0], 'up_state': 0},
70 | hidden_size=OUT_LAYER_SIZE[0],
71 | error_init_size=error_size_list[0]
72 | )
73 |
74 | def forward(self, bottom_up_input, prev_error, state):
75 | state = self.generator(prev_error, None, state)
76 | error = self.discriminator(bottom_up_input, state[0])
77 | return error, state
78 |
79 |
80 | class _BuildTwoLayerModel(nn.Module):
81 | """
82 | Build a two layer Prednet model
83 | """
84 |
85 | def __init__(self, error_size_list):
86 | super().__init__()
87 | self.discriminator_1 = DiscriminativeCell(
88 | input_size={'input': IN_LAYER_SIZE[0], 'state': OUT_LAYER_SIZE[0]},
89 | hidden_size=OUT_LAYER_SIZE[0],
90 | first=True
91 | )
92 | self.discriminator_2 = DiscriminativeCell(
93 | input_size={'input': IN_LAYER_SIZE[1], 'state': OUT_LAYER_SIZE[1]},
94 | hidden_size=OUT_LAYER_SIZE[1]
95 | )
96 | self.generator_1 = GenerativeCell(
97 | input_size={'error': ERR_LAYER_SIZE[0], 'up_state': OUT_LAYER_SIZE[1]},
98 | hidden_size=OUT_LAYER_SIZE[0],
99 | error_init_size=error_size_list[0]
100 | )
101 | self.generator_2 = GenerativeCell(
102 | input_size={'error': ERR_LAYER_SIZE[1], 'up_state': 0},
103 | hidden_size=OUT_LAYER_SIZE[1],
104 | error_init_size=error_size_list[1]
105 | )
106 |
107 | def forward(self, bottom_up_input, error, state):
108 | state[1] = self.generator_2(error[1], None, state[1])
109 | state[0] = self.generator_1(error[0], state[1][0], state[0])
110 | error[0] = self.discriminator_1(bottom_up_input, state[0][0])
111 | error[1] = self.discriminator_2(error[0], state[1][0])
112 | return error, state
113 |
114 |
115 | def _test_one_layer_model():
116 | print('\nCreate the input image')
117 | input_image = Variable(torch.rand(1, 3, 8, 12))
118 |
119 | print('Input has size', list(input_image.data.size()))
120 |
121 | error_init_size = (1, 6, 8, 12)
122 | print('The error initialisation size is', error_init_size)
123 |
124 | print('Define a 1 layer Prednet')
125 | model = _BuildOneLayerModel((error_init_size,))
126 |
127 | print('Forward input and state to the model')
128 | state = None
129 | error = None
130 | error, state = model(input_image, prev_error=error, state=state)
131 |
132 | print('The error has size', list(error.data.size()))
133 | print('The state has size', list(state[0].data.size()))
134 |
135 |
136 | def _test_two_layer_model():
137 | print('\nCreate the input image')
138 | input_image = Variable(torch.rand(1, 3, 8, 12))
139 |
140 | print('Input has size', list(input_image.data.size()))
141 |
142 | error_init_size_list = ((1, 6, 8, 12), (1, 32, 4, 6))
143 | print('The error initialisation sizes are', *error_init_size_list)
144 |
145 | print('Define a 2 layer Prednet')
146 | model = _BuildTwoLayerModel(error_init_size_list)
147 |
148 | print('Forward input and state to the model')
149 | state = [None] * 2
150 | error = [None] * 2
151 | error, state = model(input_image, error=error, state=state)
152 |
153 | for layer in range(0, 2):
154 | print('Layer', layer + 1, 'error has size', list(error[layer].data.size()))
155 | print('Layer', layer + 1, 'state has size', list(state[layer][0].data.size()))
156 |
157 |
158 | def _test_L_layer_model():
159 |
160 | max_number_of_layers = 5
161 | for L in range(0, max_number_of_layers):
162 | print('\n---------- Test', str(L + 1), 'layer network ----------')
163 |
164 | print('Create the input image')
165 | input_image = Variable(torch.rand(1, 3, 4 * 2 ** L, 6 * 2 ** L))
166 |
167 | print('Input has size', list(input_image.data.size()))
168 |
169 | error_init_size_list = tuple(
170 | (1, ERR_LAYER_SIZE[l], 4 * 2 ** (L-l), 6 * 2 ** (L-l)) for l in range(0, L + 1)
171 | )
172 | print('The error initialisation sizes are', *error_init_size_list)
173 |
174 | print('Define a', str(L + 1), 'layer Prednet')
175 | model = PrednetModel(error_init_size_list)
176 |
177 | print('Forward input and state to the model')
178 | state = [None] * (L + 1)
179 | error = [None] * (L + 1)
180 | error, state = model(input_image, error=error, state=state)
181 |
182 | for layer in range(0, L + 1):
183 | print('Layer', layer + 1, 'error has size', list(error[layer].data.size()))
184 | print('Layer', layer + 1, 'state has size', list(state[layer][0].data.size()))
185 |
186 |
187 | def _test_training():
188 | number_of_layers = 3
189 | T = 6 # sequence length
190 | max_epoch = 10 # number of epochs
191 | lr = 1e-1 # learning rate
192 |
193 | # set manual seed
194 | torch.manual_seed(0)
195 |
196 | L = number_of_layers - 1
197 | print('\n---------- Train a', str(L + 1), 'layer network ----------')
198 | print('Create the input image and target sequences')
199 | input_sequence = Variable(torch.rand(T, 1, 3, 4 * 2 ** L, 6 * 2 ** L))
200 | print('Input has size', list(input_sequence.data.size()))
201 |
202 | error_init_size_list = tuple(
203 | (1, ERR_LAYER_SIZE[l], 4 * 2 ** (L - l), 6 * 2 ** (L - l)) for l in range(0, L + 1)
204 | )
205 | print('The error initialisation sizes are', *error_init_size_list)
206 | target_sequence = Variable(torch.zeros(T, *error_init_size_list[0]))
207 |
208 | print('Define a', str(L + 1), 'layer Prednet')
209 | model = PrednetModel(error_init_size_list)
210 |
211 | print('Create a MSE criterion')
212 | loss_fn = nn.MSELoss()
213 |
214 | print('Run for', max_epoch, 'iterations')
215 | for epoch in range(0, max_epoch):
216 | state = [None] * (L + 1)
217 | error = [None] * (L + 1)
218 | loss = 0
219 | for t in range(0, T):
220 | error, state = model(input_sequence[t], error, state)
221 | loss += loss_fn(error[0], target_sequence[t])
222 |
223 | print(' > Epoch {:2d} loss: {:.3f}'.format((epoch + 1), loss.data[0]))
224 |
225 | # zero grad parameters
226 | model.zero_grad()
227 |
228 | # compute new grad parameters through time!
229 | loss.backward()
230 |
231 | # learning_rate step against the gradient
232 | for p in model.parameters():
233 | p.data.sub_(p.grad.data * lr)
234 |
235 |
236 | def _main():
237 | _test_one_layer_model()
238 | _test_two_layer_model()
239 | _test_L_layer_model()
240 | _test_training()
241 |
242 |
243 | if __name__ == '__main__':
244 | _main()
245 |
246 |
247 | __author__ = "Alfredo Canziani"
248 | __credits__ = ["Alfredo Canziani"]
249 | __maintainer__ = "Alfredo Canziani"
250 | __email__ = "alfredo.canziani@gmail.com"
251 | __status__ = "Prototype" # "Prototype", "Development", or "Production"
252 | __date__ = "Feb 17"
253 |
--------------------------------------------------------------------------------
/model/README.md:
--------------------------------------------------------------------------------
1 | # Network architectures
2 |
3 | Three main architectures are currently available:
4 |
5 | - [`PrednetModel.py`](PrednetModel.py): implementation of [*PredNet*](https://coxlab.github.io/prednet/) in *PyTorch* through the following smaller blocks:
6 | - [`DiscriminativeCell.py`](DiscriminativeCell.py): computes the error between the input and state projections;
7 | - [`GenerativeCell.py`](GenerativeCell.py): computes the new state given error, top down state, and current state. Uses the following custom module:
8 | - [`ConvLSTMCell.py`](ConvLSTMCell.py): a pretty standard LSTM that uses convolutions instead of FC layers;
9 | - [`Model01.py`](Model01.py): symmetric, additional feed-forward/back;
10 | - [`Model02.py`](Model02.py): AKA *CortexNet*, additional feed-forward, modulated feed-back. May use:
11 | - [`RG.py`](RG.py): recurrent generative block ⇒ *Model02G*.
12 |
13 | 
--------------------------------------------------------------------------------
/model/RG.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class RG(nn.Module):
5 | """Recurrent Generative Module"""
6 |
7 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
8 | """ Initialise RG Module (parameters as nn.ConvTranspose2d)"""
9 | super().__init__()
10 | self.from_input = nn.ConvTranspose2d(
11 | in_channels=in_channels, out_channels=out_channels,
12 | kernel_size=kernel_size, stride=stride, padding=padding
13 | )
14 | self.from_state = nn.Conv2d(
15 | in_channels=out_channels, out_channels=out_channels,
16 | kernel_size=kernel_size, padding=padding, bias=False
17 | )
18 |
19 | def forward(self, x, state):
20 | """
21 | Calling signature
22 |
23 | :param x: (input, output_size)
24 | :type x: tuple
25 | :param state: previous output
26 | :type state: torch.Tensor
27 | :return: current state
28 | :rtype: torch.Tensor
29 | """
30 | x = self.from_input(*x) # the very first x is a tuple (input, expected_output_size)
31 | if state: x += self.from_state(state)
32 | return x
33 |
--------------------------------------------------------------------------------
/model/utils:
--------------------------------------------------------------------------------
1 | ../utils/
--------------------------------------------------------------------------------
/new_experiment.sh:
--------------------------------------------------------------------------------
1 | # Prepares environment for new experiment
2 | # Alfredo Canziani, Mar 17
3 |
4 | # It expects a directory / link named "results" in the cwd containing
5 | # numerically increasing folders with 3 digits
6 |
7 | shopt -s extglob # use extended pattern matching capabilities
8 |
9 | old_CLI=$(awk -F': ' '/CLI/{print $2}' last/train.log)
10 | max_exp=$(ls results/ | tail -1)
11 | max_exp=${max_exp##*(0)} # remove possible leading 0s
12 | dst_path=$(printf "results/%03d" "$((++max_exp))")
13 |
14 | echo -n " > Creating folder: "
15 | mkdir $dst_path
16 | ls -d --color=always $dst_path
17 |
18 | echo -n " > Linking:"
19 | ln -snf $dst_path last
20 | ls -l --color=always last | awk -F"$USER" '{print $3}'
21 |
22 | echo " > Previously you've used the following options"
23 | echo " CUDA_VISIBLE_DEVICES=n python -u main.py $old_CLI | tee last/train.log"
24 |
--------------------------------------------------------------------------------
/notebook/README.md:
--------------------------------------------------------------------------------
1 | # Collection of Jupyter Notebooks
2 |
3 | *Jupyter Notebooks* are a very effective tool for interactive data exploration and visualisation.
4 |
5 | ## Correct display
6 |
7 | I use dark styles for both *GitHub* and *Jupyter Notebook*.
8 | To see the content appropriately install:
9 |
10 | - [*Jupyter Notebook* dark theme](https://userstyles.org/styles/98208/jupyter-notebook-dark-originally-from-ipython);
11 | - [*GitHub* dark theme](https://userstyles.org/styles/37035/github-dark) and comment out the `invert #fff to #181818` code block.
12 |
13 | ## Notebooks
14 |
15 | These notebooks are somehow of personal usage.
16 | I used them for data exploration, getting familiar with PyTorch model graphs, and generate the paper figures and website animations.
17 | They all run, on my system, with the data generated from the experiments.
18 | I'm releasing them here for reference only, and you are welcome to browse them, even though you might not be able to execute them.
19 |
20 | - [`display_loss.ipynb`](display_loss.ipynb): display train and validation losses for every experiment;
21 | - [`figures_generator.ipynb`](figures_generator.ipynb): generates paper figures and website animations;
22 | - [`frequency_analysis.ipynb`](frequency_analysis.ipynb): explor video temporal-frequency components;
23 | - [`get_all_embeddings.ipynb`](get_all_embeddings.ipynb): display ResNet18 embeddings and probability *vs.* time;
24 | - [`get_data_stats.ipynb`](get_data_stats.ipynb): show stats on the e-VDS35 data set;
25 | - [`network_bisection.ipynb`](network_bisection.ipynb): mucking around with PyTorch model graph;
26 | - [`salient_regions.ipynb`](salient_regions.ipynb): implements and compute salient regions from [bojarski2017explaining](https://arxiv.org/abs/1704.07911);
27 | - [`stability_analysis.ipynb`](stability_analysis.ipynb): mucking around with ResNet18 and videos;
28 | - [`verify_resnet.ipynb`](verify_resnet.ipynb): checks whether I'm capable to use PyTorch's ResNet18;
29 | - [`plot_conf.py`](plot_conf.py): `matplotlib` shared configuration.
30 |
--------------------------------------------------------------------------------
/notebook/data:
--------------------------------------------------------------------------------
1 | ../data/
--------------------------------------------------------------------------------
/notebook/model:
--------------------------------------------------------------------------------
1 | ../model/
--------------------------------------------------------------------------------
/notebook/network_bisection.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torchvision\n",
13 | "from torch.autograd import Variable as V\n",
14 | "from utils.visualise import make_dot"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {
21 | "collapsed": true
22 | },
23 | "outputs": [],
24 | "source": [
25 | "resnet_18 = torchvision.models.resnet18(pretrained=True)\n",
26 | "resnet_18.eval();"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 3,
32 | "metadata": {
33 | "collapsed": true
34 | },
35 | "outputs": [],
36 | "source": [
37 | "# by setting the volatile flag to True, intermediate caches are not saved\n",
38 | "# making the inspection of the graph pretty boring / useless\n",
39 | "torch.manual_seed(0)\n",
40 | "x = V(torch.randn(1, 3, 224, 224))#, volatile=True)\n",
41 | "h_x = resnet_18(x)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 4,
47 | "metadata": {
48 | "collapsed": false
49 | },
50 | "outputs": [],
51 | "source": [
52 | "dot = make_dot(h_x) # generate network graph\n",
53 | "dot.render('net.dot'); # save DOT and PDF in the current directory\n",
54 | "# dot # uncomment for displaying the graph in the notebook"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 5,
60 | "metadata": {
61 | "collapsed": false
62 | },
63 | "outputs": [
64 | {
65 | "name": "stdout",
66 | "output_type": "stream",
67 | "text": [
68 | "h_x creator -> \n",
69 | "h_x creator prev fun type -> \n",
70 | "h_x creator prev fun length -> 3\n",
71 | "\n",
72 | "--- content of h_x creator prev fun ---\n",
73 | "0 --> (, 0)\n",
74 | "1 --> (Parameter containing:\n",
75 | "-1.8474e-02 -7.0461e-02 -5.1772e-02 ... -3.9030e-02 1.7351e-01 -4.0976e-02\n",
76 | "-8.1792e-02 -9.4370e-02 1.7355e-02 ... 2.0284e-01 -2.4782e-02 3.7172e-02\n",
77 | "-3.3164e-02 -5.6569e-02 -2.4165e-02 ... -3.4402e-02 -2.2659e-02 1.9705e-02\n",
78 | " ... ⋱ ... \n",
79 | "-1.0300e-02 3.2804e-03 -3.5863e-02 ... -2.7923e-02 -1.1458e-02 1.2759e-02\n",
80 | "-3.5879e-02 -3.5296e-02 -2.9602e-02 ... -3.2961e-02 -1.1022e-02 -5.1256e-02\n",
81 | " 2.1277e-03 -2.4839e-02 -8.2920e-02 ... 4.1731e-02 -5.0030e-02 6.6327e-02\n",
82 | "[torch.FloatTensor of size 1000x512]\n",
83 | ", 0)\n",
84 | "2 --> (Parameter containing:\n",
85 | "1.00000e-02 *\n",
86 | " -0.2634\n",
87 | " 0.3000\n",
88 | " 0.0656\n",
89 | " ⋮ \n",
90 | " -1.7868\n",
91 | " -0.0782\n",
92 | " -0.6345\n",
93 | "[torch.FloatTensor of size 1000]\n",
94 | ", 0)\n",
95 | "---------------------------------------\n",
96 | "\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "# explore network graph\n",
102 | "print('h_x creator ->',h_x.creator)\n",
103 | "print('h_x creator prev fun type ->', type(h_x.creator.previous_functions))\n",
104 | "print('h_x creator prev fun length ->', len(h_x.creator.previous_functions))\n",
105 | "print('\\n--- content of h_x creator prev fun ---')\n",
106 | "for a, b in enumerate(h_x.creator.previous_functions): print(a, '-->', b)\n",
107 | "print('---------------------------------------\\n')"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "The current node is a `torch.nn._functions.linear.Linear` object, fed by\n",
115 | "\n",
116 | "- 0 --> output of `torch.autograd._functions.tensor.View` object\n",
117 | "- 1 --> weight matrix of size `(1000, 512)`\n",
118 | "- 2 --> bias vector of size `(1000)`"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 6,
124 | "metadata": {
125 | "collapsed": false
126 | },
127 | "outputs": [
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "ResNet (\n",
133 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
134 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
135 | " (relu): ReLU (inplace)\n",
136 | " (maxpool): MaxPool2d (size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))\n",
137 | " (layer1): Sequential (\n",
138 | " (0): BasicBlock (\n",
139 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
140 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
141 | " (relu): ReLU (inplace)\n",
142 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
143 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
144 | " )\n",
145 | " (1): BasicBlock (\n",
146 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
147 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
148 | " (relu): ReLU (inplace)\n",
149 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
150 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
151 | " )\n",
152 | " )\n",
153 | " (layer2): Sequential (\n",
154 | " (0): BasicBlock (\n",
155 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
156 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
157 | " (relu): ReLU (inplace)\n",
158 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
159 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
160 | " (downsample): Sequential (\n",
161 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
162 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
163 | " )\n",
164 | " )\n",
165 | " (1): BasicBlock (\n",
166 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
167 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
168 | " (relu): ReLU (inplace)\n",
169 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
170 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
171 | " )\n",
172 | " )\n",
173 | " (layer3): Sequential (\n",
174 | " (0): BasicBlock (\n",
175 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
176 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
177 | " (relu): ReLU (inplace)\n",
178 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
179 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
180 | " (downsample): Sequential (\n",
181 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
182 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
183 | " )\n",
184 | " )\n",
185 | " (1): BasicBlock (\n",
186 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
187 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
188 | " (relu): ReLU (inplace)\n",
189 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
190 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
191 | " )\n",
192 | " )\n",
193 | " (layer4): Sequential (\n",
194 | " (0): BasicBlock (\n",
195 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
196 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
197 | " (relu): ReLU (inplace)\n",
198 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
199 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
200 | " (downsample): Sequential (\n",
201 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
202 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
203 | " )\n",
204 | " )\n",
205 | " (1): BasicBlock (\n",
206 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
207 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
208 | " (relu): ReLU (inplace)\n",
209 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
210 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
211 | " )\n",
212 | " )\n",
213 | " (avgpool): AvgPool2d (\n",
214 | " )\n",
215 | " (fc): Linear (512 -> 1000)\n",
216 | ")\n"
217 | ]
218 | }
219 | ],
220 | "source": [
221 | "print(resnet_18)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 7,
227 | "metadata": {
228 | "collapsed": false
229 | },
230 | "outputs": [
231 | {
232 | "data": {
233 | "text/plain": [
234 | "odict_keys(['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc'])"
235 | ]
236 | },
237 | "execution_count": 7,
238 | "metadata": {},
239 | "output_type": "execute_result"
240 | }
241 | ],
242 | "source": [
243 | "resnet_18._modules.keys()"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 8,
249 | "metadata": {
250 | "collapsed": false
251 | },
252 | "outputs": [
253 | {
254 | "name": "stdout",
255 | "output_type": "stream",
256 | "text": [
257 | "m: \n",
258 | "i: \n",
259 | " len: 1 \n",
260 | " type: \n",
261 | " data size: torch.Size([1, 512, 7, 7]) \n",
262 | " data type: torch.FloatTensor \n",
263 | "o: \n",
264 | " data size: torch.Size([1, 512, 1, 1]) \n",
265 | " data type: torch.FloatTensor\n"
266 | ]
267 | }
268 | ],
269 | "source": [
270 | "avgpool_layer = resnet_18._modules.get('avgpool')\n",
271 | "h = avgpool_layer.register_forward_hook(\n",
272 | " lambda m, i, o: \\\n",
273 | " print(\n",
274 | " 'm:', type(m),\n",
275 | " '\\ni:', type(i),\n",
276 | " '\\n len:', len(i),\n",
277 | " '\\n type:', type(i[0]),\n",
278 | " '\\n data size:', i[0].data.size(),\n",
279 | " '\\n data type:', i[0].data.type(),\n",
280 | " '\\no:', type(o),\n",
281 | " '\\n data size:', o.data.size(),\n",
282 | " '\\n data type:', o.data.type(),\n",
283 | " )\n",
284 | ")\n",
285 | "h_x = resnet_18(x)\n",
286 | "h.remove()"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 9,
292 | "metadata": {
293 | "collapsed": false
294 | },
295 | "outputs": [],
296 | "source": [
297 | "my_embedding = torch.zeros(512)\n",
298 | "def fun(m, i, o): my_embedding.copy_(o.data)\n",
299 | "h = avgpool_layer.register_forward_hook(fun)\n",
300 | "h_x = resnet_18(x)\n",
301 | "h.remove()"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": 10,
307 | "metadata": {
308 | "collapsed": false
309 | },
310 | "outputs": [
311 | {
312 | "data": {
313 | "text/plain": [
314 | "\n",
315 | " 0.3879 0.0205 0.0268 2.9453 0.0234 0.0000 0.0000 0.7327 1.0997 0.0000\n",
316 | "[torch.FloatTensor of size 1x10]"
317 | ]
318 | },
319 | "execution_count": 10,
320 | "metadata": {},
321 | "output_type": "execute_result"
322 | }
323 | ],
324 | "source": [
325 | "# print first values of the embedding\n",
326 | "my_embedding[:10].view(1, -1)"
327 | ]
328 | }
329 | ],
330 | "metadata": {
331 | "kernelspec": {
332 | "display_name": "Python 3",
333 | "language": "python",
334 | "name": "python3"
335 | },
336 | "language_info": {
337 | "codemirror_mode": {
338 | "name": "ipython",
339 | "version": 3
340 | },
341 | "file_extension": ".py",
342 | "mimetype": "text/x-python",
343 | "name": "python",
344 | "nbconvert_exporter": "python",
345 | "pygments_lexer": "ipython3",
346 | "version": "3.5.2"
347 | }
348 | },
349 | "nbformat": 4,
350 | "nbformat_minor": 0
351 | }
352 |
--------------------------------------------------------------------------------
/notebook/plot_conf.py:
--------------------------------------------------------------------------------
1 | # matplotlib and stuff
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def plt_style(c='k'):
6 | """
7 | Set plotting style for bright (``c = 'w'``) or dark (``c = 'k'``) backgrounds
8 |
9 | :param c: colour, can be set to ``'w'`` or ``'k'`` (which is the default)
10 | :type c: str
11 | """
12 | import matplotlib as mpl
13 | from matplotlib import rc
14 |
15 | # Reset previous configuration
16 | mpl.rcParams.update(mpl.rcParamsDefault)
17 | # %matplotlib inline # not from script
18 | get_ipython().run_line_magic('matplotlib', 'inline')
19 |
20 | # configuration for bright background
21 | if c == 'w':
22 | plt.style.use('bmh')
23 |
24 | # configurations for dark background
25 | if c == 'k':
26 | # noinspection PyTypeChecker
27 | plt.style.use(['dark_background', 'bmh'])
28 |
29 | # remove background colour, set figure size
30 | rc('figure', figsize=(16, 8), max_open_warning=False)
31 | rc('axes', facecolor='none')
32 |
33 | plt_style()
34 |
--------------------------------------------------------------------------------
/notebook/utils:
--------------------------------------------------------------------------------
1 | ../utils/
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 | # Utility scripts
2 |
3 | This folder and this `README.md` contain some very useful scripts.
4 | The folder content is the following:
5 |
6 | - [`check_exp_diff.sh*`](check_exp_diff.sh): shows difference between experiments' training CLI arguments;
7 | - [`update_experiments.sh*`](update_experiments.sh): synchronises experiments across multiple nodes;
8 | - [`show_error.plt*`](show_error.plt): show current error with *gnuplot* by `awk`-ing `last/train.log`;
9 | - [`show_error_exp.plt*`](show_error_exp.plt): show losses for a specific experiment;
10 | - [`image_plot.py`](image_plot.py): saves intermediate training visual outputs to PDF (may be used by `main.py`);
11 |
12 | ## Script usage and arguments
13 |
14 | ### Compare CLI arguments
15 |
16 | To compare the CLI arguments used in two different experiments run `./check_exp_diff.sh `.
17 | The script will output the `--word-diff` between the calling arguments, and `git show` the newer experiment's `HEAD`.
18 | Replace `<{base,new} exp>` with the corresponding three-digit identifier.
19 |
20 | ### Plotting current training losses
21 |
22 | To plot the (*MSE*, *CE*, *replica MSE*, and *periodic CE*) cost functions interactively for the current training network run `./show_error.plt -i`.
23 | To statically plot them, after training, run `./show_error.plt`.
24 |
25 | ### Displaying any experiment losses
26 |
27 | If the experiment data is not on the current machine run `./update_experiments.sh -i` to iteratively sync the data every `5` seconds.
28 | To view the losses iteratively run `./show_error_exp.plt -i`, where `` is for example `012`.
29 | To run statically restrain to use the `-i` flag.
30 |
31 | ### Synchronising experiments
32 |
33 | To sync experiments across multiple nodes run `./update_experiments.sh`.
34 | If you whish to run it quietly (verbose is default), run `./update_experiments.sh -q`.
35 | If you'd like to run it in a loop of `5` seconds, run `./update_experiments.sh -i`.
36 |
37 | ## Images manipulation scripts collection
38 |
39 | ### Get PNGs out of PDFs
40 |
41 | ```bash
42 | convert *.pdf *.png
43 | ```
44 |
45 | ### Create GIFs
46 |
47 | Create animations from the content of `00{6,7}/PDFs` into `anim/`:
48 |
49 | ```bash
50 | for p in 006/PDFs/*; do
51 | g=anim/${p##*/}
52 | convert -delay 100 $p ${p/6/7} ${g/pdf/gif}
53 | done
54 | ```
55 |
56 | `delay` is expressed in `10`ms.
--------------------------------------------------------------------------------
/utils/check_exp_diff.sh:
--------------------------------------------------------------------------------
1 | # Check what has changed across experiments
2 | # Alfredo Canziani, Mar 17
3 |
4 | # ./check_exp_diff.sh base_exp new_exp
5 |
6 | base_exp="../results/$1/train.log"
7 | new_exp="../results/$2/train.log"
8 |
9 | git diff --no-index --word-diff --color=always $base_exp $new_exp | head -7
10 |
11 | echo ""
12 | hash=$(awk -F ": " '/hash/{print $2}' $new_exp)
13 | git show --stat $hash
14 |
--------------------------------------------------------------------------------
/utils/image_plot.py:
--------------------------------------------------------------------------------
1 | import torch # if torch is not imported BEFORE pyplot you get a FUCKING segmentation fault
2 | from matplotlib import pyplot as plt
3 | from os.path import isdir, join
4 | from os import mkdir
5 |
6 |
7 | def _hist_show(a, k):
8 | a = _to_view(a)
9 | plt.subplot(2, 3, k)
10 | plt.hist(a.reshape(-1), 50)
11 | plt.grid('on')
12 | plt.gca().axes.get_yaxis().set_visible(False)
13 |
14 |
15 | def show_four(x, next_x, x_hat, fig):
16 | """
17 | Saves/overwrites a PDF named fig.pdf with x, next_x, x_hat histogram and x_hat
18 |
19 | :param x: x[t]
20 | :type x: torch.FloatTensor
21 | :param next_x: x[t + 1]
22 | :type next_x: torch.FloatTensor
23 | :param x_hat: ~x[t + 1]
24 | :type x_hat: torch.FloatTensor
25 | :param fig: figure number
26 | :type fig: int
27 | :return: nothing
28 | :rtype: None
29 | """
30 | f = plt.figure(fig)
31 | plt.clf()
32 | _sub(x, 1)
33 | _sub(next_x, 4)
34 | dif = next_x - x
35 | _sub(dif, 2)
36 | _hist_show(dif, 3)
37 | _sub(x_hat, 5)
38 | _hist_show(x_hat, 6)
39 | plt.subplots_adjust(left=0.01, bottom=0.06, right=.99, top=1, wspace=0, hspace=.12)
40 | f.savefig(str(fig) + '.pdf')
41 |
42 |
43 | # Setup output folder for figures collection
44 | def _show_ten_setup(pdf_path):
45 | if isdir(pdf_path):
46 | print('Folder "{}" already existent. Exiting.'.format(pdf_path))
47 | exit()
48 | mkdir(pdf_path)
49 |
50 |
51 | def show_ten(x, x_hat, pdf_path='PDFs'):
52 | """
53 | First two rows 10 ~x[t + 1], second two rows 10 x[t]
54 |
55 | :param x: x[t]
56 | :type x: torch.FloatTensor
57 | :param x_hat: ~x[t + 1]
58 | :type x_hat: torch.FloatTensor
59 | :param pdf_path: saving path
60 | :type pdf_path: str
61 | :return: nothing
62 | :rtype: None
63 | """
64 | if show_ten.c == 0 and pdf_path: _show_ten_setup(pdf_path)
65 | if show_ten.c % 10 == 0: show_ten.f = plt.figure()
66 | plt.figure(show_ten.f.number)
67 | plt.subplot(4, 5, 1 + show_ten.c % 10)
68 | _img_show(x_hat, y0=-.16, s=8)
69 | plt.subplot(4, 5, 11 + show_ten.c % 10)
70 | _img_show(x, y0=-.16, s=8)
71 | show_ten.c += 1
72 | plt.subplots_adjust(left=0, bottom=0.02, right=1, top=1, wspace=0, hspace=.12)
73 | if show_ten.c % 10 == 0: show_ten.f.savefig(join(pdf_path, str(show_ten.c // 10) + '_10.pdf'))
74 | show_ten.c = 0
75 |
76 |
77 | def _img_show(a, y0=-.13, s=12):
78 | a = _to_view(a)
79 | plt.imshow(a)
80 | plt.title('<{:.2f}> [{:.2f}, {:.2f}]'.format(a.mean(), a.min(), a.max()), y=y0, fontsize=s)
81 | plt.axis('off')
82 |
83 |
84 | def _sub(a, k):
85 | plt.subplot(2, 3, k)
86 | _img_show(a)
87 |
88 |
89 | def _to_view(a):
90 | return a.cpu().numpy().transpose((1, 2, 0))
91 |
92 |
93 | def _test_4():
94 | img = _test_setup()
95 | show_four(img, img, img, 1)
96 |
97 |
98 | def _test_10():
99 | img = _test_setup()
100 | for i in range(20): show_ten(img, -img, '')
101 |
102 |
103 | def _test_setup():
104 | from skimage.data import astronaut
105 | from skimage.transform import resize
106 | from matplotlib.figure import Figure
107 | Figure.savefig = lambda self, _: plt.show() # patch Figure class to simply display the figure
108 | img = torch.from_numpy(resize(astronaut(), (256, 256)).astype('f4').transpose((2, 0, 1)))
109 | return img
110 |
111 |
112 | if __name__ == '__main__':
113 | _test_4()
114 | _test_10()
115 |
116 | __author__ = "Alfredo Canziani"
117 | __credits__ = ["Alfredo Canziani"]
118 | __maintainer__ = "Alfredo Canziani"
119 | __email__ = "alfredo.canziani@gmail.com"
120 | __status__ = "Development" # "Prototype", "Development", or "Production"
121 | __date__ = "Mar 17"
122 |
--------------------------------------------------------------------------------
/utils/show_error.plt:
--------------------------------------------------------------------------------
1 | #!/usr/bin/gnuplot -c
2 |
3 | # Plot MSE and CE train loss iteratively
4 |
5 | # Run as:
6 | # ./show_error3.plt -i # to run it iteratively
7 | # ./show_error3.plt # to run it statically
8 |
9 | # Alfredo Canziani, Mar 17
10 |
11 | # set white on black theme
12 | set terminal wxt background rgb "black" noraise
13 | set xlabel textcolor rgb "white"
14 | set ylabel textcolor rgb "white"
15 | set y2label textcolor rgb "white"
16 | set key textcolor rgb "white"
17 | set border lc rgb 'white'
18 | set grid lc rgb 'white'
19 |
20 | set grid
21 | set xlabel "mini batch index / 10"
22 | set ylabel "mMSE"
23 | set y2label "CE"
24 | set y2tics
25 | plot \
26 | "< awk '/batches/{print $18,$21,$25,$29}' ../last/train.log" \
27 | u 0:1 w lines lw 2 title "MSE", \
28 | "" \
29 | u 0:3 w lines lw 2 title "rpl MSE", \
30 | "" \
31 | u 0:2 w lines lw 2 title "CE" axis x1y2, \
32 | "" \
33 | u 0:4 w lines lw 2 title "per CE" axis x1y2
34 |
35 | if (ARG1 ne '-i') {
36 | pause -1 # just hang in there
37 | exit
38 | }
39 |
40 | pause 5 # wait 5 seconds
41 | reread # and start over
42 |
--------------------------------------------------------------------------------
/utils/show_error_exp.plt:
--------------------------------------------------------------------------------
1 | #!/usr/bin/gnuplot -c
2 |
3 | # Plot MSE and CE train loss iteratively
4 |
5 | # Run as:
6 | # ./show_error_exp.plt EXP -i # to run it iteratively
7 | # ./show_error_exp.plt EXP # to run it statically
8 |
9 | # Alfredo Canziani, Mar 17
10 |
11 | # You need to run ./utils/update_experiments.sh -i
12 |
13 | # set white on black theme
14 | set terminal wxt background rgb "black" noraise
15 | set xlabel textcolor rgb "white"
16 | set ylabel textcolor rgb "white"
17 | set y2label textcolor rgb "white"
18 | set title textcolor rgb "white"
19 | set key textcolor rgb "white"
20 | set border lc rgb 'white'
21 | set grid lc rgb 'white'
22 |
23 | set grid
24 | set title "Network " . ARG1
25 | set xlabel "mini batch index / 10"
26 | set ylabel "mMSE"
27 | set y2label "CE"
28 | set y2tics
29 | if (ARG1 + 0 < 20) { # "+ 0" conversion string to number
30 | plot \
31 | "< awk '/batches/{print $18,$21,$25}' ../results/".ARG1."/train.log" \
32 | u 0:1 w lines lw 2 title "MSE", \
33 | "" \
34 | u 0:2 w lines lw 2 title "CE" axis x1y2, \
35 | "" \
36 | u 0:3 w lines lw 2 title "rpl MSE"
37 | } else {
38 | plot \
39 | "< awk '/batches/{print $18,$21,$25,$29}' ../results/".ARG1."/train.log" \
40 | u 0:1 w lines lw 2 title "MSE", \
41 | "" \
42 | u 0:3 w lines lw 2 title "rpl MSE", \
43 | "" \
44 | u 0:2 w lines lw 2 title "CE" axis x1y2, \
45 | "" \
46 | u 0:4 w lines lw 2 title "per CE" axis x1y2
47 | }
48 |
49 | if (ARG2 ne '-i') {
50 | pause -1 # just hang in there
51 | exit
52 | }
53 |
54 | pause 5 # wait 5 seconds
55 | reread # and start over
56 |
--------------------------------------------------------------------------------
/utils/update_experiments.sh:
--------------------------------------------------------------------------------
1 | # Pull and push latest experiment results
2 | # Alfredo Canziani, Mar 17
3 |
4 | # Run it as
5 | # ./update_experiments.sh
6 | # ./update_experiments.sh -q # quiet
7 | # ./update_experiments.sh -i # iteratively / 5 seconds loop
8 |
9 | if [ "$1" == "-i" ]; then
10 | function ctrl_c {
11 | echo -e "\nExiting."
12 | exit 0
13 | }
14 | trap ctrl_c INT
15 | echo -n "Syncing experiments every 5 seconds"
16 | while true; do
17 | ./update_experiments.sh -q
18 | # prints one "." every second for 5 seconds, then removes them
19 | for (( i = 0; i < 5; i++ )); do
20 | printf "."
21 | sleep 1
22 | done
23 | printf "%.s\b" {1..5}
24 | printf "%.s " {1..5}
25 | printf "%.s\b" {1..5}
26 | done
27 | fi
28 |
29 | if [ "$1" != "-q" ]; then verbose="--verbose"; fi
30 |
31 | local=$(hostname)
32 | case $local in
33 | "GPU0") remote=$GPU8;;
34 | "GPU8") remote=$GPU0;;
35 | *) echo "Something's wrong"; exit -1;;
36 | esac
37 |
38 | # Get experimental data
39 | if [ -n "$verbose" ]; then
40 | echo; printf "%.s#" {1..80}; echo
41 | echo -n "Getting experiments from $remote to $local"
42 | echo; printf "%.s#" {1..80}; echo; echo
43 | fi
44 | rsync \
45 | --update \
46 | --archive \
47 | $verbose \
48 | --human-readable \
49 | $remote:MatchNet/results/ \
50 | ../results
51 |
52 | # Send experimental data
53 | if [ -n "$verbose" ]; then
54 | echo; printf "%.s#" {1..80}; echo
55 | echo -n "Sending experiments to $remote from $local"
56 | echo; printf "%.s#" {1..80}; echo; echo
57 | fi
58 | rsync \
59 | --update \
60 | --archive \
61 | $verbose \
62 | --human-readable \
63 | ../results/ \
64 | $remote:MatchNet/results
65 |
--------------------------------------------------------------------------------
/utils/visualise.py:
--------------------------------------------------------------------------------
1 | from graphviz import Digraph
2 | from torch.autograd import Variable
3 | import sys, subprocess
4 | import uuid
5 |
6 |
7 | def make_dot(root):
8 | node_attr = dict(style='filled',
9 | shape='box',
10 | align='left',
11 | fontsize='12',
12 | ranksep='0.1',
13 | height='0.2')
14 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
15 | seen = set()
16 |
17 | def add_nodes(var):
18 | if var not in seen:
19 | if isinstance(var, Variable):
20 | value = '(' + ', '.join(['%d'% v for v in var.size()]) + ')'
21 | dot.node(str(id(var)), str(value), fillcolor='lightblue')
22 | else:
23 | dot.node(str(id(var)), str(type(var).__name__))
24 | seen.add(var)
25 | if hasattr(var, 'previous_functions'):
26 | for u in var.previous_functions:
27 | dot.edge(str(id(u[0])), str(id(var)))
28 | add_nodes(u[0])
29 | add_nodes(root.creator)
30 | return dot
31 |
32 |
33 | def show_graph(root):
34 | dot_file_name = '/tmp/' + str(uuid.uuid4())
35 | make_dot(root).render(dot_file_name)
36 | pdf_file_name = dot_file_name + '.pdf'
37 | if sys.platform == 'darwin':
38 | subprocess.call(('open', pdf_file_name))
39 | elif sys.platform == 'linux':
40 | subprocess.call(('xdg-open', pdf_file_name))
41 |
42 |
43 | __author__ = "Sergey Zagoruyko and Alfredo Canziani"
44 | __credits__ = ["Sergey Zagoruyko", "Alfredo Canziani"]
45 | __maintainer__ = "Alfredo Canziani"
46 | __email__ = "alfredo.canziani@gmail.com"
47 | __status__ = "Production" # "Prototype", "Development", or "Production"
48 | __date__ = "Feb 17"
49 |
--------------------------------------------------------------------------------