├── clusgan
├── __version__.py
├── __init__.py
├── definitions.py
├── plots.py
├── datasets.py
├── utils.py
└── models.py
├── docs
└── imgs
│ ├── tsne-mnist-pca.png
│ ├── gen_classes_000199-mnist.png
│ ├── training_cycle_loss-mnist.png
│ └── training_model_losses-mnist.png
├── requirements.txt
├── Makefile
├── LICENSE
├── .gitignore
├── README.md
├── setup.py
├── gen-examples.py
├── tsne-cluster.py
└── train.py
/clusgan/__version__.py:
--------------------------------------------------------------------------------
1 |
2 | __version__ = '0.1.dev0'
3 |
--------------------------------------------------------------------------------
/docs/imgs/tsne-mnist-pca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/tsne-mnist-pca.png
--------------------------------------------------------------------------------
/docs/imgs/gen_classes_000199-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/gen_classes_000199-mnist.png
--------------------------------------------------------------------------------
/docs/imgs/training_cycle_loss-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/training_cycle_loss-mnist.png
--------------------------------------------------------------------------------
/docs/imgs/training_model_losses-mnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhampel/clusterGAN/HEAD/docs/imgs/training_model_losses-mnist.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.2.2
2 | numpy==1.14.5
3 | pandas==0.23.2
4 | scipy==1.1.0
5 | seaborn==0.9.0
6 | sklearn==0.0
7 | torch==1.0.0
8 | torchvision==0.2.1
9 | tqdm==4.23.4
10 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all: install
2 |
3 | install:
4 | python setup.py install
5 |
6 | develop:
7 | python setup.py develop
8 |
9 | develop-uninstall:
10 | python setup.py develop --uninstall
11 |
--------------------------------------------------------------------------------
/clusgan/__init__.py:
--------------------------------------------------------------------------------
1 | # __init__.py
2 |
3 | from .__version__ import __version__
4 |
5 | from .utils import *
6 | from .models import *
7 | from .plots import *
8 | from .datasets import *
9 | from .definitions import *
10 |
--------------------------------------------------------------------------------
/clusgan/definitions.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # Local directory of CypherCat API
4 | CLUSGAN_DIR = os.path.dirname(os.path.abspath(__file__))
5 |
6 | # Local directory containing entire repo
7 | REPO_DIR = os.path.split(CLUSGAN_DIR)[0]
8 |
9 | # Local directory for datasets
10 | DATASETS_DIR = os.path.join(REPO_DIR, 'datasets')
11 |
12 | # Local directory for runs
13 | RUNS_DIR = os.path.join(REPO_DIR, 'runs')
14 |
--------------------------------------------------------------------------------
/clusgan/plots.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import os
5 | import numpy as np
6 |
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 |
10 | except ImportError as e:
11 | print(e)
12 | raise ImportError
13 |
14 |
15 | def plot_train_loss(df=[], arr_list=[''], figname='training_loss.png'):
16 |
17 | fig, ax = plt.subplots(figsize=(16,10))
18 | for arr in arr_list:
19 | label = df[arr][0]
20 | vals = df[arr][1]
21 | epochs = range(0, len(vals))
22 | ax.plot(epochs, vals, label=r'%s'%(label))
23 |
24 | ax.set_xlabel('Epoch', fontsize=18)
25 | ax.set_ylabel('Loss', fontsize=18)
26 | ax.set_title('Training Loss', fontsize=24)
27 | ax.grid()
28 | #plt.yscale('log')
29 | plt.legend(loc='upper right', numpoints=1, fontsize=16)
30 | print(figname)
31 | plt.tight_layout()
32 | fig.savefig(figname)
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Zigfried Hampel-Arias
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/clusgan/datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import numpy as np
5 |
6 | import torch
7 | import torchvision
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | from torchvision import datasets
12 | import torchvision.transforms as transforms
13 | except ImportError as e:
14 | print(e)
15 | raise ImportError
16 |
17 |
18 | DATASET_FN_DICT = {'mnist' : datasets.MNIST,
19 | 'fashion-mnist' : datasets.FashionMNIST
20 | }
21 |
22 |
23 | dataset_list = DATASET_FN_DICT.keys()
24 |
25 |
26 | def get_dataset(dataset_name='mnist'):
27 | """
28 | Convenience function for retrieving
29 | allowed datasets.
30 | Parameters
31 | ----------
32 | name : {'mnist', 'fashion-mnist'}
33 | Name of dataset
34 | Returns
35 | -------
36 | fn : function
37 | PyTorch dataset
38 | """
39 | if dataset_name in DATASET_FN_DICT:
40 | fn = DATASET_FN_DICT[dataset_name]
41 | return fn
42 | else:
43 | raise ValueError('Invalid dataset, {}, entered. Must be '
44 | 'in {}'.format(dataset_name, DATASET_FN_DICT.keys()))
45 |
46 |
47 |
48 | def get_dataloader(dataset_name='mnist', data_dir='', batch_size=64, train_set=True, num_workers=1):
49 |
50 | dset = get_dataset(dataset_name)
51 |
52 | dataloader = torch.utils.data.DataLoader(
53 | dset(data_dir, train=train_set, download=True,
54 | transform=transforms.Compose([
55 | transforms.ToTensor(),
56 | #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
57 | ])),
58 | num_workers=num_workers,
59 | batch_size=batch_size,
60 | shuffle=True)
61 |
62 | return dataloader
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore contents of virtual environment directory
2 | venv/*
3 |
4 | # Ignore downloaded datasets directory
5 | datasets/*
6 |
7 | # Ignore saved run information: models, figures, images
8 | runs/*
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # Environments
94 | .env
95 | .venv
96 | env/
97 | venv/
98 | ENV/
99 | env.bak/
100 | venv.bak/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ClusterGAN: A PyTorch Implementation
2 |
3 | This is a PyTorch implementation of [ClusterGAN](https://arxiv.org/abs/1809.03627),
4 | an approach to unsupervised clustering using generative adversarial networks.
5 |
6 |
7 | ## Requirements
8 |
9 | The package as well as the necessary requirements can be installed by running `make` or via
10 | ```
11 | virtualenv -p /usr/local/bin/python3 venv
12 | source venv/bin/activate
13 | python setup.py install
14 | ```
15 |
16 | ## Run ClusterGAN on MNIST
17 |
18 | To run ClusterGAN on the MNIST dataset, ensure the package is setup and then run
19 | ```
20 | python train.py -r test_run -s mnist -b 64 -n 300
21 | ```
22 | where a directory `runs/mnist/test_run` will be made and contain the generated output
23 | (models, example generated instances, training figures) from the training run.
24 | The `-r` option denotes the run name, `-s` the dataset (currently MNIST and Fashion-MNIST),
25 | `-b` the batch size, and `-n` the number of training epochs.
26 |
27 |
28 | Below is an example set of training curves for 200 epochs, batch size of 64 on the MNIST dataset.
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 | ## Generated Examples
37 | To generate examples from randomly sampled latent space variables,
38 | ```
39 | python gen-examples -r test_run -s mnist -b 100
40 | ```
41 |
42 | Here are some example generated images by specified class (each row) of the learned labels in latent space.
43 |
44 |
45 |
46 |
47 | ## TSNE Figure
48 | To produce a TSNE figure depicting the clustering of the latent space encoding of real images,
49 | ```
50 | python tsne-cluster.py -r test_run -s mnist
51 | ```
52 |
53 | Below is the tSNE clustering figure of the latent space vectors of true MNIST images fed into the encoder.
54 |
55 |
56 |
57 |
58 |
59 |
60 | ## License
61 |
62 | [MIT License](LICENSE)
63 |
64 | Copyright (c) 2018 Zigfried Hampel-Arias
65 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import io
5 | import os
6 | import sys
7 | from shutil import rmtree
8 | from setuptools import setup, find_packages, Command
9 |
10 | NAME = 'clusgan'
11 | DESCRIPTION = 'A PyTorch Implementation of ClusterGAN',
12 | MAINTAINER = 'Zigfried Hampel-Arias'
13 | MAINTAINER_EMAIL = 'zhampel.github@gmail.com'
14 | URL = 'https://github.com/zhampel/clusterGAN'
15 | LICENSE = 'MIT'
16 |
17 |
18 | here = os.path.abspath(os.path.dirname(__file__))
19 |
20 | def read(path, encoding='utf-8'):
21 | with io.open(path, encoding=encoding) as f:
22 | content = f.read()
23 | return content
24 |
25 | def get_install_requirements(path):
26 | content = read(path)
27 | requirements = [req for req in content.split("\n")
28 | if req != '' and not req.startswith('#')]
29 | return requirements
30 |
31 | # README
32 | LONG_DESCRIPTION = read(os.path.join(here, 'README.md'))
33 |
34 |
35 | # Want to read in package version number from __version__.py
36 | about = {}
37 | with io.open(os.path.join(here, 'clusgan', '__version__.py'), encoding='utf-8') as f:
38 | exec(f.read(), about)
39 | VERSION = about['__version__']
40 |
41 | # requirements
42 | INSTALL_REQUIRES = get_install_requirements(os.path.join(here, 'requirements.txt'))
43 |
44 | class UploadCommand(Command):
45 | """Support setup.py upload."""
46 |
47 | description = 'Build and publish the package.'
48 | user_options = []
49 |
50 | @staticmethod
51 | def status(s):
52 | """Prints things in bold."""
53 | print('\033[1m{0}\033[0m'.format(s))
54 |
55 | def initialize_options(self):
56 | pass
57 |
58 | def finalize_options(self):
59 | pass
60 |
61 | def run(self):
62 | try:
63 | self.status('Removing previous builds…')
64 | rmtree(os.path.join(here, 'dist'))
65 | except OSError:
66 | pass
67 |
68 | self.status('Building Source and Wheel (universal) distribution…')
69 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable))
70 |
71 | self.status('Uploading the package to PyPi via Twine…')
72 | os.system('twine upload dist/*')
73 |
74 | self.status('Pushing git tags…')
75 | os.system('git tag v{0}'.format(about['__version__']))
76 | os.system('git push --tags')
77 |
78 | sys.exit()
79 |
80 |
81 | setup(
82 | name=NAME,
83 | version=VERSION,
84 | description=DESCRIPTION,
85 | license=LICENSE,
86 | long_description=LONG_DESCRIPTION,
87 | author=MAINTAINER,
88 | author_email=MAINTAINER_EMAIL,
89 | url=URL,
90 | packages=['clusgan'],
91 | install_requires=INSTALL_REQUIRES, #external packages as dependencies
92 | setup_requires=['setuptools>=38.6.0'],
93 | scripts=[
94 | 'train.py',
95 | 'gen-examples.py',
96 | 'tsne-cluster.py',
97 | ],
98 | # $ setup.py publish support.
99 | cmdclass={
100 | 'upload': UploadCommand,
101 | },
102 | )
103 |
--------------------------------------------------------------------------------
/gen-examples.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import argparse
5 | import os
6 | import numpy as np
7 | import sys
8 | np.set_printoptions(threshold=sys.maxsize)
9 |
10 | import matplotlib
11 | import matplotlib.pyplot as plt
12 |
13 | import pandas as pd
14 |
15 | from torch.autograd import Variable
16 | from torch.autograd import grad as torch_grad
17 |
18 | import torch
19 | import torchvision
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from torch.utils.data import DataLoader
23 | from torchvision import datasets
24 | import torchvision.transforms as transforms
25 | from torchvision.utils import save_image
26 |
27 | from itertools import chain as ichain
28 |
29 | from clusgan.definitions import DATASETS_DIR, RUNS_DIR
30 | from clusgan.models import Generator_CNN, Encoder_CNN, Discriminator_CNN
31 | from clusgan.utils import sample_z
32 | from clusgan.datasets import get_dataloader, dataset_list
33 |
34 | from sklearn.manifold import TSNE
35 | except ImportError as e:
36 | print(e)
37 | raise ImportError
38 |
39 | def main():
40 | global args
41 | parser = argparse.ArgumentParser(description="Script to save generated examples from learned ClusterGAN generator")
42 | parser.add_argument("-r", "--run_dir", dest="run_dir", help="Training run directory")
43 | parser.add_argument("-b", "--batch_size", dest="batch_size", default=100, type=int, help="Batch size")
44 | args = parser.parse_args()
45 |
46 | batch_size = args.batch_size
47 |
48 | # Directory structure for this run
49 | run_dir = args.run_dir.rstrip("/")
50 | run_name = run_dir.split(os.sep)[-1]
51 | dataset_name = run_dir.split(os.sep)[-2]
52 |
53 | run_dir = os.path.join(RUNS_DIR, dataset_name, run_name)
54 | data_dir = os.path.join(DATASETS_DIR, dataset_name)
55 | imgs_dir = os.path.join(run_dir, 'images')
56 | models_dir = os.path.join(run_dir, 'models')
57 |
58 |
59 | # Latent space info
60 | train_df = pd.read_csv('%s/training_details.csv'%(run_dir))
61 | latent_dim = train_df['latent_dim'][0]
62 | n_c = train_df['n_classes'][0]
63 |
64 | cuda = True if torch.cuda.is_available() else False
65 |
66 | # Load encoder model
67 | encoder = Encoder_CNN(latent_dim, n_c)
68 | enc_fname = os.path.join(models_dir, encoder.name + '.pth.tar')
69 | encoder.load_state_dict(torch.load(enc_fname))
70 | encoder.cuda()
71 | encoder.eval()
72 |
73 | # Load generator model
74 | x_shape = (1, 28, 28)
75 | generator = Generator_CNN(latent_dim, n_c, x_shape)
76 | gen_fname = os.path.join(models_dir, generator.name + '.pth.tar')
77 | generator.load_state_dict(torch.load(gen_fname))
78 | generator.cuda()
79 | generator.eval()
80 |
81 | # Loop through specific classes
82 | for idx in range(n_c):
83 | zn, zc, zc_idx = sample_z(shape=batch_size, latent_dim=latent_dim, n_c=n_c, fix_class=idx, req_grad=False)
84 |
85 | # Generate a batch of images
86 | gen_imgs = generator(zn, zc)
87 |
88 | # Save some examples!
89 | save_image(gen_imgs.data, '%s/class_%i_gen.png' %(imgs_dir, idx),
90 | nrow=int(np.sqrt(batch_size)), normalize=True)
91 |
92 | enc_zn, enc_zc, enc_zc_logits = encoder(gen_imgs)
93 |
94 | # Generate a batch of images
95 | gen_imgs = generator(enc_zn, enc_zc)
96 |
97 | # Save some examples!
98 | save_image(gen_imgs.data, '%s/class_enc_%i_gen.png' %(imgs_dir, idx),
99 | nrow=int(np.sqrt(batch_size)), normalize=True)
100 | enc_zn, enc_zc, enc_zc_logits = encoder(gen_imgs)
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/tsne-cluster.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import argparse
5 | import os
6 | import numpy as np
7 | import sys
8 | np.set_printoptions(threshold=sys.maxsize)
9 |
10 | import matplotlib
11 | import matplotlib.pyplot as plt
12 | import matplotlib.cm as cm
13 |
14 | import pandas as pd
15 |
16 | from torch.autograd import Variable
17 | from torch.autograd import grad as torch_grad
18 |
19 | import torch
20 | import torchvision
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | from torch.utils.data import DataLoader
24 | from torchvision import datasets
25 | import torchvision.transforms as transforms
26 | from torchvision.utils import save_image
27 |
28 | from itertools import chain as ichain
29 |
30 | from clusgan.definitions import DATASETS_DIR, RUNS_DIR
31 | from clusgan.models import Generator_CNN, Encoder_CNN, Discriminator_CNN
32 | from clusgan.datasets import get_dataloader, dataset_list
33 |
34 | from sklearn.manifold import TSNE
35 | except ImportError as e:
36 | print(e)
37 | raise ImportError
38 |
39 |
40 | def main():
41 | global args
42 | parser = argparse.ArgumentParser(description="TSNE generation script")
43 | parser.add_argument("-r", "--run_dir", dest="run_dir", help="Training run directory")
44 | parser.add_argument("-p", "--perplexity", dest="perplexity", default=-1, type=int, help="TSNE perplexity")
45 | parser.add_argument("-n", "--n_samples", dest="n_samples", default=100, type=int, help="Number of samples")
46 | args = parser.parse_args()
47 |
48 | # TSNE setup
49 | n_samples = args.n_samples
50 | perplexity = args.perplexity
51 |
52 | # Directory structure for this run
53 | run_dir = args.run_dir.rstrip("/")
54 | run_name = run_dir.split(os.sep)[-1]
55 | dataset_name = run_dir.split(os.sep)[-2]
56 |
57 | run_dir = os.path.join(RUNS_DIR, dataset_name, run_name)
58 | data_dir = os.path.join(DATASETS_DIR, dataset_name)
59 | imgs_dir = os.path.join(run_dir, 'images')
60 | models_dir = os.path.join(run_dir, 'models')
61 |
62 |
63 | # Latent space info
64 | train_df = pd.read_csv('%s/training_details.csv'%(run_dir))
65 | latent_dim = train_df['latent_dim'][0]
66 | n_c = train_df['n_classes'][0]
67 |
68 | cuda = True if torch.cuda.is_available() else False
69 |
70 | # Load encoder model
71 | encoder = Encoder_CNN(latent_dim, n_c)
72 | enc_figname = os.path.join(models_dir, encoder.name + '.pth.tar')
73 | encoder.load_state_dict(torch.load(enc_figname))
74 | encoder.cuda()
75 | encoder.eval()
76 |
77 | # Configure data loader
78 | dataloader = get_dataloader(dataset_name=dataset_name, data_dir=data_dir, batch_size=n_samples, train_set=False)
79 |
80 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
81 |
82 | # Load TSNE
83 | if (perplexity < 0):
84 | tsne = TSNE(n_components=2, verbose=1, init='pca', random_state=0)
85 | fig_title = "PCA Initialization"
86 | figname = os.path.join(run_dir, 'tsne-pca.png')
87 | else:
88 | tsne = TSNE(n_components=2, verbose=1, perplexity=perplexity, n_iter=300)
89 | fig_title = "Perplexity = $%d$"%perplexity
90 | figname = os.path.join(run_dir, 'tsne-plex%i.png'%perplexity)
91 |
92 | # Get full batch for encoding
93 | imgs, labels = next(iter(dataloader))
94 | c_imgs = Variable(imgs.type(Tensor), requires_grad=False)
95 |
96 | # Encode real images
97 | enc_zn, enc_zc, enc_zc_logits = encoder(c_imgs)
98 | # Stack latent space encoding
99 | enc = np.hstack((enc_zn.cpu().detach().numpy(), enc_zc_logits.cpu().detach().numpy()))
100 | #enc = np.hstack((enc_zn.cpu().detach().numpy(), enc_zc.cpu().detach().numpy()))
101 |
102 | # Cluster with TSNE
103 | tsne_enc = tsne.fit_transform(enc)
104 |
105 | # Convert to numpy for indexing purposes
106 | labels = labels.cpu().data.numpy()
107 |
108 | # Color and marker for each true class
109 | colors = cm.rainbow(np.linspace(0, 1, n_c))
110 | markers = matplotlib.markers.MarkerStyle.filled_markers
111 |
112 | # Save TSNE figure to file
113 | fig, ax = plt.subplots(figsize=(16,10))
114 | for iclass in range(0, n_c):
115 | # Get indices for each class
116 | idxs = labels==iclass
117 | # Scatter those points in tsne dims
118 | ax.scatter(tsne_enc[idxs, 0],
119 | tsne_enc[idxs, 1],
120 | marker=markers[iclass],
121 | c=colors[iclass],
122 | edgecolor=None,
123 | label=r'$%i$'%iclass)
124 |
125 | ax.set_title(r'%s'%fig_title, fontsize=24)
126 | ax.set_xlabel(r'$X^{\mathrm{tSNE}}_1$', fontsize=18)
127 | ax.set_ylabel(r'$X^{\mathrm{tSNE}}_2$', fontsize=18)
128 | plt.legend(title=r'Class', loc='best', numpoints=1, fontsize=16)
129 | plt.tight_layout()
130 | fig.savefig(figname)
131 |
132 | if __name__ == "__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/clusgan/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import os
5 | import numpy as np
6 |
7 | from torch.autograd import Variable
8 | from torch.autograd import grad as torch_grad
9 |
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch
13 |
14 | from itertools import chain as ichain
15 |
16 | except ImportError as e:
17 | print(e)
18 | raise ImportError
19 |
20 |
21 |
22 | # Nan-avoiding logarithm
23 | def tlog(x):
24 | return torch.log(x + 1e-8)
25 |
26 |
27 | # Softmax function
28 | def softmax(x):
29 | return F.softmax(x, dim=1)
30 |
31 |
32 | # Cross Entropy loss with two vector inputs
33 | def cross_entropy(pred, soft_targets):
34 | log_softmax_pred = torch.nn.functional.log_softmax(pred, dim=1)
35 | return torch.mean( torch.sum(- soft_targets * log_softmax_pred, 1) )
36 |
37 |
38 | # Save a provided model to file
39 | def save_model(models=[], out_dir=''):
40 |
41 | # Ensure at least one model to save
42 | assert len(models) > 0, "Must have at least one model to save."
43 |
44 | # Save models to directory out_dir
45 | for model in models:
46 | filename = model.name + '.pth.tar'
47 | outfile = os.path.join(out_dir, filename)
48 | torch.save(model.state_dict(), outfile)
49 |
50 |
51 | # Weight Initializer
52 | def initialize_weights(net):
53 | for m in net.modules():
54 | if isinstance(m, nn.Conv2d):
55 | m.weight.data.normal_(0, 0.02)
56 | m.bias.data.zero_()
57 | elif isinstance(m, nn.ConvTranspose2d):
58 | m.weight.data.normal_(0, 0.02)
59 | m.bias.data.zero_()
60 | elif isinstance(m, nn.Linear):
61 | m.weight.data.normal_(0, 0.02)
62 | m.bias.data.zero_()
63 |
64 |
65 | def weights_init(m):
66 | classname = m.__class__.__name__
67 | if classname.find('Conv') != -1:
68 | m.weight.data.normal_(0.0, 0.02)
69 | elif classname.find('BatchNorm') != -1:
70 | m.weight.data.normal_(1.0, 0.02)
71 | m.bias.data.fill_(0)
72 |
73 | # Sample a random latent space vector
74 | def sample_z(shape=64, latent_dim=10, n_c=10, fix_class=-1, req_grad=False):
75 |
76 | assert (fix_class == -1 or (fix_class >= 0 and fix_class < n_c) ), "Requested class %i outside bounds."%fix_class
77 |
78 | Tensor = torch.cuda.FloatTensor
79 |
80 | # Sample noise as generator input, zn
81 | zn = Variable(Tensor(0.75*np.random.normal(0, 1, (shape, latent_dim))), requires_grad=req_grad)
82 |
83 | ######### zc, zc_idx variables with grads, and zc to one-hot vector
84 | # Pure one-hot vector generation
85 | zc_FT = Tensor(shape, n_c).fill_(0)
86 | zc_idx = torch.empty(shape, dtype=torch.long)
87 |
88 | if (fix_class == -1):
89 | zc_idx = zc_idx.random_(n_c).cuda()
90 | zc_FT = zc_FT.scatter_(1, zc_idx.unsqueeze(1), 1.)
91 | #zc_idx = torch.empty(shape, dtype=torch.long).random_(n_c).cuda()
92 | #zc_FT = Tensor(shape, n_c).fill_(0).scatter_(1, zc_idx.unsqueeze(1), 1.)
93 | else:
94 | zc_idx[:] = fix_class
95 | zc_FT[:, fix_class] = 1
96 |
97 | zc_idx = zc_idx.cuda()
98 | zc_FT = zc_FT.cuda()
99 |
100 | zc = Variable(zc_FT, requires_grad=req_grad)
101 |
102 | ## Gaussian-noisey vector generation
103 | #zc = Variable(Tensor(np.random.normal(0, 1, (shape, n_c))), requires_grad=req_grad)
104 | #zc = softmax(zc)
105 | #zc_idx = torch.argmax(zc, dim=1)
106 |
107 | # Return components of latent space variable
108 | return zn, zc, zc_idx
109 |
110 |
111 | def calc_gradient_penalty(netD, real_data, generated_data):
112 | # GP strength
113 | LAMBDA = 10
114 |
115 | b_size = real_data.size()[0]
116 |
117 | # Calculate interpolation
118 | alpha = torch.rand(b_size, 1, 1, 1)
119 | alpha = alpha.expand_as(real_data)
120 | alpha = alpha.cuda()
121 |
122 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
123 | interpolated = Variable(interpolated, requires_grad=True)
124 | interpolated = interpolated.cuda()
125 |
126 | # Calculate probability of interpolated examples
127 | prob_interpolated = netD(interpolated)
128 |
129 | # Calculate gradients of probabilities with respect to examples
130 | gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
131 | grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
132 | create_graph=True, retain_graph=True)[0]
133 |
134 | # Gradients have shape (batch_size, num_channels, img_width, img_height),
135 | # so flatten to easily take norm per example in batch
136 | gradients = gradients.view(b_size, -1)
137 |
138 | # Derivatives of the gradient close to 0 can cause problems because of
139 | # the square root, so manually calculate norm and add epsilon
140 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
141 |
142 | # Return gradient penalty
143 | return LAMBDA * ((gradients_norm - 1) ** 2).mean()
144 |
--------------------------------------------------------------------------------
/clusgan/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import numpy as np
5 |
6 | from torch.autograd import Variable
7 | from torch.autograd import grad as torch_grad
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch
12 |
13 | from itertools import chain as ichain
14 |
15 | from clusgan.utils import tlog, softmax, initialize_weights, calc_gradient_penalty
16 | except ImportError as e:
17 | print(e)
18 | raise ImportError
19 |
20 |
21 | class Reshape(nn.Module):
22 | """
23 | Class for performing a reshape as a layer in a sequential model.
24 | """
25 | def __init__(self, shape=[]):
26 | super(Reshape, self).__init__()
27 | self.shape = shape
28 |
29 | def forward(self, x):
30 | return x.view(x.size(0), *self.shape)
31 |
32 | def extra_repr(self):
33 | # (Optional)Set the extra information about this module. You can test
34 | # it by printing an object of this class.
35 | return 'shape={}'.format(
36 | self.shape
37 | )
38 |
39 |
40 | class Generator_CNN(nn.Module):
41 | """
42 | CNN to model the generator of a ClusterGAN
43 | Input is a vector from representation space of dimension z_dim
44 | output is a vector from image space of dimension X_dim
45 | """
46 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
47 | def __init__(self, latent_dim, n_c, x_shape, verbose=False):
48 | super(Generator_CNN, self).__init__()
49 |
50 | self.name = 'generator'
51 | self.latent_dim = latent_dim
52 | self.n_c = n_c
53 | self.x_shape = x_shape
54 | self.ishape = (128, 7, 7)
55 | self.iels = int(np.prod(self.ishape))
56 | self.verbose = verbose
57 |
58 | self.model = nn.Sequential(
59 | # Fully connected layers
60 | torch.nn.Linear(self.latent_dim + self.n_c, 1024),
61 | nn.BatchNorm1d(1024),
62 | #torch.nn.ReLU(True),
63 | nn.LeakyReLU(0.2, inplace=True),
64 | torch.nn.Linear(1024, self.iels),
65 | nn.BatchNorm1d(self.iels),
66 | #torch.nn.ReLU(True),
67 | nn.LeakyReLU(0.2, inplace=True),
68 |
69 | # Reshape to 128 x (7x7)
70 | Reshape(self.ishape),
71 |
72 | # Upconvolution layers
73 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=True),
74 | nn.BatchNorm2d(64),
75 | #torch.nn.ReLU(True),
76 | nn.LeakyReLU(0.2, inplace=True),
77 |
78 | nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1, bias=True),
79 | nn.Sigmoid()
80 | )
81 |
82 | initialize_weights(self)
83 |
84 | if self.verbose:
85 | print("Setting up {}...\n".format(self.name))
86 | print(self.model)
87 |
88 | def forward(self, zn, zc):
89 | z = torch.cat((zn, zc), 1)
90 | #z = z.unsqueeze(2).unsqueeze(3)
91 | x_gen = self.model(z)
92 | # Reshape for output
93 | x_gen = x_gen.view(x_gen.size(0), *self.x_shape)
94 | return x_gen
95 |
96 |
97 | class Encoder_CNN(nn.Module):
98 | """
99 | CNN to model the encoder of a ClusterGAN
100 | Input is vector X from image space if dimension X_dim
101 | Output is vector z from representation space of dimension z_dim
102 | """
103 | def __init__(self, latent_dim, n_c, verbose=False):
104 | super(Encoder_CNN, self).__init__()
105 |
106 | self.name = 'encoder'
107 | self.channels = 1
108 | self.latent_dim = latent_dim
109 | self.n_c = n_c
110 | self.cshape = (128, 5, 5)
111 | self.iels = int(np.prod(self.cshape))
112 | self.lshape = (self.iels,)
113 | self.verbose = verbose
114 |
115 | self.model = nn.Sequential(
116 | # Convolutional layers
117 | nn.Conv2d(self.channels, 64, 4, stride=2, bias=True),
118 | nn.LeakyReLU(0.2, inplace=True),
119 | nn.Conv2d(64, 128, 4, stride=2, bias=True),
120 | nn.LeakyReLU(0.2, inplace=True),
121 |
122 | # Flatten
123 | Reshape(self.lshape),
124 |
125 | # Fully connected layers
126 | torch.nn.Linear(self.iels, 1024),
127 | nn.LeakyReLU(0.2, inplace=True),
128 | torch.nn.Linear(1024, latent_dim + n_c)
129 | )
130 |
131 | initialize_weights(self)
132 |
133 | if self.verbose:
134 | print("Setting up {}...\n".format(self.name))
135 | print(self.model)
136 |
137 | def forward(self, in_feat):
138 | z_img = self.model(in_feat)
139 | # Reshape for output
140 | z = z_img.view(z_img.shape[0], -1)
141 | # Separate continuous and one-hot components
142 | zn = z[:, 0:self.latent_dim]
143 | zc_logits = z[:, self.latent_dim:]
144 | # Softmax on zc component
145 | zc = softmax(zc_logits)
146 | return zn, zc, zc_logits
147 |
148 |
149 | class Discriminator_CNN(nn.Module):
150 | """
151 | CNN to model the discriminator of a ClusterGAN
152 | Input is tuple (X,z) of an image vector and its corresponding
153 | representation z vector. For example, if X comes from the dataset, corresponding
154 | z is Encoder(X), and if z is sampled from representation space, X is Generator(z)
155 | Output is a 1-dimensional value
156 | """
157 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
158 | def __init__(self, wass_metric=False, verbose=False):
159 | super(Discriminator_CNN, self).__init__()
160 |
161 | self.name = 'discriminator'
162 | self.channels = 1
163 | self.cshape = (128, 5, 5)
164 | self.iels = int(np.prod(self.cshape))
165 | self.lshape = (self.iels,)
166 | self.wass = wass_metric
167 | self.verbose = verbose
168 |
169 | self.model = nn.Sequential(
170 | # Convolutional layers
171 | nn.Conv2d(self.channels, 64, 4, stride=2, bias=True),
172 | nn.LeakyReLU(0.2, inplace=True),
173 | nn.Conv2d(64, 128, 4, stride=2, bias=True),
174 | nn.LeakyReLU(0.2, inplace=True),
175 |
176 | # Flatten
177 | Reshape(self.lshape),
178 |
179 | # Fully connected layers
180 | torch.nn.Linear(self.iels, 1024),
181 | nn.LeakyReLU(0.2, inplace=True),
182 | torch.nn.Linear(1024, 1),
183 | )
184 |
185 | # If NOT using Wasserstein metric, final Sigmoid
186 | if (not self.wass):
187 | self.model = nn.Sequential(self.model, torch.nn.Sigmoid())
188 |
189 | initialize_weights(self)
190 |
191 | if self.verbose:
192 | print("Setting up {}...\n".format(self.name))
193 | print(self.model)
194 |
195 | def forward(self, img):
196 | # Get output
197 | validity = self.model(img)
198 | return validity
199 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import argparse
5 | import os
6 | import numpy as np
7 |
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 |
11 | import pandas as pd
12 |
13 | from torch.autograd import Variable
14 | from torch.autograd import grad as torch_grad
15 |
16 | import torch
17 | import torchvision
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | from torch.utils.data import DataLoader
21 | from torchvision import datasets
22 | import torchvision.transforms as transforms
23 | from torchvision.utils import save_image
24 |
25 | from itertools import chain as ichain
26 |
27 | from clusgan.definitions import DATASETS_DIR, RUNS_DIR
28 | from clusgan.models import Generator_CNN, Encoder_CNN, Discriminator_CNN
29 | from clusgan.utils import save_model, calc_gradient_penalty, sample_z, cross_entropy
30 | from clusgan.datasets import get_dataloader, dataset_list
31 | from clusgan.plots import plot_train_loss
32 | except ImportError as e:
33 | print(e)
34 | raise ImportError
35 |
36 | def main():
37 | global args
38 | parser = argparse.ArgumentParser(description="Convolutional NN Training Script")
39 | parser.add_argument("-r", "--run_name", dest="run_name", default='clusgan', help="Name of training run")
40 | parser.add_argument("-n", "--n_epochs", dest="n_epochs", default=200, type=int, help="Number of epochs")
41 | parser.add_argument("-b", "--batch_size", dest="batch_size", default=64, type=int, help="Batch size")
42 | parser.add_argument("-s", "--dataset_name", dest="dataset_name", default='mnist', choices=dataset_list, help="Dataset name")
43 | parser.add_argument("-w", "--wass_metric", dest="wass_metric", action='store_true', help="Flag for Wasserstein metric")
44 | parser.add_argument("-g", "-–gpu", dest="gpu", default=0, type=int, help="GPU id to use")
45 | parser.add_argument("-k", "-–num_workers", dest="num_workers", default=1, type=int, help="Number of dataset workers")
46 | args = parser.parse_args()
47 |
48 | run_name = args.run_name
49 | dataset_name = args.dataset_name
50 | device_id = args.gpu
51 | num_workers = args.num_workers
52 |
53 | # Training details
54 | n_epochs = args.n_epochs
55 | batch_size = args.batch_size
56 | test_batch_size = 5000
57 | lr = 1e-4
58 | b1 = 0.5
59 | b2 = 0.9 #99
60 | decay = 2.5*1e-5
61 | n_skip_iter = 1 #5
62 |
63 | img_size = 28
64 | channels = 1
65 |
66 | # Latent space info
67 | latent_dim = 30
68 | n_c = 10
69 | betan = 10
70 | betac = 10
71 |
72 | # Wasserstein metric flag
73 | # Wasserstein metric flag
74 | wass_metric = args.wass_metric
75 | mtype = 'van'
76 | if (wass_metric):
77 | mtype = 'wass'
78 |
79 | # Make directory structure for this run
80 | sep_und = '_'
81 | run_name_comps = ['%iepoch'%n_epochs, 'z%s'%str(latent_dim), mtype, 'bs%i'%batch_size, run_name]
82 | run_name = sep_und.join(run_name_comps)
83 |
84 | run_dir = os.path.join(RUNS_DIR, dataset_name, run_name)
85 | data_dir = os.path.join(DATASETS_DIR, dataset_name)
86 | imgs_dir = os.path.join(run_dir, 'images')
87 | models_dir = os.path.join(run_dir, 'models')
88 |
89 | os.makedirs(data_dir, exist_ok=True)
90 | os.makedirs(run_dir, exist_ok=True)
91 | os.makedirs(imgs_dir, exist_ok=True)
92 | os.makedirs(models_dir, exist_ok=True)
93 | print('\nResults to be saved in directory %s\n'%(run_dir))
94 |
95 | x_shape = (channels, img_size, img_size)
96 |
97 | cuda = True if torch.cuda.is_available() else False
98 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
99 | torch.cuda.set_device(device_id)
100 |
101 | # Loss function
102 | bce_loss = torch.nn.BCELoss()
103 | xe_loss = torch.nn.CrossEntropyLoss()
104 | mse_loss = torch.nn.MSELoss()
105 |
106 | # Initialize generator and discriminator
107 | generator = Generator_CNN(latent_dim, n_c, x_shape)
108 | encoder = Encoder_CNN(latent_dim, n_c)
109 | discriminator = Discriminator_CNN(wass_metric=wass_metric)
110 |
111 | if cuda:
112 | generator.cuda()
113 | encoder.cuda()
114 | discriminator.cuda()
115 | bce_loss.cuda()
116 | xe_loss.cuda()
117 | mse_loss.cuda()
118 |
119 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
120 |
121 | # Configure training data loader
122 | dataloader = get_dataloader(dataset_name=dataset_name,
123 | data_dir=data_dir,
124 | batch_size=batch_size,
125 | num_workers=num_workers)
126 |
127 | # Test data loader
128 | testdata = get_dataloader(dataset_name=dataset_name, data_dir=data_dir, batch_size=test_batch_size, train_set=False)
129 | test_imgs, test_labels = next(iter(testdata))
130 | test_imgs = Variable(test_imgs.type(Tensor))
131 |
132 | ge_chain = ichain(generator.parameters(),
133 | encoder.parameters())
134 | optimizer_GE = torch.optim.Adam(ge_chain, lr=lr, betas=(b1, b2), weight_decay=decay)
135 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
136 | #optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2), weight_decay=decay)
137 |
138 | # ----------
139 | # Training
140 | # ----------
141 | ge_l = []
142 | d_l = []
143 |
144 | c_zn = []
145 | c_zc = []
146 | c_i = []
147 |
148 | # Training loop
149 | print('\nBegin training session with %i epochs...\n'%(n_epochs))
150 | for epoch in range(n_epochs):
151 | for i, (imgs, itruth_label) in enumerate(dataloader):
152 |
153 | # Ensure generator/encoder are trainable
154 | generator.train()
155 | encoder.train()
156 | # Zero gradients for models
157 | generator.zero_grad()
158 | encoder.zero_grad()
159 | discriminator.zero_grad()
160 |
161 | # Configure input
162 | real_imgs = Variable(imgs.type(Tensor))
163 |
164 | # ---------------------------
165 | # Train Generator + Encoder
166 | # ---------------------------
167 |
168 | optimizer_GE.zero_grad()
169 |
170 | # Sample random latent variables
171 | zn, zc, zc_idx = sample_z(shape=imgs.shape[0],
172 | latent_dim=latent_dim,
173 | n_c=n_c)
174 |
175 | # Generate a batch of images
176 | gen_imgs = generator(zn, zc)
177 |
178 | # Discriminator output from real and generated samples
179 | D_gen = discriminator(gen_imgs)
180 | D_real = discriminator(real_imgs)
181 |
182 | # Step for Generator & Encoder, n_skip_iter times less than for discriminator
183 | if (i % n_skip_iter == 0):
184 | # Encode the generated images
185 | enc_gen_zn, enc_gen_zc, enc_gen_zc_logits = encoder(gen_imgs)
186 |
187 | # Calculate losses for z_n, z_c
188 | zn_loss = mse_loss(enc_gen_zn, zn)
189 | zc_loss = xe_loss(enc_gen_zc_logits, zc_idx)
190 | #zc_loss = cross_entropy(enc_gen_zc_logits, zc)
191 |
192 | # Check requested metric
193 | if wass_metric:
194 | # Wasserstein GAN loss
195 | ge_loss = torch.mean(D_gen) + betan * zn_loss + betac * zc_loss
196 | else:
197 | # Vanilla GAN loss
198 | valid = Variable(Tensor(gen_imgs.size(0), 1).fill_(1.0), requires_grad=False)
199 | v_loss = bce_loss(D_gen, valid)
200 | ge_loss = v_loss + betan * zn_loss + betac * zc_loss
201 |
202 | ge_loss.backward(retain_graph=True)
203 | optimizer_GE.step()
204 |
205 | # ---------------------
206 | # Train Discriminator
207 | # ---------------------
208 |
209 | optimizer_D.zero_grad()
210 |
211 | # Measure discriminator's ability to classify real from generated samples
212 | if wass_metric:
213 | # Gradient penalty term
214 | grad_penalty = calc_gradient_penalty(discriminator, real_imgs, gen_imgs)
215 |
216 | # Wasserstein GAN loss w/gradient penalty
217 | d_loss = torch.mean(D_real) - torch.mean(D_gen) + grad_penalty
218 |
219 | else:
220 | # Vanilla GAN loss
221 | fake = Variable(Tensor(gen_imgs.size(0), 1).fill_(0.0), requires_grad=False)
222 | real_loss = bce_loss(D_real, valid)
223 | fake_loss = bce_loss(D_gen, fake)
224 | d_loss = (real_loss + fake_loss) / 2
225 |
226 | d_loss.backward()
227 | optimizer_D.step()
228 |
229 |
230 | # Save training losses
231 | d_l.append(d_loss.item())
232 | ge_l.append(ge_loss.item())
233 |
234 |
235 | # Generator in eval mode
236 | generator.eval()
237 | encoder.eval()
238 |
239 | # Set number of examples for cycle calcs
240 | n_sqrt_samp = 5
241 | n_samp = n_sqrt_samp * n_sqrt_samp
242 |
243 |
244 | ## Cycle through test real -> enc -> gen
245 | t_imgs, t_label = test_imgs.data, test_labels
246 | #r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp]
247 | # Encode sample real instances
248 | e_tzn, e_tzc, e_tzc_logits = encoder(t_imgs)
249 | # Generate sample instances from encoding
250 | teg_imgs = generator(e_tzn, e_tzc)
251 | # Calculate cycle reconstruction loss
252 | img_mse_loss = mse_loss(t_imgs, teg_imgs)
253 | # Save img reco cycle loss
254 | c_i.append(img_mse_loss.item())
255 |
256 |
257 | ## Cycle through randomly sampled encoding -> generator -> encoder
258 | zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_samp,
259 | latent_dim=latent_dim,
260 | n_c=n_c)
261 | # Generate sample instances
262 | gen_imgs_samp = generator(zn_samp, zc_samp)
263 | # Encode sample instances
264 | zn_e, zc_e, zc_e_logits = encoder(gen_imgs_samp)
265 | # Calculate cycle latent losses
266 | lat_mse_loss = mse_loss(zn_e, zn_samp)
267 | lat_xe_loss = xe_loss(zc_e_logits, zc_samp_idx)
268 | #lat_xe_loss = cross_entropy(zc_e_logits, zc_samp)
269 | # Save latent space cycle losses
270 | c_zn.append(lat_mse_loss.item())
271 | c_zc.append(lat_xe_loss.item())
272 |
273 | # Save cycled and generated examples!
274 | r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp]
275 | e_zn, e_zc, e_zc_logits = encoder(r_imgs)
276 | reg_imgs = generator(e_zn, e_zc)
277 | save_image(r_imgs.data[:n_samp],
278 | '%s/real_%06i.png' %(imgs_dir, epoch),
279 | nrow=n_sqrt_samp, normalize=True)
280 | save_image(reg_imgs.data[:n_samp],
281 | '%s/reg_%06i.png' %(imgs_dir, epoch),
282 | nrow=n_sqrt_samp, normalize=True)
283 | save_image(gen_imgs_samp.data[:n_samp],
284 | '%s/gen_%06i.png' %(imgs_dir, epoch),
285 | nrow=n_sqrt_samp, normalize=True)
286 |
287 | ## Generate samples for specified classes
288 | stack_imgs = []
289 | for idx in range(n_c):
290 | # Sample specific class
291 | zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_c,
292 | latent_dim=latent_dim,
293 | n_c=n_c,
294 | fix_class=idx)
295 |
296 | # Generate sample instances
297 | gen_imgs_samp = generator(zn_samp, zc_samp)
298 |
299 | if (len(stack_imgs) == 0):
300 | stack_imgs = gen_imgs_samp
301 | else:
302 | stack_imgs = torch.cat((stack_imgs, gen_imgs_samp), 0)
303 |
304 | # Save class-specified generated examples!
305 | save_image(stack_imgs,
306 | '%s/gen_classes_%06i.png' %(imgs_dir, epoch),
307 | nrow=n_c, normalize=True)
308 |
309 |
310 | print ("[Epoch %d/%d] \n"\
311 | "\tModel Losses: [D: %f] [GE: %f]" % (epoch,
312 | n_epochs,
313 | d_loss.item(),
314 | ge_loss.item())
315 | )
316 |
317 | print("\tCycle Losses: [x: %f] [z_n: %f] [z_c: %f]"%(img_mse_loss.item(),
318 | lat_mse_loss.item(),
319 | lat_xe_loss.item())
320 | )
321 |
322 |
323 |
324 |
325 | # Save training results
326 | train_df = pd.DataFrame({
327 | 'n_epochs' : n_epochs,
328 | 'learning_rate' : lr,
329 | 'beta_1' : b1,
330 | 'beta_2' : b2,
331 | 'weight_decay' : decay,
332 | 'n_skip_iter' : n_skip_iter,
333 | 'latent_dim' : latent_dim,
334 | 'n_classes' : n_c,
335 | 'beta_n' : betan,
336 | 'beta_c' : betac,
337 | 'wass_metric' : wass_metric,
338 | 'gen_enc_loss' : ['G+E', ge_l],
339 | 'disc_loss' : ['D', d_l],
340 | 'zn_cycle_loss' : ['$||Z_n-E(G(x))_n||$', c_zn],
341 | 'zc_cycle_loss' : ['$||Z_c-E(G(x))_c||$', c_zc],
342 | 'img_cycle_loss' : ['$||X-G(E(x))||$', c_i]
343 | })
344 |
345 | train_df.to_csv('%s/training_details.csv'%(run_dir))
346 |
347 |
348 | # Plot some training results
349 | plot_train_loss(df=train_df,
350 | arr_list=['gen_enc_loss', 'disc_loss'],
351 | figname='%s/training_model_losses.png'%(run_dir)
352 | )
353 |
354 | plot_train_loss(df=train_df,
355 | arr_list=['zn_cycle_loss', 'zc_cycle_loss', 'img_cycle_loss'],
356 | figname='%s/training_cycle_loss.png'%(run_dir)
357 | )
358 |
359 |
360 | # Save current state of trained models
361 | model_list = [discriminator, encoder, generator]
362 | save_model(models=model_list, out_dir=models_dir)
363 |
364 |
365 | if __name__ == "__main__":
366 | main()
367 |
--------------------------------------------------------------------------------