├── clusgan ├── __version__.py ├── __init__.py ├── definitions.py ├── plots.py ├── datasets.py ├── utils.py └── models.py ├── docs └── imgs │ ├── tsne-mnist-pca.png │ ├── gen_classes_000199-mnist.png │ ├── training_cycle_loss-mnist.png │ └── training_model_losses-mnist.png ├── requirements.txt ├── Makefile ├── LICENSE ├── .gitignore ├── README.md ├── setup.py ├── gen-examples.py ├── tsne-cluster.py └── train.py /clusgan/__version__.py: -------------------------------------------------------------------------------- 1 | 2 | __version__ = '0.1.dev0' 3 | -------------------------------------------------------------------------------- /docs/imgs/tsne-mnist-pca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/tsne-mnist-pca.png -------------------------------------------------------------------------------- /docs/imgs/gen_classes_000199-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/gen_classes_000199-mnist.png -------------------------------------------------------------------------------- /docs/imgs/training_cycle_loss-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/training_cycle_loss-mnist.png -------------------------------------------------------------------------------- /docs/imgs/training_model_losses-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/training_model_losses-mnist.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.2.2 2 | numpy==1.14.5 3 | pandas==0.23.2 4 | scipy==1.1.0 5 | seaborn==0.9.0 6 | sklearn==0.0 7 | torch==1.0.0 8 | torchvision==0.2.1 9 | tqdm==4.23.4 10 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: install 2 | 3 | install: 4 | python setup.py install 5 | 6 | develop: 7 | python setup.py develop 8 | 9 | develop-uninstall: 10 | python setup.py develop --uninstall 11 | -------------------------------------------------------------------------------- /clusgan/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | 3 | from .__version__ import __version__ 4 | 5 | from .utils import * 6 | from .models import * 7 | from .plots import * 8 | from .datasets import * 9 | from .definitions import * 10 | -------------------------------------------------------------------------------- /clusgan/definitions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Local directory of CypherCat API 4 | CLUSGAN_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | # Local directory containing entire repo 7 | REPO_DIR = os.path.split(CLUSGAN_DIR)[0] 8 | 9 | # Local directory for datasets 10 | DATASETS_DIR = os.path.join(REPO_DIR, 'datasets') 11 | 12 | # Local directory for runs 13 | RUNS_DIR = os.path.join(REPO_DIR, 'runs') 14 | -------------------------------------------------------------------------------- /clusgan/plots.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import os 5 | import numpy as np 6 | 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | except ImportError as e: 11 | print(e) 12 | raise ImportError 13 | 14 | 15 | def plot_train_loss(df=[], arr_list=[''], figname='training_loss.png'): 16 | 17 | fig, ax = plt.subplots(figsize=(16,10)) 18 | for arr in arr_list: 19 | label = df[arr][0] 20 | vals = df[arr][1] 21 | epochs = range(0, len(vals)) 22 | ax.plot(epochs, vals, label=r'%s'%(label)) 23 | 24 | ax.set_xlabel('Epoch', fontsize=18) 25 | ax.set_ylabel('Loss', fontsize=18) 26 | ax.set_title('Training Loss', fontsize=24) 27 | ax.grid() 28 | #plt.yscale('log') 29 | plt.legend(loc='upper right', numpoints=1, fontsize=16) 30 | print(figname) 31 | plt.tight_layout() 32 | fig.savefig(figname) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zigfried Hampel-Arias 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /clusgan/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets 12 | import torchvision.transforms as transforms 13 | except ImportError as e: 14 | print(e) 15 | raise ImportError 16 | 17 | 18 | DATASET_FN_DICT = {'mnist' : datasets.MNIST, 19 | 'fashion-mnist' : datasets.FashionMNIST 20 | } 21 | 22 | 23 | dataset_list = DATASET_FN_DICT.keys() 24 | 25 | 26 | def get_dataset(dataset_name='mnist'): 27 | """ 28 | Convenience function for retrieving 29 | allowed datasets. 30 | Parameters 31 | ---------- 32 | name : {'mnist', 'fashion-mnist'} 33 | Name of dataset 34 | Returns 35 | ------- 36 | fn : function 37 | PyTorch dataset 38 | """ 39 | if dataset_name in DATASET_FN_DICT: 40 | fn = DATASET_FN_DICT[dataset_name] 41 | return fn 42 | else: 43 | raise ValueError('Invalid dataset, {}, entered. Must be ' 44 | 'in {}'.format(dataset_name, DATASET_FN_DICT.keys())) 45 | 46 | 47 | 48 | def get_dataloader(dataset_name='mnist', data_dir='', batch_size=64, train_set=True, num_workers=1): 49 | 50 | dset = get_dataset(dataset_name) 51 | 52 | dataloader = torch.utils.data.DataLoader( 53 | dset(data_dir, train=train_set, download=True, 54 | transform=transforms.Compose([ 55 | transforms.ToTensor(), 56 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 57 | ])), 58 | num_workers=num_workers, 59 | batch_size=batch_size, 60 | shuffle=True) 61 | 62 | return dataloader 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore contents of virtual environment directory 2 | venv/* 3 | 4 | # Ignore downloaded datasets directory 5 | datasets/* 6 | 7 | # Ignore saved run information: models, figures, images 8 | runs/* 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClusterGAN: A PyTorch Implementation 2 | 3 | This is a PyTorch implementation of [ClusterGAN](https://arxiv.org/abs/1809.03627), 4 | an approach to unsupervised clustering using generative adversarial networks. 5 | 6 | 7 | ## Requirements 8 | 9 | The package as well as the necessary requirements can be installed by running `make` or via 10 | ``` 11 | virtualenv -p /usr/local/bin/python3 venv 12 | source venv/bin/activate 13 | python setup.py install 14 | ``` 15 | 16 | ## Run ClusterGAN on MNIST 17 | 18 | To run ClusterGAN on the MNIST dataset, ensure the package is setup and then run 19 | ``` 20 | python train.py -r test_run -s mnist -b 64 -n 300 21 | ``` 22 | where a directory `runs/mnist/test_run` will be made and contain the generated output 23 | (models, example generated instances, training figures) from the training run. 24 | The `-r` option denotes the run name, `-s` the dataset (currently MNIST and Fashion-MNIST), 25 | `-b` the batch size, and `-n` the number of training epochs. 26 | 27 | 28 | Below is an example set of training curves for 200 epochs, batch size of 64 on the MNIST dataset. 29 | 30 |

31 | 32 | 33 |

34 | 35 | 36 | ## Generated Examples 37 | To generate examples from randomly sampled latent space variables, 38 | ``` 39 | python gen-examples -r test_run -s mnist -b 100 40 | ``` 41 | 42 | Here are some example generated images by specified class (each row) of the learned labels in latent space. 43 |

44 | 45 |

46 | 47 | ## TSNE Figure 48 | To produce a TSNE figure depicting the clustering of the latent space encoding of real images, 49 | ``` 50 | python tsne-cluster.py -r test_run -s mnist 51 | ``` 52 | 53 | Below is the tSNE clustering figure of the latent space vectors of true MNIST images fed into the encoder. 54 | 55 |

56 | 57 |

58 | 59 | 60 | ## License 61 | 62 | [MIT License](LICENSE) 63 | 64 | Copyright (c) 2018 Zigfried Hampel-Arias 65 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import io 5 | import os 6 | import sys 7 | from shutil import rmtree 8 | from setuptools import setup, find_packages, Command 9 | 10 | NAME = 'clusgan' 11 | DESCRIPTION = 'A PyTorch Implementation of ClusterGAN', 12 | MAINTAINER = 'Zigfried Hampel-Arias' 13 | MAINTAINER_EMAIL = 'zhampel.github@gmail.com' 14 | URL = 'https://github.com/zhampel/clusterGAN' 15 | LICENSE = 'MIT' 16 | 17 | 18 | here = os.path.abspath(os.path.dirname(__file__)) 19 | 20 | def read(path, encoding='utf-8'): 21 | with io.open(path, encoding=encoding) as f: 22 | content = f.read() 23 | return content 24 | 25 | def get_install_requirements(path): 26 | content = read(path) 27 | requirements = [req for req in content.split("\n") 28 | if req != '' and not req.startswith('#')] 29 | return requirements 30 | 31 | # README 32 | LONG_DESCRIPTION = read(os.path.join(here, 'README.md')) 33 | 34 | 35 | # Want to read in package version number from __version__.py 36 | about = {} 37 | with io.open(os.path.join(here, 'clusgan', '__version__.py'), encoding='utf-8') as f: 38 | exec(f.read(), about) 39 | VERSION = about['__version__'] 40 | 41 | # requirements 42 | INSTALL_REQUIRES = get_install_requirements(os.path.join(here, 'requirements.txt')) 43 | 44 | class UploadCommand(Command): 45 | """Support setup.py upload.""" 46 | 47 | description = 'Build and publish the package.' 48 | user_options = [] 49 | 50 | @staticmethod 51 | def status(s): 52 | """Prints things in bold.""" 53 | print('\033[1m{0}\033[0m'.format(s)) 54 | 55 | def initialize_options(self): 56 | pass 57 | 58 | def finalize_options(self): 59 | pass 60 | 61 | def run(self): 62 | try: 63 | self.status('Removing previous builds…') 64 | rmtree(os.path.join(here, 'dist')) 65 | except OSError: 66 | pass 67 | 68 | self.status('Building Source and Wheel (universal) distribution…') 69 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 70 | 71 | self.status('Uploading the package to PyPi via Twine…') 72 | os.system('twine upload dist/*') 73 | 74 | self.status('Pushing git tags…') 75 | os.system('git tag v{0}'.format(about['__version__'])) 76 | os.system('git push --tags') 77 | 78 | sys.exit() 79 | 80 | 81 | setup( 82 | name=NAME, 83 | version=VERSION, 84 | description=DESCRIPTION, 85 | license=LICENSE, 86 | long_description=LONG_DESCRIPTION, 87 | author=MAINTAINER, 88 | author_email=MAINTAINER_EMAIL, 89 | url=URL, 90 | packages=['clusgan'], 91 | install_requires=INSTALL_REQUIRES, #external packages as dependencies 92 | setup_requires=['setuptools>=38.6.0'], 93 | scripts=[ 94 | 'train.py', 95 | 'gen-examples.py', 96 | 'tsne-cluster.py', 97 | ], 98 | # $ setup.py publish support. 99 | cmdclass={ 100 | 'upload': UploadCommand, 101 | }, 102 | ) 103 | -------------------------------------------------------------------------------- /gen-examples.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import argparse 5 | import os 6 | import numpy as np 7 | import sys 8 | np.set_printoptions(threshold=sys.maxsize) 9 | 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | 13 | import pandas as pd 14 | 15 | from torch.autograd import Variable 16 | from torch.autograd import grad as torch_grad 17 | 18 | import torch 19 | import torchvision 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from torch.utils.data import DataLoader 23 | from torchvision import datasets 24 | import torchvision.transforms as transforms 25 | from torchvision.utils import save_image 26 | 27 | from itertools import chain as ichain 28 | 29 | from clusgan.definitions import DATASETS_DIR, RUNS_DIR 30 | from clusgan.models import Generator_CNN, Encoder_CNN, Discriminator_CNN 31 | from clusgan.utils import sample_z 32 | from clusgan.datasets import get_dataloader, dataset_list 33 | 34 | from sklearn.manifold import TSNE 35 | except ImportError as e: 36 | print(e) 37 | raise ImportError 38 | 39 | def main(): 40 | global args 41 | parser = argparse.ArgumentParser(description="Script to save generated examples from learned ClusterGAN generator") 42 | parser.add_argument("-r", "--run_dir", dest="run_dir", help="Training run directory") 43 | parser.add_argument("-b", "--batch_size", dest="batch_size", default=100, type=int, help="Batch size") 44 | args = parser.parse_args() 45 | 46 | batch_size = args.batch_size 47 | 48 | # Directory structure for this run 49 | run_dir = args.run_dir.rstrip("/") 50 | run_name = run_dir.split(os.sep)[-1] 51 | dataset_name = run_dir.split(os.sep)[-2] 52 | 53 | run_dir = os.path.join(RUNS_DIR, dataset_name, run_name) 54 | data_dir = os.path.join(DATASETS_DIR, dataset_name) 55 | imgs_dir = os.path.join(run_dir, 'images') 56 | models_dir = os.path.join(run_dir, 'models') 57 | 58 | 59 | # Latent space info 60 | train_df = pd.read_csv('%s/training_details.csv'%(run_dir)) 61 | latent_dim = train_df['latent_dim'][0] 62 | n_c = train_df['n_classes'][0] 63 | 64 | cuda = True if torch.cuda.is_available() else False 65 | 66 | # Load encoder model 67 | encoder = Encoder_CNN(latent_dim, n_c) 68 | enc_fname = os.path.join(models_dir, encoder.name + '.pth.tar') 69 | encoder.load_state_dict(torch.load(enc_fname)) 70 | encoder.cuda() 71 | encoder.eval() 72 | 73 | # Load generator model 74 | x_shape = (1, 28, 28) 75 | generator = Generator_CNN(latent_dim, n_c, x_shape) 76 | gen_fname = os.path.join(models_dir, generator.name + '.pth.tar') 77 | generator.load_state_dict(torch.load(gen_fname)) 78 | generator.cuda() 79 | generator.eval() 80 | 81 | # Loop through specific classes 82 | for idx in range(n_c): 83 | zn, zc, zc_idx = sample_z(shape=batch_size, latent_dim=latent_dim, n_c=n_c, fix_class=idx, req_grad=False) 84 | 85 | # Generate a batch of images 86 | gen_imgs = generator(zn, zc) 87 | 88 | # Save some examples! 89 | save_image(gen_imgs.data, '%s/class_%i_gen.png' %(imgs_dir, idx), 90 | nrow=int(np.sqrt(batch_size)), normalize=True) 91 | 92 | enc_zn, enc_zc, enc_zc_logits = encoder(gen_imgs) 93 | 94 | # Generate a batch of images 95 | gen_imgs = generator(enc_zn, enc_zc) 96 | 97 | # Save some examples! 98 | save_image(gen_imgs.data, '%s/class_enc_%i_gen.png' %(imgs_dir, idx), 99 | nrow=int(np.sqrt(batch_size)), normalize=True) 100 | enc_zn, enc_zc, enc_zc_logits = encoder(gen_imgs) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /tsne-cluster.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import argparse 5 | import os 6 | import numpy as np 7 | import sys 8 | np.set_printoptions(threshold=sys.maxsize) 9 | 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import matplotlib.cm as cm 13 | 14 | import pandas as pd 15 | 16 | from torch.autograd import Variable 17 | from torch.autograd import grad as torch_grad 18 | 19 | import torch 20 | import torchvision 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.utils.data import DataLoader 24 | from torchvision import datasets 25 | import torchvision.transforms as transforms 26 | from torchvision.utils import save_image 27 | 28 | from itertools import chain as ichain 29 | 30 | from clusgan.definitions import DATASETS_DIR, RUNS_DIR 31 | from clusgan.models import Generator_CNN, Encoder_CNN, Discriminator_CNN 32 | from clusgan.datasets import get_dataloader, dataset_list 33 | 34 | from sklearn.manifold import TSNE 35 | except ImportError as e: 36 | print(e) 37 | raise ImportError 38 | 39 | 40 | def main(): 41 | global args 42 | parser = argparse.ArgumentParser(description="TSNE generation script") 43 | parser.add_argument("-r", "--run_dir", dest="run_dir", help="Training run directory") 44 | parser.add_argument("-p", "--perplexity", dest="perplexity", default=-1, type=int, help="TSNE perplexity") 45 | parser.add_argument("-n", "--n_samples", dest="n_samples", default=100, type=int, help="Number of samples") 46 | args = parser.parse_args() 47 | 48 | # TSNE setup 49 | n_samples = args.n_samples 50 | perplexity = args.perplexity 51 | 52 | # Directory structure for this run 53 | run_dir = args.run_dir.rstrip("/") 54 | run_name = run_dir.split(os.sep)[-1] 55 | dataset_name = run_dir.split(os.sep)[-2] 56 | 57 | run_dir = os.path.join(RUNS_DIR, dataset_name, run_name) 58 | data_dir = os.path.join(DATASETS_DIR, dataset_name) 59 | imgs_dir = os.path.join(run_dir, 'images') 60 | models_dir = os.path.join(run_dir, 'models') 61 | 62 | 63 | # Latent space info 64 | train_df = pd.read_csv('%s/training_details.csv'%(run_dir)) 65 | latent_dim = train_df['latent_dim'][0] 66 | n_c = train_df['n_classes'][0] 67 | 68 | cuda = True if torch.cuda.is_available() else False 69 | 70 | # Load encoder model 71 | encoder = Encoder_CNN(latent_dim, n_c) 72 | enc_figname = os.path.join(models_dir, encoder.name + '.pth.tar') 73 | encoder.load_state_dict(torch.load(enc_figname)) 74 | encoder.cuda() 75 | encoder.eval() 76 | 77 | # Configure data loader 78 | dataloader = get_dataloader(dataset_name=dataset_name, data_dir=data_dir, batch_size=n_samples, train_set=False) 79 | 80 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 81 | 82 | # Load TSNE 83 | if (perplexity < 0): 84 | tsne = TSNE(n_components=2, verbose=1, init='pca', random_state=0) 85 | fig_title = "PCA Initialization" 86 | figname = os.path.join(run_dir, 'tsne-pca.png') 87 | else: 88 | tsne = TSNE(n_components=2, verbose=1, perplexity=perplexity, n_iter=300) 89 | fig_title = "Perplexity = $%d$"%perplexity 90 | figname = os.path.join(run_dir, 'tsne-plex%i.png'%perplexity) 91 | 92 | # Get full batch for encoding 93 | imgs, labels = next(iter(dataloader)) 94 | c_imgs = Variable(imgs.type(Tensor), requires_grad=False) 95 | 96 | # Encode real images 97 | enc_zn, enc_zc, enc_zc_logits = encoder(c_imgs) 98 | # Stack latent space encoding 99 | enc = np.hstack((enc_zn.cpu().detach().numpy(), enc_zc_logits.cpu().detach().numpy())) 100 | #enc = np.hstack((enc_zn.cpu().detach().numpy(), enc_zc.cpu().detach().numpy())) 101 | 102 | # Cluster with TSNE 103 | tsne_enc = tsne.fit_transform(enc) 104 | 105 | # Convert to numpy for indexing purposes 106 | labels = labels.cpu().data.numpy() 107 | 108 | # Color and marker for each true class 109 | colors = cm.rainbow(np.linspace(0, 1, n_c)) 110 | markers = matplotlib.markers.MarkerStyle.filled_markers 111 | 112 | # Save TSNE figure to file 113 | fig, ax = plt.subplots(figsize=(16,10)) 114 | for iclass in range(0, n_c): 115 | # Get indices for each class 116 | idxs = labels==iclass 117 | # Scatter those points in tsne dims 118 | ax.scatter(tsne_enc[idxs, 0], 119 | tsne_enc[idxs, 1], 120 | marker=markers[iclass], 121 | c=colors[iclass], 122 | edgecolor=None, 123 | label=r'$%i$'%iclass) 124 | 125 | ax.set_title(r'%s'%fig_title, fontsize=24) 126 | ax.set_xlabel(r'$X^{\mathrm{tSNE}}_1$', fontsize=18) 127 | ax.set_ylabel(r'$X^{\mathrm{tSNE}}_2$', fontsize=18) 128 | plt.legend(title=r'Class', loc='best', numpoints=1, fontsize=16) 129 | plt.tight_layout() 130 | fig.savefig(figname) 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /clusgan/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import os 5 | import numpy as np 6 | 7 | from torch.autograd import Variable 8 | from torch.autograd import grad as torch_grad 9 | 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch 13 | 14 | from itertools import chain as ichain 15 | 16 | except ImportError as e: 17 | print(e) 18 | raise ImportError 19 | 20 | 21 | 22 | # Nan-avoiding logarithm 23 | def tlog(x): 24 | return torch.log(x + 1e-8) 25 | 26 | 27 | # Softmax function 28 | def softmax(x): 29 | return F.softmax(x, dim=1) 30 | 31 | 32 | # Cross Entropy loss with two vector inputs 33 | def cross_entropy(pred, soft_targets): 34 | log_softmax_pred = torch.nn.functional.log_softmax(pred, dim=1) 35 | return torch.mean( torch.sum(- soft_targets * log_softmax_pred, 1) ) 36 | 37 | 38 | # Save a provided model to file 39 | def save_model(models=[], out_dir=''): 40 | 41 | # Ensure at least one model to save 42 | assert len(models) > 0, "Must have at least one model to save." 43 | 44 | # Save models to directory out_dir 45 | for model in models: 46 | filename = model.name + '.pth.tar' 47 | outfile = os.path.join(out_dir, filename) 48 | torch.save(model.state_dict(), outfile) 49 | 50 | 51 | # Weight Initializer 52 | def initialize_weights(net): 53 | for m in net.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | m.weight.data.normal_(0, 0.02) 56 | m.bias.data.zero_() 57 | elif isinstance(m, nn.ConvTranspose2d): 58 | m.weight.data.normal_(0, 0.02) 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.Linear): 61 | m.weight.data.normal_(0, 0.02) 62 | m.bias.data.zero_() 63 | 64 | 65 | def weights_init(m): 66 | classname = m.__class__.__name__ 67 | if classname.find('Conv') != -1: 68 | m.weight.data.normal_(0.0, 0.02) 69 | elif classname.find('BatchNorm') != -1: 70 | m.weight.data.normal_(1.0, 0.02) 71 | m.bias.data.fill_(0) 72 | 73 | # Sample a random latent space vector 74 | def sample_z(shape=64, latent_dim=10, n_c=10, fix_class=-1, req_grad=False): 75 | 76 | assert (fix_class == -1 or (fix_class >= 0 and fix_class < n_c) ), "Requested class %i outside bounds."%fix_class 77 | 78 | Tensor = torch.cuda.FloatTensor 79 | 80 | # Sample noise as generator input, zn 81 | zn = Variable(Tensor(0.75*np.random.normal(0, 1, (shape, latent_dim))), requires_grad=req_grad) 82 | 83 | ######### zc, zc_idx variables with grads, and zc to one-hot vector 84 | # Pure one-hot vector generation 85 | zc_FT = Tensor(shape, n_c).fill_(0) 86 | zc_idx = torch.empty(shape, dtype=torch.long) 87 | 88 | if (fix_class == -1): 89 | zc_idx = zc_idx.random_(n_c).cuda() 90 | zc_FT = zc_FT.scatter_(1, zc_idx.unsqueeze(1), 1.) 91 | #zc_idx = torch.empty(shape, dtype=torch.long).random_(n_c).cuda() 92 | #zc_FT = Tensor(shape, n_c).fill_(0).scatter_(1, zc_idx.unsqueeze(1), 1.) 93 | else: 94 | zc_idx[:] = fix_class 95 | zc_FT[:, fix_class] = 1 96 | 97 | zc_idx = zc_idx.cuda() 98 | zc_FT = zc_FT.cuda() 99 | 100 | zc = Variable(zc_FT, requires_grad=req_grad) 101 | 102 | ## Gaussian-noisey vector generation 103 | #zc = Variable(Tensor(np.random.normal(0, 1, (shape, n_c))), requires_grad=req_grad) 104 | #zc = softmax(zc) 105 | #zc_idx = torch.argmax(zc, dim=1) 106 | 107 | # Return components of latent space variable 108 | return zn, zc, zc_idx 109 | 110 | 111 | def calc_gradient_penalty(netD, real_data, generated_data): 112 | # GP strength 113 | LAMBDA = 10 114 | 115 | b_size = real_data.size()[0] 116 | 117 | # Calculate interpolation 118 | alpha = torch.rand(b_size, 1, 1, 1) 119 | alpha = alpha.expand_as(real_data) 120 | alpha = alpha.cuda() 121 | 122 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data 123 | interpolated = Variable(interpolated, requires_grad=True) 124 | interpolated = interpolated.cuda() 125 | 126 | # Calculate probability of interpolated examples 127 | prob_interpolated = netD(interpolated) 128 | 129 | # Calculate gradients of probabilities with respect to examples 130 | gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, 131 | grad_outputs=torch.ones(prob_interpolated.size()).cuda(), 132 | create_graph=True, retain_graph=True)[0] 133 | 134 | # Gradients have shape (batch_size, num_channels, img_width, img_height), 135 | # so flatten to easily take norm per example in batch 136 | gradients = gradients.view(b_size, -1) 137 | 138 | # Derivatives of the gradient close to 0 can cause problems because of 139 | # the square root, so manually calculate norm and add epsilon 140 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 141 | 142 | # Return gradient penalty 143 | return LAMBDA * ((gradients_norm - 1) ** 2).mean() 144 | -------------------------------------------------------------------------------- /clusgan/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import numpy as np 5 | 6 | from torch.autograd import Variable 7 | from torch.autograd import grad as torch_grad 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch 12 | 13 | from itertools import chain as ichain 14 | 15 | from clusgan.utils import tlog, softmax, initialize_weights, calc_gradient_penalty 16 | except ImportError as e: 17 | print(e) 18 | raise ImportError 19 | 20 | 21 | class Reshape(nn.Module): 22 | """ 23 | Class for performing a reshape as a layer in a sequential model. 24 | """ 25 | def __init__(self, shape=[]): 26 | super(Reshape, self).__init__() 27 | self.shape = shape 28 | 29 | def forward(self, x): 30 | return x.view(x.size(0), *self.shape) 31 | 32 | def extra_repr(self): 33 | # (Optional)Set the extra information about this module. You can test 34 | # it by printing an object of this class. 35 | return 'shape={}'.format( 36 | self.shape 37 | ) 38 | 39 | 40 | class Generator_CNN(nn.Module): 41 | """ 42 | CNN to model the generator of a ClusterGAN 43 | Input is a vector from representation space of dimension z_dim 44 | output is a vector from image space of dimension X_dim 45 | """ 46 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 47 | def __init__(self, latent_dim, n_c, x_shape, verbose=False): 48 | super(Generator_CNN, self).__init__() 49 | 50 | self.name = 'generator' 51 | self.latent_dim = latent_dim 52 | self.n_c = n_c 53 | self.x_shape = x_shape 54 | self.ishape = (128, 7, 7) 55 | self.iels = int(np.prod(self.ishape)) 56 | self.verbose = verbose 57 | 58 | self.model = nn.Sequential( 59 | # Fully connected layers 60 | torch.nn.Linear(self.latent_dim + self.n_c, 1024), 61 | nn.BatchNorm1d(1024), 62 | #torch.nn.ReLU(True), 63 | nn.LeakyReLU(0.2, inplace=True), 64 | torch.nn.Linear(1024, self.iels), 65 | nn.BatchNorm1d(self.iels), 66 | #torch.nn.ReLU(True), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | 69 | # Reshape to 128 x (7x7) 70 | Reshape(self.ishape), 71 | 72 | # Upconvolution layers 73 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=True), 74 | nn.BatchNorm2d(64), 75 | #torch.nn.ReLU(True), 76 | nn.LeakyReLU(0.2, inplace=True), 77 | 78 | nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1, bias=True), 79 | nn.Sigmoid() 80 | ) 81 | 82 | initialize_weights(self) 83 | 84 | if self.verbose: 85 | print("Setting up {}...\n".format(self.name)) 86 | print(self.model) 87 | 88 | def forward(self, zn, zc): 89 | z = torch.cat((zn, zc), 1) 90 | #z = z.unsqueeze(2).unsqueeze(3) 91 | x_gen = self.model(z) 92 | # Reshape for output 93 | x_gen = x_gen.view(x_gen.size(0), *self.x_shape) 94 | return x_gen 95 | 96 | 97 | class Encoder_CNN(nn.Module): 98 | """ 99 | CNN to model the encoder of a ClusterGAN 100 | Input is vector X from image space if dimension X_dim 101 | Output is vector z from representation space of dimension z_dim 102 | """ 103 | def __init__(self, latent_dim, n_c, verbose=False): 104 | super(Encoder_CNN, self).__init__() 105 | 106 | self.name = 'encoder' 107 | self.channels = 1 108 | self.latent_dim = latent_dim 109 | self.n_c = n_c 110 | self.cshape = (128, 5, 5) 111 | self.iels = int(np.prod(self.cshape)) 112 | self.lshape = (self.iels,) 113 | self.verbose = verbose 114 | 115 | self.model = nn.Sequential( 116 | # Convolutional layers 117 | nn.Conv2d(self.channels, 64, 4, stride=2, bias=True), 118 | nn.LeakyReLU(0.2, inplace=True), 119 | nn.Conv2d(64, 128, 4, stride=2, bias=True), 120 | nn.LeakyReLU(0.2, inplace=True), 121 | 122 | # Flatten 123 | Reshape(self.lshape), 124 | 125 | # Fully connected layers 126 | torch.nn.Linear(self.iels, 1024), 127 | nn.LeakyReLU(0.2, inplace=True), 128 | torch.nn.Linear(1024, latent_dim + n_c) 129 | ) 130 | 131 | initialize_weights(self) 132 | 133 | if self.verbose: 134 | print("Setting up {}...\n".format(self.name)) 135 | print(self.model) 136 | 137 | def forward(self, in_feat): 138 | z_img = self.model(in_feat) 139 | # Reshape for output 140 | z = z_img.view(z_img.shape[0], -1) 141 | # Separate continuous and one-hot components 142 | zn = z[:, 0:self.latent_dim] 143 | zc_logits = z[:, self.latent_dim:] 144 | # Softmax on zc component 145 | zc = softmax(zc_logits) 146 | return zn, zc, zc_logits 147 | 148 | 149 | class Discriminator_CNN(nn.Module): 150 | """ 151 | CNN to model the discriminator of a ClusterGAN 152 | Input is tuple (X,z) of an image vector and its corresponding 153 | representation z vector. For example, if X comes from the dataset, corresponding 154 | z is Encoder(X), and if z is sampled from representation space, X is Generator(z) 155 | Output is a 1-dimensional value 156 | """ 157 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 158 | def __init__(self, wass_metric=False, verbose=False): 159 | super(Discriminator_CNN, self).__init__() 160 | 161 | self.name = 'discriminator' 162 | self.channels = 1 163 | self.cshape = (128, 5, 5) 164 | self.iels = int(np.prod(self.cshape)) 165 | self.lshape = (self.iels,) 166 | self.wass = wass_metric 167 | self.verbose = verbose 168 | 169 | self.model = nn.Sequential( 170 | # Convolutional layers 171 | nn.Conv2d(self.channels, 64, 4, stride=2, bias=True), 172 | nn.LeakyReLU(0.2, inplace=True), 173 | nn.Conv2d(64, 128, 4, stride=2, bias=True), 174 | nn.LeakyReLU(0.2, inplace=True), 175 | 176 | # Flatten 177 | Reshape(self.lshape), 178 | 179 | # Fully connected layers 180 | torch.nn.Linear(self.iels, 1024), 181 | nn.LeakyReLU(0.2, inplace=True), 182 | torch.nn.Linear(1024, 1), 183 | ) 184 | 185 | # If NOT using Wasserstein metric, final Sigmoid 186 | if (not self.wass): 187 | self.model = nn.Sequential(self.model, torch.nn.Sigmoid()) 188 | 189 | initialize_weights(self) 190 | 191 | if self.verbose: 192 | print("Setting up {}...\n".format(self.name)) 193 | print(self.model) 194 | 195 | def forward(self, img): 196 | # Get output 197 | validity = self.model(img) 198 | return validity 199 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import argparse 5 | import os 6 | import numpy as np 7 | 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | import pandas as pd 12 | 13 | from torch.autograd import Variable 14 | from torch.autograd import grad as torch_grad 15 | 16 | import torch 17 | import torchvision 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from torch.utils.data import DataLoader 21 | from torchvision import datasets 22 | import torchvision.transforms as transforms 23 | from torchvision.utils import save_image 24 | 25 | from itertools import chain as ichain 26 | 27 | from clusgan.definitions import DATASETS_DIR, RUNS_DIR 28 | from clusgan.models import Generator_CNN, Encoder_CNN, Discriminator_CNN 29 | from clusgan.utils import save_model, calc_gradient_penalty, sample_z, cross_entropy 30 | from clusgan.datasets import get_dataloader, dataset_list 31 | from clusgan.plots import plot_train_loss 32 | except ImportError as e: 33 | print(e) 34 | raise ImportError 35 | 36 | def main(): 37 | global args 38 | parser = argparse.ArgumentParser(description="Convolutional NN Training Script") 39 | parser.add_argument("-r", "--run_name", dest="run_name", default='clusgan', help="Name of training run") 40 | parser.add_argument("-n", "--n_epochs", dest="n_epochs", default=200, type=int, help="Number of epochs") 41 | parser.add_argument("-b", "--batch_size", dest="batch_size", default=64, type=int, help="Batch size") 42 | parser.add_argument("-s", "--dataset_name", dest="dataset_name", default='mnist', choices=dataset_list, help="Dataset name") 43 | parser.add_argument("-w", "--wass_metric", dest="wass_metric", action='store_true', help="Flag for Wasserstein metric") 44 | parser.add_argument("-g", "-–gpu", dest="gpu", default=0, type=int, help="GPU id to use") 45 | parser.add_argument("-k", "-–num_workers", dest="num_workers", default=1, type=int, help="Number of dataset workers") 46 | args = parser.parse_args() 47 | 48 | run_name = args.run_name 49 | dataset_name = args.dataset_name 50 | device_id = args.gpu 51 | num_workers = args.num_workers 52 | 53 | # Training details 54 | n_epochs = args.n_epochs 55 | batch_size = args.batch_size 56 | test_batch_size = 5000 57 | lr = 1e-4 58 | b1 = 0.5 59 | b2 = 0.9 #99 60 | decay = 2.5*1e-5 61 | n_skip_iter = 1 #5 62 | 63 | img_size = 28 64 | channels = 1 65 | 66 | # Latent space info 67 | latent_dim = 30 68 | n_c = 10 69 | betan = 10 70 | betac = 10 71 | 72 | # Wasserstein metric flag 73 | # Wasserstein metric flag 74 | wass_metric = args.wass_metric 75 | mtype = 'van' 76 | if (wass_metric): 77 | mtype = 'wass' 78 | 79 | # Make directory structure for this run 80 | sep_und = '_' 81 | run_name_comps = ['%iepoch'%n_epochs, 'z%s'%str(latent_dim), mtype, 'bs%i'%batch_size, run_name] 82 | run_name = sep_und.join(run_name_comps) 83 | 84 | run_dir = os.path.join(RUNS_DIR, dataset_name, run_name) 85 | data_dir = os.path.join(DATASETS_DIR, dataset_name) 86 | imgs_dir = os.path.join(run_dir, 'images') 87 | models_dir = os.path.join(run_dir, 'models') 88 | 89 | os.makedirs(data_dir, exist_ok=True) 90 | os.makedirs(run_dir, exist_ok=True) 91 | os.makedirs(imgs_dir, exist_ok=True) 92 | os.makedirs(models_dir, exist_ok=True) 93 | print('\nResults to be saved in directory %s\n'%(run_dir)) 94 | 95 | x_shape = (channels, img_size, img_size) 96 | 97 | cuda = True if torch.cuda.is_available() else False 98 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 99 | torch.cuda.set_device(device_id) 100 | 101 | # Loss function 102 | bce_loss = torch.nn.BCELoss() 103 | xe_loss = torch.nn.CrossEntropyLoss() 104 | mse_loss = torch.nn.MSELoss() 105 | 106 | # Initialize generator and discriminator 107 | generator = Generator_CNN(latent_dim, n_c, x_shape) 108 | encoder = Encoder_CNN(latent_dim, n_c) 109 | discriminator = Discriminator_CNN(wass_metric=wass_metric) 110 | 111 | if cuda: 112 | generator.cuda() 113 | encoder.cuda() 114 | discriminator.cuda() 115 | bce_loss.cuda() 116 | xe_loss.cuda() 117 | mse_loss.cuda() 118 | 119 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 120 | 121 | # Configure training data loader 122 | dataloader = get_dataloader(dataset_name=dataset_name, 123 | data_dir=data_dir, 124 | batch_size=batch_size, 125 | num_workers=num_workers) 126 | 127 | # Test data loader 128 | testdata = get_dataloader(dataset_name=dataset_name, data_dir=data_dir, batch_size=test_batch_size, train_set=False) 129 | test_imgs, test_labels = next(iter(testdata)) 130 | test_imgs = Variable(test_imgs.type(Tensor)) 131 | 132 | ge_chain = ichain(generator.parameters(), 133 | encoder.parameters()) 134 | optimizer_GE = torch.optim.Adam(ge_chain, lr=lr, betas=(b1, b2), weight_decay=decay) 135 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2)) 136 | #optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2), weight_decay=decay) 137 | 138 | # ---------- 139 | # Training 140 | # ---------- 141 | ge_l = [] 142 | d_l = [] 143 | 144 | c_zn = [] 145 | c_zc = [] 146 | c_i = [] 147 | 148 | # Training loop 149 | print('\nBegin training session with %i epochs...\n'%(n_epochs)) 150 | for epoch in range(n_epochs): 151 | for i, (imgs, itruth_label) in enumerate(dataloader): 152 | 153 | # Ensure generator/encoder are trainable 154 | generator.train() 155 | encoder.train() 156 | # Zero gradients for models 157 | generator.zero_grad() 158 | encoder.zero_grad() 159 | discriminator.zero_grad() 160 | 161 | # Configure input 162 | real_imgs = Variable(imgs.type(Tensor)) 163 | 164 | # --------------------------- 165 | # Train Generator + Encoder 166 | # --------------------------- 167 | 168 | optimizer_GE.zero_grad() 169 | 170 | # Sample random latent variables 171 | zn, zc, zc_idx = sample_z(shape=imgs.shape[0], 172 | latent_dim=latent_dim, 173 | n_c=n_c) 174 | 175 | # Generate a batch of images 176 | gen_imgs = generator(zn, zc) 177 | 178 | # Discriminator output from real and generated samples 179 | D_gen = discriminator(gen_imgs) 180 | D_real = discriminator(real_imgs) 181 | 182 | # Step for Generator & Encoder, n_skip_iter times less than for discriminator 183 | if (i % n_skip_iter == 0): 184 | # Encode the generated images 185 | enc_gen_zn, enc_gen_zc, enc_gen_zc_logits = encoder(gen_imgs) 186 | 187 | # Calculate losses for z_n, z_c 188 | zn_loss = mse_loss(enc_gen_zn, zn) 189 | zc_loss = xe_loss(enc_gen_zc_logits, zc_idx) 190 | #zc_loss = cross_entropy(enc_gen_zc_logits, zc) 191 | 192 | # Check requested metric 193 | if wass_metric: 194 | # Wasserstein GAN loss 195 | ge_loss = torch.mean(D_gen) + betan * zn_loss + betac * zc_loss 196 | else: 197 | # Vanilla GAN loss 198 | valid = Variable(Tensor(gen_imgs.size(0), 1).fill_(1.0), requires_grad=False) 199 | v_loss = bce_loss(D_gen, valid) 200 | ge_loss = v_loss + betan * zn_loss + betac * zc_loss 201 | 202 | ge_loss.backward(retain_graph=True) 203 | optimizer_GE.step() 204 | 205 | # --------------------- 206 | # Train Discriminator 207 | # --------------------- 208 | 209 | optimizer_D.zero_grad() 210 | 211 | # Measure discriminator's ability to classify real from generated samples 212 | if wass_metric: 213 | # Gradient penalty term 214 | grad_penalty = calc_gradient_penalty(discriminator, real_imgs, gen_imgs) 215 | 216 | # Wasserstein GAN loss w/gradient penalty 217 | d_loss = torch.mean(D_real) - torch.mean(D_gen) + grad_penalty 218 | 219 | else: 220 | # Vanilla GAN loss 221 | fake = Variable(Tensor(gen_imgs.size(0), 1).fill_(0.0), requires_grad=False) 222 | real_loss = bce_loss(D_real, valid) 223 | fake_loss = bce_loss(D_gen, fake) 224 | d_loss = (real_loss + fake_loss) / 2 225 | 226 | d_loss.backward() 227 | optimizer_D.step() 228 | 229 | 230 | # Save training losses 231 | d_l.append(d_loss.item()) 232 | ge_l.append(ge_loss.item()) 233 | 234 | 235 | # Generator in eval mode 236 | generator.eval() 237 | encoder.eval() 238 | 239 | # Set number of examples for cycle calcs 240 | n_sqrt_samp = 5 241 | n_samp = n_sqrt_samp * n_sqrt_samp 242 | 243 | 244 | ## Cycle through test real -> enc -> gen 245 | t_imgs, t_label = test_imgs.data, test_labels 246 | #r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp] 247 | # Encode sample real instances 248 | e_tzn, e_tzc, e_tzc_logits = encoder(t_imgs) 249 | # Generate sample instances from encoding 250 | teg_imgs = generator(e_tzn, e_tzc) 251 | # Calculate cycle reconstruction loss 252 | img_mse_loss = mse_loss(t_imgs, teg_imgs) 253 | # Save img reco cycle loss 254 | c_i.append(img_mse_loss.item()) 255 | 256 | 257 | ## Cycle through randomly sampled encoding -> generator -> encoder 258 | zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_samp, 259 | latent_dim=latent_dim, 260 | n_c=n_c) 261 | # Generate sample instances 262 | gen_imgs_samp = generator(zn_samp, zc_samp) 263 | # Encode sample instances 264 | zn_e, zc_e, zc_e_logits = encoder(gen_imgs_samp) 265 | # Calculate cycle latent losses 266 | lat_mse_loss = mse_loss(zn_e, zn_samp) 267 | lat_xe_loss = xe_loss(zc_e_logits, zc_samp_idx) 268 | #lat_xe_loss = cross_entropy(zc_e_logits, zc_samp) 269 | # Save latent space cycle losses 270 | c_zn.append(lat_mse_loss.item()) 271 | c_zc.append(lat_xe_loss.item()) 272 | 273 | # Save cycled and generated examples! 274 | r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp] 275 | e_zn, e_zc, e_zc_logits = encoder(r_imgs) 276 | reg_imgs = generator(e_zn, e_zc) 277 | save_image(r_imgs.data[:n_samp], 278 | '%s/real_%06i.png' %(imgs_dir, epoch), 279 | nrow=n_sqrt_samp, normalize=True) 280 | save_image(reg_imgs.data[:n_samp], 281 | '%s/reg_%06i.png' %(imgs_dir, epoch), 282 | nrow=n_sqrt_samp, normalize=True) 283 | save_image(gen_imgs_samp.data[:n_samp], 284 | '%s/gen_%06i.png' %(imgs_dir, epoch), 285 | nrow=n_sqrt_samp, normalize=True) 286 | 287 | ## Generate samples for specified classes 288 | stack_imgs = [] 289 | for idx in range(n_c): 290 | # Sample specific class 291 | zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_c, 292 | latent_dim=latent_dim, 293 | n_c=n_c, 294 | fix_class=idx) 295 | 296 | # Generate sample instances 297 | gen_imgs_samp = generator(zn_samp, zc_samp) 298 | 299 | if (len(stack_imgs) == 0): 300 | stack_imgs = gen_imgs_samp 301 | else: 302 | stack_imgs = torch.cat((stack_imgs, gen_imgs_samp), 0) 303 | 304 | # Save class-specified generated examples! 305 | save_image(stack_imgs, 306 | '%s/gen_classes_%06i.png' %(imgs_dir, epoch), 307 | nrow=n_c, normalize=True) 308 | 309 | 310 | print ("[Epoch %d/%d] \n"\ 311 | "\tModel Losses: [D: %f] [GE: %f]" % (epoch, 312 | n_epochs, 313 | d_loss.item(), 314 | ge_loss.item()) 315 | ) 316 | 317 | print("\tCycle Losses: [x: %f] [z_n: %f] [z_c: %f]"%(img_mse_loss.item(), 318 | lat_mse_loss.item(), 319 | lat_xe_loss.item()) 320 | ) 321 | 322 | 323 | 324 | 325 | # Save training results 326 | train_df = pd.DataFrame({ 327 | 'n_epochs' : n_epochs, 328 | 'learning_rate' : lr, 329 | 'beta_1' : b1, 330 | 'beta_2' : b2, 331 | 'weight_decay' : decay, 332 | 'n_skip_iter' : n_skip_iter, 333 | 'latent_dim' : latent_dim, 334 | 'n_classes' : n_c, 335 | 'beta_n' : betan, 336 | 'beta_c' : betac, 337 | 'wass_metric' : wass_metric, 338 | 'gen_enc_loss' : ['G+E', ge_l], 339 | 'disc_loss' : ['D', d_l], 340 | 'zn_cycle_loss' : ['$||Z_n-E(G(x))_n||$', c_zn], 341 | 'zc_cycle_loss' : ['$||Z_c-E(G(x))_c||$', c_zc], 342 | 'img_cycle_loss' : ['$||X-G(E(x))||$', c_i] 343 | }) 344 | 345 | train_df.to_csv('%s/training_details.csv'%(run_dir)) 346 | 347 | 348 | # Plot some training results 349 | plot_train_loss(df=train_df, 350 | arr_list=['gen_enc_loss', 'disc_loss'], 351 | figname='%s/training_model_losses.png'%(run_dir) 352 | ) 353 | 354 | plot_train_loss(df=train_df, 355 | arr_list=['zn_cycle_loss', 'zc_cycle_loss', 'img_cycle_loss'], 356 | figname='%s/training_cycle_loss.png'%(run_dir) 357 | ) 358 | 359 | 360 | # Save current state of trained models 361 | model_list = [discriminator, encoder, generator] 362 | save_model(models=model_list, out_dir=models_dir) 363 | 364 | 365 | if __name__ == "__main__": 366 | main() 367 | --------------------------------------------------------------------------------