├── .gitignore ├── LICENSE ├── README.md ├── attn_gan_pytorch ├── ConfigManagement.py ├── CustomLayers.py ├── Losses.py ├── Networks.py ├── Utils.py └── __init__.py ├── literature └── self_attention_gan.pdf ├── samples ├── .gitignore ├── data_processing │ ├── DataLoader.py │ └── __init__.py ├── generate_loss_plots.py └── sample_celeba │ ├── .gitignore │ ├── configs │ ├── 1 │ │ ├── dis.conf │ │ └── gen.conf │ ├── 2 │ │ ├── dis.conf │ │ └── gen.conf │ └── 3 │ │ ├── dis.conf │ │ └── gen.conf │ └── train.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # ignore the pycharm setup 107 | .idea/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Animesh Karnewar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # attn_gan_pytorch 2 | python package for self-attention gan implemented as 3 | extension of PyTorch nn.Module. 4 | paper -> https://arxiv.org/abs/1805.08318
5 | 6 | Also includes generic layers for image based attention mechanism. 7 | Includes a **`Full-Attention`** layer as proposed by in another 8 | project of mine [here](https://github.com/akanimax/fagan) 9 | 10 | ## Installation: 11 | This is a python package availbale at the 12 | [**pypi.org**](https://pypi.org/project/attn-gan-pytorch/#description). 13 | So, installation is fairly straightforward. This package depends on 14 | a suitable GPU version of **`torch`** and **`torch-vision`** for your 15 | architecture. So, please download suitable pytorch prior to installing 16 | this package. Follow the instructions at 17 | [pytorch.org](https://pytorch.org/) to install your version of PyTorch. 18 |

19 | Install with following commands: 20 | 21 | $ workon [your virtual environment] 22 | $ pip install attn-gan-pytorch 23 | 24 | ## Celeba Samples: 25 | some celeba samples generated using this code for the 26 | fagan architecture: 27 |

28 | generated samples 29 |

30 | 31 | ### Head over to the [**Fagan project**](https://github.com/akanimax/fagan) repo for more info! 32 | Also, this repo contains the code for using this package 33 | to build the `SAGAN` architecture as mentioned in the paper. 34 | Please refer the `samples/` directory for this. 35 | 36 | ## Thanks 37 | Please feel free to open PRs here if you train on other datasets 38 | using this package. Suggestions / Issues / Contributions are most 39 | welcome. 40 | 41 | Best regards,
42 | @akanimax :) 43 | -------------------------------------------------------------------------------- /attn_gan_pytorch/ConfigManagement.py: -------------------------------------------------------------------------------- 1 | """ Module for reading and parsing configuration files """ 2 | 3 | import yaml 4 | 5 | 6 | def get_config(conf_file): 7 | """ 8 | parse and load the provided configuration 9 | :param conf_file: configuration file 10 | :return: conf => parsed configuration 11 | """ 12 | from easydict import EasyDict as edict 13 | 14 | with open(conf_file, "r") as file_descriptor: 15 | data = yaml.load(file_descriptor) 16 | 17 | # convert the data into an easyDictionary 18 | return edict(data) 19 | 20 | 21 | def parse2tuple(inp_str): 22 | """ 23 | function for parsing a 2 tuple of integers 24 | :param inp_str: string of the form: '(3, 3)' 25 | :return: tuple => parsed tuple 26 | """ 27 | inp_str = inp_str[1: -1] # remove the parenthesis 28 | args = inp_str.split(',') 29 | args = tuple(map(int, args)) 30 | 31 | return args -------------------------------------------------------------------------------- /attn_gan_pytorch/CustomLayers.py: -------------------------------------------------------------------------------- 1 | """ Module implements the custom layers """ 2 | 3 | import torch as th 4 | 5 | 6 | class SelfAttention(th.nn.Module): 7 | """ 8 | Layer implements the self-attention module 9 | which is the main logic behind this architecture. 10 | 11 | args: 12 | channels: number of channels in the image tensor 13 | activation: activation function to be applied (default: lrelu(0.2)) 14 | squeeze_factor: squeeze factor for query and keys (default: 8) 15 | bias: whether to apply bias or not (default: True) 16 | """ 17 | 18 | def __init__(self, channels, activation=None, squeeze_factor=8, bias=True): 19 | """ constructor for the layer """ 20 | 21 | from torch.nn import Conv2d, Parameter, Softmax 22 | 23 | # base constructor call 24 | super().__init__() 25 | 26 | # state of the layer 27 | self.activation = activation 28 | self.gamma = Parameter(th.zeros(1)) 29 | 30 | # Modules required for computations 31 | self.query_conv = Conv2d( # query convolution 32 | in_channels=channels, 33 | out_channels=channels // squeeze_factor, 34 | kernel_size=(1, 1), 35 | stride=1, 36 | padding=0, 37 | bias=bias 38 | ) 39 | 40 | self.key_conv = Conv2d( 41 | in_channels=channels, 42 | out_channels=channels // squeeze_factor, 43 | kernel_size=(1, 1), 44 | stride=1, 45 | padding=0, 46 | bias=bias 47 | ) 48 | 49 | self.value_conv = Conv2d( 50 | in_channels=channels, 51 | out_channels=channels, 52 | kernel_size=(1, 1), 53 | stride=1, 54 | padding=0, 55 | bias=bias 56 | ) 57 | 58 | # softmax module for applying attention 59 | self.softmax = Softmax(dim=-1) 60 | 61 | def forward(self, x): 62 | """ 63 | forward computations of the layer 64 | :param x: input feature maps (B x C x H x W) 65 | :return: 66 | out: self attention value + input feature (B x O x H x W) 67 | attention: attention map (B x C x H x W) 68 | """ 69 | 70 | # extract the shape of the input tensor 71 | m_batchsize, c, height, width = x.size() 72 | 73 | # create the query projection 74 | proj_query = self.query_conv(x).view( 75 | m_batchsize, -1, width * height).permute(0, 2, 1) # B x (N) x C 76 | 77 | # create the key projection 78 | proj_key = self.key_conv(x).view( 79 | m_batchsize, -1, width * height) # B x C x (N) 80 | 81 | # calculate the attention maps 82 | energy = th.bmm(proj_query, proj_key) # energy 83 | attention = self.softmax(energy) # attention (B x (N) x (N)) 84 | 85 | # create the value projection 86 | proj_value = self.value_conv(x).view( 87 | m_batchsize, -1, width * height) # B X C X N 88 | 89 | # calculate the output 90 | out = th.bmm(proj_value, attention.permute(0, 2, 1)) 91 | out = out.view(m_batchsize, c, height, width) 92 | 93 | attention = attention.view(m_batchsize, -1, height, width) 94 | 95 | if self.activation is not None: 96 | out = self.activation(out) 97 | 98 | out = self.gamma * out + x 99 | return out, attention 100 | 101 | 102 | class FullAttention(th.nn.Module): 103 | """ 104 | Layer implements my version of the self-attention module 105 | it is mostly same as self attention, but generalizes to 106 | (k x k) convolutions instead of (1 x 1) 107 | args: 108 | in_channels: number of input channels 109 | out_channels: number of output channels 110 | activation: activation function to be applied (default: lrelu(0.2)) 111 | kernel_size: kernel size for convolution (default: (1 x 1)) 112 | transpose_conv: boolean denoting whether to use convolutions or transpose 113 | convolutions 114 | squeeze_factor: squeeze factor for query and keys (default: 8) 115 | stride: stride for the convolutions (default: 1) 116 | padding: padding for the applied convolutions (default: 1) 117 | bias: whether to apply bias or not (default: True) 118 | """ 119 | 120 | def __init__(self, in_channels, out_channels, 121 | activation=None, kernel_size=(1, 1), transpose_conv=False, 122 | use_spectral_norm=True, use_batch_norm=True, 123 | squeeze_factor=8, stride=1, padding=0, bias=True): 124 | """ constructor for the layer """ 125 | 126 | from torch.nn import Conv2d, Parameter, \ 127 | Softmax, ConvTranspose2d, BatchNorm2d 128 | 129 | # base constructor call 130 | super().__init__() 131 | 132 | # state of the layer 133 | self.activation = activation 134 | self.gamma = Parameter(th.zeros(1)) 135 | 136 | self.in_channels = in_channels 137 | self.out_channels = out_channels 138 | self.squeezed_channels = in_channels // squeeze_factor 139 | self.use_batch_norm = use_batch_norm 140 | 141 | # Modules required for computations 142 | if transpose_conv: 143 | self.query_conv = ConvTranspose2d( # query convolution 144 | in_channels=in_channels, 145 | out_channels=in_channels // squeeze_factor, 146 | kernel_size=kernel_size, 147 | stride=stride, 148 | padding=padding, 149 | bias=bias 150 | ) 151 | 152 | self.key_conv = ConvTranspose2d( 153 | in_channels=in_channels, 154 | out_channels=in_channels // squeeze_factor, 155 | kernel_size=kernel_size, 156 | stride=stride, 157 | padding=padding, 158 | bias=bias 159 | ) 160 | 161 | self.value_conv = ConvTranspose2d( 162 | in_channels=in_channels, 163 | out_channels=out_channels, 164 | kernel_size=kernel_size, 165 | stride=stride, 166 | padding=padding, 167 | bias=bias 168 | ) 169 | 170 | self.residual_conv = ConvTranspose2d( 171 | in_channels=in_channels, 172 | out_channels=out_channels, 173 | kernel_size=kernel_size, 174 | stride=stride, 175 | padding=padding, 176 | bias=bias 177 | ) if not use_spectral_norm else SpectralNorm( 178 | ConvTranspose2d( 179 | in_channels=in_channels, 180 | out_channels=out_channels, 181 | kernel_size=kernel_size, 182 | stride=stride, 183 | padding=padding, 184 | bias=bias 185 | ) 186 | ) 187 | 188 | else: 189 | self.query_conv = Conv2d( # query convolution 190 | in_channels=in_channels, 191 | out_channels=in_channels // squeeze_factor, 192 | kernel_size=kernel_size, 193 | stride=stride, 194 | padding=padding, 195 | bias=bias 196 | ) 197 | 198 | self.key_conv = Conv2d( 199 | in_channels=in_channels, 200 | out_channels=in_channels // squeeze_factor, 201 | kernel_size=kernel_size, 202 | stride=stride, 203 | padding=padding, 204 | bias=bias 205 | ) 206 | 207 | self.value_conv = Conv2d( 208 | in_channels=in_channels, 209 | out_channels=out_channels, 210 | kernel_size=kernel_size, 211 | stride=stride, 212 | padding=padding, 213 | bias=bias 214 | ) 215 | 216 | self.residual_conv = Conv2d( 217 | in_channels=in_channels, 218 | out_channels=out_channels, 219 | kernel_size=kernel_size, 220 | stride=stride, 221 | padding=padding, 222 | bias=bias 223 | ) if not use_spectral_norm else SpectralNorm( 224 | Conv2d( 225 | in_channels=in_channels, 226 | out_channels=out_channels, 227 | kernel_size=kernel_size, 228 | stride=stride, 229 | padding=padding, 230 | bias=bias 231 | ) 232 | ) 233 | 234 | # softmax module for applying attention 235 | self.softmax = Softmax(dim=-1) 236 | self.batch_norm = BatchNorm2d(out_channels) 237 | 238 | def forward(self, x): 239 | """ 240 | forward computations of the layer 241 | :param x: input feature maps (B x C x H x W) 242 | :return: 243 | out: self attention value + input feature (B x O x H x W) 244 | attention: attention map (B x C x H x W) 245 | """ 246 | 247 | # extract the batch size of the input tensor 248 | m_batchsize, _, _, _ = x.size() 249 | 250 | # create the query projection 251 | proj_query = self.query_conv(x).view( 252 | m_batchsize, self.squeezed_channels, -1).permute(0, 2, 1) # B x (N) x C 253 | 254 | # create the key projection 255 | proj_key = self.key_conv(x).view( 256 | m_batchsize, self.squeezed_channels, -1) # B x C x (N) 257 | 258 | # calculate the attention maps 259 | energy = th.bmm(proj_query, proj_key) # energy 260 | attention = self.softmax(energy) # attention (B x (N) x (N)) 261 | 262 | # create the value projection 263 | proj_value = self.value_conv(x).view( 264 | m_batchsize, self.out_channels, -1) # B X C X N 265 | 266 | # calculate the output 267 | out = th.bmm(proj_value, attention.permute(0, 2, 1)) 268 | 269 | # calculate the residual output 270 | res_out = self.residual_conv(x) 271 | 272 | out = out.view(m_batchsize, self.out_channels, 273 | res_out.shape[-2], res_out.shape[-1]) 274 | 275 | attention = attention.view(m_batchsize, -1, 276 | res_out.shape[-2], res_out.shape[-1]) 277 | 278 | if self.use_batch_norm: 279 | res_out = self.batch_norm(res_out) 280 | 281 | if self.activation is not None: 282 | out = self.activation(out) 283 | res_out = self.activation(res_out) 284 | 285 | # apply the residual connections 286 | out = (self.gamma * out) + ((1 - self.gamma) * res_out) 287 | return out, attention 288 | 289 | 290 | class SpectralNorm(th.nn.Module): 291 | """ 292 | Wrapper around a Torch module which applies spectral Normalization 293 | """ 294 | 295 | # TODO complete the documentation for this Layer 296 | 297 | def __init__(self, module, name='weight', power_iterations=1): 298 | super(SpectralNorm, self).__init__() 299 | self.module = module 300 | self.name = name 301 | self.power_iterations = power_iterations 302 | if not self._made_params(): 303 | self._make_params() 304 | 305 | @staticmethod 306 | def l2normalize(v, eps=1e-12): 307 | return v / (v.norm() + eps) 308 | 309 | def _update_u_v(self): 310 | u = getattr(self.module, self.name + "_u") 311 | v = getattr(self.module, self.name + "_v") 312 | w = getattr(self.module, self.name + "_bar") 313 | 314 | height = w.data.shape[0] 315 | for _ in range(self.power_iterations): 316 | v.data = self.l2normalize(th.mv(th.t(w.view(height, -1).data), u.data)) 317 | u.data = self.l2normalize(th.mv(w.view(height, -1).data, v.data)) 318 | 319 | sigma = u.dot(w.view(height, -1).mv(v)) 320 | setattr(self.module, self.name, w / sigma.expand_as(w)) 321 | 322 | def _made_params(self): 323 | try: 324 | getattr(self.module, self.name + "_u") 325 | getattr(self.module, self.name + "_v") 326 | getattr(self.module, self.name + "_bar") 327 | return True 328 | except AttributeError: 329 | return False 330 | 331 | def _make_params(self): 332 | from torch.nn import Parameter 333 | 334 | w = getattr(self.module, self.name) 335 | 336 | height = w.data.shape[0] 337 | width = w.view(height, -1).data.shape[1] 338 | 339 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 340 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 341 | u.data = self.l2normalize(u.data) 342 | v.data = self.l2normalize(v.data) 343 | w_bar = Parameter(w.data) 344 | 345 | del self.module._parameters[self.name] 346 | 347 | self.module.register_parameter(self.name + "_u", u) 348 | self.module.register_parameter(self.name + "_v", v) 349 | self.module.register_parameter(self.name + "_bar", w_bar) 350 | 351 | def forward(self, *args): 352 | self._update_u_v() 353 | return self.module.forward(*args) 354 | 355 | 356 | class IgnoreAttentionMap(th.nn.Module): 357 | """ 358 | A petty module to ignore the attention 359 | map output by the self_attention layer 360 | """ 361 | 362 | def __init__(self): 363 | """ has nothing and does nothing apart from super calls """ 364 | super().__init__() 365 | 366 | def forward(self, inp): 367 | """ 368 | ignores the attention_map the obtained input. and returns the features 369 | :param inp: (features, attention_maps) 370 | :return: output => features 371 | """ 372 | return inp[0] 373 | -------------------------------------------------------------------------------- /attn_gan_pytorch/Losses.py: -------------------------------------------------------------------------------- 1 | """ Module implementing various loss functions """ 2 | 3 | import torch as th 4 | 5 | 6 | # TODOcomplete Major rewrite: change the interface to use only predictions 7 | # for real and fake samples 8 | # The interface doesn't need to change to only use predictions for real and fake samples 9 | # because for loss such as WGAN-GP requires the samples to calculate gradient penalty 10 | 11 | class GANLoss: 12 | """ 13 | Base class for all losses 14 | Note that the gen_loss also has 15 | """ 16 | 17 | def __init__(self, device, dis): 18 | self.device = device 19 | self.dis = dis 20 | 21 | def dis_loss(self, real_samps, fake_samps): 22 | raise NotImplementedError("dis_loss method has not been implemented") 23 | 24 | def gen_loss(self, real_samps, fake_samps): 25 | raise NotImplementedError("gen_loss method has not been implemented") 26 | 27 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors): 28 | raise NotImplementedError("conditional_dis_loss method has not been implemented") 29 | 30 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors): 31 | raise NotImplementedError("conditional_gen_loss method has not been implemented") 32 | 33 | 34 | class StandardGAN(GANLoss): 35 | 36 | def __init__(self, dev, dis): 37 | from torch.nn import BCELoss 38 | 39 | super().__init__(dev, dis) 40 | 41 | # define the criterion object 42 | self.criterion = BCELoss() 43 | 44 | def dis_loss(self, real_samps, fake_samps): 45 | # calculate the real loss: 46 | real_loss = self.criterion(th.squeeze(self.dis(real_samps)), 47 | th.ones(real_samps.shape[0]).to(self.device)) 48 | # calculate the fake loss: 49 | fake_loss = self.criterion(th.squeeze(self.dis(fake_samps)), 50 | th.zeros(fake_samps.shape[0]).to(self.device)) 51 | 52 | # return final loss as average of the two: 53 | return (real_loss + fake_loss) / 2 54 | 55 | def gen_loss(self, _, fake_samps): 56 | return self.criterion(th.squeeze(self.dis(fake_samps)), 57 | th.ones(fake_samps.shape[0]).to(self.device)) 58 | 59 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors): 60 | # calculate the real loss: 61 | real_loss = self.criterion(th.squeeze(self.dis(real_samps, conditional_vectors)), 62 | th.ones(real_samps.shape[0]).to(self.device)) 63 | # calculate the fake loss: 64 | fake_loss = self.criterion(th.squeeze(self.dis(fake_samps, conditional_vectors)), 65 | th.zeros(fake_samps.shape[0]).to(self.device)) 66 | 67 | # return final loss as average of the two: 68 | return (real_loss + fake_loss) / 2 69 | 70 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors): 71 | return self.criterion(th.squeeze(self.dis(fake_samps, conditional_vectors)), 72 | th.ones(fake_samps.shape[0]).to(self.device)) 73 | 74 | 75 | class LSGAN(GANLoss): 76 | 77 | def __init__(self, device, dis): 78 | super().__init__(device, dis) 79 | 80 | def dis_loss(self, real_samps, fake_samps): 81 | return 0.5 * (((th.mean(self.dis(real_samps)) - 1) ** 2) 82 | + (th.mean(self.dis(fake_samps))) ** 2) 83 | 84 | def gen_loss(self, _, fake_samps): 85 | return 0.5 * ((th.mean(self.dis(fake_samps)) - 1) ** 2) 86 | 87 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors): 88 | return 0.5 * (((th.mean(self.dis(real_samps, conditional_vectors)) - 1) ** 2) 89 | + (th.mean(self.dis(fake_samps, conditional_vectors))) ** 2) 90 | 91 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors): 92 | return 0.5 * ((th.mean(self.dis(fake_samps, conditional_vectors)) - 1) ** 2) 93 | 94 | 95 | class HingeGAN(GANLoss): 96 | 97 | def __init__(self, device, dis): 98 | super().__init__(device, dis) 99 | 100 | def dis_loss(self, real_samps, fake_samps): 101 | return (th.mean(th.nn.ReLU()(1 - self.dis(real_samps))) + 102 | th.mean(th.nn.ReLU()(1 + self.dis(fake_samps)))) 103 | 104 | def gen_loss(self, real_samps, fake_samps): 105 | return -th.mean(self.dis(fake_samps)) 106 | 107 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors): 108 | return (th.mean(th.nn.ReLU()(1 - self.dis(real_samps, conditional_vectors))) + 109 | th.mean(th.nn.ReLU()(1 + self.dis(fake_samps, conditional_vectors)))) 110 | 111 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors): 112 | return -th.mean(self.dis(fake_samps, conditional_vectors)) 113 | 114 | 115 | class RelativisticAverageHingeGAN(GANLoss): 116 | 117 | def __init__(self, device, dis): 118 | super().__init__(device, dis) 119 | 120 | def dis_loss(self, real_samps, fake_samps): 121 | # difference between real and fake: 122 | r_f_diff = self.dis(real_samps) - th.mean(self.dis(fake_samps)) 123 | 124 | # difference between fake and real samples 125 | f_r_diff = self.dis(fake_samps) - th.mean(self.dis(real_samps)) 126 | 127 | # return the loss 128 | return (th.mean(th.nn.ReLU()(1 - r_f_diff)) 129 | + th.mean(th.nn.ReLU()(1 + f_r_diff))) 130 | 131 | def gen_loss(self, real_samps, fake_samps): 132 | # difference between real and fake: 133 | r_f_diff = self.dis(real_samps) - th.mean(self.dis(fake_samps)) 134 | 135 | # difference between fake and real samples 136 | f_r_diff = self.dis(fake_samps) - th.mean(self.dis(real_samps)) 137 | 138 | # return the loss 139 | return (th.mean(th.nn.ReLU()(1 + r_f_diff)) 140 | + th.mean(th.nn.ReLU()(1 - f_r_diff))) 141 | 142 | def conditional_dis_loss(self, real_samps, fake_samps, conditional_vectors): 143 | # difference between real and fake: 144 | r_f_diff = self.dis(real_samps, conditional_vectors) \ 145 | - th.mean(self.dis(fake_samps, conditional_vectors)) 146 | 147 | # difference between fake and real samples 148 | f_r_diff = self.dis(fake_samps, conditional_vectors) \ 149 | - th.mean(self.dis(real_samps, conditional_vectors)) 150 | 151 | # return the loss 152 | return (th.mean(th.nn.ReLU()(1 - r_f_diff)) 153 | + th.mean(th.nn.ReLU()(1 + f_r_diff))) 154 | 155 | def conditional_gen_loss(self, real_samps, fake_samps, conditional_vectors): 156 | # difference between real and fake: 157 | r_f_diff = self.dis(real_samps, conditional_vectors) \ 158 | - th.mean(self.dis(fake_samps, conditional_vectors)) 159 | 160 | # difference between fake and real samples 161 | f_r_diff = self.dis(fake_samps, conditional_vectors) \ 162 | - th.mean(self.dis(real_samps, conditional_vectors)) 163 | 164 | # return the loss 165 | return (th.mean(th.nn.ReLU()(1 + r_f_diff)) 166 | + th.mean(th.nn.ReLU()(1 - f_r_diff))) 167 | -------------------------------------------------------------------------------- /attn_gan_pytorch/Networks.py: -------------------------------------------------------------------------------- 1 | """ module implements the networks functionality """ 2 | 3 | import torch as th 4 | import numpy as np 5 | import timeit 6 | import datetime 7 | import time 8 | import os 9 | 10 | 11 | class Network(th.nn.Module): 12 | """ General module that creates a Network from the configuration provided 13 | Extends a PyTorch Module 14 | args: 15 | modules: list of PyTorch layers (nn.Modules) 16 | """ 17 | 18 | def __init__(self, modules): 19 | """ derived constructor """ 20 | 21 | # make a call to Module constructor for allowing 22 | # us to attach required modules 23 | super().__init__() 24 | 25 | self.model = th.nn.Sequential(*modules) 26 | 27 | def forward(self, x): 28 | """ 29 | forward computations 30 | :param x: input 31 | :return: y => output features volume 32 | """ 33 | return self.model(x) 34 | 35 | 36 | class Generator(Network): 37 | """ 38 | Generator is an extension of a Generic Network 39 | 40 | args: 41 | modules: same as for Network 42 | latent_size: latent size of the Generator (GAN) 43 | """ 44 | 45 | def __init__(self, modules, latent_size): 46 | super().__init__(modules) 47 | 48 | # attach the latent size for the GAN here 49 | self.latent_size = latent_size 50 | 51 | 52 | class ConditionalGenerator(Generator): 53 | """ Conditional Generator is a special case of a generator 54 | well nothing special more than just the name. Nevertheless, 55 | something does lie in name. (Niki is the name 56 | I can't stop thinking about :blush:) 57 | 58 | args: 59 | modules: same as for Network 60 | latent_size: latent size of the Generator (GAN) 61 | Note that latent_size also includes the size of the 62 | conditional labels 63 | """ 64 | pass 65 | 66 | 67 | class Discriminator(Network): 68 | pass 69 | 70 | 71 | class ConditionalDiscriminator(Discriminator): 72 | """ 73 | The conditional variant of the Discriminator which (discriminator) 74 | is just further down the Network class tree. 75 | 76 | args: 77 | modules: Note that this list of modules must not contain the final prediction 78 | layer. This only reduces the spatial dimension to 79 | (reduced_height x reduced_width) specifically. 80 | embedding_size: size of the conditional embedding 81 | last_module: th.nn.Module which makes the conditional prediction 82 | """ 83 | 84 | def __init__(self, modules, last_module): 85 | super().__init__(modules) 86 | 87 | # attach the last module separately here: 88 | self.last_module = last_module 89 | 90 | # adding the last projector conv layer which 91 | # concatenates the text embedding prior to prediction 92 | # calculation. 93 | 94 | def forward(self, x, embedding): 95 | """ 96 | The forward pass of the Conditional Discriminator. 97 | :param x: input images tensor 98 | :param embedding: conditional vector 99 | :return: predictions => scores for the inputs 100 | """ 101 | # obtain the reduced volume: 102 | reduced_volume = super().forward(x) 103 | 104 | # concatenate the embeddings to reduced_volume here: 105 | cat = th.unsqueeze(th.unsqueeze(embedding, -1), -1) 106 | # spatial replication 107 | cat = cat.expand(cat.shape[0], cat.shape[1], 108 | reduced_volume.shape[2], reduced_volume.shape[3]) 109 | final_input = th.cat((reduced_volume, cat), dim=1) 110 | 111 | # apply the last module to obtain the predictions: 112 | prediction_scores = self.last_module(final_input) 113 | 114 | # return the prediction scores: 115 | return prediction_scores 116 | 117 | 118 | class GAN: 119 | """ 120 | Unconditional GAN 121 | 122 | args: 123 | gen: Generator object 124 | dis: Discriminator object 125 | device: torch.device() for running on GPU or CPU 126 | default = torch.device("cpu") 127 | """ 128 | 129 | def __init__(self, gen, dis, 130 | device=th.device("cpu")): 131 | """ constructor for the class """ 132 | assert isinstance(gen, Generator), "gen is not an Unconditional Generator" 133 | assert isinstance(dis, Discriminator), "dis is not an Unconditional Discriminator" 134 | 135 | # define the state of the object 136 | self.generator = gen.to(device) 137 | self.discriminator = dis.to(device) 138 | self.device = device 139 | 140 | # by default the generator and discriminator are in eval mode 141 | self.generator.eval() 142 | self.discriminator.eval() 143 | 144 | def generate_samples(self, num_samples): 145 | """ 146 | generate samples using this gan 147 | :param num_samples: number of samples to be generated 148 | :return: generated samples tensor: (B x H x W x C) 149 | """ 150 | noise = th.randn(num_samples, self.generator.latent_size).to(self.device) 151 | generated_images = self.generator(noise).detach() 152 | 153 | # reshape the generated images 154 | generated_images = generated_images.permute(0, 2, 3, 1) 155 | 156 | return generated_images 157 | 158 | def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn): 159 | """ 160 | performs one step of weight update on discriminator using the batch of data 161 | :param dis_optim: discriminator optimizer 162 | :param noise: input noise of sample generation 163 | :param real_batch: real samples batch 164 | :param loss_fn: loss function to be used (object of GANLoss) 165 | :return: current loss 166 | """ 167 | 168 | # generate a batch of samples 169 | fake_samples = self.generator(noise).detach() 170 | 171 | loss = loss_fn.dis_loss(real_batch, fake_samples) 172 | 173 | # optimize discriminator 174 | dis_optim.zero_grad() 175 | loss.backward() 176 | dis_optim.step() 177 | 178 | return loss.item() 179 | 180 | def optimize_generator(self, gen_optim, noise, real_batch, loss_fn): 181 | """ 182 | performs one step of weight update on generator using the batch of data 183 | :param gen_optim: generator optimizer 184 | :param noise: input noise of sample generation 185 | :param real_batch: real samples batch 186 | :param loss_fn: loss function to be used (object of GANLoss) 187 | :return: current loss 188 | """ 189 | 190 | # generate a batch of samples 191 | fake_samples = self.generator(noise) 192 | 193 | loss = loss_fn.gen_loss(real_batch, fake_samples) 194 | 195 | # optimize discriminator 196 | gen_optim.zero_grad() 197 | loss.backward() 198 | gen_optim.step() 199 | 200 | return loss.item() 201 | 202 | @staticmethod 203 | def create_grid(samples, img_file): 204 | """ 205 | utility function to create a grid of GAN samples 206 | :param samples: generated samples for storing 207 | :param img_file: name of file to write 208 | :return: None (saves a file) 209 | """ 210 | from torchvision.utils import save_image 211 | from numpy import sqrt 212 | 213 | samples = th.clamp((samples / 2) + 0.5, min=0, max=1) 214 | 215 | # save the images: 216 | save_image(samples, img_file, nrow=int(sqrt(samples.shape[0]))) 217 | 218 | def train(self, data, gen_optim, dis_optim, loss_fn, 219 | start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1, 220 | data_percentage=100, num_samples=36, 221 | log_dir=None, sample_dir="./samples", 222 | save_dir="./models"): 223 | 224 | # TODO write the documentation for this method 225 | 226 | # turn the generator and discriminator into train mode 227 | self.generator.train() 228 | self.discriminator.train() 229 | 230 | assert isinstance(gen_optim, th.optim.Optimizer), \ 231 | "gen_optim is not an Optimizer" 232 | assert isinstance(dis_optim, th.optim.Optimizer), \ 233 | "dis_optim is not an Optimizer" 234 | 235 | print("Starting the training process ... ") 236 | 237 | # create fixed_input for debugging 238 | fixed_input = th.randn(num_samples, 239 | self.generator.latent_size, 1, 1).to(self.device) 240 | 241 | # create a global time counter 242 | global_time = time.time() 243 | 244 | for epoch in range(start, num_epochs + 1): 245 | start = timeit.default_timer() # record time at the start of epoch 246 | 247 | print("\nEpoch: %d" % epoch) 248 | total_batches = len(iter(data)) 249 | 250 | limit = int((data_percentage / 100) * total_batches) 251 | 252 | for (i, batch) in enumerate(data, 1): 253 | 254 | # extract current batch of data for training 255 | images = batch.to(self.device) 256 | 257 | gan_input = th.randn(images.shape[0], 258 | self.generator.latent_size, 1, 1).to(self.device) 259 | 260 | # optimize the discriminator: 261 | dis_loss = self.optimize_discriminator(dis_optim, gan_input, 262 | images, loss_fn) 263 | 264 | # optimize the generator: 265 | # resample from the latent noise 266 | gan_input = th.randn(images.shape[0], 267 | self.generator.latent_size, 1, 1).to(self.device) 268 | gen_loss = self.optimize_generator(gen_optim, gan_input, 269 | images, loss_fn) 270 | 271 | # provide a loss feedback 272 | if i % int(limit / feedback_factor) == 0 or i == 1: 273 | elapsed = time.time() - global_time 274 | elapsed = str(datetime.timedelta(seconds=elapsed)) 275 | print("Elapsed [%s] batch: %d d_loss: %f g_loss: %f" 276 | % (elapsed, i, dis_loss, gen_loss)) 277 | 278 | # also write the losses to the log file: 279 | if log_dir is not None: 280 | log_file = os.path.join(log_dir, "loss.log") 281 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 282 | with open(log_file, "a") as log: 283 | log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n") 284 | 285 | # create a grid of samples and save it 286 | os.makedirs(sample_dir, exist_ok=True) 287 | gen_img_file = os.path.join(sample_dir, "gen_" + 288 | str(epoch) + "_" + 289 | str(i) + ".png") 290 | self.create_grid(self.generator(fixed_input).detach(), gen_img_file) 291 | 292 | if i > limit: 293 | break 294 | 295 | # calculate the time required for the epoch 296 | stop = timeit.default_timer() 297 | print("Time taken for epoch: %.3f secs" % (stop - start)) 298 | 299 | if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs: 300 | os.makedirs(save_dir, exist_ok=True) 301 | gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth") 302 | dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth") 303 | 304 | th.save(self.generator.state_dict(), gen_save_file) 305 | th.save(self.discriminator.state_dict(), dis_save_file) 306 | 307 | print("Training completed ...") 308 | 309 | # return the generator and discriminator back to eval mode 310 | self.generator.eval() 311 | self.discriminator.eval() 312 | 313 | 314 | # TODOcomplete implement conditional gan variant of this 315 | # conditional gan implemented 316 | 317 | class ConditionalGAN(GAN): 318 | """ 319 | Conditional GAN. Actually modifies the calls 320 | for optimize discriminator, optimize generator and train 321 | 322 | args: 323 | gen: ConditionalGenerator object 324 | dis: ConditionalDiscriminator object 325 | device: torch.device() for running on GPU or CPU 326 | default = torch.device("cpu") 327 | """ 328 | 329 | def __init__(self, gen, dis, device=th.device("cpu")): 330 | """ constructor for this derived class """ 331 | 332 | # some more specific checks here 333 | assert isinstance(gen, ConditionalGenerator), \ 334 | "gen is not a Conditional Generator" 335 | assert isinstance(dis, ConditionalDiscriminator), \ 336 | "dis is not a Conditional Discriminator" 337 | 338 | super().__init__(gen, dis, device) 339 | 340 | @staticmethod 341 | def randomize(correct_labels): 342 | """ 343 | static helper for mismatching the given labels 344 | :param correct_labels: input correct labels 345 | :return: shuffled labels 346 | (Note, that this behaviour is not 347 | guaranteed to create a mismatch for every sample) 348 | """ 349 | return correct_labels[np.random.permutation(correct_labels[0]), :] 350 | 351 | def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn, 352 | conditional_vectors, matching_aware=False, 353 | randomizer=None): 354 | """ 355 | performs one step of weight update on discriminator using the batch of data 356 | :param dis_optim: discriminator optimizer 357 | :param noise: input noise of sample generation 358 | :param real_batch: real samples batch 359 | :param loss_fn: loss function to be used (object of GANLoss) 360 | :param conditional_vectors: for conditional discrimination 361 | :param matching_aware: boolean for whether to use matching aware discriminator 362 | :param randomizer: function object for randomizing the conditional vectors. 363 | i.e. to mismatch conditional vectors 364 | uses the default randomize function here 365 | :return: current loss 366 | """ 367 | 368 | # generate a batch of samples 369 | fake_samples = self.generator(noise).detach() 370 | 371 | loss = loss_fn.conditional_dis_loss(real_batch, fake_samples, 372 | conditional_vectors) 373 | 374 | # if matching aware discrimination is to be used: 375 | if matching_aware: 376 | loss += loss_fn.conditional_dis_loss( 377 | real_batch, real_batch, 378 | randomizer(conditional_vectors) 379 | if randomizer is not None 380 | else self.randomize(conditional_vectors) 381 | ) 382 | loss = loss / 2 383 | 384 | # optimize discriminator 385 | dis_optim.zero_grad() 386 | loss.backward() 387 | dis_optim.step() 388 | 389 | return loss.item() 390 | 391 | def optimize_generator(self, gen_optim, noise, real_batch, loss_fn, 392 | conditional_vectors): 393 | """ 394 | performs one step of weight update on generator using the batch of data 395 | :param gen_optim: generator optimizer 396 | :param noise: input noise of sample generation 397 | :param real_batch: real samples batch 398 | :param loss_fn: loss function to be used (object of GANLoss) 399 | :param conditional_vectors: for conditional discrimination 400 | :return: current loss 401 | """ 402 | 403 | # generate a batch of samples 404 | fake_samples = self.generator(noise) 405 | 406 | loss = loss_fn.conditional_gen_loss(real_batch, fake_samples, 407 | conditional_vectors) 408 | 409 | # optimize discriminator 410 | gen_optim.zero_grad() 411 | # retain graph is true for applying regularization on the 412 | # conditional input 413 | loss.backward(retain_graph=True) 414 | gen_optim.step() 415 | 416 | return loss.item() 417 | 418 | def train(self, data, gen_optim, dis_optim, loss_fn, 419 | start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1, 420 | data_percentage=100, num_samples=36, 421 | matching_aware=False, mismatcher=None, 422 | log_dir=None, sample_dir="./samples", 423 | save_dir="./models"): 424 | 425 | # TODO write the documentation for this method 426 | # This is the limit of procrastination now :D 427 | # Just note that data here gives image, label (one-hot encoded) 428 | # in every batch 429 | 430 | # turn the generator and discriminator into train mode 431 | self.generator.train() 432 | self.discriminator.train() 433 | 434 | assert isinstance(gen_optim, th.optim.Optimizer), \ 435 | "gen_optim is not an Optimizer" 436 | assert isinstance(dis_optim, th.optim.Optimizer), \ 437 | "dis_optim is not an Optimizer" 438 | 439 | print("Starting the training process ... ") 440 | 441 | # create fixed_input for debugging 442 | _, debug_labels = iter(data).next() 443 | debug_labels = th.unsqueeze(th.unsqueeze(debug_labels, -1), -1).to(self.device) 444 | fixed_latent_vectors = th.randn( 445 | num_samples, 446 | self.generator.latent_size - debug_labels.shape[1], 447 | 1, 1 448 | ).to(self.device) 449 | 450 | fixed_input = th.cat((fixed_latent_vectors, debug_labels), dim=1) 451 | 452 | # create a global time counter 453 | global_time = time.time() 454 | 455 | for epoch in range(start, num_epochs + 1): 456 | start = timeit.default_timer() # record time at the start of epoch 457 | 458 | print("\nEpoch: %d" % epoch) 459 | total_batches = len(iter(data)) 460 | 461 | limit = int((data_percentage / 100) * total_batches) 462 | 463 | for (i, batch) in enumerate(data, 1): 464 | 465 | # extract current batch of data for training 466 | images, labels = batch 467 | images, labels = images.to(self.device), labels.to(self.device) 468 | expanded_labels = th.unsqueeze(th.unsqueeze(labels, -1), -1) 469 | 470 | latent_input = th.randn( 471 | images.shape[0], 472 | self.generator.latent_size - expanded_labels.shape[1], 473 | 1, 1 474 | ).to(self.device) 475 | 476 | gan_input = th.cat((latent_input, expanded_labels), dim=1) 477 | 478 | # optimize the discriminator: 479 | dis_loss = self.optimize_discriminator(dis_optim, gan_input, 480 | images, loss_fn, labels, 481 | matching_aware, mismatcher) 482 | 483 | # optimize the generator: 484 | # resample from the latent noise 485 | latent_input = th.randn( 486 | images.shape[0], 487 | self.generator.latent_size - expanded_labels.shape[1], 488 | 1, 1 489 | ).to(self.device) 490 | gan_input = th.cat((latent_input, expanded_labels), dim=1) 491 | gen_loss = self.optimize_generator(gen_optim, gan_input, 492 | images, loss_fn, labels) 493 | 494 | # provide a loss feedback 495 | if i % int(limit / feedback_factor) == 0 or i == 1: 496 | elapsed = time.time() - global_time 497 | elapsed = str(datetime.timedelta(seconds=elapsed)) 498 | print("Elapsed [%s] batch: %d d_loss: %f g_loss: %f" 499 | % (elapsed, i, dis_loss, gen_loss)) 500 | 501 | # also write the losses to the log file: 502 | if log_dir is not None: 503 | log_file = os.path.join(log_dir, "loss.log") 504 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 505 | with open(log_file, "a") as log: 506 | log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n") 507 | 508 | # create a grid of samples and save it 509 | os.makedirs(sample_dir, exist_ok=True) 510 | gen_img_file = os.path.join(sample_dir, "gen_" + 511 | str(epoch) + "_" + 512 | str(i) + ".png") 513 | self.create_grid(self.generator(fixed_input).detach(), gen_img_file) 514 | 515 | if i > limit: 516 | break 517 | 518 | # calculate the time required for the epoch 519 | stop = timeit.default_timer() 520 | print("Time taken for epoch: %.3f secs" % (stop - start)) 521 | 522 | if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs: 523 | os.makedirs(save_dir, exist_ok=True) 524 | gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth") 525 | dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth") 526 | 527 | th.save(self.generator.state_dict(), gen_save_file) 528 | th.save(self.discriminator.state_dict(), dis_save_file) 529 | 530 | print("Training completed ...") 531 | 532 | # return the generator and discriminator back to eval mode 533 | self.generator.eval() 534 | self.discriminator.eval() 535 | -------------------------------------------------------------------------------- /attn_gan_pytorch/Utils.py: -------------------------------------------------------------------------------- 1 | """ module contains small utils for parsing configurations """ 2 | 3 | import torch as th 4 | 5 | 6 | def get_act_fn(fn_name): 7 | """ 8 | helper for creating the activation function 9 | :param fn_name: string containing act_fn name 10 | currently supports: [tanh, sigmoid, relu, lrelu] 11 | :return: fn => PyTorch activation function 12 | """ 13 | fn_name = fn_name.lower() 14 | 15 | if fn_name == "tanh": 16 | fn = th.nn.Tanh() 17 | 18 | elif fn_name == "sigmoid": 19 | fn = th.nn.Sigmoid() 20 | 21 | elif fn_name == "relu": 22 | fn = th.nn.ReLU() 23 | 24 | elif "lrelu" in fn_name: 25 | negative_slope = float(fn_name.split("(")[-1][:-1]) 26 | fn = th.nn.LeakyReLU(negative_slope=negative_slope) 27 | 28 | else: 29 | raise NotImplementedError("requested activation function is not implemented") 30 | 31 | return fn 32 | 33 | 34 | def get_layer(layer): 35 | """ 36 | static private helper for creating a layer from the given conf 37 | :param layer: dict containing info 38 | :return: lay => PyTorch layer 39 | """ 40 | from attn_gan_pytorch.CustomLayers import SelfAttention, \ 41 | SpectralNorm, IgnoreAttentionMap, FullAttention 42 | from torch.nn import Sequential, Conv2d, Dropout2d, ConvTranspose2d, BatchNorm2d 43 | from attn_gan_pytorch.ConfigManagement import parse2tuple 44 | 45 | # lowercase the name 46 | name = layer.name.lower() 47 | 48 | if name == "conv": 49 | in_channels, out_channels = parse2tuple(layer.channels) 50 | kernel_size = parse2tuple(layer.kernel_dims) 51 | stride = parse2tuple(layer.stride) 52 | padding = parse2tuple(layer.padding) 53 | bias = layer.bias 54 | act_fn = get_act_fn(layer.activation) 55 | 56 | if hasattr(layer, "spectral_norm") and layer.spectral_norm: 57 | if layer.batch_norm: 58 | mod_layer = Sequential( 59 | SpectralNorm(Conv2d(in_channels, out_channels, kernel_size, 60 | stride, padding, bias=bias)), 61 | BatchNorm2d(out_channels), 62 | act_fn 63 | ) 64 | else: 65 | mod_layer = Sequential( 66 | SpectralNorm(Conv2d(in_channels, out_channels, kernel_size, 67 | stride, padding, bias=bias)), 68 | act_fn 69 | ) 70 | else: 71 | if layer.batch_norm: 72 | mod_layer = Sequential( 73 | Conv2d(in_channels, out_channels, kernel_size, 74 | stride, padding, bias=bias), 75 | BatchNorm2d(out_channels), 76 | act_fn 77 | ) 78 | else: 79 | mod_layer = Sequential( 80 | Conv2d(in_channels, out_channels, kernel_size, 81 | stride, padding, bias=bias), 82 | act_fn 83 | ) 84 | 85 | elif name == "conv_transpose": 86 | in_channels, out_channels = parse2tuple(layer.channels) 87 | kernel_size = parse2tuple(layer.kernel_dims) 88 | stride = parse2tuple(layer.stride) 89 | padding = parse2tuple(layer.padding) 90 | bias = layer.bias 91 | act_fn = get_act_fn(layer.activation) 92 | 93 | if hasattr(layer, "spectral_norm") and layer.spectral_norm: 94 | if layer.batch_norm: 95 | mod_layer = Sequential( 96 | SpectralNorm(ConvTranspose2d(in_channels, out_channels, kernel_size, 97 | stride, padding, bias=bias)), 98 | BatchNorm2d(out_channels), 99 | act_fn 100 | ) 101 | else: 102 | mod_layer = Sequential( 103 | SpectralNorm(ConvTranspose2d(in_channels, out_channels, kernel_size, 104 | stride, padding, bias=bias)), 105 | act_fn 106 | ) 107 | else: 108 | if layer.batch_norm: 109 | mod_layer = Sequential( 110 | ConvTranspose2d(in_channels, out_channels, kernel_size, 111 | stride, padding, bias=bias), 112 | BatchNorm2d(out_channels), 113 | act_fn 114 | ) 115 | else: 116 | mod_layer = Sequential( 117 | ConvTranspose2d(in_channels, out_channels, kernel_size, 118 | stride, padding, bias=bias), 119 | act_fn 120 | ) 121 | 122 | elif name == "dropout": 123 | drop_probability = layer.drop_prob 124 | mod_layer = Dropout2d(p=drop_probability, inplace=False) 125 | 126 | elif name == "batch_norm": 127 | channel_num = layer.num_channels 128 | mod_layer = BatchNorm2d(channel_num) 129 | 130 | elif name == "ignore_attn_maps": 131 | mod_layer = IgnoreAttentionMap() 132 | 133 | elif name == "self_attention": 134 | channels = layer.channels 135 | squeeze_factor = layer.squeeze_factor 136 | bias = layer.bias 137 | 138 | if hasattr(layer, "activation"): 139 | act_fn = get_act_fn(layer.activation) 140 | mod_layer = SelfAttention(channels, act_fn, squeeze_factor, bias) 141 | else: 142 | mod_layer = SelfAttention(channels, None, squeeze_factor, bias) 143 | 144 | elif name == "full_attention": 145 | in_channels, out_channels = parse2tuple(layer.channels) 146 | kernel_size = parse2tuple(layer.kernel_dims) 147 | squeeze_factor = layer.squeeze_factor 148 | stride = parse2tuple(layer.stride) 149 | use_batch_norm = layer.use_batch_norm 150 | use_spectral_norm = layer.use_spectral_norm 151 | padding = parse2tuple(layer.padding) 152 | transpose_conv = layer.transpose_conv 153 | bias = layer.bias 154 | 155 | if hasattr(layer, "activation"): 156 | act_fn = get_act_fn(layer.activation) 157 | mod_layer = FullAttention(in_channels, out_channels, act_fn, 158 | kernel_size, transpose_conv, 159 | use_spectral_norm, use_batch_norm, 160 | squeeze_factor, stride, padding, bias) 161 | else: 162 | mod_layer = FullAttention(in_channels, out_channels, None, 163 | kernel_size, transpose_conv, 164 | use_spectral_norm, use_batch_norm, 165 | squeeze_factor, stride, padding, bias) 166 | else: 167 | raise ValueError("unknown layer type requested") 168 | 169 | return mod_layer 170 | -------------------------------------------------------------------------------- /attn_gan_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | """ package implements Attentional gan as an extension of PyTorch nn.Module """ 2 | 3 | # import everything for flat package access also 4 | from attn_gan_pytorch import ConfigManagement 5 | from attn_gan_pytorch import CustomLayers 6 | from attn_gan_pytorch import Losses 7 | from attn_gan_pytorch import Networks 8 | from attn_gan_pytorch import Utils 9 | -------------------------------------------------------------------------------- /literature/self_attention_gan.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akanimax/attn_gan_pytorch/2cf3810963eaf00ebc642c9413a0a0ab79a4a7bc/literature/self_attention_gan.pdf -------------------------------------------------------------------------------- /samples/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the data folder 2 | data/ -------------------------------------------------------------------------------- /samples/data_processing/DataLoader.py: -------------------------------------------------------------------------------- 1 | """ Module for the data loading pipeline for the model to train """ 2 | 3 | import os 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class FlatDirectoryImageDataset(Dataset): 8 | """ pyTorch Dataset wrapper for the generic flat directory images dataset """ 9 | 10 | def __setup_files(self): 11 | """ 12 | private helper for setting up the files_list 13 | :return: files => list of paths of files 14 | """ 15 | file_names = os.listdir(self.data_dir) 16 | files = [] # initialize to empty list 17 | 18 | for file_name in file_names: 19 | possible_file = os.path.join(self.data_dir, file_name) 20 | if os.path.isfile(possible_file): 21 | files.append(possible_file) 22 | 23 | # return the files list 24 | return files 25 | 26 | def __init__(self, data_dir, transform=None): 27 | """ 28 | constructor for the class 29 | :param data_dir: path to the directory containing the data 30 | :param transform: transforms to be applied to the images 31 | """ 32 | # define the state of the object 33 | self.data_dir = data_dir 34 | self.transform = transform 35 | 36 | # setup the files for reading 37 | self.files = self.__setup_files() 38 | 39 | def __len__(self): 40 | """ 41 | compute the length of the dataset 42 | :return: len => length of dataset 43 | """ 44 | return len(self.files) 45 | 46 | def __getitem__(self, idx): 47 | """ 48 | obtain the image (read and transform) 49 | :param idx: index of the file required 50 | :return: img => image array 51 | """ 52 | from PIL import Image 53 | 54 | # read the image: 55 | img = Image.open(self.files[idx]) 56 | 57 | # apply the transforms on the image 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | 61 | # return the image: 62 | return img 63 | 64 | 65 | class FoldersDistributedDataset(Dataset): 66 | """ pyTorch Dataset wrapper for the MNIST dataset """ 67 | 68 | def __setup_files(self): 69 | """ 70 | private helper for setting up the files_list 71 | :return: files => list of paths of files 72 | """ 73 | 74 | dir_names = os.listdir(self.data_dir) 75 | files = [] # initialize to empty list 76 | 77 | for dir_name in dir_names: 78 | file_path = os.path.join(self.data_dir, dir_name) 79 | file_names = os.listdir(file_path) 80 | for file_name in file_names: 81 | possible_file = os.path.join(file_path, file_name) 82 | if os.path.isfile(possible_file): 83 | files.append(possible_file) 84 | 85 | # return the files list 86 | return files 87 | 88 | def __init__(self, data_dir, transform=None): 89 | """ 90 | constructor for the class 91 | :param data_dir: path to the directory containing the data 92 | :param transform: transforms to be applied to the images 93 | """ 94 | # define the state of the object 95 | self.data_dir = data_dir 96 | self.transform = transform 97 | 98 | # setup the files for reading 99 | self.files = self.__setup_files() 100 | 101 | def __len__(self): 102 | """ 103 | compute the length of the dataset 104 | :return: len => length of dataset 105 | """ 106 | return len(self.files) 107 | 108 | def __getitem__(self, idx): 109 | """ 110 | obtain the image (read and transform) 111 | :param idx: index of the file required 112 | :return: img => image array 113 | """ 114 | from PIL import Image 115 | 116 | # read the image: 117 | img = Image.open(self.files[idx]) 118 | 119 | # apply the transforms on the image 120 | if self.transform is not None: 121 | img = self.transform(img) 122 | 123 | # convert the black and white image to RGB: 124 | img = img.expand(3, -1, -1) 125 | 126 | # return the image: 127 | return img 128 | 129 | 130 | def get_transform(new_size=None): 131 | """ 132 | obtain the image transforms required for the input data 133 | :param new_size: size of the resized images 134 | :return: image_transform => transform object from TorchVision 135 | """ 136 | from torchvision.transforms import ToTensor, Normalize, Compose, Resize 137 | 138 | if new_size is not None: 139 | image_transform = Compose([ 140 | Resize(new_size), 141 | ToTensor(), 142 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 143 | ]) 144 | 145 | else: 146 | image_transform = Compose([ 147 | ToTensor(), 148 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 149 | ]) 150 | return image_transform 151 | 152 | 153 | def get_data_loader(dataset, batch_size, num_workers): 154 | """ 155 | generate the data_loader from the given dataset 156 | :param dataset: F2T dataset 157 | :param batch_size: batch size of the data 158 | :param num_workers: num of parallel readers 159 | :return: dl => dataloader for the dataset 160 | """ 161 | from torch.utils.data import DataLoader 162 | 163 | dl = DataLoader( 164 | dataset, 165 | batch_size=batch_size, 166 | shuffle=True, 167 | num_workers=num_workers 168 | ) 169 | 170 | return dl -------------------------------------------------------------------------------- /samples/data_processing/__init__.py: -------------------------------------------------------------------------------- 1 | """ Package for keeping all the data processing utilities """ 2 | from data_processing import DataLoader -------------------------------------------------------------------------------- /samples/generate_loss_plots.py: -------------------------------------------------------------------------------- 1 | """ script for generating the loss plots from the Loss logs """ 2 | 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def read_loss_log(file_name, delimiter='\t'): 8 | """ 9 | read and load the loss values from a loss.log file 10 | :param file_name: path of the loss.log file 11 | :param delimiter: delimiter used to delimit the two columns 12 | :return: loss_val => numpy array [Iterations x 2] 13 | """ 14 | from numpy import genfromtxt 15 | losses = genfromtxt(file_name, delimiter=delimiter) 16 | return losses 17 | 18 | 19 | def plot_loss(*loss_vals, plot_name="Loss plot", 20 | fig_size=(17, 7), save_path=None, 21 | legends=("discriminator", "generator")): 22 | """ 23 | plot the discriminator loss values and save the plot if required 24 | :param loss_vals: (Variable Arg) numpy array or Sequence like for plotting values 25 | :param plot_name: Name of the plot 26 | :param fig_size: size of the generated figure (column_width, row_width) 27 | :param save_path: path to save the figure 28 | :param legends: list containing labels for loss plots' legends 29 | len(legends) == len(loss_vals) 30 | :return: 31 | """ 32 | assert len(loss_vals) == len(legends), "Not enough labels for legends" 33 | 34 | plt.figure(figsize=fig_size).suptitle(plot_name) 35 | plt.grid(True, which="both") 36 | plt.ylabel("loss value") 37 | plt.xlabel("spaced iterations") 38 | 39 | plt.axhline(y=0, color='k') 40 | plt.axvline(x=0, color='k') 41 | 42 | # plot all the provided loss values in a single plot 43 | plts = [] 44 | for loss_val in loss_vals: 45 | plts.append(plt.plot(loss_val)[0]) 46 | 47 | plt.legend(plts, legends, loc="upper right", fontsize=16) 48 | 49 | if save_path is not None: 50 | plt.savefig(save_path) 51 | 52 | 53 | def parse_arguments(): 54 | """ 55 | command line arguments parser 56 | :return: args => parsed command line arguments 57 | """ 58 | parser = argparse.ArgumentParser() 59 | 60 | parser.add_argument("--loss_file", action="store", type=str, default=None, 61 | help="path to loss log file") 62 | 63 | parser.add_argument("--plot_file", action="store", type=str, default=".", 64 | help="path to the file where plots are to be saved") 65 | 66 | args = parser.parse_args() 67 | 68 | return args 69 | 70 | 71 | def main(args): 72 | """ 73 | Main function for the script 74 | :param args: parsed command line arguments 75 | :return: None 76 | """ 77 | # Make sure input logs directory is provided 78 | assert args.loss_file is not None, "Loss-Log file not specified" 79 | 80 | # read the loss file 81 | loss_vals = read_loss_log(args.loss_file) 82 | 83 | # plot the loss: 84 | plot_loss(loss_vals[:, 0], loss_vals[:, 1], save_path=args.plot_file) 85 | 86 | print("Loss plots have been successfully generated ...") 87 | print("Please check: ", args.plot_file) 88 | 89 | 90 | if __name__ == '__main__': 91 | main(parse_arguments()) 92 | -------------------------------------------------------------------------------- /samples/sample_celeba/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the generated samples and trained models 2 | models/ 3 | samples/ -------------------------------------------------------------------------------- /samples/sample_celeba/configs/1/dis.conf: -------------------------------------------------------------------------------- 1 | # configuration for the discriminator architecture 2 | 3 | architecture: 4 | - 5 | name: "conv" 6 | channels: (3, 64) 7 | kernel_dims: (4, 4) 8 | stride: (2, 2) 9 | padding: (1, 1) 10 | bias: True 11 | batch_norm: False 12 | spectral_norm: True 13 | activation: "lrelu(0.1)" 14 | 15 | - 16 | name: "conv" 17 | channels: (64, 128) 18 | kernel_dims: (4, 4) 19 | stride: (2, 2) 20 | padding: (1, 1) 21 | bias: True 22 | batch_norm: False 23 | spectral_norm: True 24 | activation: "lrelu(0.1)" 25 | 26 | - 27 | name: "conv" 28 | channels: (128, 256) 29 | kernel_dims: (4, 4) 30 | stride: (2, 2) 31 | padding: (1, 1) 32 | bias: True 33 | batch_norm: False 34 | spectral_norm: True 35 | activation: "lrelu(0.1)" 36 | 37 | - 38 | name: "self_attention" 39 | channels: 256 40 | bias: True 41 | squeeze_factor: 8 42 | 43 | - 44 | name: "ignore_attn_maps" 45 | 46 | - 47 | name: "conv" 48 | channels: (256, 512) 49 | kernel_dims: (4, 4) 50 | stride: (2, 2) 51 | padding: (1, 1) 52 | bias: True 53 | batch_norm: False 54 | spectral_norm: True 55 | activation: "lrelu(0.1)" 56 | 57 | - 58 | name: "self_attention" 59 | channels: 512 60 | bias: True 61 | squeeze_factor: 8 62 | 63 | - 64 | name: "ignore_attn_maps" 65 | 66 | - 67 | name: "conv" 68 | channels: (512, 1) 69 | kernel_dims: (4, 4) 70 | stride: (1, 1) 71 | padding: (0, 0) 72 | bias: True 73 | batch_norm: False 74 | spectral_norm: False 75 | activation: "lrelu(1.0)" -------------------------------------------------------------------------------- /samples/sample_celeba/configs/1/gen.conf: -------------------------------------------------------------------------------- 1 | # configuration for the Generator architecture 2 | 3 | architecture: 4 | - 5 | name: "conv_transpose" 6 | channels: (128, 512) 7 | kernel_dims: (4, 4) 8 | stride: (1, 1) 9 | padding: (0, 0) 10 | bias: True 11 | batch_norm: True 12 | spectral_norm: True 13 | activation: "relu" 14 | 15 | - 16 | name: "conv_transpose" 17 | channels: (512, 256) 18 | kernel_dims: (4, 4) 19 | stride: (2, 2) 20 | padding: (1, 1) 21 | bias: True 22 | batch_norm: True 23 | spectral_norm: True 24 | activation: "relu" 25 | 26 | - 27 | name: "conv_transpose" 28 | channels: (256, 128) 29 | kernel_dims: (4, 4) 30 | stride: (2, 2) 31 | padding: (1, 1) 32 | bias: True 33 | batch_norm: True 34 | spectral_norm: True 35 | activation: "relu" 36 | 37 | - 38 | name: "self_attention" 39 | channels: 128 40 | bias: True 41 | squeeze_factor: 8 42 | 43 | - 44 | name: "ignore_attn_maps" 45 | 46 | - 47 | name: "conv_transpose" 48 | channels: (128, 64) 49 | kernel_dims: (4, 4) 50 | stride: (2, 2) 51 | padding: (1, 1) 52 | bias: True 53 | batch_norm: True 54 | spectral_norm: True 55 | activation: "relu" 56 | 57 | - 58 | name: "self_attention" 59 | channels: 64 60 | bias: True 61 | squeeze_factor: 8 62 | 63 | - 64 | name: "ignore_attn_maps" 65 | 66 | - 67 | name: "conv_transpose" 68 | channels: (64, 3) 69 | kernel_dims: (4, 4) 70 | stride: (2, 2) 71 | padding: (1, 1) 72 | bias: True 73 | batch_norm: False 74 | spectral_norm: False 75 | activation: "tanh" -------------------------------------------------------------------------------- /samples/sample_celeba/configs/2/dis.conf: -------------------------------------------------------------------------------- 1 | # configuration for the discriminator architecture 2 | 3 | architecture: 4 | - 5 | name: "conv" 6 | channels: (3, 64) 7 | kernel_dims: (4, 4) 8 | stride: (2, 2) 9 | padding: (1, 1) 10 | bias: True 11 | batch_norm: False 12 | spectral_norm: True 13 | activation: "lrelu(0.1)" 14 | 15 | - 16 | name: "conv" 17 | channels: (64, 128) 18 | kernel_dims: (4, 4) 19 | stride: (2, 2) 20 | padding: (1, 1) 21 | bias: True 22 | batch_norm: False 23 | spectral_norm: True 24 | activation: "lrelu(0.1)" 25 | 26 | - 27 | name: "full_attention" 28 | channels: (128, 256) 29 | kernel_dims: (4, 4) 30 | stride: (2, 2) 31 | padding: (1, 1) 32 | bias: True 33 | use_batch_norm: False 34 | use_spectral_norm: False 35 | squeeze_factor: 8 36 | transpose_conv: False 37 | activation: "lrelu(0.3)" 38 | 39 | - 40 | name: "ignore_attn_maps" 41 | 42 | - 43 | name: "self_attention" 44 | channels: 256 45 | bias: True 46 | squeeze_factor: 8 47 | 48 | - 49 | name: "ignore_attn_maps" 50 | 51 | - 52 | name: "conv" 53 | channels: (256, 512) 54 | kernel_dims: (4, 4) 55 | stride: (2, 2) 56 | padding: (1, 1) 57 | bias: True 58 | batch_norm: False 59 | spectral_norm: True 60 | activation: "lrelu(0.1)" 61 | 62 | - 63 | name: "self_attention" 64 | channels: 512 65 | bias: True 66 | squeeze_factor: 8 67 | 68 | - 69 | name: "ignore_attn_maps" 70 | 71 | - 72 | name: "conv" 73 | channels: (512, 1) 74 | kernel_dims: (4, 4) 75 | stride: (1, 1) 76 | padding: (0, 0) 77 | bias: True 78 | batch_norm: False 79 | spectral_norm: False 80 | activation: "lrelu(1.0)" -------------------------------------------------------------------------------- /samples/sample_celeba/configs/2/gen.conf: -------------------------------------------------------------------------------- 1 | # configuration for the Generator architecture 2 | 3 | architecture: 4 | - 5 | name: "conv_transpose" 6 | channels: (128, 512) 7 | kernel_dims: (4, 4) 8 | stride: (1, 1) 9 | padding: (0, 0) 10 | bias: True 11 | batch_norm: True 12 | spectral_norm: True 13 | activation: "relu" 14 | 15 | - 16 | name: "full_attention" 17 | channels: (512, 256) 18 | kernel_dims: (4, 4) 19 | stride: (2, 2) 20 | padding: (1, 1) 21 | bias: True 22 | use_spectral_norm: False 23 | use_batch_norm: False 24 | squeeze_factor: 8 25 | transpose_conv: True 26 | activation: "lrelu(0.3)" 27 | 28 | - 29 | name: "ignore_attn_maps" 30 | 31 | - 32 | name: "conv_transpose" 33 | channels: (256, 128) 34 | kernel_dims: (4, 4) 35 | stride: (2, 2) 36 | padding: (1, 1) 37 | bias: True 38 | batch_norm: True 39 | spectral_norm: True 40 | activation: "relu" 41 | 42 | - 43 | name: "self_attention" 44 | channels: 128 45 | bias: True 46 | squeeze_factor: 8 47 | 48 | - 49 | name: "ignore_attn_maps" 50 | 51 | - 52 | name: "conv_transpose" 53 | channels: (128, 64) 54 | kernel_dims: (4, 4) 55 | stride: (2, 2) 56 | padding: (1, 1) 57 | bias: True 58 | batch_norm: True 59 | spectral_norm: True 60 | activation: "relu" 61 | 62 | - 63 | name: "self_attention" 64 | channels: 64 65 | bias: True 66 | squeeze_factor: 8 67 | 68 | - 69 | name: "ignore_attn_maps" 70 | 71 | - 72 | name: "conv_transpose" 73 | channels: (64, 3) 74 | kernel_dims: (4, 4) 75 | stride: (2, 2) 76 | padding: (1, 1) 77 | bias: True 78 | batch_norm: False 79 | spectral_norm: False 80 | activation: "tanh" -------------------------------------------------------------------------------- /samples/sample_celeba/configs/3/dis.conf: -------------------------------------------------------------------------------- 1 | # configuration for the discriminator architecture 2 | 3 | architecture: 4 | - 5 | name: "conv" 6 | channels: (3, 64) 7 | kernel_dims: (4, 4) 8 | stride: (2, 2) 9 | padding: (1, 1) 10 | bias: True 11 | batch_norm: False 12 | spectral_norm: True 13 | activation: "lrelu(0.1)" 14 | 15 | - 16 | name: "conv" 17 | channels: (64, 128) 18 | kernel_dims: (4, 4) 19 | stride: (2, 2) 20 | padding: (1, 1) 21 | bias: True 22 | batch_norm: False 23 | spectral_norm: True 24 | activation: "lrelu(0.1)" 25 | 26 | - 27 | name: "full_attention" 28 | channels: (128, 256) 29 | kernel_dims: (4, 4) 30 | stride: (2, 2) 31 | padding: (1, 1) 32 | bias: True 33 | use_spectral_norm: True 34 | use_batch_norm: True 35 | squeeze_factor: 8 36 | transpose_conv: False 37 | activation: "lrelu(0.3)" 38 | 39 | - 40 | name: "ignore_attn_maps" 41 | 42 | - 43 | name: "self_attention" 44 | channels: 256 45 | bias: True 46 | squeeze_factor: 8 47 | 48 | - 49 | name: "ignore_attn_maps" 50 | 51 | - 52 | name: "conv" 53 | channels: (256, 512) 54 | kernel_dims: (4, 4) 55 | stride: (2, 2) 56 | padding: (1, 1) 57 | bias: True 58 | batch_norm: False 59 | spectral_norm: True 60 | activation: "lrelu(0.1)" 61 | 62 | - 63 | name: "self_attention" 64 | channels: 512 65 | bias: True 66 | squeeze_factor: 8 67 | 68 | - 69 | name: "ignore_attn_maps" 70 | 71 | - 72 | name: "conv" 73 | channels: (512, 1) 74 | kernel_dims: (4, 4) 75 | stride: (1, 1) 76 | padding: (0, 0) 77 | bias: True 78 | batch_norm: False 79 | spectral_norm: False 80 | activation: "lrelu(1.0)" -------------------------------------------------------------------------------- /samples/sample_celeba/configs/3/gen.conf: -------------------------------------------------------------------------------- 1 | # configuration for the Generator architecture 2 | 3 | architecture: 4 | - 5 | name: "conv_transpose" 6 | channels: (128, 512) 7 | kernel_dims: (4, 4) 8 | stride: (1, 1) 9 | padding: (0, 0) 10 | bias: True 11 | batch_norm: True 12 | spectral_norm: True 13 | activation: "relu" 14 | 15 | - 16 | name: "full_attention" 17 | channels: (512, 256) 18 | kernel_dims: (4, 4) 19 | stride: (2, 2) 20 | padding: (1, 1) 21 | bias: True 22 | use_spectral_norm: True 23 | use_batch_norm: True 24 | squeeze_factor: 8 25 | transpose_conv: True 26 | activation: "lrelu(0.3)" 27 | 28 | - 29 | name: "ignore_attn_maps" 30 | 31 | - 32 | name: "conv_transpose" 33 | channels: (256, 128) 34 | kernel_dims: (4, 4) 35 | stride: (2, 2) 36 | padding: (1, 1) 37 | bias: True 38 | batch_norm: True 39 | spectral_norm: True 40 | activation: "relu" 41 | 42 | - 43 | name: "self_attention" 44 | channels: 128 45 | bias: True 46 | squeeze_factor: 8 47 | 48 | - 49 | name: "ignore_attn_maps" 50 | 51 | - 52 | name: "conv_transpose" 53 | channels: (128, 64) 54 | kernel_dims: (4, 4) 55 | stride: (2, 2) 56 | padding: (1, 1) 57 | bias: True 58 | batch_norm: True 59 | spectral_norm: True 60 | activation: "relu" 61 | 62 | - 63 | name: "self_attention" 64 | channels: 64 65 | bias: True 66 | squeeze_factor: 8 67 | 68 | - 69 | name: "ignore_attn_maps" 70 | 71 | - 72 | name: "conv_transpose" 73 | channels: (64, 3) 74 | kernel_dims: (4, 4) 75 | stride: (2, 2) 76 | padding: (1, 1) 77 | bias: True 78 | batch_norm: False 79 | spectral_norm: False 80 | activation: "tanh" -------------------------------------------------------------------------------- /samples/sample_celeba/train.py: -------------------------------------------------------------------------------- 1 | """ script for training a Self Attention GAN on celeba images """ 2 | 3 | import torch as th 4 | import argparse 5 | 6 | from torch.backends import cudnn 7 | 8 | # define the device for the training script 9 | device = th.device("cuda" if th.cuda.is_available() else "cpu") 10 | 11 | # enable fast training 12 | cudnn.benchmark = True 13 | 14 | 15 | def parse_arguments(): 16 | """ 17 | command line arguments parser 18 | :return: args => parsed command line arguments 19 | """ 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--generator_config", action="store", type=str, 22 | default="configs/3/gen.conf", 23 | help="default configuration for generator network") 24 | 25 | parser.add_argument("--discriminator_config", action="store", type=str, 26 | default="configs/3/dis.conf", 27 | help="default configuration for discriminator network") 28 | 29 | parser.add_argument("--images_dir", action="store", type=str, 30 | default="../data/celeba", 31 | help="path for the images directory") 32 | 33 | parser.add_argument("--latent_size", action="store", type=int, 34 | default=128, 35 | help="latent size for the generator") 36 | 37 | parser.add_argument("--batch_size", action="store", type=int, 38 | default=64, 39 | help="batch_size for training") 40 | 41 | parser.add_argument("--num_epochs", action="store", type=int, 42 | default=3, 43 | help="number of epochs for training") 44 | 45 | parser.add_argument("--checkpoint_factor", action="store", type=int, 46 | default=1, 47 | help="save model per n epochs") 48 | 49 | parser.add_argument("--g_lr", action="store", type=float, 50 | default=0.0001, 51 | help="learning rate for generator") 52 | 53 | parser.add_argument("--d_lr", action="store", type=float, 54 | default=0.0004, 55 | help="learning rate for discriminator") 56 | 57 | parser.add_argument("--data_percentage", action="store", type=float, 58 | default=100, 59 | help="percentage of data to use") 60 | 61 | parser.add_argument("--num_workers", action="store", type=int, 62 | default=3, 63 | help="number of parallel workers for reading files") 64 | 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | def main(args): 71 | """ 72 | Main function for the script 73 | :param args: parsed command line arguments 74 | :return: None 75 | """ 76 | from attn_gan_pytorch.Utils import get_layer 77 | from attn_gan_pytorch.ConfigManagement import get_config 78 | from attn_gan_pytorch.Networks import Generator, Discriminator, GAN 79 | from data_processing.DataLoader import FlatDirectoryImageDataset, \ 80 | get_transform, get_data_loader 81 | from attn_gan_pytorch.Losses import RelativisticAverageHingeGAN 82 | 83 | # create a data source: 84 | celeba_dataset = FlatDirectoryImageDataset(args.images_dir, 85 | transform=get_transform((64, 64))) 86 | data = get_data_loader(celeba_dataset, args.batch_size, args.num_workers) 87 | 88 | # create generator object: 89 | gen_conf = get_config(args.generator_config) 90 | gen_conf = list(map(get_layer, gen_conf.architecture)) 91 | generator = Generator(gen_conf, args.latent_size) 92 | 93 | print("Generator Configuration: ") 94 | print(generator) 95 | 96 | # create discriminator object: 97 | dis_conf = get_config(args.discriminator_config) 98 | dis_conf = list(map(get_layer, dis_conf.architecture)) 99 | discriminator = Discriminator(dis_conf) 100 | 101 | print("Discriminator Configuration: ") 102 | print(discriminator) 103 | 104 | # create a gan from these 105 | sagan = GAN(generator, discriminator, device=device) 106 | 107 | # create optimizer for generator: 108 | gen_optim = th.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), 109 | args.g_lr, [0, 0.9]) 110 | 111 | dis_optim = th.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), 112 | args.d_lr, [0, 0.9]) 113 | 114 | # train the GAN 115 | sagan.train( 116 | data, 117 | gen_optim, 118 | dis_optim, 119 | loss_fn=RelativisticAverageHingeGAN(device, discriminator), 120 | num_epochs=args.num_epochs, 121 | checkpoint_factor=args.checkpoint_factor, 122 | data_percentage=args.data_percentage, 123 | feedback_factor=31, 124 | num_samples=64, 125 | save_dir="models/relativistic/", 126 | sample_dir="samples/4/", 127 | log_dir="models/relativistic" 128 | ) 129 | 130 | 131 | if __name__ == '__main__': 132 | # invoke the main function of the script 133 | main(parse_arguments()) 134 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='attn_gan_pytorch', 5 | version='0.6', 6 | packages=find_packages(exclude=("samples", "literature")), 7 | url='https://github.com/akanimax/attn_gan_pytorch', 8 | license='MIT', 9 | author='animesh karnewar', 10 | author_email='animeshsk3@gmail.com', 11 | description='python package for self-attention gan implemented as extension of ' + 12 | 'PyTorch nn.Module. paper -> https://arxiv.org/abs/1805.08318', 13 | install_requires=['torch', 'torchvision', 'numpy', 'PyYAML'] 14 | ) 15 | --------------------------------------------------------------------------------