├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── Playground.ipynb ├── README.md ├── data ├── font_data_loader.py └── font_dataset.py ├── environment.yml ├── models ├── loss.py └── models.py ├── results ├── results_1.png ├── results_2.png ├── results_3.png ├── results_4.png ├── results_5.png ├── results_6.png ├── results_7.png └── results_8.png ├── test.py ├── tests ├── font_dataset_test.py └── test_datasets │ └── valid │ └── Aaargh.0.0.png ├── train.py ├── trained ├── better_trained_d_state_dict.pt ├── better_trained_g_state_dict.pt ├── well_trained_d_state_dict.pt └── well_trained_g_state_dict.pt └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.unittestArgs": [ 3 | "-v", 4 | "-s", 5 | "tests", 6 | "-p", 7 | "*_test.py" 8 | ], 9 | "python.testing.pytestEnabled": false, 10 | "python.testing.nosetestsEnabled": false, 11 | "python.testing.unittestEnabled": true, 12 | "python.pythonPath": "C:\\Users\\joshpc\\Anaconda3\\python.exe", 13 | "git.ignoreLimitWarning": true, 14 | "python.testing.nosetestArgs": [ 15 | "tests" 16 | ] 17 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Joshua Tessier 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyledFontGAN 2 | The original intent of this project was to create a generative adversarial network that, given a single image of a letter, will generate the remaining characters in the alphabet (A-Z.) 3 | 4 | ## Motivation Behind This Project 5 | This is a simple project that I've been using to teach myself all about GANs, and getting familiar with Python. I don't claim to be talented when it comes to machine learning, but I want to understand it's capabilities and the challenges that can arise. 6 | 7 | Please note that this is a side project, and not academic research. Ultimately, the goal is to apply this to non-latin alphabets where producing fonts are extremely time consuming. 8 | 9 | ## Prior Work and Acknowledgements 10 | 11 | A lot of this project is inspired by the following research papers and their associated code: 12 | 13 | - [MC-GAN](https://github.com/azadis/MC-GAN) - Multi-Content GAN for Few-Shot Font Style Transfer; Samaneh Azadi, Matthew Fisher, Vladimir Kim, Zhaowen Wang, Eli Shechtman, Trevor Darrell, in arXiv, 2017. 14 | - [GlyphGAN](https://arxiv.org/abs/1905.12502v1) - GlyphGAN: Style-Consistent Font Generation Based on Generative Adversarial Networks; Hideaki Hayashi, Kohtaro Abe, Seiichi Uchida, in arXiv, 2019 15 | - [DC-GAN](https://arxiv.org/abs/1511.06434) - Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks; Alec Radford, Luke Metz, Soumith Chintala, in arXiv, 2016 16 | - [WGAN-GP](https://arxiv.org/abs/1704.00028) - Improved Training of Wasserstein GANs; Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville, in arXiv, 2017 17 | 18 | - Some pieces of code, specifically the `Flatten` and `Unflatten` methods were pulled from the [CS231N course](https://cs231n.github.io/) samples. 19 | 20 | This project was made possible by this research and the contributions made by the above authors. Thank you. 21 | 22 | ## Architecture 23 | 24 | The most succcessful model has been trained with the following generator and discriminators, and using an **L1 Loss** for the Generator, and the loss from **WGAN-GP** (Wasserstein Distance + Gradiant Penality) for the Discriminator. 25 | 26 | ### Generator Architecture 27 | 28 | The thought process behind this network architecture is not well informed. The inuition is that we take an image, extract its features with Conv layers, turn it into some intermediate format within the Linear layers, then use that intermediate format to generate a new image using the ConvTranspose layers. 29 | 30 | The result is good, but it's not ideal. Currently working on identifying what layers are actually contributing and determining if there are layers that are unnecessary. 31 | 32 | ``` 33 | Sequential( 34 | (0): Conv2d(1, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 35 | (1): LeakyReLU(negative_slope=0.2) 36 | (2): Conv2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 37 | (3): LeakyReLU(negative_slope=0.2) 38 | (4): Conv2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 39 | (5): LeakyReLU(negative_slope=0.2) 40 | (6): Flatten() 41 | (7): Linear(in_features=1024, out_features=256, bias=True) 42 | (8): ReLU() 43 | (9): Linear(in_features=256, out_features=5120, bias=True) 44 | (10): ReLU() 45 | (11): Unflatten() 46 | (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 47 | (13): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 48 | (14): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 49 | (15): ReLU() 50 | (16): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 51 | (17): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 52 | (18): ReLU() 53 | (19): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 54 | (20): Sigmoid() 55 | ) 56 | ``` 57 | 58 | ### Discriminator Architecture 59 | 60 | This follows DC-GAN. No magic, or added things here. As with the linked papers from above, we do not apply batch normalization which helps with training. 61 | 62 | ``` 63 | (Sequential( 64 | (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 65 | (1): LeakyReLU(negative_slope=0.2) 66 | (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 67 | (3): LeakyReLU(negative_slope=0.2) 68 | (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 69 | (5): LeakyReLU(negative_slope=0.2) 70 | (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 71 | (7): LeakyReLU(negative_slope=0.2) 72 | (8): Flatten() 73 | (9): Linear(in_features=640, out_features=1, bias=True) 74 | (10): Sigmoid() 75 | ), 76 | ``` 77 | 78 | ## Results 79 | 80 | This was by no means an academic study nor was it an attempt to push the boundaries of current research, but the results were good. 81 | 82 | This project used the data set from [MC-GAN](https://github.com/azadis/MC-GAN) which uses a variety of different fonts and samples. In some cases, the samples are all uppercase letters, while as others are a mix of uppercase and lowercase letters. 83 | 84 | ### Samples 85 | 86 | #### Successes 87 | In these examples, the results are both legible and match the original style. 88 | 89 | ![First Font](results/results_2.png) 90 | ![First Font](results/results_5.png) 91 | ![First Font](results/results_6.png) 92 | ![First Font](results/results_7.png) 93 | ![First Font](results/results_8.png) 94 | 95 | ### Questionable Failures 96 | ![First Font](results/results_1.png) 97 | The network succeeded in capturing the style, with the dark shadows, but couldn't produce legible letters. 98 | 99 | ![First Font](results/results_4.png) 100 | The letters aren't very clear, but despite the strange style, it seems to have respected it. 101 | 102 | ### Failures 103 | ![First Font](results/results_3.png) 104 | The network failed to produce legible letters, and failed to copy the style. 105 | 106 | ### Conclusions 107 | TBD 108 | 109 | ## Setup 110 | 111 | 1. Install `pytorch`: https://pytorch.org/get-started/locally/ 112 | 2. For now, you will also need a tool to view notebooks. I use Jupyter. 113 | 3. 114 | 3. Dependencies: TBD 115 | -------------------------------------------------------------------------------- /data/font_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class FontDataLoader(): 4 | def __init__(self, dataset, sampler, batch_size): 5 | self.data_loader = torch.utils.data.DataLoader( 6 | dataset, 7 | sampler=sampler, 8 | batch_size=batch_size 9 | ) 10 | 11 | def __iter__(self): 12 | self.data_loader_iterator = iter(self.data_loader) 13 | return self 14 | 15 | def __next__(self): 16 | return next(self.data_loader_iterator) 17 | -------------------------------------------------------------------------------- /data/font_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | 7 | from PIL import Image 8 | 9 | class FontData(): 10 | def __init__(self, font_name, font_path, image=None): 11 | self.font_name = font_name 12 | self.font_path = font_path 13 | self.image = None 14 | 15 | def load_data(self, loader): 16 | if self.image == None: 17 | self.image = loader(self.font_path) 18 | return self.image 19 | 20 | def __repr__(self): 21 | return "" % self.font_name 22 | 23 | class FontDataset(Dataset): 24 | """The Font Dataset.""" 25 | 26 | def __init__(self, root_dir, glyph_size=(64, 64), glyphs_per_image=26): 27 | self.fonts = self.load_font_filenames(root_dir) 28 | self.root_dir = root_dir 29 | self.glyph_size = glyph_size 30 | self.glyphs_per_image = glyphs_per_image 31 | 32 | def __len__(self): 33 | return len(self.fonts) 34 | 35 | def __getitem__(self, index): 36 | _index = index 37 | if torch.is_tensor(_index): 38 | _index = _index.tolist() 39 | 40 | font = self.fonts[_index] 41 | font_data = font.load_data(image_loader) 42 | 43 | transform = transforms.Compose([ 44 | transforms.Resize(self.glyph_size[0]), 45 | transforms.Grayscale(num_output_channels=1), # Drop to 1 channel 46 | transforms.ToTensor() 47 | ]) 48 | 49 | return transform(font_data) 50 | 51 | def load_font_filenames(self, root_dir): 52 | font_images = [] 53 | assert os.path.isdir(root_dir), '%s is not a valid directory!' % root_dir 54 | 55 | for root, _, filenames in sorted(os.walk(root_dir)): 56 | for filename in filenames: 57 | font_images.append(FontData(filename, os.path.join(root, filename))) 58 | 59 | return font_images 60 | 61 | # Helper Functions 62 | 63 | def image_loader(path): 64 | return Image.open(path).convert('RGB') 65 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: font-style-gan 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _pytorch_select=1.1.0=cpu 6 | - attrs=19.3.0=py_0 7 | - backcall=0.1.0=py37_0 8 | - blas=1.0=mkl 9 | - bleach=3.1.0=py37_0 10 | - ca-certificates=2020.1.1=0 11 | - certifi=2019.11.28=py37_1 12 | - cffi=1.14.0=py37h7a1dbc1_0 13 | - colorama=0.4.3=py_0 14 | - decorator=4.4.2=py_0 15 | - defusedxml=0.6.0=py_0 16 | - entrypoints=0.3=py37_0 17 | - icc_rt=2019.0.0=h0cc432a_1 18 | - icu=58.2=ha66f8fd_1 19 | - importlib_metadata=1.5.0=py37_0 20 | - intel-openmp=2020.0=166 21 | - ipykernel=5.1.4=py37h39e3cac_0 22 | - ipython=7.13.0=py37h5ca1d4c_0 23 | - ipython_genutils=0.2.0=py37_0 24 | - ipywidgets=7.5.1=py_0 25 | - jedi=0.16.0=py37_1 26 | - jinja2=2.11.1=py_0 27 | - jpeg=9b=hb83a4c4_2 28 | - jsonschema=3.2.0=py37_0 29 | - jupyter=1.0.0=py37_7 30 | - jupyter_client=6.1.2=py_0 31 | - jupyter_console=6.1.0=py_0 32 | - jupyter_core=4.6.3=py37_0 33 | - libpng=1.6.37=h2a8f88b_0 34 | - libsodium=1.0.16=h9d3ae62_0 35 | - m2w64-gcc-libgfortran=5.3.0=6 36 | - m2w64-gcc-libs=5.3.0=7 37 | - m2w64-gcc-libs-core=5.3.0=7 38 | - m2w64-gmp=6.1.0=2 39 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 40 | - markupsafe=1.1.1=py37he774522_0 41 | - mistune=0.8.4=py37he774522_0 42 | - mkl=2020.0=166 43 | - mkl-service=2.3.0=py37hb782905_0 44 | - mkl_fft=1.0.15=py37h14836fe_0 45 | - mkl_random=1.1.0=py37h675688f_0 46 | - msys2-conda-epoch=20160418=1 47 | - nbconvert=5.6.1=py37_0 48 | - nbformat=5.0.4=py_0 49 | - ninja=1.9.0=py37h74a9793_0 50 | - notebook=6.0.3=py37_0 51 | - numpy=1.18.1=py37h93ca92e_0 52 | - numpy-base=1.18.1=py37hc3f5095_1 53 | - openssl=1.1.1f=he774522_0 54 | - pandoc=2.2.3.2=0 55 | - pandocfilters=1.4.2=py37_1 56 | - parso=0.6.2=py_0 57 | - pickleshare=0.7.5=py37_0 58 | - pip=20.0.2=py37_1 59 | - pip: 60 | - git+git://github.com/stared/livelossplot.git 61 | - prometheus_client=0.7.1=py_0 62 | - prompt-toolkit=3.0.4=py_0 63 | - prompt_toolkit=3.0.4=0 64 | - pycparser=2.20=py_0 65 | - pygments=2.6.1=py_0 66 | - pyqt=5.9.2=py37h6538335_2 67 | - pyrsistent=0.16.0=py37he774522_0 68 | - python=3.7.7=h60c2a47_0_cpython 69 | - python-dateutil=2.8.1=py_0 70 | - pytorch=1.3.1=cpu_py37h9f948e0_0 71 | - pywin32=227=py37he774522_1 72 | - pywinpty=0.5.7=py37_0 73 | - pyzmq=18.1.1=py37ha925a31_0 74 | - qt=5.9.7=vc14h73c81de_0 75 | - qtconsole=4.7.2=py_0 76 | - qtpy=1.9.0=py_0 77 | - send2trash=1.5.0=py37_0 78 | - setuptools=46.1.3=py37_0 79 | - sip=4.19.8=py37h6538335_0 80 | - six=1.14.0=py37_0 81 | - sqlite=3.31.1=he774522_0 82 | - terminado=0.8.3=py37_0 83 | - testpath=0.4.4=py_0 84 | - tornado=6.0.4=py37he774522_1 85 | - traitlets=4.3.3=py37_0 86 | - vc=14.1=h0510ff6_4 87 | - vs2015_runtime=14.16.27012=hf0eaf9b_1 88 | - wcwidth=0.1.9=py_0 89 | - webencodings=0.5.1=py37_1 90 | - wheel=0.34.2=py37_0 91 | - widgetsnbextension=3.5.1=py37_0 92 | - wincertstore=0.2=py37_0 93 | - winpty=0.4.3=4 94 | - zeromq=4.3.1=h33f27b4_3 95 | - zipp=2.2.0=py_0 96 | - zlib=1.2.11=h62dcd97_3 97 | prefix: C:\Users\joshpc\Anaconda3\envs\font-style-gan 98 | 99 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | from operator import itemgetter 2 | 3 | import torch 4 | from torch.autograd import grad as torch_grad 5 | 6 | def l1_and_adversarial_loss(D, G, real_data, generated_data, losses, options): 7 | l1_lamba = 10 8 | return min_max_loss(D, G, real_data, generated_data, losses, options) + l1_lamba * l1_loss(D, G, real_data, generated_data, losses, options) 9 | 10 | def wasserstein_loss(D, G, real_data, generated_data, losses, options): 11 | real_loss = D(real_data) 12 | generated_loss = D(generated_data) 13 | 14 | batch_size, data_type = itemgetter('batch_size', 'data_type')(options) 15 | gradient_penalty_weight = 10 16 | 17 | # Calculate gradient penalty 18 | gradient_penalty = calculate_gradient_penalty(D, real_data, generated_data, batch_size, gradient_penalty_weight, losses, data_type) 19 | losses['GP'].append(gradient_penalty.data) 20 | 21 | # Calculate the Wasserstein Distance. 22 | loss = generated_loss.mean() - real_loss.mean() + gradient_penalty 23 | losses['Generated'].append(generated_loss.mean().data) 24 | losses['Real'].append(real_loss.mean().data) 25 | losses['D'].append(loss.data) 26 | 27 | return loss 28 | 29 | def min_max_loss(D, G, real_data, generated_data, losses, options): 30 | # Forward pass with the discriminator 31 | discriminator_loss = D(generated_data) 32 | 33 | # Update the loss. We're trying to fool the discriminator to say '1, this is real' 34 | loss = -discriminator_loss.mean() 35 | 36 | return loss 37 | 38 | def l1_loss(D, G, real_data, generated_data, losses, options): 39 | """ 40 | Performs the L1 loss between the generated data and the real data. 41 | 42 | It is expected that both `real_data` and `generated_data` are of the same shape. 43 | """ 44 | return torch.nn.L1Loss()(generated_data, real_data) 45 | 46 | 47 | def calculate_gradient_penalty(D, real_data, generated_data, batch_size, gradient_penalty_weight, losses, data_type): 48 | # Calculate interpolation 49 | alpha = torch.rand(batch_size, 1, 1, 1).expand_as(real_data).type(data_type) 50 | 51 | # 'interpolated' is x-hat 52 | interpolated = (alpha * real_data.data + (1 - alpha) * generated_data.data).type(data_type) 53 | interpolated.requires_grad = True 54 | 55 | # Calculate probability of interpolated examples 56 | probability_interpolated = D(interpolated) 57 | 58 | gradients = torch_grad(outputs=probability_interpolated, 59 | inputs=interpolated, 60 | grad_outputs=torch.ones( 61 | probability_interpolated.size()).type(data_type), 62 | create_graph=True, 63 | retain_graph=True)[0] 64 | 65 | # Gradients have shape (batch_size, num_channels, img_width, img_height), 66 | # so flatten to easily take norm per example in batch 67 | gradients = gradients.view(batch_size, -1) 68 | losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data) 69 | 70 | # Derivatives of the gradient close to 0 can cause problems because of 71 | # the square root, so manually calculate norm and add epsilon 72 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 73 | 74 | # Return gradient penalty 75 | return gradient_penalty_weight * ((gradients_norm - 1) ** 2).mean() -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | def build_font_shape_generator(glyph_size=(64, 64, 1), glyph_count=26, dimension=16): 6 | """ 7 | Generator model for our GAN. 8 | 9 | Architecture is similar to DC-GAN with the exception of the input being an image. 10 | 11 | Inputs: 12 | - `image_size`: A triple (W, H, C) for the size of the images and number of channels. This model generates images the same size as the input (but for every character of the alphabet.) 13 | - `dimension`: Depth 14 | 15 | Output: 16 | - 17 | """ 18 | 19 | return intermediate_generator_alt(glyph_size=glyph_size, glyph_count=glyph_count, dimension=dimension) 20 | 21 | def simple_upscale_generator(dimension): 22 | """ 23 | A generator that performs several ConvTranpsose2D Operations to upscale an image from `individual_image_size` to `final_image_size`. The dimensions of `final_image_size` must be an integer multiple of `individual_image_size.` 24 | 25 | Inputs: 26 | - `individual_image_size`: (W, H) the size of the images provided (and the expected output size.) 27 | - `dimension`: This imapcts the scale of the number of features in the upscale. 28 | 29 | Output: 30 | - An image that is 256 * 512 31 | """ 32 | 33 | return nn.Sequential( 34 | # We start with a simple image of size (W, H) 35 | # K = 3, S = 1, P = 1 -- Get a full view of the image. 36 | # This changes an image from 64x64 to 128 * 128 (2x2 grid = 4 images) - We start in a high dimension 37 | nn.ConvTranspose2d(in_channels=1, out_channels=(8 * dimension), kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), 38 | nn.ReLU(), 39 | 40 | # This changes an image from 128 * 128 to 256 * 256 (4x4 grid = 16 images) - Scale down the dimensionality 41 | nn.ConvTranspose2d(in_channels=(8 * dimension), out_channels=(4 * dimension), kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)), 42 | nn.ReLU(), 43 | 44 | # # Reduce dimensionality without changing the image. Stays at 4x4 grid. 45 | nn.ConvTranspose2d(in_channels=(4 * dimension), out_channels=(2 * dimension), kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 46 | nn.ReLU(), 47 | 48 | # # This changes an image from 256 * 256 to 256 * 512 (4x8 grid = 32 images) 49 | nn.ConvTranspose2d(in_channels=(2 * dimension), out_channels=dimension, kernel_size=(3, 4), stride=(1, 2), padding=(1, 1)), 50 | nn.ReLU(), 51 | 52 | # # Reduce dimensionality back to 1! 53 | nn.ConvTranspose2d(in_channels=dimension, out_channels=1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 54 | nn.Sigmoid() 55 | ) 56 | 57 | def intermediate_generator(glyph_size=(64, 64), glyph_count=26, dimension=16): 58 | linear_width = int(2 * dimension * glyph_size[0] / 4 * glyph_size[1] / 4) 59 | hidden_width = int(glyph_size[0] * glyph_size[1]) 60 | # Final 2 * 2 is because we Conv Trans 3 times: 2x2 -> 4x4 -> 8x8 -> 16x16 61 | final_width = int(4 * dimension * glyph_count * 2 * 2) 62 | 63 | return nn.Sequential( 64 | # (1, 64, 64) -> (D, 32, 32) 65 | nn.Conv2d(1, dimension, 4, 2, 1), 66 | nn.LeakyReLU(0.2), 67 | 68 | # (D, 32, 32) -> (2D, 16, 16) 69 | nn.Conv2d(dimension, 2 * dimension, 4, 2, 1), 70 | nn.LeakyReLU(0.2), 71 | 72 | # (2D, 16, 16) -> (1, 256 * D) 73 | Flatten(), 74 | 75 | # 256D -> 4096 76 | nn.Linear(in_features=linear_width, out_features=hidden_width), 77 | nn.ReLU(), 78 | 79 | # 4096 -> 16 * 16 * 26 * D 80 | # = 6656 D 81 | nn.Linear(hidden_width, final_width), 82 | nn.ReLU(), 83 | 84 | # 6656 D -> (4D, 16, 416) 85 | Unflatten(C=dimension * 4, H=2, W=2 * glyph_count), 86 | nn.BatchNorm2d(dimension * 4), 87 | 88 | # Fractionally Strided Conv 1 89 | nn.ConvTranspose2d(4 * dimension, 2 * dimension, 4, 2, 1), #4 * 4 * GC * 2D 90 | nn.BatchNorm2d(2 * dimension), 91 | nn.ReLU(), 92 | 93 | # Fractionally Strided Conv 2 94 | nn.ConvTranspose2d(2 * dimension, dimension, 4, 2, 1), #8 * 8 * GC * D 95 | nn.BatchNorm2d(dimension), 96 | nn.ReLU(), 97 | 98 | # Fractionally Strided Conv 3 99 | # (D, 16, 416) -> (1, 16, 416) 100 | nn.ConvTranspose2d(dimension, 1, 4, 2, 1), # 16 * 16 * GC * 1 101 | nn.Sigmoid() 102 | ) 103 | 104 | def intermediate_generator_alt(glyph_size=(16, 16), glyph_count=26, dimension=512): 105 | conv_dimensions = [dimension, int(dimension / 2), int(dimension / 4)] 106 | fc_layer_widths = [ 107 | int(conv_dimensions[2] * glyph_size[0] / 8 * glyph_size[1] / 8), 108 | int(glyph_size[0] * glyph_size[1]), 109 | int(glyph_size[0] / 8 * glyph_size[1] / 8 * glyph_count * dimension / 4) 110 | ] 111 | upconv_dimensions = [int(dimension / 4), int(dimension / 8), int(dimension / 16), 1] 112 | 113 | return nn.Sequential( 114 | # (1, 16, 16) -> (D, 8, 8) 115 | nn.Conv2d(1, conv_dimensions[0], 4, 2, 1), 116 | nn.LeakyReLU(0.2), 117 | 118 | # (D, 8, 8) -> (D/2, 4, 4) 119 | nn.Conv2d(conv_dimensions[0], conv_dimensions[1], 4, 2, 1), 120 | nn.LeakyReLU(0.2), 121 | 122 | # (D/2, 4, 4) -> (D/4, 2, 2) 123 | nn.Conv2d(conv_dimensions[1], conv_dimensions[2], 4, 2, 1), 124 | nn.LeakyReLU(0.2), 125 | 126 | # (D/4, 2, 2) -> (1, D) 127 | Flatten(), 128 | 129 | # D -> 256 (W * H) 130 | nn.Linear(in_features=fc_layer_widths[0], out_features=fc_layer_widths[1]), 131 | nn.ReLU(), 132 | 133 | # 256 -> 16 * 16 * 26 * D (W * H * GC * D/4) 134 | nn.Linear(fc_layer_widths[1], fc_layer_widths[2]), 135 | nn.ReLU(), 136 | 137 | # W * G * GC * D/8 -> (D/4, H, W * GC) 138 | Unflatten(C=upconv_dimensions[0], H=int(glyph_size[0] / 8), W=int(glyph_size[1] / 8) * glyph_count), 139 | nn.BatchNorm2d(upconv_dimensions[0]), 140 | 141 | # Fractionally Strided Conv 1 142 | nn.ConvTranspose2d(upconv_dimensions[0], upconv_dimensions[1], 4, 2, 1), #4 * 4 * GC * D/8 143 | nn.BatchNorm2d(upconv_dimensions[1]), 144 | nn.ReLU(), 145 | 146 | # Fractionally Strided Conv 2 147 | nn.ConvTranspose2d(upconv_dimensions[1], upconv_dimensions[2], 4, 2, 1), #8 * 8 * GC * D/16 148 | nn.BatchNorm2d(upconv_dimensions[2]), 149 | nn.ReLU(), 150 | 151 | # Fractionally Strided Conv 3 152 | # (D, 16, 416) -> (1, 16, 416) 153 | nn.ConvTranspose2d(upconv_dimensions[2], upconv_dimensions[3], 4, 2, 1), # 16 * 16 * GC * 1 154 | nn.Sigmoid() 155 | ) 156 | 157 | def build_font_shape_discriminator(image_size=(64, 1664), dimension=16): 158 | """ 159 | PyTorch model implementing the GlyphGAN critic. 160 | 161 | Inputs: 162 | - `image_size`: The size of the entire alphabet (usually (H, W * 26)) 163 | - `dimension`: The filter depth after each conv. Doubles per conv layer (1 - > 2 -> 4 -> 8) 164 | """ 165 | 166 | output_size = int(8 * dimension * (image_size[0] / 16) * (image_size[1] / 16)) 167 | 168 | return nn.Sequential( 169 | nn.Conv2d(1, dimension, 4, 2, 1), 170 | nn.LeakyReLU(0.2), 171 | 172 | nn.Conv2d(dimension, 2 * dimension, 4, 2, 1), 173 | nn.LeakyReLU(0.2), 174 | 175 | nn.Conv2d(2 * dimension, 4 * dimension, 4, 2, 1), 176 | nn.LeakyReLU(0.2), 177 | 178 | nn.Conv2d(4 * dimension, 8 * dimension, 4, 2, 1), 179 | nn.LeakyReLU(0.2), 180 | 181 | Flatten(), 182 | 183 | nn.Linear(output_size, 1), # Default size will be 3328 variable linear layer 184 | nn.Sigmoid() 185 | ) 186 | 187 | def get_optimizer(model, learning_rate=2e-4, beta1=0.5, beta2=0.99): 188 | """ 189 | Adam optimizer for model 190 | 191 | Input: 192 | - model: A PyTorch model that we want to optimize. 193 | 194 | Returns: 195 | - An Adam optimizer for the model with the desired hyperparameters. 196 | """ 197 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2)) 198 | return optimizer 199 | 200 | class Flatten(nn.Module): 201 | def forward(self, x): 202 | N, _, _, _ = x.size() # read in N, C, H, W 203 | return x.view(N, -1) # "flatten" the C * H * W values into a single vector per image 204 | 205 | class Unflatten(nn.Module): 206 | """ 207 | An Unflatten module receives an input of shape (N, C*H*W) and reshapes it 208 | to produce an output of shape (N, C, H, W). 209 | """ 210 | def __init__(self, N=-1, C=128, H=7, W=7): 211 | super(Unflatten, self).__init__() 212 | self.N = N 213 | self.C = C 214 | self.H = H 215 | self.W = W 216 | 217 | def forward(self, x): 218 | return x.view(self.N, self.C, self.H, self.W) 219 | 220 | def initialize_weights(m): 221 | if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): 222 | nn.init.xavier_uniform_(m.weight.data) -------------------------------------------------------------------------------- /results/results_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_1.png -------------------------------------------------------------------------------- /results/results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_2.png -------------------------------------------------------------------------------- /results/results_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_3.png -------------------------------------------------------------------------------- /results/results_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_4.png -------------------------------------------------------------------------------- /results/results_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_5.png -------------------------------------------------------------------------------- /results/results_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_6.png -------------------------------------------------------------------------------- /results/results_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_7.png -------------------------------------------------------------------------------- /results/results_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/results/results_8.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/test.py -------------------------------------------------------------------------------- /tests/font_dataset_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os.path import dirname, join, abspath 3 | sys.path.insert(0, abspath(join(dirname(__file__), '../..'))) 4 | #TODO: Swap to syspath.append(path.dirname(path.dirname(path.realpath(file)))) 5 | 6 | import unittest 7 | 8 | from data.font_dataset import FontDataset 9 | 10 | class TestFontDatasets(unittest.TestCase): 11 | def test_cannot_create_invalid_font_dataset(self): 12 | with self.assertRaises(AssertionError): 13 | FontDataset('does_not_exist') 14 | 15 | def test_can_create_font_dataset(self): 16 | dataset = FontDataset(abspath(join(dirname(__file__), 'test_datasets/valid'))) 17 | self.assertEqual(1, len(dataset)) 18 | 19 | def test_length_of_empty_folder(self): 20 | dataset = FontDataset(abspath(join(dirname(__file__), 'test_datasets/empty'))) 21 | self.assertEqual(0, len(dataset)) 22 | -------------------------------------------------------------------------------- /tests/test_datasets/valid/Aaargh.0.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/tests/test_datasets/valid/Aaargh.0.0.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | from operator import itemgetter 4 | 5 | import torch 6 | import collections 7 | 8 | # Visualization 9 | from livelossplot import PlotLosses 10 | from util import show_grayscale_image 11 | 12 | def train(D, G, D_optimizer, G_optimizer, D_loss, G_loss, data_loader, options): 13 | """ 14 | Inputs: 15 | - `options`: A dictionary of options to configure the GAN. Required values: 16 | `batch_size` - (int) The size of each batch. 17 | `epoch_count` - (int) The number of epochs to run. 18 | `data_type` - 19 | `glyph_size` - (tuple or triple, [int, int, (int)]) The size of the image (H, W, C) 20 | `glyphs_per_image` - (int) The number of glyphs found on each image 21 | 22 | Returns: Dictionary of losses. 23 | 24 | """ 25 | epoch_count = options['epoch_count'] 26 | visualize = options['visualize'] 27 | losses = collections.defaultdict(list) 28 | loss_plot = PlotLosses() 29 | 30 | if visualize: 31 | real_test, static_test = prepare_static_test(data_loader, options) 32 | visualize_progress(G, real_test, static_test) 33 | 34 | for _ in range(epoch_count): 35 | train_epoch(D, G, D_optimizer, G_optimizer, D_loss, G_loss, data_loader, losses, options) 36 | 37 | if visualize: 38 | record_losses(loss_plot, losses) 39 | visualize_progress(G, real_test, static_test) 40 | 41 | return losses 42 | 43 | def train_epoch(D, G, D_optimizer, G_optimizer, D_loss, G_loss, data_loader, losses, options): 44 | steps = 0 45 | batch_size = options['batch_size'] 46 | data_type = options['data_type'] 47 | 48 | for data in data_loader: 49 | if len(data) % batch_size != 0: 50 | continue 51 | data = data.type(data_type) 52 | 53 | steps += 1 54 | 55 | train_discriminator(D, G, D_optimizer, D_loss, data, losses, options) 56 | 57 | # TODO: Parameterize 58 | if steps % 5 == 0: 59 | train_generator(D, G, G_optimizer, G_loss, data, losses, options) 60 | 61 | def train_generator(D, G, G_optimizer, G_loss, data, losses, options): 62 | """ 63 | Executes one interation of training for the generator. This is a classic GAN setup. 64 | 65 | No return value. 66 | """ 67 | glyph_size, glyphs_per_image = itemgetter('glyph_size', 'glyphs_per_image')(options) 68 | 69 | G_optimizer.zero_grad() 70 | 71 | # Prepare our data. We only use the letter A to seed this entire process. 72 | generator_input = prepare_generator_input(data, glyph_size, glyphs_per_image) 73 | generated_data = reshape_generated_data(G(generator_input), glyph_size, glyphs_per_image) 74 | real_data = reshape_real_data(data, glyph_size, glyphs_per_image) 75 | 76 | loss = G_loss(D, G, real_data, generated_data, losses, options) 77 | loss.backward() 78 | losses['G'].append(loss.data) 79 | 80 | G_optimizer.step() 81 | 82 | def train_discriminator(D, G, D_optimizer, D_loss, data, losses, options): 83 | """ 84 | Executes one iteration of training for the discriminator. 85 | 86 | No return value. 87 | """ 88 | glyph_size, glyphs_per_image = itemgetter('glyph_size', 'glyphs_per_image')(options) 89 | 90 | D_optimizer.zero_grad() 91 | 92 | # Prepare the data 93 | generator_input = prepare_generator_input(data, glyph_size, glyphs_per_image) 94 | generated_data = reshape_generated_data(G(generator_input), glyph_size, glyphs_per_image) 95 | real_data = reshape_real_data(data, glyph_size, glyphs_per_image) 96 | 97 | # Calculate the loss 98 | loss = D_loss(D, G, real_data, generated_data, losses, options) 99 | loss.backward() 100 | losses['D'].append(loss.data) 101 | 102 | # Perform backwards pass 103 | D_optimizer.step() 104 | 105 | # --- Helper Functions --- 106 | 107 | def visualize_progress(G, real, static): 108 | show_grayscale_image(real[0].cpu()) 109 | show_grayscale_image(static[0].cpu()) 110 | show_grayscale_image(G(static)[0].cpu()) 111 | 112 | def prepare_generator_input(image_data, glyph_size, glyphs_per_image): 113 | base = random.randint(0, glyphs_per_image - 1) 114 | image_width = glyph_size[1] 115 | return image_data[:, :, :, base * image_width:(base + 1) * image_width] 116 | 117 | def reshape_real_data(real_data, glyph_size, glyphs_per_image): 118 | return real_data[:, :, :, 0:(glyphs_per_image * glyph_size[1])] 119 | 120 | def reshape_generated_data(generated_output, glyph_size, glyphs_per_image): 121 | # generated_shape = generated_output.shape 122 | return generated_output[:, :, :, 0:(glyphs_per_image * glyph_size[1])] 123 | # # Flatten the output, then take only letters A-Z (64 x 26 = 1664) -- Ignore the dead space 124 | # return generated_output.reshape( 125 | # generated_shape[0], 126 | # generated_shape[1], 127 | # glyph_size[0], 128 | # glyph_size[1] * 32 129 | # )[:, :, :, 0:glyph_size[1] * glyphs_per_image] 130 | 131 | def record_losses(loss_plot, losses): 132 | record = {} 133 | for key in losses.keys(): 134 | record[key] = losses[key][-1] 135 | loss_plot.update(record) 136 | loss_plot.send() 137 | 138 | def prepare_static_test(data_loader, options): 139 | real_test = None 140 | static_test = None 141 | glyph_size, glyphs_per_image, data_type = itemgetter('glyph_size', 'glyphs_per_image', 'data_type')(options) 142 | 143 | for data in data_loader: 144 | real_test = reshape_real_data(data, glyph_size, glyphs_per_image) 145 | static_test = prepare_generator_input(data, glyph_size, glyphs_per_image).type(data_type) 146 | break 147 | 148 | return real_test, static_test 149 | -------------------------------------------------------------------------------- /trained/better_trained_d_state_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/trained/better_trained_d_state_dict.pt -------------------------------------------------------------------------------- /trained/better_trained_g_state_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/trained/better_trained_g_state_dict.pt -------------------------------------------------------------------------------- /trained/well_trained_d_state_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/trained/well_trained_d_state_dict.pt -------------------------------------------------------------------------------- /trained/well_trained_g_state_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshpc/StyledFontGAN/da0d1c9ec7d251be4ea0053d23dca692f39433db/trained/well_trained_g_state_dict.pt -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plot 2 | import torchvision.transforms as transforms 3 | 4 | def show_grayscale_image(image): 5 | plot.imshow(transforms.Compose([ 6 | transforms.ToPILImage(), 7 | transforms.Grayscale(num_output_channels=3) 8 | ])(image)) 9 | plot.axis('off') 10 | plot.show() --------------------------------------------------------------------------------