├── .gitignore
├── LICENSE.md
├── MNIST_results
├── MNIST_plain.jpg
└── MNIST_trained.jpg
├── README.md
├── mnist.py
├── nets.py
├── semi_supervised.py
├── training_functions.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | #sublime
107 | sublime/
108 |
109 | # images
110 | *.jpg
111 | *.png
112 | *.JPG
113 | *.jpeg
114 | *.JPEG
115 | !MNIST_plain.jpg
116 | !MNIST_trained.jpg
117 |
118 | # nets
119 | *.pt
120 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Michal Nazarczuk
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MNIST_results/MNIST_plain.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaal94/Semisupervised-Clustering/555da3d49a97e54807a3fe2aca0b7b5039d0bd34/MNIST_results/MNIST_plain.jpg
--------------------------------------------------------------------------------
/MNIST_results/MNIST_trained.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michaal94/Semisupervised-Clustering/555da3d49a97e54807a3fe2aca0b7b5039d0bd34/MNIST_results/MNIST_trained.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semisupervised Clustering
2 |
3 | This repository contains the code for semi-supervised clustering developed for Master Thesis: "Automatic analysis of images from camera-traps" by Michal Nazarczuk from Imperial College London
4 |
5 | The algorithm is inspired with DCEC method ([Deep Clustering with Convolutional Autoencoders](https://xifengguo.github.io/papers/ICONIP17-DCEC.pdf)). The main change adds "labelling" loss (cross-entropy between labelled examples and their predictions) as the loss component.
6 |
7 | ## Prerequisites
8 |
9 | The following libraries are required to be installed for the proper code evaluation:
10 |
11 | 1. PyTorch
12 | 2. NumPy
13 | 3. scikit-learn
14 | 4. [TensorboardX](https://github.com/lanpa/tensorboardX)
15 |
16 | The code was written and tested on Python 3.4.1
17 |
18 | ## Installation and usage
19 |
20 | ### Installation
21 |
22 | Just copy the repository to your local folder:
23 | ```
24 | git clone https://github.com/michaal94/Semisupervised-Clustering
25 | ```
26 |
27 | ### Use of the algortihm
28 |
29 | In order to test the basic version of the semi-supervised clustering just run it with your python distribution you installed libraries for (Anaconda, Virtualenv, etc.). In general type:
30 |
31 | ```
32 | cd Semisupervised-Clustering
33 | python3 semi_supervised.py
34 | ```
35 | The example will run sample clustering with MNIST-train dataset.
36 |
37 | ## Options
38 |
39 | The algorithm offers a plenty of options for adjustments:
40 | 1. Mode choice: full or pretraining only, use:
41 | ```--mode train_full``` or ```--mode pretrain```
42 |
43 | Fot full training you can specify whether to use pretraining phase ```--pretrain True``` or use saved network ```--pretrain False``` and
44 | ```--pretrained net ("path" or idx)``` with path or index (see catalog structure) of the pretrained network
45 | 2. Dataset choice:
46 | + MNIST - train, test, full
47 | + Custom dataset - use the following data structure (characteristic for PyTorch):
48 | ```
49 | -data_directory (clusters must corespond to real clustering only for statistics)
50 | -cluster_1
51 | -image_1
52 | -image_2
53 | -...
54 | -cluster_2
55 | -image_1
56 | -image_2
57 | -...
58 | -...
59 | -data_directory_l (data used as labelled, use at least one example in each class in the current version of algorithm)
60 | -cluster_1
61 | -image_1
62 | -image_2
63 | -...
64 | -cluster_2
65 | -image_1
66 | -image_2
67 | -...
68 | -...
69 | ```
70 | Use the following: ```--dataset MNIST-train```,
71 | ```--dataset MNIST-test```,
72 | ```--dataset MNIST-full``` or
73 | ```--dataset custom``` (use the last one with path
74 | ```--dataset_path 'path to your dataset'```
75 | and the trasformation you want for images
76 | ```--custom_img_size [height, width, depth]```)
77 | 3. Different network architectures:
78 | + CAE 3 - convolutional autoencoder used in [DCEC](https://xifengguo.github.io/papers/ICONIP17-DCEC.pdf) ```--net_architecture CAE_3```
79 | + CAE 3 BN - version with Batch Normalisation layers ```--net_architecture CAE_3bn```
80 | + CAE 4 (BN) - convolutional autoencoder with 4 convolutional blocks ```--net_architecture CAE_4``` and ```--net_architecture CAE_4bn```
81 | + CAE 5 (BN) - convolutional autoencoder with 5 convolutional blocks ```--net_architecture CAE_5``` and ```--net_architecture CAE_5bn``` (used for 128x128 photos)
82 |
83 | The following opions may be used for model changes:
84 | + LeakyReLU or ReLU usage: ```--leaky True/False``` (True provided better results)
85 | + Negative slope for Leaky ReLU: ```--neg_slope value``` (Values around 0.01 were used)
86 | + Use of sigmoid and tanh activations at the end of encoder and decoder: ```--activations True/False``` (False provided better results)
87 | + Use of bias in layers: ```--bias True/False```
88 | 4. Optimiser and scheduler settings (Adam optimiser):
89 | + Learning rate: ```--rate value``` (0.001 is reasonable value for Adam)
90 | + Learning rate for pretraining phase: ```--rate_pretrain value``` (0.001 can be used as well)
91 | + Weight decay: ```--weight value``` (0 was used)
92 | + Weight decay for pretraining phase: ```--weight_pretrain value```
93 | + Scheduler step (how many iterations till the rate is changed): ```--sched_step value```
94 | + Scheduler step for pretraining phase: ```--sched_step_pretrain value```
95 | + Scheduler gamma (multiplier of learning rate): ```--sched_gamma value```
96 | + Scheduler gamma for pretraining phase: ```--sched_gamma_pretrain value```
97 | 5. Algorithm specific parameters:
98 | + Clustering loss weight (for reconstruction loss fixed with weight 1): ```--gamma value``` (Value of 0.1 provided good results)
99 | + Labelling loss weight: ```--gamma_lab value``` (0.01 provided good results)
100 | + Update interval for target distribution (in number of batches between updates): ```update_interval value``` (Value may be chosen such that distribution is updated each 1000-2000 photos)
101 | + Label check interval ```--label_upd_interval value``` (Suggested to leave each iteration update)
102 | + Stop criterium tolerance ```--tol value``` (Depends on dataset, for small 0.01 was used for bigger e.g. MNIST - 0.001)
103 | + Target number of clusters ```--num_clusters value```
104 | 6. Other options:
105 | + Batch size: ```--batch_size value``` (Depend on your device, but remember that [too much may be bad for convergence](https://towardsdatascience.com/recent-advances-for-a-better-understanding-of-deep-learning-part-i-5ce34d1cc914))
106 | + Epochs if stop criterium not met: ```--epochs value```
107 | + Epochs of pretraining: ```--epochs_pretrain value``` (300 epochs were used, 200 with 0.001 lerning rate and 100 with 10 times smaller - ```--sched_step_pretrain 200```, ```--sched_gamma_pretrain 0.1```)
108 | + Report printing frequency (in batches): ```--printing_frequency value```
109 | + Tensorboard export: ```--tensorboard True/False```
110 |
111 | ## Catalog structure
112 |
113 | The code creates the following catalog structure when reporting the statistics:
114 | ```
115 | -Reports
116 | -(net_architecture_name)_(index).txt
117 | -Nets (copies of weights
118 | -(net_architecture_name)_(index).pt
119 | -(net_architecture_name)_(index)_pretrained.txt
120 | -Runs
121 | -(net_architecture_name)_(index) <- directory containing tensorboard event file
122 | ```
123 | The files are indexed automatically for the files not to be accidentally overwritten.
124 |
125 | ## Performance
126 |
127 | The code was mainly used to cluster images coming from camera-trap events. However, some additional benchmarks were performed on MNIST datasets. The following table gather some results (for 2% of labelled data):
128 |
129 | Set | NMI | Acc
130 | ---|---|---
131 | MNIST-full | 95.13 | 98.22%
132 | MNIST-test | 89.59 | 95.29%
133 |
134 | In addition, the _t-SNE_ plots of plain and clustered MNIST full dataset are shown:
135 |
136 | Full set before clustering:
137 |
138 |
139 |
140 | After clustering:
141 |
142 |
--------------------------------------------------------------------------------
/mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.utils.data as data
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import errno
7 | import numpy as np
8 | import torch
9 | import codecs
10 |
11 |
12 | class MNIST(data.Dataset):
13 | """`MNIST `_ Dataset.
14 |
15 | Args:
16 | root (string): Root directory of dataset where ``processed/training.pt``
17 | and ``processed/test.pt`` exist.
18 | train (bool, optional): If True, creates dataset from ``training.pt``,
19 | otherwise from ``test.pt``.
20 | download (bool, optional): If true, downloads the dataset from the internet and
21 | puts it in root directory. If dataset is already downloaded, it is not
22 | downloaded again.
23 | transform (callable, optional): A function/transform that takes in an PIL image
24 | and returns a transformed version. E.g, ``transforms.RandomCrop``
25 | target_transform (callable, optional): A function/transform that takes in the
26 | target and transforms it.
27 | """
28 | urls = [
29 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
30 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
31 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
32 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
33 | ]
34 | raw_folder = 'raw'
35 | processed_folder = 'processed'
36 | training_file = 'training.pt'
37 | test_file = 'test.pt'
38 |
39 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, small=False, full=False):
40 | self.root = os.path.expanduser(root)
41 | self.transform = transform
42 | self.target_transform = target_transform
43 | self.train = train # training set or test set
44 | self.full = full
45 |
46 | if full:
47 | self.train = True
48 |
49 | if download:
50 | self.download()
51 |
52 | if not self._check_exists():
53 | raise RuntimeError('Dataset not found.' +
54 | ' You can use download=True to download it')
55 |
56 | self.train_data, self.train_labels = torch.load(os.path.join(self.root, self.processed_folder, self.training_file))
57 | self.test_data, self.test_labels = torch.load(os.path.join(self.root, self.processed_folder, self.test_file))
58 |
59 | if full:
60 | self.train_data = np.concatenate((self.train_data, self.test_data), axis=0)
61 | self.train_labels = np.concatenate((self.train_labels, self.test_labels), axis=0)
62 |
63 | if small:
64 | self.train_data = self.train_data[0:1400]
65 | self.train_labels = self.train_labels[0:1400]
66 | if not full:
67 | self.train_data = self.train_data[0:1200]
68 | self.train_labels = self.train_labels[0:1200]
69 | self.test_data = self.test_data[0:200]
70 | self.test_labels = self.test_labels[0:200]
71 |
72 | def __getitem__(self, index):
73 | """
74 | Args:
75 | index (int): Index
76 |
77 | Returns:
78 | tuple: (image, target) where target is index of the target class.
79 | """
80 | if self.train:
81 | img, target = self.train_data[index], self.train_labels[index]
82 | else:
83 | img, target = self.test_data[index], self.test_labels[index]
84 |
85 | # doing this so that it is consistent with all other datasets
86 | # to return a PIL Image
87 | if self.full:
88 | img = Image.fromarray(img, mode='L')
89 | else:
90 | img = Image.fromarray(img.numpy(), mode='L')
91 |
92 | if self.transform is not None:
93 | img = self.transform(img)
94 |
95 | if self.target_transform is not None:
96 | target = self.target_transform(target)
97 |
98 | return img, target
99 |
100 | def __len__(self):
101 | if self.train:
102 | return len(self.train_data)
103 | else:
104 | return len(self.test_data)
105 |
106 | def _check_exists(self):
107 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
108 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
109 |
110 | def download(self):
111 | """Download the MNIST data if it doesn't exist in processed_folder already."""
112 | from six.moves import urllib
113 | import gzip
114 |
115 | if self._check_exists():
116 | return
117 |
118 | # download files
119 | try:
120 | os.makedirs(os.path.join(self.root, self.raw_folder))
121 | os.makedirs(os.path.join(self.root, self.processed_folder))
122 | except OSError as e:
123 | if e.errno == errno.EEXIST:
124 | pass
125 | else:
126 | raise
127 |
128 | for url in self.urls:
129 | print('Downloading ' + url)
130 | data = urllib.request.urlopen(url)
131 | filename = url.rpartition('/')[2]
132 | file_path = os.path.join(self.root, self.raw_folder, filename)
133 | with open(file_path, 'wb') as f:
134 | f.write(data.read())
135 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \
136 | gzip.GzipFile(file_path) as zip_f:
137 | out_f.write(zip_f.read())
138 | os.unlink(file_path)
139 |
140 | # process and save as torch files
141 | print('Processing...')
142 |
143 | training_set = (
144 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
145 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
146 | )
147 | test_set = (
148 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
149 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
150 | )
151 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
152 | torch.save(training_set, f)
153 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
154 | torch.save(test_set, f)
155 |
156 | print('Done!')
157 |
158 | def __repr__(self):
159 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
160 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
161 | tmp = 'train' if self.train is True else 'test'
162 | fmt_str += ' Split: {}\n'.format(tmp)
163 | fmt_str += ' Root Location: {}\n'.format(self.root)
164 | tmp = ' Transforms (if any): '
165 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
166 | tmp = ' Target Transforms (if any): '
167 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
168 | return fmt_str
169 |
170 |
171 | class FashionMNIST(MNIST):
172 | """`Fashion-MNIST `_ Dataset.
173 |
174 | Args:
175 | root (string): Root directory of dataset where ``processed/training.pt``
176 | and ``processed/test.pt`` exist.
177 | train (bool, optional): If True, creates dataset from ``training.pt``,
178 | otherwise from ``test.pt``.
179 | download (bool, optional): If true, downloads the dataset from the internet and
180 | puts it in root directory. If dataset is already downloaded, it is not
181 | downloaded again.
182 | transform (callable, optional): A function/transform that takes in an PIL image
183 | and returns a transformed version. E.g, ``transforms.RandomCrop``
184 | target_transform (callable, optional): A function/transform that takes in the
185 | target and transforms it.
186 | """
187 | urls = [
188 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
189 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
190 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
191 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
192 | ]
193 |
194 |
195 | class EMNIST(MNIST):
196 | """`EMNIST `_ Dataset.
197 |
198 | Args:
199 | root (string): Root directory of dataset where ``processed/training.pt``
200 | and ``processed/test.pt`` exist.
201 | split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
202 | ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
203 | which one to use.
204 | train (bool, optional): If True, creates dataset from ``training.pt``,
205 | otherwise from ``test.pt``.
206 | download (bool, optional): If true, downloads the dataset from the internet and
207 | puts it in root directory. If dataset is already downloaded, it is not
208 | downloaded again.
209 | transform (callable, optional): A function/transform that takes in an PIL image
210 | and returns a transformed version. E.g, ``transforms.RandomCrop``
211 | target_transform (callable, optional): A function/transform that takes in the
212 | target and transforms it.
213 | """
214 | url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
215 | splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
216 |
217 | def __init__(self, root, split, **kwargs):
218 | if split not in self.splits:
219 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
220 | split, ', '.join(self.splits),
221 | ))
222 | self.split = split
223 | self.training_file = self._training_file(split)
224 | self.test_file = self._test_file(split)
225 | super(EMNIST, self).__init__(root, **kwargs)
226 |
227 | def _training_file(self, split):
228 | return 'training_{}.pt'.format(split)
229 |
230 | def _test_file(self, split):
231 | return 'test_{}.pt'.format(split)
232 |
233 | def download(self):
234 | """Download the EMNIST data if it doesn't exist in processed_folder already."""
235 | from six.moves import urllib
236 | import gzip
237 | import shutil
238 | import zipfile
239 |
240 | if self._check_exists():
241 | return
242 |
243 | # download files
244 | try:
245 | os.makedirs(os.path.join(self.root, self.raw_folder))
246 | os.makedirs(os.path.join(self.root, self.processed_folder))
247 | except OSError as e:
248 | if e.errno == errno.EEXIST:
249 | pass
250 | else:
251 | raise
252 |
253 | print('Downloading ' + self.url)
254 | data = urllib.request.urlopen(self.url)
255 | filename = self.url.rpartition('/')[2]
256 | raw_folder = os.path.join(self.root, self.raw_folder)
257 | file_path = os.path.join(raw_folder, filename)
258 | with open(file_path, 'wb') as f:
259 | f.write(data.read())
260 |
261 | print('Extracting zip archive')
262 | with zipfile.ZipFile(file_path) as zip_f:
263 | zip_f.extractall(raw_folder)
264 | os.unlink(file_path)
265 | gzip_folder = os.path.join(raw_folder, 'gzip')
266 | for gzip_file in os.listdir(gzip_folder):
267 | if gzip_file.endswith('.gz'):
268 | print('Extracting ' + gzip_file)
269 | with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
270 | gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
271 | out_f.write(zip_f.read())
272 | shutil.rmtree(gzip_folder)
273 |
274 | # process and save as torch files
275 | for split in self.splits:
276 | print('Processing ' + split)
277 | training_set = (
278 | read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
279 | read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
280 | )
281 | test_set = (
282 | read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
283 | read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
284 | )
285 | with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
286 | torch.save(training_set, f)
287 | with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
288 | torch.save(test_set, f)
289 |
290 | print('Done!')
291 |
292 |
293 | def get_int(b):
294 | return int(codecs.encode(b, 'hex'), 16)
295 |
296 |
297 | def read_label_file(path):
298 | with open(path, 'rb') as f:
299 | data = f.read()
300 | assert get_int(data[:4]) == 2049
301 | length = get_int(data[4:8])
302 | parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
303 | return torch.from_numpy(parsed).view(length).long()
304 |
305 |
306 | def read_image_file(path):
307 | with open(path, 'rb') as f:
308 | data = f.read()
309 | assert get_int(data[:4]) == 2051
310 | length = get_int(data[4:8])
311 | num_rows = get_int(data[8:12])
312 | num_cols = get_int(data[12:16])
313 | images = []
314 | parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
315 | return torch.from_numpy(parsed).view(length, num_rows, num_cols)
316 |
--------------------------------------------------------------------------------
/nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import copy
4 |
5 | # Clustering layer definition (see DCEC article for equations)
6 | class ClusterlingLayer(nn.Module):
7 | def __init__(self, in_features=10, out_features=10, alpha=1.0):
8 | super(ClusterlingLayer, self).__init__()
9 | self.in_features = in_features
10 | self.out_features = out_features
11 | self.alpha = alpha
12 | self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_features))
13 | self.weight = nn.init.xavier_uniform_(self.weight)
14 |
15 | def forward(self, x):
16 | x = x.unsqueeze(1) - self.weight
17 | x = torch.mul(x, x)
18 | x = torch.sum(x, dim=2)
19 | x = 1.0 + (x / self.alpha)
20 | x = 1.0 / x
21 | x = x ** ((self.alpha +1.0) / 2.0)
22 | x = torch.t(x) / torch.sum(x, dim=1)
23 | x = torch.t(x)
24 | return x
25 |
26 | def extra_repr(self):
27 | return 'in_features={}, out_features={}, alpha={}'.format(
28 | self.in_features, self.out_features, self.alpha
29 | )
30 |
31 | def set_weight(self, tensor):
32 | self.weight = nn.Parameter(tensor)
33 |
34 |
35 | # Convolutional autoencoder directly from DCEC article
36 | class CAE_3(nn.Module):
37 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128], leaky=True, neg_slope=0.01, activations=False, bias=True):
38 | super(CAE_3, self).__init__()
39 | self.activations = activations
40 | # bias = True
41 | self.pretrained = False
42 | self.num_clusters = num_clusters
43 | self.input_shape = input_shape
44 | self.filters = filters
45 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias)
46 | if leaky:
47 | self.relu = nn.LeakyReLU(negative_slope=neg_slope)
48 | else:
49 | self.relu = nn.ReLU(inplace=False)
50 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias)
51 | self.conv3 = nn.Conv2d(filters[1], filters[2], 3, stride=2, padding=0, bias=bias)
52 | lin_features_len = ((input_shape[0]//2//2-1) // 2) * ((input_shape[0]//2//2-1) // 2) * filters[2]
53 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias)
54 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias)
55 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0
56 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 3, stride=2, padding=0, output_padding=out_pad, bias=bias)
57 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0
58 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, bias=bias)
59 | out_pad = 1 if input_shape[0] % 2 == 0 else 0
60 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, bias=bias)
61 | self.clustering = ClusterlingLayer(num_clusters, num_clusters)
62 | # ReLU copies for graph representation in tensorboard
63 | self.relu1_1 = copy.deepcopy(self.relu)
64 | self.relu2_1 = copy.deepcopy(self.relu)
65 | self.relu3_1 = copy.deepcopy(self.relu)
66 | self.relu1_2 = copy.deepcopy(self.relu)
67 | self.relu2_2 = copy.deepcopy(self.relu)
68 | self.relu3_2 = copy.deepcopy(self.relu)
69 | self.sig = nn.Sigmoid()
70 | self.tanh = nn.Tanh()
71 |
72 | def forward(self, x):
73 | x = self.conv1(x)
74 | x = self.relu1_1(x)
75 | x = self.conv2(x)
76 | x = self.relu2_1(x)
77 | x = self.conv3(x)
78 | if self.activations:
79 | x = self.sig(x)
80 | else:
81 | x = self.relu3_1(x)
82 | x = x.view(x.size(0), -1)
83 | x = self.embedding(x)
84 | extra_out = x
85 | clustering_out = self.clustering(x)
86 | x = self.deembedding(x)
87 | x = self.relu1_2(x)
88 | x = x.view(x.size(0), self.filters[2], ((self.input_shape[0]//2//2-1) // 2), ((self.input_shape[0]//2//2-1) // 2))
89 | x = self.deconv3(x)
90 | x = self.relu2_2(x)
91 | x = self.deconv2(x)
92 | x = self.relu3_2(x)
93 | x = self.deconv1(x)
94 | if self.activations:
95 | x = self.tanh(x)
96 | return x, clustering_out, extra_out
97 |
98 |
99 | # Convolutional autoencoder from DCEC article with Batch Norms and Leaky ReLUs
100 | class CAE_bn3(nn.Module):
101 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128], leaky=True, neg_slope=0.01, activations=False, bias=True):
102 | super(CAE_bn3, self).__init__()
103 | self.activations=activations
104 | self.pretrained = False
105 | self.num_clusters = num_clusters
106 | self.input_shape = input_shape
107 | self.filters = filters
108 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias)
109 | self.bn1_1 = nn.BatchNorm2d(filters[0])
110 | if leaky:
111 | self.relu = nn.LeakyReLU(negative_slope=neg_slope)
112 | else:
113 | self.relu = nn.ReLU(inplace=False)
114 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias)
115 | self.bn2_1 = nn.BatchNorm2d(filters[1])
116 | self.conv3 = nn.Conv2d(filters[1], filters[2], 3, stride=2, padding=0, bias=bias)
117 | lin_features_len = ((input_shape[0]//2//2-1) // 2) * ((input_shape[0]//2//2-1) // 2) * filters[2]
118 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias)
119 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias)
120 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0
121 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 3, stride=2, padding=0, output_padding=out_pad, bias=bias)
122 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0
123 | self.bn3_2 = nn.BatchNorm2d(filters[1])
124 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, bias=bias)
125 | out_pad = 1 if input_shape[0] % 2 == 0 else 0
126 | self.bn2_2 = nn.BatchNorm2d(filters[0])
127 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, bias=bias)
128 | self.clustering = ClusterlingLayer(num_clusters, num_clusters)
129 | # ReLU copies for graph representation in tensorboard
130 | self.relu1_1 = copy.deepcopy(self.relu)
131 | self.relu2_1 = copy.deepcopy(self.relu)
132 | self.relu3_1 = copy.deepcopy(self.relu)
133 | self.relu1_2 = copy.deepcopy(self.relu)
134 | self.relu2_2 = copy.deepcopy(self.relu)
135 | self.relu3_2 = copy.deepcopy(self.relu)
136 | self.sig = nn.Sigmoid()
137 | self.tanh = nn.Tanh()
138 |
139 | def forward(self, x):
140 | x = self.conv1(x)
141 | x = self.relu1_1(x)
142 | x = self.bn1_1(x)
143 | x = self.conv2(x)
144 | x = self.relu2_1(x)
145 | x = self.bn2_1(x)
146 | x = self.conv3(x)
147 | if self.activations:
148 | x = self.sig(x)
149 | else:
150 | x = self.relu3_1(x)
151 | x = x.view(x.size(0), -1)
152 | x = self.embedding(x)
153 | extra_out = x
154 | clustering_out = self.clustering(x)
155 | x = self.deembedding(x)
156 | x = self.relu1_2(x)
157 | x = x.view(x.size(0), self.filters[2], ((self.input_shape[0]//2//2-1) // 2), ((self.input_shape[0]//2//2-1) // 2))
158 | x = self.deconv3(x)
159 | x = self.relu2_2(x)
160 | x = self.bn3_2(x)
161 | x = self.deconv2(x)
162 | x = self.relu3_2(x)
163 | x = self.bn2_2(x)
164 | x = self.deconv1(x)
165 | if self.activations:
166 | x = self.tanh(x)
167 | return x, clustering_out, extra_out
168 |
169 |
170 | # Convolutional autoencoder with 4 convolutional blocks
171 | class CAE_4(nn.Module):
172 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256], leaky=True, neg_slope=0.01, activations=False, bias=True):
173 | super(CAE_4, self).__init__()
174 | self.activations = activations
175 | self.pretrained = False
176 | self.num_clusters = num_clusters
177 | self.input_shape = input_shape
178 | self.filters = filters
179 | if leaky:
180 | self.relu = nn.LeakyReLU(negative_slope=neg_slope)
181 | else:
182 | self.relu = nn.ReLU(inplace=False)
183 |
184 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias)
185 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias)
186 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias)
187 | self.conv4 = nn.Conv2d(filters[2], filters[3], 3, stride=2, padding=0, bias=bias)
188 |
189 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * \
190 | filters[3]
191 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias)
192 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias)
193 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0
194 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 3, stride=2, padding=0, output_padding=out_pad,
195 | bias=bias)
196 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0
197 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad,
198 | bias=bias)
199 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0
200 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad,
201 | bias=bias)
202 | out_pad = 1 if input_shape[0] % 2 == 0 else 0
203 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad,
204 | bias=bias)
205 | self.clustering = ClusterlingLayer(num_clusters, num_clusters)
206 | # ReLU copies for graph representation in tensorboard
207 | self.relu1_1 = copy.deepcopy(self.relu)
208 | self.relu2_1 = copy.deepcopy(self.relu)
209 | self.relu3_1 = copy.deepcopy(self.relu)
210 | self.relu4_1 = copy.deepcopy(self.relu)
211 | self.relu1_2 = copy.deepcopy(self.relu)
212 | self.relu2_2 = copy.deepcopy(self.relu)
213 | self.relu3_2 = copy.deepcopy(self.relu)
214 | self.relu4_2 = copy.deepcopy(self.relu)
215 | self.sig = nn.Sigmoid()
216 | self.tanh = nn.Tanh()
217 |
218 | def forward(self, x):
219 | x = self.conv1(x)
220 | x = self.relu1_1(x)
221 | x = self.conv2(x)
222 | x = self.relu2_1(x)
223 | x = self.conv3(x)
224 | x = self.relu3_1(x)
225 | x = self.conv4(x)
226 | if self.activations:
227 | x = self.sig(x)
228 | else:
229 | x = self.relu4_1(x)
230 | x = x.view(x.size(0), -1)
231 | x = self.embedding(x)
232 | extra_out = x
233 | clustering_out = self.clustering(x)
234 | x = self.deembedding(x)
235 | x = self.relu4_2(x)
236 | x = x.view(x.size(0), self.filters[3], ((self.input_shape[0]//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2-1) // 2))
237 | x = self.deconv4(x)
238 | x = self.relu3_2(x)
239 | x = self.deconv3(x)
240 | x = self.relu2_2(x)
241 | x = self.deconv2(x)
242 | x = self.relu1_2(x)
243 | x = self.deconv1(x)
244 | if self.activations:
245 | x = self.tanh(x)
246 | return x, clustering_out, extra_out
247 |
248 | # Convolutional autoencoder with 4 convolutional blocks (BN version)
249 | class CAE_bn4(nn.Module):
250 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256], leaky=True, neg_slope=0.01, activations=False, bias=True):
251 | super(CAE_bn4, self).__init__()
252 | self.activations = activations
253 | self.pretrained = False
254 | self.num_clusters = num_clusters
255 | self.input_shape = input_shape
256 | self.filters = filters
257 | if leaky:
258 | self.relu = nn.LeakyReLU(negative_slope=neg_slope)
259 | else:
260 | self.relu = nn.ReLU(inplace=False)
261 |
262 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias)
263 | self.bn1_1 = nn.BatchNorm2d(filters[0])
264 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias)
265 | self.bn2_1 = nn.BatchNorm2d(filters[1])
266 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias)
267 | self.bn3_1 = nn.BatchNorm2d(filters[2])
268 | self.conv4 = nn.Conv2d(filters[2], filters[3], 3, stride=2, padding=0, bias=bias)
269 |
270 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * \
271 | filters[3]
272 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias)
273 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias)
274 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0
275 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 3, stride=2, padding=0, output_padding=out_pad,
276 | bias=bias)
277 | self.bn4_2 = nn.BatchNorm2d(filters[2])
278 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0
279 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad,
280 | bias=bias)
281 | self.bn3_2 = nn.BatchNorm2d(filters[1])
282 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0
283 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad,
284 | bias=bias)
285 | self.bn2_2 = nn.BatchNorm2d(filters[0])
286 | out_pad = 1 if input_shape[0] % 2 == 0 else 0
287 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad,
288 | bias=bias)
289 | self.clustering = ClusterlingLayer(num_clusters, num_clusters)
290 | # ReLU copies for graph representation in tensorboard
291 | self.relu1_1 = copy.deepcopy(self.relu)
292 | self.relu2_1 = copy.deepcopy(self.relu)
293 | self.relu3_1 = copy.deepcopy(self.relu)
294 | self.relu4_1 = copy.deepcopy(self.relu)
295 | self.relu1_2 = copy.deepcopy(self.relu)
296 | self.relu2_2 = copy.deepcopy(self.relu)
297 | self.relu3_2 = copy.deepcopy(self.relu)
298 | self.relu4_2 = copy.deepcopy(self.relu)
299 | self.sig = nn.Sigmoid()
300 | self.tanh = nn.Tanh()
301 |
302 | def forward(self, x):
303 | x = self.conv1(x)
304 | x = self.relu1_1(x)
305 | x = self.bn1_1(x)
306 | x = self.conv2(x)
307 | x = self.relu2_1(x)
308 | x = self.bn2_1(x)
309 | x = self.conv3(x)
310 | x = self.relu3_1(x)
311 | x = self.bn3_1(x)
312 | x = self.conv4(x)
313 | if self.activations:
314 | x = self.sig(x)
315 | else:
316 | x = self.relu4_1(x)
317 | x = x.view(x.size(0), -1)
318 | x = self.embedding(x)
319 | extra_out = x
320 | clustering_out = self.clustering(x)
321 | x = self.deembedding(x)
322 | x = self.relu4_2(x)
323 | x = x.view(x.size(0), self.filters[3], ((self.input_shape[0]//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2-1) // 2))
324 | x = self.deconv4(x)
325 | x = self.relu3_2(x)
326 | x = self.bn4_2(x)
327 | x = self.deconv3(x)
328 | x = self.relu2_2(x)
329 | x = self.bn3_2(x)
330 | x = self.deconv2(x)
331 | x = self.relu1_2(x)
332 | x = self.bn2_2(x)
333 | x = self.deconv1(x)
334 | if self.activations:
335 | x = self.tanh(x)
336 | return x, clustering_out, extra_out
337 |
338 |
339 | # Convolutional autoencoder with 5 convolutional blocks
340 | class CAE_5(nn.Module):
341 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256, 512], leaky=True, neg_slope=0.01, activations=False, bias=True):
342 | super(CAE_5, self).__init__()
343 | self.activations = activations
344 | self.pretrained = False
345 | self.num_clusters = num_clusters
346 | self.input_shape = input_shape
347 | self.filters = filters
348 | self.relu = nn.ReLU(inplace=False)
349 | if leaky:
350 | self.relu = nn.LeakyReLU(negative_slope=neg_slope)
351 | else:
352 | self.relu = nn.ReLU(inplace=False)
353 |
354 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias)
355 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias)
356 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias)
357 | self.conv4 = nn.Conv2d(filters[2], filters[3], 5, stride=2, padding=2, bias=bias)
358 | self.conv5 = nn.Conv2d(filters[3], filters[4], 3, stride=2, padding=0, bias=bias)
359 |
360 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * (
361 | (input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * filters[4]
362 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias)
363 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias)
364 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 // 2 % 2 == 0 else 0
365 | self.deconv5 = nn.ConvTranspose2d(filters[4], filters[3], 3, stride=2, padding=0, output_padding=out_pad,
366 | bias=bias)
367 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0
368 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 5, stride=2, padding=2, output_padding=out_pad,
369 | bias=bias)
370 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0
371 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad,
372 | bias=bias)
373 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0
374 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad,
375 | bias=bias)
376 | out_pad = 1 if input_shape[0] % 2 == 0 else 0
377 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad,
378 | bias=bias)
379 | self.clustering = ClusterlingLayer(num_clusters, num_clusters)
380 | # ReLU copies for graph representation in tensorboard
381 | self.relu1_1 = copy.deepcopy(self.relu)
382 | self.relu2_1 = copy.deepcopy(self.relu)
383 | self.relu3_1 = copy.deepcopy(self.relu)
384 | self.relu4_1 = copy.deepcopy(self.relu)
385 | self.relu5_1 = copy.deepcopy(self.relu)
386 | self.relu1_2 = copy.deepcopy(self.relu)
387 | self.relu2_2 = copy.deepcopy(self.relu)
388 | self.relu3_2 = copy.deepcopy(self.relu)
389 | self.relu4_2 = copy.deepcopy(self.relu)
390 | self.relu5_2 = copy.deepcopy(self.relu)
391 | self.sig = nn.Sigmoid()
392 | self.tanh = nn.Tanh()
393 |
394 | def forward(self, x):
395 | x = self.conv1(x)
396 | x = self.relu1_1(x)
397 | x = self.conv2(x)
398 | x = self.relu2_1(x)
399 | x = self.conv3(x)
400 | x = self.relu3_1(x)
401 | x = self.conv4(x)
402 | x = self.relu4_1(x)
403 | x = self.conv5(x)
404 | if self.activations:
405 | x = self.sig(x)
406 | else:
407 | x = self.relu5_1(x)
408 | x = x.view(x.size(0), -1)
409 | x = self.embedding(x)
410 | extra_out = x
411 | clustering_out = self.clustering(x)
412 | x = self.deembedding(x)
413 | x = self.relu4_2(x)
414 | x = x.view(x.size(0), self.filters[4], ((self.input_shape[0]//2//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2//2-1) // 2))
415 | x = self.deconv5(x)
416 | x = self.relu4_2(x)
417 | x = self.deconv4(x)
418 | x = self.relu3_2(x)
419 | x = self.deconv3(x)
420 | x = self.relu2_2(x)
421 | x = self.deconv2(x)
422 | x = self.relu1_2(x)
423 | x = self.deconv1(x)
424 | if self.activations:
425 | x = self.tanh(x)
426 | return x, clustering_out, extra_out
427 |
428 |
429 | # Convolutional autoencoder with 5 convolutional blocks (BN version)
430 | class CAE_bn5(nn.Module):
431 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256, 512], leaky=True, neg_slope=0.01, activations=False, bias=True):
432 | super(CAE_bn5, self).__init__()
433 | self.activations = activations
434 | self.pretrained = False
435 | self.num_clusters = num_clusters
436 | self.input_shape = input_shape
437 | self.filters = filters
438 | self.relu = nn.ReLU(inplace=False)
439 | if leaky:
440 | self.relu = nn.LeakyReLU(negative_slope=neg_slope)
441 | else:
442 | self.relu = nn.ReLU(inplace=False)
443 |
444 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias)
445 | self.bn1_1 = nn.BatchNorm2d(filters[0])
446 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias)
447 | self.bn2_1 = nn.BatchNorm2d(filters[1])
448 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias)
449 | self.bn3_1 = nn.BatchNorm2d(filters[2])
450 | self.conv4 = nn.Conv2d(filters[2], filters[3], 5, stride=2, padding=2, bias=bias)
451 | self.bn4_1 = nn.BatchNorm2d(filters[3])
452 | self.conv5 = nn.Conv2d(filters[3], filters[4], 3, stride=2, padding=0, bias=bias)
453 |
454 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * (
455 | (input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * filters[4]
456 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias)
457 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias)
458 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 // 2 % 2 == 0 else 0
459 | self.deconv5 = nn.ConvTranspose2d(filters[4], filters[3], 3, stride=2, padding=0, output_padding=out_pad,
460 | bias=bias)
461 | self.bn5_2 = nn.BatchNorm2d(filters[3])
462 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0
463 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 5, stride=2, padding=2, output_padding=out_pad,
464 | bias=bias)
465 | self.bn4_2 = nn.BatchNorm2d(filters[2])
466 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0
467 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad,
468 | bias=bias)
469 | self.bn3_2 = nn.BatchNorm2d(filters[1])
470 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0
471 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad,
472 | bias=bias)
473 | self.bn2_2 = nn.BatchNorm2d(filters[0])
474 | out_pad = 1 if input_shape[0] % 2 == 0 else 0
475 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad,
476 | bias=bias)
477 | self.clustering = ClusterlingLayer(num_clusters, num_clusters)
478 | # ReLU copies for graph representation in tensorboard
479 | self.relu1_1 = copy.deepcopy(self.relu)
480 | self.relu2_1 = copy.deepcopy(self.relu)
481 | self.relu3_1 = copy.deepcopy(self.relu)
482 | self.relu4_1 = copy.deepcopy(self.relu)
483 | self.relu5_1 = copy.deepcopy(self.relu)
484 | self.relu1_2 = copy.deepcopy(self.relu)
485 | self.relu2_2 = copy.deepcopy(self.relu)
486 | self.relu3_2 = copy.deepcopy(self.relu)
487 | self.relu4_2 = copy.deepcopy(self.relu)
488 | self.relu5_2 = copy.deepcopy(self.relu)
489 | self.sig = nn.Sigmoid()
490 | self.tanh = nn.Tanh()
491 |
492 | def forward(self, x):
493 | x = self.conv1(x)
494 | x = self.relu1_1(x)
495 | x = self.bn1_1(x)
496 | x = self.conv2(x)
497 | x = self.relu2_1(x)
498 | x = self.bn2_1(x)
499 | x = self.conv3(x)
500 | x = self.relu3_1(x)
501 | x = self.bn3_1(x)
502 | x = self.conv4(x)
503 | x = self.relu4_1(x)
504 | x = self.bn4_1(x)
505 | x = self.conv5(x)
506 | if self.activations:
507 | x = self.sig(x)
508 | else:
509 | x = self.relu5_1(x)
510 | x = x.view(x.size(0), -1)
511 | x = self.embedding(x)
512 | extra_out = x
513 | clustering_out = self.clustering(x)
514 | x = self.deembedding(x)
515 | x = self.relu5_2(x)
516 | x = x.view(x.size(0), self.filters[4], ((self.input_shape[0]//2//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2//2-1) // 2))
517 | x = self.deconv5(x)
518 | x = self.relu4_2(x)
519 | x = self.bn5_2(x)
520 | x = self.deconv4(x)
521 | x = self.relu3_2(x)
522 | x = self.bn4_2(x)
523 | x = self.deconv3(x)
524 | x = self.relu2_2(x)
525 | x = self.bn3_2(x)
526 | x = self.deconv2(x)
527 | x = self.relu1_2(x)
528 | x = self.bn2_2(x)
529 | x = self.deconv1(x)
530 | if self.activations:
531 | x = self.tanh(x)
532 | return x, clustering_out, extra_out
533 |
--------------------------------------------------------------------------------
/semi_supervised.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | if __name__ == "__main__":
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torchvision import datasets, transforms
11 | import os
12 | import math
13 | import fnmatch
14 | import nets
15 | import utils
16 | import training_functions
17 | from tensorboardX import SummaryWriter
18 |
19 | # Translate string entries to bool for parser
20 | def str2bool(v):
21 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
22 | return True
23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
24 | return False
25 | else:
26 | raise argparse.ArgumentTypeError('Boolean value expected.')
27 |
28 | parser = argparse.ArgumentParser(description='Use DCEC for clustering')
29 | parser.add_argument('--mode', default='train_full', choices=['train_full', 'pretrain'], help='mode')
30 | parser.add_argument('--tensorboard', default=True, type=bool, help='export training stats to tensorboard')
31 | parser.add_argument('--pretrain', default=True, type=str2bool, help='perform autoencoder pretraining')
32 | parser.add_argument('--pretrained_net', default=1, help='index or path of pretrained net')
33 | parser.add_argument('--net_architecture', default='CAE_3', choices=['CAE_3', 'CAE_bn3', 'CAE_4', 'CAE_bn4', 'CAE_5', 'CAE_bn5'], help='network architecture used')
34 | parser.add_argument('--dataset', default='MNIST-train', choices=['MNIST-train', 'custom', 'MNIST-test', 'MNIST-full'],
35 | help='custom or prepared dataset')
36 | parser.add_argument('--dataset_path', default='data', help='path to dataset')
37 | parser.add_argument('--batch_size', default=256, type=int, help='batch size')
38 | parser.add_argument('--rate', default=0.001, type=float, help='learning rate for clustering')
39 | parser.add_argument('--rate_pretrain', default=0.001, type=float, help='learning rate for pretraining')
40 | parser.add_argument('--weight', default=0.0, type=float, help='weight decay for clustering')
41 | parser.add_argument('--weight_pretrain', default=0.0, type=float, help='weight decay for clustering')
42 | parser.add_argument('--sched_step', default=200, type=int, help='scheduler steps for rate update')
43 | parser.add_argument('--sched_step_pretrain', default=200, type=int,
44 | help='scheduler steps for rate update - pretrain')
45 | parser.add_argument('--sched_gamma', default=0.1, type=float, help='scheduler gamma for rate update')
46 | parser.add_argument('--sched_gamma_pretrain', default=0.1, type=float,
47 | help='scheduler gamma for rate update - pretrain')
48 | parser.add_argument('--epochs', default=1000, type=int, help='clustering epochs')
49 | parser.add_argument('--epochs_pretrain', default=300, type=int, help='pretraining epochs')
50 | parser.add_argument('--printing_frequency', default=10, type=int, help='training stats printing frequency')
51 | parser.add_argument('--gamma', default=0.1, type=float, help='clustering loss weight')
52 | parser.add_argument('--gamma_lab', default=0.01, type=float, help='labelled loss weight')
53 | parser.add_argument('--update_interval', default=80, type=int, help='update interval for target distribution')
54 | parser.add_argument('--label_upd_interval', default=1, type=int, help='update interval for target distribution')
55 | parser.add_argument('--tol', default=1e-3, type=float, help='stop criterium tolerance')
56 | parser.add_argument('--num_clusters', default=10, type=int, help='number of clusters')
57 | parser.add_argument('--custom_img_size', default=[128, 128, 3], nargs=3, type=int, help='size of custom images')
58 | parser.add_argument('--leaky', default=True, type=str2bool, help='use leaky version of relu')
59 | parser.add_argument('--neg_slope', default=0.01, type=float, help='negative slope for leaky relu')
60 | parser.add_argument('--activations', default=False, type=str2bool, help='use sigmoid and tanh activations in autoencoder')
61 | parser.add_argument('--bias', default=True, type=str2bool, help='use bias in layers')
62 | args = parser.parse_args()
63 | print(args)
64 |
65 | if args.mode == 'pretrain' and not args.pretrain:
66 | print("Nothing to do :(")
67 | exit()
68 |
69 | board = args.tensorboard
70 |
71 | # Deal with pretraining option and way of showing network path
72 | pretrain = args.pretrain
73 | net_is_path = True
74 | if not pretrain:
75 | try:
76 | int(args.pretrained_net)
77 | idx = args.pretrained_net
78 | net_is_path = False
79 | except:
80 | pass
81 | params = {'pretrain': pretrain}
82 |
83 | # Directories
84 | # Create directories structure
85 | dirs = ['runs', 'reports', 'nets']
86 | list(map(lambda x: os.makedirs(x, exist_ok=True), dirs))
87 |
88 | # Net architecture
89 | model_name = args.net_architecture
90 | # Indexing (for automated reports saving) - allows to run many trainings and get all the reports collected
91 | if pretrain or (not pretrain and net_is_path):
92 | reports_list = sorted(os.listdir('reports'), reverse=True)
93 | if reports_list:
94 | for file in reports_list:
95 | # print(file)
96 | if fnmatch.fnmatch(file, model_name+'*'):
97 | print(file)
98 | idx = int(str(file)[-7:-4]) + 1
99 | print(idx)
100 | break
101 | try:
102 | idx
103 | except NameError:
104 | idx = 1
105 |
106 | # Base filename
107 | name = model_name + '_' + str(idx).zfill(3)
108 |
109 | # Filenames for report and weights
110 | name_txt = name + '.txt'
111 | name_net = name
112 | pretrained = name + '_pretrained.pt'
113 |
114 | print(name_txt)
115 |
116 | # Arrange filenames for report, network weights, pretrained network weights
117 | name_txt = os.path.join('reports', name_txt)
118 | name_net = os.path.join('nets', name_net)
119 | if net_is_path and not pretrain:
120 | pretrained = args.pretrained_net
121 | else:
122 | pretrained = os.path.join('nets', pretrained)
123 | if not pretrain and not os.path.isfile(pretrained):
124 | print("No pretrained weights, try again choosing pretrained network or create new with pretrain=True")
125 |
126 | model_files = [name_net, pretrained]
127 | params['model_files'] = model_files
128 |
129 | # Open file
130 | if pretrain:
131 | f = open(name_txt, 'w')
132 | else:
133 | f = open(name_txt, 'a')
134 | params['txt_file'] = f
135 |
136 | # Delete tensorboard entry if exist (not to overlap as the charts become unreadable)
137 | try:
138 | os.system("rm -rf runs/" + name)
139 | except:
140 | pass
141 |
142 | # Initialize tensorboard writer
143 | if board:
144 | writer = SummaryWriter('runs/' + name)
145 | params['writer'] = writer
146 | else:
147 | params['writer'] = None
148 |
149 | # Hyperparameters
150 |
151 | # Used dataset
152 | dataset = args.dataset
153 |
154 | # Batch size
155 | batch = args.batch_size
156 | params['batch'] = batch
157 | # Number of workers (typically 4*num_of_GPUs)
158 | workers = 4
159 | # Learning rate
160 | rate = args.rate
161 | rate_pretrain = args.rate_pretrain
162 | # Adam params
163 | # Weight decay
164 | weight = args.weight
165 | weight_pretrain = args.weight_pretrain
166 | # Scheduler steps for rate update
167 | sched_step = args.sched_step
168 | sched_step_pretrain = args.sched_step_pretrain
169 | # Scheduler gamma - multiplier for learning rate
170 | sched_gamma = args.sched_gamma
171 | sched_gamma_pretrain = args.sched_gamma_pretrain
172 |
173 | # Number of epochs
174 | epochs = args.epochs
175 | pretrain_epochs = args.epochs_pretrain
176 | params['pretrain_epochs'] = pretrain_epochs
177 |
178 | # Printing frequency
179 | print_freq = args.printing_frequency
180 | params['print_freq'] = print_freq
181 |
182 | # Clustering loss weight:
183 | gamma = args.gamma
184 | params['gamma'] = gamma
185 |
186 | # Labelled loss weight:
187 | gamma_lab = args.gamma_lab
188 | params['gamma_lab'] = gamma_lab
189 |
190 | # Update interval for target distribution:
191 | update_interval = args.update_interval
192 | params['update_interval'] = update_interval
193 |
194 | label_upd_interval = args.label_upd_interval
195 | params['label_upd_interval'] = label_upd_interval
196 |
197 | # Tolerance for label changes:
198 | tol = args.tol
199 | params['tol'] = tol
200 |
201 | # Number of clusters
202 | num_clusters = args.num_clusters
203 |
204 | # Report for settings
205 | tmp = "Training the '" + model_name + "' architecture"
206 | utils.print_both(f, tmp)
207 | tmp = "\n" + "The following parameters are used:"
208 | utils.print_both(f, tmp)
209 | tmp = "Batch size:\t" + str(batch)
210 | utils.print_both(f, tmp)
211 | tmp = "Number of workers:\t" + str(workers)
212 | utils.print_both(f, tmp)
213 | tmp = "Learning rate:\t" + str(rate)
214 | utils.print_both(f, tmp)
215 | tmp = "Pretraining learning rate:\t" + str(rate_pretrain)
216 | utils.print_both(f, tmp)
217 | tmp = "Weight decay:\t" + str(weight)
218 | utils.print_both(f, tmp)
219 | tmp = "Pretraining weight decay:\t" + str(weight_pretrain)
220 | utils.print_both(f, tmp)
221 | tmp = "Scheduler steps:\t" + str(sched_step)
222 | utils.print_both(f, tmp)
223 | tmp = "Scheduler gamma:\t" + str(sched_gamma)
224 | utils.print_both(f, tmp)
225 | tmp = "Pretraining scheduler steps:\t" + str(sched_step_pretrain)
226 | utils.print_both(f, tmp)
227 | tmp = "Pretraining scheduler gamma:\t" + str(sched_gamma_pretrain)
228 | utils.print_both(f, tmp)
229 | tmp = "Number of epochs of training:\t" + str(epochs)
230 | utils.print_both(f, tmp)
231 | tmp = "Number of epochs of pretraining:\t" + str(pretrain_epochs)
232 | utils.print_both(f, tmp)
233 | tmp = "Clustering loss weight:\t" + str(gamma)
234 | utils.print_both(f, tmp)
235 | tmp = "Labelled loss weight:\t" + str(gamma_lab)
236 | utils.print_both(f, tmp)
237 | tmp = "Update interval for target distribution:\t" + str(update_interval)
238 | utils.print_both(f, tmp)
239 | tmp = "Update interval for labelled loss:\t" + str(label_upd_interval)
240 | utils.print_both(f, tmp)
241 | tmp = "Stop criterium tolerance:\t" + str(tol)
242 | utils.print_both(f, tmp)
243 | tmp = "Number of clusters:\t" + str(num_clusters)
244 | utils.print_both(f, tmp)
245 | tmp = "Leaky relu:\t" + str(args.leaky)
246 | utils.print_both(f, tmp)
247 | tmp = "Leaky slope:\t" + str(args.neg_slope)
248 | utils.print_both(f, tmp)
249 | tmp = "Activations:\t" + str(args.activations)
250 | utils.print_both(f, tmp)
251 | tmp = "Bias:\t" + str(args.bias)
252 | utils.print_both(f, tmp)
253 |
254 | # Data preparation
255 | if dataset == 'MNIST-train':
256 | # Uses slightly modified torchvision MNIST class and creates dataloader with whole sets
257 | # and sets of 2% of data (as labelled)
258 | import mnist
259 | tmp = "\nData preparation\nReading data from: MNIST train dataset"
260 | utils.print_both(f, tmp)
261 | img_size = [28, 28, 1]
262 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1])
263 | utils.print_both(f, tmp)
264 |
265 | dataset = mnist.MNIST('../data', train=True, download=True,
266 | transform=transforms.Compose([
267 | transforms.ToTensor(),
268 | # transforms.Normalize((0.1307,), (0.3081,))
269 | ]))
270 |
271 | dataloader = torch.utils.data.DataLoader(dataset,
272 | batch_size=batch, shuffle=False, num_workers=workers)
273 |
274 | dataset_size = len(dataset)
275 | tmp = "Training set size:\t" + str(dataset_size)
276 | utils.print_both(f, tmp)
277 |
278 | dataset_labelled = mnist.MNIST('../data', train=True, download=True, small=True,
279 | transform=transforms.Compose([
280 | transforms.ToTensor(),
281 | # transforms.Normalize((0.1307,), (0.3081,))
282 | ]))
283 |
284 | dataloader_labelled = torch.utils.data.DataLoader(dataset_labelled,
285 | batch_size=batch, shuffle=False, num_workers=workers)
286 |
287 | dataset_labelled_size = len(dataset_labelled)
288 | tmp = "Training set labelled size:\t" + str(dataset_labelled_size)
289 | utils.print_both(f, tmp)
290 |
291 | elif dataset == 'MNIST-test':
292 | import mnist
293 | tmp = "\nData preparation\nReading data from: MNIST test dataset"
294 | utils.print_both(f, tmp)
295 | img_size = [28, 28, 1]
296 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1])
297 | utils.print_both(f, tmp)
298 |
299 | dataset = mnist.MNIST('../data', train=False, download=True,
300 | transform=transforms.Compose([
301 | transforms.ToTensor(),
302 | # transforms.Normalize((0.1307,), (0.3081,))
303 | ]))
304 |
305 | dataloader = torch.utils.data.DataLoader(dataset,
306 | batch_size=batch, shuffle=False, num_workers=workers)
307 |
308 | dataset_size = len(dataset)
309 | tmp = "Training set size:\t" + str(dataset_size)
310 | utils.print_both(f, tmp)
311 |
312 | dataset_labelled = mnist.MNIST('../data', train=False, download=True, small=True,
313 | transform=transforms.Compose([
314 | transforms.ToTensor(),
315 | # transforms.Normalize((0.1307,), (0.3081,))
316 | ]))
317 |
318 | dataloader_labelled = torch.utils.data.DataLoader(dataset_labelled,
319 | batch_size=batch, shuffle=False, num_workers=workers)
320 |
321 | dataset_labelled_size = len(dataset_labelled)
322 | tmp = "Training set labelled size:\t" + str(dataset_labelled_size)
323 | utils.print_both(f, tmp)
324 |
325 | elif dataset == 'MNIST-full':
326 | import mnist
327 | tmp = "\nData preparation\nReading data from: MNIST full dataset"
328 | utils.print_both(f, tmp)
329 | img_size = [28, 28, 1]
330 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1])
331 | utils.print_both(f, tmp)
332 |
333 | dataset = mnist.MNIST('../data', full=True, download=True,
334 | transform=transforms.Compose([
335 | transforms.ToTensor(),
336 | # transforms.Normalize((0.1307,), (0.3081,))
337 | ]))
338 |
339 | dataloader = torch.utils.data.DataLoader(dataset,
340 | batch_size=batch, shuffle=False, num_workers=workers)
341 |
342 | dataset_size = len(dataset)
343 | tmp = "Training set size:\t" + str(dataset_size)
344 | utils.print_both(f, tmp)
345 |
346 | dataset_labelled = mnist.MNIST('../data', full=True, download=True, small=True,
347 | transform=transforms.Compose([
348 | transforms.ToTensor(),
349 | # transforms.Normalize((0.1307,), (0.3081,))
350 | ]))
351 |
352 | dataloader_labelled = torch.utils.data.DataLoader(dataset_labelled,
353 | batch_size=batch, shuffle=False, num_workers=workers)
354 |
355 | dataset_labelled_size = len(dataset_labelled)
356 | tmp = "Training set labelled size:\t" + str(dataset_labelled_size)
357 | utils.print_both(f, tmp)
358 |
359 | else:
360 | # Custom dataset - arrange folders acording to README
361 |
362 | # Data folder
363 | data_dir = args.dataset_path
364 | tmp = "\nData preparation\nReading data from:\t./" + data_dir
365 | utils.print_both(f, tmp)
366 |
367 | # Image size
368 | custom_size = math.nan
369 | custom_size = args.custom_img_size
370 | if isinstance(custom_size, list):
371 | img_size = custom_size
372 |
373 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1])
374 | utils.print_both(f, tmp)
375 |
376 | # Transformations
377 | data_transforms = transforms.Compose([
378 | transforms.Resize(img_size[0:2]),
379 | # transforms.RandomHorizontalFlip(),
380 | transforms.ToTensor(),
381 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
382 | ])
383 |
384 | # Read data from selected folder and apply transformations
385 | image_dataset = datasets.ImageFolder(data_dir, data_transforms)
386 | # Prepare data for network: schuffle and arrange batches
387 | dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch,
388 | shuffle=False, num_workers=workers)
389 |
390 | # Size of data sets
391 | dataset_size = len(image_dataset)
392 | tmp = "Training set size:\t" + str(dataset_size)
393 | utils.print_both(f, tmp)
394 |
395 | # Read data from selected folder and apply transformations
396 | image_dataset_l = datasets.ImageFolder(data_dir+'_l', data_transforms)
397 | # Prepare data for network: schuffle and arrange batches
398 | dataloader_labelled = torch.utils.data.DataLoader(image_dataset_l, batch_size=batch,
399 | shuffle=False, num_workers=workers)
400 | dataset_labelled_size = len(image_dataset_l)
401 |
402 | params['dataset_size'] = dataset_size
403 | params['dataset_labelled_size'] = dataset_labelled_size
404 |
405 | # GPU check
406 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
407 | tmp = "\nPerforming calculations on:\t" + str(device)
408 | utils.print_both(f, tmp + '\n')
409 | params['device'] = device
410 |
411 | # Evaluate the proper model
412 | to_eval = "nets." + model_name + "(img_size, num_clusters=num_clusters, leaky = args.leaky, neg_slope = args.neg_slope)"
413 | model = eval(to_eval)
414 |
415 | # Tensorboard model representation
416 | # if board:
417 | # writer.add_graph(model, torch.autograd.Variable(torch.Tensor(batch, img_size[2], img_size[0], img_size[1])))
418 |
419 | model = model.to(device)
420 | # Reconstruction loss
421 | criterion_1 = nn.MSELoss(size_average=True)
422 | # Clustering loss
423 | criterion_2 = nn.KLDivLoss(size_average=False)
424 | # Labelled loss
425 | criterion_3 = nn.CrossEntropyLoss(size_average=False)
426 |
427 | criteria = [criterion_1, criterion_2, criterion_3]
428 |
429 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=rate, weight_decay=weight)
430 |
431 | optimizer_pretrain = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=rate_pretrain, weight_decay=weight_pretrain)
432 |
433 | optimizers = [optimizer, optimizer_pretrain]
434 |
435 | scheduler = lr_scheduler.StepLR(optimizer, step_size=sched_step, gamma=sched_gamma)
436 | scheduler_pretrain = lr_scheduler.StepLR(optimizer_pretrain, step_size=sched_step_pretrain, gamma=sched_gamma_pretrain)
437 |
438 | schedulers = [scheduler, scheduler_pretrain]
439 |
440 | if args.mode == 'train_full':
441 | model = training_functions.train_semisupervised(model, [dataloader, dataloader_labelled], criteria, optimizers, schedulers, epochs, params)
442 | elif args.mode == 'pretrain':
443 | model = training_functions.pretraining(model, [dataloader, dataloader_labelled], criteria, optimizers, schedulers, epochs, params)
444 |
445 | # Save final model
446 | torch.save(model.state_dict(), name_net + '.pt')
447 |
448 | # Close files
449 | f.close()
450 | if board:
451 | writer.close()
452 |
453 |
--------------------------------------------------------------------------------
/training_functions.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import time
3 | import torch
4 | import numpy as np
5 | import copy
6 | from sklearn.cluster import KMeans
7 |
8 |
9 | # Training function (from my torch_DCEC implementation, kept for completeness)
10 | def train_model(model, dataloader, criteria, optimizers, schedulers, num_epochs, params):
11 |
12 | # Note the time
13 | since = time.time()
14 |
15 | # Unpack parameters
16 | writer = params['writer']
17 | if writer is not None: board = True
18 | txt_file = params['txt_file']
19 | pretrained = params['model_files'][1]
20 | pretrain = params['pretrain']
21 | print_freq = params['print_freq']
22 | dataset_size = params['dataset_size']
23 | device = params['device']
24 | batch = params['batch']
25 | pretrain_epochs = params['pretrain_epochs']
26 | gamma = params['gamma']
27 | update_interval = params['update_interval']
28 | tol = params['tol']
29 |
30 | dl = dataloader
31 |
32 | # Pretrain or load weights
33 | if pretrain:
34 | while True:
35 | pretrained_model = pretraining(model, copy.deepcopy(dl), criteria[0], optimizers[1], schedulers[1], pretrain_epochs, params)
36 | if pretrained_model:
37 | break
38 | else:
39 | for layer in model.children():
40 | if hasattr(layer, 'reset_parameters'):
41 | layer.reset_parameters()
42 | model = pretrained_model
43 | else:
44 | try:
45 | model.load_state_dict(torch.load(pretrained))
46 | utils.print_both(txt_file, 'Pretrained weights loaded from file: ' + str(pretrained))
47 | except:
48 | print("Couldn't load pretrained weights")
49 |
50 | # Initialise clusters
51 | utils.print_both(txt_file, '\nInitializing cluster centers based on K-means')
52 | kmeans(model, copy.deepcopy(dl), params)
53 |
54 | utils.print_both(txt_file, '\nBegin clusters training')
55 |
56 | # Prep variables for weights and accuracy of the best model
57 | best_model_wts = copy.deepcopy(model.state_dict())
58 | best_loss = 10000.0
59 |
60 | # Initial target distribution
61 | utils.print_both(txt_file, '\nUpdating target distribution')
62 | output_distribution, labels, preds_prev = calculate_predictions(model, copy.deepcopy(dl), params)
63 | target_distribution = target(output_distribution)
64 | nmi = utils.metrics.nmi(labels, preds_prev)
65 | ari = utils.metrics.ari(labels, preds_prev)
66 | acc = utils.metrics.acc(labels, preds_prev)
67 | utils.print_both(txt_file,
68 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\n'.format(nmi, ari, acc))
69 |
70 | if board:
71 | niter = 0
72 | writer.add_scalar('/NMI', nmi, niter)
73 | writer.add_scalar('/ARI', ari, niter)
74 | writer.add_scalar('/Acc', acc, niter)
75 |
76 | update_iter = 1
77 | finished = False
78 |
79 | # Go through all epochs
80 | for epoch in range(num_epochs):
81 |
82 | utils.print_both(txt_file, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
83 | utils.print_both(txt_file, '-' * 10)
84 |
85 | schedulers[0].step()
86 | model.train(True) # Set model to training mode
87 |
88 | running_loss = 0.0
89 | running_loss_rec = 0.0
90 | running_loss_clust = 0.0
91 |
92 | # Keep the batch number for inter-phase statistics
93 | batch_num = 1
94 | img_counter = 0
95 |
96 | # Iterate over data.
97 | for data in dataloader:
98 | # Get the inputs and labels
99 | inputs, _ = data
100 |
101 | inputs = inputs.to(device)
102 |
103 | # Uptade target distribution, chack and print performance
104 | if (batch_num - 1) % update_interval == 0 and not (batch_num == 1 and epoch == 0):
105 | utils.print_both(txt_file, '\nUpdating target distribution:')
106 | output_distribution, labels, preds = calculate_predictions(model, dataloader, params)
107 | target_distribution = target(output_distribution)
108 | nmi = utils.metrics.nmi(labels, preds)
109 | ari = utils.metrics.ari(labels, preds)
110 | acc = utils.metrics.acc(labels, preds)
111 | utils.print_both(txt_file,
112 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\t'.format(nmi, ari, acc))
113 | if board:
114 | niter = update_iter
115 | writer.add_scalar('/NMI', nmi, niter)
116 | writer.add_scalar('/ARI', ari, niter)
117 | writer.add_scalar('/Acc', acc, niter)
118 | update_iter += 1
119 |
120 | # check stop criterion
121 | delta_label = np.sum(preds != preds_prev).astype(np.float32) / preds.shape[0]
122 | preds_prev = np.copy(preds)
123 | if delta_label < tol:
124 | utils.print_both(txt_file, 'Label divergence ' + str(delta_label) + '< tol ' + str(tol))
125 | utils.print_both(txt_file, 'Reached tolerance threshold. Stopping training.')
126 | finished = True
127 | break
128 |
129 | tar_dist = target_distribution[((batch_num - 1) * batch):(batch_num*batch), :]
130 | tar_dist = torch.from_numpy(tar_dist).to(device)
131 | # print(tar_dist)
132 |
133 | # zero the parameter gradients
134 | optimizers[0].zero_grad()
135 |
136 | # Calculate losses and backpropagate
137 | with torch.set_grad_enabled(True):
138 | outputs, clusters, _ = model(inputs)
139 | loss_rec = criteria[0](outputs, inputs)
140 | loss_clust = gamma *criteria[1](torch.log(clusters), tar_dist) / batch
141 | loss = loss_rec + loss_clust
142 | loss.backward()
143 | optimizers[0].step()
144 |
145 | # For keeping statistics
146 | running_loss += loss.item() * inputs.size(0)
147 | running_loss_rec += loss_rec.item() * inputs.size(0)
148 | running_loss_clust += loss_rec.item() * inputs.size(0)
149 |
150 | # Some current stats
151 | loss_batch = loss.item()
152 | loss_batch_rec = loss_rec.item()
153 | loss_batch_clust = loss_clust.item()
154 | loss_accum = running_loss / ((batch_num - 1) * batch + inputs.size(0))
155 | loss_accum_rec = running_loss_rec / ((batch_num - 1) * batch + inputs.size(0))
156 | loss_accum_clust = running_loss_clust / ((batch_num - 1) * batch + inputs.size(0))
157 |
158 | if batch_num % print_freq == 0:
159 | utils.print_both(txt_file, 'Epoch: [{0}][{1}/{2}]\t'
160 | 'Loss {3:.4f} ({4:.4f})\t'
161 | 'Loss_recovery {5:.4f} ({6:.4f})\t'
162 | 'Loss clustering {7:.4f} ({8:.4f})\t'.format(epoch + 1, batch_num,
163 | len(dataloader),
164 | loss_batch,
165 | loss_accum, loss_batch_rec,
166 | loss_accum_rec,
167 | loss_batch_clust,
168 | loss_accum_clust))
169 | if board:
170 | niter = epoch * len(dataloader) + batch_num
171 | writer.add_scalar('/Loss', loss_accum, niter)
172 | writer.add_scalar('/Loss_recovery', loss_accum_rec, niter)
173 | writer.add_scalar('/Loss_clustering', loss_accum_clust, niter)
174 | batch_num = batch_num + 1
175 |
176 | # Print image to tensorboard
177 | if batch_num == len(dataloader) and (epoch+1) % 5:
178 | inp = utils.tensor2img(inputs)
179 | out = utils.tensor2img(outputs)
180 | if board:
181 | img = np.concatenate((inp, out), axis=2)
182 | writer.add_image('Clustering/Epoch_' + str(epoch + 1).zfill(3) + '/Sample_' + str(img_counter).zfill(2), img)
183 | img_counter += 1
184 |
185 | if finished: break
186 |
187 | epoch_loss = running_loss / dataset_size
188 | epoch_loss_rec = running_loss_rec / dataset_size
189 | epoch_loss_clust = running_loss_clust / dataset_size
190 |
191 | if board:
192 | writer.add_scalar('/Loss' + '/Epoch', epoch_loss, epoch + 1)
193 | writer.add_scalar('/Loss_rec' + '/Epoch', epoch_loss_rec, epoch + 1)
194 | writer.add_scalar('/Loss_clust' + '/Epoch', epoch_loss_clust, epoch + 1)
195 |
196 | utils.print_both(txt_file, 'Loss: {0:.4f}\tLoss_recovery: {1:.4f}\tLoss_clustering: {2:.4f}'.format(epoch_loss,
197 | epoch_loss_rec,
198 | epoch_loss_clust))
199 |
200 | # If wanted to do some criterium in the future (for now useless)
201 | if epoch_loss < best_loss or epoch_loss >= best_loss:
202 | best_loss = epoch_loss
203 | best_model_wts = copy.deepcopy(model.state_dict())
204 |
205 | utils.print_both(txt_file, '')
206 |
207 | time_elapsed = time.time() - since
208 | utils.print_both(txt_file, 'Training complete in {:.0f}m {:.0f}s'.format(
209 | time_elapsed // 60, time_elapsed % 60))
210 |
211 | # load best model weights
212 | model.load_state_dict(best_model_wts)
213 | return model
214 |
215 |
216 | # Training function (proper semisupervised training)
217 | def train_semisupervised(model, dataloaders, criteria, optimizers, schedulers, num_epochs, params):
218 |
219 | # Note the time
220 | since = time.time()
221 |
222 | # Unpack parameters
223 | writer = params['writer']
224 | if writer is not None: board = True
225 | txt_file = params['txt_file']
226 | pretrained = params['model_files'][1]
227 | pretrain = params['pretrain']
228 | print_freq = params['print_freq']
229 | dataset_size = params['dataset_size']
230 | dataset_labelled_size = params['dataset_labelled_size']
231 | device = params['device']
232 | batch = params['batch']
233 | pretrain_epochs = params['pretrain_epochs']
234 | gamma = params['gamma']
235 | gamma_lab = params['gamma_lab']
236 | update_interval = params['update_interval']
237 | tol = params['tol']
238 | label_upd_interval = params['label_upd_interval']
239 |
240 | dataloader = dataloaders[0]
241 | dataloader_labelled = dataloaders[1]
242 |
243 | dl = dataloader
244 |
245 | # Pretrain or load weights
246 | if pretrain:
247 | while True:
248 | pretrained_model = pretraining(model, copy.deepcopy(dl), criteria[0], optimizers[1], schedulers[1], pretrain_epochs, params)
249 | if pretrained_model:
250 | break
251 | else:
252 | for layer in model.children():
253 | if hasattr(layer, 'reset_parameters'):
254 | layer.reset_parameters()
255 | model = pretrained_model
256 | else:
257 | try:
258 | model.load_state_dict(torch.load(pretrained))
259 | utils.print_both(txt_file, 'Pretrained weights loaded from file: ' + str(pretrained))
260 | except:
261 | print("Couldn't load pretrained weights")
262 |
263 | # Initialise clusters
264 | utils.print_both(txt_file, '\nInitializing cluster centers based on average')
265 | average_labelled_dist(model, copy.deepcopy(dataloader_labelled), params)
266 |
267 | utils.print_both(txt_file, '\nBegin clusters training')
268 |
269 | # Prep variables for weights and accuracy of the best model
270 | best_model_wts = copy.deepcopy(model.state_dict())
271 | best_loss = 10000.0
272 |
273 | # Initial target distribution
274 | utils.print_both(txt_file, '\nUpdating target distribution')
275 | output_distribution, labels, preds_prev = calculate_predictions(model, copy.deepcopy(dl), params)
276 | target_distribution = target(output_distribution)
277 | nmi = utils.metrics.nmi(labels, preds_prev)
278 | ari = utils.metrics.ari(labels, preds_prev)
279 | acc = utils.metrics.acc(labels, preds_prev)
280 | utils.print_both(txt_file,
281 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\n'.format(nmi, ari, acc))
282 |
283 | if board:
284 | niter = 0
285 | writer.add_scalar('/NMI', nmi, niter)
286 | writer.add_scalar('/ARI', ari, niter)
287 | writer.add_scalar('/Acc', acc, niter)
288 |
289 | update_iter = 1
290 | finished = False
291 |
292 | # Go through all epochs
293 | for epoch in range(num_epochs):
294 |
295 | utils.print_both(txt_file, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
296 | utils.print_both(txt_file, '-' * 10)
297 |
298 | schedulers[0].step()
299 | model.train(True) # Set model to training mode
300 |
301 | running_loss = 0.0
302 | running_loss_rec = 0.0
303 | running_loss_clust = 0.0
304 | running_loss_labels = 0.0
305 |
306 | # Keep the batch number for inter-phase statistics
307 | batch_num = 1
308 | img_counter = 0
309 |
310 | # print(dataloader)
311 | # Iterate over data.
312 | for data in dataloader:
313 | # Get the inputs and labels
314 | inputs, _ = data
315 |
316 | inputs = inputs.to(device)
317 |
318 | # Uptade target distribution, chack and print performance
319 | if (batch_num - 1) % update_interval == 0 and not (batch_num == 1 and epoch == 0):
320 | utils.print_both(txt_file, '\nUpdating target distribution:')
321 | output_distribution, labels, preds = calculate_predictions(model, dataloader, params)
322 | target_distribution = target(output_distribution)
323 | nmi = utils.metrics.nmi(labels, preds)
324 | ari = utils.metrics.ari(labels, preds)
325 | acc = utils.metrics.acc(labels, preds)
326 | utils.print_both(txt_file,
327 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\t'.format(nmi, ari, acc))
328 | if board:
329 | niter = update_iter
330 | writer.add_scalar('/NMI', nmi, niter)
331 | writer.add_scalar('/ARI', ari, niter)
332 | writer.add_scalar('/Acc', acc, niter)
333 | update_iter += 1
334 |
335 | # check stop criterion
336 | delta_label = np.sum(preds != preds_prev).astype(np.float32) / preds.shape[0]
337 | preds_prev = np.copy(preds)
338 | if delta_label < tol:
339 | utils.print_both(txt_file, 'Label divergence ' + str(delta_label) + ' < tol ' + str(tol))
340 | utils.print_both(txt_file, 'Reached tolerance threshold. Stopping training.')
341 | finished = True
342 | break
343 |
344 | tar_dist = target_distribution[((batch_num - 1) * batch):(batch_num*batch), :]
345 | tar_dist = torch.from_numpy(tar_dist).to(device)
346 | # print(tar_dist)
347 |
348 | loss_labelled = 0
349 |
350 | # zero the parameter gradients
351 | optimizers[0].zero_grad()
352 |
353 | # Calculate losses and backpropagate
354 | with torch.set_grad_enabled(True):
355 | if (batch_num - 1) % label_upd_interval == 0 and not (batch_num == 1 and epoch == 0):
356 | # utils.print_both(txt_file, '\nUpdating labelled loss:')
357 | size = 0
358 | # Iterate through labelled part of the set
359 | for d in dataloader_labelled:
360 | inp, lab = d
361 | inp = inp.to(params['device'])
362 | lab = lab.to(params['device'])
363 | _, outs, _ = model(inp)
364 | loss_labelled += criteria[2](outs, lab)
365 | size += inp.size(0)
366 | loss_labelled = loss_labelled / size * gamma_lab
367 |
368 | outputs, clusters, _ = model(inputs)
369 | loss_rec = criteria[0](outputs, inputs)
370 | loss_clust = gamma *criteria[1](torch.log(clusters), tar_dist) / batch
371 | loss = loss_rec + loss_clust + loss_labelled
372 | loss.backward()
373 | optimizers[0].step()
374 |
375 | # For keeping statistics
376 | running_loss += loss.item() * inputs.size(0)
377 | running_loss_rec += loss_rec.item() * inputs.size(0)
378 | running_loss_clust += loss_rec.item() * inputs.size(0)
379 | running_loss_labels += loss_labelled * inputs.size(0)
380 |
381 | # Some current stats
382 | loss_batch = loss.item()
383 | loss_batch_rec = loss_rec.item()
384 | loss_batch_clust = loss_clust.item()
385 | loss_batch_labels = loss_labelled
386 | loss_accum = running_loss / ((batch_num - 1) * batch + inputs.size(0))
387 | loss_accum_rec = running_loss_rec / ((batch_num - 1) * batch + inputs.size(0))
388 | loss_accum_clust = running_loss_clust / ((batch_num - 1) * batch + inputs.size(0))
389 | loss_accum_labels = running_loss_labels / ((batch_num - 1) * batch + inputs.size(0))
390 |
391 | if batch_num % print_freq == 0:
392 | utils.print_both(txt_file, 'Epoch: [{0}][{1}/{2}]\t'
393 | 'Loss {3:.4f} ({4:.4f})\t'
394 | 'Loss_recovery {5:.4f} ({6:.4f})\t'
395 | 'Loss clustering {7:.4f} ({8:.4f})\t'
396 | 'Loss labels {9:.4f} ({10:.4f})\t'.format(epoch + 1, batch_num,
397 | len(dataloader),
398 | loss_batch,
399 | loss_accum, loss_batch_rec,
400 | loss_accum_rec,
401 | loss_batch_clust,
402 | loss_accum_clust,
403 | loss_batch_labels,
404 | loss_accum_labels))
405 | if board:
406 | niter = epoch * len(dataloader) + batch_num
407 | writer.add_scalar('/Loss', loss_accum, niter)
408 | writer.add_scalar('/Loss_recovery', loss_accum_rec, niter)
409 | writer.add_scalar('/Loss_clustering', loss_accum_clust, niter)
410 | writer.add_scalar('/Loss_labels', loss_accum_labels, niter)
411 | batch_num = batch_num + 1
412 |
413 | # Print image to tensorboard
414 | if batch_num == len(dataloader) and (epoch+1) % 5:
415 | inp = utils.tensor2img(inputs)
416 | out = utils.tensor2img(outputs)
417 | if board:
418 | img = np.concatenate((inp, out), axis=2)
419 | writer.add_image('Clustering/Epoch_' + str(epoch + 1).zfill(3) + '/Sample_' + str(img_counter).zfill(2), img)
420 | img_counter += 1
421 |
422 | if finished: break
423 |
424 | epoch_loss = running_loss / dataset_size
425 | epoch_loss_rec = running_loss_rec / dataset_size
426 | epoch_loss_clust = running_loss_clust / dataset_size
427 | epoch_loss_labels = running_loss_labels / dataset_size
428 |
429 | if board:
430 | writer.add_scalar('/Loss' + '/Epoch', epoch_loss, epoch + 1)
431 | writer.add_scalar('/Loss_rec' + '/Epoch', epoch_loss_rec, epoch + 1)
432 | writer.add_scalar('/Loss_clust' + '/Epoch', epoch_loss_clust, epoch + 1)
433 | writer.add_scalar('/Loss_label' + '/Epoch', epoch_loss_labels, epoch + 1)
434 |
435 | utils.print_both(txt_file, 'Loss: {0:.4f}\tLoss_recovery: {1:.4f}\tLoss_clustering: {2:.4f}\tLoss labels: {3:.4f}'.format(
436 | epoch_loss,
437 | epoch_loss_rec,
438 | epoch_loss_clust, epoch_loss_labels))
439 |
440 | # If wanted to do some criterium in the future (for now useless)
441 | if epoch_loss < best_loss or epoch_loss >= best_loss:
442 | best_loss = epoch_loss
443 | best_model_wts = copy.deepcopy(model.state_dict())
444 |
445 | utils.print_both(txt_file, '')
446 |
447 | time_elapsed = time.time() - since
448 | utils.print_both(txt_file, 'Training complete in {:.0f}m {:.0f}s'.format(
449 | time_elapsed // 60, time_elapsed % 60))
450 |
451 | # load best model weights
452 | model.load_state_dict(best_model_wts)
453 | return model
454 |
455 |
456 | # Pretraining function for recovery loss only
457 | def pretraining(model, dataloader, criterion, optimizer, scheduler, num_epochs, params):
458 | # Note the time
459 | since = time.time()
460 |
461 | # Unpack parameters
462 | writer = params['writer']
463 | if writer is not None: board = True
464 | txt_file = params['txt_file']
465 | pretrained = params['model_files'][1]
466 | print_freq = params['print_freq']
467 | dataset_size = params['dataset_size']
468 | device = params['device']
469 | batch = params['batch']
470 |
471 | # Prep variables for weights and accuracy of the best model
472 | best_model_wts = copy.deepcopy(model.state_dict())
473 | best_loss = 10000.0
474 |
475 | # Go through all epochs
476 | for epoch in range(num_epochs):
477 | utils.print_both(txt_file, 'Pretraining:\tEpoch {}/{}'.format(epoch + 1, num_epochs))
478 | utils.print_both(txt_file, '-' * 10)
479 |
480 | scheduler.step()
481 | model.train(True) # Set model to training mode
482 |
483 | running_loss = 0.0
484 |
485 | # Keep the batch number for inter-phase statistics
486 | batch_num = 1
487 | # Images to show
488 | img_counter = 0
489 |
490 | # Iterate over data.
491 | for data in dataloader:
492 | # Get the inputs and labels
493 | inputs, _ = data
494 | inputs = inputs.to(device)
495 |
496 | # zero the parameter gradients
497 | optimizer.zero_grad()
498 |
499 | with torch.set_grad_enabled(True):
500 | outputs, _, _ = model(inputs)
501 | loss = criterion(outputs, inputs)
502 | loss.backward()
503 | optimizer.step()
504 |
505 | # For keeping statistics
506 | running_loss += loss.item() * inputs.size(0)
507 |
508 | # Some current stats
509 | loss_batch = loss.item()
510 | loss_accum = running_loss / ((batch_num - 1) * batch + inputs.size(0))
511 |
512 | if batch_num % print_freq == 0:
513 | utils.print_both(txt_file, 'Pretraining:\tEpoch: [{0}][{1}/{2}]\t'
514 | 'Loss {3:.4f} ({4:.4f})\t'.format(epoch + 1, batch_num, len(dataloader),
515 | loss_batch,
516 | loss_accum))
517 | if board:
518 | niter = epoch * len(dataloader) + batch_num
519 | writer.add_scalar('Pretraining/Loss', loss_accum, niter)
520 | batch_num = batch_num + 1
521 |
522 | if batch_num in [len(dataloader), len(dataloader)//2, len(dataloader)//4, 3*len(dataloader)//4]:
523 | inp = utils.tensor2img(inputs)
524 | out = utils.tensor2img(outputs)
525 | if board:
526 | img = np.concatenate((inp, out), axis=2)
527 | writer.add_image('Pretraining/Epoch_' + str(epoch + 1).zfill(3) + '/Sample_' + str(img_counter).zfill(2), img)
528 | img_counter += 1
529 |
530 | epoch_loss = running_loss / dataset_size
531 | if epoch == 0: first_loss = epoch_loss
532 | if epoch == 4 and epoch_loss / first_loss > 1:
533 | utils.print_both(txt_file, "\nLoss not converging, starting pretraining again\n")
534 | return False
535 |
536 | if board:
537 | writer.add_scalar('Pretraining/Loss' + '/Epoch', epoch_loss, epoch + 1)
538 |
539 | utils.print_both(txt_file, 'Pretraining:\t Loss: {:.4f}'.format(epoch_loss))
540 |
541 | # If wanted to add some criterium in the future
542 | if epoch_loss < best_loss or epoch_loss >= best_loss:
543 | best_loss = epoch_loss
544 | best_model_wts = copy.deepcopy(model.state_dict())
545 |
546 | utils.print_both(txt_file, '')
547 |
548 | time_elapsed = time.time() - since
549 | utils.print_both(txt_file, 'Pretraining complete in {:.0f}m {:.0f}s'.format(
550 | time_elapsed // 60, time_elapsed % 60))
551 |
552 | # load best model weights
553 | model.load_state_dict(best_model_wts)
554 | model.pretrained = True
555 | torch.save(model.state_dict(), pretrained)
556 |
557 | return model
558 |
559 |
560 | # K-means clusters initialisation
561 | def kmeans(model, dataloader, params):
562 | km = KMeans(n_clusters=model.num_clusters, n_init=20)
563 | output_array = None
564 | model.eval()
565 | # Itarate throught the data and concatenate the latent space representations of images
566 | for data in dataloader:
567 | inputs, _ = data
568 | inputs = inputs.to(params['device'])
569 | _, _, outputs = model(inputs)
570 | if output_array is not None:
571 | output_array = np.concatenate((output_array, outputs.cpu().detach().numpy()), 0)
572 | else:
573 | output_array = outputs.cpu().detach().numpy()
574 | # print(output_array.shape)
575 | if output_array.shape[0] > 50000: break
576 |
577 | # Perform K-means
578 | km.fit_predict(output_array)
579 | # Update clustering layer weights
580 | weights = torch.from_numpy(km.cluster_centers_)
581 | model.clustering.set_weight(weights.to(params['device']))
582 | # torch.cuda.empty_cache()
583 |
584 |
585 | def average_labelled_dist(model, dataloader, params):
586 | output_array = None
587 | label_array = None
588 | model.eval()
589 | # Itarate throught the data and concatenate the latent space representations of images
590 | for data in dataloader:
591 | inputs, labels = data
592 | inputs = inputs.to(params['device'])
593 | _, _, outputs = model(inputs)
594 | if output_array is not None:
595 | output_array = np.concatenate((output_array, outputs.cpu().detach().numpy()), 0)
596 | label_array = np.concatenate((label_array, labels.cpu().detach().numpy()), 0)
597 | else:
598 | output_array = outputs.cpu().detach().numpy()
599 | label_array = labels.cpu().detach().numpy()
600 |
601 | # Initialise weights
602 | weights = np.zeros((model.num_clusters, model.num_clusters))
603 | num_probes = np.zeros((model.num_clusters, 1))
604 |
605 | # Iterate though latent space descriptors and sum labels for each cluster (keep number of elements in clusters)
606 | for j, row in enumerate(output_array):
607 | label = label_array[j]
608 | weights[label,:] += row
609 | num_probes[label] += 1
610 |
611 | # Divide by the number of elements to get average
612 | for i in range(0, weights.shape[0]):
613 | weights[i, :] /= num_probes[i]
614 |
615 | print(num_probes)
616 |
617 | # Update weights in network
618 | weights = weights.astype(np.float32)
619 | weights = torch.from_numpy(weights)
620 | model.clustering.set_weight(weights.to(params['device']))
621 | # torch.cuda.empty_cache()
622 |
623 |
624 | # Function forwarding data through network, collecting clustering weight output and returning prediciotns and labels
625 | def calculate_predictions(model, dataloader, params):
626 | output_array = None
627 | label_array = None
628 | model.eval()
629 | for data in dataloader:
630 | inputs, labels = data
631 | inputs = inputs.to(params['device'])
632 | labels = labels.to(params['device'])
633 | _, outputs, _ = model(inputs)
634 | if output_array is not None:
635 | output_array = np.concatenate((output_array, outputs.cpu().detach().numpy()), 0)
636 | label_array = np.concatenate((label_array, labels.cpu().detach().numpy()), 0)
637 | else:
638 | output_array = outputs.cpu().detach().numpy()
639 | label_array = labels.cpu().detach().numpy()
640 |
641 | preds = np.argmax(output_array.data, axis=1)
642 | # print(output_array.shape)
643 | return output_array, label_array, preds
644 |
645 |
646 | # Calculate target distribution
647 | def target(out_distr):
648 | tar_dist = out_distr ** 2 / np.sum(out_distr, axis=0)
649 | tar_dist = np.transpose(np.transpose(tar_dist) / np.sum(tar_dist, axis=1))
650 | return tar_dist
651 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn.metrics
3 |
4 |
5 | # Simple tensor to image translation
6 | def tensor2img(tensor):
7 | img = tensor.cpu().data[0].numpy().transpose((1, 2, 0))
8 | mean = np.array([0.485, 0.456, 0.406])
9 | std = np.array([0.229, 0.224, 0.225])
10 | img = std * img + mean
11 | img = np.clip(img, 0, 1)
12 | img = img.transpose((2, 0, 1))
13 | return img
14 |
15 |
16 | # Define printing to console and file
17 | def print_both(f, text):
18 | print(text)
19 | f.write(text + '\n')
20 |
21 |
22 | # Metrics class was copied from DCEC article authors repository (link in README)
23 | class metrics:
24 | nmi = sklearn.metrics.normalized_mutual_info_score
25 | ari = sklearn.metrics.adjusted_rand_score
26 |
27 | @staticmethod
28 | def acc(labels_true, labels_pred):
29 | labels_true = labels_true.astype(np.int64)
30 | assert labels_pred.size == labels_true.size
31 | D = max(labels_pred.max(), labels_true.max()) + 1
32 | w = np.zeros((D, D), dtype=np.int64)
33 | for i in range(labels_pred.size):
34 | w[labels_pred[i], labels_true[i]] += 1
35 | from sklearn.utils.linear_assignment_ import linear_assignment
36 | ind = linear_assignment(w.max() - w)
37 | return sum([w[i, j] for i, j in ind]) * 1.0 / labels_pred.size
38 |
--------------------------------------------------------------------------------