├── .gitignore ├── LICENSE ├── README.md └── transgan_pytorch ├── test.py └── transgan_pytorch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishabh Anand 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransGAN-PyTorch 2 | PyTorch implementation of the TransGAN paper 3 | 4 | The original paper can be found [here](https://arxiv.org/abs/2102.07074). 5 | 6 | ### Installation 7 | You can install the package via `pip`: 8 | 9 | ```bash 10 | pip install transgan-pytorch 11 | ``` 12 | 13 | ### Usage 14 | 15 | ```python 16 | import torch 17 | from transgan_pytorch import TransGAN 18 | 19 | tgan = TransGAN(...) 20 | 21 | z = torch.rand(100) # random noise 22 | pred = tgan(z) 23 | ``` 24 | 25 | ### License 26 | [MIT](https://github.com/rish-16/TransGAN-PyTorch/blob/main/LICENSE) -------------------------------------------------------------------------------- /transgan_pytorch/test.py: -------------------------------------------------------------------------------- 1 | from transgan_pytorch.transgan_pytorch import TransGAN 2 | 3 | tgan = TransGAN( 4 | z_dim=100, 5 | output_gim=32*32 6 | ) -------------------------------------------------------------------------------- /transgan_pytorch/transgan_pytorch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Paper uses Vaswani (2017) Attention with minimal changes. 8 | 9 | Multi-head self-attention with a feed-forward MLP 10 | with GELU non-linearity. Layer normalisation is used 11 | before each segment and employs residual skip connections. 12 | """ 13 | class Attention(nn.Module): 14 | def __init__(self, D, heads=8): 15 | super().__init__() 16 | self.D = D 17 | self.heads = heads 18 | 19 | assert (D % heads == 0), "Embedding size should be divisble by number of heads" 20 | self.head_dim = self.D // heads 21 | 22 | self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) 23 | self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) 24 | self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) 25 | self.H = nn.Linear(self.D, self.D) 26 | 27 | def forward(self, Q, K, V, mask): 28 | batch_size = Q.shape[0] 29 | q_len, k_len, v_len = Q.shape[1], K.shape[1], V.shape[1] 30 | 31 | Q = Q.reshape(batch_size, q_len, self.heads, self.head_dim) 32 | K = K.reshape(batch_size, k_len, self.heads, self.head_dim) 33 | V = V.reshape(batch_size, v_len, self.heads, self.head_dim) 34 | 35 | # performing batch-wise matrix multiplication 36 | raw_scores = torch.einsum("bqhd,bkhd->bhqk", [Q, K]) 37 | 38 | # shut off triangular matrix with very small value 39 | scores = raw_scores.masked_fill(mask == 0, -np.inf) if mask else raw_scores 40 | 41 | attn = torch.softmax(scores / np.sqrt(self.D), dim=3) 42 | attn_output = torch.einsum("bhql,blhd->bqhd", [attn, V]) 43 | attn_output = attn_output.reshape(batch_size, q_len, self.D) 44 | 45 | output = self.H(attn_output) 46 | 47 | return output 48 | 49 | class EncoderBlock(nn.Module): 50 | def __init__(self, D, heads, p, fwd_exp): 51 | super().__init__() 52 | self.mha = Attention(D, heads) 53 | self.drop_prob = p 54 | self.n1 = nn.LayerNorm(D) 55 | self.n2 = nn.LayerNorm(D) 56 | self.mlp = nn.Sequential( 57 | nn.Linear(D, fwd_exp*D), 58 | nn.ReLU(), 59 | nn.Linear(fwd_exp*D, D), 60 | ) 61 | self.dropout = nn.Dropout(p) 62 | 63 | def forward(self, Q, K, V, mask): 64 | attn = self.mha(Q, K, V, mask) 65 | 66 | """ 67 | Layer normalisation with residual connections 68 | """ 69 | x = self.n1(attn + Q) 70 | x = self.dropout(x) 71 | forward = self.mlp(x) 72 | x = self.n2(forward + x) 73 | out = self.dropout(x) 74 | 75 | return out 76 | 77 | class MLP(nn.Module): 78 | def __init__(self, noise_w, noise_h, channels): 79 | super().__init__() 80 | self.l1 = nn.Linear( 81 | noise_w*noise_h*channels, 82 | (8*8)*noise_w*noise_h*channels, 83 | bias=False 84 | ) 85 | 86 | def forward(self, x): 87 | out = self.l1(x) 88 | return out 89 | 90 | class PixelShuffle(nn.Module): 91 | def __init__(self): 92 | super().__init__() 93 | pass 94 | 95 | class Generator(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | self.mlp = MLP(32, 32, 1) 99 | 100 | # stage 1 101 | self.s1_enc = nn.ModuleList([ 102 | EncoderBlock(1024*8*8) 103 | for _ in range(5) 104 | ]) 105 | 106 | # stage 2 107 | self.s2_pix_shuffle = PixelShuffle() 108 | self.s2_enc = nn.ModuleList([ 109 | EncoderBlock(256*16*16) 110 | for _ in range (4) 111 | ]) 112 | 113 | # stage 3 114 | self.s3_pix_shuffle = PixelShuffle() 115 | self.s3_enc = nn.ModuleList([ 116 | EncoderBlock(64*32*32) 117 | for _ in range(2) 118 | ]) 119 | 120 | # stage 4 121 | self.linear = nn.Linear(32*32*64, 32*32*3) 122 | 123 | def forward(self, noise): 124 | x = self.mlp(noise) 125 | for layer in self.s1_enc: 126 | x = layer(x) 127 | 128 | x = self.s2_pix_shuffle(x) 129 | for layer in self.s2_enc: 130 | x = layer(x) 131 | 132 | x - self.s3_pix_shuffle(x) 133 | for layer in self.s3_enc: 134 | x = layer(x) 135 | 136 | img = self.linear(x) 137 | 138 | return img 139 | 140 | class Discriminator(nn.Module): 141 | def __init__(self): 142 | super().__init__() 143 | 144 | self.l1 = nn.Linear(32*32*3, (8*8+1)*384) 145 | self.s2_enc = nn.ModuleList([ 146 | EncoderBlock((8*8+1)*284) 147 | for _ in range(7) 148 | ]) 149 | 150 | self.classification_head = nn.Linear(1*384, 1) 151 | 152 | def forward(self, img): 153 | x = self.l1(img) 154 | for layer in self.s2_enc: 155 | x = layer(x) 156 | 157 | logits = self.classification_head(x) 158 | pred = F.softmax(logits) 159 | 160 | return pred 161 | 162 | class TransGAN_S(nn.Module): 163 | def __init__(self): 164 | super().__init__() 165 | self.gen = Generator() 166 | self.disc = Discriminator() 167 | 168 | def forward(self, noise): 169 | img = self.gen(noise) 170 | pred = self.disc(img) 171 | 172 | return img, pred --------------------------------------------------------------------------------