├── .gitattributes ├── .gitignore ├── LICENSE ├── On_the_Generation_and_Evaluation_of_Synthetic_Tabular_Data.pdf ├── Presentation - On the Generation and Evaluation of Synthetic Data.pptx.pdf ├── README.md ├── notebooks ├── Gaussian_mixture_models.ipynb ├── data_comparison.ipynb ├── data_comparison_berka.ipynb ├── data_comparison_census.ipynb ├── data_comparison_creditcard.ipynb ├── eval.py ├── evaluation_classifier.ipynb ├── gumbel_softmax.ipynb └── script.py ├── report_images ├── Cumsum-example.png ├── Extrema_example.png ├── chainrule_example.png ├── gmm.png ├── gradient_descent.jpg ├── mean_std_plotted.png ├── medgan-architecture.png ├── mnist_100_sva_20-500.pdf ├── mnist_auto_encoder.png ├── mnist_auto_encoder_combined_loss.png ├── mnist_gan_example.png ├── mnist_kl_only.png ├── normal-auto-encoder.png ├── results │ ├── berka_correlations.png │ ├── berka_mean_std.png │ ├── census_correlation.png │ ├── census_mean_std.png │ ├── column_correlations_berka_tgan.png │ ├── column_distributions_berka_medgan.png │ ├── column_distributions_berka_tgan.png │ ├── creditcard_correlation.png │ ├── creditcard_mean_std.png │ ├── cumsums_creditcard_tablegan.png │ └── cumsums_creditcard_tgan_wgan_gp.png ├── toy.png ├── vanishing-gradient-divergenes.jpeg └── variational-autoencoder-faces.jpg ├── tgan_org ├── __init__.py ├── cli.py ├── data.py ├── model.py └── trainer.py ├── tgan_skip ├── __init__.py ├── cli.py ├── data.py ├── model.py └── trainer.py └── tgan_wgan_gp ├── __init__.py ├── cli.py ├── data.py ├── model.py └── trainer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/macos,python 3 | demo/ 4 | output/ 5 | 6 | 7 | ### macOS ### 8 | *.DS_Store 9 | .AppleDouble 10 | .LSOverride 11 | 12 | # Icon must end with two \r 13 | Icon 14 | 15 | # Thumbnails 16 | ._* 17 | 18 | # Files that might appear in the root of a volume 19 | .DocumentRevisions-V100 20 | .fseventsd 21 | .Spotlight-V100 22 | .TemporaryItems 23 | .Trashes 24 | .VolumeIcon.icns 25 | .com.apple.timemachine.donotpresent 26 | 27 | # Directories potentially created on remote AFP share 28 | .AppleDB 29 | .AppleDesktop 30 | Network Trash Folder 31 | Temporary Items 32 | .apdisk 33 | 34 | ### Python ### 35 | # Byte-compiled / optimized / DLL files 36 | __pycache__/ 37 | *.py[cod] 38 | *$py.class 39 | 40 | # C extensions 41 | *.so 42 | 43 | # Distribution / packaging 44 | .Python 45 | build/ 46 | develop-eggs/ 47 | dist/ 48 | downloads/ 49 | eggs/ 50 | .eggs/ 51 | lib/ 52 | lib64/ 53 | parts/ 54 | sdist/ 55 | var/ 56 | wheels/ 57 | *.egg-info/ 58 | .installed.cfg 59 | *.egg 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .coverage 75 | .coverage.* 76 | .cache 77 | .pytest_cache/ 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | .hypothesis/ 82 | 83 | # Translations 84 | *.mo 85 | *.pot 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule.* 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | 134 | 135 | # End of https://www.gitignore.io/api/macos,python 136 | 137 | train_log/ 138 | expdir/ 139 | *.npz 140 | logs/ 141 | model/ 142 | docs/api/ 143 | 144 | # Vim 145 | .*.swp 146 | 147 | \.idea/ 148 | 149 | final_data/ 150 | 151 | data/ 152 | 153 | data_samples/ 154 | 155 | abn_files/ 156 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019, MIT Data To AI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /On_the_Generation_and_Evaluation_of_Synthetic_Tabular_Data.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/On_the_Generation_and_Evaluation_of_Synthetic_Tabular_Data.pdf -------------------------------------------------------------------------------- /Presentation - On the Generation and Evaluation of Synthetic Data.pptx.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/Presentation - On the Generation and Evaluation of Synthetic Data.pptx.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Generation and Evaluation of Synthetic Tabular Data using GANs 2 | 3 | ## Overview 4 | - Master Thesis Data Science, Radboud University 2019 5 | - License: MIT 6 | - Based on the awesome work from the guys at MIT Data to AI Lab. ([TGAN](https://github.com/DAI-Lab/TGAN), [SDGym](https://github.com/DAI-Lab/SDGym)) 7 | 8 | ## Abstract 9 | With privacy regulations becoming stricter, the opportunity to apply synthetic data is growing rapidly. Synthetic data can be used in any setting where access to data with personal information is not strictly necessary. However, many require the synthetic data to present the same relations as the original data. Existing statistical models and anonymization tools often have adverse effects on the quality of data for downstream tasks like classification. Deep learning based synthesization techniques like GANs provide solutions for cases where it is vital these relations are kept. Inspired by GANs, we propose an improvement in the state-of-the-art in maintaining these relations in synthetic data. Our proposal includes three contributions. First, we propose the addition of skip connections in the generator, which increases gradient flow and modeling capacity. Second, we propose using the WGAN-GP architecture for training the GAN, which suffers less from mode-collapse and has a more meaningful loss. And finally, we propose a new similarity metric for evaluating synthetic data. This metric better captures different aspects of synthetic data when comparing it to real data. We study the behaviour of our proposed model adaptations against several baseline models on three datasets. Our results show that our proposals improve on the state-of-the-art models, by creating higher quality data. Our evaluation metric captures quality improvements in synthetic data and gives detailed insight into the strengths and weaknesses of evaluated models. We conclude that our proposed adaptations should be used for data synthesis, and our evaluation metric is precise and gives a balanced view of different aspects of the data. 10 | 11 | The data evaluation library can be found in an additional repository: [https://github.com/Baukebrenninkmeijer/Table_Evaluator](https://github.com/Baukebrenninkmeijer/Table_Evaluator). 12 | 13 | ## Motivation 14 | To see the motivation for my decisions, please have a look at my master thesis, found at https://www.ru.nl/publish/pages/769526/z04_master_thesis_brenninkmeijer.pdf 15 | 16 | 17 | 18 | ## Using this work? 19 | If you're using this work, please cite the following work: 20 | 21 | ``` 22 | @article{brenninkmeijer2019synthetic, 23 | title={On the Generation and Evaluation of Synthetic Tabular Data using GANs}, 24 | author={Bauke Brenninkmeijer, Youri Hille, Arjen P. de Vries}, 25 | year={2019} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /notebooks/script.py: -------------------------------------------------------------------------------- 1 | from comet_ml import Experiment 2 | import pandas as pd 3 | import argparse 4 | import os 5 | from tgan_wgan_gp.model import TGANModel 6 | ## Change above line to which version you want to use. Choose from ['tgan_org', tgan_skip', 'tgan_wgan_gp'] 7 | 8 | def get_data(ds, drop=None, n_unique=20, sep=';', suffix='cat'): 9 | d = pd.read_csv(f'../data/{ds}/{ds}_{suffix}.csv', sep=sep) 10 | if drop is not None: 11 | d = d.drop(drop, axis=1) 12 | 13 | continuous_columns = [] 14 | for col in d._get_numeric_data().columns: 15 | if len(d[col].unique()) > n_unique: 16 | continuous_columns.append(d.columns.get_loc(col)) 17 | return d, continuous_columns 18 | 19 | parser = argparse.ArgumentParser(description='Evaluate data synthesizers') 20 | parser.add_argument('--dataset', nargs='*', help='Which dataset to choose. Options are berka, creditcard and ticket', default=['berka', 'census', 'creditcard']) 21 | 22 | args = parser.parse_args() 23 | datasets = args.dataset 24 | 25 | for ds in datasets: 26 | 27 | if ds == 'berka': 28 | d, continuous_columns = get_data(ds, drop=['trans_bank_partner', 'trans_account_partner']) 29 | elif ds == 'census': 30 | d, continuous_columns = get_data(ds, sep=',') 31 | elif ds == 'creditcard': 32 | d, continuous_columns = get_data(ds, sep=',', suffix='num') 33 | else: 34 | raise Exception('Unknown dataset mentioned') 35 | 36 | project_name = "tgan-wgan-gp" 37 | experiment = Experiment(api_key=os.environ['COMETML_API_KEY'], 38 | project_name=project_name, workspace="baukebrenninkmeijer") 39 | experiment.log_parameter('dataset', ds) 40 | print(f'ds: {ds}') 41 | 42 | batch_size = 200 43 | assert len(d) > batch_size, f'Batch size larger than data' 44 | steps_per_epoch = len(d)//batch_size 45 | print('Steps per epoch: ', steps_per_epoch) 46 | tgan = TGANModel(continuous_columns, 47 | restore_session=False, 48 | max_epoch=100, 49 | steps_per_epoch=steps_per_epoch, 50 | batch_size=batch_size, 51 | experiment=experiment, 52 | num_gen_rnn=50, 53 | num_gen_feature=64) 54 | tgan.fit(d) 55 | 56 | try: 57 | if os.path.exists('/mnt'): 58 | if not os.path.exists('/mnt/model'): 59 | os.mkdir('/mnt/model') 60 | model_path = f'/mnt/model/{ds}_{project_name}' 61 | else: 62 | model_path = f'model/{ds}_{project_name}' 63 | except: 64 | model_path = f'model/{ds}_{project_name}' 65 | 66 | # try: 67 | # tgan.save(model_path, force=True) 68 | # except Exception as e: 69 | # print(f'{e}\nModel could not be saved') 70 | # 71 | num_samples = 100000 72 | new_samples = tgan.sample(num_samples) 73 | new_samples.to_csv(f'temp_save_{ds}.csv', index=False) 74 | 75 | p = new_samples.copy() 76 | d.columns = p.columns 77 | if ds == 'berka' or ds == 'census': 78 | p[p._get_numeric_data().columns] = p[p._get_numeric_data().columns].astype('int') 79 | if ds == 'creditcard': 80 | p[['Time', 'Class']] = p[['Time', 'Class']].astype('int') 81 | 82 | try: 83 | if os.path.exists('/mnt'): 84 | if not os.path.exists('/mnt/samples'): 85 | os.mkdir('/mnt/samples') 86 | p.to_csv(f'/mnt/samples/{ds}_sample_{project_name}.csv', index=False) 87 | else: 88 | p.to_csv(f'samples/{ds}_sample_{project_name}.csv', index=False) 89 | except: 90 | p.to_csv(f'samples/{ds}_sample_{project_name}.csv', index=False) 91 | 92 | try: 93 | os.remove(f'temp_save_{ds}.csv') 94 | except Exception as e: 95 | print(f'{e} -- Could not remove temp_save_{ds}.csv') 96 | 97 | experiment.end() 98 | 99 | 100 | import tensorflow as tf 101 | tf.keras.backend.clear_session() 102 | 103 | -------------------------------------------------------------------------------- /report_images/Cumsum-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/Cumsum-example.png -------------------------------------------------------------------------------- /report_images/Extrema_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/Extrema_example.png -------------------------------------------------------------------------------- /report_images/chainrule_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/chainrule_example.png -------------------------------------------------------------------------------- /report_images/gmm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/gmm.png -------------------------------------------------------------------------------- /report_images/gradient_descent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/gradient_descent.jpg -------------------------------------------------------------------------------- /report_images/mean_std_plotted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/mean_std_plotted.png -------------------------------------------------------------------------------- /report_images/medgan-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/medgan-architecture.png -------------------------------------------------------------------------------- /report_images/mnist_100_sva_20-500.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/mnist_100_sva_20-500.pdf -------------------------------------------------------------------------------- /report_images/mnist_auto_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/mnist_auto_encoder.png -------------------------------------------------------------------------------- /report_images/mnist_auto_encoder_combined_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/mnist_auto_encoder_combined_loss.png -------------------------------------------------------------------------------- /report_images/mnist_gan_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/mnist_gan_example.png -------------------------------------------------------------------------------- /report_images/mnist_kl_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/mnist_kl_only.png -------------------------------------------------------------------------------- /report_images/normal-auto-encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/normal-auto-encoder.png -------------------------------------------------------------------------------- /report_images/results/berka_correlations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/berka_correlations.png -------------------------------------------------------------------------------- /report_images/results/berka_mean_std.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/berka_mean_std.png -------------------------------------------------------------------------------- /report_images/results/census_correlation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/census_correlation.png -------------------------------------------------------------------------------- /report_images/results/census_mean_std.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/census_mean_std.png -------------------------------------------------------------------------------- /report_images/results/column_correlations_berka_tgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/column_correlations_berka_tgan.png -------------------------------------------------------------------------------- /report_images/results/column_distributions_berka_medgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/column_distributions_berka_medgan.png -------------------------------------------------------------------------------- /report_images/results/column_distributions_berka_tgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/column_distributions_berka_tgan.png -------------------------------------------------------------------------------- /report_images/results/creditcard_correlation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/creditcard_correlation.png -------------------------------------------------------------------------------- /report_images/results/creditcard_mean_std.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/creditcard_mean_std.png -------------------------------------------------------------------------------- /report_images/results/cumsums_creditcard_tablegan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/cumsums_creditcard_tablegan.png -------------------------------------------------------------------------------- /report_images/results/cumsums_creditcard_tgan_wgan_gp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/results/cumsums_creditcard_tgan_wgan_gp.png -------------------------------------------------------------------------------- /report_images/toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/toy.png -------------------------------------------------------------------------------- /report_images/vanishing-gradient-divergenes.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/vanishing-gradient-divergenes.jpeg -------------------------------------------------------------------------------- /report_images/variational-autoencoder-faces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baukebrenninkmeijer/On-the-Generation-and-Evaluation-of-Synthetic-Tabular-Data-using-GANs/94abf372ab5f53a70a4da8e0f311f0583fea4d51/report_images/variational-autoencoder-faces.jpg -------------------------------------------------------------------------------- /tgan_org/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for TGAN.""" 4 | 5 | __author__ = """MIT Data To AI Lab""" 6 | __email__ = 'dailabmit@gmail.com' 7 | __version__ = '0.1.0' 8 | -------------------------------------------------------------------------------- /tgan_org/cli.py: -------------------------------------------------------------------------------- 1 | """Command Line Interface for TGAN.""" 2 | 3 | import argparse 4 | 5 | from tgan.research.experiments import run_experiments 6 | 7 | 8 | def get_train_parser(): 9 | """Build the ArgumentParser for CLI.""" 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 12 | parser.add_argument('--load', help='load model') 13 | parser.add_argument('--sample', type=int, default=0, 14 | help='the number of samples in the synthetic output.') 15 | parser.add_argument('--data', required=True, help='a npz file') 16 | parser.add_argument('--output', type=str) 17 | parser.add_argument('--exp_name', type=str, default=None) 18 | 19 | # parameters for model tuning. 20 | parser.add_argument('--batch_size', type=int, default=200) 21 | parser.add_argument('--z_dim', type=int, default=100) 22 | parser.add_argument('--max_epoch', type=int, default=100) 23 | parser.add_argument('--steps_per_epoch', type=int, default=1000) 24 | 25 | parser.add_argument('--num_gen_rnn', type=int, default=400) 26 | parser.add_argument('--num_gen_feature', type=int, default=100) 27 | 28 | parser.add_argument('--num_dis_layers', type=int, default=2) 29 | parser.add_argument('--num_dis_hidden', type=int, default=200) 30 | 31 | parser.add_argument('--noise', type=float, default=0.2) 32 | 33 | parser.add_argument('--optimizer', type=str, default='AdamOptimizer', 34 | choices=['GradientDescentOptimizer', 'AdamOptimizer', 'AdadeltaOptimizer']) 35 | parser.add_argument('--learning_rate', type=float, default=0.001) 36 | 37 | parser.add_argument('--l2norm', type=float, default=0.00001) 38 | 39 | return parser 40 | 41 | 42 | def get_parser(): 43 | """Build argument parser for TGAN CLI utility.""" 44 | parser = argparse.ArgumentParser(description='TGAN Command Line Interface.') 45 | parser.set_defaults(function=None) 46 | 47 | action = parser.add_subparsers(title='action', dest='action') 48 | action.required = True 49 | 50 | experiments = action.add_parser('experiments', help='Run experiments using TGAN.') 51 | experiments.add_argument( 52 | 'input', type=str, help='Path to the JSON file with the configuration.') 53 | experiments.add_argument( 54 | 'output', type=str, help='Path to store the results.') 55 | 56 | return parser 57 | 58 | 59 | def main(): 60 | """Python Entry point for CLI.""" 61 | parser = get_parser() 62 | args = parser.parse_args() 63 | run_experiments(args.input, args.output) 64 | -------------------------------------------------------------------------------- /tgan_org/data.py: -------------------------------------------------------------------------------- 1 | """Data related functionalities. 2 | 3 | This modules contains the tools to preprare the data, from the raw csv files, to the DataFlow 4 | objects will be used to fit our models. 5 | """ 6 | import os 7 | import urllib 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from sklearn.mixture import GaussianMixture 12 | from sklearn.preprocessing import LabelEncoder 13 | from tensorpack import DataFlow, RNGDataFlow 14 | 15 | DEMO_DATASETS = { 16 | 'census': ( 17 | 'http://hdi-project-tgan.s3.amazonaws.com/census-train.csv', 18 | 'data/census.csv', 19 | [0, 5, 16, 17, 18, 29, 38] 20 | ), 21 | 'covertype': ( 22 | 'http://hdi-project-tgan.s3.amazonaws.com/covertype-train.csv', 23 | 'data/covertype.csv', 24 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 25 | ) 26 | } 27 | 28 | 29 | def check_metadata(metadata): 30 | """Check that the given metadata has correct types for all its members. 31 | 32 | Args: 33 | metadata(dict): Description of the inputs. 34 | 35 | Returns: 36 | None 37 | 38 | Raises: 39 | AssertionError: If any of the details is not valid. 40 | 41 | """ 42 | message = 'The given metadata contains unsupported types.' 43 | assert all([item['type'] in ['category', 'value'] for item in metadata['details']]), message 44 | 45 | 46 | def check_inputs(function): 47 | """Validate inputs for functions whose first argument is a numpy.ndarray with shape (n,1). 48 | 49 | Args: 50 | function(callable): Method to validate. 51 | 52 | Returns: 53 | callable: Will check the inputs before calling :attr:`function`. 54 | 55 | Raises: 56 | ValueError: If first argument is not a valid :class:`numpy.array` of shape (n, 1). 57 | 58 | """ 59 | def decorated(self, data, *args, **kwargs): 60 | if not (isinstance(data, np.ndarray) and len(data.shape) == 2 and data.shape[1] == 1): 61 | raise ValueError('The argument `data` must be a numpy.ndarray with shape (n, 1).') 62 | 63 | return function(self, data, *args, **kwargs) 64 | 65 | decorated.__doc__ = function.__doc__ 66 | return decorated 67 | 68 | 69 | class TGANDataFlow(RNGDataFlow): 70 | """Subclass of :class:`tensorpack.RNGDataFlow` prepared to work with :class:`numpy.ndarray`. 71 | 72 | Attributes: 73 | shuffle(bool): Wheter or not to shuffle the data. 74 | metadata(dict): Metadata for the given :attr:`data`. 75 | num_features(int): Number of features in given data. 76 | data(list): Prepared data from :attr:`filename`. 77 | distribution(list): DepecrationWarning? 78 | 79 | """ 80 | 81 | def __init__(self, data, metadata, shuffle=True): 82 | """Initialize object. 83 | 84 | Args: 85 | filename(str): Path to the json file containing the metadata. 86 | shuffle(bool): Wheter or not to shuffle the data. 87 | 88 | Raises: 89 | ValueError: If any column_info['type'] is not supported 90 | 91 | """ 92 | self.shuffle = shuffle 93 | if self.shuffle: 94 | self.reset_state() 95 | 96 | self.metadata = metadata 97 | self.num_features = self.metadata['num_features'] 98 | 99 | self.data = [] 100 | self.distribution = [] 101 | for column_id, column_info in enumerate(self.metadata['details']): 102 | if column_info['type'] == 'value': 103 | col_data = data['f%02d' % column_id] 104 | value = col_data[:, :1] 105 | cluster = col_data[:, 1:] 106 | self.data.append(value) 107 | self.data.append(cluster) 108 | 109 | elif column_info['type'] == 'category': 110 | col_data = np.asarray(data['f%02d' % column_id], dtype='int32') 111 | self.data.append(col_data) 112 | 113 | else: 114 | raise ValueError( 115 | "column_info['type'] must be either 'category' or 'value'." 116 | "Instead it was '{}'.".format(column_info['type']) 117 | ) 118 | 119 | self.data = list(zip(*self.data)) 120 | 121 | def size(self): 122 | """Return the number of rows in data. 123 | 124 | Returns: 125 | int: Number of rows in :attr:`data`. 126 | 127 | """ 128 | return len(self.data) 129 | 130 | def get_data(self): 131 | """Yield the rows from :attr:`data`. 132 | 133 | Yields: 134 | tuple: Row of data. 135 | 136 | """ 137 | idxs = np.arange(len(self.data)) 138 | if self.shuffle: 139 | self.rng.shuffle(idxs) 140 | 141 | for k in idxs: 142 | yield self.data[k] 143 | 144 | def __iter__(self): 145 | """Iterate over self.data.""" 146 | return self.get_data() 147 | 148 | def __len__(self): 149 | """Length of batches.""" 150 | return self.size() 151 | 152 | 153 | class RandomZData(DataFlow): 154 | """Random dataflow. 155 | 156 | Args: 157 | shape(tuple): Shape of the array to return on :meth:`get_data` 158 | 159 | """ 160 | 161 | def __init__(self, shape): 162 | """Initialize object.""" 163 | super(RandomZData, self).__init__() 164 | self.shape = shape 165 | 166 | def get_data(self): 167 | """Yield random normal vectors of shape :attr:`shape`.""" 168 | while True: 169 | yield [np.random.normal(0, 1, size=self.shape)] 170 | 171 | def __iter__(self): 172 | """Return data.""" 173 | return self.get_data() 174 | 175 | def __len__(self): 176 | """Length of batches.""" 177 | return self.shape[0] 178 | 179 | 180 | class MultiModalNumberTransformer: 181 | r"""Reversible transform for multimodal data. 182 | 183 | To effectively sample values from a multimodal distribution, we cluster values of a 184 | numerical variable using a `skelarn.mixture.GaussianMixture`_ model (GMM). 185 | 186 | * We train a GMM with :attr:`n` components for each numerical variable :math:`C_i`. 187 | GMM models a distribution with a weighted sum of :attr:`n` Gaussian distributions. 188 | The means and standard deviations of the :attr:`n` Gaussian distributions are 189 | :math:`{\eta}^{(1)}_{i}, ..., {\eta}^{(n)}_{i}` and 190 | :math:`{\sigma}^{(1)}_{i}, ...,{\sigma}^{(n)}_{i}`. 191 | 192 | * We compute the probability of :math:`c_{i,j}` coming from each of the :attr:`n` Gaussian 193 | distributions as a vector :math:`{u}^{(1)}_{i,j}, ..., {u}^{(n)}_{i,j}`. u_{i,j} is a 194 | normalized probability distribution over :attr:`n` Gaussian distributions. 195 | 196 | * We normalize :math:`c_{i,j}` as :math:`v_{i,j} = (c_{i,j}−{\eta}^{(k)}_{i})/2{\sigma}^ 197 | {(k)}_{i}`, where :math:`k = arg max_k {u}^{(k)}_{i,j}`. We then clip :math:`v_{i,j}` to 198 | [−0.99, 0.99]. 199 | 200 | Then we use :math:`u_i` and :math:`v_i` to represent :math:`c_i`. For simplicity, 201 | we cluster all the numerical features, i.e. both uni-modal and multi-modal features are 202 | clustered to :attr:`n = 5` Gaussian distributions. 203 | 204 | The simplification is fair because GMM automatically weighs :attr:`n` components. 205 | For example, if a variable has only one mode and fits some Gaussian distribution, then GMM 206 | will assign a very low probability to :attr:`n − 1` components and only 1 remaining 207 | component actually works, which is equivalent to not clustering this feature. 208 | 209 | Args: 210 | num_modes(int): Number of modes on given data. 211 | 212 | Attributes: 213 | num_modes(int): Number of components in the `skelarn.mixture.GaussianMixture`_ model. 214 | 215 | .. _skelarn.mixture.GaussianMixture: https://scikit-learn.org/stable/modules/generated/ 216 | sklearn.mixture.GaussianMixture.html 217 | 218 | """ 219 | 220 | def __init__(self, num_modes=5): 221 | """Initialize instance.""" 222 | self.num_modes = num_modes 223 | 224 | @check_inputs 225 | def transform(self, data): 226 | """Cluster values using a `skelarn.mixture.GaussianMixture`_ model. 227 | 228 | Args: 229 | data(numpy.ndarray): Values to cluster in array of shape (n,1). 230 | 231 | Returns: 232 | tuple[numpy.ndarray, numpy.ndarray, list, list]: Tuple containg the features, 233 | probabilities, averages and stds of the given data. 234 | 235 | .. _skelarn.mixture.GaussianMixture: https://scikit-learn.org/stable/modules/generated/ 236 | sklearn.mixture.GaussianMixture.html 237 | 238 | """ 239 | model = GaussianMixture(self.num_modes) 240 | model.fit(data) 241 | 242 | means = model.means_.reshape((1, self.num_modes)) 243 | stds = np.sqrt(model.covariances_).reshape((1, self.num_modes)) 244 | 245 | features = (data - means) / (2 * stds) 246 | probs = model.predict_proba(data) 247 | argmax = np.argmax(probs, axis=1) 248 | idx = np.arange(len(features)) 249 | features = features[idx, argmax].reshape([-1, 1]) 250 | 251 | features = np.clip(features, -0.99, 0.99) 252 | 253 | return features, probs, list(means.flat), list(stds.flat) 254 | 255 | @staticmethod 256 | def inverse_transform(data, info): 257 | """Reverse the clustering of values. 258 | 259 | Args: 260 | data(numpy.ndarray): Transformed data to restore. 261 | info(dict): Metadata. 262 | 263 | Returns: 264 | numpy.ndarray: Values in the original space. 265 | 266 | """ 267 | features = data[:, 0] 268 | probs = data[:, 1:] 269 | p_argmax = np.argmax(probs, axis=1) 270 | 271 | mean = np.asarray(info['means']) 272 | std = np.asarray(info['stds']) 273 | 274 | select_mean = mean[p_argmax] 275 | select_std = std[p_argmax] 276 | 277 | return features * 2 * select_std + select_mean 278 | 279 | 280 | class Preprocessor: 281 | """Transform back and forth human-readable data into TGAN numerical features. 282 | 283 | Args: 284 | continous_columns(list): List of columns to be considered continuous 285 | metadata(dict): Metadata to initialize the object. 286 | 287 | Attributes: 288 | continous_columns(list): Same as constructor argument. 289 | metadata(dict): Information about the transformations applied to the data and its format. 290 | continous_transformer(MultiModalNumberTransformer): 291 | Transformer for columns in :attr:`continuous_columns` 292 | categorical_transformer(CategoricalTransformer): 293 | Transformer for categorical columns. 294 | columns(list): List of columns labels. 295 | 296 | """ 297 | 298 | def __init__(self, continuous_columns=None, metadata=None): 299 | """Initialize object, set arguments as attributes, initialize transformers.""" 300 | if continuous_columns is None: 301 | continuous_columns = [] 302 | 303 | self.continuous_columns = continuous_columns 304 | self.metadata = metadata 305 | self.continous_transformer = MultiModalNumberTransformer() 306 | self.categorical_transformer = LabelEncoder() 307 | self.columns = None 308 | 309 | def fit_transform(self, data, fitting=True): 310 | """Transform human-readable data into TGAN numerical features. 311 | 312 | Args: 313 | data(pandas.DataFrame): Data to transform. 314 | fitting(bool): Whether or not to update self.metadata. 315 | 316 | Returns: 317 | pandas.DataFrame: Model features 318 | 319 | """ 320 | num_cols = data.shape[1] 321 | self.columns = data.columns 322 | data.columns = list(range(num_cols)) 323 | 324 | transformed_data = {} 325 | details = [] 326 | 327 | for i in data.columns: 328 | if i in self.continuous_columns: 329 | column_data = data[i].values.reshape([-1, 1]) 330 | features, probs, means, stds = self.continous_transformer.transform(column_data) 331 | transformed_data['f%02d' % i] = np.concatenate((features, probs), axis=1) 332 | 333 | if fitting: 334 | details.append({ 335 | "type": "value", 336 | "means": means, 337 | "stds": stds, 338 | "n": 5 339 | }) 340 | 341 | else: 342 | column_data = data[i].astype(str).values 343 | features = self.categorical_transformer.fit_transform(column_data) 344 | transformed_data['f%02d' % i] = features.reshape([-1, 1]) 345 | 346 | if fitting: 347 | mapping = self.categorical_transformer.classes_ 348 | details.append({ 349 | "type": "category", 350 | "mapping": mapping, 351 | "n": mapping.shape[0], 352 | }) 353 | 354 | if fitting: 355 | metadata = { 356 | "num_features": num_cols, 357 | "details": details 358 | } 359 | check_metadata(metadata) 360 | self.metadata = metadata 361 | 362 | return transformed_data 363 | 364 | def transform(self, data): 365 | """Transform the given dataframe without generating new metadata. 366 | 367 | Args: 368 | data(pandas.DataFrame): Data to fit the object. 369 | 370 | """ 371 | return self.fit_transform(data, fitting=False) 372 | 373 | def fit(self, data): 374 | """Initialize the internal state of the object using :attr:`data`. 375 | 376 | Args: 377 | data(pandas.DataFrame): Data to fit the object. 378 | 379 | """ 380 | self.fit_transform(data) 381 | 382 | def reverse_transform(self, data): 383 | """Transform TGAN numerical features back into human-readable data. 384 | 385 | Args: 386 | data(pandas.DataFrame): Data to transform. 387 | fitting(bool): Whether or not to update self.metadata. 388 | 389 | Returns: 390 | pandas.DataFrame: Model features 391 | 392 | """ 393 | table = [] 394 | 395 | for i in range(self.metadata['num_features']): 396 | column_data = data['f%02d' % i] 397 | column_metadata = self.metadata['details'][i] 398 | 399 | if column_metadata['type'] == 'value': 400 | column = self.continous_transformer.inverse_transform(column_data, column_metadata) 401 | 402 | if column_metadata['type'] == 'category': 403 | self.categorical_transformer.classes_ = column_metadata['mapping'] 404 | column = self.categorical_transformer.inverse_transform( 405 | column_data.ravel().astype(np.int32)) 406 | 407 | table.append(column) 408 | 409 | result = pd.DataFrame(dict(enumerate(table))) 410 | result.columns = self.columns 411 | return result 412 | 413 | 414 | def load_demo_data(name, header=None): 415 | """Fetch, load and prepare a dataset. 416 | 417 | If name is one of the demo datasets 418 | 419 | 420 | Args: 421 | name(str): Name or path of the dataset. 422 | header(): Header parameter when executing :attr:`pandas.read_csv` 423 | 424 | """ 425 | params = DEMO_DATASETS.get(name) 426 | if params: 427 | url, file_path, continuous_columns = params 428 | if not os.path.isfile(file_path): 429 | base_path = os.path.dirname(file_path) 430 | if not os.path.exists(base_path): 431 | os.makedirs(base_path) 432 | 433 | urllib.request.urlretrieve(url, file_path) 434 | 435 | else: 436 | message = ( 437 | '{} is not a valid dataset name. ' 438 | 'Supported values are: {}.'.format(name, list(DEMO_DATASETS.keys())) 439 | ) 440 | raise ValueError(message) 441 | 442 | return pd.read_csv(file_path, header=header), continuous_columns 443 | -------------------------------------------------------------------------------- /tgan_org/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Module with the model for TGAN. 5 | 6 | This module contains two classes: 7 | 8 | - :attr:`GraphBuilder`: That defines the graph and implements a Tensorpack compatible API. 9 | - :attr:`TGANModel`: The public API for the model, that offers a simplified interface for the 10 | underlying operations with GraphBuilder and trainers in order to fit and sample data. 11 | """ 12 | import json 13 | import os 14 | import pickle 15 | import tarfile 16 | 17 | import numpy as np 18 | from comet_ml import Experiment 19 | from tensorpack.callbacks import CometMLMonitor, MergeAllSummaries 20 | import tensorflow as tf 21 | from tensorpack import ( 22 | BatchData, BatchNorm, Dropout, FullyConnected, InputDesc, ModelDescBase, ModelSaver, 23 | PredictConfig, QueueInput, SaverRestore, SimpleDatasetPredictor, logger) 24 | from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope 25 | from tensorpack.tfutils.summary import add_moving_summary 26 | from tensorpack.utils.argtools import memoized 27 | 28 | from tgan.data import Preprocessor, RandomZData, TGANDataFlow 29 | from tgan.trainer import GANTrainer 30 | 31 | TUNABLE_VARIABLES = { 32 | 'batch_size': [50, 100, 200], 33 | 'z_dim': [50, 100, 200, 400], 34 | 'num_gen_rnn': [100, 200, 300, 400, 500, 600], 35 | 'num_gen_feature': [100, 200, 300, 400, 500, 600], 36 | 'num_dis_layers': [1, 2, 3, 4, 5], 37 | 'num_dis_hidden': [100, 200, 300, 400, 500], 38 | 'learning_rate': [0.0002, 0.0005, 0.001], 39 | 'noise': [0.05, 0.1, 0.2, 0.3] 40 | } 41 | 42 | 43 | class GraphBuilder(ModelDescBase): 44 | """Main model for TGAN. 45 | 46 | Args: 47 | None 48 | 49 | Attributes: 50 | 51 | """ 52 | 53 | def __init__( 54 | self, 55 | metadata, 56 | batch_size=200, 57 | z_dim=200, 58 | noise=0.2, 59 | l2norm=0.00001, 60 | learning_rate=0.001, 61 | num_gen_rnn=100, 62 | num_gen_feature=100, 63 | num_dis_layers=1, 64 | num_dis_hidden=100, 65 | optimizer='AdamOptimizer', 66 | training=True 67 | ): 68 | """Initialize the object, set arguments as attributes.""" 69 | self.metadata = metadata 70 | self.batch_size = batch_size 71 | self.z_dim = z_dim 72 | self.noise = noise 73 | self.l2norm = l2norm 74 | self.learning_rate = learning_rate 75 | self.num_gen_rnn = num_gen_rnn 76 | self.num_gen_feature = num_gen_feature 77 | self.num_dis_layers = num_dis_layers 78 | self.num_dis_hidden = num_dis_hidden 79 | self.optimizer = optimizer 80 | self.training = training 81 | 82 | def collect_variables(self, g_scope='gen', d_scope='discrim'): 83 | """Assign generator and discriminator variables from their scopes. 84 | 85 | Args: 86 | g_scope(str): Scope for the generator. 87 | d_scope(str): Scope for the discriminator. 88 | 89 | Raises: 90 | ValueError: If any of the assignments fails or the collections are empty. 91 | 92 | """ 93 | self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope) 94 | self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope) 95 | 96 | if not (self.g_vars or self.d_vars): 97 | raise ValueError('There are no variables defined in some of the given scopes') 98 | 99 | def build_losses(self, logits_real, logits_fake, extra_g=0, l2_norm=0.00001): 100 | r"""D and G play two-player minimax game with value function :math:`V(G,D)`. 101 | 102 | .. math:: 103 | 104 | min_G max_D V(D, G) = IE_{x \sim p_{data}} [log D(x)] + IE_{z \sim p_{fake}} 105 | [log (1 - D(G(z)))] 106 | 107 | Args: 108 | logits_real (tensorflow.Tensor): discrim logits from real samples. 109 | logits_fake (tensorflow.Tensor): discrim logits from fake samples from generator. 110 | extra_g(float): 111 | l2_norm(float): scale to apply L2 regularization. 112 | 113 | Returns: 114 | None 115 | 116 | """ 117 | with tf.name_scope("GAN_loss"): 118 | score_real = tf.sigmoid(logits_real) 119 | score_fake = tf.sigmoid(logits_fake) 120 | tf.summary.histogram('score-real', score_real) 121 | tf.summary.histogram('score-fake', score_fake) 122 | 123 | with tf.name_scope("discrim"): 124 | d_loss_pos = tf.reduce_mean( 125 | tf.nn.sigmoid_cross_entropy_with_logits( 126 | logits=logits_real, 127 | labels=tf.ones_like(logits_real)) * 0.7 + tf.random_uniform( 128 | tf.shape(logits_real), 129 | maxval=0.3 130 | ), 131 | name='loss_real' 132 | ) 133 | 134 | d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 135 | logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss_fake') 136 | 137 | d_pos_acc = tf.reduce_mean( 138 | tf.cast(score_real > 0.5, tf.float32), name='accuracy_real') 139 | 140 | d_neg_acc = tf.reduce_mean( 141 | tf.cast(score_fake < 0.5, tf.float32), name='accuracy_fake') 142 | 143 | d_loss = 0.5 * d_loss_pos + 0.5 * d_loss_neg + \ 144 | tf.contrib.layers.apply_regularization( 145 | tf.contrib.layers.l2_regularizer(l2_norm), 146 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discrim")) 147 | 148 | self.d_loss = tf.identity(d_loss, name='loss') 149 | 150 | with tf.name_scope("gen"): 151 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 152 | logits=logits_fake, labels=tf.ones_like(logits_fake))) + \ 153 | tf.contrib.layers.apply_regularization( 154 | tf.contrib.layers.l2_regularizer(l2_norm), 155 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen')) 156 | 157 | g_loss = tf.identity(g_loss, name='loss') 158 | extra_g = tf.identity(extra_g, name='klloss') 159 | self.g_loss = tf.identity(g_loss + extra_g, name='final-g-loss') 160 | 161 | add_moving_summary( 162 | g_loss, extra_g, self.g_loss, self.d_loss, d_pos_acc, d_neg_acc, decay=0.) 163 | 164 | @memoized 165 | def get_optimizer(self): 166 | """Return optimizer of base class.""" 167 | return self._get_optimizer() 168 | 169 | def inputs(self): 170 | """Return metadata about entry data. 171 | 172 | Returns: 173 | list[tensorpack.InputDesc] 174 | 175 | Raises: 176 | ValueError: If any of the elements in self.metadata['details'] has an unsupported 177 | value in the `type` key. 178 | 179 | """ 180 | inputs = [] 181 | for col_id, col_info in enumerate(self.metadata['details']): 182 | if col_info['type'] == 'value': 183 | gaussian_components = col_info['n'] 184 | inputs.append( 185 | InputDesc(tf.float32, (self.batch_size, 1), 'input%02dvalue' % col_id)) 186 | 187 | inputs.append( 188 | InputDesc( 189 | tf.float32, 190 | (self.batch_size, gaussian_components), 191 | 'input%02dcluster' % col_id 192 | ) 193 | ) 194 | 195 | elif col_info['type'] == 'category': 196 | inputs.append(InputDesc(tf.int32, (self.batch_size, 1), 'input%02d' % col_id)) 197 | 198 | else: 199 | raise ValueError( 200 | "self.metadata['details'][{}]['type'] must be either `category` or " 201 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 202 | ) 203 | 204 | return inputs 205 | 206 | def generator(self, z): 207 | r"""Build generator graph. 208 | 209 | We generate a numerical variable in 2 steps. We first generate the value scalar 210 | :math:`v_i`, then generate the cluster vector :math:`u_i`. We generate categorical 211 | feature in 1 step as a probability distribution over all possible labels. 212 | 213 | The output and hidden state size of LSTM is :math:`n_h`. The input to the LSTM in each 214 | step :math:`t` is the random variable :math:`z`, the previous hidden vector :math:`f_{t−1}` 215 | or an embedding vector :math:`f^{\prime}_{t−1}` depending on the type of previous output, 216 | and the weighted context vector :math:`a_{t−1}`. The random variable :math:`z` has 217 | :math:`n_z` dimensions. 218 | Each dimension is sampled from :math:`\mathcal{N}(0, 1)`. The attention-based context 219 | vector at is a weighted average over all the previous LSTM outputs :math:`h_{1:t}`. 220 | So :math:`a_t` is a :math:`n_h`-dimensional vector. 221 | We learn a attention weight vector :math:`α_t \in \mathbb{R}^t` and compute context as 222 | 223 | .. math:: 224 | a_t = \sum_{k=1}^{t} \frac{\textrm{exp} {\alpha}_{t, j}} 225 | {\sum_{j} \textrm{exp} \alpha_{t,j}} h_k. 226 | 227 | We set :math: `a_0` = 0. The output of LSTM is :math:`h_t` and we project the output to 228 | a hidden vector :math:`f_t = \textrm{tanh}(W_h h_t)`, where :math:`W_h` is a learned 229 | parameter in the network. The size of :math:`f_t` is :math:`n_f` . 230 | We further convert the hidden vector to an output variable. 231 | 232 | * If the output is the value part of a continuous variable, we compute the output as 233 | :math:`v_i = \textrm{tanh}(W_t f_t)`. The hidden vector for :math:`t + 1` step is 234 | :math:`f_t`. 235 | 236 | * If the output is the cluster part of a continuous variable, we compute the output as 237 | :math:`u_i = \textrm{softmax}(W_t f_t)`. The feature vector for :math:`t + 1` step is 238 | :math:`f_t`. 239 | 240 | * If the output is a discrete variable, we compute the output as 241 | :math:`d_i = \textrm{softmax}(W_t f_t)`. The hidden vector for :math:`t + 1` step is 242 | :math:`f^{\prime}_{t} = E_i [arg_k \hspace{0.25em} \textrm{max} \hspace{0.25em} d_i ]`, 243 | where :math:`E \in R^{|D_i|×n_f}` is an embedding matrix for discrete variable 244 | :math:`D_i`. 245 | 246 | * :math:`f_0` is a special vector :math:`\texttt{}` and we learn it during the 247 | training. 248 | 249 | Args: 250 | z: 251 | 252 | Returns: 253 | list[tensorflow.Tensor]: Outpu 254 | 255 | Raises: 256 | ValueError: If any of the elements in self.metadata['details'] has an unsupported 257 | value in the `type` key. 258 | 259 | """ 260 | with tf.variable_scope('LSTM'): 261 | cell = tf.nn.rnn_cell.LSTMCell(self.num_gen_rnn) 262 | 263 | state = cell.zero_state(self.batch_size, dtype='float32') 264 | attention = tf.zeros( 265 | shape=(self.batch_size, self.num_gen_rnn), dtype='float32') 266 | input = tf.get_variable(name='go', shape=(1, self.num_gen_feature)) # 267 | input = tf.tile(input, [self.batch_size, 1]) 268 | input = tf.concat([input, z], axis=1) 269 | 270 | ptr = 0 271 | outputs = [] 272 | states = [] 273 | for col_id, col_info in enumerate(self.metadata['details']): 274 | if col_info['type'] == 'value': 275 | output, state = cell(tf.concat([input, attention], axis=1), state) 276 | states.append(state[1]) 277 | 278 | gaussian_components = col_info['n'] 279 | with tf.variable_scope("%02d" % ptr): 280 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 281 | outputs.append(FullyConnected('FC2', h, 1, nl=tf.tanh)) 282 | input = tf.concat([h, z], axis=1) 283 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 284 | attw = tf.nn.softmax(attw, axis=0) 285 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 286 | 287 | ptr += 1 288 | 289 | output, state = cell(tf.concat([input, attention], axis=1), state) 290 | states.append(state[1]) 291 | with tf.variable_scope("%02d" % ptr): 292 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 293 | w = FullyConnected('FC2', h, gaussian_components, nl=tf.nn.softmax) 294 | outputs.append(w) 295 | input = FullyConnected('FC3', w, self.num_gen_feature, nl=tf.identity) 296 | input = tf.concat([input, z], axis=1) 297 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 298 | attw = tf.nn.softmax(attw, axis=0) 299 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 300 | 301 | ptr += 1 302 | 303 | elif col_info['type'] == 'category': 304 | output, state = cell(tf.concat([input, attention], axis=1), state) 305 | states.append(state[1]) 306 | with tf.variable_scope("%02d" % ptr): 307 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 308 | w = FullyConnected('FC2', h, col_info['n'], nl=tf.nn.softmax) 309 | outputs.append(w) 310 | one_hot = tf.one_hot(tf.argmax(w, axis=1), col_info['n']) 311 | input = FullyConnected( 312 | 'FC3', one_hot, self.num_gen_feature, nl=tf.identity) 313 | input = tf.concat([input, z], axis=1) 314 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 315 | attw = tf.nn.softmax(attw, axis=0) 316 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 317 | 318 | ptr += 1 319 | 320 | else: 321 | raise ValueError( 322 | "self.metadata['details'][{}]['type'] must be either `category` or " 323 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 324 | ) 325 | 326 | return outputs 327 | 328 | @staticmethod 329 | def batch_diversity(l, n_kernel=10, kernel_dim=10): 330 | r"""Return the minibatch discrimination vector. 331 | 332 | Let :math:`f(x_i) \in \mathbb{R}^A` denote a vector of features for input :math:`x_i`, 333 | produced by some intermediate layer in the discriminator. We then multiply the vector 334 | :math:`f(x_i)` by a tensor :math:`T \in \mathbb{R}^{A×B×C}`, which results in a matrix 335 | :math:`M_i \in \mathbb{R}^{B×C}`. We then compute the :math:`L_1`-distance between the 336 | rows of the resulting matrix :math:`M_i` across samples :math:`i \in {1, 2, ... , n}` 337 | and apply a negative exponential: 338 | 339 | .. math:: 340 | 341 | cb(x_i, x_j) = exp(−||M_{i,b} − M_{j,b}||_{L_1} ) \in \mathbb{R}. 342 | 343 | The output :math:`o(x_i)` for this *minibatch layer* for a sample :math:`x_i` is then 344 | defined as the sum of the cb(xi, xj )’s to all other samples: 345 | 346 | .. math:: 347 | :nowrap: 348 | 349 | \begin{aligned} 350 | 351 | &o(x_i)_b = \sum^{n}_{j=1} cb(x_i , x_j) \in \mathbb{R}\\ 352 | &o(x_i) = \Big[ o(x_i)_1, o(x_i)_2, . . . , o(x_i)_B \Big] \in \mathbb{R}^B\\ 353 | &o(X) ∈ R^{n×B}\\ 354 | 355 | \end{aligned} 356 | 357 | Note: 358 | This is extracted from `Improved techniques for training GANs`_ (Section 3.2) by 359 | Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and 360 | Xi Chen. 361 | 362 | .. _Improved techniques for training GANs: https://arxiv.org/pdf/1606.03498.pdf 363 | 364 | Args: 365 | l(tf.Tensor) 366 | n_kernel(int) 367 | kernel_dim(int) 368 | 369 | Returns: 370 | tensorflow.Tensor 371 | 372 | """ 373 | M = FullyConnected('fc_diversity', l, n_kernel * kernel_dim, nl=tf.identity) 374 | M = tf.reshape(M, [-1, n_kernel, kernel_dim]) 375 | M1 = tf.reshape(M, [-1, 1, n_kernel, kernel_dim]) 376 | M2 = tf.reshape(M, [1, -1, n_kernel, kernel_dim]) 377 | diff = tf.exp(-tf.reduce_sum(tf.abs(M1 - M2), axis=3)) 378 | return tf.reduce_sum(diff, axis=0) 379 | 380 | @auto_reuse_variable_scope 381 | def discriminator(self, vecs): 382 | r"""Build discriminator. 383 | 384 | We use a :math:`l`-layer fully connected neural network as the discriminator. 385 | We concatenate :math:`v_{1:n_c}`, :math:`u_{1:n_c}` and :math:`d_{1:n_d}` together as the 386 | input. We compute the internal layers as 387 | 388 | .. math:: 389 | \begin{aligned} 390 | 391 | f^{(D)}_{1} &= \textrm{LeakyReLU}(\textrm{BN}(W^{(D)}_{1}(v_{1:n_c} \oplus u_{1:n_c} 392 | \oplus d_{1:n_d}) 393 | 394 | f^{(D)}_{1} &= \textrm{LeakyReLU}(\textrm{BN}(W^{(D)}_{i}(f^{(D)}_{i−1} \oplus 395 | \textrm{diversity}(f^{(D)}_{i−1})))), i = 2:l 396 | 397 | \end{aligned} 398 | 399 | where :math:`\oplus` is the concatenation operation. :math:`\textrm{diversity}(·)` is the 400 | mini-batch discrimination vector [42]. Each dimension of the diversity vector is the total 401 | distance between one sample and all other samples in the mini-batch using some learned 402 | distance metric. :math:`\textrm{BN}(·)` is batch normalization, and 403 | :math:`\textrm{LeakyReLU}(·)` is the leaky reflect linear activation function. We further 404 | compute the output of discriminator as :math:`W^{(D)}(f^{(D)}_{l} \oplus \textrm{diversity} 405 | (f^{(D)}_{l}))` which is a scalar. 406 | 407 | Args: 408 | vecs(list[tensorflow.Tensor]): List of tensors matching the spec of :meth:`inputs` 409 | 410 | Returns: 411 | tensorpack.FullyConected: a (b, 1) logits 412 | 413 | """ 414 | logits = tf.concat(vecs, axis=1) 415 | for i in range(self.num_dis_layers): 416 | with tf.variable_scope('dis_fc{}'.format(i)): 417 | if i == 0: 418 | logits = FullyConnected( 419 | 'fc', logits, self.num_dis_hidden, nl=tf.identity, 420 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.1) 421 | ) 422 | 423 | else: 424 | logits = FullyConnected('fc', logits, self.num_dis_hidden, nl=tf.identity) 425 | 426 | logits = tf.concat([logits, self.batch_diversity(logits)], axis=1) 427 | logits = BatchNorm('bn', logits, center=True, scale=False) 428 | logits = Dropout(logits) 429 | logits = tf.nn.leaky_relu(logits) 430 | 431 | return FullyConnected('dis_fc_top', logits, 1, nl=tf.identity) 432 | 433 | @staticmethod 434 | def compute_kl(real, pred): 435 | r"""Compute the Kullback–Leibler divergence, :math:`D_{KL}(\textrm{pred} || \textrm{real})`. 436 | 437 | Args: 438 | real(tensorflow.Tensor): Real values. 439 | pred(tensorflow.Tensor): Predicted values. 440 | 441 | Returns: 442 | float: Computed divergence for the given values. 443 | 444 | """ 445 | return tf.reduce_sum((tf.log(pred + 1e-4) - tf.log(real + 1e-4)) * pred) 446 | 447 | def build_graph(self, *inputs): 448 | """Build the whole graph. 449 | 450 | Args: 451 | inputs(list[tensorflow.Tensor]): 452 | 453 | Returns: 454 | None 455 | 456 | """ 457 | z = tf.random_normal( 458 | [self.batch_size, self.z_dim], name='z_train') 459 | 460 | z = tf.placeholder_with_default(z, [None, self.z_dim], name='z') 461 | 462 | with tf.variable_scope('gen'): 463 | vecs_gen = self.generator(z) 464 | 465 | vecs_denorm = [] 466 | ptr = 0 467 | for col_id, col_info in enumerate(self.metadata['details']): 468 | if col_info['type'] == 'category': 469 | t = tf.argmax(vecs_gen[ptr], axis=1) 470 | t = tf.cast(tf.reshape(t, [-1, 1]), 'float32') 471 | vecs_denorm.append(t) 472 | ptr += 1 473 | 474 | elif col_info['type'] == 'value': 475 | vecs_denorm.append(vecs_gen[ptr]) 476 | ptr += 1 477 | vecs_denorm.append(vecs_gen[ptr]) 478 | ptr += 1 479 | 480 | else: 481 | raise ValueError( 482 | "self.metadata['details'][{}]['type'] must be either `category` or " 483 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 484 | ) 485 | 486 | tf.identity(tf.concat(vecs_denorm, axis=1), name='gen') 487 | 488 | vecs_pos = [] 489 | ptr = 0 490 | for col_id, col_info in enumerate(self.metadata['details']): 491 | if col_info['type'] == 'category': 492 | one_hot = tf.one_hot(tf.reshape(inputs[ptr], [-1]), col_info['n']) 493 | noise_input = one_hot 494 | 495 | if self.training: 496 | noise = tf.random_uniform(tf.shape(one_hot), minval=0, maxval=self.noise) 497 | noise_input = (one_hot + noise) / tf.reduce_sum( 498 | one_hot + noise, keepdims=True, axis=1) 499 | 500 | vecs_pos.append(noise_input) 501 | ptr += 1 502 | 503 | elif col_info['type'] == 'value': 504 | vecs_pos.append(inputs[ptr]) 505 | ptr += 1 506 | vecs_pos.append(inputs[ptr]) 507 | ptr += 1 508 | 509 | else: 510 | raise ValueError( 511 | "self.metadata['details'][{}]['type'] must be either `category` or " 512 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 513 | ) 514 | 515 | KL = 0. 516 | ptr = 0 517 | if self.training: 518 | for col_id, col_info in enumerate(self.metadata['details']): 519 | if col_info['type'] == 'category': 520 | dist = tf.reduce_sum(vecs_gen[ptr], axis=0) 521 | dist = dist / tf.reduce_sum(dist) 522 | 523 | real = tf.reduce_sum(vecs_pos[ptr], axis=0) 524 | real = real / tf.reduce_sum(real) 525 | KL += self.compute_kl(real, dist) 526 | ptr += 1 527 | 528 | elif col_info['type'] == 'value': 529 | ptr += 1 530 | dist = tf.reduce_sum(vecs_gen[ptr], axis=0) 531 | dist = dist / tf.reduce_sum(dist) 532 | real = tf.reduce_sum(vecs_pos[ptr], axis=0) 533 | real = real / tf.reduce_sum(real) 534 | KL += self.compute_kl(real, dist) 535 | 536 | ptr += 1 537 | 538 | else: 539 | raise ValueError( 540 | "self.metadata['details'][{}]['type'] must be either `category` or " 541 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 542 | ) 543 | 544 | with tf.variable_scope('discrim'): 545 | discrim_pos = self.discriminator(vecs_pos) 546 | discrim_neg = self.discriminator(vecs_gen) 547 | 548 | self.build_losses(discrim_pos, discrim_neg, extra_g=KL, l2_norm=self.l2norm) 549 | self.collect_variables() 550 | 551 | def _get_optimizer(self): 552 | if self.optimizer == 'AdamOptimizer': 553 | return tf.train.AdamOptimizer(self.learning_rate, 0.5) 554 | 555 | elif self.optimizer == 'AdadeltaOptimizer': 556 | return tf.train.AdadeltaOptimizer(self.learning_rate, 0.95) 557 | 558 | else: 559 | return tf.train.GradientDescentOptimizer(self.learning_rate) 560 | 561 | 562 | class TGANModel: 563 | """Main model from TGAN. 564 | 565 | Args: 566 | continuous_columns (list[int]): 0-index list of column indices to be considered continuous. 567 | output (str, optional): Path to store the model and its artifacts. Defaults to 568 | :attr:`output`. 569 | gpu (list[str], optional):Comma separated list of GPU(s) to use. Defaults to :attr:`None`. 570 | max_epoch (int, optional): Number of epochs to use during training. Defaults to :attr:`5`. 571 | steps_per_epoch (int, optional): Number of steps to run on each epoch. Defaults to 572 | :attr:`10000`. 573 | save_checkpoints(bool, optional): Whether or not to store checkpoints of the model after 574 | each training epoch. Defaults to :attr:`True` 575 | restore_session(bool, optional): Whether or not continue training from the last checkpoint. 576 | Defaults to :attr:`True`. 577 | batch_size (int, optional): Size of the batch to feed the model at each step. Defaults to 578 | :attr:`200`. 579 | z_dim (int, optional): Number of dimensions in the noise input for the generator. 580 | Defaults to :attr:`100`. 581 | noise (float, optional): Upper bound to the gaussian noise added to categorical columns. 582 | Defaults to :attr:`0.2`. 583 | l2norm (float, optional): 584 | L2 reguralization coefficient when computing losses. Defaults to :attr:`0.00001`. 585 | learning_rate (float, optional): Learning rate for the optimizer. Defaults to 586 | :attr:`0.001`. 587 | num_gen_rnn (int, optional): Defaults to :attr:`400`. 588 | num_gen_feature (int, optional): Number of features of in the generator. Defaults to 589 | :attr:`100` 590 | num_dis_layers (int, optional): Defaults to :attr:`2`. 591 | num_dis_hidden (int, optional): Defaults to :attr:`200`. 592 | optimizer (str, optional): Name of the optimizer to use during `fit`,possible values are: 593 | [`GradientDescentOptimizer`, `AdamOptimizer`, `AdadeltaOptimizer`]. Defaults to 594 | :attr:`AdamOptimizer`. 595 | """ 596 | 597 | def __init__( 598 | self, continuous_columns, output='output', gpu=None, max_epoch=5, steps_per_epoch=10000, 599 | save_checkpoints=True, restore_session=True, batch_size=200, z_dim=200, noise=0.2, 600 | l2norm=0.00001, learning_rate=0.001, num_gen_rnn=100, num_gen_feature=100, 601 | num_dis_layers=1, num_dis_hidden=100, optimizer='AdamOptimizer', comet_ml_key=None, experiment=None, ds=None 602 | ): 603 | """Initialize object.""" 604 | # Output 605 | self.continuous_columns = continuous_columns 606 | self.log_dir = os.path.join(output, 'logs') 607 | self.model_dir = os.path.join(output, 'model') 608 | self.output = output 609 | 610 | # Training params 611 | self.max_epoch = max_epoch 612 | self.steps_per_epoch = steps_per_epoch 613 | self.save_checkpoints = save_checkpoints 614 | self.restore_session = restore_session 615 | 616 | # Model params 617 | self.model = None 618 | self.batch_size = batch_size 619 | self.z_dim = z_dim 620 | self.noise = noise 621 | self.l2norm = l2norm 622 | self.learning_rate = learning_rate 623 | self.num_gen_rnn = num_gen_rnn 624 | self.num_gen_feature = num_gen_feature 625 | self.num_dis_layers = num_dis_layers 626 | self.num_dis_hidden = num_dis_hidden 627 | self.optimizer = optimizer 628 | 629 | if gpu: 630 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 631 | 632 | if experiment is not None: 633 | self.experiment = experiment 634 | elif comet_ml_key is not None: 635 | self.comet_ml_key = comet_ml_key 636 | self.experiment = Experiment(api_key=comet_ml_key, project_name='tgan-wgan-gp', workspace="baukebrenninkmeijer") 637 | if ds is not None: 638 | experiment.log_dataset_info(name=ds) 639 | self.gpu = gpu 640 | 641 | def get_model(self, training=True): 642 | """Return a new instance of the model.""" 643 | return GraphBuilder( 644 | metadata=self.metadata, 645 | batch_size=self.batch_size, 646 | z_dim=self.z_dim, 647 | noise=self.noise, 648 | l2norm=self.l2norm, 649 | learning_rate=self.learning_rate, 650 | num_gen_rnn=self.num_gen_rnn, 651 | num_gen_feature=self.num_gen_feature, 652 | num_dis_layers=self.num_dis_layers, 653 | num_dis_hidden=self.num_dis_hidden, 654 | optimizer=self.optimizer, 655 | training=training 656 | ) 657 | 658 | def prepare_sampling(self): 659 | """Prepare model for generate samples.""" 660 | if self.model is None: 661 | self.model = self.get_model(training=False) 662 | 663 | else: 664 | self.model.training = False 665 | 666 | predict_config = PredictConfig( 667 | session_init=SaverRestore(self.restore_path), 668 | model=self.model, 669 | input_names=['z'], 670 | output_names=['gen/gen', 'z'], 671 | ) 672 | 673 | self.simple_dataset_predictor = SimpleDatasetPredictor( 674 | predict_config, 675 | RandomZData((self.batch_size, self.z_dim)) 676 | ) 677 | 678 | def fit(self, data): 679 | """Fit the model to the given data. 680 | 681 | Args: 682 | data(pandas.DataFrame): dataset to fit the model. 683 | 684 | Returns: 685 | None 686 | 687 | """ 688 | self.preprocessor = Preprocessor(continuous_columns=self.continuous_columns) 689 | data = self.preprocessor.fit_transform(data) 690 | self.metadata = self.preprocessor.metadata 691 | dataflow = TGANDataFlow(data, self.metadata) 692 | batch_data = BatchData(dataflow, self.batch_size) 693 | input_queue = QueueInput(batch_data) 694 | 695 | self.model = self.get_model(training=True) 696 | 697 | trainer = GANTrainer( 698 | model=self.model, 699 | input_queue=input_queue, 700 | ) 701 | 702 | self.restore_path = os.path.join(self.model_dir, 'checkpoint') 703 | 704 | if os.path.isfile(self.restore_path) and self.restore_session: 705 | session_init = SaverRestore(self.restore_path) 706 | with open(os.path.join(self.log_dir, 'stats.json')) as f: 707 | starting_epoch = json.load(f)[-1]['epoch_num'] + 1 708 | 709 | else: 710 | session_init = None 711 | starting_epoch = 1 712 | 713 | action = 'k' if self.restore_session else None 714 | # logger.set_logger_dir(self.log_dir, action=action) 715 | 716 | callbacks = [] 717 | monitors = [] 718 | if self.save_checkpoints: 719 | callbacks.append(ModelSaver(checkpoint_dir=self.model_dir)) 720 | callbacks.append(MergeAllSummaries(period=10)) 721 | 722 | if self.experiment is not None: 723 | monitors.append(CometMLMonitor(experiment=self.experiment)) 724 | 725 | trainer.train_with_defaults( 726 | callbacks=callbacks, 727 | monitors=monitors, 728 | steps_per_epoch=self.steps_per_epoch, 729 | max_epoch=self.max_epoch, 730 | session_init=session_init, 731 | starting_epoch=starting_epoch 732 | ) 733 | 734 | self.prepare_sampling() 735 | 736 | def sample(self, num_samples): 737 | """Generate samples from model. 738 | 739 | Args: 740 | num_samples(int) 741 | 742 | Returns: 743 | None 744 | 745 | Raises: 746 | ValueError 747 | 748 | """ 749 | max_iters = (num_samples // self.batch_size) 750 | 751 | results = [] 752 | for idx, o in enumerate(self.simple_dataset_predictor.get_result()): 753 | results.append(o[0]) 754 | if idx + 1 == max_iters: 755 | break 756 | 757 | results = np.concatenate(results, axis=0) 758 | 759 | ptr = 0 760 | features = {} 761 | for col_id, col_info in enumerate(self.metadata['details']): 762 | if col_info['type'] == 'category': 763 | features['f%02d' % col_id] = results[:, ptr:ptr + 1] 764 | ptr += 1 765 | 766 | elif col_info['type'] == 'value': 767 | gaussian_components = col_info['n'] 768 | val = results[:, ptr:ptr + 1] 769 | ptr += 1 770 | pro = results[:, ptr:ptr + gaussian_components] 771 | ptr += gaussian_components 772 | features['f%02d' % col_id] = np.concatenate([val, pro], axis=1) 773 | 774 | else: 775 | raise ValueError( 776 | "self.metadata['details'][{}]['type'] must be either `category` or " 777 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 778 | ) 779 | 780 | return self.preprocessor.reverse_transform(features)[:num_samples].copy() 781 | 782 | def tar_folder(self, tar_name): 783 | """Generate a tar of :self.output:.""" 784 | with tarfile.open(tar_name, 'w:gz') as tar_handle: 785 | for root, dirs, files in os.walk(self.output): 786 | for file_ in files: 787 | tar_handle.add(os.path.join(root, file_)) 788 | 789 | tar_handle.close() 790 | 791 | @classmethod 792 | def load(cls, path): 793 | """Load a pretrained model from a given path.""" 794 | with tarfile.open(path, 'r:gz') as tar_handle: 795 | destination_dir = os.path.dirname(tar_handle.getmembers()[0].name) 796 | tar_handle.extractall() 797 | 798 | with open('{}/TGANModel'.format(destination_dir), 'rb+') as f: 799 | instance = pickle.load(f) 800 | 801 | instance.prepare_sampling() 802 | return instance 803 | 804 | def save(self, path, force=False): 805 | """Save the fitted model in the given path.""" 806 | if os.path.exists(path) and not force: 807 | logger.info('The indicated path already exists. Use `force=True` to overwrite.') 808 | return 809 | 810 | base_path = os.path.dirname(path) 811 | if not os.path.exists(base_path): 812 | os.makedirs(base_path) 813 | 814 | model = self.model 815 | dataset_predictor = self.simple_dataset_predictor 816 | 817 | self.model = None 818 | self.simple_dataset_predictor = None 819 | 820 | with open('{}/TGANModel'.format(self.output), 'wb') as f: 821 | pickle.dump(self, f) 822 | 823 | self.model = model 824 | self.simple_dataset_predictor = dataset_predictor 825 | 826 | self.tar_folder(path) 827 | 828 | logger.info('Model saved successfully.') 829 | -------------------------------------------------------------------------------- /tgan_org/trainer.py: -------------------------------------------------------------------------------- 1 | """GAN Models.""" 2 | 3 | import tensorflow as tf 4 | from tensorpack import StagingInput, TowerTrainer 5 | from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter 6 | from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper 7 | 8 | 9 | class GANTrainer(TowerTrainer): 10 | """GanTrainer model. 11 | 12 | We need to set :meth:`tower_func` because it's a :class:`TowerTrainer`, and only 13 | :class:`TowerTrainer` supports automatic graph creation for inference during training. 14 | 15 | If we don't care about inference during training, using :meth:`tower_func` is not needed. 16 | Just calling :meth:`model.build_graph` directly is OK. 17 | 18 | Args: 19 | input_queue(tensorpack.input_source.QueueInput): Data input. 20 | model(tgan.GAN.GANModelDesc): Model to train. 21 | 22 | """ 23 | 24 | def __init__(self, model, input_queue): 25 | """Initialize object.""" 26 | super().__init__() 27 | inputs_desc = model.get_inputs_desc() 28 | 29 | # Setup input 30 | cbs = input_queue.setup(inputs_desc) 31 | self.register_callback(cbs) 32 | 33 | # Build the graph 34 | self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) 35 | with TowerContext('', is_training=True): 36 | self.tower_func(*input_queue.get_input_tensors()) 37 | 38 | opt = model.get_optimizer() 39 | 40 | # Define the training iteration by default, run one d_min after one g_min 41 | with tf.name_scope('optimize'): 42 | g_min_grad = opt.compute_gradients(model.g_loss, var_list=model.g_vars) 43 | g_min_grad_clip = [ 44 | (tf.clip_by_value(grad, -5.0, 5.0), var) 45 | for grad, var in g_min_grad 46 | ] 47 | 48 | g_min_train_op = opt.apply_gradients(g_min_grad_clip, name='g_op') 49 | with tf.control_dependencies([g_min_train_op]): 50 | d_min_grad = opt.compute_gradients(model.d_loss, var_list=model.d_vars) 51 | d_min_grad_clip = [ 52 | (tf.clip_by_value(grad, -5.0, 5.0), var) 53 | for grad, var in d_min_grad 54 | ] 55 | 56 | d_min_train_op = opt.apply_gradients(d_min_grad_clip, name='d_op') 57 | 58 | self.train_op = d_min_train_op 59 | 60 | 61 | class SeparateGANTrainer(TowerTrainer): 62 | """A GAN trainer which runs two optimization ops with a certain ratio. 63 | 64 | Args: 65 | input(tensorpack.input_source.QueueInput): Data input. 66 | model(tgan.GAN.GANModelDesc): Model to train. 67 | d_period(int): period of each d_opt run 68 | g_period(int): period of each g_opt run 69 | 70 | """ 71 | 72 | def __init__(self, input, model, d_period=1, g_period=1): 73 | """Initialize object.""" 74 | super(SeparateGANTrainer, self).__init__() 75 | self._d_period = int(d_period) 76 | self._g_period = int(g_period) 77 | if not min(d_period, g_period) == 1: 78 | raise ValueError('The minimum between d_period and g_period must be 1.') 79 | 80 | # Setup input 81 | cbs = input.setup(model.get_inputs_desc()) 82 | self.register_callback(cbs) 83 | 84 | # Build the graph 85 | self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) 86 | with TowerContext('', is_training=True): 87 | self.tower_func(*input.get_input_tensors()) 88 | 89 | opt = model.get_optimizer() 90 | with tf.name_scope('optimize'): 91 | self.d_min = opt.minimize( 92 | model.d_loss, var_list=model.d_vars, name='d_min') 93 | self.g_min = opt.minimize( 94 | model.g_loss, var_list=model.g_vars, name='g_min') 95 | 96 | def run_step(self): 97 | """Define the training iteration.""" 98 | if self.global_step % (self._d_period) == 0: 99 | self.hooked_sess.run(self.d_min) 100 | if self.global_step % (self._g_period) == 0: 101 | self.hooked_sess.run(self.g_min) 102 | 103 | 104 | class MultiGPUGANTrainer(TowerTrainer): 105 | """A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. 106 | 107 | Args: 108 | nr_gpu(int): 109 | input(tensorpack.input_source.QueueInput): Data input. 110 | model(tgan.GAN.GANModelDesc): Model to train. 111 | 112 | """ 113 | 114 | def __init__(self, nr_gpu, input, model): 115 | """Initialize object.""" 116 | super(MultiGPUGANTrainer, self).__init__() 117 | if nr_gpu <= 1: 118 | raise ValueError('nr_gpu must be strictly greater than 1.') 119 | 120 | raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)] 121 | 122 | # Setup input 123 | input = StagingInput(input) 124 | cbs = input.setup(model.get_inputs_desc()) 125 | self.register_callback(cbs) 126 | 127 | # Build the graph with multi-gpu replication 128 | def get_cost(*inputs): 129 | model.build_graph(*inputs) 130 | return [model.d_loss, model.g_loss] 131 | 132 | self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) 133 | devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] 134 | 135 | cost_list = DataParallelBuilder.build_on_towers( 136 | list(range(nr_gpu)), 137 | lambda: self.tower_func(*input.get_input_tensors()), 138 | devices) 139 | 140 | # Simply average the cost here. It might be faster to average the gradients 141 | with tf.name_scope('optimize'): 142 | d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) 143 | g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu) 144 | 145 | opt = model.get_optimizer() 146 | # run one d_min after one g_min 147 | g_min = opt.minimize(g_loss, var_list=model.g_vars, 148 | colocate_gradients_with_ops=True, name='g_op') 149 | 150 | with tf.control_dependencies([g_min]): 151 | d_min = opt.minimize(d_loss, var_list=model.d_vars, 152 | colocate_gradients_with_ops=True, name='d_op') 153 | 154 | # Define the training iteration 155 | self.train_op = d_min 156 | -------------------------------------------------------------------------------- /tgan_skip/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for TGAN.""" 4 | 5 | __author__ = """MIT Data To AI Lab""" 6 | __email__ = 'dailabmit@gmail.com' 7 | __version__ = '0.1.0' 8 | -------------------------------------------------------------------------------- /tgan_skip/cli.py: -------------------------------------------------------------------------------- 1 | """Command Line Interface for TGAN.""" 2 | 3 | import argparse 4 | 5 | from tgan.research.experiments import run_experiments 6 | 7 | 8 | def get_train_parser(): 9 | """Build the ArgumentParser for CLI.""" 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 12 | parser.add_argument('--load', help='load model') 13 | parser.add_argument('--sample', type=int, default=0, 14 | help='the number of samples in the synthetic output.') 15 | parser.add_argument('--data', required=True, help='a npz file') 16 | parser.add_argument('--output', type=str) 17 | parser.add_argument('--exp_name', type=str, default=None) 18 | 19 | # parameters for model tuning. 20 | parser.add_argument('--batch_size', type=int, default=200) 21 | parser.add_argument('--z_dim', type=int, default=100) 22 | parser.add_argument('--max_epoch', type=int, default=100) 23 | parser.add_argument('--steps_per_epoch', type=int, default=1000) 24 | 25 | parser.add_argument('--num_gen_rnn', type=int, default=400) 26 | parser.add_argument('--num_gen_feature', type=int, default=100) 27 | 28 | parser.add_argument('--num_dis_layers', type=int, default=2) 29 | parser.add_argument('--num_dis_hidden', type=int, default=200) 30 | 31 | parser.add_argument('--noise', type=float, default=0.2) 32 | 33 | parser.add_argument('--optimizer', type=str, default='AdamOptimizer', 34 | choices=['GradientDescentOptimizer', 'AdamOptimizer', 'AdadeltaOptimizer']) 35 | parser.add_argument('--learning_rate', type=float, default=0.001) 36 | 37 | parser.add_argument('--l2norm', type=float, default=0.00001) 38 | 39 | return parser 40 | 41 | 42 | def get_parser(): 43 | """Build argument parser for TGAN CLI utility.""" 44 | parser = argparse.ArgumentParser(description='TGAN Command Line Interface.') 45 | parser.set_defaults(function=None) 46 | 47 | action = parser.add_subparsers(title='action', dest='action') 48 | action.required = True 49 | 50 | experiments = action.add_parser('experiments', help='Run experiments using TGAN.') 51 | experiments.add_argument( 52 | 'input', type=str, help='Path to the JSON file with the configuration.') 53 | experiments.add_argument( 54 | 'output', type=str, help='Path to store the results.') 55 | 56 | return parser 57 | 58 | 59 | def main(): 60 | """Python Entry point for CLI.""" 61 | parser = get_parser() 62 | args = parser.parse_args() 63 | run_experiments(args.input, args.output) 64 | -------------------------------------------------------------------------------- /tgan_skip/data.py: -------------------------------------------------------------------------------- 1 | """Data related functionalities. 2 | 3 | This modules contains the tools to preprare the data, from the raw csv files, to the DataFlow 4 | objects will be used to fit our models. 5 | """ 6 | import os 7 | import urllib 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from sklearn.mixture import GaussianMixture 12 | from sklearn.preprocessing import LabelEncoder 13 | from tensorpack import DataFlow, RNGDataFlow 14 | 15 | DEMO_DATASETS = { 16 | 'census': ( 17 | 'http://hdi-project-tgan.s3.amazonaws.com/census-train.csv', 18 | 'data/census.csv', 19 | [0, 5, 16, 17, 18, 29, 38] 20 | ), 21 | 'covertype': ( 22 | 'http://hdi-project-tgan.s3.amazonaws.com/covertype-train.csv', 23 | 'data/covertype.csv', 24 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 25 | ) 26 | } 27 | 28 | 29 | def check_metadata(metadata): 30 | """Check that the given metadata has correct types for all its members. 31 | 32 | Args: 33 | metadata(dict): Description of the inputs. 34 | 35 | Returns: 36 | None 37 | 38 | Raises: 39 | AssertionError: If any of the details is not valid. 40 | 41 | """ 42 | message = 'The given metadata contains unsupported types.' 43 | assert all([item['type'] in ['category', 'value'] for item in metadata['details']]), message 44 | 45 | 46 | def check_inputs(function): 47 | """Validate inputs for functions whose first argument is a numpy.ndarray with shape (n,1). 48 | 49 | Args: 50 | function(callable): Method to validate. 51 | 52 | Returns: 53 | callable: Will check the inputs before calling :attr:`function`. 54 | 55 | Raises: 56 | ValueError: If first argument is not a valid :class:`numpy.array` of shape (n, 1). 57 | 58 | """ 59 | def decorated(self, data, *args, **kwargs): 60 | if not (isinstance(data, np.ndarray) and len(data.shape) == 2 and data.shape[1] == 1): 61 | raise ValueError('The argument `data` must be a numpy.ndarray with shape (n, 1).') 62 | 63 | return function(self, data, *args, **kwargs) 64 | 65 | decorated.__doc__ = function.__doc__ 66 | return decorated 67 | 68 | 69 | class TGANDataFlow(RNGDataFlow): 70 | """Subclass of :class:`tensorpack.RNGDataFlow` prepared to work with :class:`numpy.ndarray`. 71 | 72 | Attributes: 73 | shuffle(bool): Wheter or not to shuffle the data. 74 | metadata(dict): Metadata for the given :attr:`data`. 75 | num_features(int): Number of features in given data. 76 | data(list): Prepared data from :attr:`filename`. 77 | distribution(list): DepecrationWarning? 78 | 79 | """ 80 | 81 | def __init__(self, data, metadata, shuffle=True): 82 | """Initialize object. 83 | 84 | Args: 85 | filename(str): Path to the json file containing the metadata. 86 | shuffle(bool): Wheter or not to shuffle the data. 87 | 88 | Raises: 89 | ValueError: If any column_info['type'] is not supported 90 | 91 | """ 92 | self.shuffle = shuffle 93 | if self.shuffle: 94 | self.reset_state() 95 | 96 | self.metadata = metadata 97 | self.num_features = self.metadata['num_features'] 98 | 99 | self.data = [] 100 | self.distribution = [] 101 | for column_id, column_info in enumerate(self.metadata['details']): 102 | if column_info['type'] == 'value': 103 | col_data = data['f%02d' % column_id] 104 | value = col_data[:, :1] 105 | cluster = col_data[:, 1:] 106 | self.data.append(value) 107 | self.data.append(cluster) 108 | 109 | elif column_info['type'] == 'category': 110 | col_data = np.asarray(data['f%02d' % column_id], dtype='int32') 111 | self.data.append(col_data) 112 | 113 | else: 114 | raise ValueError( 115 | "column_info['type'] must be either 'category' or 'value'." 116 | "Instead it was '{}'.".format(column_info['type']) 117 | ) 118 | 119 | self.data = list(zip(*self.data)) 120 | 121 | def size(self): 122 | """Return the number of rows in data. 123 | 124 | Returns: 125 | int: Number of rows in :attr:`data`. 126 | 127 | """ 128 | return len(self.data) 129 | 130 | def get_data(self): 131 | """Yield the rows from :attr:`data`. 132 | 133 | Yields: 134 | tuple: Row of data. 135 | 136 | """ 137 | idxs = np.arange(len(self.data)) 138 | if self.shuffle: 139 | self.rng.shuffle(idxs) 140 | 141 | for k in idxs: 142 | yield self.data[k] 143 | 144 | def __iter__(self): 145 | """Iterate over self.data.""" 146 | return self.get_data() 147 | 148 | def __len__(self): 149 | """Length of batches.""" 150 | return self.size() 151 | 152 | 153 | class RandomZData(DataFlow): 154 | """Random dataflow. 155 | 156 | Args: 157 | shape(tuple): Shape of the array to return on :meth:`get_data` 158 | 159 | """ 160 | 161 | def __init__(self, shape): 162 | """Initialize object.""" 163 | super(RandomZData, self).__init__() 164 | self.shape = shape 165 | 166 | def get_data(self): 167 | """Yield random normal vectors of shape :attr:`shape`.""" 168 | while True: 169 | yield [np.random.normal(0, 1, size=self.shape)] 170 | 171 | def __iter__(self): 172 | """Return data.""" 173 | return self.get_data() 174 | 175 | def __len__(self): 176 | """Length of batches.""" 177 | return self.shape[0] 178 | 179 | 180 | class MultiModalNumberTransformer: 181 | r"""Reversible transform for multimodal data. 182 | 183 | To effectively sample values from a multimodal distribution, we cluster values of a 184 | numerical variable using a `skelarn.mixture.GaussianMixture`_ model (GMM). 185 | 186 | * We train a GMM with :attr:`n` components for each numerical variable :math:`C_i`. 187 | GMM models a distribution with a weighted sum of :attr:`n` Gaussian distributions. 188 | The means and standard deviations of the :attr:`n` Gaussian distributions are 189 | :math:`{\eta}^{(1)}_{i}, ..., {\eta}^{(n)}_{i}` and 190 | :math:`{\sigma}^{(1)}_{i}, ...,{\sigma}^{(n)}_{i}`. 191 | 192 | * We compute the probability of :math:`c_{i,j}` coming from each of the :attr:`n` Gaussian 193 | distributions as a vector :math:`{u}^{(1)}_{i,j}, ..., {u}^{(n)}_{i,j}`. u_{i,j} is a 194 | normalized probability distribution over :attr:`n` Gaussian distributions. 195 | 196 | * We normalize :math:`c_{i,j}` as :math:`v_{i,j} = (c_{i,j}−{\eta}^{(k)}_{i})/2{\sigma}^ 197 | {(k)}_{i}`, where :math:`k = arg max_k {u}^{(k)}_{i,j}`. We then clip :math:`v_{i,j}` to 198 | [−0.99, 0.99]. 199 | 200 | Then we use :math:`u_i` and :math:`v_i` to represent :math:`c_i`. For simplicity, 201 | we cluster all the numerical features, i.e. both uni-modal and multi-modal features are 202 | clustered to :attr:`n = 5` Gaussian distributions. 203 | 204 | The simplification is fair because GMM automatically weighs :attr:`n` components. 205 | For example, if a variable has only one mode and fits some Gaussian distribution, then GMM 206 | will assign a very low probability to :attr:`n − 1` components and only 1 remaining 207 | component actually works, which is equivalent to not clustering this feature. 208 | 209 | Args: 210 | num_modes(int): Number of modes on given data. 211 | 212 | Attributes: 213 | num_modes(int): Number of components in the `skelarn.mixture.GaussianMixture`_ model. 214 | 215 | .. _skelarn.mixture.GaussianMixture: https://scikit-learn.org/stable/modules/generated/ 216 | sklearn.mixture.GaussianMixture.html 217 | 218 | """ 219 | 220 | def __init__(self, num_modes=5): 221 | """Initialize instance.""" 222 | self.num_modes = num_modes 223 | 224 | @check_inputs 225 | def transform(self, data): 226 | """Cluster values using a `skelarn.mixture.GaussianMixture`_ model. 227 | 228 | Args: 229 | data(numpy.ndarray): Values to cluster in array of shape (n,1). 230 | 231 | Returns: 232 | tuple[numpy.ndarray, numpy.ndarray, list, list]: Tuple containg the features, 233 | probabilities, averages and stds of the given data. 234 | 235 | .. _skelarn.mixture.GaussianMixture: https://scikit-learn.org/stable/modules/generated/ 236 | sklearn.mixture.GaussianMixture.html 237 | 238 | """ 239 | model = GaussianMixture(self.num_modes) 240 | model.fit(data) 241 | 242 | means = model.means_.reshape((1, self.num_modes)) 243 | stds = np.sqrt(model.covariances_).reshape((1, self.num_modes)) 244 | 245 | features = (data - means) / (2 * stds) 246 | probs = model.predict_proba(data) 247 | argmax = np.argmax(probs, axis=1) 248 | idx = np.arange(len(features)) 249 | features = features[idx, argmax].reshape([-1, 1]) 250 | 251 | features = np.clip(features, -0.99, 0.99) 252 | 253 | return features, probs, list(means.flat), list(stds.flat) 254 | 255 | @staticmethod 256 | def inverse_transform(data, info): 257 | """Reverse the clustering of values. 258 | 259 | Args: 260 | data(numpy.ndarray): Transformed data to restore. 261 | info(dict): Metadata. 262 | 263 | Returns: 264 | numpy.ndarray: Values in the original space. 265 | 266 | """ 267 | features = data[:, 0] 268 | probs = data[:, 1:] 269 | p_argmax = np.argmax(probs, axis=1) 270 | 271 | mean = np.asarray(info['means']) 272 | std = np.asarray(info['stds']) 273 | 274 | select_mean = mean[p_argmax] 275 | select_std = std[p_argmax] 276 | 277 | return features * 2 * select_std + select_mean 278 | 279 | 280 | class Preprocessor: 281 | """Transform back and forth human-readable data into TGAN numerical features. 282 | 283 | Args: 284 | continous_columns(list): List of columns to be considered continuous 285 | metadata(dict): Metadata to initialize the object. 286 | 287 | Attributes: 288 | continous_columns(list): Same as constructor argument. 289 | metadata(dict): Information about the transformations applied to the data and its format. 290 | continous_transformer(MultiModalNumberTransformer): 291 | Transformer for columns in :attr:`continuous_columns` 292 | categorical_transformer(CategoricalTransformer): 293 | Transformer for categorical columns. 294 | columns(list): List of columns labels. 295 | 296 | """ 297 | 298 | def __init__(self, continuous_columns=None, metadata=None): 299 | """Initialize object, set arguments as attributes, initialize transformers.""" 300 | if continuous_columns is None: 301 | continuous_columns = [] 302 | 303 | self.continuous_columns = continuous_columns 304 | self.metadata = metadata 305 | self.continous_transformer = MultiModalNumberTransformer() 306 | self.categorical_transformer = LabelEncoder() 307 | self.columns = None 308 | 309 | def fit_transform(self, data, fitting=True): 310 | """Transform human-readable data into TGAN numerical features. 311 | 312 | Args: 313 | data(pandas.DataFrame): Data to transform. 314 | fitting(bool): Whether or not to update self.metadata. 315 | 316 | Returns: 317 | pandas.DataFrame: Model features 318 | 319 | """ 320 | num_cols = data.shape[1] 321 | self.columns = data.columns 322 | data.columns = list(range(num_cols)) 323 | 324 | transformed_data = {} 325 | details = [] 326 | 327 | for i in data.columns: 328 | if i in self.continuous_columns: 329 | column_data = data[i].values.reshape([-1, 1]) 330 | features, probs, means, stds = self.continous_transformer.transform(column_data) 331 | transformed_data['f%02d' % i] = np.concatenate((features, probs), axis=1) 332 | 333 | if fitting: 334 | details.append({ 335 | "type": "value", 336 | "means": means, 337 | "stds": stds, 338 | "n": 5 339 | }) 340 | 341 | else: 342 | column_data = data[i].astype(str).values 343 | features = self.categorical_transformer.fit_transform(column_data) 344 | transformed_data['f%02d' % i] = features.reshape([-1, 1]) 345 | 346 | if fitting: 347 | mapping = self.categorical_transformer.classes_ 348 | details.append({ 349 | "type": "category", 350 | "mapping": mapping, 351 | "n": mapping.shape[0], 352 | }) 353 | 354 | if fitting: 355 | metadata = { 356 | "num_features": num_cols, 357 | "details": details 358 | } 359 | check_metadata(metadata) 360 | self.metadata = metadata 361 | 362 | return transformed_data 363 | 364 | def transform(self, data): 365 | """Transform the given dataframe without generating new metadata. 366 | 367 | Args: 368 | data(pandas.DataFrame): Data to fit the object. 369 | 370 | """ 371 | return self.fit_transform(data, fitting=False) 372 | 373 | def fit(self, data): 374 | """Initialize the internal state of the object using :attr:`data`. 375 | 376 | Args: 377 | data(pandas.DataFrame): Data to fit the object. 378 | 379 | """ 380 | self.fit_transform(data) 381 | 382 | def reverse_transform(self, data): 383 | """Transform TGAN numerical features back into human-readable data. 384 | 385 | Args: 386 | data(pandas.DataFrame): Data to transform. 387 | fitting(bool): Whether or not to update self.metadata. 388 | 389 | Returns: 390 | pandas.DataFrame: Model features 391 | 392 | """ 393 | table = [] 394 | 395 | for i in range(self.metadata['num_features']): 396 | column_data = data['f%02d' % i] 397 | column_metadata = self.metadata['details'][i] 398 | 399 | if column_metadata['type'] == 'value': 400 | column = self.continous_transformer.inverse_transform(column_data, column_metadata) 401 | 402 | if column_metadata['type'] == 'category': 403 | self.categorical_transformer.classes_ = column_metadata['mapping'] 404 | column = self.categorical_transformer.inverse_transform( 405 | column_data.ravel().astype(np.int32)) 406 | 407 | table.append(column) 408 | 409 | result = pd.DataFrame(dict(enumerate(table))) 410 | result.columns = self.columns 411 | return result 412 | 413 | 414 | def load_demo_data(name, header=None): 415 | """Fetch, load and prepare a dataset. 416 | 417 | If name is one of the demo datasets 418 | 419 | 420 | Args: 421 | name(str): Name or path of the dataset. 422 | header(): Header parameter when executing :attr:`pandas.read_csv` 423 | 424 | """ 425 | params = DEMO_DATASETS.get(name) 426 | if params: 427 | url, file_path, continuous_columns = params 428 | if not os.path.isfile(file_path): 429 | base_path = os.path.dirname(file_path) 430 | if not os.path.exists(base_path): 431 | os.makedirs(base_path) 432 | 433 | urllib.request.urlretrieve(url, file_path) 434 | 435 | else: 436 | message = ( 437 | '{} is not a valid dataset name. ' 438 | 'Supported values are: {}.'.format(name, list(DEMO_DATASETS.keys())) 439 | ) 440 | raise ValueError(message) 441 | 442 | return pd.read_csv(file_path, header=header), continuous_columns 443 | -------------------------------------------------------------------------------- /tgan_skip/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Module with the model for TGAN. 5 | 6 | This module contains two classes: 7 | 8 | - :attr:`GraphBuilder`: That defines the graph and implements a Tensorpack compatible API. 9 | - :attr:`TGANModel`: The public API for the model, that offers a simplified interface for the 10 | underlying operations with GraphBuilder and trainers in order to fit and sample data. 11 | """ 12 | import json 13 | import os 14 | import pickle 15 | import tarfile 16 | 17 | import numpy as np 18 | from comet_ml import Experiment 19 | from tensorpack.callbacks import CometMLMonitor, MergeAllSummaries 20 | import tensorflow as tf 21 | from tensorpack import ( 22 | BatchData, BatchNorm, Dropout, FullyConnected, InputDesc, ModelDescBase, ModelSaver, 23 | PredictConfig, QueueInput, SaverRestore, SimpleDatasetPredictor, logger) 24 | from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope 25 | from tensorpack.tfutils.summary import add_moving_summary 26 | from tensorpack.utils.argtools import memoized 27 | 28 | from tgan.data import Preprocessor, RandomZData, TGANDataFlow 29 | from tgan.trainer import GANTrainer 30 | 31 | TUNABLE_VARIABLES = { 32 | 'batch_size': [50, 100, 200], 33 | 'z_dim': [50, 100, 200, 400], 34 | 'num_gen_rnn': [100, 200, 300, 400, 500, 600], 35 | 'num_gen_feature': [100, 200, 300, 400, 500, 600], 36 | 'num_dis_layers': [1, 2, 3, 4, 5], 37 | 'num_dis_hidden': [100, 200, 300, 400, 500], 38 | 'learning_rate': [0.0002, 0.0005, 0.001], 39 | 'noise': [0.05, 0.1, 0.2, 0.3] 40 | } 41 | 42 | 43 | class GraphBuilder(ModelDescBase): 44 | """Main model for TGAN. 45 | 46 | Args: 47 | None 48 | 49 | Attributes: 50 | 51 | """ 52 | 53 | def __init__( 54 | self, 55 | metadata, 56 | batch_size=200, 57 | z_dim=200, 58 | noise=0.2, 59 | l2norm=0.00001, 60 | learning_rate=0.001, 61 | num_gen_rnn=100, 62 | num_gen_feature=100, 63 | num_dis_layers=1, 64 | num_dis_hidden=100, 65 | optimizer='AdamOptimizer', 66 | training=True 67 | ): 68 | """Initialize the object, set arguments as attributes.""" 69 | self.metadata = metadata 70 | self.batch_size = batch_size 71 | self.z_dim = z_dim 72 | self.noise = noise 73 | self.l2norm = l2norm 74 | self.learning_rate = learning_rate 75 | self.num_gen_rnn = num_gen_rnn 76 | self.num_gen_feature = num_gen_feature 77 | self.num_dis_layers = num_dis_layers 78 | self.num_dis_hidden = num_dis_hidden 79 | self.optimizer = optimizer 80 | self.training = training 81 | 82 | def collect_variables(self, g_scope='gen', d_scope='discrim'): 83 | """Assign generator and discriminator variables from their scopes. 84 | 85 | Args: 86 | g_scope(str): Scope for the generator. 87 | d_scope(str): Scope for the discriminator. 88 | 89 | Raises: 90 | ValueError: If any of the assignments fails or the collections are empty. 91 | 92 | """ 93 | self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope) 94 | self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope) 95 | 96 | if not (self.g_vars or self.d_vars): 97 | raise ValueError('There are no variables defined in some of the given scopes') 98 | 99 | def build_losses(self, logits_real, logits_fake, extra_g=0, l2_norm=0.00001): 100 | r"""D and G play two-player minimax game with value function :math:`V(G,D)`. 101 | 102 | .. math:: 103 | 104 | min_G max_D V(D, G) = IE_{x \sim p_{data}} [log D(x)] + IE_{z \sim p_{fake}} 105 | [log (1 - D(G(z)))] 106 | 107 | Args: 108 | logits_real (tensorflow.Tensor): discrim logits from real samples. 109 | logits_fake (tensorflow.Tensor): discrim logits from fake samples from generator. 110 | extra_g(float): 111 | l2_norm(float): scale to apply L2 regularization. 112 | 113 | Returns: 114 | None 115 | 116 | """ 117 | with tf.name_scope("GAN_loss"): 118 | score_real = tf.sigmoid(logits_real) 119 | score_fake = tf.sigmoid(logits_fake) 120 | tf.summary.histogram('score-real', score_real) 121 | tf.summary.histogram('score-fake', score_fake) 122 | 123 | with tf.name_scope("discrim"): 124 | d_loss_pos = tf.reduce_mean( 125 | tf.nn.sigmoid_cross_entropy_with_logits( 126 | logits=logits_real, 127 | labels=tf.ones_like(logits_real)) * 0.7 + tf.random_uniform( 128 | tf.shape(logits_real), 129 | maxval=0.3 130 | ), 131 | name='loss_real' 132 | ) 133 | 134 | d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 135 | logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss_fake') 136 | 137 | d_pos_acc = tf.reduce_mean( 138 | tf.cast(score_real > 0.5, tf.float32), name='accuracy_real') 139 | 140 | d_neg_acc = tf.reduce_mean( 141 | tf.cast(score_fake < 0.5, tf.float32), name='accuracy_fake') 142 | 143 | d_loss = 0.5 * d_loss_pos + 0.5 * d_loss_neg + \ 144 | tf.contrib.layers.apply_regularization( 145 | tf.contrib.layers.l2_regularizer(l2_norm), 146 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discrim")) 147 | 148 | self.d_loss = tf.identity(d_loss, name='loss') 149 | 150 | with tf.name_scope("gen"): 151 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 152 | logits=logits_fake, labels=tf.ones_like(logits_fake))) + \ 153 | tf.contrib.layers.apply_regularization( 154 | tf.contrib.layers.l2_regularizer(l2_norm), 155 | tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'gen')) 156 | 157 | g_loss = tf.identity(g_loss, name='loss') 158 | extra_g = tf.identity(extra_g, name='klloss') 159 | self.g_loss = tf.identity(g_loss + extra_g, name='final-g-loss') 160 | 161 | add_moving_summary( 162 | g_loss, extra_g, self.g_loss, self.d_loss, d_pos_acc, d_neg_acc, decay=0.) 163 | 164 | @memoized 165 | def get_optimizer(self): 166 | """Return optimizer of base class.""" 167 | return self._get_optimizer() 168 | 169 | def inputs(self): 170 | """Return metadata about entry data. 171 | 172 | Returns: 173 | list[tensorpack.InputDesc] 174 | 175 | Raises: 176 | ValueError: If any of the elements in self.metadata['details'] has an unsupported 177 | value in the `type` key. 178 | 179 | """ 180 | inputs = [] 181 | for col_id, col_info in enumerate(self.metadata['details']): 182 | if col_info['type'] == 'value': 183 | gaussian_components = col_info['n'] 184 | inputs.append( 185 | InputDesc(tf.float32, (self.batch_size, 1), 'input%02dvalue' % col_id)) 186 | 187 | inputs.append( 188 | InputDesc( 189 | tf.float32, 190 | (self.batch_size, gaussian_components), 191 | 'input%02dcluster' % col_id 192 | ) 193 | ) 194 | 195 | elif col_info['type'] == 'category': 196 | inputs.append(InputDesc(tf.int32, (self.batch_size, 1), 'input%02d' % col_id)) 197 | 198 | else: 199 | raise ValueError( 200 | "self.metadata['details'][{}]['type'] must be either `category` or " 201 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 202 | ) 203 | 204 | return inputs 205 | 206 | def generator(self, z): 207 | r"""Build generator graph. 208 | 209 | We generate a numerical variable in 2 steps. We first generate the value scalar 210 | :math:`v_i`, then generate the cluster vector :math:`u_i`. We generate categorical 211 | feature in 1 step as a probability distribution over all possible labels. 212 | 213 | The output and hidden state size of LSTM is :math:`n_h`. The input to the LSTM in each 214 | step :math:`t` is the random variable :math:`z`, the previous hidden vector :math:`f_{t−1}` 215 | or an embedding vector :math:`f^{\prime}_{t−1}` depending on the type of previous output, 216 | and the weighted context vector :math:`a_{t−1}`. The random variable :math:`z` has 217 | :math:`n_z` dimensions. 218 | Each dimension is sampled from :math:`\mathcal{N}(0, 1)`. The attention-based context 219 | vector at is a weighted average over all the previous LSTM outputs :math:`h_{1:t}`. 220 | So :math:`a_t` is a :math:`n_h`-dimensional vector. 221 | We learn a attention weight vector :math:`α_t \in \mathbb{R}^t` and compute context as 222 | 223 | .. math:: 224 | a_t = \sum_{k=1}^{t} \frac{\textrm{exp} {\alpha}_{t, j}} 225 | {\sum_{j} \textrm{exp} \alpha_{t,j}} h_k. 226 | 227 | We set :math: `a_0` = 0. The output of LSTM is :math:`h_t` and we project the output to 228 | a hidden vector :math:`f_t = \textrm{tanh}(W_h h_t)`, where :math:`W_h` is a learned 229 | parameter in the network. The size of :math:`f_t` is :math:`n_f` . 230 | We further convert the hidden vector to an output variable. 231 | 232 | * If the output is the value part of a continuous variable, we compute the output as 233 | :math:`v_i = \textrm{tanh}(W_t f_t)`. The hidden vector for :math:`t + 1` step is 234 | :math:`f_t`. 235 | 236 | * If the output is the cluster part of a continuous variable, we compute the output as 237 | :math:`u_i = \textrm{softmax}(W_t f_t)`. The feature vector for :math:`t + 1` step is 238 | :math:`f_t`. 239 | 240 | * If the output is a discrete variable, we compute the output as 241 | :math:`d_i = \textrm{softmax}(W_t f_t)`. The hidden vector for :math:`t + 1` step is 242 | :math:`f^{\prime}_{t} = E_i [arg_k \hspace{0.25em} \textrm{max} \hspace{0.25em} d_i ]`, 243 | where :math:`E \in R^{|D_i|×n_f}` is an embedding matrix for discrete variable 244 | :math:`D_i`. 245 | 246 | * :math:`f_0` is a special vector :math:`\texttt{}` and we learn it during the 247 | training. 248 | 249 | Args: 250 | z: 251 | 252 | Returns: 253 | list[tensorflow.Tensor]: Outpu 254 | 255 | Raises: 256 | ValueError: If any of the elements in self.metadata['details'] has an unsupported 257 | value in the `type` key. 258 | 259 | """ 260 | with tf.variable_scope('LSTM'): 261 | cell = tf.nn.rnn_cell.LSTMCell(self.num_gen_rnn) 262 | 263 | state = cell.zero_state(self.batch_size, dtype='float32') 264 | attention = tf.zeros( 265 | shape=(self.batch_size, self.num_gen_rnn), dtype='float32') 266 | input = tf.get_variable(name='go', shape=(1, self.num_gen_feature)) # 267 | input = tf.tile(input, [self.batch_size, 1]) 268 | input = tf.concat([input, z], axis=1) 269 | 270 | ptr = 0 271 | outputs = [] 272 | states = [] 273 | for col_id, col_info in enumerate(self.metadata['details']): 274 | if col_info['type'] == 'value': 275 | output, state = cell(tf.concat([input, attention], axis=1), state) 276 | states.append(state[1]) 277 | 278 | gaussian_components = col_info['n'] 279 | with tf.variable_scope("%02d" % ptr): 280 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 281 | outputs.append(FullyConnected('FC2', h, 1, nl=tf.tanh)) 282 | input = tf.concat([h, z], axis=1) 283 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 284 | attw = tf.nn.softmax(attw, axis=0) 285 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 286 | 287 | ptr += 1 288 | 289 | output, state = cell(tf.concat([input, attention], axis=1), state) 290 | states.append(state[1]) 291 | with tf.variable_scope("%02d" % ptr): 292 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 293 | h = tf.concat([h, z], axis=1) 294 | w = FullyConnected('FC2', h, gaussian_components, nl=tf.nn.softmax) 295 | outputs.append(w) 296 | input = FullyConnected('FC3', w, self.num_gen_feature, nl=tf.identity) 297 | input = tf.concat([input, z], axis=1) 298 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 299 | attw = tf.nn.softmax(attw, axis=0) 300 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 301 | 302 | ptr += 1 303 | 304 | elif col_info['type'] == 'category': 305 | output, state = cell(tf.concat([input, attention], axis=1), state) 306 | states.append(state[1]) 307 | with tf.variable_scope("%02d" % ptr): 308 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 309 | h = tf.concat([h, z], axis=1) 310 | w = FullyConnected('FC2', h, col_info['n'], nl=tf.nn.softmax) 311 | outputs.append(w) 312 | one_hot = tf.one_hot(tf.argmax(w, axis=1), col_info['n']) 313 | input = FullyConnected( 314 | 'FC3', one_hot, self.num_gen_feature, nl=tf.identity) 315 | input = tf.concat([input, z], axis=1) 316 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 317 | attw = tf.nn.softmax(attw, axis=0) 318 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 319 | 320 | ptr += 1 321 | 322 | else: 323 | raise ValueError( 324 | "self.metadata['details'][{}]['type'] must be either `category` or " 325 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 326 | ) 327 | 328 | return outputs 329 | 330 | @staticmethod 331 | def batch_diversity(l, n_kernel=10, kernel_dim=10): 332 | r"""Return the minibatch discrimination vector. 333 | 334 | Let :math:`f(x_i) \in \mathbb{R}^A` denote a vector of features for input :math:`x_i`, 335 | produced by some intermediate layer in the discriminator. We then multiply the vector 336 | :math:`f(x_i)` by a tensor :math:`T \in \mathbb{R}^{A×B×C}`, which results in a matrix 337 | :math:`M_i \in \mathbb{R}^{B×C}`. We then compute the :math:`L_1`-distance between the 338 | rows of the resulting matrix :math:`M_i` across samples :math:`i \in {1, 2, ... , n}` 339 | and apply a negative exponential: 340 | 341 | .. math:: 342 | 343 | cb(x_i, x_j) = exp(−||M_{i,b} − M_{j,b}||_{L_1} ) \in \mathbb{R}. 344 | 345 | The output :math:`o(x_i)` for this *minibatch layer* for a sample :math:`x_i` is then 346 | defined as the sum of the cb(xi, xj )’s to all other samples: 347 | 348 | .. math:: 349 | :nowrap: 350 | 351 | \begin{aligned} 352 | 353 | &o(x_i)_b = \sum^{n}_{j=1} cb(x_i , x_j) \in \mathbb{R}\\ 354 | &o(x_i) = \Big[ o(x_i)_1, o(x_i)_2, . . . , o(x_i)_B \Big] \in \mathbb{R}^B\\ 355 | &o(X) ∈ R^{n×B}\\ 356 | 357 | \end{aligned} 358 | 359 | Note: 360 | This is extracted from `Improved techniques for training GANs`_ (Section 3.2) by 361 | Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and 362 | Xi Chen. 363 | 364 | .. _Improved techniques for training GANs: https://arxiv.org/pdf/1606.03498.pdf 365 | 366 | Args: 367 | l(tf.Tensor) 368 | n_kernel(int) 369 | kernel_dim(int) 370 | 371 | Returns: 372 | tensorflow.Tensor 373 | 374 | """ 375 | M = FullyConnected('fc_diversity', l, n_kernel * kernel_dim, nl=tf.identity) 376 | M = tf.reshape(M, [-1, n_kernel, kernel_dim]) 377 | M1 = tf.reshape(M, [-1, 1, n_kernel, kernel_dim]) 378 | M2 = tf.reshape(M, [1, -1, n_kernel, kernel_dim]) 379 | diff = tf.exp(-tf.reduce_sum(tf.abs(M1 - M2), axis=3)) 380 | return tf.reduce_sum(diff, axis=0) 381 | 382 | @auto_reuse_variable_scope 383 | def discriminator(self, vecs): 384 | r"""Build discriminator. 385 | 386 | We use a :math:`l`-layer fully connected neural network as the discriminator. 387 | We concatenate :math:`v_{1:n_c}`, :math:`u_{1:n_c}` and :math:`d_{1:n_d}` together as the 388 | input. We compute the internal layers as 389 | 390 | .. math:: 391 | \begin{aligned} 392 | 393 | f^{(D)}_{1} &= \textrm{LeakyReLU}(\textrm{BN}(W^{(D)}_{1}(v_{1:n_c} \oplus u_{1:n_c} 394 | \oplus d_{1:n_d}) 395 | 396 | f^{(D)}_{1} &= \textrm{LeakyReLU}(\textrm{BN}(W^{(D)}_{i}(f^{(D)}_{i−1} \oplus 397 | \textrm{diversity}(f^{(D)}_{i−1})))), i = 2:l 398 | 399 | \end{aligned} 400 | 401 | where :math:`\oplus` is the concatenation operation. :math:`\textrm{diversity}(·)` is the 402 | mini-batch discrimination vector [42]. Each dimension of the diversity vector is the total 403 | distance between one sample and all other samples in the mini-batch using some learned 404 | distance metric. :math:`\textrm{BN}(·)` is batch normalization, and 405 | :math:`\textrm{LeakyReLU}(·)` is the leaky reflect linear activation function. We further 406 | compute the output of discriminator as :math:`W^{(D)}(f^{(D)}_{l} \oplus \textrm{diversity} 407 | (f^{(D)}_{l}))` which is a scalar. 408 | 409 | Args: 410 | vecs(list[tensorflow.Tensor]): List of tensors matching the spec of :meth:`inputs` 411 | 412 | Returns: 413 | tensorpack.FullyConected: a (b, 1) logits 414 | 415 | """ 416 | logits = tf.concat(vecs, axis=1) 417 | for i in range(self.num_dis_layers): 418 | with tf.variable_scope('dis_fc{}'.format(i)): 419 | if i == 0: 420 | logits = FullyConnected( 421 | 'fc', logits, self.num_dis_hidden, nl=tf.identity, 422 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.1) 423 | ) 424 | 425 | else: 426 | logits = FullyConnected('fc', logits, self.num_dis_hidden, nl=tf.identity) 427 | 428 | logits = tf.concat([logits, self.batch_diversity(logits)], axis=1) 429 | logits = BatchNorm('bn', logits, center=True, scale=False) 430 | logits = Dropout(logits) 431 | logits = tf.nn.leaky_relu(logits) 432 | 433 | return FullyConnected('dis_fc_top', logits, 1, nl=tf.identity) 434 | 435 | @staticmethod 436 | def compute_kl(real, pred): 437 | r"""Compute the Kullback–Leibler divergence, :math:`D_{KL}(\textrm{pred} || \textrm{real})`. 438 | 439 | Args: 440 | real(tensorflow.Tensor): Real values. 441 | pred(tensorflow.Tensor): Predicted values. 442 | 443 | Returns: 444 | float: Computed divergence for the given values. 445 | 446 | """ 447 | return tf.reduce_sum((tf.log(pred + 1e-4) - tf.log(real + 1e-4)) * pred) 448 | 449 | def build_graph(self, *inputs): 450 | """Build the whole graph. 451 | 452 | Args: 453 | inputs(list[tensorflow.Tensor]): 454 | 455 | Returns: 456 | None 457 | 458 | """ 459 | z = tf.random_normal( 460 | [self.batch_size, self.z_dim], name='z_train') 461 | 462 | z = tf.placeholder_with_default(z, [None, self.z_dim], name='z') 463 | 464 | with tf.variable_scope('gen'): 465 | vecs_gen = self.generator(z) 466 | 467 | vecs_denorm = [] 468 | ptr = 0 469 | for col_id, col_info in enumerate(self.metadata['details']): 470 | if col_info['type'] == 'category': 471 | t = tf.argmax(vecs_gen[ptr], axis=1) 472 | t = tf.cast(tf.reshape(t, [-1, 1]), 'float32') 473 | vecs_denorm.append(t) 474 | ptr += 1 475 | 476 | elif col_info['type'] == 'value': 477 | vecs_denorm.append(vecs_gen[ptr]) 478 | ptr += 1 479 | vecs_denorm.append(vecs_gen[ptr]) 480 | ptr += 1 481 | 482 | else: 483 | raise ValueError( 484 | "self.metadata['details'][{}]['type'] must be either `category` or " 485 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 486 | ) 487 | 488 | tf.identity(tf.concat(vecs_denorm, axis=1), name='gen') 489 | 490 | vecs_pos = [] 491 | ptr = 0 492 | for col_id, col_info in enumerate(self.metadata['details']): 493 | if col_info['type'] == 'category': 494 | one_hot = tf.one_hot(tf.reshape(inputs[ptr], [-1]), col_info['n']) 495 | noise_input = one_hot 496 | 497 | if self.training: 498 | noise = tf.random_uniform(tf.shape(one_hot), minval=0, maxval=self.noise) 499 | noise_input = (one_hot + noise) / tf.reduce_sum( 500 | one_hot + noise, keepdims=True, axis=1) 501 | 502 | vecs_pos.append(noise_input) 503 | ptr += 1 504 | 505 | elif col_info['type'] == 'value': 506 | vecs_pos.append(inputs[ptr]) 507 | ptr += 1 508 | vecs_pos.append(inputs[ptr]) 509 | ptr += 1 510 | 511 | else: 512 | raise ValueError( 513 | "self.metadata['details'][{}]['type'] must be either `category` or " 514 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 515 | ) 516 | 517 | KL = 0. 518 | ptr = 0 519 | if self.training: 520 | for col_id, col_info in enumerate(self.metadata['details']): 521 | if col_info['type'] == 'category': 522 | dist = tf.reduce_sum(vecs_gen[ptr], axis=0) 523 | dist = dist / tf.reduce_sum(dist) 524 | 525 | real = tf.reduce_sum(vecs_pos[ptr], axis=0) 526 | real = real / tf.reduce_sum(real) 527 | KL += self.compute_kl(real, dist) 528 | ptr += 1 529 | 530 | elif col_info['type'] == 'value': 531 | ptr += 1 532 | dist = tf.reduce_sum(vecs_gen[ptr], axis=0) 533 | dist = dist / tf.reduce_sum(dist) 534 | real = tf.reduce_sum(vecs_pos[ptr], axis=0) 535 | real = real / tf.reduce_sum(real) 536 | KL += self.compute_kl(real, dist) 537 | 538 | ptr += 1 539 | 540 | else: 541 | raise ValueError( 542 | "self.metadata['details'][{}]['type'] must be either `category` or " 543 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 544 | ) 545 | 546 | with tf.variable_scope('discrim'): 547 | discrim_pos = self.discriminator(vecs_pos) 548 | discrim_neg = self.discriminator(vecs_gen) 549 | 550 | self.build_losses(discrim_pos, discrim_neg, extra_g=KL, l2_norm=self.l2norm) 551 | self.collect_variables() 552 | 553 | def _get_optimizer(self): 554 | if self.optimizer == 'AdamOptimizer': 555 | return tf.train.AdamOptimizer(self.learning_rate, 0.5) 556 | 557 | elif self.optimizer == 'AdadeltaOptimizer': 558 | return tf.train.AdadeltaOptimizer(self.learning_rate, 0.95) 559 | 560 | else: 561 | return tf.train.GradientDescentOptimizer(self.learning_rate) 562 | 563 | 564 | class TGANModel: 565 | """Main model from TGAN. 566 | 567 | Args: 568 | continuous_columns (list[int]): 0-index list of column indices to be considered continuous. 569 | output (str, optional): Path to store the model and its artifacts. Defaults to 570 | :attr:`output`. 571 | gpu (list[str], optional):Comma separated list of GPU(s) to use. Defaults to :attr:`None`. 572 | max_epoch (int, optional): Number of epochs to use during training. Defaults to :attr:`5`. 573 | steps_per_epoch (int, optional): Number of steps to run on each epoch. Defaults to 574 | :attr:`10000`. 575 | save_checkpoints(bool, optional): Whether or not to store checkpoints of the model after 576 | each training epoch. Defaults to :attr:`True` 577 | restore_session(bool, optional): Whether or not continue training from the last checkpoint. 578 | Defaults to :attr:`True`. 579 | batch_size (int, optional): Size of the batch to feed the model at each step. Defaults to 580 | :attr:`200`. 581 | z_dim (int, optional): Number of dimensions in the noise input for the generator. 582 | Defaults to :attr:`100`. 583 | noise (float, optional): Upper bound to the gaussian noise added to categorical columns. 584 | Defaults to :attr:`0.2`. 585 | l2norm (float, optional): 586 | L2 reguralization coefficient when computing losses. Defaults to :attr:`0.00001`. 587 | learning_rate (float, optional): Learning rate for the optimizer. Defaults to 588 | :attr:`0.001`. 589 | num_gen_rnn (int, optional): Defaults to :attr:`400`. 590 | num_gen_feature (int, optional): Number of features of in the generator. Defaults to 591 | :attr:`100` 592 | num_dis_layers (int, optional): Defaults to :attr:`2`. 593 | num_dis_hidden (int, optional): Defaults to :attr:`200`. 594 | optimizer (str, optional): Name of the optimizer to use during `fit`,possible values are: 595 | [`GradientDescentOptimizer`, `AdamOptimizer`, `AdadeltaOptimizer`]. Defaults to 596 | :attr:`AdamOptimizer`. 597 | """ 598 | 599 | def __init__( 600 | self, continuous_columns, output='output', gpu=None, max_epoch=5, steps_per_epoch=10000, 601 | save_checkpoints=True, restore_session=True, batch_size=200, z_dim=200, noise=0.2, 602 | l2norm=0.00001, learning_rate=0.001, num_gen_rnn=100, num_gen_feature=100, 603 | num_dis_layers=1, num_dis_hidden=100, optimizer='AdamOptimizer', comet_ml_key=None, experiment=None, ds=None 604 | ): 605 | """Initialize object.""" 606 | # Output 607 | self.continuous_columns = continuous_columns 608 | self.log_dir = os.path.join(output, 'logs') 609 | self.model_dir = os.path.join(output, 'model') 610 | self.output = output 611 | 612 | # Training params 613 | self.max_epoch = max_epoch 614 | self.steps_per_epoch = steps_per_epoch 615 | self.save_checkpoints = save_checkpoints 616 | self.restore_session = restore_session 617 | 618 | # Model params 619 | self.model = None 620 | self.batch_size = batch_size 621 | self.z_dim = z_dim 622 | self.noise = noise 623 | self.l2norm = l2norm 624 | self.learning_rate = learning_rate 625 | self.num_gen_rnn = num_gen_rnn 626 | self.num_gen_feature = num_gen_feature 627 | self.num_dis_layers = num_dis_layers 628 | self.num_dis_hidden = num_dis_hidden 629 | self.optimizer = optimizer 630 | 631 | if gpu: 632 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 633 | 634 | if experiment is not None: 635 | self.experiment = experiment 636 | elif comet_ml_key is not None: 637 | self.comet_ml_key = comet_ml_key 638 | self.experiment = Experiment(api_key=comet_ml_key, project_name='tgan-wgan-gp', workspace="baukebrenninkmeijer") 639 | if ds is not None: 640 | experiment.log_dataset_info(name=ds) 641 | self.gpu = gpu 642 | 643 | def get_model(self, training=True): 644 | """Return a new instance of the model.""" 645 | return GraphBuilder( 646 | metadata=self.metadata, 647 | batch_size=self.batch_size, 648 | z_dim=self.z_dim, 649 | noise=self.noise, 650 | l2norm=self.l2norm, 651 | learning_rate=self.learning_rate, 652 | num_gen_rnn=self.num_gen_rnn, 653 | num_gen_feature=self.num_gen_feature, 654 | num_dis_layers=self.num_dis_layers, 655 | num_dis_hidden=self.num_dis_hidden, 656 | optimizer=self.optimizer, 657 | training=training 658 | ) 659 | 660 | def prepare_sampling(self): 661 | """Prepare model for generate samples.""" 662 | if self.model is None: 663 | self.model = self.get_model(training=False) 664 | 665 | else: 666 | self.model.training = False 667 | 668 | predict_config = PredictConfig( 669 | session_init=SaverRestore(self.restore_path), 670 | model=self.model, 671 | input_names=['z'], 672 | output_names=['gen/gen', 'z'], 673 | ) 674 | 675 | self.simple_dataset_predictor = SimpleDatasetPredictor( 676 | predict_config, 677 | RandomZData((self.batch_size, self.z_dim)) 678 | ) 679 | 680 | def fit(self, data): 681 | """Fit the model to the given data. 682 | 683 | Args: 684 | data(pandas.DataFrame): dataset to fit the model. 685 | 686 | Returns: 687 | None 688 | 689 | """ 690 | self.preprocessor = Preprocessor(continuous_columns=self.continuous_columns) 691 | data = self.preprocessor.fit_transform(data) 692 | self.metadata = self.preprocessor.metadata 693 | dataflow = TGANDataFlow(data, self.metadata) 694 | batch_data = BatchData(dataflow, self.batch_size) 695 | input_queue = QueueInput(batch_data) 696 | 697 | self.model = self.get_model(training=True) 698 | 699 | trainer = GANTrainer( 700 | model=self.model, 701 | input_queue=input_queue, 702 | ) 703 | 704 | self.restore_path = os.path.join(self.model_dir, 'checkpoint') 705 | 706 | if os.path.isfile(self.restore_path) and self.restore_session: 707 | session_init = SaverRestore(self.restore_path) 708 | with open(os.path.join(self.log_dir, 'stats.json')) as f: 709 | starting_epoch = json.load(f)[-1]['epoch_num'] + 1 710 | 711 | else: 712 | session_init = None 713 | starting_epoch = 1 714 | 715 | action = 'k' if self.restore_session else None 716 | # logger.set_logger_dir(self.log_dir, action=action) 717 | 718 | callbacks = [] 719 | monitors = [] 720 | if self.save_checkpoints: 721 | callbacks.append(ModelSaver(checkpoint_dir=self.model_dir)) 722 | callbacks.append(MergeAllSummaries(period=10)) 723 | 724 | if self.experiment is not None: 725 | monitors.append(CometMLMonitor(experiment=self.experiment)) 726 | 727 | trainer.train_with_defaults( 728 | callbacks=callbacks, 729 | monitors=monitors, 730 | steps_per_epoch=self.steps_per_epoch, 731 | max_epoch=self.max_epoch, 732 | session_init=session_init, 733 | starting_epoch=starting_epoch 734 | ) 735 | 736 | self.prepare_sampling() 737 | 738 | def sample(self, num_samples): 739 | """Generate samples from model. 740 | 741 | Args: 742 | num_samples(int) 743 | 744 | Returns: 745 | None 746 | 747 | Raises: 748 | ValueError 749 | 750 | """ 751 | max_iters = (num_samples // self.batch_size) 752 | 753 | results = [] 754 | for idx, o in enumerate(self.simple_dataset_predictor.get_result()): 755 | results.append(o[0]) 756 | if idx + 1 == max_iters: 757 | break 758 | 759 | results = np.concatenate(results, axis=0) 760 | 761 | ptr = 0 762 | features = {} 763 | for col_id, col_info in enumerate(self.metadata['details']): 764 | if col_info['type'] == 'category': 765 | features['f%02d' % col_id] = results[:, ptr:ptr + 1] 766 | ptr += 1 767 | 768 | elif col_info['type'] == 'value': 769 | gaussian_components = col_info['n'] 770 | val = results[:, ptr:ptr + 1] 771 | ptr += 1 772 | pro = results[:, ptr:ptr + gaussian_components] 773 | ptr += gaussian_components 774 | features['f%02d' % col_id] = np.concatenate([val, pro], axis=1) 775 | 776 | else: 777 | raise ValueError( 778 | "self.metadata['details'][{}]['type'] must be either `category` or " 779 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 780 | ) 781 | 782 | return self.preprocessor.reverse_transform(features)[:num_samples].copy() 783 | 784 | def tar_folder(self, tar_name): 785 | """Generate a tar of :self.output:.""" 786 | with tarfile.open(tar_name, 'w:gz') as tar_handle: 787 | for root, dirs, files in os.walk(self.output): 788 | for file_ in files: 789 | tar_handle.add(os.path.join(root, file_)) 790 | 791 | tar_handle.close() 792 | 793 | @classmethod 794 | def load(cls, path): 795 | """Load a pretrained model from a given path.""" 796 | with tarfile.open(path, 'r:gz') as tar_handle: 797 | destination_dir = os.path.dirname(tar_handle.getmembers()[0].name) 798 | tar_handle.extractall() 799 | 800 | with open('{}/TGANModel'.format(destination_dir), 'rb+') as f: 801 | instance = pickle.load(f) 802 | 803 | instance.prepare_sampling() 804 | return instance 805 | 806 | def save(self, path, force=False): 807 | """Save the fitted model in the given path.""" 808 | if os.path.exists(path) and not force: 809 | logger.info('The indicated path already exists. Use `force=True` to overwrite.') 810 | return 811 | 812 | base_path = os.path.dirname(path) 813 | if not os.path.exists(base_path): 814 | os.makedirs(base_path) 815 | 816 | model = self.model 817 | dataset_predictor = self.simple_dataset_predictor 818 | 819 | self.model = None 820 | self.simple_dataset_predictor = None 821 | 822 | with open('{}/TGANModel'.format(self.output), 'wb') as f: 823 | pickle.dump(self, f) 824 | 825 | self.model = model 826 | self.simple_dataset_predictor = dataset_predictor 827 | 828 | self.tar_folder(path) 829 | 830 | logger.info('Model saved successfully.') 831 | -------------------------------------------------------------------------------- /tgan_skip/trainer.py: -------------------------------------------------------------------------------- 1 | """GAN Models.""" 2 | 3 | import tensorflow as tf 4 | from tensorpack import StagingInput, TowerTrainer 5 | from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter 6 | from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper 7 | 8 | 9 | class GANTrainer(TowerTrainer): 10 | """GanTrainer model. 11 | 12 | We need to set :meth:`tower_func` because it's a :class:`TowerTrainer`, and only 13 | :class:`TowerTrainer` supports automatic graph creation for inference during training. 14 | 15 | If we don't care about inference during training, using :meth:`tower_func` is not needed. 16 | Just calling :meth:`model.build_graph` directly is OK. 17 | 18 | Args: 19 | input_queue(tensorpack.input_source.QueueInput): Data input. 20 | model(tgan.GAN.GANModelDesc): Model to train. 21 | 22 | """ 23 | 24 | def __init__(self, model, input_queue): 25 | """Initialize object.""" 26 | super().__init__() 27 | inputs_desc = model.get_inputs_desc() 28 | 29 | # Setup input 30 | cbs = input_queue.setup(inputs_desc) 31 | self.register_callback(cbs) 32 | 33 | # Build the graph 34 | self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) 35 | with TowerContext('', is_training=True): 36 | self.tower_func(*input_queue.get_input_tensors()) 37 | 38 | opt = model.get_optimizer() 39 | 40 | # Define the training iteration by default, run one d_min after one g_min 41 | with tf.name_scope('optimize'): 42 | g_min_grad = opt.compute_gradients(model.g_loss, var_list=model.g_vars) 43 | g_min_grad_clip = [ 44 | (tf.clip_by_value(grad, -5.0, 5.0), var) 45 | for grad, var in g_min_grad 46 | ] 47 | 48 | g_min_train_op = opt.apply_gradients(g_min_grad_clip, name='g_op') 49 | with tf.control_dependencies([g_min_train_op]): 50 | d_min_grad = opt.compute_gradients(model.d_loss, var_list=model.d_vars) 51 | d_min_grad_clip = [ 52 | (tf.clip_by_value(grad, -5.0, 5.0), var) 53 | for grad, var in d_min_grad 54 | ] 55 | 56 | d_min_train_op = opt.apply_gradients(d_min_grad_clip, name='d_op') 57 | 58 | self.train_op = d_min_train_op 59 | 60 | 61 | class SeparateGANTrainer(TowerTrainer): 62 | """A GAN trainer which runs two optimization ops with a certain ratio. 63 | 64 | Args: 65 | input(tensorpack.input_source.QueueInput): Data input. 66 | model(tgan.GAN.GANModelDesc): Model to train. 67 | d_period(int): period of each d_opt run 68 | g_period(int): period of each g_opt run 69 | 70 | """ 71 | 72 | def __init__(self, input, model, d_period=1, g_period=1): 73 | """Initialize object.""" 74 | super(SeparateGANTrainer, self).__init__() 75 | self._d_period = int(d_period) 76 | self._g_period = int(g_period) 77 | if not min(d_period, g_period) == 1: 78 | raise ValueError('The minimum between d_period and g_period must be 1.') 79 | 80 | # Setup input 81 | cbs = input.setup(model.get_inputs_desc()) 82 | self.register_callback(cbs) 83 | 84 | # Build the graph 85 | self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) 86 | with TowerContext('', is_training=True): 87 | self.tower_func(*input.get_input_tensors()) 88 | 89 | opt = model.get_optimizer() 90 | with tf.name_scope('optimize'): 91 | self.d_min = opt.minimize( 92 | model.d_loss, var_list=model.d_vars, name='d_min') 93 | self.g_min = opt.minimize( 94 | model.g_loss, var_list=model.g_vars, name='g_min') 95 | 96 | def run_step(self): 97 | """Define the training iteration.""" 98 | if self.global_step % (self._d_period) == 0: 99 | self.hooked_sess.run(self.d_min) 100 | if self.global_step % (self._g_period) == 0: 101 | self.hooked_sess.run(self.g_min) 102 | 103 | 104 | class MultiGPUGANTrainer(TowerTrainer): 105 | """A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. 106 | 107 | Args: 108 | nr_gpu(int): 109 | input(tensorpack.input_source.QueueInput): Data input. 110 | model(tgan.GAN.GANModelDesc): Model to train. 111 | 112 | """ 113 | 114 | def __init__(self, nr_gpu, input, model): 115 | """Initialize object.""" 116 | super(MultiGPUGANTrainer, self).__init__() 117 | if nr_gpu <= 1: 118 | raise ValueError('nr_gpu must be strictly greater than 1.') 119 | 120 | raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)] 121 | 122 | # Setup input 123 | input = StagingInput(input) 124 | cbs = input.setup(model.get_inputs_desc()) 125 | self.register_callback(cbs) 126 | 127 | # Build the graph with multi-gpu replication 128 | def get_cost(*inputs): 129 | model.build_graph(*inputs) 130 | return [model.d_loss, model.g_loss] 131 | 132 | self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) 133 | devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] 134 | 135 | cost_list = DataParallelBuilder.build_on_towers( 136 | list(range(nr_gpu)), 137 | lambda: self.tower_func(*input.get_input_tensors()), 138 | devices) 139 | 140 | # Simply average the cost here. It might be faster to average the gradients 141 | with tf.name_scope('optimize'): 142 | d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) 143 | g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu) 144 | 145 | opt = model.get_optimizer() 146 | # run one d_min after one g_min 147 | g_min = opt.minimize(g_loss, var_list=model.g_vars, 148 | colocate_gradients_with_ops=True, name='g_op') 149 | 150 | with tf.control_dependencies([g_min]): 151 | d_min = opt.minimize(d_loss, var_list=model.d_vars, 152 | colocate_gradients_with_ops=True, name='d_op') 153 | 154 | # Define the training iteration 155 | self.train_op = d_min 156 | -------------------------------------------------------------------------------- /tgan_wgan_gp/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for TGAN.""" 4 | 5 | __author__ = """MIT Data To AI Lab""" 6 | __email__ = 'dailabmit@gmail.com' 7 | __version__ = '0.1.0' 8 | -------------------------------------------------------------------------------- /tgan_wgan_gp/cli.py: -------------------------------------------------------------------------------- 1 | """Command Line Interface for TGAN.""" 2 | 3 | import argparse 4 | 5 | from tgan.research.experiments import run_experiments 6 | 7 | 8 | def get_train_parser(): 9 | """Build the ArgumentParser for CLI.""" 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 12 | parser.add_argument('--load', help='load model') 13 | parser.add_argument('--sample', type=int, default=0, 14 | help='the number of samples in the synthetic output.') 15 | parser.add_argument('--data', required=True, help='a npz file') 16 | parser.add_argument('--output', type=str) 17 | parser.add_argument('--exp_name', type=str, default=None) 18 | 19 | # parameters for model tuning. 20 | parser.add_argument('--batch_size', type=int, default=200) 21 | parser.add_argument('--z_dim', type=int, default=100) 22 | parser.add_argument('--max_epoch', type=int, default=100) 23 | parser.add_argument('--steps_per_epoch', type=int, default=1000) 24 | 25 | parser.add_argument('--num_gen_rnn', type=int, default=400) 26 | parser.add_argument('--num_gen_feature', type=int, default=100) 27 | 28 | parser.add_argument('--num_dis_layers', type=int, default=2) 29 | parser.add_argument('--num_dis_hidden', type=int, default=200) 30 | 31 | parser.add_argument('--noise', type=float, default=0.2) 32 | 33 | parser.add_argument('--optimizer', type=str, default='AdamOptimizer', 34 | choices=['GradientDescentOptimizer', 'AdamOptimizer', 'AdadeltaOptimizer']) 35 | parser.add_argument('--learning_rate', type=float, default=0.001) 36 | 37 | parser.add_argument('--l2norm', type=float, default=0.00001) 38 | 39 | return parser 40 | 41 | 42 | def get_parser(): 43 | """Build argument parser for TGAN CLI utility.""" 44 | parser = argparse.ArgumentParser(description='TGAN Command Line Interface.') 45 | parser.set_defaults(function=None) 46 | 47 | action = parser.add_subparsers(title='action', dest='action') 48 | action.required = True 49 | 50 | experiments = action.add_parser('experiments', help='Run experiments using TGAN.') 51 | experiments.add_argument( 52 | 'input', type=str, help='Path to the JSON file with the configuration.') 53 | experiments.add_argument( 54 | 'output', type=str, help='Path to store the results.') 55 | 56 | return parser 57 | 58 | 59 | def main(): 60 | """Python Entry point for CLI.""" 61 | parser = get_parser() 62 | args = parser.parse_args() 63 | run_experiments(args.input, args.output) 64 | -------------------------------------------------------------------------------- /tgan_wgan_gp/data.py: -------------------------------------------------------------------------------- 1 | """Data related functionalities. 2 | 3 | This modules contains the tools to preprare the data, from the raw csv files, to the DataFlow 4 | objects will be used to fit our models. 5 | """ 6 | import os 7 | import urllib 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from sklearn.mixture import GaussianMixture 12 | from sklearn.preprocessing import LabelEncoder 13 | from tensorpack import DataFlow, RNGDataFlow 14 | 15 | DEMO_DATASETS = { 16 | 'census': ( 17 | 'http://hdi-project-tgan.s3.amazonaws.com/census-train.csv', 18 | 'data/census.csv', 19 | [0, 5, 16, 17, 18, 29, 38] 20 | ), 21 | 'covertype': ( 22 | 'http://hdi-project-tgan.s3.amazonaws.com/covertype-train.csv', 23 | 'data/covertype.csv', 24 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 25 | ) 26 | } 27 | 28 | 29 | def check_metadata(metadata): 30 | """Check that the given metadata has correct types for all its members. 31 | 32 | Args: 33 | metadata(dict): Description of the inputs. 34 | 35 | Returns: 36 | None 37 | 38 | Raises: 39 | AssertionError: If any of the details is not valid. 40 | 41 | """ 42 | message = 'The given metadata contains unsupported types.' 43 | assert all([item['type'] in ['category', 'value'] for item in metadata['details']]), message 44 | 45 | 46 | def check_inputs(function): 47 | """Validate inputs for functions whose first argument is a numpy.ndarray with shape (n,1). 48 | 49 | Args: 50 | function(callable): Method to validate. 51 | 52 | Returns: 53 | callable: Will check the inputs before calling :attr:`function`. 54 | 55 | Raises: 56 | ValueError: If first argument is not a valid :class:`numpy.array` of shape (n, 1). 57 | 58 | """ 59 | def decorated(self, data, *args, **kwargs): 60 | if not (isinstance(data, np.ndarray) and len(data.shape) == 2 and data.shape[1] == 1): 61 | raise ValueError('The argument `data` must be a numpy.ndarray with shape (n, 1).') 62 | 63 | return function(self, data, *args, **kwargs) 64 | 65 | decorated.__doc__ = function.__doc__ 66 | return decorated 67 | 68 | 69 | class TGANDataFlow(RNGDataFlow): 70 | """Subclass of :class:`tensorpack.RNGDataFlow` prepared to work with :class:`numpy.ndarray`. 71 | 72 | Attributes: 73 | shuffle(bool): Wheter or not to shuffle the data. 74 | metadata(dict): Metadata for the given :attr:`data`. 75 | num_features(int): Number of features in given data. 76 | data(list): Prepared data from :attr:`filename`. 77 | distribution(list): DepecrationWarning? 78 | 79 | """ 80 | 81 | def __init__(self, data, metadata, shuffle=True): 82 | """Initialize object. 83 | 84 | Args: 85 | filename(str): Path to the json file containing the metadata. 86 | shuffle(bool): Wheter or not to shuffle the data. 87 | 88 | Raises: 89 | ValueError: If any column_info['type'] is not supported 90 | 91 | """ 92 | self.shuffle = shuffle 93 | if self.shuffle: 94 | self.reset_state() 95 | 96 | self.metadata = metadata 97 | self.num_features = self.metadata['num_features'] 98 | 99 | self.data = [] 100 | self.distribution = [] 101 | for column_id, column_info in enumerate(self.metadata['details']): 102 | if column_info['type'] == 'value': 103 | col_data = data['f%02d' % column_id] 104 | value = col_data[:, :1] 105 | cluster = col_data[:, 1:] 106 | self.data.append(value) 107 | self.data.append(cluster) 108 | 109 | elif column_info['type'] == 'category': 110 | col_data = np.asarray(data['f%02d' % column_id], dtype='int32') 111 | self.data.append(col_data) 112 | 113 | else: 114 | raise ValueError( 115 | "column_info['type'] must be either 'category' or 'value'." 116 | "Instead it was '{}'.".format(column_info['type']) 117 | ) 118 | 119 | self.data = list(zip(*self.data)) 120 | 121 | def size(self): 122 | """Return the number of rows in data. 123 | 124 | Returns: 125 | int: Number of rows in :attr:`data`. 126 | 127 | """ 128 | return len(self.data) 129 | 130 | def get_data(self): 131 | """Yield the rows from :attr:`data`. 132 | 133 | Yields: 134 | tuple: Row of data. 135 | 136 | """ 137 | idxs = np.arange(len(self.data)) 138 | if self.shuffle: 139 | self.rng.shuffle(idxs) 140 | 141 | for k in idxs: 142 | yield self.data[k] 143 | 144 | def __iter__(self): 145 | """Iterate over self.data.""" 146 | return self.get_data() 147 | 148 | def __len__(self): 149 | """Length of batches.""" 150 | return self.size() 151 | 152 | 153 | class RandomZData(DataFlow): 154 | """Random dataflow. 155 | 156 | Args: 157 | shape(tuple): Shape of the array to return on :meth:`get_data` 158 | 159 | """ 160 | 161 | def __init__(self, shape): 162 | """Initialize object.""" 163 | super(RandomZData, self).__init__() 164 | self.shape = shape 165 | 166 | def get_data(self): 167 | """Yield random normal vectors of shape :attr:`shape`.""" 168 | while True: 169 | yield [np.random.normal(0, 1, size=self.shape)] 170 | 171 | def __iter__(self): 172 | """Return data.""" 173 | return self.get_data() 174 | 175 | def __len__(self): 176 | """Length of batches.""" 177 | return self.shape[0] 178 | 179 | 180 | class MultiModalNumberTransformer: 181 | r"""Reversible transform for multimodal data. 182 | 183 | To effectively sample values from a multimodal distribution, we cluster values of a 184 | numerical variable using a `skelarn.mixture.GaussianMixture`_ model (GMM). 185 | 186 | * We train a GMM with :attr:`n` components for each numerical variable :math:`C_i`. 187 | GMM models a distribution with a weighted sum of :attr:`n` Gaussian distributions. 188 | The means and standard deviations of the :attr:`n` Gaussian distributions are 189 | :math:`{\eta}^{(1)}_{i}, ..., {\eta}^{(n)}_{i}` and 190 | :math:`{\sigma}^{(1)}_{i}, ...,{\sigma}^{(n)}_{i}`. 191 | 192 | * We compute the probability of :math:`c_{i,j}` coming from each of the :attr:`n` Gaussian 193 | distributions as a vector :math:`{u}^{(1)}_{i,j}, ..., {u}^{(n)}_{i,j}`. u_{i,j} is a 194 | normalized probability distribution over :attr:`n` Gaussian distributions. 195 | 196 | * We normalize :math:`c_{i,j}` as :math:`v_{i,j} = (c_{i,j}−{\eta}^{(k)}_{i})/2{\sigma}^ 197 | {(k)}_{i}`, where :math:`k = arg max_k {u}^{(k)}_{i,j}`. We then clip :math:`v_{i,j}` to 198 | [−0.99, 0.99]. 199 | 200 | Then we use :math:`u_i` and :math:`v_i` to represent :math:`c_i`. For simplicity, 201 | we cluster all the numerical features, i.e. both uni-modal and multi-modal features are 202 | clustered to :attr:`n = 5` Gaussian distributions. 203 | 204 | The simplification is fair because GMM automatically weighs :attr:`n` components. 205 | For example, if a variable has only one mode and fits some Gaussian distribution, then GMM 206 | will assign a very low probability to :attr:`n − 1` components and only 1 remaining 207 | component actually works, which is equivalent to not clustering this feature. 208 | 209 | Args: 210 | num_modes(int): Number of modes on given data. 211 | 212 | Attributes: 213 | num_modes(int): Number of components in the `skelarn.mixture.GaussianMixture`_ model. 214 | 215 | .. _skelarn.mixture.GaussianMixture: https://scikit-learn.org/stable/modules/generated/ 216 | sklearn.mixture.GaussianMixture.html 217 | 218 | """ 219 | 220 | def __init__(self, num_modes=5): 221 | """Initialize instance.""" 222 | self.num_modes = num_modes 223 | 224 | @check_inputs 225 | def transform(self, data): 226 | """Cluster values using a `skelarn.mixture.GaussianMixture`_ model. 227 | 228 | Args: 229 | data(numpy.ndarray): Values to cluster in array of shape (n,1). 230 | 231 | Returns: 232 | tuple[numpy.ndarray, numpy.ndarray, list, list]: Tuple containg the features, 233 | probabilities, averages and stds of the given data. 234 | 235 | .. _skelarn.mixture.GaussianMixture: https://scikit-learn.org/stable/modules/generated/ 236 | sklearn.mixture.GaussianMixture.html 237 | 238 | """ 239 | model = GaussianMixture(self.num_modes) 240 | model.fit(data) 241 | 242 | means = model.means_.reshape((1, self.num_modes)) 243 | stds = np.sqrt(model.covariances_).reshape((1, self.num_modes)) 244 | 245 | features = (data - means) / (2 * stds) 246 | probs = model.predict_proba(data) 247 | argmax = np.argmax(probs, axis=1) 248 | idx = np.arange(len(features)) 249 | features = features[idx, argmax].reshape([-1, 1]) 250 | 251 | features = np.clip(features, -0.99, 0.99) 252 | 253 | return features, probs, list(means.flat), list(stds.flat) 254 | 255 | @staticmethod 256 | def inverse_transform(data, info): 257 | """Reverse the clustering of values. 258 | 259 | Args: 260 | data(numpy.ndarray): Transformed data to restore. 261 | info(dict): Metadata. 262 | 263 | Returns: 264 | numpy.ndarray: Values in the original space. 265 | 266 | """ 267 | features = data[:, 0] 268 | probs = data[:, 1:] 269 | p_argmax = np.argmax(probs, axis=1) 270 | 271 | mean = np.asarray(info['means']) 272 | std = np.asarray(info['stds']) 273 | 274 | select_mean = mean[p_argmax] 275 | select_std = std[p_argmax] 276 | 277 | return features * 2 * select_std + select_mean 278 | 279 | 280 | class Preprocessor: 281 | """Transform back and forth human-readable data into TGAN numerical features. 282 | 283 | Args: 284 | continous_columns(list): List of columns to be considered continuous 285 | metadata(dict): Metadata to initialize the object. 286 | 287 | Attributes: 288 | continous_columns(list): Same as constructor argument. 289 | metadata(dict): Information about the transformations applied to the data and its format. 290 | continous_transformer(MultiModalNumberTransformer): 291 | Transformer for columns in :attr:`continuous_columns` 292 | categorical_transformer(CategoricalTransformer): 293 | Transformer for categorical columns. 294 | columns(list): List of columns labels. 295 | 296 | """ 297 | 298 | def __init__(self, continuous_columns=None, metadata=None): 299 | """Initialize object, set arguments as attributes, initialize transformers.""" 300 | if continuous_columns is None: 301 | continuous_columns = [] 302 | 303 | self.continuous_columns = continuous_columns 304 | self.metadata = metadata 305 | self.continous_transformer = MultiModalNumberTransformer() 306 | self.categorical_transformer = LabelEncoder() 307 | self.columns = None 308 | 309 | def fit_transform(self, data, fitting=True): 310 | """Transform human-readable data into TGAN numerical features. 311 | 312 | Args: 313 | data(pandas.DataFrame): Data to transform. 314 | fitting(bool): Whether or not to update self.metadata. 315 | 316 | Returns: 317 | pandas.DataFrame: Model features 318 | 319 | """ 320 | num_cols = data.shape[1] 321 | self.columns = data.columns 322 | data.columns = list(range(num_cols)) 323 | 324 | transformed_data = {} 325 | details = [] 326 | 327 | for i in data.columns: 328 | if i in self.continuous_columns: 329 | column_data = data[i].values.reshape([-1, 1]) 330 | features, probs, means, stds = self.continous_transformer.transform(column_data) 331 | transformed_data['f%02d' % i] = np.concatenate((features, probs), axis=1) 332 | 333 | if fitting: 334 | details.append({ 335 | "type": "value", 336 | "means": means, 337 | "stds": stds, 338 | "n": 5 339 | }) 340 | 341 | else: 342 | column_data = data[i].astype(str).values 343 | features = self.categorical_transformer.fit_transform(column_data) 344 | transformed_data['f%02d' % i] = features.reshape([-1, 1]) 345 | 346 | if fitting: 347 | mapping = self.categorical_transformer.classes_ 348 | details.append({ 349 | "type": "category", 350 | "mapping": mapping, 351 | "n": mapping.shape[0], 352 | }) 353 | 354 | if fitting: 355 | metadata = { 356 | "num_features": num_cols, 357 | "details": details 358 | } 359 | check_metadata(metadata) 360 | self.metadata = metadata 361 | 362 | return transformed_data 363 | 364 | def transform(self, data): 365 | """Transform the given dataframe without generating new metadata. 366 | 367 | Args: 368 | data(pandas.DataFrame): Data to fit the object. 369 | 370 | """ 371 | return self.fit_transform(data, fitting=False) 372 | 373 | def fit(self, data): 374 | """Initialize the internal state of the object using :attr:`data`. 375 | 376 | Args: 377 | data(pandas.DataFrame): Data to fit the object. 378 | 379 | """ 380 | self.fit_transform(data) 381 | 382 | def reverse_transform(self, data): 383 | """Transform TGAN numerical features back into human-readable data. 384 | 385 | Args: 386 | data(pandas.DataFrame): Data to transform. 387 | fitting(bool): Whether or not to update self.metadata. 388 | 389 | Returns: 390 | pandas.DataFrame: Model features 391 | 392 | """ 393 | table = [] 394 | 395 | for i in range(self.metadata['num_features']): 396 | column_data = data['f%02d' % i] 397 | column_metadata = self.metadata['details'][i] 398 | 399 | if column_metadata['type'] == 'value': 400 | column = self.continous_transformer.inverse_transform(column_data, column_metadata) 401 | 402 | if column_metadata['type'] == 'category': 403 | self.categorical_transformer.classes_ = column_metadata['mapping'] 404 | column = self.categorical_transformer.inverse_transform( 405 | column_data.ravel().astype(np.int32)) 406 | 407 | table.append(column) 408 | 409 | result = pd.DataFrame(dict(enumerate(table))) 410 | result.columns = self.columns 411 | return result 412 | 413 | 414 | def load_demo_data(name, header=None): 415 | """Fetch, load and prepare a dataset. 416 | 417 | If name is one of the demo datasets 418 | 419 | 420 | Args: 421 | name(str): Name or path of the dataset. 422 | header(): Header parameter when executing :attr:`pandas.read_csv` 423 | 424 | """ 425 | params = DEMO_DATASETS.get(name) 426 | if params: 427 | url, file_path, continuous_columns = params 428 | if not os.path.isfile(file_path): 429 | base_path = os.path.dirname(file_path) 430 | if not os.path.exists(base_path): 431 | os.makedirs(base_path) 432 | 433 | urllib.request.urlretrieve(url, file_path) 434 | 435 | else: 436 | message = ( 437 | '{} is not a valid dataset name. ' 438 | 'Supported values are: {}.'.format(name, list(DEMO_DATASETS.keys())) 439 | ) 440 | raise ValueError(message) 441 | 442 | return pd.read_csv(file_path, header=header), continuous_columns 443 | -------------------------------------------------------------------------------- /tgan_wgan_gp/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Module with the model for TGAN. 5 | 6 | This module contains two classes: 7 | 8 | - :attr:`GraphBuilder`: That defines the graph and implements a Tensorpack compatible API. 9 | - :attr:`TGANModel`: The public API for the model, that offers a simplified interface for the 10 | underlying operations with GraphBuilder and trainers in order to fit and sample data. 11 | """ 12 | import json 13 | import os 14 | import pickle 15 | import tarfile 16 | 17 | import numpy as np 18 | from comet_ml import Experiment 19 | from tensorpack.callbacks import CometMLMonitor, MergeAllSummaries 20 | import tensorflow as tf 21 | from tensorpack import ( 22 | BatchData, BatchNorm, Dropout, FullyConnected, InputDesc, ModelDescBase, ModelSaver, 23 | PredictConfig, QueueInput, SaverRestore, SimpleDatasetPredictor, logger, LayerNorm) 24 | from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope 25 | from tensorpack.tfutils.summary import add_moving_summary 26 | from tensorpack.utils.argtools import memoized 27 | 28 | from tgan.data import Preprocessor, RandomZData, TGANDataFlow 29 | from tgan.trainer import GANTrainer, SeparateGANTrainer 30 | 31 | TUNABLE_VARIABLES = { 32 | 'batch_size': [50, 100, 200], 33 | 'z_dim': [50, 100, 200, 400], 34 | 'num_gen_rnn': [100, 200, 300, 400, 500, 600], 35 | 'num_gen_feature': [100, 200, 300, 400, 500, 600], 36 | 'num_dis_layers': [1, 2, 3, 4, 5], 37 | 'num_dis_hidden': [100, 200, 300, 400, 500], 38 | 'learning_rate': [0.0002, 0.0005, 0.001], 39 | 'noise': [0.05, 0.1, 0.2, 0.3] 40 | } 41 | 42 | 43 | class GraphBuilder(ModelDescBase): 44 | """Main model for TGAN. 45 | 46 | Args: 47 | None 48 | 49 | Attributes: 50 | 51 | """ 52 | 53 | def __init__( 54 | self, 55 | metadata, 56 | batch_size=200, 57 | z_dim=200, 58 | noise=0.2, 59 | l2norm=0.00001, 60 | learning_rate=0.001, 61 | num_gen_rnn=100, 62 | num_gen_feature=100, 63 | num_dis_layers=1, 64 | num_dis_hidden=100, 65 | optimizer='AdamOptimizer', 66 | training=True 67 | ): 68 | """Initialize the object, set arguments as attributes.""" 69 | self.metadata = metadata 70 | self.batch_size = batch_size 71 | self.z_dim = z_dim 72 | self.noise = noise 73 | self.l2norm = l2norm 74 | self.learning_rate = learning_rate 75 | self.num_gen_rnn = num_gen_rnn 76 | self.num_gen_feature = num_gen_feature 77 | self.num_dis_layers = num_dis_layers 78 | self.num_dis_hidden = num_dis_hidden 79 | self.optimizer = optimizer 80 | self.training = training 81 | 82 | def collect_variables(self, g_scope='gen', d_scope='discrim'): 83 | """Assign generator and discriminator variables from their scopes. 84 | 85 | Args: 86 | g_scope(str): Scope for the generator. 87 | d_scope(str): Scope for the discriminator. 88 | 89 | Raises: 90 | ValueError: If any of the assignments fails or the collections are empty. 91 | 92 | """ 93 | self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope) 94 | self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope) 95 | 96 | if not (self.g_vars or self.d_vars): 97 | raise ValueError('There are no variables defined in some of the given scopes') 98 | 99 | def build_losses(self, logits_real, logits_fake, x_fake, x_real, 100 | extra_g=0, l2_norm=0.00001): 101 | r"""D and G play two-player minimax game with value function :math:`V(G,D)`. 102 | 103 | .. math:: 104 | 105 | min_G max_D V(D, G) = IE_{x \sim p_{data}} [log D(x)] + IE_{z \sim p_{fake}} 106 | [log (1 - D(G(z)))] 107 | 108 | Args: 109 | logits_real (tensorflow.Tensor): discrim logits from real samples. 110 | logits_fake (tensorflow.Tensor): discrim logits from fake samples from generator. 111 | extra_g(float): 112 | l2_norm(float): scale to apply L2 regularization. 113 | 114 | Returns: 115 | None 116 | 117 | """ 118 | 119 | x_fake = tf.concat(x_fake, axis=1) 120 | x_real = tf.concat(x_real, axis=1) 121 | # print('x_real shape: ', x_real) 122 | 123 | with tf.name_scope("GAN_loss"): 124 | score_real = tf.sigmoid(logits_real) 125 | score_fake = tf.sigmoid(logits_fake) 126 | tf.summary.histogram('score-real', score_real) 127 | tf.summary.histogram('score-fake', score_fake) 128 | tf.summary.histogram('logits_real', logits_real) 129 | tf.summary.histogram('logits_fake', logits_fake) 130 | 131 | # with tf.name_scope("discrim"): 132 | self.epsilon = tf.random_uniform( 133 | shape=[self.batch_size, 1], 134 | minval=0., 135 | maxval=1.) 136 | 137 | X_hat = x_real + self.epsilon * (x_fake - x_real) 138 | # print('X_hat shape: ', X_hat.shape) 139 | D_X_hat = self.discriminator(X_hat) 140 | grad_D_X_hat = tf.gradients(D_X_hat, [X_hat])[0] 141 | red_idx = list(range(1, X_hat.shape.ndims)) 142 | slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_X_hat), reduction_indices=red_idx)) 143 | gradient_penalty = tf.identity(tf.reduce_mean((slopes - 1.) ** 2), name='GP') 144 | self.d_loss = tf.reduce_mean(logits_fake) - tf.reduce_mean(logits_real) 145 | self.d_loss = self.d_loss + 10 * gradient_penalty 146 | self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) 147 | self.gp_sum = tf.summary.scalar("Gradient_penalty", gradient_penalty) 148 | 149 | self.d_loss_sum = tf.summary.scalar("Discriminator_loss", self.d_loss) 150 | self.gp_sum = tf.summary.scalar("Gradient_penalty", gradient_penalty) 151 | 152 | with tf.name_scope("gen"): 153 | self.g_loss = -tf.reduce_mean(logits_fake) 154 | self.g_loss_sum = tf.summary.scalar("Generator_loss", self.g_loss) 155 | 156 | add_moving_summary(self.g_loss, self.d_loss) 157 | 158 | @memoized 159 | def get_optimizer(self): 160 | """Return optimizer of base class.""" 161 | return self._get_optimizer() 162 | 163 | def inputs(self): 164 | """Return metadata about entry data. 165 | 166 | Returns: 167 | list[tensorpack.InputDesc] 168 | 169 | Raises: 170 | ValueError: If any of the elements in self.metadata['details'] has an unsupported 171 | value in the `type` key. 172 | 173 | """ 174 | inputs = [] 175 | for col_id, col_info in enumerate(self.metadata['details']): 176 | if col_info['type'] == 'value': 177 | gaussian_components = col_info['n'] 178 | inputs.append( 179 | InputDesc(tf.float32, (self.batch_size, 1), 'input%02dvalue' % col_id)) 180 | 181 | inputs.append( 182 | InputDesc( 183 | tf.float32, 184 | (self.batch_size, gaussian_components), 185 | 'input%02dcluster' % col_id 186 | ) 187 | ) 188 | 189 | elif col_info['type'] == 'category': 190 | inputs.append(InputDesc(tf.int32, (self.batch_size, 1), 'input%02d' % col_id)) 191 | 192 | else: 193 | raise ValueError( 194 | "self.metadata['details'][{}]['type'] must be either `category` or " 195 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 196 | ) 197 | 198 | return inputs 199 | 200 | def generator(self, z): 201 | r"""Build generator graph. 202 | 203 | We generate a numerical variable in 2 steps. We first generate the value scalar 204 | :math:`v_i`, then generate the cluster vector :math:`u_i`. We generate categorical 205 | feature in 1 step as a probability distribution over all possible labels. 206 | 207 | The output and hidden state size of LSTM is :math:`n_h`. The input to the LSTM in each 208 | step :math:`t` is the random variable :math:`z`, the previous hidden vector :math:`f_{t−1}` 209 | or an embedding vector :math:`f^{\prime}_{t−1}` depending on the type of previous output, 210 | and the weighted context vector :math:`a_{t−1}`. The random variable :math:`z` has 211 | :math:`n_z` dimensions. 212 | Each dimension is sampled from :math:`\mathcal{N}(0, 1)`. The attention-based context 213 | vector at is a weighted average over all the previous LSTM outputs :math:`h_{1:t}`. 214 | So :math:`a_t` is a :math:`n_h`-dimensional vector. 215 | We learn a attention weight vector :math:`α_t \in \mathbb{R}^t` and compute context as 216 | 217 | .. math:: 218 | a_t = \sum_{k=1}^{t} \frac{\textrm{exp} {\alpha}_{t, j}} 219 | {\sum_{j} \textrm{exp} \alpha_{t,j}} h_k. 220 | 221 | We set :math: `a_0` = 0. The output of LSTM is :math:`h_t` and we project the output to 222 | a hidden vector :math:`f_t = \textrm{tanh}(W_h h_t)`, where :math:`W_h` is a learned 223 | parameter in the network. The size of :math:`f_t` is :math:`n_f` . 224 | We further convert the hidden vector to an output variable. 225 | 226 | * If the output is the value part of a continuous variable, we compute the output as 227 | :math:`v_i = \textrm{tanh}(W_t f_t)`. The hidden vector for :math:`t + 1` step is 228 | :math:`f_t`. 229 | 230 | * If the output is the cluster part of a continuous variable, we compute the output as 231 | :math:`u_i = \textrm{softmax}(W_t f_t)`. The feature vector for :math:`t + 1` step is 232 | :math:`f_t`. 233 | 234 | * If the output is a discrete variable, we compute the output as 235 | :math:`d_i = \textrm{softmax}(W_t f_t)`. The hidden vector for :math:`t + 1` step is 236 | :math:`f^{\prime}_{t} = E_i [arg_k \hspace{0.25em} \textrm{max} \hspace{0.25em} d_i ]`, 237 | where :math:`E \in R^{|D_i|×n_f}` is an embedding matrix for discrete variable 238 | :math:`D_i`. 239 | 240 | * :math:`f_0` is a special vector :math:`\texttt{}` and we learn it during the 241 | training. 242 | 243 | Args: 244 | z: 245 | 246 | Returns: 247 | list[tensorflow.Tensor]: Outpu 248 | 249 | Raises: 250 | ValueError: If any of the elements in self.metadata['details'] has an unsupported 251 | value in the `type` key. 252 | 253 | """ 254 | with tf.variable_scope('LSTM'): 255 | cell = tf.nn.rnn_cell.LSTMCell(self.num_gen_rnn) 256 | 257 | state = cell.zero_state(self.batch_size, dtype='float32') 258 | attention = tf.zeros( 259 | shape=(self.batch_size, self.num_gen_rnn), dtype='float32') 260 | input = tf.get_variable(name='go', shape=(1, self.num_gen_feature)) # 261 | input = tf.tile(input, [self.batch_size, 1]) 262 | input = tf.concat([input, z], axis=1) 263 | 264 | ptr = 0 265 | outputs = [] 266 | states = [] 267 | for col_id, col_info in enumerate(self.metadata['details']): 268 | if col_info['type'] == 'value': 269 | output, state = cell(tf.concat([input, attention], axis=1), state) 270 | states.append(state[1]) 271 | 272 | gaussian_components = col_info['n'] 273 | with tf.variable_scope("%02d" % ptr): 274 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 275 | outputs.append(FullyConnected('FC2', h, 1, nl=tf.tanh)) 276 | input = tf.concat([h, z], axis=1) 277 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 278 | attw = tf.nn.softmax(attw, axis=0) 279 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 280 | 281 | ptr += 1 282 | 283 | output, state = cell(tf.concat([input, attention], axis=1), state) 284 | states.append(state[1]) 285 | with tf.variable_scope("%02d" % ptr): 286 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 287 | w = FullyConnected('FC2', h, gaussian_components, nl=tf.nn.softmax) 288 | outputs.append(w) 289 | input = FullyConnected('FC3', w, self.num_gen_feature, nl=tf.identity) 290 | input = tf.concat([input, z], axis=1) 291 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 292 | attw = tf.nn.softmax(attw, axis=0) 293 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 294 | 295 | ptr += 1 296 | 297 | elif col_info['type'] == 'category': 298 | output, state = cell(tf.concat([input, attention], axis=1), state) 299 | states.append(state[1]) 300 | with tf.variable_scope("%02d" % ptr): 301 | h = FullyConnected('FC', output, self.num_gen_feature, nl=tf.tanh) 302 | w = FullyConnected('FC2', h, col_info['n'], nl=tf.nn.softmax) 303 | outputs.append(w) 304 | one_hot = tf.one_hot(tf.argmax(w, axis=1), col_info['n']) 305 | input = FullyConnected( 306 | 'FC3', one_hot, self.num_gen_feature, nl=tf.identity) 307 | input = tf.concat([input, z], axis=1) 308 | attw = tf.get_variable("attw", shape=(len(states), 1, 1)) 309 | attw = tf.nn.softmax(attw, axis=0) 310 | attention = tf.reduce_sum(tf.stack(states, axis=0) * attw, axis=0) 311 | 312 | ptr += 1 313 | 314 | else: 315 | raise ValueError( 316 | "self.metadata['details'][{}]['type'] must be either `category` or " 317 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 318 | ) 319 | 320 | return outputs 321 | 322 | @staticmethod 323 | def batch_diversity(l, n_kernel=10, kernel_dim=10): 324 | r"""Return the minibatch discrimination vector. 325 | 326 | Let :math:`f(x_i) \in \mathbb{R}^A` denote a vector of features for input :math:`x_i`, 327 | produced by some intermediate layer in the discriminator. We then multiply the vector 328 | :math:`f(x_i)` by a tensor :math:`T \in \mathbb{R}^{A×B×C}`, which results in a matrix 329 | :math:`M_i \in \mathbb{R}^{B×C}`. We then compute the :math:`L_1`-distance between the 330 | rows of the resulting matrix :math:`M_i` across samples :math:`i \in {1, 2, ... , n}` 331 | and apply a negative exponential: 332 | 333 | .. math:: 334 | 335 | cb(x_i, x_j) = exp(−||M_{i,b} − M_{j,b}||_{L_1} ) \in \mathbb{R}. 336 | 337 | The output :math:`o(x_i)` for this *minibatch layer* for a sample :math:`x_i` is then 338 | defined as the sum of the cb(xi, xj )’s to all other samples: 339 | 340 | .. math:: 341 | :nowrap: 342 | 343 | \begin{aligned} 344 | 345 | &o(x_i)_b = \sum^{n}_{j=1} cb(x_i , x_j) \in \mathbb{R}\\ 346 | &o(x_i) = \Big[ o(x_i)_1, o(x_i)_2, . . . , o(x_i)_B \Big] \in \mathbb{R}^B\\ 347 | &o(X) ∈ R^{n×B}\\ 348 | 349 | \end{aligned} 350 | 351 | Note: 352 | This is extracted from `Improved techniques for training GANs`_ (Section 3.2) by 353 | Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and 354 | Xi Chen. 355 | 356 | .. _Improved techniques for training GANs: https://arxiv.org/pdf/1606.03498.pdf 357 | 358 | Args: 359 | l(tf.Tensor) 360 | n_kernel(int) 361 | kernel_dim(int) 362 | 363 | Returns: 364 | tensorflow.Tensor 365 | 366 | """ 367 | M = FullyConnected('fc_diversity', l, n_kernel * kernel_dim, nl=tf.identity) 368 | M = tf.reshape(M, [-1, n_kernel, kernel_dim]) 369 | M1 = tf.reshape(M, [-1, 1, n_kernel, kernel_dim]) 370 | M2 = tf.reshape(M, [1, -1, n_kernel, kernel_dim]) 371 | diff = tf.exp(-tf.reduce_sum(tf.abs(M1 - M2), axis=3)) 372 | return tf.reduce_sum(diff, axis=0) 373 | 374 | @auto_reuse_variable_scope 375 | def discriminator(self, vecs): 376 | r"""Build discriminator. 377 | 378 | We use a :math:`l`-layer fully connected neural network as the discriminator. 379 | We concatenate :math:`v_{1:n_c}`, :math:`u_{1:n_c}` and :math:`d_{1:n_d}` together as the 380 | input. We compute the internal layers as 381 | 382 | .. math:: 383 | \begin{aligned} 384 | 385 | f^{(D)}_{1} &= \textrm{LeakyReLU}(\textrm{BN}(W^{(D)}_{1}(v_{1:n_c} \oplus u_{1:n_c} 386 | \oplus d_{1:n_d}) 387 | 388 | f^{(D)}_{1} &= \textrm{LeakyReLU}(\textrm{BN}(W^{(D)}_{i}(f^{(D)}_{i−1} \oplus 389 | \textrm{diversity}(f^{(D)}_{i−1})))), i = 2:l 390 | 391 | \end{aligned} 392 | 393 | where :math:`\oplus` is the concatenation operation. :math:`\textrm{diversity}(·)` is the 394 | mini-batch discrimination vector [42]. Each dimension of the diversity vector is the total 395 | distance between one sample and all other samples in the mini-batch using some learned 396 | distance metric. :math:`\textrm{BN}(·)` is batch normalization, and 397 | :math:`\textrm{LeakyReLU}(·)` is the leaky reflect linear activation function. We further 398 | compute the output of discriminator as :math:`W^{(D)}(f^{(D)}_{l} \oplus \textrm{diversity} 399 | (f^{(D)}_{l}))` which is a scalar. 400 | 401 | Args: 402 | vecs(list[tensorflow.Tensor]): List of tensors matching the spec of :meth:`inputs` 403 | 404 | Returns: 405 | tensorpack.FullyConected: a (b, 1) logits 406 | 407 | """ 408 | logits = tf.concat(vecs, axis=1) 409 | with tf.variable_scope('discrim'): 410 | for i in range(self.num_dis_layers): 411 | with tf.variable_scope('dis_fc{}'.format(i)): 412 | if i == 0: 413 | logits = FullyConnected( 414 | 'fc', logits, self.num_dis_hidden, nl=tf.identity, 415 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.1) 416 | ) 417 | 418 | else: 419 | logits = FullyConnected('fc', logits, self.num_dis_hidden, nl=tf.identity) 420 | 421 | logits = tf.concat([logits, self.batch_diversity(logits)], axis=1) 422 | logits = LayerNorm('ln', logits) 423 | logits = Dropout(logits) 424 | logits = tf.nn.leaky_relu(logits) 425 | 426 | return FullyConnected('dis_fc_top', logits, 1, nl=tf.identity) 427 | 428 | @staticmethod 429 | def compute_kl(real, pred): 430 | r"""Compute the Kullback–Leibler divergence, :math:`D_{KL}(\textrm{pred} || \textrm{real})`. 431 | 432 | Args: 433 | real(tensorflow.Tensor): Real values. 434 | pred(tensorflow.Tensor): Predicted values. 435 | 436 | Returns: 437 | float: Computed divergence for the given values. 438 | 439 | """ 440 | return tf.reduce_sum((tf.log(pred + 1e-4) - tf.log(real + 1e-4)) * pred) 441 | 442 | def build_graph(self, *inputs): 443 | """Build the whole graph. 444 | 445 | Args: 446 | inputs(list[tensorflow.Tensor]): 447 | 448 | Returns: 449 | None 450 | 451 | """ 452 | z = tf.random_normal( 453 | [self.batch_size, self.z_dim], name='z_train') 454 | 455 | z = tf.placeholder_with_default(z, [None, self.z_dim], name='z') 456 | 457 | with tf.variable_scope('gen'): 458 | vecs_gen = self.generator(z) 459 | 460 | vecs_denorm = [] 461 | ptr = 0 462 | for col_id, col_info in enumerate(self.metadata['details']): 463 | if col_info['type'] == 'category': 464 | t = tf.argmax(vecs_gen[ptr], axis=1) 465 | t = tf.cast(tf.reshape(t, [-1, 1]), 'float32') 466 | vecs_denorm.append(t) 467 | ptr += 1 468 | 469 | elif col_info['type'] == 'value': 470 | vecs_denorm.append(vecs_gen[ptr]) 471 | ptr += 1 472 | vecs_denorm.append(vecs_gen[ptr]) 473 | ptr += 1 474 | 475 | else: 476 | raise ValueError( 477 | "self.metadata['details'][{}]['type'] must be either `category` or " 478 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 479 | ) 480 | 481 | tf.identity(tf.concat(vecs_denorm, axis=1), name='gen') 482 | 483 | vecs_pos = [] 484 | ptr = 0 485 | for col_id, col_info in enumerate(self.metadata['details']): 486 | if col_info['type'] == 'category': 487 | one_hot = tf.one_hot(tf.reshape(inputs[ptr], [-1]), col_info['n']) 488 | noise_input = one_hot 489 | 490 | if self.training: 491 | noise = tf.random_uniform(tf.shape(one_hot), minval=0, maxval=self.noise) 492 | noise_input = (one_hot + noise) / tf.reduce_sum( 493 | one_hot + noise, keepdims=True, axis=1) 494 | 495 | vecs_pos.append(noise_input) 496 | ptr += 1 497 | 498 | elif col_info['type'] == 'value': 499 | vecs_pos.append(inputs[ptr]) 500 | ptr += 1 501 | vecs_pos.append(inputs[ptr]) 502 | ptr += 1 503 | 504 | else: 505 | raise ValueError( 506 | "self.metadata['details'][{}]['type'] must be either `category` or " 507 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 508 | ) 509 | 510 | KL = 0. 511 | ptr = 0 512 | if self.training: 513 | for col_id, col_info in enumerate(self.metadata['details']): 514 | if col_info['type'] == 'category': 515 | dist = tf.reduce_sum(vecs_gen[ptr], axis=0) 516 | dist = dist / tf.reduce_sum(dist) 517 | 518 | real = tf.reduce_sum(vecs_pos[ptr], axis=0) 519 | real = real / tf.reduce_sum(real) 520 | KL += self.compute_kl(real, dist) 521 | ptr += 1 522 | 523 | elif col_info['type'] == 'value': 524 | ptr += 1 525 | dist = tf.reduce_sum(vecs_gen[ptr], axis=0) 526 | dist = dist / tf.reduce_sum(dist) 527 | real = tf.reduce_sum(vecs_pos[ptr], axis=0) 528 | real = real / tf.reduce_sum(real) 529 | KL += self.compute_kl(real, dist) 530 | 531 | ptr += 1 532 | 533 | else: 534 | raise ValueError( 535 | "self.metadata['details'][{}]['type'] must be either `category` or " 536 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 537 | ) 538 | # with tf.variable_scope('discrim'): 539 | # print('vecs pos: ', len(vecs_pos)) 540 | # print('vecs neg: ', len(vecs_gen)) 541 | discrim_pos = self.discriminator(vecs_pos) 542 | discrim_neg = self.discriminator(vecs_gen) 543 | 544 | self.build_losses(discrim_pos, discrim_neg, vecs_gen, vecs_pos, extra_g=KL, l2_norm=self.l2norm) 545 | self.collect_variables() 546 | 547 | def _get_optimizer(self): 548 | if self.optimizer == 'AdamOptimizer': 549 | return tf.train.AdamOptimizer(self.learning_rate, 0, 0.9) 550 | 551 | elif self.optimizer == 'AdadeltaOptimizer': 552 | return tf.train.AdadeltaOptimizer(self.learning_rate, 0.95) 553 | 554 | else: 555 | return tf.train.GradientDescentOptimizer(self.learning_rate) 556 | 557 | 558 | class TGANModel: 559 | """Main model from TGAN. 560 | 561 | Args: 562 | continuous_columns (list[int]): 0-index list of column indices to be considered continuous. 563 | output (str, optional): Path to store the model and its artifacts. Defaults to 564 | :attr:`output`. 565 | gpu (list[str], optional):Comma separated list of GPU(s) to use. Defaults to :attr:`None`. 566 | max_epoch (int, optional): Number of epochs to use during training. Defaults to :attr:`5`. 567 | steps_per_epoch (int, optional): Number of steps to run on each epoch. Defaults to 568 | :attr:`10000`. 569 | save_checkpoints(bool, optional): Whether or not to store checkpoints of the model after 570 | each training epoch. Defaults to :attr:`True` 571 | restore_session(bool, optional): Whether or not continue training from the last checkpoint. 572 | Defaults to :attr:`True`. 573 | batch_size (int, optional): Size of the batch to feed the model at each step. Defaults to 574 | :attr:`200`. 575 | z_dim (int, optional): Number of dimensions in the noise input for the generator. 576 | Defaults to :attr:`100`. 577 | noise (float, optional): Upper bound to the gaussian noise added to categorical columns. 578 | Defaults to :attr:`0.2`. 579 | l2norm (float, optional): 580 | L2 reguralization coefficient when computing losses. Defaults to :attr:`0.00001`. 581 | learning_rate (float, optional): Learning rate for the optimizer. Defaults to 582 | :attr:`0.001`. 583 | num_gen_rnn (int, optional): Defaults to :attr:`400`. 584 | num_gen_feature (int, optional): Number of features of in the generator. Defaults to 585 | :attr:`100` 586 | num_dis_layers (int, optional): Defaults to :attr:`2`. 587 | num_dis_hidden (int, optional): Defaults to :attr:`200`. 588 | optimizer (str, optional): Name of the optimizer to use during `fit`,possible values are: 589 | [`GradientDescentOptimizer`, `AdamOptimizer`, `AdadeltaOptimizer`]. Defaults to 590 | :attr:`AdamOptimizer`. 591 | """ 592 | 593 | def __init__( 594 | self, continuous_columns, output='output', gpu=None, max_epoch=5, steps_per_epoch=10000, 595 | save_checkpoints=True, restore_session=True, batch_size=200, z_dim=200, noise=0.2, 596 | l2norm=0.00001, learning_rate=0.001, num_gen_rnn=100, num_gen_feature=100, 597 | num_dis_layers=1, num_dis_hidden=100, optimizer='AdamOptimizer', comet_ml_key=None, experiment=None, ds=None 598 | ): 599 | """Initialize object.""" 600 | # Output 601 | self.continuous_columns = continuous_columns 602 | self.log_dir = os.path.join(output, 'logs') 603 | self.model_dir = os.path.join(output, 'model') 604 | self.output = output 605 | 606 | # Training params 607 | self.max_epoch = max_epoch 608 | self.steps_per_epoch = steps_per_epoch 609 | self.save_checkpoints = save_checkpoints 610 | self.restore_session = restore_session 611 | 612 | # Model params 613 | self.model = None 614 | self.batch_size = batch_size 615 | self.z_dim = z_dim 616 | self.noise = noise 617 | self.l2norm = l2norm 618 | self.learning_rate = learning_rate 619 | self.num_gen_rnn = num_gen_rnn 620 | self.num_gen_feature = num_gen_feature 621 | self.num_dis_layers = num_dis_layers 622 | self.num_dis_hidden = num_dis_hidden 623 | self.optimizer = optimizer 624 | 625 | if gpu: 626 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 627 | 628 | if experiment is not None: 629 | self.experiment = experiment 630 | elif comet_ml_key is not None: 631 | self.comet_ml_key = comet_ml_key 632 | self.experiment = Experiment(api_key=comet_ml_key, project_name='tgan-wgan-gp', workspace="baukebrenninkmeijer") 633 | if ds is not None: 634 | experiment.log_dataset_info(name=ds) 635 | self.gpu = gpu 636 | 637 | def get_model(self, training=True): 638 | """Return a new instance of the model.""" 639 | return GraphBuilder( 640 | metadata=self.metadata, 641 | batch_size=self.batch_size, 642 | z_dim=self.z_dim, 643 | noise=self.noise, 644 | l2norm=self.l2norm, 645 | learning_rate=self.learning_rate, 646 | num_gen_rnn=self.num_gen_rnn, 647 | num_gen_feature=self.num_gen_feature, 648 | num_dis_layers=self.num_dis_layers, 649 | num_dis_hidden=self.num_dis_hidden, 650 | optimizer=self.optimizer, 651 | training=training 652 | ) 653 | 654 | def prepare_sampling(self): 655 | """Prepare model for generate samples.""" 656 | if self.model is None: 657 | self.model = self.get_model(training=False) 658 | 659 | else: 660 | self.model.training = False 661 | 662 | predict_config = PredictConfig( 663 | session_init=SaverRestore(self.restore_path), 664 | model=self.model, 665 | input_names=['z'], 666 | output_names=['gen/gen', 'z'], 667 | ) 668 | 669 | self.simple_dataset_predictor = SimpleDatasetPredictor( 670 | predict_config, 671 | RandomZData((self.batch_size, self.z_dim)) 672 | ) 673 | 674 | def fit(self, data): 675 | """Fit the model to the given data. 676 | 677 | Args: 678 | data(pandas.DataFrame): dataset to fit the model. 679 | 680 | Returns: 681 | None 682 | 683 | """ 684 | 685 | self.preprocessor = Preprocessor(continuous_columns=self.continuous_columns) 686 | data = self.preprocessor.fit_transform(data) 687 | self.metadata = self.preprocessor.metadata 688 | dataflow = TGANDataFlow(data, self.metadata) 689 | batch_data = BatchData(dataflow, self.batch_size) 690 | input_queue = QueueInput(batch_data) 691 | 692 | self.model = self.get_model(training=True) 693 | 694 | from tensorpack.callbacks import CometMLMonitor 695 | 696 | trainer = SeparateGANTrainer( 697 | model=self.model, 698 | input_queue=input_queue, 699 | g_period=6, 700 | ) 701 | 702 | self.restore_path = os.path.join(self.model_dir, 'checkpoint') 703 | 704 | if os.path.isfile(self.restore_path) and self.restore_session: 705 | session_init = SaverRestore(self.restore_path) 706 | with open(os.path.join(self.log_dir, 'stats.json')) as f: 707 | starting_epoch = json.load(f)[-1]['epoch_num'] + 1 708 | 709 | else: 710 | session_init = None 711 | starting_epoch = 1 712 | 713 | action = 'k' if self.restore_session else 'd' 714 | # logger.set_logger_dir(self.log_dir, action=action) 715 | 716 | callbacks = [] 717 | monitors = [] 718 | if self.save_checkpoints: 719 | callbacks.append(ModelSaver(checkpoint_dir=self.model_dir)) 720 | callbacks.append(MergeAllSummaries(period=10)) 721 | 722 | if self.experiment is not None: 723 | monitors.append(CometMLMonitor(experiment=self.experiment)) 724 | 725 | trainer.train_with_defaults( 726 | callbacks=callbacks, 727 | monitors=monitors, 728 | steps_per_epoch=self.steps_per_epoch, 729 | max_epoch=self.max_epoch, 730 | session_init=session_init, 731 | starting_epoch=starting_epoch 732 | ) 733 | 734 | self.prepare_sampling() 735 | 736 | def sample(self, num_samples): 737 | """Generate samples from model. 738 | 739 | Args: 740 | num_samples(int) 741 | 742 | Returns: 743 | None 744 | 745 | Raises: 746 | ValueError 747 | 748 | """ 749 | max_iters = (num_samples // self.batch_size) 750 | 751 | results = [] 752 | for idx, o in enumerate(self.simple_dataset_predictor.get_result()): 753 | results.append(o[0]) 754 | if idx + 1 == max_iters: 755 | break 756 | 757 | results = np.concatenate(results, axis=0) 758 | 759 | ptr = 0 760 | features = {} 761 | for col_id, col_info in enumerate(self.metadata['details']): 762 | if col_info['type'] == 'category': 763 | features['f%02d' % col_id] = results[:, ptr:ptr + 1] 764 | ptr += 1 765 | 766 | elif col_info['type'] == 'value': 767 | gaussian_components = col_info['n'] 768 | val = results[:, ptr:ptr + 1] 769 | ptr += 1 770 | pro = results[:, ptr:ptr + gaussian_components] 771 | ptr += gaussian_components 772 | features['f%02d' % col_id] = np.concatenate([val, pro], axis=1) 773 | 774 | else: 775 | raise ValueError( 776 | "self.metadata['details'][{}]['type'] must be either `category` or " 777 | "`values`. Instead it was {}.".format(col_id, col_info['type']) 778 | ) 779 | 780 | return self.preprocessor.reverse_transform(features)[:num_samples].copy() 781 | 782 | def tar_folder(self, tar_name): 783 | """Generate a tar of :self.output:.""" 784 | with tarfile.open(tar_name, 'w:gz') as tar_handle: 785 | for root, dirs, files in os.walk(self.output): 786 | for file_ in files: 787 | tar_handle.add(os.path.join(root, file_)) 788 | 789 | tar_handle.close() 790 | 791 | @classmethod 792 | def load(cls, path): 793 | """Load a pretrained model from a given path.""" 794 | with tarfile.open(path, 'r:gz') as tar_handle: 795 | destination_dir = os.path.dirname(tar_handle.getmembers()[0].name) 796 | tar_handle.extractall() 797 | 798 | with open('{}/TGANModel'.format(destination_dir), 'rb+') as f: 799 | instance = pickle.load(f) 800 | 801 | instance.prepare_sampling() 802 | return instance 803 | 804 | def save(self, path, force=False): 805 | """Save the fitted model in the given path.""" 806 | if os.path.exists(path) and not force: 807 | logger.info('The indicated path already exists. Use `force=True` to overwrite.') 808 | return 809 | 810 | base_path = os.path.dirname(path) 811 | if not os.path.exists(base_path): 812 | os.makedirs(base_path) 813 | 814 | model = self.model 815 | dataset_predictor = self.simple_dataset_predictor 816 | 817 | self.model = None 818 | self.simple_dataset_predictor = None 819 | 820 | with open('{}/TGANModel'.format(self.output), 'wb') as f: 821 | pickle.dump(self, f) 822 | 823 | self.model = model 824 | self.simple_dataset_predictor = dataset_predictor 825 | 826 | self.tar_folder(path) 827 | 828 | logger.info('Model saved successfully.') 829 | -------------------------------------------------------------------------------- /tgan_wgan_gp/trainer.py: -------------------------------------------------------------------------------- 1 | """GAN Models.""" 2 | 3 | import tensorflow as tf 4 | from tensorpack import StagingInput, TowerTrainer 5 | from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter 6 | from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper 7 | 8 | 9 | class GANTrainer(TowerTrainer): 10 | """GanTrainer model. 11 | 12 | We need to set :meth:`tower_func` because it's a :class:`TowerTrainer`, and only 13 | :class:`TowerTrainer` supports automatic graph creation for inference during training. 14 | 15 | If we don't care about inference during training, using :meth:`tower_func` is not needed. 16 | Just calling :meth:`model.build_graph` directly is OK. 17 | 18 | Args: 19 | input_queue(tensorpack.input_source.QueueInput): Data input. 20 | model(tgan.GAN.GANModelDesc): Model to train. 21 | 22 | """ 23 | 24 | def __init__(self, model, input_queue): 25 | """Initialize object.""" 26 | super().__init__() 27 | inputs_desc = model.get_inputs_desc() 28 | 29 | # Setup input 30 | cbs = input_queue.setup(inputs_desc) 31 | self.register_callback(cbs) 32 | 33 | # Build the graph 34 | self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) 35 | with TowerContext('', is_training=True): 36 | self.tower_func(*input_queue.get_input_tensors()) 37 | 38 | opt = model.get_optimizer() 39 | 40 | # Define the training iteration by default, run one d_min after one g_min 41 | with tf.name_scope('optimize'): 42 | g_min_grad = opt.compute_gradients(model.g_loss, var_list=model.g_vars) 43 | # g_min_grad_clip = [ 44 | # (tf.clip_by_value(grad, -5.0, 5.0), var) 45 | # for grad, var in g_min_grad 46 | # ] 47 | 48 | g_min_train_op = opt.apply_gradients(g_min_grad, name='g_op') 49 | with tf.control_dependencies([g_min_train_op]): 50 | d_min_grad = opt.compute_gradients(model.d_loss, var_list=model.d_vars) 51 | # d_min_grad_clip = [ 52 | # (tf.clip_by_value(grad, -5.0, 5.0), var) 53 | # for grad, var in d_min_grad 54 | # ] 55 | 56 | d_min_train_op = opt.apply_gradients(d_min_grad, name='d_op') 57 | 58 | self.train_op = d_min_train_op 59 | 60 | 61 | class SeparateGANTrainer(TowerTrainer): 62 | """A GAN trainer which runs two optimization ops with a certain ratio. 63 | 64 | Args: 65 | input(tensorpack.input_source.QueueInput): Data input. 66 | model(tgan.GAN.GANModelDesc): Model to train. 67 | d_period(int): period of each d_opt run 68 | g_period(int): period of each g_opt run 69 | 70 | """ 71 | 72 | def __init__(self, input_queue, model, d_period=1, g_period=1): 73 | """Initialize object.""" 74 | super(SeparateGANTrainer, self).__init__() 75 | self._d_period = int(d_period) 76 | self._g_period = int(g_period) 77 | if not min(d_period, g_period) == 1: 78 | raise ValueError('The minimum between d_period and g_period must be 1.') 79 | 80 | # Setup input 81 | cbs = input_queue.setup(model.get_inputs_desc()) 82 | self.register_callback(cbs) 83 | 84 | # Build the graph 85 | self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) 86 | with TowerContext('', is_training=True): 87 | self.tower_func(*input_queue.get_input_tensors()) 88 | 89 | opt = model.get_optimizer() 90 | with tf.name_scope('optimize'): 91 | self.d_min = opt.minimize( 92 | model.d_loss, var_list=model.d_vars, name='d_min') 93 | self.g_min = opt.minimize( 94 | model.g_loss, var_list=model.g_vars, name='g_min') 95 | 96 | def run_step(self): 97 | """Define the training iteration.""" 98 | if self.global_step % (self._d_period) == 0: 99 | self.hooked_sess.run(self.d_min) 100 | if self.global_step % (self._g_period) == 0: 101 | self.hooked_sess.run(self.g_min) 102 | 103 | 104 | class MultiGPUGANTrainer(TowerTrainer): 105 | """A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. 106 | 107 | Args: 108 | nr_gpu(int): 109 | input(tensorpack.input_source.QueueInput): Data input. 110 | model(tgan.GAN.GANModelDesc): Model to train. 111 | 112 | """ 113 | 114 | def __init__(self, nr_gpu, input, model): 115 | """Initialize object.""" 116 | super(MultiGPUGANTrainer, self).__init__() 117 | if nr_gpu <= 1: 118 | raise ValueError('nr_gpu must be strictly greater than 1.') 119 | 120 | raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)] 121 | 122 | # Setup input 123 | input = StagingInput(input) 124 | cbs = input.setup(model.get_inputs_desc()) 125 | self.register_callback(cbs) 126 | 127 | # Build the graph with multi-gpu replication 128 | def get_cost(*inputs): 129 | model.build_graph(*inputs) 130 | return [model.d_loss, model.g_loss] 131 | 132 | self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) 133 | devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] 134 | 135 | cost_list = DataParallelBuilder.build_on_towers( 136 | list(range(nr_gpu)), 137 | lambda: self.tower_func(*input.get_input_tensors()), 138 | devices) 139 | 140 | # Simply average the cost here. It might be faster to average the gradients 141 | with tf.name_scope('optimize'): 142 | d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) 143 | g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / nr_gpu) 144 | 145 | opt = model.get_optimizer() 146 | # run one d_min after one g_min 147 | g_min = opt.minimize(g_loss, var_list=model.g_vars, 148 | colocate_gradients_with_ops=True, name='g_op') 149 | 150 | with tf.control_dependencies([g_min]): 151 | d_min = opt.minimize(d_loss, var_list=model.d_vars, 152 | colocate_gradients_with_ops=True, name='d_op') 153 | 154 | # Define the training iteration 155 | self.train_op = d_min 156 | --------------------------------------------------------------------------------