├── .gitignore ├── LICENSE.md ├── MNIST_results ├── MNIST_plain.jpg └── MNIST_trained.jpg ├── README.md ├── mnist.py ├── nets.py ├── semi_supervised.py ├── training_functions.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | #sublime 107 | sublime/ 108 | 109 | # images 110 | *.jpg 111 | *.png 112 | *.JPG 113 | *.jpeg 114 | *.JPEG 115 | !MNIST_plain.jpg 116 | !MNIST_trained.jpg 117 | 118 | # nets 119 | *.pt 120 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Michal Nazarczuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNIST_results/MNIST_plain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaal94/Semisupervised-Clustering/555da3d49a97e54807a3fe2aca0b7b5039d0bd34/MNIST_results/MNIST_plain.jpg -------------------------------------------------------------------------------- /MNIST_results/MNIST_trained.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaal94/Semisupervised-Clustering/555da3d49a97e54807a3fe2aca0b7b5039d0bd34/MNIST_results/MNIST_trained.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semisupervised Clustering 2 | 3 | This repository contains the code for semi-supervised clustering developed for Master Thesis: "Automatic analysis of images from camera-traps" by Michal Nazarczuk from Imperial College London 4 | 5 | The algorithm is inspired with DCEC method ([Deep Clustering with Convolutional Autoencoders](https://xifengguo.github.io/papers/ICONIP17-DCEC.pdf)). The main change adds "labelling" loss (cross-entropy between labelled examples and their predictions) as the loss component. 6 | 7 | ## Prerequisites 8 | 9 | The following libraries are required to be installed for the proper code evaluation: 10 | 11 | 1. PyTorch 12 | 2. NumPy 13 | 3. scikit-learn 14 | 4. [TensorboardX](https://github.com/lanpa/tensorboardX) 15 | 16 | The code was written and tested on Python 3.4.1 17 | 18 | ## Installation and usage 19 | 20 | ### Installation 21 | 22 | Just copy the repository to your local folder: 23 | ``` 24 | git clone https://github.com/michaal94/Semisupervised-Clustering 25 | ``` 26 | 27 | ### Use of the algortihm 28 | 29 | In order to test the basic version of the semi-supervised clustering just run it with your python distribution you installed libraries for (Anaconda, Virtualenv, etc.). In general type: 30 | 31 | ``` 32 | cd Semisupervised-Clustering 33 | python3 semi_supervised.py 34 | ``` 35 | The example will run sample clustering with MNIST-train dataset. 36 | 37 | ## Options 38 | 39 | The algorithm offers a plenty of options for adjustments: 40 | 1. Mode choice: full or pretraining only, use: 41 | ```--mode train_full``` or ```--mode pretrain``` 42 | 43 | Fot full training you can specify whether to use pretraining phase ```--pretrain True``` or use saved network ```--pretrain False``` and 44 | ```--pretrained net ("path" or idx)``` with path or index (see catalog structure) of the pretrained network 45 | 2. Dataset choice: 46 | + MNIST - train, test, full 47 | + Custom dataset - use the following data structure (characteristic for PyTorch): 48 | ``` 49 | -data_directory (clusters must corespond to real clustering only for statistics) 50 | -cluster_1 51 | -image_1 52 | -image_2 53 | -... 54 | -cluster_2 55 | -image_1 56 | -image_2 57 | -... 58 | -... 59 | -data_directory_l (data used as labelled, use at least one example in each class in the current version of algorithm) 60 | -cluster_1 61 | -image_1 62 | -image_2 63 | -... 64 | -cluster_2 65 | -image_1 66 | -image_2 67 | -... 68 | -... 69 | ``` 70 | Use the following: ```--dataset MNIST-train```, 71 | ```--dataset MNIST-test```, 72 | ```--dataset MNIST-full``` or 73 | ```--dataset custom``` (use the last one with path 74 | ```--dataset_path 'path to your dataset'``` 75 | and the trasformation you want for images 76 | ```--custom_img_size [height, width, depth]```) 77 | 3. Different network architectures: 78 | + CAE 3 - convolutional autoencoder used in [DCEC](https://xifengguo.github.io/papers/ICONIP17-DCEC.pdf) ```--net_architecture CAE_3``` 79 | + CAE 3 BN - version with Batch Normalisation layers ```--net_architecture CAE_3bn``` 80 | + CAE 4 (BN) - convolutional autoencoder with 4 convolutional blocks ```--net_architecture CAE_4``` and ```--net_architecture CAE_4bn``` 81 | + CAE 5 (BN) - convolutional autoencoder with 5 convolutional blocks ```--net_architecture CAE_5``` and ```--net_architecture CAE_5bn``` (used for 128x128 photos) 82 | 83 | The following opions may be used for model changes: 84 | + LeakyReLU or ReLU usage: ```--leaky True/False``` (True provided better results) 85 | + Negative slope for Leaky ReLU: ```--neg_slope value``` (Values around 0.01 were used) 86 | + Use of sigmoid and tanh activations at the end of encoder and decoder: ```--activations True/False``` (False provided better results) 87 | + Use of bias in layers: ```--bias True/False``` 88 | 4. Optimiser and scheduler settings (Adam optimiser): 89 | + Learning rate: ```--rate value``` (0.001 is reasonable value for Adam) 90 | + Learning rate for pretraining phase: ```--rate_pretrain value``` (0.001 can be used as well) 91 | + Weight decay: ```--weight value``` (0 was used) 92 | + Weight decay for pretraining phase: ```--weight_pretrain value``` 93 | + Scheduler step (how many iterations till the rate is changed): ```--sched_step value``` 94 | + Scheduler step for pretraining phase: ```--sched_step_pretrain value``` 95 | + Scheduler gamma (multiplier of learning rate): ```--sched_gamma value``` 96 | + Scheduler gamma for pretraining phase: ```--sched_gamma_pretrain value``` 97 | 5. Algorithm specific parameters: 98 | + Clustering loss weight (for reconstruction loss fixed with weight 1): ```--gamma value``` (Value of 0.1 provided good results) 99 | + Labelling loss weight: ```--gamma_lab value``` (0.01 provided good results) 100 | + Update interval for target distribution (in number of batches between updates): ```update_interval value``` (Value may be chosen such that distribution is updated each 1000-2000 photos) 101 | + Label check interval ```--label_upd_interval value``` (Suggested to leave each iteration update) 102 | + Stop criterium tolerance ```--tol value``` (Depends on dataset, for small 0.01 was used for bigger e.g. MNIST - 0.001) 103 | + Target number of clusters ```--num_clusters value``` 104 | 6. Other options: 105 | + Batch size: ```--batch_size value``` (Depend on your device, but remember that [too much may be bad for convergence](https://towardsdatascience.com/recent-advances-for-a-better-understanding-of-deep-learning-part-i-5ce34d1cc914)) 106 | + Epochs if stop criterium not met: ```--epochs value``` 107 | + Epochs of pretraining: ```--epochs_pretrain value``` (300 epochs were used, 200 with 0.001 lerning rate and 100 with 10 times smaller - ```--sched_step_pretrain 200```, ```--sched_gamma_pretrain 0.1```) 108 | + Report printing frequency (in batches): ```--printing_frequency value``` 109 | + Tensorboard export: ```--tensorboard True/False``` 110 | 111 | ## Catalog structure 112 | 113 | The code creates the following catalog structure when reporting the statistics: 114 | ``` 115 | -Reports 116 | -(net_architecture_name)_(index).txt 117 | -Nets (copies of weights 118 | -(net_architecture_name)_(index).pt 119 | -(net_architecture_name)_(index)_pretrained.txt 120 | -Runs 121 | -(net_architecture_name)_(index) <- directory containing tensorboard event file 122 | ``` 123 | The files are indexed automatically for the files not to be accidentally overwritten. 124 | 125 | ## Performance 126 | 127 | The code was mainly used to cluster images coming from camera-trap events. However, some additional benchmarks were performed on MNIST datasets. The following table gather some results (for 2% of labelled data): 128 | 129 | Set | NMI | Acc 130 | ---|---|--- 131 | MNIST-full | 95.13 | 98.22% 132 | MNIST-test | 89.59 | 95.29% 133 | 134 | In addition, the _t-SNE_ plots of plain and clustered MNIST full dataset are shown: 135 | 136 | Full set before clustering: 137 | 138 | MNIST full set t-SNE plot 139 | 140 | After clustering: 141 | 142 | MNIST full set after clustering t-SNE plot -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import numpy as np 8 | import torch 9 | import codecs 10 | 11 | 12 | class MNIST(data.Dataset): 13 | """`MNIST `_ Dataset. 14 | 15 | Args: 16 | root (string): Root directory of dataset where ``processed/training.pt`` 17 | and ``processed/test.pt`` exist. 18 | train (bool, optional): If True, creates dataset from ``training.pt``, 19 | otherwise from ``test.pt``. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | transform (callable, optional): A function/transform that takes in an PIL image 24 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 25 | target_transform (callable, optional): A function/transform that takes in the 26 | target and transforms it. 27 | """ 28 | urls = [ 29 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 30 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 31 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 32 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 33 | ] 34 | raw_folder = 'raw' 35 | processed_folder = 'processed' 36 | training_file = 'training.pt' 37 | test_file = 'test.pt' 38 | 39 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, small=False, full=False): 40 | self.root = os.path.expanduser(root) 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | self.train = train # training set or test set 44 | self.full = full 45 | 46 | if full: 47 | self.train = True 48 | 49 | if download: 50 | self.download() 51 | 52 | if not self._check_exists(): 53 | raise RuntimeError('Dataset not found.' + 54 | ' You can use download=True to download it') 55 | 56 | self.train_data, self.train_labels = torch.load(os.path.join(self.root, self.processed_folder, self.training_file)) 57 | self.test_data, self.test_labels = torch.load(os.path.join(self.root, self.processed_folder, self.test_file)) 58 | 59 | if full: 60 | self.train_data = np.concatenate((self.train_data, self.test_data), axis=0) 61 | self.train_labels = np.concatenate((self.train_labels, self.test_labels), axis=0) 62 | 63 | if small: 64 | self.train_data = self.train_data[0:1400] 65 | self.train_labels = self.train_labels[0:1400] 66 | if not full: 67 | self.train_data = self.train_data[0:1200] 68 | self.train_labels = self.train_labels[0:1200] 69 | self.test_data = self.test_data[0:200] 70 | self.test_labels = self.test_labels[0:200] 71 | 72 | def __getitem__(self, index): 73 | """ 74 | Args: 75 | index (int): Index 76 | 77 | Returns: 78 | tuple: (image, target) where target is index of the target class. 79 | """ 80 | if self.train: 81 | img, target = self.train_data[index], self.train_labels[index] 82 | else: 83 | img, target = self.test_data[index], self.test_labels[index] 84 | 85 | # doing this so that it is consistent with all other datasets 86 | # to return a PIL Image 87 | if self.full: 88 | img = Image.fromarray(img, mode='L') 89 | else: 90 | img = Image.fromarray(img.numpy(), mode='L') 91 | 92 | if self.transform is not None: 93 | img = self.transform(img) 94 | 95 | if self.target_transform is not None: 96 | target = self.target_transform(target) 97 | 98 | return img, target 99 | 100 | def __len__(self): 101 | if self.train: 102 | return len(self.train_data) 103 | else: 104 | return len(self.test_data) 105 | 106 | def _check_exists(self): 107 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 108 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 109 | 110 | def download(self): 111 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 112 | from six.moves import urllib 113 | import gzip 114 | 115 | if self._check_exists(): 116 | return 117 | 118 | # download files 119 | try: 120 | os.makedirs(os.path.join(self.root, self.raw_folder)) 121 | os.makedirs(os.path.join(self.root, self.processed_folder)) 122 | except OSError as e: 123 | if e.errno == errno.EEXIST: 124 | pass 125 | else: 126 | raise 127 | 128 | for url in self.urls: 129 | print('Downloading ' + url) 130 | data = urllib.request.urlopen(url) 131 | filename = url.rpartition('/')[2] 132 | file_path = os.path.join(self.root, self.raw_folder, filename) 133 | with open(file_path, 'wb') as f: 134 | f.write(data.read()) 135 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 136 | gzip.GzipFile(file_path) as zip_f: 137 | out_f.write(zip_f.read()) 138 | os.unlink(file_path) 139 | 140 | # process and save as torch files 141 | print('Processing...') 142 | 143 | training_set = ( 144 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), 145 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) 146 | ) 147 | test_set = ( 148 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), 149 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) 150 | ) 151 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 152 | torch.save(training_set, f) 153 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 154 | torch.save(test_set, f) 155 | 156 | print('Done!') 157 | 158 | def __repr__(self): 159 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 160 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 161 | tmp = 'train' if self.train is True else 'test' 162 | fmt_str += ' Split: {}\n'.format(tmp) 163 | fmt_str += ' Root Location: {}\n'.format(self.root) 164 | tmp = ' Transforms (if any): ' 165 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 166 | tmp = ' Target Transforms (if any): ' 167 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 168 | return fmt_str 169 | 170 | 171 | class FashionMNIST(MNIST): 172 | """`Fashion-MNIST `_ Dataset. 173 | 174 | Args: 175 | root (string): Root directory of dataset where ``processed/training.pt`` 176 | and ``processed/test.pt`` exist. 177 | train (bool, optional): If True, creates dataset from ``training.pt``, 178 | otherwise from ``test.pt``. 179 | download (bool, optional): If true, downloads the dataset from the internet and 180 | puts it in root directory. If dataset is already downloaded, it is not 181 | downloaded again. 182 | transform (callable, optional): A function/transform that takes in an PIL image 183 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 184 | target_transform (callable, optional): A function/transform that takes in the 185 | target and transforms it. 186 | """ 187 | urls = [ 188 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 189 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 190 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 191 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 192 | ] 193 | 194 | 195 | class EMNIST(MNIST): 196 | """`EMNIST `_ Dataset. 197 | 198 | Args: 199 | root (string): Root directory of dataset where ``processed/training.pt`` 200 | and ``processed/test.pt`` exist. 201 | split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, 202 | ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies 203 | which one to use. 204 | train (bool, optional): If True, creates dataset from ``training.pt``, 205 | otherwise from ``test.pt``. 206 | download (bool, optional): If true, downloads the dataset from the internet and 207 | puts it in root directory. If dataset is already downloaded, it is not 208 | downloaded again. 209 | transform (callable, optional): A function/transform that takes in an PIL image 210 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 211 | target_transform (callable, optional): A function/transform that takes in the 212 | target and transforms it. 213 | """ 214 | url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' 215 | splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') 216 | 217 | def __init__(self, root, split, **kwargs): 218 | if split not in self.splits: 219 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 220 | split, ', '.join(self.splits), 221 | )) 222 | self.split = split 223 | self.training_file = self._training_file(split) 224 | self.test_file = self._test_file(split) 225 | super(EMNIST, self).__init__(root, **kwargs) 226 | 227 | def _training_file(self, split): 228 | return 'training_{}.pt'.format(split) 229 | 230 | def _test_file(self, split): 231 | return 'test_{}.pt'.format(split) 232 | 233 | def download(self): 234 | """Download the EMNIST data if it doesn't exist in processed_folder already.""" 235 | from six.moves import urllib 236 | import gzip 237 | import shutil 238 | import zipfile 239 | 240 | if self._check_exists(): 241 | return 242 | 243 | # download files 244 | try: 245 | os.makedirs(os.path.join(self.root, self.raw_folder)) 246 | os.makedirs(os.path.join(self.root, self.processed_folder)) 247 | except OSError as e: 248 | if e.errno == errno.EEXIST: 249 | pass 250 | else: 251 | raise 252 | 253 | print('Downloading ' + self.url) 254 | data = urllib.request.urlopen(self.url) 255 | filename = self.url.rpartition('/')[2] 256 | raw_folder = os.path.join(self.root, self.raw_folder) 257 | file_path = os.path.join(raw_folder, filename) 258 | with open(file_path, 'wb') as f: 259 | f.write(data.read()) 260 | 261 | print('Extracting zip archive') 262 | with zipfile.ZipFile(file_path) as zip_f: 263 | zip_f.extractall(raw_folder) 264 | os.unlink(file_path) 265 | gzip_folder = os.path.join(raw_folder, 'gzip') 266 | for gzip_file in os.listdir(gzip_folder): 267 | if gzip_file.endswith('.gz'): 268 | print('Extracting ' + gzip_file) 269 | with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \ 270 | gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f: 271 | out_f.write(zip_f.read()) 272 | shutil.rmtree(gzip_folder) 273 | 274 | # process and save as torch files 275 | for split in self.splits: 276 | print('Processing ' + split) 277 | training_set = ( 278 | read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), 279 | read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) 280 | ) 281 | test_set = ( 282 | read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), 283 | read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) 284 | ) 285 | with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f: 286 | torch.save(training_set, f) 287 | with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f: 288 | torch.save(test_set, f) 289 | 290 | print('Done!') 291 | 292 | 293 | def get_int(b): 294 | return int(codecs.encode(b, 'hex'), 16) 295 | 296 | 297 | def read_label_file(path): 298 | with open(path, 'rb') as f: 299 | data = f.read() 300 | assert get_int(data[:4]) == 2049 301 | length = get_int(data[4:8]) 302 | parsed = np.frombuffer(data, dtype=np.uint8, offset=8) 303 | return torch.from_numpy(parsed).view(length).long() 304 | 305 | 306 | def read_image_file(path): 307 | with open(path, 'rb') as f: 308 | data = f.read() 309 | assert get_int(data[:4]) == 2051 310 | length = get_int(data[4:8]) 311 | num_rows = get_int(data[8:12]) 312 | num_cols = get_int(data[12:16]) 313 | images = [] 314 | parsed = np.frombuffer(data, dtype=np.uint8, offset=16) 315 | return torch.from_numpy(parsed).view(length, num_rows, num_cols) 316 | -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | 5 | # Clustering layer definition (see DCEC article for equations) 6 | class ClusterlingLayer(nn.Module): 7 | def __init__(self, in_features=10, out_features=10, alpha=1.0): 8 | super(ClusterlingLayer, self).__init__() 9 | self.in_features = in_features 10 | self.out_features = out_features 11 | self.alpha = alpha 12 | self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_features)) 13 | self.weight = nn.init.xavier_uniform_(self.weight) 14 | 15 | def forward(self, x): 16 | x = x.unsqueeze(1) - self.weight 17 | x = torch.mul(x, x) 18 | x = torch.sum(x, dim=2) 19 | x = 1.0 + (x / self.alpha) 20 | x = 1.0 / x 21 | x = x ** ((self.alpha +1.0) / 2.0) 22 | x = torch.t(x) / torch.sum(x, dim=1) 23 | x = torch.t(x) 24 | return x 25 | 26 | def extra_repr(self): 27 | return 'in_features={}, out_features={}, alpha={}'.format( 28 | self.in_features, self.out_features, self.alpha 29 | ) 30 | 31 | def set_weight(self, tensor): 32 | self.weight = nn.Parameter(tensor) 33 | 34 | 35 | # Convolutional autoencoder directly from DCEC article 36 | class CAE_3(nn.Module): 37 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128], leaky=True, neg_slope=0.01, activations=False, bias=True): 38 | super(CAE_3, self).__init__() 39 | self.activations = activations 40 | # bias = True 41 | self.pretrained = False 42 | self.num_clusters = num_clusters 43 | self.input_shape = input_shape 44 | self.filters = filters 45 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias) 46 | if leaky: 47 | self.relu = nn.LeakyReLU(negative_slope=neg_slope) 48 | else: 49 | self.relu = nn.ReLU(inplace=False) 50 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias) 51 | self.conv3 = nn.Conv2d(filters[1], filters[2], 3, stride=2, padding=0, bias=bias) 52 | lin_features_len = ((input_shape[0]//2//2-1) // 2) * ((input_shape[0]//2//2-1) // 2) * filters[2] 53 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias) 54 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias) 55 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0 56 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 3, stride=2, padding=0, output_padding=out_pad, bias=bias) 57 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0 58 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, bias=bias) 59 | out_pad = 1 if input_shape[0] % 2 == 0 else 0 60 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, bias=bias) 61 | self.clustering = ClusterlingLayer(num_clusters, num_clusters) 62 | # ReLU copies for graph representation in tensorboard 63 | self.relu1_1 = copy.deepcopy(self.relu) 64 | self.relu2_1 = copy.deepcopy(self.relu) 65 | self.relu3_1 = copy.deepcopy(self.relu) 66 | self.relu1_2 = copy.deepcopy(self.relu) 67 | self.relu2_2 = copy.deepcopy(self.relu) 68 | self.relu3_2 = copy.deepcopy(self.relu) 69 | self.sig = nn.Sigmoid() 70 | self.tanh = nn.Tanh() 71 | 72 | def forward(self, x): 73 | x = self.conv1(x) 74 | x = self.relu1_1(x) 75 | x = self.conv2(x) 76 | x = self.relu2_1(x) 77 | x = self.conv3(x) 78 | if self.activations: 79 | x = self.sig(x) 80 | else: 81 | x = self.relu3_1(x) 82 | x = x.view(x.size(0), -1) 83 | x = self.embedding(x) 84 | extra_out = x 85 | clustering_out = self.clustering(x) 86 | x = self.deembedding(x) 87 | x = self.relu1_2(x) 88 | x = x.view(x.size(0), self.filters[2], ((self.input_shape[0]//2//2-1) // 2), ((self.input_shape[0]//2//2-1) // 2)) 89 | x = self.deconv3(x) 90 | x = self.relu2_2(x) 91 | x = self.deconv2(x) 92 | x = self.relu3_2(x) 93 | x = self.deconv1(x) 94 | if self.activations: 95 | x = self.tanh(x) 96 | return x, clustering_out, extra_out 97 | 98 | 99 | # Convolutional autoencoder from DCEC article with Batch Norms and Leaky ReLUs 100 | class CAE_bn3(nn.Module): 101 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128], leaky=True, neg_slope=0.01, activations=False, bias=True): 102 | super(CAE_bn3, self).__init__() 103 | self.activations=activations 104 | self.pretrained = False 105 | self.num_clusters = num_clusters 106 | self.input_shape = input_shape 107 | self.filters = filters 108 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias) 109 | self.bn1_1 = nn.BatchNorm2d(filters[0]) 110 | if leaky: 111 | self.relu = nn.LeakyReLU(negative_slope=neg_slope) 112 | else: 113 | self.relu = nn.ReLU(inplace=False) 114 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias) 115 | self.bn2_1 = nn.BatchNorm2d(filters[1]) 116 | self.conv3 = nn.Conv2d(filters[1], filters[2], 3, stride=2, padding=0, bias=bias) 117 | lin_features_len = ((input_shape[0]//2//2-1) // 2) * ((input_shape[0]//2//2-1) // 2) * filters[2] 118 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias) 119 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias) 120 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0 121 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 3, stride=2, padding=0, output_padding=out_pad, bias=bias) 122 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0 123 | self.bn3_2 = nn.BatchNorm2d(filters[1]) 124 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, bias=bias) 125 | out_pad = 1 if input_shape[0] % 2 == 0 else 0 126 | self.bn2_2 = nn.BatchNorm2d(filters[0]) 127 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, bias=bias) 128 | self.clustering = ClusterlingLayer(num_clusters, num_clusters) 129 | # ReLU copies for graph representation in tensorboard 130 | self.relu1_1 = copy.deepcopy(self.relu) 131 | self.relu2_1 = copy.deepcopy(self.relu) 132 | self.relu3_1 = copy.deepcopy(self.relu) 133 | self.relu1_2 = copy.deepcopy(self.relu) 134 | self.relu2_2 = copy.deepcopy(self.relu) 135 | self.relu3_2 = copy.deepcopy(self.relu) 136 | self.sig = nn.Sigmoid() 137 | self.tanh = nn.Tanh() 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.relu1_1(x) 142 | x = self.bn1_1(x) 143 | x = self.conv2(x) 144 | x = self.relu2_1(x) 145 | x = self.bn2_1(x) 146 | x = self.conv3(x) 147 | if self.activations: 148 | x = self.sig(x) 149 | else: 150 | x = self.relu3_1(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.embedding(x) 153 | extra_out = x 154 | clustering_out = self.clustering(x) 155 | x = self.deembedding(x) 156 | x = self.relu1_2(x) 157 | x = x.view(x.size(0), self.filters[2], ((self.input_shape[0]//2//2-1) // 2), ((self.input_shape[0]//2//2-1) // 2)) 158 | x = self.deconv3(x) 159 | x = self.relu2_2(x) 160 | x = self.bn3_2(x) 161 | x = self.deconv2(x) 162 | x = self.relu3_2(x) 163 | x = self.bn2_2(x) 164 | x = self.deconv1(x) 165 | if self.activations: 166 | x = self.tanh(x) 167 | return x, clustering_out, extra_out 168 | 169 | 170 | # Convolutional autoencoder with 4 convolutional blocks 171 | class CAE_4(nn.Module): 172 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256], leaky=True, neg_slope=0.01, activations=False, bias=True): 173 | super(CAE_4, self).__init__() 174 | self.activations = activations 175 | self.pretrained = False 176 | self.num_clusters = num_clusters 177 | self.input_shape = input_shape 178 | self.filters = filters 179 | if leaky: 180 | self.relu = nn.LeakyReLU(negative_slope=neg_slope) 181 | else: 182 | self.relu = nn.ReLU(inplace=False) 183 | 184 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias) 185 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias) 186 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias) 187 | self.conv4 = nn.Conv2d(filters[2], filters[3], 3, stride=2, padding=0, bias=bias) 188 | 189 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * \ 190 | filters[3] 191 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias) 192 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias) 193 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0 194 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 3, stride=2, padding=0, output_padding=out_pad, 195 | bias=bias) 196 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0 197 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad, 198 | bias=bias) 199 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0 200 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, 201 | bias=bias) 202 | out_pad = 1 if input_shape[0] % 2 == 0 else 0 203 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, 204 | bias=bias) 205 | self.clustering = ClusterlingLayer(num_clusters, num_clusters) 206 | # ReLU copies for graph representation in tensorboard 207 | self.relu1_1 = copy.deepcopy(self.relu) 208 | self.relu2_1 = copy.deepcopy(self.relu) 209 | self.relu3_1 = copy.deepcopy(self.relu) 210 | self.relu4_1 = copy.deepcopy(self.relu) 211 | self.relu1_2 = copy.deepcopy(self.relu) 212 | self.relu2_2 = copy.deepcopy(self.relu) 213 | self.relu3_2 = copy.deepcopy(self.relu) 214 | self.relu4_2 = copy.deepcopy(self.relu) 215 | self.sig = nn.Sigmoid() 216 | self.tanh = nn.Tanh() 217 | 218 | def forward(self, x): 219 | x = self.conv1(x) 220 | x = self.relu1_1(x) 221 | x = self.conv2(x) 222 | x = self.relu2_1(x) 223 | x = self.conv3(x) 224 | x = self.relu3_1(x) 225 | x = self.conv4(x) 226 | if self.activations: 227 | x = self.sig(x) 228 | else: 229 | x = self.relu4_1(x) 230 | x = x.view(x.size(0), -1) 231 | x = self.embedding(x) 232 | extra_out = x 233 | clustering_out = self.clustering(x) 234 | x = self.deembedding(x) 235 | x = self.relu4_2(x) 236 | x = x.view(x.size(0), self.filters[3], ((self.input_shape[0]//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2-1) // 2)) 237 | x = self.deconv4(x) 238 | x = self.relu3_2(x) 239 | x = self.deconv3(x) 240 | x = self.relu2_2(x) 241 | x = self.deconv2(x) 242 | x = self.relu1_2(x) 243 | x = self.deconv1(x) 244 | if self.activations: 245 | x = self.tanh(x) 246 | return x, clustering_out, extra_out 247 | 248 | # Convolutional autoencoder with 4 convolutional blocks (BN version) 249 | class CAE_bn4(nn.Module): 250 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256], leaky=True, neg_slope=0.01, activations=False, bias=True): 251 | super(CAE_bn4, self).__init__() 252 | self.activations = activations 253 | self.pretrained = False 254 | self.num_clusters = num_clusters 255 | self.input_shape = input_shape 256 | self.filters = filters 257 | if leaky: 258 | self.relu = nn.LeakyReLU(negative_slope=neg_slope) 259 | else: 260 | self.relu = nn.ReLU(inplace=False) 261 | 262 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias) 263 | self.bn1_1 = nn.BatchNorm2d(filters[0]) 264 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias) 265 | self.bn2_1 = nn.BatchNorm2d(filters[1]) 266 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias) 267 | self.bn3_1 = nn.BatchNorm2d(filters[2]) 268 | self.conv4 = nn.Conv2d(filters[2], filters[3], 3, stride=2, padding=0, bias=bias) 269 | 270 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * ((input_shape[0] // 2 // 2 // 2 - 1) // 2) * \ 271 | filters[3] 272 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias) 273 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias) 274 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0 275 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 3, stride=2, padding=0, output_padding=out_pad, 276 | bias=bias) 277 | self.bn4_2 = nn.BatchNorm2d(filters[2]) 278 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0 279 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad, 280 | bias=bias) 281 | self.bn3_2 = nn.BatchNorm2d(filters[1]) 282 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0 283 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, 284 | bias=bias) 285 | self.bn2_2 = nn.BatchNorm2d(filters[0]) 286 | out_pad = 1 if input_shape[0] % 2 == 0 else 0 287 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, 288 | bias=bias) 289 | self.clustering = ClusterlingLayer(num_clusters, num_clusters) 290 | # ReLU copies for graph representation in tensorboard 291 | self.relu1_1 = copy.deepcopy(self.relu) 292 | self.relu2_1 = copy.deepcopy(self.relu) 293 | self.relu3_1 = copy.deepcopy(self.relu) 294 | self.relu4_1 = copy.deepcopy(self.relu) 295 | self.relu1_2 = copy.deepcopy(self.relu) 296 | self.relu2_2 = copy.deepcopy(self.relu) 297 | self.relu3_2 = copy.deepcopy(self.relu) 298 | self.relu4_2 = copy.deepcopy(self.relu) 299 | self.sig = nn.Sigmoid() 300 | self.tanh = nn.Tanh() 301 | 302 | def forward(self, x): 303 | x = self.conv1(x) 304 | x = self.relu1_1(x) 305 | x = self.bn1_1(x) 306 | x = self.conv2(x) 307 | x = self.relu2_1(x) 308 | x = self.bn2_1(x) 309 | x = self.conv3(x) 310 | x = self.relu3_1(x) 311 | x = self.bn3_1(x) 312 | x = self.conv4(x) 313 | if self.activations: 314 | x = self.sig(x) 315 | else: 316 | x = self.relu4_1(x) 317 | x = x.view(x.size(0), -1) 318 | x = self.embedding(x) 319 | extra_out = x 320 | clustering_out = self.clustering(x) 321 | x = self.deembedding(x) 322 | x = self.relu4_2(x) 323 | x = x.view(x.size(0), self.filters[3], ((self.input_shape[0]//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2-1) // 2)) 324 | x = self.deconv4(x) 325 | x = self.relu3_2(x) 326 | x = self.bn4_2(x) 327 | x = self.deconv3(x) 328 | x = self.relu2_2(x) 329 | x = self.bn3_2(x) 330 | x = self.deconv2(x) 331 | x = self.relu1_2(x) 332 | x = self.bn2_2(x) 333 | x = self.deconv1(x) 334 | if self.activations: 335 | x = self.tanh(x) 336 | return x, clustering_out, extra_out 337 | 338 | 339 | # Convolutional autoencoder with 5 convolutional blocks 340 | class CAE_5(nn.Module): 341 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256, 512], leaky=True, neg_slope=0.01, activations=False, bias=True): 342 | super(CAE_5, self).__init__() 343 | self.activations = activations 344 | self.pretrained = False 345 | self.num_clusters = num_clusters 346 | self.input_shape = input_shape 347 | self.filters = filters 348 | self.relu = nn.ReLU(inplace=False) 349 | if leaky: 350 | self.relu = nn.LeakyReLU(negative_slope=neg_slope) 351 | else: 352 | self.relu = nn.ReLU(inplace=False) 353 | 354 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias) 355 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias) 356 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias) 357 | self.conv4 = nn.Conv2d(filters[2], filters[3], 5, stride=2, padding=2, bias=bias) 358 | self.conv5 = nn.Conv2d(filters[3], filters[4], 3, stride=2, padding=0, bias=bias) 359 | 360 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * ( 361 | (input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * filters[4] 362 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias) 363 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias) 364 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 // 2 % 2 == 0 else 0 365 | self.deconv5 = nn.ConvTranspose2d(filters[4], filters[3], 3, stride=2, padding=0, output_padding=out_pad, 366 | bias=bias) 367 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0 368 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 5, stride=2, padding=2, output_padding=out_pad, 369 | bias=bias) 370 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0 371 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad, 372 | bias=bias) 373 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0 374 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, 375 | bias=bias) 376 | out_pad = 1 if input_shape[0] % 2 == 0 else 0 377 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, 378 | bias=bias) 379 | self.clustering = ClusterlingLayer(num_clusters, num_clusters) 380 | # ReLU copies for graph representation in tensorboard 381 | self.relu1_1 = copy.deepcopy(self.relu) 382 | self.relu2_1 = copy.deepcopy(self.relu) 383 | self.relu3_1 = copy.deepcopy(self.relu) 384 | self.relu4_1 = copy.deepcopy(self.relu) 385 | self.relu5_1 = copy.deepcopy(self.relu) 386 | self.relu1_2 = copy.deepcopy(self.relu) 387 | self.relu2_2 = copy.deepcopy(self.relu) 388 | self.relu3_2 = copy.deepcopy(self.relu) 389 | self.relu4_2 = copy.deepcopy(self.relu) 390 | self.relu5_2 = copy.deepcopy(self.relu) 391 | self.sig = nn.Sigmoid() 392 | self.tanh = nn.Tanh() 393 | 394 | def forward(self, x): 395 | x = self.conv1(x) 396 | x = self.relu1_1(x) 397 | x = self.conv2(x) 398 | x = self.relu2_1(x) 399 | x = self.conv3(x) 400 | x = self.relu3_1(x) 401 | x = self.conv4(x) 402 | x = self.relu4_1(x) 403 | x = self.conv5(x) 404 | if self.activations: 405 | x = self.sig(x) 406 | else: 407 | x = self.relu5_1(x) 408 | x = x.view(x.size(0), -1) 409 | x = self.embedding(x) 410 | extra_out = x 411 | clustering_out = self.clustering(x) 412 | x = self.deembedding(x) 413 | x = self.relu4_2(x) 414 | x = x.view(x.size(0), self.filters[4], ((self.input_shape[0]//2//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2//2-1) // 2)) 415 | x = self.deconv5(x) 416 | x = self.relu4_2(x) 417 | x = self.deconv4(x) 418 | x = self.relu3_2(x) 419 | x = self.deconv3(x) 420 | x = self.relu2_2(x) 421 | x = self.deconv2(x) 422 | x = self.relu1_2(x) 423 | x = self.deconv1(x) 424 | if self.activations: 425 | x = self.tanh(x) 426 | return x, clustering_out, extra_out 427 | 428 | 429 | # Convolutional autoencoder with 5 convolutional blocks (BN version) 430 | class CAE_bn5(nn.Module): 431 | def __init__(self, input_shape=[128,128,3], num_clusters=10, filters=[32, 64, 128, 256, 512], leaky=True, neg_slope=0.01, activations=False, bias=True): 432 | super(CAE_bn5, self).__init__() 433 | self.activations = activations 434 | self.pretrained = False 435 | self.num_clusters = num_clusters 436 | self.input_shape = input_shape 437 | self.filters = filters 438 | self.relu = nn.ReLU(inplace=False) 439 | if leaky: 440 | self.relu = nn.LeakyReLU(negative_slope=neg_slope) 441 | else: 442 | self.relu = nn.ReLU(inplace=False) 443 | 444 | self.conv1 = nn.Conv2d(input_shape[2], filters[0], 5, stride=2, padding=2, bias=bias) 445 | self.bn1_1 = nn.BatchNorm2d(filters[0]) 446 | self.conv2 = nn.Conv2d(filters[0], filters[1], 5, stride=2, padding=2, bias=bias) 447 | self.bn2_1 = nn.BatchNorm2d(filters[1]) 448 | self.conv3 = nn.Conv2d(filters[1], filters[2], 5, stride=2, padding=2, bias=bias) 449 | self.bn3_1 = nn.BatchNorm2d(filters[2]) 450 | self.conv4 = nn.Conv2d(filters[2], filters[3], 5, stride=2, padding=2, bias=bias) 451 | self.bn4_1 = nn.BatchNorm2d(filters[3]) 452 | self.conv5 = nn.Conv2d(filters[3], filters[4], 3, stride=2, padding=0, bias=bias) 453 | 454 | lin_features_len = ((input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * ( 455 | (input_shape[0] // 2 // 2 // 2 // 2 - 1) // 2) * filters[4] 456 | self.embedding = nn.Linear(lin_features_len, num_clusters, bias=bias) 457 | self.deembedding = nn.Linear(num_clusters, lin_features_len, bias=bias) 458 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 // 2 % 2 == 0 else 0 459 | self.deconv5 = nn.ConvTranspose2d(filters[4], filters[3], 3, stride=2, padding=0, output_padding=out_pad, 460 | bias=bias) 461 | self.bn5_2 = nn.BatchNorm2d(filters[3]) 462 | out_pad = 1 if input_shape[0] // 2 // 2 // 2 % 2 == 0 else 0 463 | self.deconv4 = nn.ConvTranspose2d(filters[3], filters[2], 5, stride=2, padding=2, output_padding=out_pad, 464 | bias=bias) 465 | self.bn4_2 = nn.BatchNorm2d(filters[2]) 466 | out_pad = 1 if input_shape[0] // 2 // 2 % 2 == 0 else 0 467 | self.deconv3 = nn.ConvTranspose2d(filters[2], filters[1], 5, stride=2, padding=2, output_padding=out_pad, 468 | bias=bias) 469 | self.bn3_2 = nn.BatchNorm2d(filters[1]) 470 | out_pad = 1 if input_shape[0] // 2 % 2 == 0 else 0 471 | self.deconv2 = nn.ConvTranspose2d(filters[1], filters[0], 5, stride=2, padding=2, output_padding=out_pad, 472 | bias=bias) 473 | self.bn2_2 = nn.BatchNorm2d(filters[0]) 474 | out_pad = 1 if input_shape[0] % 2 == 0 else 0 475 | self.deconv1 = nn.ConvTranspose2d(filters[0], input_shape[2], 5, stride=2, padding=2, output_padding=out_pad, 476 | bias=bias) 477 | self.clustering = ClusterlingLayer(num_clusters, num_clusters) 478 | # ReLU copies for graph representation in tensorboard 479 | self.relu1_1 = copy.deepcopy(self.relu) 480 | self.relu2_1 = copy.deepcopy(self.relu) 481 | self.relu3_1 = copy.deepcopy(self.relu) 482 | self.relu4_1 = copy.deepcopy(self.relu) 483 | self.relu5_1 = copy.deepcopy(self.relu) 484 | self.relu1_2 = copy.deepcopy(self.relu) 485 | self.relu2_2 = copy.deepcopy(self.relu) 486 | self.relu3_2 = copy.deepcopy(self.relu) 487 | self.relu4_2 = copy.deepcopy(self.relu) 488 | self.relu5_2 = copy.deepcopy(self.relu) 489 | self.sig = nn.Sigmoid() 490 | self.tanh = nn.Tanh() 491 | 492 | def forward(self, x): 493 | x = self.conv1(x) 494 | x = self.relu1_1(x) 495 | x = self.bn1_1(x) 496 | x = self.conv2(x) 497 | x = self.relu2_1(x) 498 | x = self.bn2_1(x) 499 | x = self.conv3(x) 500 | x = self.relu3_1(x) 501 | x = self.bn3_1(x) 502 | x = self.conv4(x) 503 | x = self.relu4_1(x) 504 | x = self.bn4_1(x) 505 | x = self.conv5(x) 506 | if self.activations: 507 | x = self.sig(x) 508 | else: 509 | x = self.relu5_1(x) 510 | x = x.view(x.size(0), -1) 511 | x = self.embedding(x) 512 | extra_out = x 513 | clustering_out = self.clustering(x) 514 | x = self.deembedding(x) 515 | x = self.relu5_2(x) 516 | x = x.view(x.size(0), self.filters[4], ((self.input_shape[0]//2//2//2//2-1) // 2), ((self.input_shape[0]//2//2//2//2-1) // 2)) 517 | x = self.deconv5(x) 518 | x = self.relu4_2(x) 519 | x = self.bn5_2(x) 520 | x = self.deconv4(x) 521 | x = self.relu3_2(x) 522 | x = self.bn4_2(x) 523 | x = self.deconv3(x) 524 | x = self.relu2_2(x) 525 | x = self.bn3_2(x) 526 | x = self.deconv2(x) 527 | x = self.relu1_2(x) 528 | x = self.bn2_2(x) 529 | x = self.deconv1(x) 530 | if self.activations: 531 | x = self.tanh(x) 532 | return x, clustering_out, extra_out 533 | -------------------------------------------------------------------------------- /semi_supervised.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | if __name__ == "__main__": 4 | 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.optim import lr_scheduler 10 | from torchvision import datasets, transforms 11 | import os 12 | import math 13 | import fnmatch 14 | import nets 15 | import utils 16 | import training_functions 17 | from tensorboardX import SummaryWriter 18 | 19 | # Translate string entries to bool for parser 20 | def str2bool(v): 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError('Boolean value expected.') 27 | 28 | parser = argparse.ArgumentParser(description='Use DCEC for clustering') 29 | parser.add_argument('--mode', default='train_full', choices=['train_full', 'pretrain'], help='mode') 30 | parser.add_argument('--tensorboard', default=True, type=bool, help='export training stats to tensorboard') 31 | parser.add_argument('--pretrain', default=True, type=str2bool, help='perform autoencoder pretraining') 32 | parser.add_argument('--pretrained_net', default=1, help='index or path of pretrained net') 33 | parser.add_argument('--net_architecture', default='CAE_3', choices=['CAE_3', 'CAE_bn3', 'CAE_4', 'CAE_bn4', 'CAE_5', 'CAE_bn5'], help='network architecture used') 34 | parser.add_argument('--dataset', default='MNIST-train', choices=['MNIST-train', 'custom', 'MNIST-test', 'MNIST-full'], 35 | help='custom or prepared dataset') 36 | parser.add_argument('--dataset_path', default='data', help='path to dataset') 37 | parser.add_argument('--batch_size', default=256, type=int, help='batch size') 38 | parser.add_argument('--rate', default=0.001, type=float, help='learning rate for clustering') 39 | parser.add_argument('--rate_pretrain', default=0.001, type=float, help='learning rate for pretraining') 40 | parser.add_argument('--weight', default=0.0, type=float, help='weight decay for clustering') 41 | parser.add_argument('--weight_pretrain', default=0.0, type=float, help='weight decay for clustering') 42 | parser.add_argument('--sched_step', default=200, type=int, help='scheduler steps for rate update') 43 | parser.add_argument('--sched_step_pretrain', default=200, type=int, 44 | help='scheduler steps for rate update - pretrain') 45 | parser.add_argument('--sched_gamma', default=0.1, type=float, help='scheduler gamma for rate update') 46 | parser.add_argument('--sched_gamma_pretrain', default=0.1, type=float, 47 | help='scheduler gamma for rate update - pretrain') 48 | parser.add_argument('--epochs', default=1000, type=int, help='clustering epochs') 49 | parser.add_argument('--epochs_pretrain', default=300, type=int, help='pretraining epochs') 50 | parser.add_argument('--printing_frequency', default=10, type=int, help='training stats printing frequency') 51 | parser.add_argument('--gamma', default=0.1, type=float, help='clustering loss weight') 52 | parser.add_argument('--gamma_lab', default=0.01, type=float, help='labelled loss weight') 53 | parser.add_argument('--update_interval', default=80, type=int, help='update interval for target distribution') 54 | parser.add_argument('--label_upd_interval', default=1, type=int, help='update interval for target distribution') 55 | parser.add_argument('--tol', default=1e-3, type=float, help='stop criterium tolerance') 56 | parser.add_argument('--num_clusters', default=10, type=int, help='number of clusters') 57 | parser.add_argument('--custom_img_size', default=[128, 128, 3], nargs=3, type=int, help='size of custom images') 58 | parser.add_argument('--leaky', default=True, type=str2bool, help='use leaky version of relu') 59 | parser.add_argument('--neg_slope', default=0.01, type=float, help='negative slope for leaky relu') 60 | parser.add_argument('--activations', default=False, type=str2bool, help='use sigmoid and tanh activations in autoencoder') 61 | parser.add_argument('--bias', default=True, type=str2bool, help='use bias in layers') 62 | args = parser.parse_args() 63 | print(args) 64 | 65 | if args.mode == 'pretrain' and not args.pretrain: 66 | print("Nothing to do :(") 67 | exit() 68 | 69 | board = args.tensorboard 70 | 71 | # Deal with pretraining option and way of showing network path 72 | pretrain = args.pretrain 73 | net_is_path = True 74 | if not pretrain: 75 | try: 76 | int(args.pretrained_net) 77 | idx = args.pretrained_net 78 | net_is_path = False 79 | except: 80 | pass 81 | params = {'pretrain': pretrain} 82 | 83 | # Directories 84 | # Create directories structure 85 | dirs = ['runs', 'reports', 'nets'] 86 | list(map(lambda x: os.makedirs(x, exist_ok=True), dirs)) 87 | 88 | # Net architecture 89 | model_name = args.net_architecture 90 | # Indexing (for automated reports saving) - allows to run many trainings and get all the reports collected 91 | if pretrain or (not pretrain and net_is_path): 92 | reports_list = sorted(os.listdir('reports'), reverse=True) 93 | if reports_list: 94 | for file in reports_list: 95 | # print(file) 96 | if fnmatch.fnmatch(file, model_name+'*'): 97 | print(file) 98 | idx = int(str(file)[-7:-4]) + 1 99 | print(idx) 100 | break 101 | try: 102 | idx 103 | except NameError: 104 | idx = 1 105 | 106 | # Base filename 107 | name = model_name + '_' + str(idx).zfill(3) 108 | 109 | # Filenames for report and weights 110 | name_txt = name + '.txt' 111 | name_net = name 112 | pretrained = name + '_pretrained.pt' 113 | 114 | print(name_txt) 115 | 116 | # Arrange filenames for report, network weights, pretrained network weights 117 | name_txt = os.path.join('reports', name_txt) 118 | name_net = os.path.join('nets', name_net) 119 | if net_is_path and not pretrain: 120 | pretrained = args.pretrained_net 121 | else: 122 | pretrained = os.path.join('nets', pretrained) 123 | if not pretrain and not os.path.isfile(pretrained): 124 | print("No pretrained weights, try again choosing pretrained network or create new with pretrain=True") 125 | 126 | model_files = [name_net, pretrained] 127 | params['model_files'] = model_files 128 | 129 | # Open file 130 | if pretrain: 131 | f = open(name_txt, 'w') 132 | else: 133 | f = open(name_txt, 'a') 134 | params['txt_file'] = f 135 | 136 | # Delete tensorboard entry if exist (not to overlap as the charts become unreadable) 137 | try: 138 | os.system("rm -rf runs/" + name) 139 | except: 140 | pass 141 | 142 | # Initialize tensorboard writer 143 | if board: 144 | writer = SummaryWriter('runs/' + name) 145 | params['writer'] = writer 146 | else: 147 | params['writer'] = None 148 | 149 | # Hyperparameters 150 | 151 | # Used dataset 152 | dataset = args.dataset 153 | 154 | # Batch size 155 | batch = args.batch_size 156 | params['batch'] = batch 157 | # Number of workers (typically 4*num_of_GPUs) 158 | workers = 4 159 | # Learning rate 160 | rate = args.rate 161 | rate_pretrain = args.rate_pretrain 162 | # Adam params 163 | # Weight decay 164 | weight = args.weight 165 | weight_pretrain = args.weight_pretrain 166 | # Scheduler steps for rate update 167 | sched_step = args.sched_step 168 | sched_step_pretrain = args.sched_step_pretrain 169 | # Scheduler gamma - multiplier for learning rate 170 | sched_gamma = args.sched_gamma 171 | sched_gamma_pretrain = args.sched_gamma_pretrain 172 | 173 | # Number of epochs 174 | epochs = args.epochs 175 | pretrain_epochs = args.epochs_pretrain 176 | params['pretrain_epochs'] = pretrain_epochs 177 | 178 | # Printing frequency 179 | print_freq = args.printing_frequency 180 | params['print_freq'] = print_freq 181 | 182 | # Clustering loss weight: 183 | gamma = args.gamma 184 | params['gamma'] = gamma 185 | 186 | # Labelled loss weight: 187 | gamma_lab = args.gamma_lab 188 | params['gamma_lab'] = gamma_lab 189 | 190 | # Update interval for target distribution: 191 | update_interval = args.update_interval 192 | params['update_interval'] = update_interval 193 | 194 | label_upd_interval = args.label_upd_interval 195 | params['label_upd_interval'] = label_upd_interval 196 | 197 | # Tolerance for label changes: 198 | tol = args.tol 199 | params['tol'] = tol 200 | 201 | # Number of clusters 202 | num_clusters = args.num_clusters 203 | 204 | # Report for settings 205 | tmp = "Training the '" + model_name + "' architecture" 206 | utils.print_both(f, tmp) 207 | tmp = "\n" + "The following parameters are used:" 208 | utils.print_both(f, tmp) 209 | tmp = "Batch size:\t" + str(batch) 210 | utils.print_both(f, tmp) 211 | tmp = "Number of workers:\t" + str(workers) 212 | utils.print_both(f, tmp) 213 | tmp = "Learning rate:\t" + str(rate) 214 | utils.print_both(f, tmp) 215 | tmp = "Pretraining learning rate:\t" + str(rate_pretrain) 216 | utils.print_both(f, tmp) 217 | tmp = "Weight decay:\t" + str(weight) 218 | utils.print_both(f, tmp) 219 | tmp = "Pretraining weight decay:\t" + str(weight_pretrain) 220 | utils.print_both(f, tmp) 221 | tmp = "Scheduler steps:\t" + str(sched_step) 222 | utils.print_both(f, tmp) 223 | tmp = "Scheduler gamma:\t" + str(sched_gamma) 224 | utils.print_both(f, tmp) 225 | tmp = "Pretraining scheduler steps:\t" + str(sched_step_pretrain) 226 | utils.print_both(f, tmp) 227 | tmp = "Pretraining scheduler gamma:\t" + str(sched_gamma_pretrain) 228 | utils.print_both(f, tmp) 229 | tmp = "Number of epochs of training:\t" + str(epochs) 230 | utils.print_both(f, tmp) 231 | tmp = "Number of epochs of pretraining:\t" + str(pretrain_epochs) 232 | utils.print_both(f, tmp) 233 | tmp = "Clustering loss weight:\t" + str(gamma) 234 | utils.print_both(f, tmp) 235 | tmp = "Labelled loss weight:\t" + str(gamma_lab) 236 | utils.print_both(f, tmp) 237 | tmp = "Update interval for target distribution:\t" + str(update_interval) 238 | utils.print_both(f, tmp) 239 | tmp = "Update interval for labelled loss:\t" + str(label_upd_interval) 240 | utils.print_both(f, tmp) 241 | tmp = "Stop criterium tolerance:\t" + str(tol) 242 | utils.print_both(f, tmp) 243 | tmp = "Number of clusters:\t" + str(num_clusters) 244 | utils.print_both(f, tmp) 245 | tmp = "Leaky relu:\t" + str(args.leaky) 246 | utils.print_both(f, tmp) 247 | tmp = "Leaky slope:\t" + str(args.neg_slope) 248 | utils.print_both(f, tmp) 249 | tmp = "Activations:\t" + str(args.activations) 250 | utils.print_both(f, tmp) 251 | tmp = "Bias:\t" + str(args.bias) 252 | utils.print_both(f, tmp) 253 | 254 | # Data preparation 255 | if dataset == 'MNIST-train': 256 | # Uses slightly modified torchvision MNIST class and creates dataloader with whole sets 257 | # and sets of 2% of data (as labelled) 258 | import mnist 259 | tmp = "\nData preparation\nReading data from: MNIST train dataset" 260 | utils.print_both(f, tmp) 261 | img_size = [28, 28, 1] 262 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1]) 263 | utils.print_both(f, tmp) 264 | 265 | dataset = mnist.MNIST('../data', train=True, download=True, 266 | transform=transforms.Compose([ 267 | transforms.ToTensor(), 268 | # transforms.Normalize((0.1307,), (0.3081,)) 269 | ])) 270 | 271 | dataloader = torch.utils.data.DataLoader(dataset, 272 | batch_size=batch, shuffle=False, num_workers=workers) 273 | 274 | dataset_size = len(dataset) 275 | tmp = "Training set size:\t" + str(dataset_size) 276 | utils.print_both(f, tmp) 277 | 278 | dataset_labelled = mnist.MNIST('../data', train=True, download=True, small=True, 279 | transform=transforms.Compose([ 280 | transforms.ToTensor(), 281 | # transforms.Normalize((0.1307,), (0.3081,)) 282 | ])) 283 | 284 | dataloader_labelled = torch.utils.data.DataLoader(dataset_labelled, 285 | batch_size=batch, shuffle=False, num_workers=workers) 286 | 287 | dataset_labelled_size = len(dataset_labelled) 288 | tmp = "Training set labelled size:\t" + str(dataset_labelled_size) 289 | utils.print_both(f, tmp) 290 | 291 | elif dataset == 'MNIST-test': 292 | import mnist 293 | tmp = "\nData preparation\nReading data from: MNIST test dataset" 294 | utils.print_both(f, tmp) 295 | img_size = [28, 28, 1] 296 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1]) 297 | utils.print_both(f, tmp) 298 | 299 | dataset = mnist.MNIST('../data', train=False, download=True, 300 | transform=transforms.Compose([ 301 | transforms.ToTensor(), 302 | # transforms.Normalize((0.1307,), (0.3081,)) 303 | ])) 304 | 305 | dataloader = torch.utils.data.DataLoader(dataset, 306 | batch_size=batch, shuffle=False, num_workers=workers) 307 | 308 | dataset_size = len(dataset) 309 | tmp = "Training set size:\t" + str(dataset_size) 310 | utils.print_both(f, tmp) 311 | 312 | dataset_labelled = mnist.MNIST('../data', train=False, download=True, small=True, 313 | transform=transforms.Compose([ 314 | transforms.ToTensor(), 315 | # transforms.Normalize((0.1307,), (0.3081,)) 316 | ])) 317 | 318 | dataloader_labelled = torch.utils.data.DataLoader(dataset_labelled, 319 | batch_size=batch, shuffle=False, num_workers=workers) 320 | 321 | dataset_labelled_size = len(dataset_labelled) 322 | tmp = "Training set labelled size:\t" + str(dataset_labelled_size) 323 | utils.print_both(f, tmp) 324 | 325 | elif dataset == 'MNIST-full': 326 | import mnist 327 | tmp = "\nData preparation\nReading data from: MNIST full dataset" 328 | utils.print_both(f, tmp) 329 | img_size = [28, 28, 1] 330 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1]) 331 | utils.print_both(f, tmp) 332 | 333 | dataset = mnist.MNIST('../data', full=True, download=True, 334 | transform=transforms.Compose([ 335 | transforms.ToTensor(), 336 | # transforms.Normalize((0.1307,), (0.3081,)) 337 | ])) 338 | 339 | dataloader = torch.utils.data.DataLoader(dataset, 340 | batch_size=batch, shuffle=False, num_workers=workers) 341 | 342 | dataset_size = len(dataset) 343 | tmp = "Training set size:\t" + str(dataset_size) 344 | utils.print_both(f, tmp) 345 | 346 | dataset_labelled = mnist.MNIST('../data', full=True, download=True, small=True, 347 | transform=transforms.Compose([ 348 | transforms.ToTensor(), 349 | # transforms.Normalize((0.1307,), (0.3081,)) 350 | ])) 351 | 352 | dataloader_labelled = torch.utils.data.DataLoader(dataset_labelled, 353 | batch_size=batch, shuffle=False, num_workers=workers) 354 | 355 | dataset_labelled_size = len(dataset_labelled) 356 | tmp = "Training set labelled size:\t" + str(dataset_labelled_size) 357 | utils.print_both(f, tmp) 358 | 359 | else: 360 | # Custom dataset - arrange folders acording to README 361 | 362 | # Data folder 363 | data_dir = args.dataset_path 364 | tmp = "\nData preparation\nReading data from:\t./" + data_dir 365 | utils.print_both(f, tmp) 366 | 367 | # Image size 368 | custom_size = math.nan 369 | custom_size = args.custom_img_size 370 | if isinstance(custom_size, list): 371 | img_size = custom_size 372 | 373 | tmp = "Image size used:\t{0}x{1}".format(img_size[0], img_size[1]) 374 | utils.print_both(f, tmp) 375 | 376 | # Transformations 377 | data_transforms = transforms.Compose([ 378 | transforms.Resize(img_size[0:2]), 379 | # transforms.RandomHorizontalFlip(), 380 | transforms.ToTensor(), 381 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 382 | ]) 383 | 384 | # Read data from selected folder and apply transformations 385 | image_dataset = datasets.ImageFolder(data_dir, data_transforms) 386 | # Prepare data for network: schuffle and arrange batches 387 | dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch, 388 | shuffle=False, num_workers=workers) 389 | 390 | # Size of data sets 391 | dataset_size = len(image_dataset) 392 | tmp = "Training set size:\t" + str(dataset_size) 393 | utils.print_both(f, tmp) 394 | 395 | # Read data from selected folder and apply transformations 396 | image_dataset_l = datasets.ImageFolder(data_dir+'_l', data_transforms) 397 | # Prepare data for network: schuffle and arrange batches 398 | dataloader_labelled = torch.utils.data.DataLoader(image_dataset_l, batch_size=batch, 399 | shuffle=False, num_workers=workers) 400 | dataset_labelled_size = len(image_dataset_l) 401 | 402 | params['dataset_size'] = dataset_size 403 | params['dataset_labelled_size'] = dataset_labelled_size 404 | 405 | # GPU check 406 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 407 | tmp = "\nPerforming calculations on:\t" + str(device) 408 | utils.print_both(f, tmp + '\n') 409 | params['device'] = device 410 | 411 | # Evaluate the proper model 412 | to_eval = "nets." + model_name + "(img_size, num_clusters=num_clusters, leaky = args.leaky, neg_slope = args.neg_slope)" 413 | model = eval(to_eval) 414 | 415 | # Tensorboard model representation 416 | # if board: 417 | # writer.add_graph(model, torch.autograd.Variable(torch.Tensor(batch, img_size[2], img_size[0], img_size[1]))) 418 | 419 | model = model.to(device) 420 | # Reconstruction loss 421 | criterion_1 = nn.MSELoss(size_average=True) 422 | # Clustering loss 423 | criterion_2 = nn.KLDivLoss(size_average=False) 424 | # Labelled loss 425 | criterion_3 = nn.CrossEntropyLoss(size_average=False) 426 | 427 | criteria = [criterion_1, criterion_2, criterion_3] 428 | 429 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=rate, weight_decay=weight) 430 | 431 | optimizer_pretrain = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=rate_pretrain, weight_decay=weight_pretrain) 432 | 433 | optimizers = [optimizer, optimizer_pretrain] 434 | 435 | scheduler = lr_scheduler.StepLR(optimizer, step_size=sched_step, gamma=sched_gamma) 436 | scheduler_pretrain = lr_scheduler.StepLR(optimizer_pretrain, step_size=sched_step_pretrain, gamma=sched_gamma_pretrain) 437 | 438 | schedulers = [scheduler, scheduler_pretrain] 439 | 440 | if args.mode == 'train_full': 441 | model = training_functions.train_semisupervised(model, [dataloader, dataloader_labelled], criteria, optimizers, schedulers, epochs, params) 442 | elif args.mode == 'pretrain': 443 | model = training_functions.pretraining(model, [dataloader, dataloader_labelled], criteria, optimizers, schedulers, epochs, params) 444 | 445 | # Save final model 446 | torch.save(model.state_dict(), name_net + '.pt') 447 | 448 | # Close files 449 | f.close() 450 | if board: 451 | writer.close() 452 | 453 | -------------------------------------------------------------------------------- /training_functions.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import time 3 | import torch 4 | import numpy as np 5 | import copy 6 | from sklearn.cluster import KMeans 7 | 8 | 9 | # Training function (from my torch_DCEC implementation, kept for completeness) 10 | def train_model(model, dataloader, criteria, optimizers, schedulers, num_epochs, params): 11 | 12 | # Note the time 13 | since = time.time() 14 | 15 | # Unpack parameters 16 | writer = params['writer'] 17 | if writer is not None: board = True 18 | txt_file = params['txt_file'] 19 | pretrained = params['model_files'][1] 20 | pretrain = params['pretrain'] 21 | print_freq = params['print_freq'] 22 | dataset_size = params['dataset_size'] 23 | device = params['device'] 24 | batch = params['batch'] 25 | pretrain_epochs = params['pretrain_epochs'] 26 | gamma = params['gamma'] 27 | update_interval = params['update_interval'] 28 | tol = params['tol'] 29 | 30 | dl = dataloader 31 | 32 | # Pretrain or load weights 33 | if pretrain: 34 | while True: 35 | pretrained_model = pretraining(model, copy.deepcopy(dl), criteria[0], optimizers[1], schedulers[1], pretrain_epochs, params) 36 | if pretrained_model: 37 | break 38 | else: 39 | for layer in model.children(): 40 | if hasattr(layer, 'reset_parameters'): 41 | layer.reset_parameters() 42 | model = pretrained_model 43 | else: 44 | try: 45 | model.load_state_dict(torch.load(pretrained)) 46 | utils.print_both(txt_file, 'Pretrained weights loaded from file: ' + str(pretrained)) 47 | except: 48 | print("Couldn't load pretrained weights") 49 | 50 | # Initialise clusters 51 | utils.print_both(txt_file, '\nInitializing cluster centers based on K-means') 52 | kmeans(model, copy.deepcopy(dl), params) 53 | 54 | utils.print_both(txt_file, '\nBegin clusters training') 55 | 56 | # Prep variables for weights and accuracy of the best model 57 | best_model_wts = copy.deepcopy(model.state_dict()) 58 | best_loss = 10000.0 59 | 60 | # Initial target distribution 61 | utils.print_both(txt_file, '\nUpdating target distribution') 62 | output_distribution, labels, preds_prev = calculate_predictions(model, copy.deepcopy(dl), params) 63 | target_distribution = target(output_distribution) 64 | nmi = utils.metrics.nmi(labels, preds_prev) 65 | ari = utils.metrics.ari(labels, preds_prev) 66 | acc = utils.metrics.acc(labels, preds_prev) 67 | utils.print_both(txt_file, 68 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\n'.format(nmi, ari, acc)) 69 | 70 | if board: 71 | niter = 0 72 | writer.add_scalar('/NMI', nmi, niter) 73 | writer.add_scalar('/ARI', ari, niter) 74 | writer.add_scalar('/Acc', acc, niter) 75 | 76 | update_iter = 1 77 | finished = False 78 | 79 | # Go through all epochs 80 | for epoch in range(num_epochs): 81 | 82 | utils.print_both(txt_file, 'Epoch {}/{}'.format(epoch + 1, num_epochs)) 83 | utils.print_both(txt_file, '-' * 10) 84 | 85 | schedulers[0].step() 86 | model.train(True) # Set model to training mode 87 | 88 | running_loss = 0.0 89 | running_loss_rec = 0.0 90 | running_loss_clust = 0.0 91 | 92 | # Keep the batch number for inter-phase statistics 93 | batch_num = 1 94 | img_counter = 0 95 | 96 | # Iterate over data. 97 | for data in dataloader: 98 | # Get the inputs and labels 99 | inputs, _ = data 100 | 101 | inputs = inputs.to(device) 102 | 103 | # Uptade target distribution, chack and print performance 104 | if (batch_num - 1) % update_interval == 0 and not (batch_num == 1 and epoch == 0): 105 | utils.print_both(txt_file, '\nUpdating target distribution:') 106 | output_distribution, labels, preds = calculate_predictions(model, dataloader, params) 107 | target_distribution = target(output_distribution) 108 | nmi = utils.metrics.nmi(labels, preds) 109 | ari = utils.metrics.ari(labels, preds) 110 | acc = utils.metrics.acc(labels, preds) 111 | utils.print_both(txt_file, 112 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\t'.format(nmi, ari, acc)) 113 | if board: 114 | niter = update_iter 115 | writer.add_scalar('/NMI', nmi, niter) 116 | writer.add_scalar('/ARI', ari, niter) 117 | writer.add_scalar('/Acc', acc, niter) 118 | update_iter += 1 119 | 120 | # check stop criterion 121 | delta_label = np.sum(preds != preds_prev).astype(np.float32) / preds.shape[0] 122 | preds_prev = np.copy(preds) 123 | if delta_label < tol: 124 | utils.print_both(txt_file, 'Label divergence ' + str(delta_label) + '< tol ' + str(tol)) 125 | utils.print_both(txt_file, 'Reached tolerance threshold. Stopping training.') 126 | finished = True 127 | break 128 | 129 | tar_dist = target_distribution[((batch_num - 1) * batch):(batch_num*batch), :] 130 | tar_dist = torch.from_numpy(tar_dist).to(device) 131 | # print(tar_dist) 132 | 133 | # zero the parameter gradients 134 | optimizers[0].zero_grad() 135 | 136 | # Calculate losses and backpropagate 137 | with torch.set_grad_enabled(True): 138 | outputs, clusters, _ = model(inputs) 139 | loss_rec = criteria[0](outputs, inputs) 140 | loss_clust = gamma *criteria[1](torch.log(clusters), tar_dist) / batch 141 | loss = loss_rec + loss_clust 142 | loss.backward() 143 | optimizers[0].step() 144 | 145 | # For keeping statistics 146 | running_loss += loss.item() * inputs.size(0) 147 | running_loss_rec += loss_rec.item() * inputs.size(0) 148 | running_loss_clust += loss_rec.item() * inputs.size(0) 149 | 150 | # Some current stats 151 | loss_batch = loss.item() 152 | loss_batch_rec = loss_rec.item() 153 | loss_batch_clust = loss_clust.item() 154 | loss_accum = running_loss / ((batch_num - 1) * batch + inputs.size(0)) 155 | loss_accum_rec = running_loss_rec / ((batch_num - 1) * batch + inputs.size(0)) 156 | loss_accum_clust = running_loss_clust / ((batch_num - 1) * batch + inputs.size(0)) 157 | 158 | if batch_num % print_freq == 0: 159 | utils.print_both(txt_file, 'Epoch: [{0}][{1}/{2}]\t' 160 | 'Loss {3:.4f} ({4:.4f})\t' 161 | 'Loss_recovery {5:.4f} ({6:.4f})\t' 162 | 'Loss clustering {7:.4f} ({8:.4f})\t'.format(epoch + 1, batch_num, 163 | len(dataloader), 164 | loss_batch, 165 | loss_accum, loss_batch_rec, 166 | loss_accum_rec, 167 | loss_batch_clust, 168 | loss_accum_clust)) 169 | if board: 170 | niter = epoch * len(dataloader) + batch_num 171 | writer.add_scalar('/Loss', loss_accum, niter) 172 | writer.add_scalar('/Loss_recovery', loss_accum_rec, niter) 173 | writer.add_scalar('/Loss_clustering', loss_accum_clust, niter) 174 | batch_num = batch_num + 1 175 | 176 | # Print image to tensorboard 177 | if batch_num == len(dataloader) and (epoch+1) % 5: 178 | inp = utils.tensor2img(inputs) 179 | out = utils.tensor2img(outputs) 180 | if board: 181 | img = np.concatenate((inp, out), axis=2) 182 | writer.add_image('Clustering/Epoch_' + str(epoch + 1).zfill(3) + '/Sample_' + str(img_counter).zfill(2), img) 183 | img_counter += 1 184 | 185 | if finished: break 186 | 187 | epoch_loss = running_loss / dataset_size 188 | epoch_loss_rec = running_loss_rec / dataset_size 189 | epoch_loss_clust = running_loss_clust / dataset_size 190 | 191 | if board: 192 | writer.add_scalar('/Loss' + '/Epoch', epoch_loss, epoch + 1) 193 | writer.add_scalar('/Loss_rec' + '/Epoch', epoch_loss_rec, epoch + 1) 194 | writer.add_scalar('/Loss_clust' + '/Epoch', epoch_loss_clust, epoch + 1) 195 | 196 | utils.print_both(txt_file, 'Loss: {0:.4f}\tLoss_recovery: {1:.4f}\tLoss_clustering: {2:.4f}'.format(epoch_loss, 197 | epoch_loss_rec, 198 | epoch_loss_clust)) 199 | 200 | # If wanted to do some criterium in the future (for now useless) 201 | if epoch_loss < best_loss or epoch_loss >= best_loss: 202 | best_loss = epoch_loss 203 | best_model_wts = copy.deepcopy(model.state_dict()) 204 | 205 | utils.print_both(txt_file, '') 206 | 207 | time_elapsed = time.time() - since 208 | utils.print_both(txt_file, 'Training complete in {:.0f}m {:.0f}s'.format( 209 | time_elapsed // 60, time_elapsed % 60)) 210 | 211 | # load best model weights 212 | model.load_state_dict(best_model_wts) 213 | return model 214 | 215 | 216 | # Training function (proper semisupervised training) 217 | def train_semisupervised(model, dataloaders, criteria, optimizers, schedulers, num_epochs, params): 218 | 219 | # Note the time 220 | since = time.time() 221 | 222 | # Unpack parameters 223 | writer = params['writer'] 224 | if writer is not None: board = True 225 | txt_file = params['txt_file'] 226 | pretrained = params['model_files'][1] 227 | pretrain = params['pretrain'] 228 | print_freq = params['print_freq'] 229 | dataset_size = params['dataset_size'] 230 | dataset_labelled_size = params['dataset_labelled_size'] 231 | device = params['device'] 232 | batch = params['batch'] 233 | pretrain_epochs = params['pretrain_epochs'] 234 | gamma = params['gamma'] 235 | gamma_lab = params['gamma_lab'] 236 | update_interval = params['update_interval'] 237 | tol = params['tol'] 238 | label_upd_interval = params['label_upd_interval'] 239 | 240 | dataloader = dataloaders[0] 241 | dataloader_labelled = dataloaders[1] 242 | 243 | dl = dataloader 244 | 245 | # Pretrain or load weights 246 | if pretrain: 247 | while True: 248 | pretrained_model = pretraining(model, copy.deepcopy(dl), criteria[0], optimizers[1], schedulers[1], pretrain_epochs, params) 249 | if pretrained_model: 250 | break 251 | else: 252 | for layer in model.children(): 253 | if hasattr(layer, 'reset_parameters'): 254 | layer.reset_parameters() 255 | model = pretrained_model 256 | else: 257 | try: 258 | model.load_state_dict(torch.load(pretrained)) 259 | utils.print_both(txt_file, 'Pretrained weights loaded from file: ' + str(pretrained)) 260 | except: 261 | print("Couldn't load pretrained weights") 262 | 263 | # Initialise clusters 264 | utils.print_both(txt_file, '\nInitializing cluster centers based on average') 265 | average_labelled_dist(model, copy.deepcopy(dataloader_labelled), params) 266 | 267 | utils.print_both(txt_file, '\nBegin clusters training') 268 | 269 | # Prep variables for weights and accuracy of the best model 270 | best_model_wts = copy.deepcopy(model.state_dict()) 271 | best_loss = 10000.0 272 | 273 | # Initial target distribution 274 | utils.print_both(txt_file, '\nUpdating target distribution') 275 | output_distribution, labels, preds_prev = calculate_predictions(model, copy.deepcopy(dl), params) 276 | target_distribution = target(output_distribution) 277 | nmi = utils.metrics.nmi(labels, preds_prev) 278 | ari = utils.metrics.ari(labels, preds_prev) 279 | acc = utils.metrics.acc(labels, preds_prev) 280 | utils.print_both(txt_file, 281 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\n'.format(nmi, ari, acc)) 282 | 283 | if board: 284 | niter = 0 285 | writer.add_scalar('/NMI', nmi, niter) 286 | writer.add_scalar('/ARI', ari, niter) 287 | writer.add_scalar('/Acc', acc, niter) 288 | 289 | update_iter = 1 290 | finished = False 291 | 292 | # Go through all epochs 293 | for epoch in range(num_epochs): 294 | 295 | utils.print_both(txt_file, 'Epoch {}/{}'.format(epoch + 1, num_epochs)) 296 | utils.print_both(txt_file, '-' * 10) 297 | 298 | schedulers[0].step() 299 | model.train(True) # Set model to training mode 300 | 301 | running_loss = 0.0 302 | running_loss_rec = 0.0 303 | running_loss_clust = 0.0 304 | running_loss_labels = 0.0 305 | 306 | # Keep the batch number for inter-phase statistics 307 | batch_num = 1 308 | img_counter = 0 309 | 310 | # print(dataloader) 311 | # Iterate over data. 312 | for data in dataloader: 313 | # Get the inputs and labels 314 | inputs, _ = data 315 | 316 | inputs = inputs.to(device) 317 | 318 | # Uptade target distribution, chack and print performance 319 | if (batch_num - 1) % update_interval == 0 and not (batch_num == 1 and epoch == 0): 320 | utils.print_both(txt_file, '\nUpdating target distribution:') 321 | output_distribution, labels, preds = calculate_predictions(model, dataloader, params) 322 | target_distribution = target(output_distribution) 323 | nmi = utils.metrics.nmi(labels, preds) 324 | ari = utils.metrics.ari(labels, preds) 325 | acc = utils.metrics.acc(labels, preds) 326 | utils.print_both(txt_file, 327 | 'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\t'.format(nmi, ari, acc)) 328 | if board: 329 | niter = update_iter 330 | writer.add_scalar('/NMI', nmi, niter) 331 | writer.add_scalar('/ARI', ari, niter) 332 | writer.add_scalar('/Acc', acc, niter) 333 | update_iter += 1 334 | 335 | # check stop criterion 336 | delta_label = np.sum(preds != preds_prev).astype(np.float32) / preds.shape[0] 337 | preds_prev = np.copy(preds) 338 | if delta_label < tol: 339 | utils.print_both(txt_file, 'Label divergence ' + str(delta_label) + ' < tol ' + str(tol)) 340 | utils.print_both(txt_file, 'Reached tolerance threshold. Stopping training.') 341 | finished = True 342 | break 343 | 344 | tar_dist = target_distribution[((batch_num - 1) * batch):(batch_num*batch), :] 345 | tar_dist = torch.from_numpy(tar_dist).to(device) 346 | # print(tar_dist) 347 | 348 | loss_labelled = 0 349 | 350 | # zero the parameter gradients 351 | optimizers[0].zero_grad() 352 | 353 | # Calculate losses and backpropagate 354 | with torch.set_grad_enabled(True): 355 | if (batch_num - 1) % label_upd_interval == 0 and not (batch_num == 1 and epoch == 0): 356 | # utils.print_both(txt_file, '\nUpdating labelled loss:') 357 | size = 0 358 | # Iterate through labelled part of the set 359 | for d in dataloader_labelled: 360 | inp, lab = d 361 | inp = inp.to(params['device']) 362 | lab = lab.to(params['device']) 363 | _, outs, _ = model(inp) 364 | loss_labelled += criteria[2](outs, lab) 365 | size += inp.size(0) 366 | loss_labelled = loss_labelled / size * gamma_lab 367 | 368 | outputs, clusters, _ = model(inputs) 369 | loss_rec = criteria[0](outputs, inputs) 370 | loss_clust = gamma *criteria[1](torch.log(clusters), tar_dist) / batch 371 | loss = loss_rec + loss_clust + loss_labelled 372 | loss.backward() 373 | optimizers[0].step() 374 | 375 | # For keeping statistics 376 | running_loss += loss.item() * inputs.size(0) 377 | running_loss_rec += loss_rec.item() * inputs.size(0) 378 | running_loss_clust += loss_rec.item() * inputs.size(0) 379 | running_loss_labels += loss_labelled * inputs.size(0) 380 | 381 | # Some current stats 382 | loss_batch = loss.item() 383 | loss_batch_rec = loss_rec.item() 384 | loss_batch_clust = loss_clust.item() 385 | loss_batch_labels = loss_labelled 386 | loss_accum = running_loss / ((batch_num - 1) * batch + inputs.size(0)) 387 | loss_accum_rec = running_loss_rec / ((batch_num - 1) * batch + inputs.size(0)) 388 | loss_accum_clust = running_loss_clust / ((batch_num - 1) * batch + inputs.size(0)) 389 | loss_accum_labels = running_loss_labels / ((batch_num - 1) * batch + inputs.size(0)) 390 | 391 | if batch_num % print_freq == 0: 392 | utils.print_both(txt_file, 'Epoch: [{0}][{1}/{2}]\t' 393 | 'Loss {3:.4f} ({4:.4f})\t' 394 | 'Loss_recovery {5:.4f} ({6:.4f})\t' 395 | 'Loss clustering {7:.4f} ({8:.4f})\t' 396 | 'Loss labels {9:.4f} ({10:.4f})\t'.format(epoch + 1, batch_num, 397 | len(dataloader), 398 | loss_batch, 399 | loss_accum, loss_batch_rec, 400 | loss_accum_rec, 401 | loss_batch_clust, 402 | loss_accum_clust, 403 | loss_batch_labels, 404 | loss_accum_labels)) 405 | if board: 406 | niter = epoch * len(dataloader) + batch_num 407 | writer.add_scalar('/Loss', loss_accum, niter) 408 | writer.add_scalar('/Loss_recovery', loss_accum_rec, niter) 409 | writer.add_scalar('/Loss_clustering', loss_accum_clust, niter) 410 | writer.add_scalar('/Loss_labels', loss_accum_labels, niter) 411 | batch_num = batch_num + 1 412 | 413 | # Print image to tensorboard 414 | if batch_num == len(dataloader) and (epoch+1) % 5: 415 | inp = utils.tensor2img(inputs) 416 | out = utils.tensor2img(outputs) 417 | if board: 418 | img = np.concatenate((inp, out), axis=2) 419 | writer.add_image('Clustering/Epoch_' + str(epoch + 1).zfill(3) + '/Sample_' + str(img_counter).zfill(2), img) 420 | img_counter += 1 421 | 422 | if finished: break 423 | 424 | epoch_loss = running_loss / dataset_size 425 | epoch_loss_rec = running_loss_rec / dataset_size 426 | epoch_loss_clust = running_loss_clust / dataset_size 427 | epoch_loss_labels = running_loss_labels / dataset_size 428 | 429 | if board: 430 | writer.add_scalar('/Loss' + '/Epoch', epoch_loss, epoch + 1) 431 | writer.add_scalar('/Loss_rec' + '/Epoch', epoch_loss_rec, epoch + 1) 432 | writer.add_scalar('/Loss_clust' + '/Epoch', epoch_loss_clust, epoch + 1) 433 | writer.add_scalar('/Loss_label' + '/Epoch', epoch_loss_labels, epoch + 1) 434 | 435 | utils.print_both(txt_file, 'Loss: {0:.4f}\tLoss_recovery: {1:.4f}\tLoss_clustering: {2:.4f}\tLoss labels: {3:.4f}'.format( 436 | epoch_loss, 437 | epoch_loss_rec, 438 | epoch_loss_clust, epoch_loss_labels)) 439 | 440 | # If wanted to do some criterium in the future (for now useless) 441 | if epoch_loss < best_loss or epoch_loss >= best_loss: 442 | best_loss = epoch_loss 443 | best_model_wts = copy.deepcopy(model.state_dict()) 444 | 445 | utils.print_both(txt_file, '') 446 | 447 | time_elapsed = time.time() - since 448 | utils.print_both(txt_file, 'Training complete in {:.0f}m {:.0f}s'.format( 449 | time_elapsed // 60, time_elapsed % 60)) 450 | 451 | # load best model weights 452 | model.load_state_dict(best_model_wts) 453 | return model 454 | 455 | 456 | # Pretraining function for recovery loss only 457 | def pretraining(model, dataloader, criterion, optimizer, scheduler, num_epochs, params): 458 | # Note the time 459 | since = time.time() 460 | 461 | # Unpack parameters 462 | writer = params['writer'] 463 | if writer is not None: board = True 464 | txt_file = params['txt_file'] 465 | pretrained = params['model_files'][1] 466 | print_freq = params['print_freq'] 467 | dataset_size = params['dataset_size'] 468 | device = params['device'] 469 | batch = params['batch'] 470 | 471 | # Prep variables for weights and accuracy of the best model 472 | best_model_wts = copy.deepcopy(model.state_dict()) 473 | best_loss = 10000.0 474 | 475 | # Go through all epochs 476 | for epoch in range(num_epochs): 477 | utils.print_both(txt_file, 'Pretraining:\tEpoch {}/{}'.format(epoch + 1, num_epochs)) 478 | utils.print_both(txt_file, '-' * 10) 479 | 480 | scheduler.step() 481 | model.train(True) # Set model to training mode 482 | 483 | running_loss = 0.0 484 | 485 | # Keep the batch number for inter-phase statistics 486 | batch_num = 1 487 | # Images to show 488 | img_counter = 0 489 | 490 | # Iterate over data. 491 | for data in dataloader: 492 | # Get the inputs and labels 493 | inputs, _ = data 494 | inputs = inputs.to(device) 495 | 496 | # zero the parameter gradients 497 | optimizer.zero_grad() 498 | 499 | with torch.set_grad_enabled(True): 500 | outputs, _, _ = model(inputs) 501 | loss = criterion(outputs, inputs) 502 | loss.backward() 503 | optimizer.step() 504 | 505 | # For keeping statistics 506 | running_loss += loss.item() * inputs.size(0) 507 | 508 | # Some current stats 509 | loss_batch = loss.item() 510 | loss_accum = running_loss / ((batch_num - 1) * batch + inputs.size(0)) 511 | 512 | if batch_num % print_freq == 0: 513 | utils.print_both(txt_file, 'Pretraining:\tEpoch: [{0}][{1}/{2}]\t' 514 | 'Loss {3:.4f} ({4:.4f})\t'.format(epoch + 1, batch_num, len(dataloader), 515 | loss_batch, 516 | loss_accum)) 517 | if board: 518 | niter = epoch * len(dataloader) + batch_num 519 | writer.add_scalar('Pretraining/Loss', loss_accum, niter) 520 | batch_num = batch_num + 1 521 | 522 | if batch_num in [len(dataloader), len(dataloader)//2, len(dataloader)//4, 3*len(dataloader)//4]: 523 | inp = utils.tensor2img(inputs) 524 | out = utils.tensor2img(outputs) 525 | if board: 526 | img = np.concatenate((inp, out), axis=2) 527 | writer.add_image('Pretraining/Epoch_' + str(epoch + 1).zfill(3) + '/Sample_' + str(img_counter).zfill(2), img) 528 | img_counter += 1 529 | 530 | epoch_loss = running_loss / dataset_size 531 | if epoch == 0: first_loss = epoch_loss 532 | if epoch == 4 and epoch_loss / first_loss > 1: 533 | utils.print_both(txt_file, "\nLoss not converging, starting pretraining again\n") 534 | return False 535 | 536 | if board: 537 | writer.add_scalar('Pretraining/Loss' + '/Epoch', epoch_loss, epoch + 1) 538 | 539 | utils.print_both(txt_file, 'Pretraining:\t Loss: {:.4f}'.format(epoch_loss)) 540 | 541 | # If wanted to add some criterium in the future 542 | if epoch_loss < best_loss or epoch_loss >= best_loss: 543 | best_loss = epoch_loss 544 | best_model_wts = copy.deepcopy(model.state_dict()) 545 | 546 | utils.print_both(txt_file, '') 547 | 548 | time_elapsed = time.time() - since 549 | utils.print_both(txt_file, 'Pretraining complete in {:.0f}m {:.0f}s'.format( 550 | time_elapsed // 60, time_elapsed % 60)) 551 | 552 | # load best model weights 553 | model.load_state_dict(best_model_wts) 554 | model.pretrained = True 555 | torch.save(model.state_dict(), pretrained) 556 | 557 | return model 558 | 559 | 560 | # K-means clusters initialisation 561 | def kmeans(model, dataloader, params): 562 | km = KMeans(n_clusters=model.num_clusters, n_init=20) 563 | output_array = None 564 | model.eval() 565 | # Itarate throught the data and concatenate the latent space representations of images 566 | for data in dataloader: 567 | inputs, _ = data 568 | inputs = inputs.to(params['device']) 569 | _, _, outputs = model(inputs) 570 | if output_array is not None: 571 | output_array = np.concatenate((output_array, outputs.cpu().detach().numpy()), 0) 572 | else: 573 | output_array = outputs.cpu().detach().numpy() 574 | # print(output_array.shape) 575 | if output_array.shape[0] > 50000: break 576 | 577 | # Perform K-means 578 | km.fit_predict(output_array) 579 | # Update clustering layer weights 580 | weights = torch.from_numpy(km.cluster_centers_) 581 | model.clustering.set_weight(weights.to(params['device'])) 582 | # torch.cuda.empty_cache() 583 | 584 | 585 | def average_labelled_dist(model, dataloader, params): 586 | output_array = None 587 | label_array = None 588 | model.eval() 589 | # Itarate throught the data and concatenate the latent space representations of images 590 | for data in dataloader: 591 | inputs, labels = data 592 | inputs = inputs.to(params['device']) 593 | _, _, outputs = model(inputs) 594 | if output_array is not None: 595 | output_array = np.concatenate((output_array, outputs.cpu().detach().numpy()), 0) 596 | label_array = np.concatenate((label_array, labels.cpu().detach().numpy()), 0) 597 | else: 598 | output_array = outputs.cpu().detach().numpy() 599 | label_array = labels.cpu().detach().numpy() 600 | 601 | # Initialise weights 602 | weights = np.zeros((model.num_clusters, model.num_clusters)) 603 | num_probes = np.zeros((model.num_clusters, 1)) 604 | 605 | # Iterate though latent space descriptors and sum labels for each cluster (keep number of elements in clusters) 606 | for j, row in enumerate(output_array): 607 | label = label_array[j] 608 | weights[label,:] += row 609 | num_probes[label] += 1 610 | 611 | # Divide by the number of elements to get average 612 | for i in range(0, weights.shape[0]): 613 | weights[i, :] /= num_probes[i] 614 | 615 | print(num_probes) 616 | 617 | # Update weights in network 618 | weights = weights.astype(np.float32) 619 | weights = torch.from_numpy(weights) 620 | model.clustering.set_weight(weights.to(params['device'])) 621 | # torch.cuda.empty_cache() 622 | 623 | 624 | # Function forwarding data through network, collecting clustering weight output and returning prediciotns and labels 625 | def calculate_predictions(model, dataloader, params): 626 | output_array = None 627 | label_array = None 628 | model.eval() 629 | for data in dataloader: 630 | inputs, labels = data 631 | inputs = inputs.to(params['device']) 632 | labels = labels.to(params['device']) 633 | _, outputs, _ = model(inputs) 634 | if output_array is not None: 635 | output_array = np.concatenate((output_array, outputs.cpu().detach().numpy()), 0) 636 | label_array = np.concatenate((label_array, labels.cpu().detach().numpy()), 0) 637 | else: 638 | output_array = outputs.cpu().detach().numpy() 639 | label_array = labels.cpu().detach().numpy() 640 | 641 | preds = np.argmax(output_array.data, axis=1) 642 | # print(output_array.shape) 643 | return output_array, label_array, preds 644 | 645 | 646 | # Calculate target distribution 647 | def target(out_distr): 648 | tar_dist = out_distr ** 2 / np.sum(out_distr, axis=0) 649 | tar_dist = np.transpose(np.transpose(tar_dist) / np.sum(tar_dist, axis=1)) 650 | return tar_dist 651 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics 3 | 4 | 5 | # Simple tensor to image translation 6 | def tensor2img(tensor): 7 | img = tensor.cpu().data[0].numpy().transpose((1, 2, 0)) 8 | mean = np.array([0.485, 0.456, 0.406]) 9 | std = np.array([0.229, 0.224, 0.225]) 10 | img = std * img + mean 11 | img = np.clip(img, 0, 1) 12 | img = img.transpose((2, 0, 1)) 13 | return img 14 | 15 | 16 | # Define printing to console and file 17 | def print_both(f, text): 18 | print(text) 19 | f.write(text + '\n') 20 | 21 | 22 | # Metrics class was copied from DCEC article authors repository (link in README) 23 | class metrics: 24 | nmi = sklearn.metrics.normalized_mutual_info_score 25 | ari = sklearn.metrics.adjusted_rand_score 26 | 27 | @staticmethod 28 | def acc(labels_true, labels_pred): 29 | labels_true = labels_true.astype(np.int64) 30 | assert labels_pred.size == labels_true.size 31 | D = max(labels_pred.max(), labels_true.max()) + 1 32 | w = np.zeros((D, D), dtype=np.int64) 33 | for i in range(labels_pred.size): 34 | w[labels_pred[i], labels_true[i]] += 1 35 | from sklearn.utils.linear_assignment_ import linear_assignment 36 | ind = linear_assignment(w.max() - w) 37 | return sum([w[i, j] for i, j in ind]) * 1.0 / labels_pred.size 38 | --------------------------------------------------------------------------------