├── assets └── PyTorch │ ├── loss_plot.png │ ├── generated_images.gif │ └── generated_images │ ├── epoch_0.png │ ├── epoch_120.png │ ├── epoch_160.png │ ├── epoch_200.png │ ├── epoch_40.png │ └── epoch_80.png ├── LICENSE ├── src └── PyTorch │ ├── input_data.py │ └── gan-mnist-pytorch.py ├── .gitignore └── README.md /assets/PyTorch/loss_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/loss_plot.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images.gif -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_0.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_120.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_160.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_160.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_200.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_40.png -------------------------------------------------------------------------------- /assets/PyTorch/generated_images/epoch_80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vamsi3/simple-GAN/HEAD/assets/PyTorch/generated_images/epoch_80.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vamsi Krishna 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/PyTorch/input_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | import gzip 23 | import os 24 | import tempfile 25 | 26 | import numpy 27 | from six.moves import urllib 28 | from six.moves import xrange # pylint: disable=redefined-builtin 29 | import tensorflow as tf 30 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 31 | # pylint: enable=unused-import 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple GAN using PyTorch 2 | > This project is a basic Generative Adversarial Network (GAN) implemented in PyTorch on the MNIST Database 3 | 4 | This is one of my initial steps towards GANs in general. This mostly follows from the idea of GAN first published in [arXiv:1406.2661 [stat.ML]](https://arxiv.org/pdf/1406.2661.pdf) by GoodFellow _et.al._ 5 | 6 | ## Getting the code to work 7 | 8 | Follow the instructions below to get our project running on your local machine. 9 | 10 | 1. Clone the repository and make sure you have Python 3 to run the project. 11 | 2. Go to `src/PyTorch/` and run `python gan-mnist-pytorch.py` 12 | 3. All the outputs and related plots can be found in `src/PyTorch/output` folder generated. 13 | 4. The various parameters that can be tweaked before run can be found at `python gan-mnist-pytorch.py --help` 14 | 15 | ### Prerequisites 16 | 17 | * PyTorch 0.4.0 or above 18 | * CUDA 9.1 (or other version corresponding to PyTorch) to utilize any compatible GPU present for faster training 19 | 20 | ### Results 21 | 22 |
24 |